In [None]:
# Comprehensive Stroke Prediction with Multiple Models and SHAP Explainability
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import pickle
import os

# For data preprocessing
from sklearn.preprocessing import StandardScaler
from sklearn.impute import SimpleImputer, KNNImputer
from sklearn.model_selection import train_test_split, StratifiedKFold, GridSearchCV, cross_val_score

# For handling imbalance
from imblearn.over_sampling import SMOTE
from imblearn.pipeline import Pipeline as ImbPipeline

# For metrics
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, 
    roc_auc_score, confusion_matrix, classification_report, 
    roc_curve, auc, precision_recall_curve, average_precision_score
)

# Models to evaluate
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
import xgboost as xgb
import lightgbm as lgb

# For SHAP explanations
import shap

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore')

# Create directories for outputs
os.makedirs('models', exist_ok=True)
os.makedirs('plots', exist_ok=True)
os.makedirs('results', exist_ok=True)
os.makedirs('shap_plots', exist_ok=True)

# Set random seed for reproducibility
RANDOM_SEED = 42

def load_and_explore_data(file_path="healthcare-dataset-stroke-data.csv"):
    """
    Load and perform exploratory data analysis on the stroke dataset
    
    Parameters:
    -----------
    file_path : str
        Path to the dataset
        
    Returns:
    --------
    pd.DataFrame
        Processed dataframe
    """
    print("\n" + "="*80)
    print("STEP 1: LOADING AND EXPLORING DATA")
    print("="*80)
    
    # Load the dataset
    print(f"\nLoading data from {file_path}...")
    df = pd.read_csv(file_path)
    
    # Basic info
    print(f"\nDataset shape: {df.shape}")
    
    # Check for missing values
    missing_values = df.isnull().sum()
    print("\nMissing values by column:")
    print(missing_values[missing_values > 0])
    
    # Target distribution
    print("\nTarget (stroke) distribution:")
    stroke_counts = df['stroke'].value_counts()
    print(stroke_counts)
    print(f"Percentage of stroke cases: {100 * stroke_counts[1] / len(df):.2f}%")
    
    # Feature distributions
    plt.figure(figsize=(14, 10))
    
    # Plot target distribution
    plt.subplot(2, 2, 1)
    sns.countplot(x='stroke', data=df, palette='Set2')
    plt.title('Stroke Distribution')
    
    # Plot age distribution by stroke
    plt.subplot(2, 2, 2)
    sns.boxplot(x='stroke', y='age', data=df)
    plt.title('Age Distribution by Stroke Status')
    
    # Plot glucose level distribution by stroke
    plt.subplot(2, 2, 3)
    sns.boxplot(x='stroke', y='avg_glucose_level', data=df)
    plt.title('Glucose Level by Stroke Status')
    
    # Plot BMI distribution by stroke
    plt.subplot(2, 2, 4)
    sns.boxplot(x='stroke', y='bmi', data=df)
    plt.title('BMI by Stroke Status')
    
    plt.tight_layout()
    plt.savefig('plots/exploratory_analysis.png')
    plt.close()
    
    # Check for duplicates
    duplicate_count = df.duplicated().sum()
    print(f"\nNumber of duplicated rows: {duplicate_count}")
    
    return df

def preprocess_data(df):
    """
    Perform data preprocessing including handling missing values,
    encoding categorical features, and preparing features and target
    
    Parameters:
    -----------
    df : pd.DataFrame
        Raw dataframe
        
    Returns:
    --------
    tuple
        X, y, encoder, and scaler
    """
    print("\n" + "="*80)
    print("STEP 2: DATA PREPROCESSING")
    print("="*80)
    
    # Create a copy to avoid modifying the original dataframe
    data = df.copy()
    
    # Drop ID column if it exists
    if 'id' in data.columns:
        print("\nDropping ID column...")
        data.drop('id', axis=1, inplace=True)
    
    # Handle the 'Other' gender category (very small number of samples)
    if 'gender' in data.columns:
        other_count = (data['gender'] == 'Other').sum()
        if other_count > 0:
            print(f"\nReplacing 'Other' gender category ({other_count} samples) with 'Female'...")
            data['gender'] = data['gender'].replace('Other', 'Female')
    
    # Handle missing values in bmi using KNN imputation
    if 'bmi' in data.columns and data['bmi'].isnull().sum() > 0:
        missing_bmi_count = data['bmi'].isnull().sum()
        print(f"\nImputing {missing_bmi_count} missing BMI values using KNN...")
        
        # Prepare data for KNN imputation
        impute_df = data.copy()
        
        # Encode categorical features for KNN imputation
        categorical_cols = impute_df.select_dtypes(include=['object']).columns
        for col in categorical_cols:
            impute_df[col] = pd.factorize(impute_df[col])[0]
        
        # Apply KNN imputation
        imputer = KNNImputer(n_neighbors=5)
        imputed_data = imputer.fit_transform(impute_df)
        imputed_df = pd.DataFrame(imputed_data, columns=impute_df.columns)
        
        # Update BMI values in original dataframe
        data['bmi'] = imputed_df['bmi']
    
    # Handle categorical features
    print("\nEncoding categorical features...")
    
    # Create dummies for categorical features
    categorical_cols = data.select_dtypes(include=['object']).columns
    
    # Apply one-hot encoding
    data = pd.get_dummies(data, columns=categorical_cols, drop_first=True)
    
    # Check for outliers in numerical features
    print("\nChecking for outliers in numerical features...")
    numerical_cols = ['age', 'avg_glucose_level', 'bmi']
    
    for col in numerical_cols:
        # Calculate IQR
        Q1 = data[col].quantile(0.25)
        Q3 = data[col].quantile(0.75)
        IQR = Q3 - Q1
        
        # Define bounds
        lower_bound = Q1 - 1.5 * IQR
        upper_bound = Q3 + 1.5 * IQR
        
        # Count outliers
        outliers = ((data[col] < lower_bound) | (data[col] > upper_bound)).sum()
        print(f"  {col}: {outliers} outliers detected")
    
    # Split into features and target
    print("\nSplitting into features and target...")
    X = data.drop('stroke', axis=1)
    y = data['stroke']
    
    # Scale numerical features
    print("\nScaling numerical features...")
    scaler = StandardScaler()
    X[numerical_cols] = scaler.fit_transform(X[numerical_cols])
    
    print(f"\nFinal feature set: {X.shape[1]} features")
    print(f"Feature names: {X.columns.tolist()}")
    
    return X, y, scaler

def handle_imbalance(X, y, method='smote'):
    """
    Handle class imbalance using SMOTE
    
    Parameters:
    -----------
    X : pd.DataFrame
        Feature dataframe
    y : pd.Series
        Target series
    method : str
        Method to use for handling imbalance
        
    Returns:
    --------
    tuple
        Balanced X and y
    """
    print("\n" + "="*80)
    print("STEP 3: HANDLING CLASS IMBALANCE")
    print("="*80)
    
    # Display class distribution before balancing
    print("\nClass distribution before balancing:")
    class_counts = y.value_counts()
    print(class_counts)
    print(f"Class imbalance ratio: 1:{class_counts[0]/class_counts[1]:.1f}")
    
    # Apply SMOTE for oversampling
    print(f"\nApplying {method.upper()} to balance classes...")
    if method.lower() == 'smote':
        smote = SMOTE(random_state=RANDOM_SEED)
        X_balanced, y_balanced = smote.fit_resample(X, y)
    else:
        raise ValueError(f"Unknown method: {method}")
    
    # Display new class distribution
    print("\nClass distribution after balancing:")
    balanced_counts = pd.Series(y_balanced).value_counts()
    print(balanced_counts)
    
    # Plot class distribution before and after balancing
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.bar(['No Stroke', 'Stroke'], class_counts.values, color=['skyblue', 'salmon'])
    plt.title('Class Distribution Before Balancing')
    plt.ylabel('Count')
    
    plt.subplot(1, 2, 2)
    plt.bar(['No Stroke', 'Stroke'], balanced_counts.values, color=['skyblue', 'salmon'])
    plt.title('Class Distribution After Balancing')
    
    plt.tight_layout()
    plt.savefig('plots/class_balance.png')
    plt.close()
    
    return X_balanced, y_balanced

def train_and_evaluate_models(X, y, cv=5):
    """
    Train and evaluate multiple classification models
    
    Parameters:
    -----------
    X : pd.DataFrame
        Feature dataframe
    y : pd.Series
        Target series
    cv : int
        Number of cross-validation folds
        
    Returns:
    --------
    dict
        Results for each model
    """
    print("\n" + "="*80)
    print("STEP 4: MODEL TRAINING AND EVALUATION")
    print("="*80)
    
    # Define models with optimized hyperparameters
    models = {
        'Logistic Regression': LogisticRegression(
            C=1.0, 
            class_weight='balanced',
            max_iter=1000,
            solver='liblinear',
            random_state=RANDOM_SEED
        ),
        'Decision Tree': DecisionTreeClassifier(
            max_depth=10,
            min_samples_split=5,
            min_samples_leaf=2,
            class_weight='balanced',
            random_state=RANDOM_SEED
        ),
        'Random Forest': RandomForestClassifier(
            n_estimators=100,
            max_depth=10,
            min_samples_split=5,
            min_samples_leaf=2,
            class_weight='balanced',
            random_state=RANDOM_SEED
        ),
        'KNN': KNeighborsClassifier(
            n_neighbors=7,
            weights='distance'
        ),
        'XGBoost': xgb.XGBClassifier(
            learning_rate=0.1,
            n_estimators=100,
            max_depth=4,
            min_child_weight=2,
            subsample=0.8,
            colsample_bytree=0.8,
            scale_pos_weight=10,  # Helps with class imbalance
            random_state=RANDOM_SEED
        ),
        'LightGBM': lgb.LGBMClassifier(
            learning_rate=0.05,
            n_estimators=100,
            num_leaves=31,
            max_depth=5,
            min_child_samples=20,
            subsample=0.8,
            colsample_bytree=0.8,
            is_unbalance=True,  # Handle class imbalance
            random_state=RANDOM_SEED
        )
    }
    
    # Dictionary to store results
    results = {}
    model_performances = {}
    
    # Split data for validation set (to be used for SHAP explanations)
    X_train, X_val, y_train, y_val = train_test_split(
        X, y, test_size=0.2, random_state=RANDOM_SEED, stratify=y
    )
    
    # Train and evaluate each model with cross-validation
    for name, model in models.items():
        print(f"\n{'-'*50}")
        print(f"Training and evaluating {name}...")
        
        # Set up cross-validation
        skf = StratifiedKFold(n_splits=cv, shuffle=True, random_state=RANDOM_SEED)
        
        # Lists to store metrics from each fold
        accuracies = []
        precisions = []
        recalls = []
        f1_scores = []
        roc_aucs = []
        
        # Perform cross-validation
        for train_idx, test_idx in skf.split(X_train, y_train):
            # Split data for this fold
            X_train_fold, X_test_fold = X_train.iloc[train_idx], X_train.iloc[test_idx]
            y_train_fold, y_test_fold = y_train.iloc[train_idx], y_train.iloc[test_idx]
            
            # Train model
            model.fit(X_train_fold, y_train_fold)
            
            # Make predictions
            y_pred = model.predict(X_test_fold)
            
            try:
                # Get probabilities for AUC calculation
                y_pred_proba = model.predict_proba(X_test_fold)[:, 1]
                has_proba = True
            except:
                has_proba = False
            
            # Calculate metrics
            accuracies.append(accuracy_score(y_test_fold, y_pred))
            precisions.append(precision_score(y_test_fold, y_pred))
            recalls.append(recall_score(y_test_fold, y_pred))
            f1_scores.append(f1_score(y_test_fold, y_pred))
            
            if has_proba:
                roc_aucs.append(roc_auc_score(y_test_fold, y_pred_proba))
        
        # Calculate mean metrics
        mean_accuracy = np.mean(accuracies)
        mean_precision = np.mean(precisions)
        mean_recall = np.mean(recalls)
        mean_f1 = np.mean(f1_scores)
        mean_roc_auc = np.mean(roc_aucs) if roc_aucs else None
        
        # Display mean metrics
        print(f"\nCross-validation results ({cv} folds):")
        print(f"  Accuracy:  {mean_accuracy:.4f}")
        print(f"  Precision: {mean_precision:.4f}")
        print(f"  Recall:    {mean_recall:.4f}")
        print(f"  F1 Score:  {mean_f1:.4f}")
        if mean_roc_auc:
            print(f"  ROC AUC:   {mean_roc_auc:.4f}")
        
        # Train final model on full training set
        print("\nTraining final model on full training set...")
        model.fit(X_train, y_train)
        
        # Evaluate on validation set
        y_val_pred = model.predict(X_val)
        
        try:
            y_val_pred_proba = model.predict_proba(X_val)[:, 1]
            val_roc_auc = roc_auc_score(y_val, y_val_pred_proba)
            has_val_proba = True
        except:
            val_roc_auc = None
            has_val_proba = False
        
        val_accuracy = accuracy_score(y_val, y_val_pred)
        val_precision = precision_score(y_val, y_val_pred)
        val_recall = recall_score(y_val, y_val_pred)
        val_f1 = f1_score(y_val, y_val_pred)
        
        print("\nValidation set results:")
        print(f"  Accuracy:  {val_accuracy:.4f}")
        print(f"  Precision: {val_precision:.4f}")
        print(f"  Recall:    {val_recall:.4f}")
        print(f"  F1 Score:  {val_f1:.4f}")
        if val_roc_auc:
            print(f"  ROC AUC:   {val_roc_auc:.4f}")
        
        # Print classification report
        print("\nClassification Report (Validation Set):")
        print(classification_report(y_val, y_val_pred))
        
        # Plot confusion matrix
        plt.figure(figsize=(8, 6))
        cm = confusion_matrix(y_val, y_val_pred)
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                   xticklabels=['No Stroke', 'Stroke'],
                   yticklabels=['No Stroke', 'Stroke'])
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title(f'Confusion Matrix - {name}')
        plt.tight_layout()
        plt.savefig(f'plots/{name.lower().replace(" ", "_")}_confusion_matrix.png')
        plt.close()
        
        # Plot ROC curve if probabilities are available
        if has_val_proba:
            plt.figure(figsize=(8, 6))
            fpr, tpr, _ = roc_curve(y_val, y_val_pred_proba)
            plt.plot(fpr, tpr, label=f'ROC Curve (AUC = {val_roc_auc:.3f})')
            plt.plot([0, 1], [0, 1], 'k--', label='Random')
            plt.xlabel('False Positive Rate')
            plt.ylabel('True Positive Rate')
            plt.title(f'ROC Curve - {name}')
            plt.legend(loc='lower right')
            plt.grid(True, alpha=0.3)
            plt.savefig(f'plots/{name.lower().replace(" ", "_")}_roc_curve.png')
            plt.close()
            
            # Plot precision-recall curve
            plt.figure(figsize=(8, 6))
            precision_vals, recall_vals, _ = precision_recall_curve(y_val, y_val_pred_proba)
            avg_precision = average_precision_score(y_val, y_val_pred_proba)
            plt.plot(recall_vals, precision_vals, label=f'AP = {avg_precision:.3f}')
            plt.xlabel('Recall')
            plt.ylabel('Precision')
            plt.title(f'Precision-Recall Curve - {name}')
            plt.legend()
            plt.grid(True, alpha=0.3)
            plt.savefig(f'plots/{name.lower().replace(" ", "_")}_pr_curve.png')
            plt.close()
        
        # Feature importance if available
        if hasattr(model, 'feature_importances_') or name == 'Logistic Regression':
            plt.figure(figsize=(12, 8))
            
            if name == 'Logistic Regression':
                # For logistic regression, use coefficients
                coefs = model.coef_[0]
                feature_importance = pd.DataFrame({
                    'Feature': X.columns,
                    'Importance': np.abs(coefs)  # Use absolute values
                }).sort_values('Importance', ascending=False)
                
                # Plot feature importance
                sns.barplot(x='Importance', y='Feature', data=feature_importance.head(15))
                plt.title(f'Feature Importance (Coefficient Magnitude) - {name}')
                
            else:
                # For tree-based models, use built-in feature importance
                feature_importance = pd.DataFrame({
                    'Feature': X.columns,
                    'Importance': model.feature_importances_
                }).sort_values('Importance', ascending=False)
                
                # Plot feature importance
                sns.barplot(x='Importance', y='Feature', data=feature_importance.head(15))
                plt.title(f'Feature Importance - {name}')
            
            plt.tight_layout()
            plt.savefig(f'plots/{name.lower().replace(" ", "_")}_feature_importance.png')
            plt.close()
        
        # Store results
        results[name] = {
            'model': model,
            'cv_metrics': {
                'accuracy': mean_accuracy,
                'precision': mean_precision,
                'recall': mean_recall,
                'f1': mean_f1,
                'roc_auc': mean_roc_auc
            },
            'val_metrics': {
                'accuracy': val_accuracy,
                'precision': val_precision,
                'recall': val_recall,
                'f1': val_f1,
                'roc_auc': val_roc_auc
            }
        }
        
        # Store performance metrics for comparison
        model_performances[name] = {
            'CV Accuracy': mean_accuracy,
            'CV Precision': mean_precision,
            'CV Recall': mean_recall,
            'CV F1': mean_f1,
            'CV ROC AUC': mean_roc_auc if mean_roc_auc else 0,
            'Val Accuracy': val_accuracy,
            'Val Precision': val_precision,
            'Val Recall': val_recall,
            'Val F1': val_f1,
            'Val ROC AUC': val_roc_auc if val_roc_auc else 0
        }
    
    # Create a dataframe for model comparison
    performance_df = pd.DataFrame(model_performances).T
    
    # Sort by F1 score
    performance_df = performance_df.sort_values('Val F1', ascending=False)
    
    # Display model comparison
    print("\n" + "="*80)
    print("MODEL COMPARISON (SORTED BY VALIDATION F1 SCORE)")
    print("="*80)
    print(performance_df.round(4))
    
    # Plot model comparison
    plt.figure(figsize=(12, 8))
    performance_df[['Val Accuracy', 'Val Precision', 'Val Recall', 'Val F1', 'Val ROC AUC']].plot(kind='bar')
    plt.title('Model Performance Comparison (Validation Set)')
    plt.xlabel('Model')
    plt.ylabel('Score')
    plt.xticks(rotation=45)
    plt.grid(axis='y', alpha=0.3)
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=5)
    plt.tight_layout()
    plt.savefig('plots/model_comparison.png')
    plt.close()
    
    # Return results and the validation data for SHAP analysis
    return results, performance_df, (X_val, y_val)

def create_shap_explanations(results, X_val, feature_names):
    """
    Create SHAP explanations for the best models
    
    Parameters:
    -----------
    results : dict
        Results for each model
    X_val : pd.DataFrame
        Validation feature dataframe
    feature_names : list
        List of feature names
    """
    print("\n" + "="*80)
    print("STEP 5: MODEL EXPLAINABILITY WITH SHAP")
    print("="*80)
    
    # Models to explain (top 3 based on F1 score)
    models_to_explain = list(results.keys())[:3]
    
    for model_name in models_to_explain:
        print(f"\nGenerating SHAP explanations for {model_name}...")
        
        model = results[model_name]['model']
        
        try:
            # Create explainer based on model type
            if hasattr(model, 'feature_importances_'):
                # For tree-based models (Random Forest, XGBoost, etc.)
                explainer = shap.TreeExplainer(model)
                shap_values = explainer.shap_values(X_val)
                
                # For classification with 2 classes, shap_values is a list with values for each class
                if isinstance(shap_values, list):
                    # Use values for positive class (stroke = 1)
                    shap_values_to_plot = shap_values[1]
                else:
                    shap_values_to_plot = shap_values
                
            else:
                # For other models (Logistic Regression, SVM, etc.)
                # Create a background dataset (sample from validation set)
                background = shap.sample(X_val, 100)
                explainer = shap.KernelExplainer(model.predict_proba, background)
                shap_values = explainer.shap_values(X_val.iloc[:100])  # Use subset for performance
                
                # Use values for positive class (stroke = 1)
                shap_values_to_plot = shap_values[1]
            
            # Summary plot (beeswarm plot)
            plt.figure(figsize=(12, 8))
            shap.summary_plot(shap_values_to_plot, X_val, feature_names=feature_names, show=False)
            plt.title(f'SHAP Summary Plot - {model_name}')
            plt.tight_layout()
            plt.savefig(f'shap_plots/{model_name.lower().replace(" ", "_")}_summary.png')
            plt.close()
            print(f"  Summary plot saved to shap_plots/{model_name.lower().replace(' ', '_')}_summary.png")
            
            # Bar plot (feature importance)
            plt.figure(figsize=(12, 8))
            shap.summary_plot(shap_values_to_plot, X_val, feature_names=feature_names, plot_type='bar', show=False)
            plt.title(f'SHAP Feature Importance - {model_name}')
            plt.tight_layout()
            plt.savefig(f'shap_plots/{model_name.lower().replace(" ", "_")}_importance.png')
            plt.close()
            print(f"  Feature importance plot saved to shap_plots/{model_name.lower().replace(' ', '_')}_importance.png")
            
            # Dependence plots for top 3 features
            # Get mean absolute SHAP values for each feature
            feature_importance = np.abs(shap_values_to_plot).mean(0)
            most_important_features = np.argsort(feature_importance)[-3:][::-1]
            
            for i, feature_idx in enumerate(most_important_features):
                feature_name = feature_names[feature_idx]
                plt.figure(figsize=(12, 8))
                shap.dependence_plot(feature_idx, shap_values_to_plot, X_val, feature_names=feature_names, show=False)
                plt.title(f'SHAP Dependence Plot - {feature_name} ({model_name})')
                plt.tight_layout()
                plt.savefig(f'shap_plots/{model_name.lower().replace(" ", "_")}_dependence_{i+1}.png')
                plt.close()
                print(f"  Dependence plot for {feature_name} saved")
            
            # Force plot for a single prediction (first positive case in validation set)
            positive_indices = np.where(y_val == 1)[0]
            if len(positive_indices) > 0:
                idx = positive_indices[0]
                plt.figure(figsize=(12, 3))
                
                if hasattr(model, 'feature_importances_'):
                    # For tree-based models
                    if isinstance(shap_values, list):
                        # Get expected value for positive class
                        expected_value = explainer.expected_value[1] if isinstance(explainer.expected_value, list) else explainer.expected_value
                    else:
                        expected_value = explainer.expected_value
                    
                    # Create force plot
                    shap.force_plot(expected_value, shap_values_to_plot[idx], X_val.iloc[idx], feature_names=feature_names, matplotlib=True, show=False)
                else:
                    # For other models
                    shap.force_plot(explainer.expected_value[1], shap_values[1][0], X_val.iloc[0], feature_names=feature_names, matplotlib=True, show=False)
                
                plt.title(f'SHAP Force Plot - Single Prediction ({model_name})')
                plt.tight_layout()
                plt.savefig(f'shap_plots/{model_name.lower().replace(" ", "_")}_force_plot.png')
                plt.close()
                print(f"  Force plot saved to shap_plots/{model_name.lower().replace(' ', '_')}_force_plot.png")
            
        except Exception as e:
            print(f"Error creating SHAP explanations for {model_name}: {str(e)}")
    
    print("\nSHAP explanations completed!")

def save_best_model(results, performance_df):
    """
    Save the best model based on F1 score
    
    Parameters:
    -----------
    results : dict
        Results for each model
    performance_df : pd.DataFrame
        Dataframe with model performance metrics
    """
    print("\n" + "="*80)
    print("STEP 6: SAVING BEST MODEL")
    print("="*80)
    
    # Get the best model (first row in performance dataframe, sorted by F1 score)
    best_model_name = performance_df.index[0]
    best_model = results[best_model_name]['model']
    
    print(f"\nBest model: {best_model_name}")
    print(f"Validation F1 Score: {performance_df.loc[best_model_name, 'Val F1']:.4f}")
    print(f"Validation Recall: {performance_df.loc[best_model_name, 'Val Recall']:.4f}")
    
    # Save model
    model_path = 'models/best_stroke_model.pkl'
    with open(model_path, 'wb') as f:
        pickle.dump(best_model, f)
    
    print(f"\nBest model saved to {model_path}")
    
    # Save model info
    model_info = {
        'model_name': best_model_name,
        'metrics': {
            'accuracy': performance_df.loc[best_model_name, 'Val Accuracy'],
            'precision': performance_df.loc[best_model_name, 'Val Precision'],
            'recall': performance_df.loc[best_model_name, 'Val Recall'],
            'f1': performance_df.loc[best_model_name, 'Val F1'],
            'roc_auc': performance_df.loc[best_model_name, 'Val ROC AUC']
        }
    }
    
    model_info_path = 'models/model_info.json'
    with open(model_info_path, 'w') as f:
        import json
        json.dump(model_info, f, indent=4)
    
    print(f"Model info saved to {model_info_path}")

def main():
    """Main function to run the stroke prediction pipeline"""
    print("\n" + "="*80)
    print("STROKE PREDICTION PIPELINE")
    print("="*80)
    
    # Step 1: Load and explore data
    df = load_and_explore_data("healthcare-dataset-stroke-data.csv")
    
    # Step 2: Preprocess data
    X, y, scaler = preprocess_data(df)
    
    # Step 3: Handle class imbalance
    X_balanced, y_balanced = handle_imbalance(X, y, method='smote')
    
    # Step 4: Train and evaluate models
    results, performance_df, (X_val, y_val) = train_and_evaluate_models(X_balanced, y_balanced, cv=5)
    
    # Step 5: Create SHAP explanations
    create_shap_explanations(results, X_val, X.columns)
    
    # Step 6: Save best model
    save_best_model(results, performance_df)
    
    print("\n" + "="*80)
    print("PIPELINE COMPLETED!")
    print("="*80)

if __name__ == "__main__":
    main()