# 06 - Model Explainability and SHAP Analysis

This notebook focuses on model explainability and interpretability:
- SHAP values for tree-based models
- Feature importance comparison
- Individual prediction explanations
- Global interpretability analysis
- Clinical feature impact assessment


In [None]:
# Ensure repository root on sys.path for `import app.*`
import sys
from pathlib import Path
repo_root = (Path.cwd() / '..').resolve()
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))
print('Repo root:', repo_root)


In [None]:
# Import libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from lifelines import CoxPHFitter
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
import xgboost as xgb
import warnings
warnings.filterwarnings('ignore')

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10

from app.interpret import compute_hazard_ratios, compute_shap_values
from app.feature_engineering import engineer_features

print("Libraries loaded successfully!")


## Load Data and Model Results


In [None]:
# Load model results from notebook 3
import pickle
with open('../models/model_results.pkl', 'rb') as f:
    model_results = pickle.load(f)

# Extract data
models = model_results['models']
X_test = model_results['X_test']
y_test = model_results['y_test']
rf = model_results['rf']
xgb_cox = model_results['xgb_cox']
xgb_aft = model_results['xgb_aft']

print("Model results loaded successfully!")
print(f"Available models: {list(models.keys())}")
print(f"Test set shape: {X_test.shape}")
print(f"Features: {list(X_test.columns)}")


## 1. SHAP Values for Tree-Based Models


In [None]:
## 1. SHAP Values for Tree-Based Models
try:
    import shap

    # Create SHAP explainers for tree-based models
    explainers = {}
    shap_values_dict = {}

    # Random Forest SHAP
    print("Computing SHAP values for Random Forest...")
    rf_explainer = shap.TreeExplainer(rf)
    rf_shap_values = rf_explainer.shap_values(X_test)
    explainers['Random Forest'] = rf_explainer
    shap_values_dict['Random Forest'] = rf_shap_values

    # XGBoost Cox SHAP
    print("Computing SHAP values for XGBoost Cox...")
    xgb_cox_explainer = shap.TreeExplainer(xgb_cox)
    xgb_cox_shap_values = xgb_cox_explainer.shap_values(X_test)
    explainers['XGBoost Cox'] = xgb_cox_explainer
    shap_values_dict['XGBoost Cox'] = xgb_cox_shap_values

    # XGBoost AFT SHAP
    print("Computing SHAP values for XGBoost AFT...")
    xgb_aft_explainer = shap.TreeExplainer(xgb_aft)
    xgb_aft_shap_values = xgb_aft_explainer.shap_values(X_test)
    explainers['XGBoost AFT'] = xgb_aft_explainer
    shap_values_dict['XGBoost AFT'] = xgb_aft_shap_values

    print("✓ SHAP values computed successfully!")

    # Global feature importance from SHAP
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    for i, (name, shap_values) in enumerate(shap_values_dict.items()):
        # Calculate mean absolute SHAP values for feature importance
        feature_importance = np.abs(shap_values).mean(0)
        feature_names = X_test.columns

        # Sort by importance
        importance_df = pd.DataFrame({
            'feature': feature_names,
            'importance': feature_importance
        }).sort_values('importance', ascending=True)

        # Plot top 15 features
        top_features = importance_df.tail(15)
        axes[i].barh(range(len(top_features)), top_features['importance'], alpha=0.7)
        axes[i].set_yticks(range(len(top_features)))
        axes[i].set_yticklabels(top_features['feature'], fontsize=8)
        axes[i].set_xlabel('Mean |SHAP value|')
        axes[i].set_title(f'{name} Feature Importance')
        axes[i].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    # SHAP summary plots
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    for i, (name, shap_values) in enumerate(shap_values_dict.items()):
        shap.summary_plot(shap_values, X_test, show=False, ax=axes[i])
        axes[i].set_title(f'{name} SHAP Summary')

    plt.tight_layout()
    plt.show()

except ImportError:
    print("SHAP library not available. Installing with: pip install shap")
except Exception as e:
    print(f"SHAP computation failed: {e}")
    print("This might be due to compatibility issues or missing dependencies.")
