# Stroke Prediction: Data Preprocessing and EDA

This notebook focuses on the following preprocessing steps:
1. Loading and exploring the stroke dataset
2. Visualizing key features and relationships
3. Handling missing values with multiple imputation strategies
4. Creating feature encodings and transformations
5. Saving processed data for the modeling phase

## 1. Import Libraries

In [None]:
# General libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import os
from tqdm import tqdm

# Specific libraries for preprocessing and imputation
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer, KNNImputer
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer # MICE
from sklearn.neighbors import KNeighborsRegressor
from sklearn.linear_model import LinearRegression, LogisticRegression

# Set random seed for reproducibility
np.random.seed(42)

# Display settings
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 100)
pd.set_option('display.float_format', '{:.2f}'.format)

# Ignore warnings for cleaner output
warnings.filterwarnings('ignore')

# Plotting style
sns.set_style('whitegrid')
plt.style.use('fivethirtyeight')

# Create directories for outputs
os.makedirs('data/processed', exist_ok=True)
os.makedirs('figures', exist_ok=True)

## 2. Load and Explore Dataset

In [None]:
# Load dataset
df = pd.read_csv('healthcare-dataset-stroke-data.csv')

# Display basic information
print(f"Dataset shape: {df.shape}")
print("\nFirst 5 rows:")
df.head()

In [None]:
# Check data types and missing values
print("Data types:")
print(df.dtypes)

print("\nMissing values:")
missing_values = df.isnull().sum()
missing_percent = (df.isnull().sum() / len(df)) * 100
missing_data = pd.DataFrame({'Missing Values': missing_values, 
                            'Percentage': missing_percent})
print(missing_data[missing_data['Missing Values'] > 0])

In [None]:
# Statistical summary
print("Statistical summary of numerical features:")
df.describe().T

In [None]:
# Categorical features summary
print("Categorical features summary:")
categorical_features = df.select_dtypes(include=['object']).columns

for feature in categorical_features:
    print(f"\n{feature}:")
    print(df[feature].value_counts())
    print(f"Percentage:\n{df[feature].value_counts(normalize=True) * 100}")

In [None]:
# Check class distribution (target variable)
print("Target variable distribution:")
print(df['stroke'].value_counts())
print(f"Percentage:\n{df['stroke'].value_counts(normalize=True) * 100}")

In [None]:
# Check for any unusual values or potential data quality issues
print("Check for unusual values:")
for col in df.columns:
    if df[col].dtype != 'object':
        print(f"{col}: Min={df[col].min()}, Max={df[col].max()}")

## 3. Exploratory Data Analysis (EDA)

### 3.1 Target Variable Analysis

In [None]:
# Visualize class distribution
plt.figure(figsize=(10, 6))
ax = sns.countplot(x='stroke', data=df, palette='Set2')

# Add count labels
for p in ax.patches:
    ax.annotate(f'{p.get_height()}', (p.get_x() + p.get_width() / 2., p.get_height()),
                ha='center', va='bottom', fontsize=12)
    
plt.title('Stroke Distribution (Target Variable)', fontsize=15)
plt.xlabel('Stroke (0=No, 1=Yes)', fontsize=12)
plt.ylabel('Count', fontsize=12)
plt.savefig('figures/stroke_distribution.png', dpi=300, bbox_inches='tight')
plt.show()

# Calculate class imbalance ratio
imbalance_ratio = df['stroke'].value_counts()[0] / df['stroke'].value_counts()[1]
print(f"Class imbalance ratio (No:Yes): {imbalance_ratio:.2f}:1")

### 3.2 Numerical Features Analysis

In [None]:
# Explore numerical features distribution
numerical_features = ['age', 'avg_glucose_level', 'bmi']

plt.figure(figsize=(18, 12))
for i, feature in enumerate(numerical_features, 1):
    # Distribution by stroke
    plt.subplot(3, 3, i)
    sns.histplot(data=df, x=feature, hue='stroke', kde=True, bins=30, alpha=0.6, element='step')
    plt.title(f'{feature} Distribution by Stroke')
    
    # Boxplot by stroke
    plt.subplot(3, 3, i+3)
    sns.boxplot(x='stroke', y=feature, data=df, palette='Set2')
    plt.title(f'{feature} by Stroke Status')
    
    # Violin plot
    plt.subplot(3, 3, i+6)
    sns.violinplot(x='stroke', y=feature, data=df, palette='Set2', inner='quartile')
    plt.title(f'{feature} Distribution (Violin) by Stroke Status')

plt.tight_layout()
plt.savefig('figures/numerical_features_analysis.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Correlation analysis between numerical features
plt.figure(figsize=(10, 8))
corr_matrix = df[numerical_features + ['stroke']].corr()
mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
sns.heatmap(corr_matrix, mask=mask, annot=True, fmt='.2f', cmap='coolwarm', 
            linewidths=0.5, vmin=-1, vmax=1)
plt.title('Correlation Heatmap of Numerical Features', fontsize=15)
plt.savefig('figures/correlation_heatmap.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Pairplot of numerical features by stroke status
plt.figure(figsize=(12, 8))
sns.pairplot(df[numerical_features + ['stroke']], hue='stroke', diag_kind='kde',
             palette='Set2', height=2.5, aspect=1.2)
plt.savefig('figures/pairplot_numerical_features.png', dpi=300, bbox_inches='tight')
plt.show()

### 3.3 Categorical Features Analysis

In [None]:
# Analyze categorical features and their relation with stroke
categorical_features = ['gender', 'hypertension', 'heart_disease', 'ever_married', 
                       'work_type', 'Residence_type', 'smoking_status']

plt.figure(figsize=(20, 15))
for i, feature in enumerate(categorical_features, 1):
    plt.subplot(3, 3, i)
    
    # Calculate stroke percentage for each category
    stroke_pct = df.groupby(feature)['stroke'].mean() * 100
    counts = df[feature].value_counts()
    
    # Create a DataFrame for plotting
    plot_df = pd.DataFrame({
        'Category': stroke_pct.index,
        'Stroke_Percentage': stroke_pct.values,
        'Count': counts.values
    })
    
    # Sort by stroke percentage
    plot_df = plot_df.sort_values('Stroke_Percentage', ascending=False)
    
    # Bar plot
    ax = sns.barplot(x='Category', y='Stroke_Percentage', data=plot_df, palette='coolwarm')
    
    # Add count labels
    for j, p in enumerate(ax.patches):
        ax.annotate(f'n={plot_df["Count"].iloc[j]}', 
                    (p.get_x() + p.get_width() / 2., p.get_height() + 0.3),
                    ha='center', va='bottom', fontsize=9)
    
    plt.title(f'Stroke Rate by {feature}')
    plt.ylabel('Stroke Percentage (%)')
    plt.xticks(rotation=45, ha='right')
    plt.ylim(0, max(plot_df['Stroke_Percentage']) * 1.2)

plt.tight_layout()
plt.savefig('figures/categorical_features_analysis.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Stacked bar plot for categorical features
plt.figure(figsize=(20, 15))
for i, feature in enumerate(categorical_features, 1):
    plt.subplot(3, 3, i)
    
    # Calculate proportions
    props = df.groupby([feature, 'stroke']).size().unstack()
    props = props.div(props.sum(axis=1), axis=0)
    
    # Plot
    props.plot(kind='bar', stacked=True, ax=plt.gca(), 
               color=['#3498db', '#e74c3c'], width=0.8)
    
    plt.title(f'Stroke Distribution by {feature}')
    plt.ylabel('Proportion')
    plt.xlabel(feature)
    plt.xticks(rotation=45, ha='right')
    plt.legend(['No Stroke (0)', 'Stroke (1)'])

plt.tight_layout()
plt.savefig('figures/categorical_stacked_plots.png', dpi=300, bbox_inches='tight')
plt.show()

### 3.4 Age-specific Analysis

In [None]:
# Create age groups
df['age_group'] = pd.cut(df['age'], bins=[0, 18, 30, 40, 50, 60, 70, 80, 100],
                         labels=['0-18', '19-30', '31-40', '41-50', '51-60', '61-70', '71-80', '81+'])

# Analyze stroke rate by age group
plt.figure(figsize=(12, 6))
age_stroke = df.groupby('age_group')['stroke'].mean() * 100
counts = df['age_group'].value_counts().sort_index()

# Plot bar chart
ax = sns.barplot(x=age_stroke.index, y=age_stroke.values, palette='rocket')

# Add count labels
for i, p in enumerate(ax.patches):
    ax.annotate(f'n={counts.iloc[i]}', 
                (p.get_x() + p.get_width() / 2., p.get_height() + 0.3),
                ha='center', va='bottom', fontsize=10)

plt.title('Stroke Rate by Age Group', fontsize=15)
plt.xlabel('Age Group', fontsize=12)
plt.ylabel('Stroke Percentage (%)', fontsize=12)
plt.ylim(0, max(age_stroke.values) * 1.2)
plt.savefig('figures/stroke_rate_by_age.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# BMI vs Age with Stroke indication
plt.figure(figsize=(12, 8))
sns.scatterplot(data=df.dropna(subset=['bmi']), x='age', y='bmi', 
                hue='stroke', palette={0: '#3498db', 1: '#e74c3c'}, 
                size='stroke', sizes={0: 30, 1: 100}, alpha=0.7)

plt.title('BMI vs Age with Stroke Indication', fontsize=15)
plt.xlabel('Age', fontsize=12)
plt.ylabel('BMI', fontsize=12)
plt.grid(True, alpha=0.3)
plt.legend(title='Stroke', labels=['No', 'Yes'])
plt.savefig('figures/bmi_vs_age.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Glucose level vs Age with Stroke indication
plt.figure(figsize=(12, 8))
sns.scatterplot(data=df, x='age', y='avg_glucose_level', 
                hue='stroke', palette={0: '#3498db', 1: '#e74c3c'}, 
                size='stroke', sizes={0: 30, 1: 100}, alpha=0.7)

plt.title('Glucose Level vs Age with Stroke Indication', fontsize=15)
plt.xlabel('Age', fontsize=12)
plt.ylabel('Average Glucose Level', fontsize=12)
plt.grid(True, alpha=0.3)
plt.legend(title='Stroke', labels=['No', 'Yes'])
plt.savefig('figures/glucose_vs_age.png', dpi=300, bbox_inches='tight')
plt.show()

### 3.5 Additional Analysis: Hypertension & Heart Disease

In [None]:
# Combine hypertension and heart disease to analyze comorbidity
df['comorbidity'] = df['hypertension'] + df['heart_disease']
df['comorbidity'] = df['comorbidity'].map({0: 'None', 1: 'One Condition', 2: 'Both Conditions'})

# Analyze stroke rate by comorbidity status
plt.figure(figsize=(10, 6))
comorbidity_stroke = df.groupby('comorbidity')['stroke'].mean() * 100
counts = df['comorbidity'].value_counts().reindex(['None', 'One Condition', 'Both Conditions'])

# Plot bar chart
ax = sns.barplot(x=comorbidity_stroke.index, y=comorbidity_stroke.values, palette='YlOrRd')

# Add count labels
for i, p in enumerate(ax.patches):
    ax.annotate(f'n={counts.iloc[i]}', 
                (p.get_x() + p.get_width() / 2., p.get_height() + 0.3),
                ha='center', va='bottom', fontsize=10)

plt.title('Stroke Rate by Comorbidity Status', fontsize=15)
plt.xlabel('Comorbidity Status', fontsize=12)
plt.ylabel('Stroke Percentage (%)', fontsize=12)
plt.ylim(0, max(comorbidity_stroke.values) * 1.2)
plt.savefig('figures/stroke_rate_by_comorbidity.png', dpi=300, bbox_inches='tight')
plt.show()

## 4. Feature Engineering

In [None]:
# Create a copy of the dataframe for feature engineering
df_fe = df.copy()

# Remove ID column as it doesn't provide predictive value
df_fe = df_fe.drop('id', axis=1)

# Create BMI categories based on standard ranges
def categorize_bmi(bmi):
    if pd.isna(bmi):
        return np.nan
    elif bmi < 18.5:
        return 'Underweight'
    elif bmi < 25:
        return 'Normal'
    elif bmi < 30:
        return 'Overweight'
    else:
        return 'Obese'

df_fe['bmi_category'] = df_fe['bmi'].apply(categorize_bmi)

# Create glucose level categories
def categorize_glucose(glucose):
    if glucose < 70:
        return 'Low'
    elif glucose < 100:
        return 'Normal'
    elif glucose < 126:
        return 'Prediabetes'
    else:
        return 'Diabetes'

df_fe['glucose_category'] = df_fe['avg_glucose_level'].apply(categorize_glucose)

# Create interaction features
df_fe['age_hypertension'] = df_fe['age'] * df_fe['hypertension']
df_fe['age_heart_disease'] = df_fe['age'] * df_fe['heart_disease']
df_fe['glucose_bmi'] = df_fe['avg_glucose_level'] * df_fe['bmi']

# Create a binary variable for senior citizens (age >= 65)
df_fe['is_senior'] = (df_fe['age'] >= 65).astype(int)

# Display new features
print("New features added:")
print(df_fe[['bmi_category', 'glucose_category', 'age_hypertension', 
             'age_heart_disease', 'glucose_bmi', 'is_senior']].head())

# Count of missing values in new features
print("\nMissing values in new features:")
print(df_fe[['bmi_category', 'glucose_category', 'age_hypertension', 
             'age_heart_disease', 'glucose_bmi', 'is_senior']].isnull().sum())

## 5. Handle Missing Values (BMI)

### 5.1 Analyze BMI Missing Pattern

In [None]:
# Analyze missing BMI values pattern
missing_bmi = df_fe['bmi'].isnull()
print(f"Number of records with missing BMI: {missing_bmi.sum()} ({missing_bmi.mean()*100:.2f}%)")

# Compare stroke rate in records with and without BMI
stroke_rate_with_bmi = df_fe[~missing_bmi]['stroke'].mean() * 100
stroke_rate_without_bmi = df_fe[missing_bmi]['stroke'].mean() * 100

print(f"Stroke rate in records with BMI: {stroke_rate_with_bmi:.2f}%")
print(f"Stroke rate in records without BMI: {stroke_rate_without_bmi:.2f}%")

# Visualize comparison of other features between records with and without BMI
plt.figure(figsize=(15, 10))

# Age distribution by BMI missing status
plt.subplot(2, 2, 1)
sns.histplot(data=df_fe, x='age', hue=missing_bmi, kde=True, 
             common_norm=False, palette=['#3498db', '#e74c3c'])
plt.title('Age Distribution by BMI Missing Status')
plt.legend(['BMI Present', 'BMI Missing'])

# Glucose distribution by BMI missing status
plt.subplot(2, 2, 2)
sns.histplot(data=df_fe, x='avg_glucose_level', hue=missing_bmi, kde=True, 
             common_norm=False, palette=['#3498db', '#e74c3c'])
plt.title('Glucose Distribution by BMI Missing Status')
plt.legend(['BMI Present', 'BMI Missing'])

# Gender distribution by BMI missing status
plt.subplot(2, 2, 3)
sns.countplot(data=df_fe, x='gender', hue=missing_bmi, palette=['#3498db', '#e74c3c'])
plt.title('Gender Distribution by BMI Missing Status')
plt.legend(['BMI Present', 'BMI Missing'])

# Stroke distribution by BMI missing status
plt.subplot(2, 2, 4)
sns.countplot(data=df_fe, x='stroke', hue=missing_bmi, palette=['#3498db', '#e74c3c'])
plt.title('Stroke Distribution by BMI Missing Status')
plt.legend(['BMI Present', 'BMI Missing'])

plt.tight_layout()
plt.savefig('figures/bmi_missing_pattern.png', dpi=300, bbox_inches='tight')
plt.show()

### 5.2 Implement Multiple Imputation Strategies

In [None]:
# Create copies for different imputation methods
df_mean = df_fe.copy()  # Simple mean imputation
df_mice = df_fe.copy()  # MICE imputation
df_age_group = df_fe.copy()  # Age-group based imputation

# 1. Simple Mean Imputation
mean_imputer = SimpleImputer(strategy='mean')
df_mean['bmi'] = mean_imputer.fit_transform(df_mean[['bmi']])
print("1. Mean Imputation - Completed")
print(f"   BMI mean after imputation: {df_mean['bmi'].mean():.2f}")

# 2. MICE (Multiple Imputation by Chained Equations)
# Prepare data for MICE
mice_features = ['age', 'hypertension', 'heart_disease', 'avg_glucose_level', 'bmi']
mice_data = df_mice[mice_features].copy()

# Convert categorical features to binary for imputation
mice_categorical = df_mice[['gender', 'ever_married', 'work_type', 'Residence_type', 'smoking_status']]
mice_categorical_encoded = pd.get_dummies(mice_categorical, drop_first=True)
mice_data_full = pd.concat([mice_data, mice_categorical_encoded], axis=1)

# Apply MICE imputation
mice_imputer = IterativeImputer(estimator=LinearRegression(), 
                               random_state=42, 
                               max_iter=10, 
                               verbose=0)
mice_imputed = mice_imputer.fit_transform(mice_data_full)

# Update BMI values
df_mice['bmi'] = mice_imputed[:, mice_data_full.columns.get_loc('bmi')]
print("2. MICE Imputation - Completed")
print(f"   BMI mean after imputation: {df_mice['bmi'].mean():.2f}")

# 3. Age-group based imputation
# Calculate median BMI by age group
age_group_bmi_median = df_age_group.groupby('age_group')['bmi'].median()
print("\nMedian BMI by age group:")
print(age_group_bmi_median)

# Impute based on age group
for age_group in age_group_bmi_median.index:
    mask = (df_age_group['age_group'] == age_group) & (df_age_group['bmi'].isna())
    df_age_group.loc[mask, 'bmi'] = age_group_bmi_median[age_group]
    
# If any missing values remain (e.g., if an age group had all NaN values)
if df_age_group['bmi'].isna().any():
    df_age_group['bmi'] = df_age_group['bmi'].fillna(df_age_group['bmi'].median())
    
print("3. Age-group Imputation - Completed")
print(f"   BMI mean after imputation: {df_age_group['bmi'].mean():.2f}")

In [None]:
# 4. KNN Imputation as an additional method
df_knn = df_fe.copy()

# Prepare data for KNN imputation
knn_features = ['age', 'hypertension', 'heart_disease', 'avg_glucose_level', 'bmi']
knn_data = df_knn[knn_features].copy()

# Normalize numerical features for KNN
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
knn_data_scaled = pd.DataFrame(
    scaler.fit_transform(knn_data.fillna(knn_data.median())),
    columns=knn_data.columns
)

# Apply KNN imputation
knn_imputer = KNNImputer(n_neighbors=5, weights='distance')
knn_imputed = knn_imputer.fit_transform(knn_data_scaled)

# Inverse transform to get original scale
knn_imputed_original = scaler.inverse_transform(knn_imputed)

# Update BMI values
df_knn['bmi'] = knn_imputed_original[:, knn_data.columns.get_loc('bmi')]
print("4. KNN Imputation - Completed")
print(f"   BMI mean after imputation: {df_knn['bmi'].mean():.2f}")

In [None]:
# Compare imputation methods
plt.figure(figsize=(15, 10))

# Original distribution (excluding NaN)
plt.subplot(2, 2, 1)
sns.histplot(df_fe['bmi'].dropna(), kde=True, color='#3498db')
plt.axvline(df_fe['bmi'].dropna().mean(), color='r', linestyle='--')
plt.title(f'Original BMI Distribution (Mean: {df_fe["bmi"].dropna().mean():.2f})')

# Mean imputation
plt.subplot(2, 2, 2)
sns.histplot(df_mean['bmi'], kde=True, color='#2ecc71')
plt.axvline(df_mean['bmi'].mean(), color='r', linestyle='--')
plt.title(f'Mean Imputation (Mean: {df_mean["bmi"].mean():.2f})')

# MICE imputation
plt.subplot(2, 2, 3)
sns.histplot(df_mice['bmi'], kde=True, color='#e74c3c')
plt.axvline(df_mice['bmi'].mean(), color='r', linestyle='--')
plt.title(f'MICE Imputation (Mean: {df_mice["bmi"].mean():.2f})')

# KNN imputation
plt.subplot(2, 2, 4)
sns.histplot(df_knn['bmi'], kde=True, color='#9b59b6')
plt.axvline(df_knn['bmi'].mean(), color='r', linestyle='--')
plt.title(f'KNN Imputation (Mean: {df_knn["bmi"].mean():.2f})')

plt.tight_layout()
plt.savefig('figures/bmi_imputation_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Compare imputation methods by age group
plt.figure(figsize=(15, 10))

# Calculate mean BMI by age group for each imputation method
original_bmi_by_age = df_fe.groupby('age_group')['bmi'].mean()
mean_bmi_by_age = df_mean.groupby('age_group')['bmi'].mean()
mice_bmi_by_age = df_mice.groupby('age_group')['bmi'].mean()
knn_bmi_by_age = df_knn.groupby('age_group')['bmi'].mean()
age_group_bmi_by_age = df_age_group.groupby('age_group')['bmi'].mean()

# Create dataframe for plotting
bmi_comparison = pd.DataFrame({
    'Original': original_bmi_by_age,
    'Mean Imputation': mean_bmi_by_age,
    'MICE Imputation': mice_bmi_by_age,
    'KNN Imputation': knn_bmi_by_age,
    'Age-Group Imputation': age_group_bmi_by_age
})

# Plot
bmi_comparison.plot(kind='bar', figsize=(15, 8))
plt.title('Average BMI by Age Group Across Imputation Methods', fontsize=15)
plt.xlabel('Age Group', fontsize=12)
plt.ylabel('Average BMI', fontsize=12)
plt.grid(True, alpha=0.3)
plt.legend(title='Imputation Method')
plt.savefig('figures/bmi_imputation_by_age.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Visualize how well each imputation method preserves the relationship with other variables
plt.figure(figsize=(15, 20))

# BMI vs Age
plt.subplot(4, 2, 1)
sns.scatterplot(data=df_fe.dropna(subset=['bmi']), x='age', y='bmi', alpha=0.5, label='Original', color='#3498db')
plt.title('Original BMI vs Age')

plt.subplot(4, 2, 2)
mask = df_fe['bmi'].isna()
sns.scatterplot(data=df_mice[mask], x='age', y='bmi', alpha=0.5, label='MICE Imputed', color='#e74c3c')
sns.scatterplot(data=df_knn[mask], x='age', y='bmi', alpha=0.5, label='KNN Imputed', color='#9b59b6')
sns.scatterplot(data=df_age_group[mask], x='age', y='bmi', alpha=0.5, label='Age-Group Imputed', color='#2ecc71')
plt.title('Imputed BMI vs Age (Missing Values Only)')

# BMI vs Glucose
plt.subplot(4, 2, 3)
sns.scatterplot(data=df_fe.dropna(subset=['bmi']), x='avg_glucose_level', y='bmi', alpha=0.5, label='Original', color='#3498db')
plt.title('Original BMI vs Glucose')

plt.subplot(4, 2, 4)
sns.scatterplot(data=df_mice[mask], x='avg_glucose_level', y='bmi', alpha=0.5, label='MICE Imputed', color='#e74c3c')
sns.scatterplot(data=df_knn[mask], x='avg_glucose_level', y='bmi', alpha=0.5, label='KNN Imputed', color='#9b59b6')
sns.scatterplot(data=df_age_group[mask], x='avg_glucose_level', y='bmi', alpha=0.5, label='Age-Group Imputed', color='#2ecc71')
plt.title('Imputed BMI vs Glucose (Missing Values Only)')

# BMI vs Hypertension
plt.subplot(4, 2, 5)
sns.boxplot(data=df_fe.dropna(subset=['bmi']), x='hypertension', y='bmi', palette='Set2')
plt.title('Original BMI by Hypertension')

plt.subplot(4, 2, 6)
boxplot_data = pd.melt(pd.DataFrame({
    'hypertension': df_fe.loc[mask, 'hypertension'],
    'MICE': df_mice.loc[mask, 'bmi'],
    'KNN': df_knn.loc[mask, 'bmi'],
    'Age-Group': df_age_group.loc[mask, 'bmi']
}), id_vars=['hypertension'], var_name='Method', value_name='BMI')
sns.boxplot(data=boxplot_data, x='hypertension', y='BMI', hue='Method', palette='Dark2')
plt.title('Imputed BMI by Hypertension (Missing Values Only)')

# BMI vs Heart Disease
plt.subplot(4, 2, 7)
sns.boxplot(data=df_fe.dropna(subset=['bmi']), x='heart_disease', y='bmi', palette='Set2')
plt.title('Original BMI by Heart Disease')

plt.subplot(4, 2, 8)
boxplot_data = pd.melt(pd.DataFrame({
    'heart_disease': df_fe.loc[mask, 'heart_disease'],
    'MICE': df_mice.loc[mask, 'bmi'],
    'KNN': df_knn.loc[mask, 'bmi'],
    'Age-Group': df_age_group.loc[mask, 'bmi']
}), id_vars=['heart_disease'], var_name='Method', value_name='BMI')
sns.boxplot(data=boxplot_data, x='heart_disease', y='BMI', hue='Method', palette='Dark2')
plt.title('Imputed BMI by Heart Disease (Missing Values Only)')

plt.tight_layout()
plt.savefig('figures/bmi_imputation_relationships.png', dpi=300, bbox_inches='tight')
plt.show()

### 5.3 Select Best Imputation Method for Final Dataset

In [None]:
# Based on evaluation of imputation methods, select the best approach
# For this example, we'll choose MICE imputation as it typically preserves relationships best
# In a real scenario, you might want to compare these methods based on model performance

# Create final dataset with MICE imputation for BMI
df_final = df_mice.copy()

# Update the bmi_category after imputation
df_final['bmi_category'] = df_final['bmi'].apply(categorize_bmi)

# Update the glucose_bmi interaction feature
df_final['glucose_bmi'] = df_final['avg_glucose_level'] * df_final['bmi']

# Verify no missing values remain
print("Missing values in final dataset:")
print(df_final.isnull().sum())

# Show summary statistics of the final dataset
print("\nSummary statistics for numerical features:")
print(df_final[['age', 'avg_glucose_level', 'bmi']].describe())

## 6. Encode Categorical Features

In [None]:
# One-hot encode categorical features
categorical_cols = ['gender', 'ever_married', 'work_type', 'Residence_type', 
                    'smoking_status', 'bmi_category', 'glucose_category', 'age_group']

# Apply one-hot encoding
df_encoded = pd.get_dummies(df_final, columns=categorical_cols, drop_first=True)

# Show the new columns created
print(f"Original dataset columns: {len(df_final.columns)}")
print(f"Encoded dataset columns: {len(df_encoded.columns)}")

# Display first few rows of encoded dataset
print("\nFirst 5 rows of encoded dataset:")
df_encoded.head()

## 7. Save Processed Dataset

In [None]:
# Save the preprocessed datasets

# 1. Save the dataset with basic preprocessing but without encoding
# This is useful for further model-specific preprocessing
df_final.to_csv('data/processed/stroke_dataset_processed.csv', index=False)
print(f"Saved processed dataset with {df_final.shape[1]} columns to 'data/processed/stroke_dataset_processed.csv'")

# 2. Save the fully processed dataset with encoding
# This is ready for modeling
df_encoded.to_csv('data/processed/stroke_dataset_encoded.csv', index=False)
print(f"Saved encoded dataset with {df_encoded.shape[1]} columns to 'data/processed/stroke_dataset_encoded.csv'")

# 3. Save a version specifically for EDA
df_eda = df_final.copy()
df_eda.to_csv('data/processed/stroke_dataset_eda.csv', index=False)
print(f"Saved EDA dataset to 'data/processed/stroke_dataset_eda.csv'")

## 8. Summary of Preprocessing Steps

In [None]:
# Display summary of preprocessing steps
print("PREPROCESSING SUMMARY")
print("=" * 50)

print("1. Dataset Information:")
print(f"   - Original shape: {df.shape}")
print(f"   - Processed shape: {df_final.shape}")
print(f"   - Encoded shape: {df_encoded.shape}")

print("\n2. Missing Values:")
print(f"   - Original missing BMI values: {df['bmi'].isna().sum()} ({df['bmi'].isna().mean()*100:.2f}%)")
print(f"   - Imputation method used: MICE (Multiple Imputation by Chained Equations)")
print(f"   - Final missing values: {df_final.isna().sum().sum()}")

print("\n3. Class Distribution:")
print(f"   - Negative class (No Stroke): {df_final['stroke'].value_counts()[0]} ({df_final['stroke'].value_counts(normalize=True)[0]*100:.2f}%)")
print(f"   - Positive class (Stroke): {df_final['stroke'].value_counts()[1]} ({df_final['stroke'].value_counts(normalize=True)[1]*100:.2f}%)")
print(f"   - Class imbalance ratio: {df_final['stroke'].value_counts()[0]/df_final['stroke'].value_counts()[1]:.2f}:1")

print("\n4. Feature Engineering:")
print(f"   - Original features: {len(df.columns)}")
print(f"   - Added features: {len(df_final.columns) - len(df.columns) + 1}")
print(f"   - Key added features: bmi_category, glucose_category, age_group, comorbidity, age_hypertension, etc.")

print("\n5. Categorical Encoding:")
print(f"   - Categorical columns encoded: {len(categorical_cols)}")
print(f"   - Total features after encoding: {len(df_encoded.columns)}")

print("\n6. Output Files:")
print("   - data/processed/stroke_dataset_processed.csv (Processed without encoding)")
print("   - data/processed/stroke_dataset_encoded.csv (Processed with encoding)")
print("   - data/processed/stroke_dataset_eda.csv (For EDA purposes)")

print("\nPreprocessing completed successfully!")