# Task 3: Model Explainability (SHAP)

In this notebook, we will:
1. Load the trained Random Forest model.
2. Visualize built-in feature importance.
3. Perform SHAP analysis to understand model decisions.
4. Visualize key drivers for specific fraud cases (TP, FP, FN).

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import joblib
import shap
import os

# Set style
sns.set(style="whitegrid")
shap.initjs()

## 1. Setup: Load Data and Model

In [None]:
model_path = "../models/fraud_detection_RandomForest.pkl"
test_path = "../data/processed/test_enc.csv"

if not os.path.exists(model_path) or not os.path.exists(test_path):
    raise FileNotFoundError("Model or processed data not found.")

model = joblib.load(model_path)
test_df = pd.read_csv(test_path)

X_test = test_df.drop(columns=['class'])
y_test = test_df['class']

# Sample data for SHAP (calculating on full test set can be slow)
X_shap = X_test.sample(100, random_state=42)

print("Model and data loaded.")

## 2. Feature Importance (Baseline)

In [None]:
feature_importance = model.feature_importances_
features = X_test.columns

fi_df = pd.DataFrame({'Feature': features, 'Importance': feature_importance})
fi_df = fi_df.sort_values(by='Importance', ascending=False).head(10)

plt.figure(figsize=(10, 6))
sns.barplot(x='Importance', y='Feature', data=fi_df, palette='viridis')
plt.title('Top 10 Feature Importance (Random Forest Built-in)')
plt.show()

## 3. SHAP Analysis

In [None]:
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_shap)

# Handling binary classification output (shap_values is a list of arrays for class 0 and 1)
# We focus on Class 1 (Fraud)
if isinstance(shap_values, list):
    shap_values_fraud = shap_values[1]
else:
    shap_values_fraud = shap_values

print("SHAP values calculated.")

### SHAP Summary Plot

In [None]:
plt.figure(figsize=(10, 8))
shap.summary_plot(shap_values_fraud, X_shap, show=True)

### SHAP Force Plots (TP, FP, FN)

In [None]:
# Find indices for TP, FP, FN
y_pred = model.predict(X_test)

# Create a DataFrame to easily filter
results = pd.DataFrame({
    'Actual': y_test,
    'Predicted': y_pred
})

tp_indices = results[(results['Actual'] == 1) & (results['Predicted'] == 1)].index
fp_indices = results[(results['Actual'] == 0) & (results['Predicted'] == 1)].index
fn_indices = results[(results['Actual'] == 1) & (results['Predicted'] == 0)].index

print(f"Found {len(tp_indices)} TPs, {len(fp_indices)} FPs, {len(fn_indices)} FNs")

def plot_force(index, title):
    if index not in X_test.index:
        print(f"Index {index} not in tests... picking first available")
        return
    
    # Calculate shap for this single instance
    instance = X_test.loc[[index]]
    shap_single = explainer.shap_values(instance)
    if isinstance(shap_single, list):
        shap_single = shap_single[1]
        
    print(f"--- {title} (Index: {index}) ---")
    display(shap.force_plot(explainer.expected_value[1], shap_single, instance, matplotlib=False))

# Pick one of each if available
if len(tp_indices) > 0:
    plot_force(tp_indices[0], "True Positive (Correct Fraud)")

if len(fp_indices) > 0:
    plot_force(fp_indices[0], "False Positive (Legit flagged as Fraud)")

if len(fn_indices) > 0:
    plot_force(fn_indices[0], "False Negative (Missed Fraud)")

## 4. Interpretation & Recommendations

**Comparison:**
- Compare the top features from built-in importance vs SHAP.

**Top Fraud Drivers:**
- Identify which features push the prediction towards fraud (red bars in SHAP).

**Business Recommendations:**
1. ...
2. ...
3. ...