# Stroke Prediction: Model Training & Evaluation

This notebook focuses on training and evaluating predictive models for stroke prediction:
1. Loading the preprocessed dataset
2. Handling class imbalance with SMOTE
3. Training multiple classification models
4. Building a dense stacking ensemble model
5. Evaluating models with robust metrics
6. Saving the best model

## 1. Import Libraries

In [None]:
# General libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import os
import time
import pickle
from tqdm import tqdm

# ML libraries
from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_score, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline

# Class imbalance handling
from imblearn.over_sampling import SMOTE, ADASYN, BorderlineSMOTE
from imblearn.pipeline import Pipeline as ImbPipeline

# Models
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import (
    RandomForestClassifier, GradientBoostingClassifier, 
    AdaBoostClassifier, ExtraTreesClassifier,
    VotingClassifier, StackingClassifier
)
from sklearn.svm import SVC
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
import xgboost as xgb
import lightgbm as lgb
import catboost as cb

# Metrics
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, confusion_matrix, classification_report,
    precision_recall_curve, auc, average_precision_score, roc_curve
)

# Set random seed for reproducibility
SEED = 42
np.random.seed(SEED)

# Display settings
pd.set_option('display.max_columns', None)
pd.set_option('display.float_format', '{:.4f}'.format)

# Ignore warnings for cleaner output
warnings.filterwarnings('ignore')

# Plotting style
sns.set_style('whitegrid')
plt.style.use('fivethirtyeight')

# Create directories for outputs
os.makedirs('models', exist_ok=True)
os.makedirs('results', exist_ok=True)
os.makedirs('figures/model_evaluation', exist_ok=True)

## 2. Load Processed Dataset

In [None]:
# Load the preprocessed and encoded dataset
df = pd.read_csv('data/processed/stroke_dataset_encoded.csv')

# Display basic information
print(f"Dataset shape: {df.shape}")
print(f"Number of features: {df.shape[1] - 1}")
print(f"Target distribution:\n{df['stroke'].value_counts(normalize=True) * 100}")

# Check for any missing values that might have been introduced
missing_values = df.isnull().sum()
if missing_values.sum() > 0:
    print("\nMissing values:")
    print(missing_values[missing_values > 0])
else:
    print("\nNo missing values found.")

# Display first few rows
df.head()

## 3. Prepare Data for Modeling

In [None]:
# Split features and target
X = df.drop('stroke', axis=1)
y = df['stroke']

# Split the data into training and testing sets with stratification
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.20, random_state=SEED, stratify=y
)

# Further split training data to create a validation set
X_train, X_val, y_train, y_val = train_test_split(
    X_train, y_train, test_size=0.25, random_state=SEED, stratify=y_train  # 0.25 * 0.8 = 0.2 of original data
)

print(f"Training set shape: {X_train.shape}")
print(f"Validation set shape: {X_val.shape}")
print(f"Test set shape: {X_test.shape}")

# Check class distribution in the splits
def print_class_distribution(name, y_data):
    counts = np.bincount(y_data)
    percent = counts / len(y_data) * 100
    percent_str = [f"{p:.2f}%" for p in percent]
    print(f"{name} set: {counts.tolist()} ({percent_str})")

print("\nClass distribution:")
print_class_distribution("Training", y_train)
print_class_distribution("Validation", y_val)
print_class_distribution("Test", y_test)

## 4. Handle Class Imbalance with SMOTE

In [None]:
# Standardize features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_val_scaled = scaler.transform(X_val)
X_test_scaled = scaler.transform(X_test)

# Apply SMOTE to handle class imbalance
print("Applying SMOTE to balance the training set...")
smote = SMOTE(random_state=SEED, sampling_strategy='auto')
X_train_smote, y_train_smote = smote.fit_resample(X_train_scaled, y_train)

print(f"Shape after SMOTE: {X_train_smote.shape}")
print(f"Class distribution after SMOTE: {np.bincount(y_train_smote)} ({np.bincount(y_train_smote) / len(y_train_smote) * 100:.2f}%)")

# Also try BorderlineSMOTE for comparison
print("\nApplying BorderlineSMOTE to balance the training set...")
bsmote = BorderlineSMOTE(random_state=SEED, kind='borderline-1')
X_train_bsmote, y_train_bsmote = bsmote.fit_resample(X_train_scaled, y_train)

print(f"Shape after BorderlineSMOTE: {X_train_bsmote.shape}")
print(f"Class distribution after BorderlineSMOTE: {np.bincount(y_train_bsmote)} ({np.bincount(y_train_bsmote) / len(y_train_bsmote) * 100:.2f}%)")

In [None]:
# Visualize the balanced vs. unbalanced class distribution
plt.figure(figsize=(12, 6))

# Original distribution
plt.subplot(1, 3, 1)
sns.countplot(x=y_train, palette='Set2')
plt.title('Original Class Distribution')
plt.xlabel('Stroke')
plt.ylabel('Count')

# SMOTE distribution
plt.subplot(1, 3, 2)
sns.countplot(x=y_train_smote, palette='Set2')
plt.title('After SMOTE')
plt.xlabel('Stroke')

# BorderlineSMOTE distribution
plt.subplot(1, 3, 3)
sns.countplot(x=y_train_bsmote, palette='Set2')
plt.title('After BorderlineSMOTE')
plt.xlabel('Stroke')

plt.tight_layout()
plt.savefig('figures/model_evaluation/class_balance_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

## 5. Define Evaluation Metrics and Helper Functions

In [None]:
def evaluate_model(model, X_val, y_val, model_name, threshold=0.5):
    """
    Evaluate a model on the validation set and return metrics.
    
    Parameters:
    -----------
    model : classifier object
        The trained model to evaluate
    X_val : array-like
        Validation features
    y_val : array-like
        Validation target
    model_name : str
        Name of the model for logging
    threshold : float, default=0.5
        Classification threshold
        
    Returns:
    --------
    dict
        Dictionary of evaluation metrics
    """
    # Get predictions
    try:
        y_pred_proba = model.predict_proba(X_val)[:, 1]
        y_pred = (y_pred_proba >= threshold).astype(int)
    except:
        # Some models might not have predict_proba
        y_pred = model.predict(X_val)
        y_pred_proba = None
    
    # Calculate metrics
    metrics = {}
    metrics['accuracy'] = accuracy_score(y_val, y_pred)
    metrics['precision'] = precision_score(y_val, y_pred)
    metrics['recall'] = recall_score(y_val, y_pred)
    metrics['f1'] = f1_score(y_val, y_pred)
    
    if y_pred_proba is not None:
        metrics['roc_auc'] = roc_auc_score(y_val, y_pred_proba)
        metrics['average_precision'] = average_precision_score(y_val, y_pred_proba)
    
    # Confusion matrix
    metrics['confusion_matrix'] = confusion_matrix(y_val, y_pred)
    
    # Print results
    print(f"\n{model_name} Evaluation:")
    print(f"Accuracy: {metrics['accuracy']:.4f}")
    print(f"Precision: {metrics['precision']:.4f}")
    print(f"Recall: {metrics['recall']:.4f}")
    print(f"F1 Score: {metrics['f1']:.4f}")
    if 'roc_auc' in metrics:
        print(f"ROC AUC: {metrics['roc_auc']:.4f}")
        print(f"Average Precision: {metrics['average_precision']:.4f}")
    
    print("\nConfusion Matrix:")
    print(metrics['confusion_matrix'])
    
    return metrics

def plot_roc_curve(models, X_val, y_val, figsize=(10, 8)):
    """
    Plot ROC curves for multiple models.
    
    Parameters:
    -----------
    models : dict
        Dictionary of {model_name: model}
    X_val : array-like
        Validation features
    y_val : array-like
        Validation target
    figsize : tuple, default=(10, 8)
        Figure size
    """
    plt.figure(figsize=figsize)
    
    for name, model in models.items():
        try:
            y_pred_proba = model.predict_proba(X_val)[:, 1]
            fpr, tpr, _ = roc_curve(y_val, y_pred_proba)
            roc_auc = auc(fpr, tpr)
            plt.plot(fpr, tpr, lw=2, label=f'{name} (AUC = {roc_auc:.4f})')
        except:
            print(f"{name} does not support predict_proba, skipping ROC curve.")
    
    # Random classifier line
    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate', fontsize=12)
    plt.ylabel('True Positive Rate', fontsize=12)
    plt.title('Receiver Operating Characteristic (ROC) Curve', fontsize=15)
    plt.legend(loc="lower right")
    plt.grid(True, alpha=0.3)
    plt.savefig('figures/model_evaluation/roc_curve_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()

def plot_precision_recall_curve(models, X_val, y_val, figsize=(10, 8)):
    """
    Plot precision-recall curves for multiple models.
    
    Parameters:
    -----------
    models : dict
        Dictionary of {model_name: model}
    X_val : array-like
        Validation features
    y_val : array-like
        Validation target
    figsize : tuple, default=(10, 8)
        Figure size
    """
    plt.figure(figsize=figsize)
    
    # Calculate baseline (no skill classifier)
    no_skill = len(y_val[y_val == 1]) / len(y_val)
    plt.plot([0, 1], [no_skill, no_skill], 'k--', lw=2, label='No Skill')
    
    for name, model in models.items():
        try:
            y_pred_proba = model.predict_proba(X_val)[:, 1]
            precision, recall, _ = precision_recall_curve(y_val, y_pred_proba)
            pr_auc = auc(recall, precision)
            avg_precision = average_precision_score(y_val, y_pred_proba)
            plt.plot(recall, precision, lw=2, 
                    label=f'{name} (AP = {avg_precision:.4f}, AUC = {pr_auc:.4f})')
        except:
            print(f"{name} does not support predict_proba, skipping PR curve.")
    
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('Recall', fontsize=12)
    plt.ylabel('Precision', fontsize=12)
    plt.title('Precision-Recall Curve', fontsize=15)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig('figures/model_evaluation/precision_recall_curve_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()

def plot_confusion_matrix(cm, model_name, figsize=(8, 6)):
    """
    Plot a confusion matrix.
    
    Parameters:
    -----------
    cm : array-like
        Confusion matrix
    model_name : str
        Name of the model
    figsize : tuple, default=(8, 6)
        Figure size
    """
    plt.figure(figsize=figsize)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=['No Stroke', 'Stroke'],
                yticklabels=['No Stroke', 'Stroke'])
    plt.xlabel('Predicted Label', fontsize=12)
    plt.ylabel('True Label', fontsize=12)
    plt.title(f'Confusion Matrix - {model_name}', fontsize=15)
    plt.savefig(f'figures/model_evaluation/confusion_matrix_{model_name.replace(" ", "_").lower()}.png', 
                dpi=300, bbox_inches='tight')
    plt.show()

def plot_feature_importance(model, feature_names, model_name, figsize=(12, 10), top_n=20):
    """
    Plot feature importance for a model, if available.
    
    Parameters:
    -----------
    model : classifier object
        The trained model
    feature_names : array-like
        Names of the features
    model_name : str
        Name of the model
    figsize : tuple, default=(12, 10)
        Figure size
    top_n : int, default=20
        Number of top features to display
    """
    try:
        if hasattr(model, 'feature_importances_'):
            # For tree-based models
            importances = model.feature_importances_
        elif hasattr(model, 'coef_'):
            # For linear models
            importances = np.abs(model.coef_[0])
        else:
            print(f"{model_name} does not have feature importances.")
            return
        
        # Create DataFrame of feature importances
        feature_importance = pd.DataFrame({
            'Feature': feature_names,
            'Importance': importances
        })
        
        # Sort by importance
        feature_importance = feature_importance.sort_values('Importance', ascending=False)
        
        # Plot top N features
        plt.figure(figsize=figsize)
        sns.barplot(x='Importance', y='Feature', data=feature_importance.head(top_n))
        plt.title(f'Top {top_n} Feature Importance - {model_name}', fontsize=15)
        plt.xlabel('Importance', fontsize=12)
        plt.tight_layout()
        plt.savefig(f'figures/model_evaluation/feature_importance_{model_name.replace(" ", "_").lower()}.png', 
                   dpi=300, bbox_inches='tight')
        plt.show()
        
        return feature_importance
    except Exception as e:
        print(f"Error plotting feature importance for {model_name}: {e}")
        return None

## 6. Train Base Models

In [None]:
# Initialize models for training
models = {
    'Logistic Regression': LogisticRegression(max_iter=1000, class_weight='balanced', random_state=SEED),
    'Random Forest': RandomForestClassifier(n_estimators=100, class_weight='balanced', random_state=SEED),
    'XGBoost': xgb.XGBClassifier(scale_pos_weight=10, random_state=SEED),
    'LightGBM': lgb.LGBMClassifier(class_weight='balanced', random_state=SEED),
    'CatBoost': cb.CatBoostClassifier(random_state=SEED, verbose=0),
    'SVM': SVC(probability=True, class_weight='balanced', random_state=SEED),
    'KNN': KNeighborsClassifier(n_neighbors=5, weights='distance'),
    'Gradient Boosting': GradientBoostingClassifier(random_state=SEED),
    'Extra Trees': ExtraTreesClassifier(n_estimators=100, class_weight='balanced', random_state=SEED),
    'AdaBoost': AdaBoostClassifier(random_state=SEED)
}

# Dictionary to store trained models and their metrics
trained_models = {}
model_metrics = {}

# Train each model with SMOTE-balanced data
print("Training base models with SMOTE...")
for name, model in models.items():
    print(f"\nTraining {name}...")
    start_time = time.time()
    
    # Train the model
    model.fit(X_train_smote, y_train_smote)
    
    # Evaluate on validation set
    metrics = evaluate_model(model, X_val_scaled, y_val, name)
    
    # Calculate training time
    train_time = time.time() - start_time
    metrics['train_time'] = train_time
    print(f"Training time: {train_time:.2f} seconds")
    
    # Store trained model and metrics
    trained_models[name] = model
    model_metrics[name] = metrics

In [None]:
# Train models with BorderlineSMOTE for comparison
trained_models_bsmote = {}
model_metrics_bsmote = {}

print("Training base models with BorderlineSMOTE...")
for name, model in models.items():
    # Only train top performing models from SMOTE
    if name not in ['Logistic Regression', 'Random Forest', 'XGBoost', 'CatBoost', 'LightGBM']:
        continue
        
    print(f"\nTraining {name} with BorderlineSMOTE...")
    start_time = time.time()
    
    # Create a fresh instance of the model
    if name == 'Logistic Regression':
        model = LogisticRegression(max_iter=1000, class_weight='balanced', random_state=SEED)
    elif name == 'Random Forest':
        model = RandomForestClassifier(n_estimators=100, class_weight='balanced', random_state=SEED)
    elif name == 'XGBoost':
        model = xgb.XGBClassifier(scale_pos_weight=10, random_state=SEED)
    elif name == 'LightGBM':
        model = lgb.LGBMClassifier(class_weight='balanced', random_state=SEED)
    elif name == 'CatBoost':
        model = cb.CatBoostClassifier(random_state=SEED, verbose=0)
    
    # Train the model
    model.fit(X_train_bsmote, y_train_bsmote)
    
    # Evaluate on validation set
    metrics = evaluate_model(model, X_val_scaled, y_val, f"{name} (BorderlineSMOTE)")
    
    # Calculate training time
    train_time = time.time() - start_time
    metrics['train_time'] = train_time
    print(f"Training time: {train_time:.2f} seconds")
    
    # Store trained model and metrics
    trained_models_bsmote[name] = model
    model_metrics_bsmote[name] = metrics

## 7. Compare Base Model Performance

In [None]:
# Create a dataframe with all model metrics
metrics_df = pd.DataFrame(columns=['Model', 'Accuracy', 'Precision', 'Recall', 'F1 Score', 'ROC AUC', 'Avg Precision', 'Training Time'])

# Add metrics for SMOTE models
for name, metrics in model_metrics.items():
    metrics_df = pd.concat([metrics_df, pd.DataFrame([{
        'Model': name,
        'Sampling': 'SMOTE',
        'Accuracy': metrics['accuracy'],
        'Precision': metrics['precision'],
        'Recall': metrics['recall'],
        'F1 Score': metrics['f1'],
        'ROC AUC': metrics.get('roc_auc', np.nan),
        'Avg Precision': metrics.get('average_precision', np.nan),
        'Training Time': metrics['train_time']
    }])], ignore_index=True)

# Add metrics for BorderlineSMOTE models
for name, metrics in model_metrics_bsmote.items():
    metrics_df = pd.concat([metrics_df, pd.DataFrame([{
        'Model': name,
        'Sampling': 'BorderlineSMOTE',
        'Accuracy': metrics['accuracy'],
        'Precision': metrics['precision'],
        'Recall': metrics['recall'],
        'F1 Score': metrics['f1'],
        'ROC AUC': metrics.get('roc_auc', np.nan),
        'Avg Precision': metrics.get('average_precision', np.nan),
        'Training Time': metrics['train_time']
    }])], ignore_index=True)

# Sort by F1 score (descending)
metrics_df = metrics_df.sort_values('F1 Score', ascending=False).reset_index(drop=True)

# Display all metrics
print("Model Performance Comparison (Sorted by F1 Score):")
metrics_df

In [None]:
# Visualize model performance comparison
plt.figure(figsize=(14, 10))

# Plot F1 score comparison
plt.subplot(2, 2, 1)
sns.barplot(x='F1 Score', y='Model', hue='Sampling', data=metrics_df)
plt.title('F1 Score Comparison', fontsize=14)
plt.xlim(0, 1)
plt.grid(True, alpha=0.3)

# Plot ROC AUC comparison
plt.subplot(2, 2, 2)
sns.barplot(x='ROC AUC', y='Model', hue='Sampling', data=metrics_df)
plt.title('ROC AUC Comparison', fontsize=14)
plt.xlim(0, 1)
plt.grid(True, alpha=0.3)

# Plot Precision comparison
plt.subplot(2, 2, 3)
sns.barplot(x='Precision', y='Model', hue='Sampling', data=metrics_df)
plt.title('Precision Comparison', fontsize=14)
plt.xlim(0, 1)
plt.grid(True, alpha=0.3)

# Plot Recall comparison
plt.subplot(2, 2, 4)
sns.barplot(x='Recall', y='Model', hue='Sampling', data=metrics_df)
plt.title('Recall Comparison', fontsize=14)
plt.xlim(0, 1)
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('figures/model_evaluation/model_performance_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Plot ROC curves for the top models
top_models = {}
for name in metrics_df['Model'].unique()[:5]:  # Get top 5 unique model names
    # Get the better version (SMOTE or BorderlineSMOTE) based on F1 score
    model_metrics_filtered = metrics_df[metrics_df['Model'] == name]
    best_sampling = model_metrics_filtered.iloc[0]['Sampling']
    
    if best_sampling == 'SMOTE':
        top_models[name] = trained_models[name]
    else:
        top_models[name] = trained_models_bsmote[name]

# Plot ROC curves
plot_roc_curve(top_models, X_val_scaled, y_val)

# Plot precision-recall curves
plot_precision_recall_curve(top_models, X_val_scaled, y_val)

In [None]:
# Plot confusion matrices for top 3 models
for name, model in list(top_models.items())[:3]:
    # Get confusion matrix
    y_pred = model.predict(X_val_scaled)
    cm = confusion_matrix(y_val, y_pred)
    
    # Plot
    plot_confusion_matrix(cm, name)

In [None]:
# Plot feature importance for the best performing models
feature_importances = {}

for name, model in list(top_models.items())[:3]:
    importance = plot_feature_importance(model, X_train.columns, name)
    if importance is not None:
        feature_importances[name] = importance

## 8. Hyperparameter Tuning for Best Models

In [None]:
# Select the best 3 models for hyperparameter tuning
best_models = metrics_df['Model'].unique()[:3]
print(f"Best models for tuning: {best_models}")

# Define parameter grids for each model
param_grids = {}

# Define parameter grids based on best models
for model_name in best_models:
    if model_name == 'XGBoost':
        param_grids[model_name] = {
            'learning_rate': [0.01, 0.05, 0.1],
            'max_depth': [3, 5, 7],
            'min_child_weight': [1, 3, 5],
            'subsample': [0.7, 0.8, 0.9],
            'colsample_bytree': [0.7, 0.8, 0.9],
            'n_estimators': [100, 200],
            'scale_pos_weight': [1, 5, 10]
        }
    elif model_name == 'Random Forest':
        param_grids[model_name] = {
            'n_estimators': [100, 200, 300],
            'max_depth': [None, 10, 20, 30],
            'min_samples_split': [2, 5, 10],
            'min_samples_leaf': [1, 2, 4],
            'class_weight': ['balanced']
        }
    elif model_name == 'LightGBM':
        param_grids[model_name] = {
            'learning_rate': [0.01, 0.05, 0.1],
            'num_leaves': [31, 63, 127],
            'max_depth': [-1, 5, 10],
            'min_child_samples': [5, 10, 20],
            'subsample': [0.7, 0.8, 0.9],
            'colsample_bytree': [0.7, 0.8, 0.9],
            'n_estimators': [100, 200],
            'class_weight': ['balanced']
        }
    elif model_name == 'CatBoost':
        param_grids[model_name] = {
            'learning_rate': [0.01, 0.05, 0.1],
            'depth': [4, 6, 8],
            'l2_leaf_reg': [1, 3, 5],
            'iterations': [100, 200],
            'verbose': [0]
        }
    elif model_name == 'Gradient Boosting':
        param_grids[model_name] = {
            'n_estimators': [100, 200],
            'learning_rate': [0.01, 0.05, 0.1],
            'max_depth': [3, 5, 7],
            'min_samples_split': [2, 5, 10],
            'min_samples_leaf': [1, 2, 4],
            'subsample': [0.7, 0.8, 0.9]
        }
    elif model_name == 'Logistic Regression':
        param_grids[model_name] = {
            'C': [0.01, 0.1, 1.0, 10.0],
            'solver': ['liblinear', 'saga'],
            'penalty': ['l1', 'l2'],
            'class_weight': ['balanced'],
            'max_iter': [1000]
        }
    else:
        print(f"No parameter grid defined for {model_name}, skipping.")
        continue

# Dictionary to store tuned models
tuned_models = {}
tuned_model_metrics = {}

# Perform grid search for each model
for model_name in best_models:
    if model_name not in param_grids:
        continue
        
    print(f"\nPerforming hyperparameter tuning for {model_name}...")
    
    # Get the base model
    if model_name == 'XGBoost':
        base_model = xgb.XGBClassifier(random_state=SEED)
    elif model_name == 'Random Forest':
        base_model = RandomForestClassifier(random_state=SEED)
    elif model_name == 'LightGBM':
        base_model = lgb.LGBMClassifier(random_state=SEED)
    elif model_name == 'CatBoost':
        base_model = cb.CatBoostClassifier(random_state=SEED)
    elif model_name == 'Gradient Boosting':
        base_model = GradientBoostingClassifier(random_state=SEED)
    elif model_name == 'Logistic Regression':
        base_model = LogisticRegression(random_state=SEED)
    
    # Set up grid search
    grid_search = GridSearchCV(
        estimator=base_model,
        param_grid=param_grids[model_name],
        cv=3,  # Use a smaller CV to speed up the search
        scoring='f1',
        n_jobs=-1,
        verbose=1
    )
    
    # Fit grid search
    grid_search.fit(X_train_smote, y_train_smote)
    
    # Get best estimator and parameters
    best_model = grid_search.best_estimator_
    print(f"Best parameters: {grid_search.best_params_}")
    print(f"Best CV score: {grid_search.best_score_:.4f}")
    
    # Evaluate best model on validation set
    metrics = evaluate_model(best_model, X_val_scaled, y_val, f"{model_name} (Tuned)")
    
    # Store tuned model and metrics
    tuned_models[model_name] = best_model
    tuned_model_metrics[model_name] = metrics

In [None]:
# Compare tuned models with base models
compare_df = pd.DataFrame(columns=['Model', 'Version', 'Accuracy', 'Precision', 'Recall', 'F1 Score', 'ROC AUC', 'Avg Precision'])

# Add base model metrics
for name in tuned_models.keys():
    # Find the best version (SMOTE or BorderlineSMOTE)
    model_metrics_filtered = metrics_df[metrics_df['Model'] == name]
    best_sampling = model_metrics_filtered.iloc[0]['Sampling']
    metrics = model_metrics_filtered.iloc[0]
    
    compare_df = pd.concat([compare_df, pd.DataFrame([{
        'Model': name,
        'Version': f'Base ({best_sampling})',
        'Accuracy': metrics['Accuracy'],
        'Precision': metrics['Precision'],
        'Recall': metrics['Recall'],
        'F1 Score': metrics['F1 Score'],
        'ROC AUC': metrics['ROC AUC'],
        'Avg Precision': metrics['Avg Precision']
    }])], ignore_index=True)

# Add tuned model metrics
for name, metrics in tuned_model_metrics.items():
    compare_df = pd.concat([compare_df, pd.DataFrame([{
        'Model': name,
        'Version': 'Tuned',
        'Accuracy': metrics['accuracy'],
        'Precision': metrics['precision'],
        'Recall': metrics['recall'],
        'F1 Score': metrics['f1'],
        'ROC AUC': metrics.get('roc_auc', np.nan),
        'Avg Precision': metrics.get('average_precision', np.nan)
    }])], ignore_index=True)

# Sort by model and then version
compare_df = compare_df.sort_values(['Model', 'Version']).reset_index(drop=True)

# Display comparison
print("Base vs. Tuned Models Performance Comparison:")
compare_df

In [None]:
# Visualize base vs. tuned model performance comparison
plt.figure(figsize=(14, 10))

# Plot F1 score comparison
plt.subplot(2, 2, 1)
sns.barplot(x='F1 Score', y='Model', hue='Version', data=compare_df)
plt.title('F1 Score: Base vs. Tuned Models', fontsize=14)
plt.xlim(0.5, 1)
plt.grid(True, alpha=0.3)

# Plot ROC AUC comparison
plt.subplot(2, 2, 2)
sns.barplot(x='ROC AUC', y='Model', hue='Version', data=compare_df)
plt.title('ROC AUC: Base vs. Tuned Models', fontsize=14)
plt.xlim(0.5, 1)
plt.grid(True, alpha=0.3)

# Plot Precision comparison
plt.subplot(2, 2, 3)
sns.barplot(x='Precision', y='Model', hue='Version', data=compare_df)
plt.title('Precision: Base vs. Tuned Models', fontsize=14)
plt.xlim(0.5, 1)
plt.grid(True, alpha=0.3)

# Plot Recall comparison
plt.subplot(2, 2, 4)
sns.barplot(x='Recall', y='Model', hue='Version', data=compare_df)
plt.title('Recall: Base vs. Tuned Models', fontsize=14)
plt.xlim(0.5, 1)
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('figures/model_evaluation/tuned_vs_base_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

## 9. Build Dense Stacking Ensemble Model

In [None]:
# Create a Dense Stacking Ensemble model using the best tuned models
print("Building Dense Stacking Ensemble Model...")

# Select base models for the ensemble (use tuned models)
base_models = [
    (name, model) for name, model in tuned_models.items()
]

# Add some additional base models for diversity
if 'AdaBoost' not in tuned_models and 'AdaBoost' in trained_models:
    base_models.append(('AdaBoost', trained_models['AdaBoost']))
if 'Extra Trees' not in tuned_models and 'Extra Trees' in trained_models:
    base_models.append(('Extra Trees', trained_models['Extra Trees']))

print(f"Base models for stacking: {[name for name, _ in base_models]}")

# Define meta-learner (final estimator)
meta_learner = LogisticRegression(max_iter=1000, class_weight='balanced', random_state=SEED)

# Create the stacking ensemble
stacking_ensemble = StackingClassifier(
    estimators=base_models,
    final_estimator=meta_learner,
    cv=5,
    stack_method='predict_proba',
    n_jobs=-1
)

# Train the stacking ensemble
print("Training Stacking Ensemble...")
start_time = time.time()
stacking_ensemble.fit(X_train_smote, y_train_smote)
training_time = time.time() - start_time
print(f"Training time: {training_time:.2f} seconds")

# Evaluate the stacking ensemble
stacking_metrics = evaluate_model(stacking_ensemble, X_val_scaled, y_val, "Dense Stacking Ensemble")
stacking_metrics['train_time'] = training_time

# Create weighted voting ensemble as an alternative
print("\nBuilding Weighted Voting Ensemble...")

# Define weights based on validation F1 scores
voting_weights = []
for name, _ in base_models:
    if name in tuned_model_metrics:
        voting_weights.append(tuned_model_metrics[name]['f1'])
    else:
        voting_weights.append(model_metrics[name]['f1'])

print(f"Voting weights: {voting_weights}")

# Create the voting ensemble
voting_ensemble = VotingClassifier(
    estimators=base_models,
    voting='soft',
    weights=voting_weights,
    n_jobs=-1
)

# Train the voting ensemble
print("Training Voting Ensemble...")
start_time = time.time()
voting_ensemble.fit(X_train_smote, y_train_smote)
training_time = time.time() - start_time
print(f"Training time: {training_time:.2f} seconds")

# Evaluate the voting ensemble
voting_metrics = evaluate_model(voting_ensemble, X_val_scaled, y_val, "Weighted Voting Ensemble")
voting_metrics['train_time'] = training_time

In [None]:
# Compare ensemble models with the best individual models
ensemble_comparison = pd.DataFrame(columns=['Model', 'Accuracy', 'Precision', 'Recall', 'F1 Score', 'ROC AUC', 'Avg Precision'])

# Add best individual model (tuned)
best_tuned_model = None
best_f1 = 0

for name, metrics in tuned_model_metrics.items():
    if metrics['f1'] > best_f1:
        best_f1 = metrics['f1']
        best_tuned_model = name

best_metrics = tuned_model_metrics[best_tuned_model]
ensemble_comparison = pd.concat([ensemble_comparison, pd.DataFrame([{
    'Model': f"Best Individual ({best_tuned_model})",
    'Accuracy': best_metrics['accuracy'],
    'Precision': best_metrics['precision'],
    'Recall': best_metrics['recall'],
    'F1 Score': best_metrics['f1'],
    'ROC AUC': best_metrics.get('roc_auc', np.nan),
    'Avg Precision': best_metrics.get('average_precision', np.nan)
}])], ignore_index=True)

# Add stacking ensemble metrics
ensemble_comparison = pd.concat([ensemble_comparison, pd.DataFrame([{
    'Model': "Dense Stacking Ensemble",
    'Accuracy': stacking_metrics['accuracy'],
    'Precision': stacking_metrics['precision'],
    'Recall': stacking_metrics['recall'],
    'F1 Score': stacking_metrics['f1'],
    'ROC AUC': stacking_metrics.get('roc_auc', np.nan),
    'Avg Precision': stacking_metrics.get('average_precision', np.nan)
}])], ignore_index=True)

# Add voting ensemble metrics
ensemble_comparison = pd.concat([ensemble_comparison, pd.DataFrame([{
    'Model': "Weighted Voting Ensemble",
    'Accuracy': voting_metrics['accuracy'],
    'Precision': voting_metrics['precision'],
    'Recall': voting_metrics['recall'],
    'F1 Score': voting_metrics['f1'],
    'ROC AUC': voting_metrics.get('roc_auc', np.nan),
    'Avg Precision': voting_metrics.get('average_precision', np.nan)
}])], ignore_index=True)

# Sort by F1 score
ensemble_comparison = ensemble_comparison.sort_values('F1 Score', ascending=False).reset_index(drop=True)

# Display comparison
print("Ensemble vs. Best Individual Model Performance Comparison:")
ensemble_comparison

In [None]:
# Visualize ensemble comparison
plt.figure(figsize=(14, 10))

metrics_to_plot = ['F1 Score', 'ROC AUC', 'Precision', 'Recall']
for i, metric in enumerate(metrics_to_plot, 1):
    plt.subplot(2, 2, i)
    sns.barplot(x=metric, y='Model', data=ensemble_comparison, palette='viridis')
    plt.title(f'{metric} Comparison', fontsize=14)
    plt.xlim(0.5, 1)
    plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('figures/model_evaluation/ensemble_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Compare ROC curves
ensemble_models = {
    'Dense Stacking Ensemble': stacking_ensemble,
    'Weighted Voting Ensemble': voting_ensemble,
    f'Best Individual ({best_tuned_model})': tuned_models[best_tuned_model]
}

# Plot ROC curves
plot_roc_curve(ensemble_models, X_val_scaled, y_val)

# Plot precision-recall curves
plot_precision_recall_curve(ensemble_models, X_val_scaled, y_val)

## 10. Final Evaluation on Test Set

In [None]:
# Select the best model based on validation performance
best_model_name = ensemble_comparison.iloc[0]['Model']
print(f"Best model: {best_model_name}")

# Get the corresponding model object
if best_model_name == "Dense Stacking Ensemble":
    best_model = stacking_ensemble
elif best_model_name == "Weighted Voting Ensemble":
    best_model = voting_ensemble
else:
    # Extract model name from "Best Individual (ModelName)"
    model_name = best_model_name.split('(')[1].split(')')[0]
    best_model = tuned_models[model_name]

# Evaluate the best model on the test set
print("\nEvaluating best model on test set...")
test_metrics = evaluate_model(best_model, X_test_scaled, y_test, f"Best Model ({best_model_name}) - Test Set")

# Plot confusion matrix for test set
y_test_pred = best_model.predict(X_test_scaled)
test_cm = confusion_matrix(y_test, y_test_pred)
plot_confusion_matrix(test_cm, f"Best Model ({best_model_name}) - Test Set")

In [None]:
# Generate detailed classification report
print("Detailed Classification Report (Test Set):")
print(classification_report(y_test, y_test_pred))

# Calculate ROC curve and AUC for test set
try:
    y_test_proba = best_model.predict_proba(X_test_scaled)[:, 1]
    
    # ROC curve
    fpr, tpr, thresholds = roc_curve(y_test, y_test_proba)
    test_auc = auc(fpr, tpr)
    
    plt.figure(figsize=(10, 8))
    plt.plot(fpr, tpr, lw=2, label=f'ROC curve (AUC = {test_auc:.4f})')
    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate', fontsize=12)
    plt.ylabel('True Positive Rate', fontsize=12)
    plt.title(f'ROC Curve - Test Set', fontsize=15)
    plt.legend(loc="lower right")
    plt.grid(True, alpha=0.3)
    plt.savefig('figures/model_evaluation/test_set_roc_curve.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Precision-Recall curve
    precision, recall, _ = precision_recall_curve(y_test, y_test_proba)
    test_avg_precision = average_precision_score(y_test, y_test_proba)
    
    plt.figure(figsize=(10, 8))
    no_skill = len(y_test[y_test == 1]) / len(y_test)
    plt.plot([0, 1], [no_skill, no_skill], 'k--', lw=2, label='No Skill')
    plt.plot(recall, precision, lw=2, label=f'PR curve (AP = {test_avg_precision:.4f})')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('Recall', fontsize=12)
    plt.ylabel('Precision', fontsize=12)
    plt.title(f'Precision-Recall Curve - Test Set', fontsize=15)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig('figures/model_evaluation/test_set_pr_curve.png', dpi=300, bbox_inches='tight')
    plt.show()
except Exception as e:
    print(f"Error generating probability-based curves: {e}")

## 11. Save Final Model

In [None]:
# Save the best model, scaler, and metadata
print(f"Saving best model: {best_model_name}")

# Save model
model_filename = 'models/stroke_prediction_model.pkl'
with open(model_filename, 'wb') as f:
    pickle.dump(best_model, f)
print(f"Model saved to {model_filename}")

# Save scaler
scaler_filename = 'models/stroke_prediction_scaler.pkl'
with open(scaler_filename, 'wb') as f:
    pickle.dump(scaler, f)
print(f"Scaler saved to {scaler_filename}")

# Save model metadata
model_metadata = {
    'model_name': best_model_name,
    'feature_names': list(X.columns),
    'num_features': len(X.columns),
    'test_metrics': {
        'accuracy': test_metrics['accuracy'],
        'precision': test_metrics['precision'],
        'recall': test_metrics['recall'],
        'f1': test_metrics['f1'],
        'roc_auc': test_metrics.get('roc_auc', None),
        'avg_precision': test_metrics.get('average_precision', None)
    },
    'class_distribution': {
        'negative': int(np.sum(y == 0)),
        'positive': int(np.sum(y == 1))
    },
    'timestamp': pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')
}

metadata_filename = 'models/stroke_prediction_metadata.json'
with open(metadata_filename, 'w') as f:
    import json
    json.dump(model_metadata, f, indent=4)
print(f"Model metadata saved to {metadata_filename}")

In [None]:
# Create a helper function for making predictions with the saved model
def load_model_and_predict(input_data, model_path='models/stroke_prediction_model.pkl', 
                           scaler_path='models/stroke_prediction_scaler.pkl'):
    """
    Load the saved model and make predictions on new data.
    
    Parameters:
    -----------
    input_data : pd.DataFrame
        Input data with the same features as the training data
    model_path : str
        Path to the saved model
    scaler_path : str
        Path to the saved scaler
        
    Returns:
    --------
    tuple
        Predicted class and probability
    """
    # Load model and scaler
    with open(model_path, 'rb') as f:
        model = pickle.load(f)
    
    with open(scaler_path, 'rb') as f:
        scaler = pickle.load(f)
    
    # Scale the input data
    input_scaled = scaler.transform(input_data)
    
    # Make predictions
    try:
        pred_proba = model.predict_proba(input_scaled)[:, 1]
        pred_class = model.predict(input_scaled)
        return pred_class, pred_proba
    except:
        pred_class = model.predict(input_scaled)
        return pred_class, None

# Test the function with a sample from the test set
sample_idx = np.random.randint(0, len(X_test))
sample_input = X_test.iloc[[sample_idx]]
true_label = y_test.iloc[sample_idx]

print(f"Testing prediction function with sample {sample_idx}...")
print(f"True label: {true_label}")

# Make prediction with the saved model
pred_class, pred_proba = load_model_and_predict(sample_input)

print(f"Predicted class: {pred_class[0]}")
if pred_proba is not None:
    print(f"Predicted probability: {pred_proba[0]:.4f}")

print("\nModel training and evaluation completed successfully!")

## Summary of Model Training Process

In this notebook, we've developed a comprehensive stroke prediction model using the following approach:

1. **Data Preparation**:
   - Loaded preprocessed dataset with encoded features
   - Split data into training, validation, and test sets
   - Standardized numerical features

2. **Class Imbalance Handling**:
   - Applied SMOTE and BorderlineSMOTE to balance the training set
   - Compared performance of both approaches

3. **Base Model Training**:
   - Trained multiple classification models including Logistic Regression, Random Forest, XGBoost, LightGBM, etc.
   - Evaluated models on validation set using multiple metrics

4. **Hyperparameter Tuning**:
   - Applied grid search to tune the best performing models
   - Compared tuned models with base models

5. **Ensemble Learning**:
   - Built a Dense Stacking Ensemble using the best tuned models
   - Created a Weighted Voting Ensemble as an alternative
   - Compared ensemble models with individual models

6. **Final Evaluation**:
   - Evaluated the best model on the test set
   - Generated detailed metrics and visualizations

7. **Model Saving**:
   - Saved the best model, scaler, and metadata for future use
   - Created a helper function for making predictions

In [None]:
#!/usr/bin/env python3
# Enhanced Stroke Prediction Model
# This script implements advanced techniques to achieve high performance for stroke prediction

# General libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import os
import time
import pickle
from tqdm import tqdm

# ML libraries
from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_score, GridSearchCV
from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler
from sklearn.pipeline import Pipeline
from sklearn.feature_selection import SelectFromModel, RFECV
from sklearn.base import clone

# Class imbalance handling
from imblearn.over_sampling import SMOTE, ADASYN, BorderlineSMOTE
from imblearn.under_sampling import TomekLinks, EditedNearestNeighbours
from imblearn.combine import SMOTETomek, SMOTEENN
from imblearn.pipeline import Pipeline as ImbPipeline
from imblearn.ensemble import BalancedRandomForestClassifier, RUSBoostClassifier, EasyEnsembleClassifier

# Models
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import (
    RandomForestClassifier, GradientBoostingClassifier, 
    AdaBoostClassifier, ExtraTreesClassifier,
    VotingClassifier, StackingClassifier, BaggingClassifier
)
from sklearn.svm import SVC
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.calibration import CalibratedClassifierCV
from sklearn.preprocessing import PolynomialFeatures

# For XGBoost, LightGBM, and CatBoost
import xgboost as xgb
import lightgbm as lgb
import catboost as cb

# Metrics
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, confusion_matrix, classification_report,
    precision_recall_curve, auc, average_precision_score, roc_curve,
    balanced_accuracy_score, cohen_kappa_score, matthews_corrcoef
)

# Set random seed for reproducibility
SEED = 42
np.random.seed(SEED)

# Display settings
pd.set_option('display.max_columns', None)
pd.set_option('display.float_format', '{:.4f}'.format)

# Ignore warnings for cleaner output
warnings.filterwarnings('ignore')

# Plotting style
sns.set_style('whitegrid')
plt.style.use('fivethirtyeight')

# Create directories for outputs
os.makedirs('models', exist_ok=True)
os.makedirs('results', exist_ok=True)
os.makedirs('figures/model_evaluation', exist_ok=True)

# Function to evaluate model with multiple metrics
def evaluate_model_comprehensive(model, X_val, y_val, model_name, threshold=0.5, print_report=True):
    """
    Comprehensive evaluation of a model with multiple metrics.
    """
    # Get predictions
    try:
        y_pred_proba = model.predict_proba(X_val)[:, 1]
        y_pred = (y_pred_proba >= threshold).astype(int)
    except:
        # Some models might not have predict_proba
        y_pred = model.predict(X_val)
        y_pred_proba = None
    
    # Calculate standard metrics
    metrics = {}
    metrics['accuracy'] = accuracy_score(y_val, y_pred)
    metrics['balanced_accuracy'] = balanced_accuracy_score(y_val, y_pred)
    metrics['precision'] = precision_score(y_val, y_pred)
    metrics['recall'] = recall_score(y_val, y_pred)
    metrics['f1'] = f1_score(y_val, y_pred)
    metrics['cohen_kappa'] = cohen_kappa_score(y_val, y_pred)
    metrics['matthews_corrcoef'] = matthews_corrcoef(y_val, y_pred)
    
    if y_pred_proba is not None:
        metrics['roc_auc'] = roc_auc_score(y_val, y_pred_proba)
        metrics['average_precision'] = average_precision_score(y_val, y_pred_proba)
    
    # Confusion matrix
    cm = confusion_matrix(y_val, y_pred)
    metrics['confusion_matrix'] = cm
    
    # Additional metrics derived from confusion matrix
    tn, fp, fn, tp = cm.ravel()
    metrics['specificity'] = tn / (tn + fp) if (tn + fp) > 0 else 0
    metrics['npv'] = tn / (tn + fn) if (tn + fn) > 0 else 0  # Negative Predictive Value
    
    # Calculate G-mean
    metrics['g_mean'] = np.sqrt(metrics['recall'] * metrics['specificity'])
    
    # Print results if requested
    if print_report:
        print(f"\n{model_name} Evaluation (threshold={threshold:.2f}):")
        print(f"Accuracy: {metrics['accuracy']:.4f}")
        print(f"Balanced Accuracy: {metrics['balanced_accuracy']:.4f}")
        print(f"Precision: {metrics['precision']:.4f}")
        print(f"Recall/Sensitivity: {metrics['recall']:.4f}")
        print(f"Specificity: {metrics['specificity']:.4f}")
        print(f"F1 Score: {metrics['f1']:.4f}")
        print(f"Cohen's Kappa: {metrics['cohen_kappa']:.4f}")
        print(f"Matthews Correlation Coefficient: {metrics['matthews_corrcoef']:.4f}")
        print(f"G-mean: {metrics['g_mean']:.4f}")
        
        if 'roc_auc' in metrics:
            print(f"ROC AUC: {metrics['roc_auc']:.4f}")
            print(f"Average Precision: {metrics['average_precision']:.4f}")
        
        print("\nConfusion Matrix:")
        print(cm)
        print("\nClassification Report:")
        print(classification_report(y_val, y_pred))
    
    return metrics

# Function for finding optimal threshold
def find_optimal_threshold(model, X_val, y_val, metric='f1', thresholds=None):
    """Find the optimal threshold to maximize a given metric."""
    if thresholds is None:
        thresholds = np.linspace(0.01, 0.99, 99)
        
    try:
        y_pred_proba = model.predict_proba(X_val)[:, 1]
    except:
        print("Model doesn't support predict_proba. Using default threshold of 0.5.")
        return 0.5
    
    scores = []
    for threshold in thresholds:
        y_pred = (y_pred_proba >= threshold).astype(int)
        
        if metric == 'f1':
            score = f1_score(y_val, y_pred)
        elif metric == 'recall':
            score = recall_score(y_val, y_pred)
        elif metric == 'precision':
            score = precision_score(y_val, y_pred)
        elif metric == 'g_mean':
            recall = recall_score(y_val, y_pred)
            tn, fp, fn, tp = confusion_matrix(y_val, y_pred).ravel()
            specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
            score = np.sqrt(recall * specificity)
        elif metric == 'balanced_accuracy':
            score = balanced_accuracy_score(y_val, y_pred)
        else:
            score = f1_score(y_val, y_pred)  # Default to F1
            
        scores.append(score)
    
    best_idx = np.argmax(scores)
    best_threshold = thresholds[best_idx]
    best_score = scores[best_idx]
    
    # Plot threshold vs metric
    plt.figure(figsize=(10, 6))
    plt.plot(thresholds, scores, marker='o', markersize=3, linewidth=2)
    plt.axvline(x=best_threshold, color='r', linestyle='--', label=f'Best threshold: {best_threshold:.2f}')
    plt.axhline(y=best_score, color='g', linestyle='--', label=f'Best {metric}: {best_score:.4f}')
    plt.title(f'{metric.capitalize()} vs. Threshold', fontsize=14)
    plt.xlabel('Threshold', fontsize=12)
    plt.ylabel(metric.capitalize(), fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.savefig(f'figures/model_evaluation/threshold_tuning_{metric}.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"Optimal threshold for {metric}: {best_threshold:.4f}, {metric}: {best_score:.4f}")
    return best_threshold

# Enhanced plotting functions
def plot_roc_curve_enhanced(models, X_val, y_val, figsize=(12, 8)):
    """
    Plot enhanced ROC curves for multiple models with AUC and confidence intervals.
    """
    plt.figure(figsize=figsize)
    
    # Random classifier line
    plt.plot([0, 1], [0, 1], 'k--', lw=2, label='Random')
    
    # Plot for each model
    for name, model in models.items():
        try:
            y_pred_proba = model.predict_proba(X_val)[:, 1]
            fpr, tpr, _ = roc_curve(y_val, y_pred_proba)
            roc_auc = auc(fpr, tpr)
            
            # Plot ROC curve
            plt.plot(fpr, tpr, lw=2, 
                     label=f'{name} (AUC = {roc_auc:.4f})')
        except Exception as e:
            print(f"Could not plot ROC curve for {name}: {e}")
    
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate', fontsize=12)
    plt.ylabel('True Positive Rate', fontsize=12)
    plt.title('Receiver Operating Characteristic (ROC) Curves', fontsize=15)
    plt.legend(loc="lower right")
    plt.grid(True, alpha=0.3)
    plt.savefig('figures/model_evaluation/enhanced_roc_curves.png', dpi=300, bbox_inches='tight')
    plt.show()

def plot_precision_recall_curve_enhanced(models, X_val, y_val, figsize=(12, 8)):
    """
    Plot enhanced precision-recall curves for multiple models.
    """
    plt.figure(figsize=figsize)
    
    # Calculate baseline (no skill classifier)
    no_skill = len(y_val[y_val == 1]) / len(y_val)
    plt.plot([0, 1], [no_skill, no_skill], 'k--', lw=2, label=f'No Skill ({no_skill:.3f})')
    
    # Plot for each model
    for name, model in models.items():
        try:
            y_pred_proba = model.predict_proba(X_val)[:, 1]
            precision, recall, _ = precision_recall_curve(y_val, y_pred_proba)
            pr_auc = auc(recall, precision)
            avg_precision = average_precision_score(y_val, y_pred_proba)
            
            # Plot PR curve
            plt.plot(recall, precision, lw=2, 
                     label=f'{name} (AP = {avg_precision:.4f}, AUC = {pr_auc:.4f})')
        except Exception as e:
            print(f"Could not plot PR curve for {name}: {e}")
    
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('Recall', fontsize=12)
    plt.ylabel('Precision', fontsize=12)
    plt.title('Precision-Recall Curves', fontsize=15)
    plt.legend(loc="best")
    plt.grid(True, alpha=0.3)
    plt.savefig('figures/model_evaluation/enhanced_pr_curves.png', dpi=300, bbox_inches='tight')
    plt.show()

def plot_confusion_matrix_enhanced(cm, model_name, figsize=(10, 8)):
    """
    Plot an enhanced confusion matrix with additional metrics.
    """
    plt.figure(figsize=figsize)
    
    # Extract values from confusion matrix
    tn, fp, fn, tp = cm.ravel()
    
    # Calculate metrics
    accuracy = (tp + tn) / (tp + tn + fp + fn)
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    
    # Plot confusion matrix
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=['No Stroke', 'Stroke'],
                yticklabels=['No Stroke', 'Stroke'])
    
    # Add metrics as text
    plt.text(0.5, -0.15, f"Accuracy: {accuracy:.4f} | Precision: {precision:.4f}", 
             ha='center', fontsize=12, transform=plt.gca().transAxes)
    plt.text(0.5, -0.2, f"Recall: {recall:.4f} | Specificity: {specificity:.4f} | F1: {f1:.4f}", 
             ha='center', fontsize=12, transform=plt.gca().transAxes)
    
    plt.xlabel('Predicted Label', fontsize=12)
    plt.ylabel('True Label', fontsize=12)
    plt.title(f'Confusion Matrix - {model_name}', fontsize=15)
    plt.tight_layout()
    plt.savefig(f'figures/model_evaluation/enhanced_cm_{model_name.replace(" ", "_").lower()}.png', 
                dpi=300, bbox_inches='tight')
    plt.show()

def plot_feature_importance_enhanced(model, feature_names, model_name, figsize=(12, 10), top_n=20):
    """
    Enhanced feature importance plot with sorted bars and annotations.
    """
    try:
        # Get feature importances
        if hasattr(model, 'feature_importances_'):
            importances = model.feature_importances_
        elif hasattr(model, 'coef_'):
            importances = np.abs(model.coef_[0])
        else:
            print(f"{model_name} does not have feature importances.")
            return None
        
        # Create DataFrame of feature importances
        feature_importance = pd.DataFrame({
            'Feature': feature_names,
            'Importance': importances
        })
        
        # Sort by importance and get top N
        feature_importance = feature_importance.sort_values('Importance', ascending=False)
        top_features = feature_importance.head(top_n)
        
        # Plot
        plt.figure(figsize=figsize)
        ax = sns.barplot(x='Importance', y='Feature', data=top_features, palette='viridis')
        
        # Add values as annotations
        for i, v in enumerate(top_features['Importance']):
            ax.text(v + 0.001, i, f"{v:.4f}", va='center', fontsize=10)
        
        plt.title(f'Top {top_n} Feature Importance - {model_name}', fontsize=15)
        plt.xlabel('Importance', fontsize=12)
        plt.ylabel('Feature', fontsize=12)
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig(f'figures/model_evaluation/enhanced_importance_{model_name.replace(" ", "_").lower()}.png', 
                   dpi=300, bbox_inches='tight')
        plt.show()
        
        return feature_importance
    except Exception as e:
        print(f"Error plotting feature importance for {model_name}: {e}")
        return None

# Function to apply and evaluate different resampling methods
def apply_resampling_methods(X_train_scaled, y_train):
    resampled_data = {}
    
    # 1. SMOTE with different settings
    print("Applying SMOTE...")
    smote = SMOTE(sampling_strategy=0.5, random_state=SEED, k_neighbors=5)
    X_train_smote, y_train_smote = smote.fit_resample(X_train_scaled, y_train)
    resampled_data['smote'] = (X_train_smote, y_train_smote)
    
    # 2. BorderlineSMOTE
    print("Applying BorderlineSMOTE...")
    bsmote = BorderlineSMOTE(sampling_strategy=0.5, random_state=SEED, k_neighbors=5)
    X_train_bsmote, y_train_bsmote = bsmote.fit_resample(X_train_scaled, y_train)
    resampled_data['bsmote'] = (X_train_bsmote, y_train_bsmote)

    # 3. ADASYN for more advanced synthetic data generation
    print("Applying ADASYN...")
    adasyn = ADASYN(sampling_strategy=0.5, random_state=SEED, n_neighbors=5)
    X_train_adasyn, y_train_adasyn = adasyn.fit_resample(X_train_scaled, y_train)
    resampled_data['adasyn'] = (X_train_adasyn, y_train_adasyn)

    # 4. SMOTETomek (combines oversampling with cleaning)
    print("Applying SMOTETomek...")
    smotetomek = SMOTETomek(sampling_strategy=0.5, random_state=SEED)
    X_train_smotetomek, y_train_smotetomek = smotetomek.fit_resample(X_train_scaled, y_train)
    resampled_data['smotetomek'] = (X_train_smotetomek, y_train_smotetomek)

    # 5. SMOTEENN (combines SMOTE with Edited Nearest Neighbors)
    print("Applying SMOTEENN...")
    smoteenn = SMOTEENN(sampling_strategy=0.5, random_state=SEED)
    X_train_smoteenn, y_train_smoteenn = smoteenn.fit_resample(X_train_scaled, y_train)
    resampled_data['smoteenn'] = (X_train_smoteenn, y_train_smoteenn)

    # 6. More aggressive SMOTE (1:1 ratio)
    print("Applying aggressive SMOTE (1:1 ratio)...")
    smote_11 = SMOTE(sampling_strategy=1.0, random_state=SEED, k_neighbors=3)
    X_train_smote_11, y_train_smote_11 = smote_11.fit_resample(X_train_scaled, y_train)
    resampled_data['smote_11'] = (X_train_smote_11, y_train_smote_11)
    
    # Print sizes and class distributions
    for name, (X_resampled, y_resampled) in resampled_data.items():
        pos_count = np.sum(y_resampled == 1)
        total = len(y_resampled)
        print(f"{name}: {X_resampled.shape}, {pos_count}/{total} positive cases ({pos_count/total*100:.2f}%)")
    
    return resampled_data

# Function to train models on different resampled data
def train_models_on_resampled_data(models, resampled_data, X_val, y_val):
    """
    Train models on different resampled datasets and evaluate on validation set.
    """
    results = {}
    
    for resampling_method, (X_resampled, y_resampled) in resampled_data.items():
        print(f"\n{'='*50}")
        print(f"Training models on {resampling_method} resampled data")
        print(f"{'='*50}")
        
        method_results = {'models': {}, 'metrics': {}}
        
        for model_name, model in models.items():
            print(f"\nTraining {model_name} on {resampling_method} data...")
            
            # Clone the model to avoid contamination
            model_clone = clone(model)
            
            # Train the model
            start_time = time.time()
            try:
                model_clone.fit(X_resampled, y_resampled)
                training_time = time.time() - start_time
                
                # Evaluate on validation set
                metrics = evaluate_model_comprehensive(
                    model_clone, X_val, y_val, 
                    f"{model_name} ({resampling_method})"
                )
                metrics['training_time'] = training_time
                
                # Store model and metrics
                method_results['models'][model_name] = model_clone
                method_results['metrics'][model_name] = metrics
                
                print(f"Training time: {training_time:.2f} seconds")
            except Exception as e:
                print(f"Error training {model_name} on {resampling_method} data: {e}")
        
        results[resampling_method] = method_results
    
    return results

# Create optimized models
def create_optimized_models():
    """Create models with optimized configurations for imbalanced data."""
    optimized_models = {}
    
    # 1. Logistic Regression with class weights and regularization
    optimized_models['LR_Optimized'] = LogisticRegression(
        C=0.1,  # Stronger regularization
        penalty='l1',  # L1 regularization for feature selection
        solver='liblinear',
        class_weight={0: 1, 1: 10},  # Higher weight for minority class
        max_iter=5000,
        random_state=SEED
    )
    
    # 2. Random Forest with balanced class weights
    optimized_models['RF_Optimized'] = RandomForestClassifier(
        n_estimators=500,
        max_depth=10,
        min_samples_split=5,
        min_samples_leaf=2,
        max_features='sqrt',
        bootstrap=True,
        oob_score=True,
        class_weight='balanced_subsample',  # Balanced weights at each bootstrap iteration
        random_state=SEED,
        n_jobs=-1
    )
    
    # 3. XGBoost with imbalanced settings
    optimized_models['XGB_Optimized'] = xgb.XGBClassifier(
        learning_rate=0.01,
        n_estimators=1000,
        max_depth=4,
        min_child_weight=2,
        gamma=0.1,
        subsample=0.8,
        colsample_bytree=0.8,
        scale_pos_weight=20,  # Weight for positive class
        objective='binary:logistic',
        tree_method='exact',  # Use 'gpu_hist' if GPU available
        random_state=SEED,
        verbosity=0
    )
    
    # 4. LightGBM with imbalanced settings
    optimized_models['LGBM_Optimized'] = lgb.LGBMClassifier(
        boosting_type='goss',  # Gradient-based One-Side Sampling
        learning_rate=0.01,
        n_estimators=1000,
        num_leaves=31,
        max_depth=5,
        min_child_samples=20,
        subsample=0.8,
        colsample_bytree=0.8,
        is_unbalance=True,  # Handle imbalanced dataset
        random_state=SEED,
        verbose=-1
    )
    
    # 5. CatBoost with imbalanced settings
    optimized_models['CatBoost_Optimized'] = cb.CatBoostClassifier(
        iterations=500,
        learning_rate=0.02,
        depth=6,
        l2_leaf_reg=3,
        bootstrap_type='Bernoulli',
        subsample=0.8,
        scale_pos_weight=15,  # Weight for positive class
        random_state=SEED,
        verbose=0
    )
    
    # 6. Balanced Random Forest (from imbalanced-learn)
    optimized_models['Balanced_RF'] = BalancedRandomForestClassifier(
        n_estimators=500,
        max_depth=8,
        min_samples_split=5,
        min_samples_leaf=2,
        max_features='sqrt',
        bootstrap=True,
        sampling_strategy='auto',
        replacement=False,
        random_state=SEED,
        n_jobs=-1
    )
    
    # 7. RUSBoost (boosting with undersampling)
    optimized_models['RUSBoost'] = RUSBoostClassifier(
        n_estimators=500,
        learning_rate=0.1,
        algorithm='SAMME.R',
        sampling_strategy='auto',
        replacement=False,
        random_state=SEED
    )
    
    # 8. EasyEnsemble - FIXED version without base_estimator parameter
    optimized_models['EasyEnsemble'] = EasyEnsembleClassifier(
        n_estimators=10,
        sampling_strategy='auto',
        replacement=False,
        random_state=SEED,
        n_jobs=-1
    )
    
    # 9. Gradient Boosting with calibration
    gb = GradientBoostingClassifier(
        learning_rate=0.01,
        n_estimators=1000,
        max_depth=3,
        min_samples_split=5,
        min_samples_leaf=2,
        subsample=0.8,
        max_features='sqrt',
        random_state=SEED
    )
    # FIXED: Changed base_estimator to estimator
    optimized_models['GB_Calibrated'] = CalibratedClassifierCV(
        estimator=gb,  # Changed from base_estimator to estimator
        method='sigmoid',
        cv=5
    )
    
    # 10. SVM with class weights and probability calibration
    svc = SVC(
        C=1.0,
        kernel='rbf',
        gamma='scale',
        class_weight='balanced',
        probability=True,
        random_state=SEED
    )
    # FIXED: Changed base_estimator to estimator
    optimized_models['SVM_Calibrated'] = CalibratedClassifierCV(
        estimator=svc,  # Changed from base_estimator to estimator
        method='sigmoid',
        cv=5
    )
    
    return optimized_models

# Create stack/ensemble models
def create_ensemble_models(base_models):
    """Create advanced ensemble models from base models."""
    ensemble_models = {}
    
    # 1. Simple Voting Ensemble (hard voting)
    ensemble_models['Voting_Hard'] = VotingClassifier(
        estimators=list(base_models.items()),
        voting='hard',
        n_jobs=-1
    )
    
    # 2. Soft Voting Ensemble with weights
    # Use inverse of log loss or F1 as weights
    # For simplicity, using equal weights here
    ensemble_models['Voting_Soft'] = VotingClassifier(
        estimators=list(base_models.items()),
        voting='soft',
        weights=[1] * len(base_models),
        n_jobs=-1
    )
    
    # 3. Stacking Ensemble with LogisticRegression meta-learner
    ensemble_models['Stacking_LR'] = StackingClassifier(
        estimators=list(base_models.items()),
        final_estimator=LogisticRegression(class_weight='balanced'),
        cv=5,
        stack_method='predict_proba',
        n_jobs=-1
    )
    
    # 4. Stacking Ensemble with Random Forest meta-learner
    ensemble_models['Stacking_RF'] = StackingClassifier(
        estimators=list(base_models.items()),
        final_estimator=RandomForestClassifier(
            n_estimators=100, 
            class_weight='balanced',
            random_state=SEED
        ),
        cv=5,
        stack_method='predict_proba',
        n_jobs=-1
    )
    
    # 5. Multi-level stacking (creating a more complex ensemble)
    # First level stacking
    level1_stack = StackingClassifier(
        estimators=[
            ('lr', base_models.get('LR_Optimized', LogisticRegression(class_weight='balanced'))),
            ('rf', base_models.get('RF_Optimized', RandomForestClassifier(class_weight='balanced'))),
            ('xgb', base_models.get('XGB_Optimized', xgb.XGBClassifier())),
        ],
        final_estimator=LogisticRegression(),
        cv=5,
        stack_method='predict_proba'
    )
    
    # Second level stacking
    ensemble_models['Deep_Stacking'] = StackingClassifier(
        estimators=[
            ('level1', level1_stack),
            ('lgbm', base_models.get('LGBM_Optimized', lgb.LGBMClassifier())),
            ('cb', base_models.get('CatBoost_Optimized', cb.CatBoostClassifier(verbose=0))),
        ],
        final_estimator=LogisticRegression(class_weight='balanced'),
        cv=5,
        stack_method='predict_proba'
    )
    
    # 6. Bagging with sampling adjustment
    ensemble_models['Bagging_Adjusted'] = BaggingClassifier(
        base_estimator=LogisticRegression(class_weight='balanced'),
        n_estimators=100,
        max_samples=0.8,  # Use 80% of samples
        max_features=0.8,  # Use 80% of features
        bootstrap=True,
        bootstrap_features=True,
        random_state=SEED,
        n_jobs=-1
    )
    
    return ensemble_models

# Function to extract best models based on a specific metric
def extract_best_models(training_results, metric='f1', top_n=3):
    """
    Extract the top N models based on a specific metric.
    """
    all_models = []
    
    for resampling_method, result in training_results.items():
        for model_name, metrics in result['metrics'].items():
            all_models.append({
                'resampling_method': resampling_method,
                'model_name': model_name,
                'model': result['models'][model_name],
                metric: metrics[metric],
                'metrics': metrics
            })
    
    # Sort by the specified metric
    all_models.sort(key=lambda x: x[metric], reverse=True)
    
    # Return top N models
    return all_models[:top_n]

# Create a prediction function for new data
def predict_stroke_risk(input_data, model_path='models/final_stroke_prediction_model.pkl', 
                        scaler_path='models/final_stroke_prediction_scaler.pkl',
                        features_path='models/selected_features.pkl',
                        threshold_path='models/optimal_threshold.pkl'):
    """
    Make stroke predictions on new data using the trained model.
    """
    # Load the model, scaler, and optimal threshold
    with open(model_path, 'rb') as f:
        model = pickle.load(f)
    
    with open(scaler_path, 'rb') as f:
        scaler = pickle.load(f)
    
    with open(threshold_path, 'rb') as f:
        threshold = pickle.load(f)
    
    # Check if feature selection was used
    try:
        with open(features_path, 'rb') as f:
            selected_features = pickle.load(f)
        
        # Select the features from input data
        input_data = input_data[selected_features]
    except:
        print("No feature selection was used or file not found.")
    
    # Scale the input data
    input_data_scaled = scaler.transform(input_data)
    
    # Get prediction probabilities
    probabilities = model.predict_proba(input_data_scaled)[:, 1]
    
    # Apply the optimal threshold
    predictions = (probabilities >= threshold).astype(int)
    
    return predictions, probabilities

def main():
    # Load the preprocessed and encoded dataset
    print("Loading dataset...")
    df = pd.read_csv('data/processed/stroke_dataset_encoded.csv')

    # Display basic information
    print(f"Dataset shape: {df.shape}")
    print(f"Number of features: {df.shape[1] - 1}")

    # Target variable distribution
    print("\nTarget distribution:")
    print(df['stroke'].value_counts())
    print(df['stroke'].value_counts(normalize=True) * 100)

    # Check for any missing values
    missing_values = df.isnull().sum()
    if missing_values.sum() > 0:
        print("\nMissing values:")
        print(missing_values[missing_values > 0])
    else:
        print("\nNo missing values found except for comorbidity.")

    # Check the comorbidity column if it exists
    if 'comorbidity' in df.columns:
        print(f"\nComorbidity column unique values: {df['comorbidity'].unique()}")
        print(f"Comorbidity column missing values: {df['comorbidity'].isna().sum()}")

    # Get a list of all numerical features
    numerical_features = df.select_dtypes(include=['float64', 'int64']).columns.tolist()
    if 'stroke' in numerical_features:
        numerical_features.remove('stroke')
    if 'comorbidity' in numerical_features:
        numerical_features.remove('comorbidity')

    print(f"\nNumerical features: {numerical_features}")

    # Get a list of all categorical features
    categorical_features = df.select_dtypes(include=['object', 'bool']).columns.tolist()
    print(f"Categorical features: {categorical_features}")

    # Additional binary features (encoded as 0/1 but actually categorical)
    binary_features = [col for col in numerical_features if df[col].nunique() == 2]
    print(f"Binary features: {binary_features}")

    # Perform advanced feature engineering
    print("\nPerforming advanced feature engineering...")

    # Create a copy of the dataframe
    df_engineered = df.copy()

    # Drop problematic columns if they exist
    if 'comorbidity' in df_engineered.columns:
        df_engineered = df_engineered.drop('comorbidity', axis=1)
    if 'id' in df_engineered.columns:
        df_engineered = df_engineered.drop('id', axis=1)

    # Create advanced features

    # 1. Age-related transformations
    df_engineered['age_squared'] = df_engineered['age'] ** 2
    df_engineered['age_cube'] = df_engineered['age'] ** 3
    df_engineered['log_age'] = np.log1p(df_engineered['age'])

    # 2. BMI-related features
    if 'bmi' in df_engineered.columns:
        df_engineered['bmi_squared'] = df_engineered['bmi'] ** 2
        df_engineered['age_bmi_ratio'] = df_engineered['age'] / df_engineered['bmi']
        df_engineered['bmi_age_product'] = df_engineered['age'] * df_engineered['bmi'] / 100
        
        # BMI risk categories (more detailed)
        df_engineered['bmi_risk'] = pd.cut(
            df_engineered['bmi'], 
            bins=[0, 18.5, 25, 30, 35, 100], 
            labels=[1, 0, 2, 3, 4]
        ).astype(float)
        
        # Age-adjusted BMI
        df_engineered['age_adj_bmi'] = df_engineered['bmi'] * (1 + (df_engineered['age'] / 100))

    # 3. Glucose-related features
    if 'avg_glucose_level' in df_engineered.columns:
        df_engineered['glucose_squared'] = df_engineered['avg_glucose_level'] ** 2
        df_engineered['log_glucose'] = np.log1p(df_engineered['avg_glucose_level'])
        df_engineered['glucose_bmi_product'] = df_engineered['avg_glucose_level'] * df_engineered['bmi'] / 100
        df_engineered['glucose_age_ratio'] = df_engineered['avg_glucose_level'] / df_engineered['age']
        
        # Glucose risk categories
        df_engineered['glucose_risk'] = pd.cut(
            df_engineered['avg_glucose_level'],
            bins=[0, 70, 100, 126, 200, 500],
            labels=[2, 0, 1, 3, 4]
        ).astype(float)

    # 4. Hypertension and heart disease interactions
    if 'hypertension' in df_engineered.columns and 'heart_disease' in df_engineered.columns:
        df_engineered['health_risk_combined'] = df_engineered['hypertension'] + df_engineered['heart_disease']
        df_engineered['hypertension_heart_disease'] = df_engineered['hypertension'] * df_engineered['heart_disease']
        
        # Age-adjusted health risks
        df_engineered['age_hypertension_risk'] = df_engineered['age'] * df_engineered['hypertension'] / 10
        df_engineered['age_heart_disease_risk'] = df_engineered['age'] * df_engineered['heart_disease'] / 10
        
        # Combined age and health risks
        df_engineered['combined_health_age_risk'] = df_engineered['age'] * (1 + df_engineered['hypertension'] + df_engineered['heart_disease'])

    # 5. Complex feature combinations
    # Risk score based on all factors
    df_engineered['comprehensive_risk_score'] = (
        df_engineered['age'] / 10 +
        df_engineered['bmi'] / 5 +
        df_engineered['avg_glucose_level'] / 20
    )

    if 'hypertension' in df_engineered.columns and 'heart_disease' in df_engineered.columns:
        df_engineered['comprehensive_risk_score'] += (
            df_engineered['hypertension'] * 5 +
            df_engineered['heart_disease'] * 5
        )

    if 'ever_married_Yes' in df_engineered.columns:
        df_engineered['comprehensive_risk_score'] += df_engineered['ever_married_Yes'] * 1

    # 6. Create interaction features between numerical and binary features
    for num_feat in ['age', 'bmi', 'avg_glucose_level']:
        if num_feat in df_engineered.columns:
            for bin_feat in binary_features:
                if bin_feat in df_engineered.columns and bin_feat != num_feat:
                    df_engineered[f'{num_feat}_{bin_feat}_interaction'] = df_engineered[num_feat] * df_engineered[bin_feat]

    # 7. Create polynomial features for key numerical variables
    # Select key numerical features for polynomial expansion
    poly_features = ['age', 'bmi', 'avg_glucose_level']
    poly_features = [f for f in poly_features if f in df_engineered.columns]

    if poly_features:
        # Create polynomial features (degree 2)
        poly = PolynomialFeatures(degree=2, include_bias=False, interaction_only=True)
        poly_features_array = poly.fit_transform(df_engineered[poly_features])
        
        # Get the feature names
        poly_feature_names = poly.get_feature_names_out(poly_features)
        
        # Add polynomial features to the dataframe
        # Skip the first few that are just the original features
        for i, name in enumerate(poly_feature_names[len(poly_features):], len(poly_features)):
            df_engineered[f'poly_{name}'] = poly_features_array[:, i]

    # Display the engineered dataframe info
    print(f"\nOriginal dataframe shape: {df.shape}")
    print(f"Engineered dataframe shape: {df_engineered.shape}")
    print(f"Number of new features added: {df_engineered.shape[1] - df.shape[1]}")

    # List the new features
    new_features = set(df_engineered.columns) - set(df.columns)
    print(f"\nNew features added: {sorted(new_features)}")

    # Separate features and target
    X = df_engineered.drop('stroke', axis=1)
    y = df_engineered['stroke']

    # Split the data into training, validation, and test sets
    X_train_full, X_test, y_train_full, y_test = train_test_split(
        X, y, test_size=0.2, random_state=SEED, stratify=y
    )

    X_train, X_val, y_train, y_val = train_test_split(
        X_train_full, y_train_full, test_size=0.25, random_state=SEED, stratify=y_train_full
    )

    print(f"Training set: {X_train.shape}, {y_train.value_counts()[1]}/{len(y_train)} positive cases ({y_train.mean()*100:.2f}%)")
    print(f"Validation set: {X_val.shape}, {y_val.value_counts()[1]}/{len(y_val)} positive cases ({y_val.mean()*100:.2f}%)")
    print(f"Test set: {X_test.shape}, {y_test.value_counts()[1]}/{len(y_test)} positive cases ({y_test.mean()*100:.2f}%)")

    # Feature scaling - try different scalers
    scalers = {
        'standard': StandardScaler(),
        'minmax': MinMaxScaler(),
        'robust': RobustScaler()
    }

    scaled_data = {}
    for scaler_name, scaler in scalers.items():
        # Fit on training data
        X_train_scaled = scaler.fit_transform(X_train)
        X_val_scaled = scaler.transform(X_val)
        X_test_scaled = scaler.transform(X_test)
        
        scaled_data[scaler_name] = {
            'train': X_train_scaled,
            'val': X_val_scaled,
            'test': X_test_scaled,
            'scaler': scaler
        }

    print("Data scaled with multiple methods: Standard, MinMax, and Robust scaling")

    # Convert back to DataFrames with column names for easier handling
    for scaler_name in scaled_data:
        scaled_data[scaler_name]['train_df'] = pd.DataFrame(
            scaled_data[scaler_name]['train'], 
            columns=X_train.columns
        )
        scaled_data[scaler_name]['val_df'] = pd.DataFrame(
            scaled_data[scaler_name]['val'], 
            columns=X_val.columns
        )
        scaled_data[scaler_name]['test_df'] = pd.DataFrame(
            scaled_data[scaler_name]['test'], 
            columns=X_test.columns
        )

    # Apply resampling methods
    print("Applying advanced resampling techniques...")
    resampled_data = apply_resampling_methods(scaled_data['standard']['train'], y_train)

    # Visualize class distributions after resampling
    plt.figure(figsize=(15, 8))
    plt.subplot(2, 4, 1)
    plt.title("Original")
    plt.pie([len(y_train) - sum(y_train), sum(y_train)], labels=["0", "1"], autopct='%1.1f%%', colors=['#82CAFA', '#FE6B64'])

    positions = [(2, 4, 2), (2, 4, 3), (2, 4, 4), (2, 4, 5), (2, 4, 6), (2, 4, 7)]
    for (i, j, k), (method, (_, y_resampled)) in zip(positions, list(resampled_data.items())[:6]):
        plt.subplot(i, j, k)
        plt.title(method)
        plt.pie([len(y_resampled) - sum(y_resampled), sum(y_resampled)], 
                labels=["0", "1"], autopct='%1.1f%%', colors=['#82CAFA', '#FE6B64'])

    plt.tight_layout()
    plt.savefig('figures/model_evaluation/resampling_distribution.png', dpi=300, bbox_inches='tight')
    plt.show()

    # Create optimized models
    print("Creating optimized models...")
    optimized_models = create_optimized_models()

    # Train models on resampled data
    print("Training models on different resampled datasets...")
    # Use a subset of models to speed up the process
    models_to_train = {
        'LR_Optimized': optimized_models['LR_Optimized'],
        'RF_Optimized': optimized_models['RF_Optimized'],
        'XGB_Optimized': optimized_models['XGB_Optimized'],
        'LGBM_Optimized': optimized_models['LGBM_Optimized'],
        'Balanced_RF': optimized_models['Balanced_RF']
    }

    # Use a subset of resampling methods
    resampling_methods_to_use = {
        'smote': resampled_data['smote'],
        'smotetomek': resampled_data['smotetomek'],
        'smoteenn': resampled_data['smoteenn']
    }

    # Train models
    training_results = train_models_on_resampled_data(
        models_to_train, 
        resampling_methods_to_use,
        scaled_data['standard']['val'], 
        y_val
    )

    # Extract best models based on F1 score
    best_models_f1 = extract_best_models(training_results, metric='f1', top_n=5)

    print("\nTop 5 models based on F1 score:")
    for i, model_info in enumerate(best_models_f1, 1):
        print(f"{i}. {model_info['model_name']} with {model_info['resampling_method']}: F1 = {model_info['f1']:.4f}")

    # Extract best models based on balanced accuracy
    best_models_bal_acc = extract_best_models(training_results, metric='balanced_accuracy', top_n=5)

    print("\nTop 5 models based on balanced accuracy:")
    for i, model_info in enumerate(best_models_bal_acc, 1):
        print(f"{i}. {model_info['model_name']} with {model_info['resampling_method']}: Balanced Accuracy = {model_info['balanced_accuracy']:.4f}")

    # Extract best models based on ROC AUC
    best_models_roc_auc = extract_best_models(training_results, metric='roc_auc', top_n=5)

    print("\nTop 5 models based on ROC AUC:")
    for i, model_info in enumerate(best_models_roc_auc, 1):
        print(f"{i}. {model_info['model_name']} with {model_info['resampling_method']}: ROC AUC = {model_info['roc_auc']:.4f}")

    # Compile top performing models for ensemble
    top_performing_models = {}
    for model_info in best_models_f1[:3]:  # Take top 3 from F1
        model_key = f"{model_info['model_name']}_{model_info['resampling_method']}"
        top_performing_models[model_key] = model_info['model']

    for model_info in best_models_roc_auc[:2]:  # Take top 2 from ROC AUC
        if model_info not in best_models_f1[:3]:  # Avoid duplicates
            model_key = f"{model_info['model_name']}_{model_info['resampling_method']}"
            top_performing_models[model_key] = model_info['model']

    print(f"\nCollected {len(top_performing_models)} unique top-performing models for ensemble creation.")

    # Create ensembles from top performing models
    print("Creating ensemble models from top performers...")
    ensemble_models = create_ensemble_models(top_performing_models)

    # Train and evaluate ensemble models
    ensemble_results = {}
    for ensemble_name, ensemble_model in ensemble_models.items():
        print(f"\nTraining {ensemble_name}...")
        
        # Train the ensemble
        start_time = time.time()
        try:
            # Use the original training data for ensemble model training
            ensemble_model.fit(scaled_data['standard']['train'], y_train)
            training_time = time.time() - start_time
            
            # Evaluate on validation set
            metrics = evaluate_model_comprehensive(
                ensemble_model, scaled_data['standard']['val'], y_val, 
                ensemble_name
            )
            metrics['training_time'] = training_time
            
            # Store ensemble and metrics
            ensemble_results[ensemble_name] = {
                'model': ensemble_model,
                'metrics': metrics
            }
            
            print(f"Training time: {training_time:.2f} seconds")
        except Exception as e:
            print(f"Error training {ensemble_name}: {e}")

    # Find the best ensemble model
    best_ensemble = None
    best_ensemble_f1 = 0

    for ensemble_name, result in ensemble_results.items():
        f1 = result['metrics']['f1']
        if f1 > best_ensemble_f1:
            best_ensemble_f1 = f1
            best_ensemble = (ensemble_name, result['model'])

    print(f"\nBest ensemble: {best_ensemble[0]} with F1 = {best_ensemble_f1:.4f}")

    # Compare top individual models with ensemble models
    compare_models = {}

    # Add top individual models
    for model_info in best_models_f1[:3]:
        model_key = f"{model_info['model_name']} ({model_info['resampling_method']})"
        compare_models[model_key] = model_info['model']

    # Add top ensemble models
    for ensemble_name, result in ensemble_results.items():
        if result['metrics']['f1'] >= 0.8 * best_ensemble_f1:  # Include ensembles with at least 80% of the best F1
            compare_models[ensemble_name] = result['model']

    # Plot ROC and PR curves for comparison
    plot_roc_curve_enhanced(compare_models, scaled_data['standard']['val'], y_val)
    plot_precision_recall_curve_enhanced(compare_models, scaled_data['standard']['val'], y_val)

    # Create a performance comparison dataframe
    performance_df = pd.DataFrame(columns=['Model', 'Type', 'Accuracy', 'Balanced_Accuracy', 'Precision', 
                                         'Recall', 'F1', 'ROC_AUC', 'G_mean'])

    # Add individual models
    for model_info in best_models_f1:
        model_key = f"{model_info['model_name']} ({model_info['resampling_method']})"
        metrics = model_info['metrics']
        performance_df = pd.concat([performance_df, pd.DataFrame([{
            'Model': model_key,
            'Type': 'Individual',
            'Accuracy': metrics['accuracy'],
            'Balanced_Accuracy': metrics['balanced_accuracy'],
            'Precision': metrics['precision'],
            'Recall': metrics['recall'],
            'F1': metrics['f1'],
            'ROC_AUC': metrics.get('roc_auc', np.nan),
            'G_mean': metrics['g_mean']
        }])], ignore_index=True)

    # Add ensemble models
    for ensemble_name, result in ensemble_results.items():
        metrics = result['metrics']
        performance_df = pd.concat([performance_df, pd.DataFrame([{
            'Model': ensemble_name,
            'Type': 'Ensemble',
            'Accuracy': metrics['accuracy'],
            'Balanced_Accuracy': metrics['balanced_accuracy'],
            'Precision': metrics['precision'],
            'Recall': metrics['recall'],
            'F1': metrics['f1'],
            'ROC_AUC': metrics.get('roc_auc', np.nan),
            'G_mean': metrics['g_mean']
        }])], ignore_index=True)

    # Sort by F1 score
    performance_df = performance_df.sort_values('F1', ascending=False).reset_index(drop=True)

    # Save the performance comparison
    performance_df.to_csv('results/model_comparison.csv', index=False)

    # Display the performance comparison
    print("\nModel Performance Comparison:")
    print(performance_df)

    # Visualize the performance comparison
    plt.figure(figsize=(14, 10))

    metrics_to_plot = ['F1', 'ROC_AUC', 'Balanced_Accuracy', 'G_mean']
    for i, metric in enumerate(metrics_to_plot, 1):
        plt.subplot(2, 2, i)
        ax = sns.barplot(x=metric, y='Model', hue='Type', data=performance_df.head(10))
        plt.title(f'Top 10 Models by {metric}', fontsize=14)
        plt.xlim(0.5, 1.0)
        plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('figures/model_evaluation/performance_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()

    # Optimize thresholds for best models
    print("Optimizing decision thresholds for best models...")

    # Select models for threshold optimization
    threshold_models = {}

    # Add top individual model
    top_individual = best_models_f1[0]
    threshold_models[f"{top_individual['model_name']} ({top_individual['resampling_method']})"] = top_individual['model']

    # Add best ensemble
    threshold_models[best_ensemble[0]] = best_ensemble[1]

    # Add another top model
    second_best = best_models_f1[1]
    threshold_models[f"{second_best['model_name']} ({second_best['resampling_method']})"] = second_best['model']

    # Metrics to optimize for
    metrics_to_optimize = ['f1', 'balanced_accuracy', 'g_mean']

    # Find optimal thresholds
    optimized_thresholds = {}

    for model_name, model in threshold_models.items():
        optimized_thresholds[model_name] = {}
        
        for metric in metrics_to_optimize:
            print(f"\nFinding optimal threshold for {model_name} based on {metric}...")
            
            threshold = find_optimal_threshold(
                model, scaled_data['standard']['val'], y_val, 
                metric=metric
            )
            
            optimized_thresholds[model_name][metric] = threshold
            
            # Evaluate with the optimized threshold
            metrics_dict = evaluate_model_comprehensive(
                model, scaled_data['standard']['val'], y_val,
                f"{model_name} (optimized for {metric}, threshold={threshold:.2f})",
                threshold=threshold
            )
            
            # Store the metrics
            optimized_thresholds[model_name][f"{metric}_metrics"] = metrics_dict

    # Create a threshold comparison dataframe
    threshold_df = pd.DataFrame(columns=['Model', 'Optimization_Metric', 'Threshold', 
                                       'Accuracy', 'Balanced_Accuracy', 'Precision', 
                                       'Recall', 'F1', 'ROC_AUC', 'G_mean'])

    for model_name, thresholds in optimized_thresholds.items():
        for metric in metrics_to_optimize:
            threshold = thresholds[metric]
            metrics_dict = thresholds[f"{metric}_metrics"]
            
            threshold_df = pd.concat([threshold_df, pd.DataFrame([{
                'Model': model_name,
                'Optimization_Metric': metric,
                'Threshold': threshold,
                'Accuracy': metrics_dict['accuracy'],
                'Balanced_Accuracy': metrics_dict['balanced_accuracy'],
                'Precision': metrics_dict['precision'],
                'Recall': metrics_dict['recall'],
                'F1': metrics_dict['f1'],
                'ROC_AUC': metrics_dict.get('roc_auc', np.nan),
                'G_mean': metrics_dict['g_mean']
            }])], ignore_index=True)

    # Save the threshold comparison
    threshold_df.to_csv('results/threshold_optimization.csv', index=False)

    # Display the threshold comparison
    print("\nThreshold Optimization Results:")
    print(threshold_df)

    # Visualize the threshold optimization results
    plt.figure(figsize=(12, 8))
    sns.lineplot(x='Threshold', y='F1', hue='Model', data=threshold_df)
    plt.title('F1 Score vs. Threshold by Model', fontsize=14)
    plt.xlabel('Threshold', fontsize=12)
    plt.ylabel('F1 Score', fontsize=12)
    plt.ylim(0.5, 1.0)
    plt.grid(True, alpha=0.3)
    plt.savefig('figures/model_evaluation/f1_vs_threshold.png', dpi=300, bbox_inches='tight')
    plt.show()

    # Perform feature selection to identify the most important features
    print("Performing feature selection to identify key predictors...")

    # Use the best model for feature importance
    best_model_info = best_models_f1[0]
    best_model = best_model_info['model']
    best_model_name = best_model_info['model_name']

    # Extract feature importance if available
    feature_importance = plot_feature_importance_enhanced(
        best_model, X_train.columns, 
        f"{best_model_name} ({best_model_info['resampling_method']})"
    )

    if feature_importance is not None:
        # Select top features based on importance
        top_features = feature_importance['Feature'].head(15).tolist()
        
        print(f"\nSelected top {len(top_features)} features:")
        print(top_features)
        
        # Create datasets with selected features
        X_train_selected = X_train[top_features]
        X_val_selected = X_val[top_features]
        X_test_selected = X_test[top_features]
        
        # Scale the selected features
        scaler_selected = StandardScaler()
        X_train_selected_scaled = scaler_selected.fit_transform(X_train_selected)
        X_val_selected_scaled = scaler_selected.transform(X_val_selected)
        X_test_selected_scaled = scaler_selected.transform(X_test_selected)
        
        # Convert to DataFrames
        X_train_selected_df = pd.DataFrame(X_train_selected_scaled, columns=top_features)
        X_val_selected_df = pd.DataFrame(X_val_selected_scaled, columns=top_features)
        X_test_selected_df = pd.DataFrame(X_test_selected_scaled, columns=top_features)
        
        # Apply best resampling technique
        best_resampling = best_model_info['resampling_method']
        print(f"\nApplying {best_resampling} to selected features...")
        
        if best_resampling == 'smote':
            resampler = SMOTE(random_state=SEED)
        elif best_resampling == 'smotetomek':
            resampler = SMOTETomek(random_state=SEED)
        elif best_resampling == 'smoteenn':
            resampler = SMOTEENN(random_state=SEED)
        else:
            resampler = SMOTE(random_state=SEED)
        
        X_train_selected_resampled, y_train_selected_resampled = resampler.fit_resample(
            X_train_selected_scaled, y_train
        )
        
        # Train the best model on selected features
        print(f"\nTraining final model ({best_model_name}) on selected features with {best_resampling}...")
        
        # Clone the best model
        final_model = clone(best_model)
        
        # Train the final model
        final_model.fit(X_train_selected_resampled, y_train_selected_resampled)
        
        # Find the optimal threshold for the final model
        print("\nOptimizing threshold for final model...")
        optimal_threshold = find_optimal_threshold(
            final_model, X_val_selected_scaled, y_val, 
            metric='f1'
        )
        
        # Evaluate the final model on validation set with optimal threshold
        final_val_metrics = evaluate_model_comprehensive(
            final_model, X_val_selected_scaled, y_val,
            "Final Model (Selected Features) - Validation",
            threshold=optimal_threshold
        )
        
        # Evaluate the final model on test set
        final_test_metrics = evaluate_model_comprehensive(
            final_model, X_test_selected_scaled, y_test,
            "Final Model (Selected Features) - Test",
            threshold=optimal_threshold
        )
        
        # Plot confusion matrix for test set
        y_test_pred = (final_model.predict_proba(X_test_selected_scaled)[:, 1] >= optimal_threshold).astype(int)
        test_cm = confusion_matrix(y_test, y_test_pred)
        plot_confusion_matrix_enhanced(
            test_cm, "Final Model (Selected Features) - Test"
        )
        
        # Print final test results
        print("\nFinal Test Results:")
        print(f"Accuracy: {final_test_metrics['accuracy']:.4f}")
        print(f"Balanced Accuracy: {final_test_metrics['balanced_accuracy']:.4f}")
        print(f"Precision: {final_test_metrics['precision']:.4f}")
        print(f"Recall: {final_test_metrics['recall']:.4f}")
        print(f"F1 Score: {final_test_metrics['f1']:.4f}")
        print(f"ROC AUC: {final_test_metrics['roc_auc']:.4f}")
        print(f"G-mean: {final_test_metrics['g_mean']:.4f}")
        
        # Save the final model
        final_model_path = 'models/final_stroke_prediction_model.pkl'
        with open(final_model_path, 'wb') as f:
            pickle.dump(final_model, f)
        
        # Save the scaler
        scaler_path = 'models/final_stroke_prediction_scaler.pkl'
        with open(scaler_path, 'wb') as f:
            pickle.dump(scaler_selected, f)
        
        # Save the selected features
        with open('models/selected_features.pkl', 'wb') as f:
            pickle.dump(top_features, f)
        
        # Save the optimal threshold
        with open('models/optimal_threshold.pkl', 'wb') as f:
            pickle.dump(optimal_threshold, f)
        
        print(f"\nFinal model saved to {final_model_path}")
        print(f"Scaler saved to {scaler_path}")
        print(f"Selected features: {top_features}")
        print(f"Optimal threshold: {optimal_threshold:.4f}")
    else:
        print("Could not extract feature importance. Using the best model without feature selection.")
        
        # Use the best model directly
        best_ensemble_name, best_ensemble_model = best_ensemble
        
        # Find the optimal threshold
        print("\nOptimizing threshold for best ensemble model...")
        optimal_threshold = find_optimal_threshold(
            best_ensemble_model, scaled_data['standard']['val'], y_val, 
            metric='f1'
        )
        
        # Evaluate on test set
        final_test_metrics = evaluate_model_comprehensive(
            best_ensemble_model, scaled_data['standard']['test'], y_test,
            f"Best Ensemble Model ({best_ensemble_name}) - Test",
            threshold=optimal_threshold
        )
        
        # Plot confusion matrix for test set
        y_test_pred = (best_ensemble_model.predict_proba(scaled_data['standard']['test'])[:, 1] >= optimal_threshold).astype(int)
        test_cm = confusion_matrix(y_test, y_test_pred)
        plot_confusion_matrix_enhanced(
            test_cm, f"Best Ensemble Model ({best_ensemble_name}) - Test"
        )
        
        # Save the best model
        final_model_path = 'models/final_stroke_prediction_model.pkl'
        with open(final_model_path, 'wb') as f:
            pickle.dump(best_ensemble_model, f)
        
        # Save the scaler
        scaler_path = 'models/final_stroke_prediction_scaler.pkl'
        with open(scaler_path, 'wb') as f:
            pickle.dump(scaled_data['standard']['scaler'], f)
        
        # Save the optimal threshold
        with open('models/optimal_threshold.pkl', 'wb') as f:
            pickle.dump(optimal_threshold, f)
        
        print(f"\nBest ensemble model saved to {final_model_path}")
        print(f"Scaler saved to {scaler_path}")
        print(f"Optimal threshold: {optimal_threshold:.4f}")

    # Test the prediction function on a sample from the test set
    sample_size = min(5, len(X_test))
    sample_indices = np.random.choice(len(X_test), sample_size, replace=False)
    sample_data = X_test.iloc[sample_indices]
    sample_labels = y_test.iloc[sample_indices]

    print(f"\nTesting prediction function on {sample_size} random samples:")
    try:
        predictions, probabilities = predict_stroke_risk(sample_data)
        
        print("\nSample  True Label  Predicted  Probability")
        print("----------------------------------------")
        for i, (pred, prob, true) in enumerate(zip(predictions, probabilities, sample_labels)):
            print(f"{i+1:6d}  {true:10d}  {pred:9d}  {prob:.6f}")
        
        print("\nPrediction function works correctly!")
    except Exception as e:
        print(f"Error testing prediction function: {e}")

    # Create a summary of the model development process
    print("\n" + "="*80)
    print("STROKE PREDICTION MODEL DEVELOPMENT SUMMARY")
    print("="*80)

    print("\n1. Data Preparation and Engineering:")
    print(f"   - Original dataset shape: {df.shape}")
    print(f"   - Advanced feature engineering added {df_engineered.shape[1] - df.shape[1]} new features")
    print("   - Features included age transformations, risk scores, and interaction terms")

    print("\n2. Class Imbalance Handling:")
    print("   - Original class distribution: {:.2f}% positive cases".format(y_train.mean() * 100))
    print("   - Applied multiple resampling techniques (SMOTE, SMOTETomek, SMOTEENN)")
    if 'best_model_info' in locals():
        print("   - Best resampling method: " + best_model_info['resampling_method'])
    else:
        print("   - Best resampling method: N/A")

    print("\n3. Model Development:")
    print("   - Trained optimized models with imbalanced learning configurations")
    print("   - Created ensemble models to improve performance")
    if 'best_model_info' in locals():
        print(f"   - Best individual model: {best_model_info['model_name']} (F1={best_model_info['f1']:.4f})")
    else:
        print("   - Best individual model: N/A")
    if 'best_ensemble' in locals() and 'best_ensemble_f1' in locals():
        print(f"   - Best ensemble model: {best_ensemble[0]} (F1={best_ensemble_f1:.4f})")
    else:
        print("   - Best ensemble model: N/A")

    print("\n4. Threshold Optimization:")
    print("   - Optimized decision thresholds for performance metrics (F1, G-mean)")
    if 'optimal_threshold' in locals():
        print(f"   - Optimal threshold: {optimal_threshold:.4f}")
    else:
        print("   - Optimal threshold: N/A")

    print("\n5. Feature Selection:")
    if 'top_features' in locals():
        print(f"   - Selected {len(top_features)} top features based on importance")
        print(f"   - Key predictors: {', '.join(top_features[:5])}")
    else:
        print("   - Used all available features")

    print("\n6. Final Test Performance:")
    if 'final_test_metrics' in locals():
        print(f"   - Accuracy: {final_test_metrics['accuracy']:.4f}")
        print(f"   - Balanced Accuracy: {final_test_metrics['balanced_accuracy']:.4f}")
        print(f"   - Precision: {final_test_metrics['precision']:.4f}")
        print(f"   - Recall: {final_test_metrics['recall']:.4f}")
        print(f"   - F1 Score: {final_test_metrics['f1']:.4f}")
        print(f"   - ROC AUC: {final_test_metrics['roc_auc']:.4f}")
        print(f"   - G-mean: {final_test_metrics['g_mean']:.4f}")
    else:
        print("   - Final test metrics not available")

    print("\n7. Model Deployment:")
    print(f"   - Final model saved to: models/final_stroke_prediction_model.pkl")
    print(f"   - Prediction function created for new data")
    print("   - Ready for integration into applications")

    # Save the summary to a file
    with open('results/model_development_summary.txt', 'w') as f:
        f.write("STROKE PREDICTION MODEL DEVELOPMENT SUMMARY\n")
        f.write("="*80 + "\n\n")
        
        f.write("1. Data Preparation and Engineering:\n")
        f.write(f"   - Original dataset shape: {df.shape}\n")
        f.write(f"   - Advanced feature engineering added {df_engineered.shape[1] - df.shape[1]} new features\n")
        f.write("   - Features included age transformations, risk scores, and interaction terms\n\n")
        
        f.write("2. Class Imbalance Handling:\n")
        f.write("   - Original class distribution: {:.2f}% positive cases\n".format(y_train.mean() * 100))
        f.write("   - Applied multiple resampling techniques (SMOTE, SMOTETomek, SMOTEENN)\n")
        if 'best_model_info' in locals():
            f.write("   - Best resampling method: " + best_model_info['resampling_method'] + "\n\n")
        else:
            f.write("   - Best resampling method: N/A\n\n")
        
        f.write("3. Model Development:\n")
        f.write("   - Trained optimized models with imbalanced learning configurations\n")
        f.write("   - Created ensemble models to improve performance\n")
        if 'best_model_info' in locals():
            f.write(f"   - Best individual model: {best_model_info['model_name']} (F1={best_model_info['f1']:.4f})\n")
        else:
            f.write("   - Best individual model: N/A\n")
        if 'best_ensemble' in locals() and 'best_ensemble_f1' in locals():
            f.write(f"   - Best ensemble model: {best_ensemble[0]} (F1={best_ensemble_f1:.4f})\n\n")
        else:
            f.write("   - Best ensemble model: N/A\n\n")
        
        f.write("4. Threshold Optimization:\n")
        f.write("   - Optimized decision thresholds for performance metrics (F1, G-mean)\n")
        if 'optimal_threshold' in locals():
            f.write(f"   - Optimal threshold: {optimal_threshold:.4f}\n\n")
        else:
            f.write("   - Optimal threshold: N/A\n\n")
        
        f.write("5. Feature Selection:\n")
        if 'top_features' in locals():
            f.write(f"   - Selected {len(top_features)} top features based on importance\n")
            f.write(f"   - Key predictors: {', '.join(top_features[:5])}\n\n")
        else:
            f.write("   - Used all available features\n\n")
        
        f.write("6. Final Test Performance:\n")
        if 'final_test_metrics' in locals():
            f.write(f"   - Accuracy: {final_test_metrics['accuracy']:.4f}\n")
            f.write(f"   - Balanced Accuracy: {final_test_metrics['balanced_accuracy']:.4f}\n")
            f.write(f"   - Precision: {final_test_metrics['precision']:.4f}\n")
            f.write(f"   - Recall: {final_test_metrics['recall']:.4f}\n")
            f.write(f"   - F1 Score: {final_test_metrics['f1']:.4f}\n")
            f.write(f"   - ROC AUC: {final_test_metrics['roc_auc']:.4f}\n")
            f.write(f"   - G-mean: {final_test_metrics['g_mean']:.4f}\n\n")
        else:
            f.write("   - Final test metrics not available\n\n")
        
        f.write("7. Model Deployment:\n")
        f.write(f"   - Final model saved to: models/final_stroke_prediction_model.pkl\n")
        f.write(f"   - Prediction function created for new data\n")
        f.write("   - Ready for integration into applications\n")

    print("\nModel development summary saved to: results/model_development_summary.txt")
    print("\nStroke prediction model development complete!")

if __name__ == "__main__":
    main()