In [None]:
# SHAP Analysis Notebook
import pandas as pd
import numpy as np
import shap
import matplotlib.pyplot as plt
from src.utils import load_config, setup_logging
from src.explainability import SHAPExplainer

# Setup
config = load_config("../config/params.yaml")
setup_logging()

# Load data and model
X_train = pd.read_csv("../data/processed/train_features.csv")
X_test = pd.read_csv("../data/processed/test_features.csv")
model = joblib.load("../models/best_model.pkl")
preprocessor = joblib.load("../models/preprocessor.pkl")

print("SHAP analysis starting...")

# Initialize SHAP explainer
explainer = SHAPExplainer(model, preprocessor, X_train.columns.tolist())
explainer.create_explainer(X_train)

# 1. Global feature importance
shap_importance = explainer.get_feature_importance(X_test)
shap_importance.to_csv("../results/reports/shap_importance.csv", index=False)

# 2. Summary plot
plt.figure(figsize=(10, 8))
explainer.summary_plot(X_test, show=False)
plt.savefig("../results/figures/shap/summary_plot.png", bbox_inches='tight')
plt.close()

# 3. Individual explanations
for i in range(3):  # First 3 instances
    plt.figure(figsize=(12, 4))
    explainer.force_plot(i, X_test, show=False)
    plt.savefig(f"../results/figures/shap/force_plot_{i}.png", bbox_inches='tight')
    plt.close()

print("SHAP analysis completed!")