# Model Interpretability: SHAP Analysis
## MSDS692 - Data Science Practicum
### Sai Teja Lakkapally

This notebook focuses on model interpretability using SHAP values to understand feature impacts and ensure model transparency.

In [None]:
# Import libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import shap
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
pd.set_option('display.max_columns', 50)

print("Libraries imported successfully!")
print(f"SHAP version: {shap.__version__}")

In [None]:
# Import project modules
import sys
sys.path.append('../src')

from etl import DataETL
from features import FeatureEngineer
from model import ReadmissionModel

## 1. Load Trained Models and Data

In [None]:
# Load pre-trained models
import joblib
import os

print("Loading pre-trained models...")

model_trainer = ReadmissionModel()

if os.path.exists('../models/'):
    model_trainer.load_models('../models/')
    feature_names = joblib.load('../models/feature_names.pkl')
    preprocessor = joblib.load('../models/preprocessor.pkl')
    print("✓ Models and features loaded successfully!")
else:
    print("Training new models...")
    # Train models if not saved
    etl = DataETL()
    data = etl.run_pipeline()
    
    feature_engineer = FeatureEngineer()
    X, y, feature_names = feature_engineer.prepare_features(data)
    X_balanced, y_balanced = feature_engineer.handle_imbalance(X, y, method='smote')
    
    from sklearn.model_selection import train_test_split
    X_train, X_test, y_train, y_test = train_test_split(
        X_balanced, y_balanced, test_size=0.2, random_state=42, stratify=y_balanced
    )
    
    model_trainer.train_baseline_models(X_train, y_train)
    model_trainer.train_advanced_models(X_train, y_train)
    test_results = model_trainer.evaluate_models(X_test, y_test)
    
    # Save models
    os.makedirs('../models', exist_ok=True)
    model_trainer.save_models('../models/')
    joblib.dump(feature_names, '../models/feature_names.pkl')
    joblib.dump(feature_engineer.preprocessor, '../models/preprocessor.pkl')

print(f"Best model: {model_trainer.best_model_name}")
print(f"Number of features: {len(feature_names)}")

## 2. SHAP Analysis for Best Model

In [None]:
# Prepare data for SHAP analysis
print("Preparing data for SHAP analysis...")

# Use a sample for faster computation (SHAP can be computationally intensive)
X_sample = X_test.iloc[:1000] if 'X_test' in locals() else None

if X_sample is None:
    # Generate sample data if not available
    etl = DataETL()
    data = etl.run_pipeline()
    feature_engineer = FeatureEngineer()
    X, y, feature_names = feature_engineer.prepare_features(data)
    X_sample = X.iloc[:1000]

print(f"SHAP analysis sample shape: {X_sample.shape}")

In [None]:
# Perform SHAP analysis for best model
best_model = model_trainer.best_model
best_model_name = model_trainer.best_model_name

print(f"Performing SHAP analysis for {best_model_name}...")

explainer, shap_values, shap_importance = model_trainer.shap_analysis(
    X_sample, feature_names, best_model_name
)

## 3. Comprehensive SHAP Visualizations

In [None]:
# 1. SHAP Summary Plot (Bee Swarm)
print("1. SHAP Summary Plot (Bee Swarm)")
print("-" * 40)

plt.figure(figsize=(12, 8))
shap.summary_plot(shap_values, X_sample, feature_names=feature_names, show=False)
plt.title(f'SHAP Summary Plot - {best_model_name.replace("_", " ").title()}', 
          fontsize=16, fontweight='bold', pad=20)
plt.tight_layout()
plt.show()

In [None]:
# 2. SHAP Bar Plot (Feature Importance)
print("2. SHAP Feature Importance")
print("-" * 40)

plt.figure(figsize=(12, 8))
shap.summary_plot(shap_values, X_sample, feature_names=feature_names, plot_type="bar", show=False)
plt.title(f'SHAP Feature Importance - {best_model_name.replace("_", " ").title()}', 
          fontsize=16, fontweight='bold', pad=20)
plt.tight_layout()
plt.show()

In [None]:
# 3. Custom Feature Importance Comparison
print("3. Feature Importance Comparison: SHAP vs Model")
print("-" * 40)

# Get model feature importance if available
if best_model_name in model_trainer.feature_importance:
    model_importance = model_trainer.feature_importance[best_model_name]
    
    # Create comparison DataFrame
    comparison_df = pd.DataFrame({
        'feature': feature_names,
        'shap_importance': np.abs(shap_values).mean(axis=0),
        'model_importance': model_importance.set_index('feature').reindex(feature_names)['importance'].values
    })
    
    # Normalize importances
    comparison_df['shap_importance_norm'] = comparison_df['shap_importance'] / comparison_df['shap_importance'].sum()
    comparison_df['model_importance_norm'] = comparison_df['model_importance'] / comparison_df['model_importance'].sum()
    
    # Get top 15 features
    top_features = comparison_df.nlargest(15, 'shap_importance_norm')
    
    # Plot comparison
    fig, ax = plt.subplots(figsize=(14, 10))
    
    x = np.arange(len(top_features))
    width = 0.35
    
    bars1 = ax.barh(x - width/2, top_features['shap_importance_norm'], width, 
                    label='SHAP Importance', color='skyblue', alpha=0.8)
    bars2 = ax.barh(x + width/2, top_features['model_importance_norm'], width, 
                    label='Model Importance', color='lightcoral', alpha=0.8)
    
    ax.set_yticks(x)
    ax.set_yticklabels(top_features['feature'])
    ax.set_xlabel('Normalized Importance')
    ax.set_title('Feature Importance: SHAP vs Model', fontsize=16, fontweight='bold')
    ax.legend()
    
    plt.gca().invert_yaxis()
    plt.tight_layout()
    plt.show()
    
    # Display correlation between importances
    correlation = comparison_df['shap_importance_norm'].corr(comparison_df['model_importance_norm'])
    print(f"Correlation between SHAP and Model importance: {correlation:.4f}")

## 4. Feature Impact Analysis by Category

In [None]:
# Categorize features
def categorize_feature(feature_name):
    """Categorize features into medical, SDOH, or demographic"""
    medical_keywords = ['length_of_stay', 'num_', 'number_', 'prior_admissions', 
                       'comorbidity', 'utilization', 'procedure', 'medications',
                       'diagnoses', 'lab_procedures']
    
    sdoh_keywords = ['svi', 'rpl', 'theme', 'vulnerability', 'economic', 
                    'hardship', 'access_barrier', 'isolation', 'pov', 'unemp',
                    'nohsdp', 'noveh']
    
    demographic_keywords = ['age', 'gender', 'race']
    
    feature_lower = feature_name.lower()
    
    for keyword in medical_keywords:
        if keyword in feature_lower:
            return 'Medical'
    
    for keyword in sdoh_keywords:
        if keyword in feature_lower:
            return 'SDOH'
    
    for keyword in demographic_keywords:
        if keyword in feature_lower:
            return 'Demographic'
    
    return 'Other'

# Add categories to importance data
shap_importance_df = pd.DataFrame({
    'feature': feature_names,
    'shap_importance': np.abs(shap_values).mean(axis=0)
})

shap_importance_df['category'] = shap_importance_df['feature'].apply(categorize_feature)
shap_importance_df = shap_importance_df.sort_values('shap_importance', ascending=False)

print("Feature Importance by Category:")
print("=" * 50)

category_summary = shap_importance_df.groupby('category').agg({
    'shap_importance': ['sum', 'mean', 'count']
}).round(4)

category_summary.columns = ['Total Importance', 'Average Importance', 'Feature Count']
print(category_summary)

In [None]:
# Visualize importance by category
plt.figure(figsize=(12, 8))

category_importance = shap_importance_df.groupby('category')['shap_importance'].sum().sort_values(ascending=True)

colors = {'Medical': '#3498db', 'SDOH': '#e74c3c', 'Demographic': '#2ecc71', 'Other': '#f39c12'}
bar_colors = [colors.get(cat, '#95a5a6') for cat in category_importance.index]

bars = plt.barh(category_importance.index, category_importance.values, color=bar_colors, alpha=0.8)

plt.xlabel('Total SHAP Importance')
plt.title('Total Feature Importance by Category', fontsize=16, fontweight='bold')
plt.grid(axis='x', alpha=0.3)

# Add value labels
for bar in bars:
    width = bar.get_width()
    plt.text(width + 0.001, bar.get_y() + bar.get_height()/2, 
             f'{width:.4f}', ha='left', va='center')

plt.tight_layout()
plt.show()

## 5. Individual Prediction Explanations

In [None]:
# Analyze specific predictions
print("5. Individual Prediction Explanations")
print("-" * 40)

# Get a few example predictions
sample_indices = [0, 1, 2]  # First three samples

for i, idx in enumerate(sample_indices):
    actual_label = y_test.iloc[idx] if 'y_test' in locals() else "Unknown"
    prediction = best_model.predict_proba(X_sample.iloc[[idx]])[0, 1]
    
    print(f"\nSample {i+1}:")
    print(f"  Predicted probability of readmission: {prediction:.4f}")
    print(f"  Actual readmission status: {actual_label}")
    
    # Get SHAP values for this prediction
    shap_value_single = shap_values[idx]
    
    # Get top features influencing this prediction
    feature_effects = pd.DataFrame({
        'feature': feature_names,
        'shap_value': shap_value_single,
        'abs_effect': np.abs(shap_value_single)
    }).sort_values('abs_effect', ascending=False)
    
    print("  Top features influencing prediction:")
    for _, row in feature_effects.head(5).iterrows():
        direction = "increases" if row['shap_value'] > 0 else "decreases"
        print(f"    • {row['feature']}: {direction} risk (impact: {row['shap_value']:.4f})")

In [None]:
# Force plot for a single prediction
print("\nForce Plot for Sample Prediction:")
print("-" * 40)

# Use the first sample for demonstration
sample_idx = 0
shap.force_plot(
    explainer.expected_value[1] if hasattr(explainer.expected_value, '__len__') else explainer.expected_value,
    shap_values[sample_idx],
    X_sample.iloc[sample_idx],
    feature_names=feature_names,
    matplotlib=True,
    show=False
)

plt.title(f'SHAP Force Plot - Sample {sample_idx}', fontweight='bold')
plt.tight_layout()
plt.show()

## 6. Dependence Plots for Key Features

In [None]:
# Dependence plots for top features
print("6. Feature Dependence Plots")
print("-" * 40)

top_features = shap_importance_df.head(6)['feature'].tolist()

for i, feature in enumerate(top_features[:3]):  # Plot first 3 for clarity
    print(f"\nDependence plot for: {feature}")
    
    plt.figure(figsize=(10, 6))
    shap.dependence_plot(
        feature, 
        shap_values, 
        X_sample, 
        feature_names=feature_names,
        show=False
    )
    
    plt.title(f'SHAP Dependence Plot: {feature}', fontweight='bold')
    plt.tight_layout()
    plt.show()

## 7. SDOH Impact Analysis

In [None]:
# Focus on SDOH features
print("7. Social Determinants of Health Impact Analysis")
print("=" * 50)

sdoh_features = shap_importance_df[shap_importance_df['category'] == 'SDOH']

print("Top SDOH Features by Impact:")
print("-" * 40)

for _, row in sdoh_features.head(10).iterrows():
    print(f"{row['feature']}: {row['shap_importance']:.4f}")

# Visualize SDOH feature impacts
plt.figure(figsize=(12, 8))

top_sdoh = sdoh_features.head(10)

# Get the actual SHAP values for these features to see direction
sdoh_impacts = []
for feature in top_sdoh['feature']:
    feature_idx = feature_names.index(feature)
    mean_impact = shap_values[:, feature_idx].mean()
    sdoh_impacts.append(mean_impact)

colors = ['red' if impact > 0 else 'blue' for impact in sdoh_impacts]

bars = plt.barh(top_sdoh['feature'], top_sdoh['shap_importance'], color=colors, alpha=0.7)

plt.xlabel('Mean |SHAP Value| (Impact Magnitude)')
plt.title('Top SDOH Features: Impact on Readmission Prediction', fontsize=16, fontweight='bold')
plt.gca().invert_yaxis()
plt.grid(axis='x', alpha=0.3)

# Add value labels
for bar, impact in zip(bars, sdoh_impacts):
    width = bar.get_width()
    direction = "↑ increases" if impact > 0 else "↓ decreases"
    plt.text(width + 0.001, bar.get_y() + bar.get_height()/2, 
             f'{width:.4f} ({direction})', ha='left', va='center', fontsize=9)

plt.tight_layout()
plt.show()

## 8. Waterfall Plot for High-Risk Patient

In [None]:
# Find a high-risk prediction for detailed explanation
print("8. Waterfall Plot for High-Risk Patient")
print("-" * 40)

if 'y_test' in locals():
    # Find a patient with high predicted risk
    high_risk_idx = best_model.predict_proba(X_sample)[:, 1].argmax()
    high_risk_prob = best_model.predict_proba(X_sample.iloc[[high_risk_idx]])[0, 1]
    actual_status = y_test.iloc[high_risk_idx]
    
    print(f"High-risk patient analysis:")
    print(f"  Predicted readmission probability: {high_risk_prob:.4f}")
    print(f"  Actual readmission status: {actual_status}")
    
    # Create waterfall plot
    plt.figure(figsize=(14, 8))
    
    shap.waterfall_plot(
        shap.Explanation(
            values=shap_values[high_risk_idx],
            base_values=explainer.expected_value[1] if hasattr(explainer.expected_value, '__len__') else explainer.expected_value,
            data=X_sample.iloc[high_risk_idx],
            feature_names=feature_names
        ),
        max_display=15,
        show=False
    )
    
    plt.title(f'SHAP Waterfall Plot: High-Risk Patient (Probability: {high_risk_prob:.3f})', 
              fontsize=16, fontweight='bold', pad=20)
    plt.tight_layout()
    plt.show()
else:
    print("Test labels not available for high-risk patient analysis")

## 9. Summary and Key Insights

In [None]:
# Final interpretability summary
print("MODEL INTERPRETABILITY SUMMARY")
print("=" * 60)

print(f"\n1. OVERALL INSIGHTS:")
print(f"   • Best Model: {best_model_name.replace('_', ' ').title()}")
print(f"   • Total Features Analyzed: {len(feature_names)}")
print(f"   • Samples for SHAP: {X_sample.shape[0]}")

print(f"\n2. FEATURE CATEGORY IMPACT:")
category_impact = shap_importance_df.groupby('category')['shap_importance'].sum()
total_impact = category_impact.sum()

for category, impact in category_impact.items():
    percentage = (impact / total_impact) * 100
    print(f"   • {category}: {percentage:.1f}% of total impact")

print(f"\n3. TOP 5 MOST IMPORTANT FEATURES:")
top_5 = shap_importance_df.head(5)
for i, (_, row) in enumerate(top_5.iterrows(), 1):
    print(f"   {i}. {row['feature']} ({row['category']}): {row['shap_importance']:.4f}")

print(f"\n4. SDOH IMPACT CONFIRMATION:")
sdoh_total_impact = category_impact.get('SDOH', 0)
sdoh_percentage = (sdoh_total_impact / total_impact) * 100
print(f"   • SDOH features contribute {sdoh_percentage:.1f}% to model predictions")
print(f"   • This confirms the importance of social factors in readmission risk")

print(f"\n5. MODEL TRANSPARENCY:")
print(f"   ✓ Feature impacts quantified and visualized")
print(f"   ✓ Individual predictions explainable")
print(f"   ✓ SDOH impact clearly demonstrated")
print(f"   ✓ Model decisions are interpretable")

print("\n" + "=" * 60)
print("INTERPRETABILITY ANALYSIS COMPLETED!")
print("=" * 60)

## Key Business Implications

1. **Clinical Decision Support**: The model identifies both clinical and social risk factors, enabling comprehensive patient assessment.

2. **Resource Allocation**: Healthcare systems can prioritize interventions based on both medical needs and social vulnerabilities.

3. **Policy Insights**: The significant impact of SDOH features supports policies addressing social determinants.

4. **Health Equity**: Transparent model explanations help ensure fair and equitable care decisions.