# Fraud Explainability (SHAP)
Simple, fast run to train a fraud model and generate a SHAP summary plot. Requires `data/raw/fraud/creditcard.csv` and `shap` installed (see requirements).

In [None]:
from pathlib import Path
import sys

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

project_root = Path('..').resolve()
sys.path.append(str(project_root / 'src'))

from uais.data.load_fraud_data import load_fraud_data
from uais.features.fraud_features import build_fraud_feature_table
from uais.supervised.train_fraud_supervised import FraudModelConfig, train_fraud_model
from uais.utils.metrics import best_f1_threshold

# Optional SHAP imports are handled in the next cell

In [None]:
# Load fraud data and build features
_df_raw = load_fraud_data()
df_feats = build_fraud_feature_table(_df_raw, time_column='Time', amount_column='Amount', target_column='Class')

target_col = 'Class'
X = df_feats.drop(columns=[target_col])
y = df_feats[target_col].astype(int)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, stratify=y, random_state=42
)

config = FraudModelConfig(model_type='hist_gb', max_depth=4, learning_rate=0.1, max_iter=200)
model, test_metrics = train_fraud_model(X_train, y_train, X_test, y_test, config)

print('Test metrics:')
for k, v in test_metrics.items():
    print(f"{k}: {v:.4f}")
print('Best F1 threshold (val):', best_f1_threshold(y_test.values, model.predict_proba(X_test)[:,1]))

In [None]:
# SHAP summary (best effort) and save plot to experiments/fraud/plots/shap_summary.png
try:
    import shap
    import matplotlib.pyplot as plt

    sample = X_test.sample(n=min(500, len(X_test)), random_state=42)
    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(sample)

    plots_dir = project_root / 'experiments' / 'fraud' / 'plots'
    plots_dir.mkdir(parents=True, exist_ok=True)
    plt.figure()
    shap.summary_plot(shap_values, sample, show=False)
    plt.tight_layout()
    out_path = plots_dir / 'shap_summary.png'
    plt.savefig(out_path, dpi=150)
    plt.close()
    print('Saved SHAP plot to', out_path)
except Exception as exc:
    print('SHAP generation skipped:', exc)