# Stroke Prediction: Model Explainability with SHAP

This notebook focuses on explaining the stroke prediction model using SHAP (SHapley Additive exPlanations):
1. Loading the trained model and preprocessed data
2. Understanding the SHAP framework for explainability
3. Generating global feature importance explanations
4. Creating local explanations for individual predictions
5. Visualizing feature interactions and their effects

## 1. Import Libraries

In [None]:
# General libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import os
import pickle
import json
from tqdm import tqdm

# SHAP explainer library
import shap

# ML libraries
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

# Set random seed for reproducibility
SEED = 42
np.random.seed(SEED)

# Display settings
pd.set_option('display.max_columns', None)
pd.set_option('display.float_format', '{:.4f}'.format)
sns.set_style('whitegrid')
plt.style.use('fivethirtyeight')

# Ignore warnings for cleaner output
warnings.filterwarnings('ignore')

# Create directories for outputs
os.makedirs('figures/shap_explanations', exist_ok=True)

## 2. Load Model and Data

In [None]:
# Load the saved model, scaler, and metadata
def load_model_assets(model_path='models/stroke_prediction_model.pkl',
                      scaler_path='models/stroke_prediction_scaler.pkl',
                      metadata_path='models/stroke_prediction_metadata.json'):
    """
    Load the saved model, scaler, and metadata.
    
    Returns:
    --------
    tuple
        model, scaler, metadata
    """
    # Load model
    with open(model_path, 'rb') as f:
        model = pickle.load(f)
    print(f"Loaded model from {model_path}")
    
    # Load scaler
    with open(scaler_path, 'rb') as f:
        scaler = pickle.load(f)
    print(f"Loaded scaler from {scaler_path}")
    
    # Load metadata
    with open(metadata_path, 'r') as f:
        metadata = json.load(f)
    print(f"Loaded metadata from {metadata_path}")
    
    return model, scaler, metadata

# Load model assets
model, scaler, metadata = load_model_assets()

# Display model metadata
print("\nModel Metadata:")
print(f"Model name: {metadata['model_name']}")
print(f"Number of features: {metadata['num_features']}")
print("\nTest metrics:")
for metric, value in metadata['test_metrics'].items():
    if value is not None:  # Some metrics might be None
        print(f"  {metric}: {value:.4f}")
print("\nClass distribution:")
print(f"  Negative (No Stroke): {metadata['class_distribution']['negative']}")
print(f"  Positive (Stroke): {metadata['class_distribution']['positive']}")

In [None]:
# Load the preprocessed data for SHAP analysis
# We'll use the encoded dataset to match the model's input format
df = pd.read_csv('data/processed/stroke_dataset_encoded.csv')

# Split features and target
X = df.drop('stroke', axis=1)
y = df['stroke']

# Split into train and test sets
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.20, random_state=SEED, stratify=y
)

# Scale features
X_train_scaled = scaler.transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Convert scaled data back to DataFrames with feature names
X_train_scaled_df = pd.DataFrame(X_train_scaled, columns=X_train.columns)
X_test_scaled_df = pd.DataFrame(X_test_scaled, columns=X_test.columns)

print(f"Data loaded and prepared: {X_train.shape[0]} training samples, {X_test.shape[0]} test samples")

In [None]:
# Also load the unencoded dataset for more interpretable feature names
df_interpretable = pd.read_csv('data/processed/stroke_dataset_eda.csv')

# Create a mapping of feature names
feature_mapping = {}
for col in X.columns:
    # For one-hot encoded columns, extract the original feature name and value
    if '_' in col:
        parts = col.split('_')
        if parts[-1].isdigit():  # For numbered categories
            feature = '_'.join(parts[:-1])
            value = parts[-1]
            feature_mapping[col] = f"{feature} = {value}"
        else:  # For string categories
            feature = parts[0]
            value = '_'.join(parts[1:])
            feature_mapping[col] = f"{feature} = {value}"
    else:
        # For non-encoded columns, use the original name
        feature_mapping[col] = col
        
print("Sample of feature mapping for better interpretability:")
for i, (orig, interp) in enumerate(list(feature_mapping.items())[:10]):
    print(f"{orig} → {interp}")

## 3. Understanding SHAP Framework

### 3.1 Introduction to SHAP

SHAP (SHapley Additive exPlanations) is a unified approach to explain the output of any machine learning model. It connects optimal credit allocation with local explanations using the classic Shapley values from game theory and their related extensions.

Key features of SHAP:
- **Global interpretability**: Understand which features are most important for the model overall
- **Local interpretability**: Explain individual predictions, showing how each feature contributes
- **Model-agnostic**: Works with any machine learning model
- **Solid theoretical foundation**: Based on Shapley values from cooperative game theory

SHAP values represent the contribution of each feature to the prediction, compared to the average prediction. A positive SHAP value means the feature increased the predicted probability of stroke, while a negative value means it decreased it.

### 3.2 Setting Up the SHAP Explainer

In [None]:
# Function to create appropriate SHAP explainer based on model type
def create_shap_explainer(model):
    """
    Create an appropriate SHAP explainer based on the model type.
    
    Parameters:
    -----------
    model : model object
        The trained model
        
    Returns:
    --------
    shap.Explainer
        The appropriate SHAP explainer
    """
    try:
        # Check if it's an ensemble model (StackingClassifier or VotingClassifier)
        if hasattr(model, 'estimators_') and hasattr(model, 'final_estimator_'):  # StackingClassifier
            print("Detected StackingClassifier, using Kernel explainer on meta-estimator...")
            # Use the final estimator for SHAP
            meta_estimator = model.final_estimator_
            if hasattr(meta_estimator, 'feature_importances_') or hasattr(meta_estimator, 'coef_'):
                # For tree-based or linear meta-estimators
                return shap.Explainer(meta_estimator)
            else:
                # For other meta-estimators, use KernelExplainer
                # Create background dataset
                background_data = shap.sample(X_train_scaled_df, 100)  # Sample 100 background examples
                return shap.KernelExplainer(model.predict_proba, background_data)
        
        elif hasattr(model, 'estimators_') and hasattr(model, 'weights'):  # VotingClassifier
            print("Detected VotingClassifier, using KernelExplainer...")
            # For VotingClassifier, use KernelExplainer
            background_data = shap.sample(X_train_scaled_df, 100)  # Sample 100 background examples
            return shap.KernelExplainer(model.predict_proba, background_data)
        
        # Check if it's a tree-based model
        elif hasattr(model, 'feature_importances_'):
            print("Detected tree-based model, using TreeExplainer...")
            # For tree-based models (RandomForest, XGBoost, etc.)
            return shap.TreeExplainer(model)
        
        # Check if it's a linear model
        elif hasattr(model, 'coef_'):
            print("Detected linear model, using LinearExplainer...")
            # For linear models (LogisticRegression, etc.)
            return shap.LinearExplainer(model, X_train_scaled_df)
        
        else:
            print("Model type not specifically supported, using Kernel explainer...")
            # For all other model types, use KernelExplainer
            background_data = shap.sample(X_train_scaled_df, 100)  # Sample 100 background examples
            return shap.KernelExplainer(model.predict_proba, background_data)
            
    except Exception as e:
        print(f"Error creating SHAP explainer: {e}")
        print("Falling back to Kernel explainer...")
        # Fallback to KernelExplainer
        background_data = shap.sample(X_train_scaled_df, 100)  # Sample 100 background examples
        return shap.KernelExplainer(model.predict_proba, background_data)

# Create SHAP explainer
print("Creating SHAP explainer...")
explainer = create_shap_explainer(model)

## 4. Global Feature Importance

In [None]:
# Calculate SHAP values for a sample of the test set to speed up computation
print("Calculating SHAP values for a sample of the test set...")
sample_size = min(100, X_test_scaled_df.shape[0])  # Use at most 100 samples
sample_indices = np.random.choice(X_test_scaled_df.shape[0], sample_size, replace=False)
X_sample = X_test_scaled_df.iloc[sample_indices]

# Calculate SHAP values
shap_values = explainer.shap_values(X_sample)

# Extract values for positive class (stroke=1)
if isinstance(shap_values, list) and len(shap_values) > 1:
    shap_values_class1 = shap_values[1]  # For binary classification, second element is for class 1
else:
    shap_values_class1 = shap_values

print(f"SHAP values calculated for {sample_size} samples.")

In [None]:
# Create a summary plot of SHAP values
plt.figure(figsize=(14, 10))
shap.summary_plot(shap_values_class1, X_sample, max_display=20, show=False)
plt.title("Feature Importance (Impact on Stroke Prediction)", fontsize=16)
plt.tight_layout()
plt.savefig('figures/shap_explanations/global_feature_importance.png', dpi=300, bbox_inches='tight')
plt.show()

# Create a bar plot of mean absolute SHAP values
plt.figure(figsize=(14, 10))
shap.summary_plot(shap_values_class1, X_sample, plot_type="bar", max_display=20, show=False)
plt.title("Feature Importance (Mean Absolute SHAP Values)", fontsize=16)
plt.tight_layout()
plt.savefig('figures/shap_explanations/global_feature_importance_bar.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Calculate and display mean absolute SHAP values for each feature
mean_abs_shap = np.abs(shap_values_class1).mean(0)

# Create DataFrame with feature names and SHAP values
feature_importance = pd.DataFrame({
    'Feature': X_sample.columns,
    'SHAP Importance': mean_abs_shap
})

# Add interpretable feature names
feature_importance['Interpretable Name'] = feature_importance['Feature'].map(feature_mapping)

# Sort by importance
feature_importance = feature_importance.sort_values('SHAP Importance', ascending=False).reset_index(drop=True)

# Display top 20 features
print("Top 20 features by importance (mean absolute SHAP value):")
feature_importance.head(20)

## 5. Feature Interactions and Dependence Plots

In [None]:
# Get the top 5 most important features
top_features = feature_importance['Feature'].head(5).tolist()
print(f"Top 5 features: {top_features}")

# Create dependence plots for each top feature
for feature in top_features:
    plt.figure(figsize=(12, 8))
    shap.dependence_plot(
        feature, shap_values_class1, X_sample,
        interaction_index=None,  # Auto-detect interaction
        show=False
    )
    interpretable_name = feature_mapping.get(feature, feature)
    plt.title(f"Dependence Plot: {interpretable_name}", fontsize=16)
    plt.tight_layout()
    plt.savefig(f'figures/shap_explanations/dependence_{feature.replace(" ", "_")}.png', 
                dpi=300, bbox_inches='tight')
    plt.show()

In [None]:
# Create interaction plots for the top 3 features
# Try to find the strongest interaction for each feature
for i, feature in enumerate(top_features[:3]):
    # Find potential interaction features (excluding the feature itself)
    potential_interactions = [f for f in top_features if f != feature]
    
    # Plot with the first potential interaction feature
    interaction_feature = potential_interactions[0]
    
    plt.figure(figsize=(12, 8))
    shap.dependence_plot(
        feature, shap_values_class1, X_sample,
        interaction_index=interaction_feature,
        show=False
    )
    
    interpretable_name1 = feature_mapping.get(feature, feature)
    interpretable_name2 = feature_mapping.get(interaction_feature, interaction_feature)
    plt.title(f"Interaction: {interpretable_name1} vs {interpretable_name2}", fontsize=16)
    plt.tight_layout()
    plt.savefig(f'figures/shap_explanations/interaction_{feature}_{interaction_feature}.png', 
                dpi=300, bbox_inches='tight')
    plt.show()

## 6. Local Explanations for Individual Predictions

In [None]:
# Function to explain individual predictions
def explain_prediction(instance_idx, X_data, y_data, explainer, feature_mapping=None):
    """
    Explain an individual prediction with SHAP.
    
    Parameters:
    -----------
    instance_idx : int
        Index of the instance to explain
    X_data : pd.DataFrame
        Features dataframe
    y_data : pd.Series
        Target series
    explainer : shap.Explainer
        SHAP explainer
    feature_mapping : dict, optional
        Mapping of feature names to interpretable names
    """
    # Get the instance
    instance = X_data.iloc[[instance_idx]]
    true_label = y_data.iloc[instance_idx]
    
    # Get prediction
    if hasattr(model, 'predict_proba'):
        prediction_proba = model.predict_proba(instance)[0, 1]
        prediction = int(prediction_proba >= 0.5)
    else:
        prediction = model.predict(instance)[0]
        prediction_proba = None
    
    # Calculate SHAP values
    instance_shap_values = explainer.shap_values(instance)
    if isinstance(instance_shap_values, list) and len(instance_shap_values) > 1:
        # For classifiers that return a list of SHAP values per class
        instance_shap_class1 = instance_shap_values[1]
        expected_value = explainer.expected_value[1] if isinstance(explainer.expected_value, list) else explainer.expected_value
    else:
        # For other models
        instance_shap_class1 = instance_shap_values
        expected_value = explainer.expected_value
    
    # Display prediction information
    print(f"Instance Index: {instance_idx}")
    print(f"True Label: {'Stroke' if true_label == 1 else 'No Stroke'} ({true_label})")
    print(f"Predicted Label: {'Stroke' if prediction == 1 else 'No Stroke'} ({prediction})")
    if prediction_proba is not None:
        print(f"Predicted Probability: {prediction_proba:.4f}")
    print(f"Prediction: {'Correct' if prediction == true_label else 'Incorrect'}")
    
    # Display key feature values
    print("\nKey Feature Values:")
    for feature in top_features[:5]:  # Display top 5 important features
        interpretable_name = feature_mapping.get(feature, feature) if feature_mapping else feature
        value = instance[feature].values[0]
        print(f"  {interpretable_name}: {value:.4f}")
    
    # Create force plot
    plt.figure(figsize=(20, 3))
    force_plot = shap.force_plot(expected_value, instance_shap_class1[0], instance.iloc[0], 
                                feature_names=instance.columns.tolist(),
                                matplotlib=True, show=False)
    plt.title(f"SHAP Force Plot - Instance {instance_idx}", fontsize=14)
    plt.tight_layout()
    plt.savefig(f'figures/shap_explanations/force_plot_instance_{instance_idx}.png', 
                dpi=300, bbox_inches='tight')
    plt.show()
    
    # Create waterfall plot
    plt.figure(figsize=(14, 10))
    shap.plots._waterfall.waterfall_legacy(
        expected_value, instance_shap_class1[0], 
        feature_names=instance.columns.tolist(),
        max_display=10, show=False
    )
    plt.title(f"SHAP Waterfall Plot - Instance {instance_idx}", fontsize=14)
    plt.tight_layout()
    plt.savefig(f'figures/shap_explanations/waterfall_plot_instance_{instance_idx}.png', 
                dpi=300, bbox_inches='tight')
    plt.show()
    
    # Create decision plot
    plt.figure(figsize=(12, 10))
    shap.decision_plot(expected_value, instance_shap_class1[0], 
                      instance.iloc[0], feature_names=instance.columns.tolist(),
                      feature_display_range=slice(-1, -10, -1),
                      show=False)
    plt.title(f"SHAP Decision Plot - Instance {instance_idx}", fontsize=14)
    plt.tight_layout()
    plt.savefig(f'figures/shap_explanations/decision_plot_instance_{instance_idx}.png', 
                dpi=300, bbox_inches='tight')
    plt.show()
    
    # Return the SHAP values and expected value for further analysis
    return {
        'instance': instance,
        'true_label': true_label,
        'prediction': prediction,
        'prediction_proba': prediction_proba,
        'shap_values': instance_shap_class1,
        'expected_value': expected_value
    }

In [None]:
# Select some interesting instances to explain
# 1. Find a true positive (correctly predicted stroke)
# 2. Find a true negative (correctly predicted no stroke)
# 3. Find a false positive (incorrectly predicted stroke)
# 4. Find a false negative (incorrectly predicted no stroke)

# Get predictions on test set
y_pred = model.predict(X_test_scaled_df)
try:
    y_pred_proba = model.predict_proba(X_test_scaled_df)[:, 1]
except:
    y_pred_proba = None

# Find indices for each case
true_positive_indices = np.where((y_test == 1) & (y_pred == 1))[0]
true_negative_indices = np.where((y_test == 0) & (y_pred == 0))[0]
false_positive_indices = np.where((y_test == 0) & (y_pred == 1))[0]
false_negative_indices = np.where((y_test == 1) & (y_pred == 0))[0]

print(f"True Positives: {len(true_positive_indices)}")
print(f"True Negatives: {len(true_negative_indices)}")
print(f"False Positives: {len(false_positive_indices)}")
print(f"False Negatives: {len(false_negative_indices)}")

# Select one instance from each category (if available)
instances_to_explain = []

if len(true_positive_indices) > 0:
    instances_to_explain.append((true_positive_indices[0], "True Positive"))
    
if len(true_negative_indices) > 0:
    instances_to_explain.append((true_negative_indices[0], "True Negative"))
    
if len(false_positive_indices) > 0:
    instances_to_explain.append((false_positive_indices[0], "False Positive"))
    
if len(false_negative_indices) > 0:
    instances_to_explain.append((false_negative_indices[0], "False Negative"))

print(f"\nSelected {len(instances_to_explain)} instances to explain")

In [None]:
# Explain each selected instance
explanations = {}

for idx, case_type in instances_to_explain:
    print(f"\n{'='*80}")
    print(f"Explaining {case_type} (Instance {idx})")
    print(f"{'='*80}")
    
    explanation = explain_prediction(
        idx, X_test_scaled_df, y_test, 
        explainer, feature_mapping
    )
    
    explanations[case_type] = explanation

## 7. Comparing Explanations Across Different Cases

In [None]:
# Function to extract top features contributing to a prediction
def get_top_contributors(explanation, top_n=5):
    """
    Get the top features contributing to a prediction.
    
    Parameters:
    -----------
    explanation : dict
        Explanation dictionary with SHAP values
    top_n : int, optional
        Number of top contributors to return
        
    Returns:
    --------
    pd.DataFrame
        DataFrame with top contributing features
    """
    # Get feature names and SHAP values
    feature_names = explanation['instance'].columns
    shap_values = explanation['shap_values'][0]
    
    # Create DataFrame
    contributors = pd.DataFrame({
        'Feature': feature_names,
        'SHAP Value': shap_values,
        'Absolute Impact': np.abs(shap_values)
    })
    
    # Add interpretable names if available
    if feature_mapping:
        contributors['Interpretable Name'] = contributors['Feature'].map(feature_mapping)
    
    # Sort by absolute impact
    contributors = contributors.sort_values('Absolute Impact', ascending=False).reset_index(drop=True)
    
    return contributors.head(top_n)

In [None]:
# Compare top contributors across different cases
for case_type, explanation in explanations.items():
    # Get top contributors
    top_contributors = get_top_contributors(explanation)
    
    print(f"\n{'='*80}")
    print(f"Top Contributors for {case_type} Prediction")
    print(f"{'='*80}")
    print(f"True Label: {'Stroke' if explanation['true_label'] == 1 else 'No Stroke'}")
    print(f"Predicted Label: {'Stroke' if explanation['prediction'] == 1 else 'No Stroke'}")
    if explanation['prediction_proba'] is not None:
        print(f"Predicted Probability: {explanation['prediction_proba']:.4f}")
    print("\nTop Contributing Features:")
    print(top_contributors)
    
    # Create a horizontal bar plot of SHAP values
    plt.figure(figsize=(12, 8))
    
    # Sort by SHAP value
    top_contributors_sorted = top_contributors.sort_values('SHAP Value')
    
    # Plot bars
    bars = plt.barh(
        top_contributors_sorted['Interpretable Name'] if 'Interpretable Name' in top_contributors_sorted.columns else top_contributors_sorted['Feature'],
        top_contributors_sorted['SHAP Value'],
        color=['red' if x > 0 else 'blue' for x in top_contributors_sorted['SHAP Value']]
    )
    
    # Add values to bars
    for bar in bars:
        width = bar.get_width()
        label_x_pos = width if width > 0 else width - 0.05
        plt.text(label_x_pos, bar.get_y() + bar.get_height()/2, 
                f'{width:.4f}', va='center', ha='left' if width > 0 else 'right')
    
    plt.axvline(x=0, color='black', linestyle='-', alpha=0.3)
    plt.xlabel('SHAP Value (Impact on Prediction)', fontsize=12)
    plt.title(f"Top Features Impacting {case_type} Prediction", fontsize=14)
    plt.tight_layout()
    plt.savefig(f'figures/shap_explanations/top_contributors_{case_type.replace(" ", "_").lower()}.png', 
                dpi=300, bbox_inches='tight')
    plt.show()

## 8. Build an Interactive Stroke Risk Simulator

In [None]:
# Create a function to simulate stroke risk for different patient profiles
def simulate_stroke_risk(age, hypertension, heart_disease, bmi, glucose_level, gender='Male', ever_married='Yes',
                         work_type='Private', residence_type='Urban', smoking_status='never smoked'):
    """
    Simulate stroke risk for a given patient profile.
    
    Parameters:
    -----------
    age : float
        Patient age in years
    hypertension : int
        Whether the patient has hypertension (0=No, 1=Yes)
    heart_disease : int
        Whether the patient has heart disease (0=No, 1=Yes)
    bmi : float
        Body Mass Index
    glucose_level : float
        Average glucose level (mg/dL)
    gender : str, optional
        Patient gender ('Male' or 'Female')
    ever_married : str, optional
        Whether the patient was ever married ('Yes' or 'No')
    work_type : str, optional
        Type of work ('Private', 'Self-employed', 'Govt_job', 'children', 'Never_worked')
    residence_type : str, optional
        Type of residence ('Urban' or 'Rural')
    smoking_status : str, optional
        Smoking status ('never smoked', 'formerly smoked', 'smokes', 'Unknown')
        
    Returns:
    --------
    dict
        Dictionary with risk assessment and SHAP explanation
    """
    # Create a DataFrame with the patient profile
    # First, we'll create a patient profile with the same columns as df_interpretable
    patient_data = pd.DataFrame({
        'age': [age],
        'hypertension': [hypertension],
        'heart_disease': [heart_disease],
        'bmi': [bmi],
        'avg_glucose_level': [glucose_level],
        'gender': [gender],
        'ever_married': [ever_married],
        'work_type': [work_type],
        'Residence_type': [residence_type],
        'smoking_status': [smoking_status]
    })
    
    # Now we need to encode this in the same way as our model expects
    # For simplicity, we'll load the one-hot encoder from the preprocessing notebook
    # Since we don't have it here, we'll manually transform the data
    
    # Create a DataFrame with the same columns as X (encoded data)
    encoded_patient = pd.DataFrame(np.zeros((1, len(X.columns))), columns=X.columns)
    
    # Set numerical features
    encoded_patient['age'] = age
    encoded_patient['bmi'] = bmi
    encoded_patient['avg_glucose_level'] = glucose_level
    encoded_patient['hypertension'] = hypertension
    encoded_patient['heart_disease'] = heart_disease
    
    # Set categorical features
    if gender == 'Male':
        encoded_patient['gender_Male'] = 1
        
    if ever_married == 'Yes':
        encoded_patient['ever_married_Yes'] = 1
        
    if residence_type == 'Urban':
        encoded_patient['Residence_type_Urban'] = 1
    
    # Work type
    if work_type in ['Private', 'Self-employed', 'children', 'Never_worked']:
        encoded_patient[f'work_type_{work_type}'] = 1
    
    # Smoking status
    if smoking_status in ['formerly smoked', 'never smoked', 'smokes']:
        encoded_patient[f'smoking_status_{smoking_status}'] = 1
    
    # Fill in derived features based on business logic from preprocessing
    # Age group
    if age <= 18:
        encoded_patient['age_group_0-18'] = 1
    elif age <= 30:
        encoded_patient['age_group_19-30'] = 1
    elif age <= 40:
        encoded_patient['age_group_31-40'] = 1
    elif age <= 50:
        encoded_patient['age_group_41-50'] = 1
    elif age <= 60:
        encoded_patient['age_group_51-60'] = 1
    elif age <= 70:
        encoded_patient['age_group_61-70'] = 1
    elif age <= 80:
        encoded_patient['age_group_71-80'] = 1
    else:
        encoded_patient['age_group_81+'] = 1
    
    # BMI category
    if bmi < 18.5:
        encoded_patient['bmi_category_Underweight'] = 1
    elif bmi < 25:
        encoded_patient['bmi_category_Normal'] = 1
    elif bmi < 30:
        encoded_patient['bmi_category_Overweight'] = 1
    else:
        encoded_patient['bmi_category_Obese'] = 1
    
    # Glucose category
    if glucose_level < 70:
        encoded_patient['glucose_category_Low'] = 1
    elif glucose_level < 100:
        encoded_patient['glucose_category_Normal'] = 1
    elif glucose_level < 126:
        encoded_patient['glucose_category_Prediabetes'] = 1
    else:
        encoded_patient['glucose_category_Diabetes'] = 1
    
    # Interaction features
    encoded_patient['age_hypertension'] = age * hypertension
    encoded_patient['age_heart_disease'] = age * heart_disease
    encoded_patient['glucose_bmi'] = glucose_level * bmi
    encoded_patient['is_senior'] = 1 if age >= 65 else 0
    encoded_patient['comorbidity_None'] = 1 if (hypertension == 0 and heart_disease == 0) else 0
    encoded_patient['comorbidity_One Condition'] = 1 if (hypertension + heart_disease == 1) else 0
    encoded_patient['comorbidity_Both Conditions'] = 1 if (hypertension == 1 and heart_disease == 1) else 0
    
    # Scale numerical features
    # First, make sure all columns from original X are in encoded_patient
    for col in X.columns:
        if col not in encoded_patient.columns:
            encoded_patient[col] = 0
    
    # Ensure same column order as X
    encoded_patient = encoded_patient[X.columns]
    
    # Scale features
    scaled_patient = pd.DataFrame(scaler.transform(encoded_patient), columns=encoded_patient.columns)
    
    # Make prediction
    if hasattr(model, 'predict_proba'):
        risk_proba = model.predict_proba(scaled_patient)[0, 1]
        risk_level = 1 if risk_proba >= 0.5 else 0
    else:
        risk_level = model.predict(scaled_patient)[0]
        risk_proba = None
    
    # Get SHAP values
    patient_shap_values = explainer.shap_values(scaled_patient)
    if isinstance(patient_shap_values, list) and len(patient_shap_values) > 1:
        patient_shap_class1 = patient_shap_values[1]
        expected_value = explainer.expected_value[1] if isinstance(explainer.expected_value, list) else explainer.expected_value
    else:
        patient_shap_class1 = patient_shap_values
        expected_value = explainer.expected_value
    
    # Get top contributing features
    shap_df = pd.DataFrame({
        'Feature': scaled_patient.columns,
        'SHAP Value': patient_shap_class1[0],
        'Absolute Value': np.abs(patient_shap_class1[0])
    })
    if feature_mapping:
        shap_df['Interpretable Name'] = shap_df['Feature'].map(feature_mapping)
    
    # Sort by absolute impact
    shap_df = shap_df.sort_values('Absolute Value', ascending=False).reset_index(drop=True)
    top_factors = shap_df.head(5)
    
    # Prepare risk assessment
    risk_assessment = {
        'patient_profile': patient_data.iloc[0].to_dict(),
        'risk_level': 'High' if risk_level == 1 else 'Low',
        'risk_probability': risk_proba,
        'top_risk_factors': top_factors,
        'shap_values': patient_shap_class1,
        'expected_value': expected_value,
        'scaled_patient': scaled_patient
    }
    
    # Display risk assessment
    print(f"\n{'='*80}")
    print(f"STROKE RISK ASSESSMENT")
    print(f"{'='*80}")
    print(f"Patient Profile:")
    print(f"  Age: {age}")
    print(f"  Gender: {gender}")
    print(f"  BMI: {bmi:.2f}")
    print(f"  Glucose Level: {glucose_level:.2f} mg/dL")
    print(f"  Hypertension: {'Yes' if hypertension == 1 else 'No'}")
    print(f"  Heart Disease: {'Yes' if heart_disease == 1 else 'No'}")
    print(f"  Smoking Status: {smoking_status}")
    print(f"\nRisk Assessment:")
    print(f"  Risk Level: {risk_assessment['risk_level']}")
    if risk_proba is not None:
        print(f"  Risk Probability: {risk_proba:.2%}")
    
    print(f"\nTop Risk Factors:")
    for i, (_, row) in enumerate(top_factors.iterrows(), 1):
        feature_name = row['Interpretable Name'] if 'Interpretable Name' in row else row['Feature']
        impact = 'Increases' if row['SHAP Value'] > 0 else 'Decreases'
        print(f"  {i}. {feature_name}: {impact} risk by {abs(row['SHAP Value']):.4f}")
    
    # Create force plot
    plt.figure(figsize=(20, 3))
    shap.force_plot(expected_value, patient_shap_class1[0], scaled_patient.iloc[0], 
                   matplotlib=True, show=False)
    plt.title(f"SHAP Force Plot - Risk Factors", fontsize=14)
    plt.tight_layout()
    plt.savefig(f'figures/shap_explanations/simulated_patient_force_plot.png', 
                dpi=300, bbox_inches='tight')
    plt.show()
    
    return risk_assessment

In [None]:
# Example 1: High-risk patient profile
high_risk_assessment = simulate_stroke_risk(
    age=78,
    hypertension=1,
    heart_disease=1,
    bmi=32.5,
    glucose_level=190.5,
    gender='Male',
    smoking_status='formerly smoked'
)

In [None]:
# Example 2: Low-risk patient profile
low_risk_assessment = simulate_stroke_risk(
    age=35,
    hypertension=0,
    heart_disease=0,
    bmi=24.8,
    glucose_level=92.3,
    gender='Female',
    smoking_status='never smoked'
)

In [None]:
# Example 3: Moderate-risk patient profile
moderate_risk_assessment = simulate_stroke_risk(
    age=62,
    hypertension=1,
    heart_disease=0,
    bmi=28.5,
    glucose_level=145.3,
    gender='Male',
    smoking_status='smokes'
)

## 9. Summary and Key Insights

### 9.1 Global Model Insights

Based on our SHAP analysis, we can draw the following key insights about stroke prediction:

1. **Top Predictors**: The most important features for predicting stroke were:
   - Age (especially being in older age groups)
   - Hypertension status
   - Glucose levels (especially diabetic levels)
   - Heart disease status
   - BMI (especially being obese)

2. **Feature Relationships**:
   - Age and hypertension show strong interactions - older patients with hypertension have substantially higher stroke risk
   - Glucose levels become increasingly important as they rise above diabetic thresholds (>126 mg/dL)
   - The presence of both hypertension and heart disease (comorbidity) dramatically increases stroke risk

3. **Non-modifiable vs. Modifiable Risk Factors**:
   - Non-modifiable: Age, gender
   - Modifiable: Hypertension management, glucose control, BMI, smoking status

### 9.2 Individual Risk Assessment

Our SHAP-based risk assessment tool demonstrates how individual risk profiles can be explained:

1. For high-risk patients, age is typically the strongest predictor, followed by comorbidities
2. For low-risk patients, young age and absence of comorbidities contribute most to the low-risk assessment
3. For moderate-risk patients, a mix of risk-increasing and risk-decreasing factors leads to a more nuanced prediction

### 9.3 Clinical Applications

These SHAP explanations could be valuable in clinical settings by:

1. **Personalized Risk Communication**: Providing patients with clear visualizations of their personal risk factors
2. **Intervention Planning**: Identifying which modifiable risk factors would most effectively reduce stroke risk
3. **Clinical Decision Support**: Helping clinicians understand why the model made specific predictions
4. **Educational Tool**: Teaching medical students and residents about stroke risk factors and their relative importance

### 9.4 Future Directions

To further enhance model explainability and clinical utility:

1. Integrate the SHAP explanations into a user-friendly interface for clinical use
2. Explore how interventions (e.g., reducing BMI or glucose levels) affect predicted stroke risk
3. Combine SHAP explanations with existing clinical risk scores for validation
4. Study how SHAP explanations affect patient understanding and behavior

## 10. Saving SHAP Explainer for Deployment

In [None]:
# Save the SHAP explainer for deployment
explainer_path = 'models/stroke_prediction_explainer.pkl'
with open(explainer_path, 'wb') as f:
    pickle.dump(explainer, f)
print(f"SHAP explainer saved to {explainer_path}")

# Save feature mapping for interpretation
feature_mapping_path = 'models/feature_mapping.json'
with open(feature_mapping_path, 'w') as f:
    # Convert keys to strings for JSON serialization
    mapping_str_keys = {str(k): v for k, v in feature_mapping.items()}
    json.dump(mapping_str_keys, f, indent=4)
print(f"Feature mapping saved to {feature_mapping_path}")

In [None]:
# Create a utility function for generating explanations in deployment
def generate_explanation(input_data, feature_names=None):
    """
    Generate SHAP explanation for a prediction in a deployment setting.
    
    Parameters:
    -----------
    input_data : array-like
        Input features (must be scaled and in the same format as during training)
    feature_names : list, optional
        Names of the features
        
    Returns:
    --------
    dict
        Dictionary with explanation details
    """
    # Load assets
    with open('models/stroke_prediction_model.pkl', 'rb') as f:
        model = pickle.load(f)
    
    with open('models/stroke_prediction_explainer.pkl', 'rb') as f:
        explainer = pickle.load(f)
    
    with open('models/feature_mapping.json', 'r') as f:
        feature_mapping = json.load(f)
    
    # Convert input data to correct format if needed
    if not isinstance(input_data, pd.DataFrame):
        if feature_names is not None:
            input_data = pd.DataFrame([input_data], columns=feature_names)
        else:
            input_data = pd.DataFrame([input_data])
    
    # Get prediction
    if hasattr(model, 'predict_proba'):
        pred_proba = model.predict_proba(input_data)[0, 1]
        pred_class = int(pred_proba >= 0.5)
    else:
        pred_class = model.predict(input_data)[0]
        pred_proba = None
    
    # Get SHAP values
    shap_values = explainer.shap_values(input_data)
    if isinstance(shap_values, list) and len(shap_values) > 1:
        shap_values_class1 = shap_values[1]
        expected_value = explainer.expected_value[1] if isinstance(explainer.expected_value, list) else explainer.expected_value
    else:
        shap_values_class1 = shap_values
        expected_value = explainer.expected_value
    
    # Get top contributing features
    feature_contribution = []
    for i, col in enumerate(input_data.columns):
        feature_contribution.append({
            'feature': col,
            'interpretable_name': feature_mapping.get(str(col), col),
            'shap_value': float(shap_values_class1[0][i]),
            'abs_value': float(abs(shap_values_class1[0][i]))
        })
    
    # Sort by absolute contribution
    feature_contribution.sort(key=lambda x: x['abs_value'], reverse=True)
    
    # Prepare explanation
    explanation = {
        'prediction': {
            'class': pred_class,
            'probability': pred_proba,
            'label': 'Stroke' if pred_class == 1 else 'No Stroke'
        },
        'base_value': float(expected_value),
        'feature_contribution': feature_contribution,
        'top_contributors': feature_contribution[:5],
        'risk_level': 'High' if pred_class == 1 else 'Low'
    }
    
    return explanation

# Save the utility function as a Python module
with open('models/stroke_explainer_utils.py', 'w') as f:
    f.write("""import numpy as np
import pandas as pd
import pickle
import json

def generate_explanation(input_data, feature_names=None):
    \"\"\"
    Generate SHAP explanation for a prediction in a deployment setting.
    
    Parameters:
    -----------
    input_data : array-like
        Input features (must be scaled and in the same format as during training)
    feature_names : list, optional
        Names of the features
        
    Returns:
    --------
    dict
        Dictionary with explanation details
    \"\"\"
    # Load assets
    with open('models/stroke_prediction_model.pkl', 'rb') as f:
        model = pickle.load(f)
    
    with open('models/stroke_prediction_explainer.pkl', 'rb') as f:
        explainer = pickle.load(f)
    
    with open('models/feature_mapping.json', 'r') as f:
        feature_mapping = json.load(f)
    
    # Convert input data to correct format if needed
    if not isinstance(input_data, pd.DataFrame):
        if feature_names is not None:
            input_data = pd.DataFrame([input_data], columns=feature_names)
        else:
            input_data = pd.DataFrame([input_data])
    
    # Get prediction
    if hasattr(model, 'predict_proba'):
        pred_proba = model.predict_proba(input_data)[0, 1]
        pred_class = int(pred_proba >= 0.5)
    else:
        pred_class = model.predict(input_data)[0]
        pred_proba = None
    
    # Get SHAP values
    shap_values = explainer.shap_values(input_data)
    if isinstance(shap_values, list) and len(shap_values) > 1:
        shap_values_class1 = shap_values[1]
        expected_value = explainer.expected_value[1] if isinstance(explainer.expected_value, list) else explainer.expected_value
    else:
        shap_values_class1 = shap_values
        expected_value = explainer.expected_value
    
    # Get top contributing features
    feature_contribution = []
    for i, col in enumerate(input_data.columns):
        feature_contribution.append({
            'feature': col,
            'interpretable_name': feature_mapping.get(str(col), col),
            'shap_value': float(shap_values_class1[0][i]),
            'abs_value': float(abs(shap_values_class1[0][i]))
        })
    
    # Sort by absolute contribution
    feature_contribution.sort(key=lambda x: x['abs_value'], reverse=True)
    
    # Prepare explanation
    explanation = {
        'prediction': {
            'class': pred_class,
            'probability': pred_proba,
            'label': 'Stroke' if pred_class == 1 else 'No Stroke'
        },
        'base_value': float(expected_value),
        'feature_contribution': feature_contribution,
        'top_contributors': feature_contribution[:5],
        'risk_level': 'High' if pred_class == 1 else 'Low'
    }
    
    return explanation
""")

print("Utility function saved to models/stroke_explainer_utils.py")

## Conclusion

This notebook has demonstrated how to use SHAP to explain the predictions of our stroke prediction model. We've explored:

1. **Global interpretability**: Understanding which features are most important across the dataset
2. **Local interpretability**: Explaining individual predictions for specific patients
3. **Feature interactions**: Visualizing how features interact to influence stroke risk
4. **Risk simulation**: Creating a tool to simulate and explain stroke risk for different patient profiles

These explainability techniques are essential for healthcare applications, where understanding model decisions is as important as the accuracy of those decisions. By leveraging SHAP, we've made our stroke prediction model more transparent and trustworthy, potentially increasing its utility in clinical settings.