# Hospital Readmission Prediction Project
## Phase 8: Explainability (SHAP Analysis)

**Author:** Vindya Siriwardhana  
**Final Model:** Logistic Regression (Tuned)  
**Model Performance:**
- AUC-ROC: 0.6380
- Recall: 51.36% (catches half of readmissions)
- Precision: 16.66% (intentional trade-off for healthcare)

---
## SETUP & LOAD MODEL

In [None]:
# Import libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import shap
import warnings

warnings.filterwarnings('ignore')
plt.style.use('seaborn-v0_8-darkgrid')
%matplotlib inline

# Initialize SHAP's JavaScript visualizations
shap.initjs()

print("‚úÖ Libraries imported successfully!")

In [None]:
# Load prepared data
with open('/home/claude/hospital_readmission_prepared_data.pkl', 'rb') as f:
    data = pickle.load(f)

X_train = data['X_train']
X_test = data['X_test']
y_train = data['y_train']
y_test = data['y_test']
feature_names = data['feature_names']

print("‚úÖ Data loaded!")
print(f"\nüìä Test set: {X_test.shape}")
print(f"üìä Features: {len(feature_names)}")

In [None]:
# Load final model
with open('/home/claude/hospital_readmission_final_model.pkl', 'rb') as f:
    model_artifacts = pickle.load(f)

final_model = model_artifacts['model']
model_name = model_artifacts['model_name']
metrics = model_artifacts['metrics']

print("‚úÖ Model loaded!")
print(f"\nüèÜ Model: {model_name}")
print(f"üìä AUC-ROC: {metrics['auc']:.4f}")
print(f"üìä Recall: {metrics['recall']:.4f}")

---
## PHASE 8: EXPLAINABILITY (Steps 23-25)

### Step 23: Feature Importance Analysis

In [None]:
print("\n" + "="*80)
print("STEP 23: FEATURE IMPORTANCE ANALYSIS")
print("="*80)

# For Logistic Regression, use coefficients as importance
if hasattr(final_model, 'coef_'):
    # Get absolute coefficients
    importance = np.abs(final_model.coef_[0])
    
    feature_importance = pd.DataFrame({
        'Feature': feature_names,
        'Importance': importance
    }).sort_values('Importance', ascending=False)
    
    print("\nüìä TOP 15 MOST IMPORTANT FEATURES:")
    print(feature_importance.head(15).to_string(index=False))
    
elif hasattr(final_model, 'feature_importances_'):
    # For tree-based models
    feature_importance = pd.DataFrame({
        'Feature': feature_names,
        'Importance': final_model.feature_importances_
    }).sort_values('Importance', ascending=False)
    
    print("\nüìä TOP 15 MOST IMPORTANT FEATURES:")
    print(feature_importance.head(15).to_string(index=False))

else:
    print("\n‚ö†Ô∏è Model doesn't have feature_importances_ or coef_ attribute")
    feature_importance = None

In [None]:
# Visualize feature importance
if feature_importance is not None:
    plt.figure(figsize=(12, 8))
    top_features = feature_importance.head(20)
    
    plt.barh(range(len(top_features)), top_features['Importance'], color='steelblue')
    plt.yticks(range(len(top_features)), top_features['Feature'])
    plt.xlabel('Importance (Absolute Coefficient)', fontsize=12)
    plt.title('Top 20 Feature Importances', fontsize=14, fontweight='bold')
    plt.gca().invert_yaxis()
    plt.grid(axis='x', alpha=0.3)
    plt.tight_layout()
    plt.show()
    
    print("\n‚úÖ Feature importance visualized!")

### Step 24: SHAP Values Implementation

In [None]:
print("\n" + "="*80)
print("STEP 24: SHAP VALUES IMPLEMENTATION")
print("="*80)

print("\nüîß Creating SHAP explainer...")
print("‚è≥ This may take 2-3 minutes...")

# Create SHAP explainer based on model type
if 'Logistic' in model_name:
    # For linear models, use LinearExplainer
    explainer = shap.LinearExplainer(final_model, X_train)
    print("‚úÖ Using LinearExplainer for Logistic Regression")
    
elif 'XGBoost' in model_name or 'Random Forest' in model_name:
    # For tree-based models, use TreeExplainer
    explainer = shap.TreeExplainer(final_model)
    print(f"‚úÖ Using TreeExplainer for {model_name}")
    
else:
    # Fallback to KernelExplainer (slower but works for all models)
    explainer = shap.KernelExplainer(final_model.predict_proba, shap.sample(X_train, 100))
    print("‚úÖ Using KernelExplainer (general purpose)")

print("\nüîß Calculating SHAP values for test set...")
print("‚è≥ This may take 3-5 minutes...")

# Calculate SHAP values for test set
shap_values = explainer.shap_values(X_test)

# For binary classification, some explainers return list [class0, class1]
if isinstance(shap_values, list):
    shap_values = shap_values[1]  # Use positive class (readmission)

print("\n‚úÖ SHAP values calculated!")
print(f"üìä SHAP values shape: {shap_values.shape}")

### Step 25: SHAP Visualizations

In [None]:
print("\n" + "="*80)
print("STEP 25: SHAP VISUALIZATIONS")
print("="*80)

#### 1. SHAP Summary Plot (Feature Importance)

In [None]:
print("\nüìä 1. SHAP Summary Plot (Feature Importance)")
print("   Shows which features have the biggest impact on predictions")

plt.figure(figsize=(12, 8))
shap.summary_plot(shap_values, X_test, feature_names=feature_names, show=False, max_display=20)
plt.title('SHAP Feature Importance', fontsize=14, fontweight='bold', pad=20)
plt.tight_layout()
plt.show()

print("‚úÖ Summary plot created!")

#### 2. SHAP Summary Plot (Detailed with Values)

In [None]:
print("\nüìä 2. SHAP Summary Plot (Detailed)")
print("   Shows how feature values affect predictions")
print("   Red = High feature value, Blue = Low feature value")

plt.figure(figsize=(12, 8))
shap.summary_plot(shap_values, X_test, feature_names=feature_names, 
                 plot_type='dot', show=False, max_display=20)
plt.title('SHAP Values by Feature (Detailed)', fontsize=14, fontweight='bold', pad=20)
plt.tight_layout()
plt.show()

print("‚úÖ Detailed summary plot created!")

#### 3. SHAP Dependence Plots (Top Features)

In [None]:
print("\nüìä 3. SHAP Dependence Plots (Top 4 Features)")
print("   Shows relationship between feature value and prediction impact")

# Get top 4 features by absolute SHAP value
mean_abs_shap = np.abs(shap_values).mean(axis=0)
top_feature_indices = np.argsort(mean_abs_shap)[-4:][::-1]

fig, axes = plt.subplots(2, 2, figsize=(16, 12))
axes = axes.ravel()

for idx, feat_idx in enumerate(top_feature_indices):
    shap.dependence_plot(feat_idx, shap_values, X_test, 
                        feature_names=feature_names,
                        show=False, ax=axes[idx])
    axes[idx].set_title(f'Dependence Plot: {feature_names[feat_idx]}', 
                       fontsize=12, fontweight='bold')

plt.tight_layout()
plt.show()

print("‚úÖ Dependence plots created!")

#### 4. Individual Patient Explanations (Force Plots)

In [None]:
print("\nüìä 4. Individual Patient Explanations (Force Plots)")
print("   Shows why model predicted high/low risk for specific patients")

# Get predictions
y_pred_proba = final_model.predict_proba(X_test)[:, 1]

# Find examples: high risk patient (correctly identified readmission)
high_risk_correct = np.where((y_test == 1) & (y_pred_proba > 0.7))[0]
if len(high_risk_correct) > 0:
    high_risk_idx = high_risk_correct[0]
else:
    # Fallback: highest predicted probability for readmission
    high_risk_idx = np.argmax(y_pred_proba)

# Find low risk patient (correctly identified non-readmission)
low_risk_correct = np.where((y_test == 0) & (y_pred_proba < 0.3))[0]
if len(low_risk_correct) > 0:
    low_risk_idx = low_risk_correct[0]
else:
    # Fallback: lowest predicted probability
    low_risk_idx = np.argmin(y_pred_proba)

print(f"\nüî¥ HIGH RISK PATIENT (Index {high_risk_idx}):")
print(f"   Predicted probability: {y_pred_proba[high_risk_idx]:.2%}")
print(f"   Actual outcome: {'Readmitted' if y_test.iloc[high_risk_idx] == 1 else 'Not readmitted'}")

In [None]:
# Force plot for high risk patient
shap.force_plot(
    explainer.expected_value if not isinstance(explainer.expected_value, np.ndarray) else explainer.expected_value[1],
    shap_values[high_risk_idx],
    X_test.iloc[high_risk_idx],
    feature_names=feature_names,
    matplotlib=True,
    show=False
)
plt.title(f'High Risk Patient Explanation (Index {high_risk_idx})', fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
print(f"\nüü¢ LOW RISK PATIENT (Index {low_risk_idx}):")
print(f"   Predicted probability: {y_pred_proba[low_risk_idx]:.2%}")
print(f"   Actual outcome: {'Readmitted' if y_test.iloc[low_risk_idx] == 1 else 'Not readmitted'}")

# Force plot for low risk patient
shap.force_plot(
    explainer.expected_value if not isinstance(explainer.expected_value, np.ndarray) else explainer.expected_value[1],
    shap_values[low_risk_idx],
    X_test.iloc[low_risk_idx],
    feature_names=feature_names,
    matplotlib=True,
    show=False
)
plt.title(f'Low Risk Patient Explanation (Index {low_risk_idx})', fontsize=12, fontweight='bold')
plt.tight_layout()
plt.show()

print("\n‚úÖ Force plots created!")

#### 5. SHAP Waterfall Plot (Alternative Individual Explanation)

In [None]:
print("\nüìä 5. SHAP Waterfall Plots")
print("   Alternative way to show individual predictions")

fig, axes = plt.subplots(1, 2, figsize=(18, 6))

# High risk patient
plt.sca(axes[0])
shap.waterfall_plot(
    shap.Explanation(
        values=shap_values[high_risk_idx],
        base_values=explainer.expected_value if not isinstance(explainer.expected_value, np.ndarray) else explainer.expected_value[1],
        data=X_test.iloc[high_risk_idx].values,
        feature_names=feature_names
    ),
    show=False,
    max_display=15
)
axes[0].set_title(f'High Risk Patient (Index {high_risk_idx})', fontsize=12, fontweight='bold')

# Low risk patient
plt.sca(axes[1])
shap.waterfall_plot(
    shap.Explanation(
        values=shap_values[low_risk_idx],
        base_values=explainer.expected_value if not isinstance(explainer.expected_value, np.ndarray) else explainer.expected_value[1],
        data=X_test.iloc[low_risk_idx].values,
        feature_names=feature_names
    ),
    show=False,
    max_display=15
)
axes[1].set_title(f'Low Risk Patient (Index {low_risk_idx})', fontsize=12, fontweight='bold')

plt.tight_layout()
plt.show()

print("\n‚úÖ Waterfall plots created!")

---
## üíæ SAVE SHAP VALUES

In [None]:
print("\n" + "="*80)
print("SAVING SHAP VALUES")
print("="*80)

# Save SHAP explainer and values
shap_artifacts = {
    'explainer': explainer,
    'shap_values': shap_values,
    'X_test': X_test,
    'y_test': y_test,
    'feature_names': feature_names,
    'high_risk_example': high_risk_idx,
    'low_risk_example': low_risk_idx
}

with open('/home/claude/hospital_readmission_shap.pkl', 'wb') as f:
    pickle.dump(shap_artifacts, f)

print("\nüíæ Saved:")
print("   ‚úì SHAP explainer")
print("   ‚úì SHAP values for all test samples")
print("   ‚úì Example patient indices")

print("\nüìÅ File: hospital_readmission_shap.pkl")
print("\n‚úÖ SHAP analysis complete and saved!")

---
## üìä PHASE 8 SUMMARY & KEY INSIGHTS

In [None]:
print("\n" + "="*80)
print("PHASE 8 COMPLETED SUMMARY")
print("="*80)

print("\n‚úÖ COMPLETED STEPS:")
print("   Step 23: ‚úì Feature importance analysis")
print("   Step 24: ‚úì SHAP values calculated")
print("   Step 25: ‚úì SHAP visualizations created")
print("             ‚Ä¢ Summary plots")
print("             ‚Ä¢ Dependence plots")
print("             ‚Ä¢ Force plots (individual patients)")
print("             ‚Ä¢ Waterfall plots")

# Get top 5 features
mean_abs_shap = np.abs(shap_values).mean(axis=0)
top_5_indices = np.argsort(mean_abs_shap)[-5:][::-1]
top_5_features = [feature_names[i] for i in top_5_indices]

print("\nüìä TOP 5 MOST IMPORTANT FEATURES (by SHAP):")
for i, feat in enumerate(top_5_features, 1):
    print(f"   {i}. {feat}")

print("\nüí° KEY INSIGHTS:")
print("   ‚Ä¢ Model predictions are now explainable to doctors")
print("   ‚Ä¢ Each patient's risk score can be justified with specific factors")
print("   ‚Ä¢ SHAP values show which features push risk up or down")
print("   ‚Ä¢ Transparent AI = builds trust with clinicians")

print("\nüéØ NEXT STEPS (Phase 9):")
print("   Step 26: Build Streamlit dashboard")
print("   Step 27: Add file upload functionality")
print("   Step 28: Display predictions")
print("   Step 29: Add visualizations")
print("   Step 30: Add download functionality")

print("\n" + "="*80)
print("Ready to proceed to Phase 9: Dashboard Creation!")
print("="*80)