# 03 - Interpretation

This notebook explains how the best model makes its predictions.
- Uses **SHAP** for global + local feature importance.
- Optionally uses **LIME** for local instance interpretation.
- Saves interpretability plots to `/images/`.


In [ ]:
# --- Imports ---
import pandas as pd
import shap
import joblib
import matplotlib.pyplot as plt
from lime.lime_tabular import LimeTabularExplainer
from sklearn.metrics import roc_auc_score

# Load processed data and model
df = pd.read_csv('../data/processed/heart_clean.csv')
model = joblib.load('../models/final_xgb_optuna.pkl')

X = df.drop(columns=['target'])
y = df['target']
print('Data shape:', X.shape)


In [ ]:
# --- SHAP Interpretation ---
# SHAP helps explain global feature importance and per-patient explanations.
explainer = shap.Explainer(model, X)
shap_values = explainer(X)

# Global summary plot
plt.title('SHAP Summary Plot')
shap.summary_plot(shap_values, X, show=False)
plt.savefig('../images/shap_summary.png', bbox_inches='tight')
plt.show()

# Bar plot (mean absolute value)
shap.summary_plot(shap_values, X, plot_type='bar', show=False)
plt.title('Mean Absolute SHAP Values')
plt.savefig('../images/shap_importance_bar.png', bbox_inches='tight')
plt.show()


In [ ]:
# --- Dependence Plot for Top Features ---
top_feature = X.columns[0]
shap.dependence_plot(top_feature, shap_values.values, X, show=False)
plt.title(f'SHAP Dependence for {top_feature}')
plt.savefig(f'../images/shap_dependence_{top_feature}.png', bbox_inches='tight')
plt.show()


In [ ]:
# --- LIME Example (optional, for one patient) ---
import numpy as np
explainer_lime = LimeTabularExplainer(
    training_data=np.array(X),
    feature_names=X.columns,
    class_names=['No Disease', 'Disease'],
    mode='classification'
)

# Pick one random patient
i = np.random.randint(0, len(X))
exp = explainer_lime.explain_instance(X.iloc[i].values, model.predict_proba)
exp.save_to_file('../images/lime_patient_example.html')
print(f'LIME explanation saved for patient #{i}')


In [ ]:
# --- Fairness / Subgroup Comparison (if demographic columns exist) ---
if any(col in X.columns for col in ['sex', 'gender', 'age']):
    print('Running subgroup comparison...')
    if 'sex' in X.columns:
        male_auc = roc_auc_score(y[X['sex']==1], model.predict_proba(X[X['sex']==1])[:,1])
        female_auc = roc_auc_score(y[X['sex']==0], model.predict_proba(X[X['sex']==0])[:,1])
        print(f'AUC Male: {male_auc:.3f}, Female: {female_auc:.3f}')
    if 'age' in X.columns:
        median_age = X['age'].median()
        young_auc = roc_auc_score(y[X['age']<=median_age], model.predict_proba(X[X['age']<=median_age])[:,1])
        old_auc = roc_auc_score(y[X['age']>median_age], model.predict_proba(X[X['age']>median_age])[:,1])
        print(f'AUC Younger vs Older: {young_auc:.3f} / {old_auc:.3f}')


In [ ]:
# --- Summary of Insights ---
print('\nKey Insights:')
print('- Features like age, cholesterol, and chest pain type typically show strong SHAP values.')
print('- Positive SHAP value → higher heart disease risk contribution.')
print('- LIME helps explain individual predictions in an easy-to-read format.')
print('- Fairness metrics suggest whether the model behaves similarly across groups.')
