# 1) Import required libraries

import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from IPython.display import IFrame, HTML, display

import joblib
import shap

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from xgboost import XGBClassifier

# Project util
from src.explainability.shap_explainer import SHAPExplainer

# Notebook display settings
%matplotlib inline
sns.set(style='whitegrid')
np.random.seed(42)


In [None]:
# 2) Constants and configuration

DATA_PATH = Path(os.getenv('FRAUD_PROCESSED', 'data/processed/Fraud_Data_processed.csv'))
MODEL_PATH = Path(os.getenv('FRAUD_MODEL', 'modelsave/best_model.joblib'))
OUTPUT_DIR = Path(os.getenv('EXPLAIN_OUTPUT', 'outputs/explainability'))
TARGET_COL = 'class'

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
Path('modelsave').mkdir(parents=True, exist_ok=True)
Path('docs').mkdir(parents=True, exist_ok=True)

print(f"DATA_PATH={DATA_PATH}, MODEL_PATH={MODEL_PATH}, OUTPUT_DIR={OUTPUT_DIR}")


In [None]:
# 3) Load and preprocess data (with fallback to small synthetic sample)

if DATA_PATH.exists():
    df = pd.read_csv(DATA_PATH)
    print('Loaded processed data:', df.shape)
else:
    print('Processed data not found; creating a small synthetic dataset for demo')
    rng = np.random.RandomState(42)
    size = 2000
    df = pd.DataFrame({
        'amount': np.abs(rng.normal(100, 50, size)),
        'time_since_signup_h': rng.exponential(scale=24, size=size),
        'device_is_new': rng.randint(0, 2, size=size),
        'ip_risk_score': rng.uniform(0, 1, size),
        'transaction_hour': rng.randint(0, 24, size=size),
        'class': (rng.rand(size) < 0.05).astype(int)
    })

assert TARGET_COL in df.columns, f"Target column '{TARGET_COL}' not present"

# Small preprocessing: drop rows with missing target, sample for speed
df = df.dropna(subset=[TARGET_COL])
if len(df) > 5000:
    df = df.sample(5000, random_state=42)

X = df.drop(columns=[TARGET_COL])
y = df[TARGET_COL]

# Train/test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
print('Train/Test shapes:', X_train.shape, X_test.shape)


In [None]:
# 4) Load or train model

if MODEL_PATH.exists():
    print('Loading model from', MODEL_PATH)
    model = joblib.load(MODEL_PATH)
else:
    print('Model not found; training a small XGBoost model for this notebook')
    model = XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42, n_jobs=4)
    model.fit(X_train, y_train)
    joblib.dump(model, MODEL_PATH)
    print('Saved model to', MODEL_PATH)

# Feature names
feature_names = list(X.columns)


In [None]:
# 5) Built-in feature importance (top 10)

import numpy as np
import matplotlib.pyplot as plt

if hasattr(model, 'feature_importances_'):
    importances = model.feature_importances_
elif hasattr(model, 'coef_'):
    importances = np.abs(model.coef_).ravel()
else:
    importances = None

if importances is not None:
    idx = np.argsort(importances)[::-1]
    top_idx = idx[:10]
    top_feats = [feature_names[i] for i in top_idx]
    top_vals = importances[top_idx]

    plt.figure(figsize=(8,6))
    plt.barh(range(len(top_feats)), top_vals[::-1])
    plt.yticks(range(len(top_feats)), top_feats[::-1])
    plt.xlabel('Importance')
    plt.title('Top 10 Feature Importances')
    plt.tight_layout()
    fp = OUTPUT_DIR / 'feature_importance_top10.png'
    plt.savefig(fp, dpi=150)
    plt.show()
    print('Saved feature importance to', fp)
else:
    print('Model does not expose importances')


In [None]:
# 6) SHAP analysis: compute and save summary plot

expl = SHAPExplainer(model, feature_names)

# Use a subset for speed in notebooks
X_shap = X_test.copy()
if len(X_shap) > 2000:
    X_shap = X_shap.sample(2000, random_state=42)

start = time.perf_counter()
shap_values = expl.explain(X_shap)
end = time.perf_counter()
print(f"Computed SHAP values for {len(X_shap)} rows in {end-start:.2f}s")

# Save summary plot
summary_fp = OUTPUT_DIR / 'shap_summary.png'
expl.plot_summary_save(shap_values, X_shap, str(summary_fp))
print('Saved SHAP summary to', summary_fp)

# Display summary inline
from IPython.display import Image
display(Image(str(summary_fp)))


In [None]:
# 7) Identify TP, FP, FN examples and show rows

probs = model.predict_proba(X_test)[:,1]
preds = (probs >= 0.5).astype(int)

# convert y_test to aligned index
y_test_aligned = y_test.reset_index(drop=True)

def find_index(condition):
    for i in range(len(y_test_aligned)):
        if condition(i):
            return i
    return None

idx_tp = find_index(lambda i: y_test_aligned[i] == 1 and preds[i] == 1)
idx_fp = find_index(lambda i: y_test_aligned[i] == 0 and preds[i] == 1)
idx_fn = find_index(lambda i: y_test_aligned[i] == 1 and preds[i] == 0)

print('Indices found - TP:', idx_tp, 'FP:', idx_fp, 'FN:', idx_fn)

if idx_tp is not None:
    display(X_test.reset_index(drop=True).iloc[idx_tp])
if idx_fp is not None:
    display(X_test.reset_index(drop=True).iloc[idx_fp])
if idx_fn is not None:
    display(X_test.reset_index(drop=True).iloc[idx_fn])


In [None]:
# 8) Force plots for selected examples (saved as HTML)

if idx_tp is not None:
    fp_tp = OUTPUT_DIR / 'shap_force_tp.html'
    expl.plot_force_save(shap_values, X_shap.reset_index(drop=True), idx_tp, str(fp_tp))
    display(HTML(f"<a href='{fp_tp}' target='_blank'>Open TP force plot</a>"))

if idx_fp is not None:
    fp_fp = OUTPUT_DIR / 'shap_force_fp.html'
    expl.plot_force_save(shap_values, X_shap.reset_index(drop=True), idx_fp, str(fp_fp))
    display(HTML(f"<a href='{fp_fp}' target='_blank'>Open FP force plot</a>"))

if idx_fn is not None:
    fp_fn = OUTPUT_DIR / 'shap_force_fn.html'
    expl.plot_force_save(shap_values, X_shap.reset_index(drop=True), idx_fn, str(fp_fn))
    display(HTML(f"<a href='{fp_fn}' target='_blank'>Open FN force plot</a>"))

print('Saved force plot files to', OUTPUT_DIR)


In [None]:
# 9) Top drivers and short interpretation

mean_abs_shap = np.mean(np.abs(shap_values), axis=0)
order = np.argsort(mean_abs_shap)[::-1]

top5_idx = order[:5]
pd.DataFrame({
    'feature': [feature_names[i] for i in top5_idx],
    'mean_abs_shap': mean_abs_shap[top5_idx]
})


In [None]:
# 10) Simple unit tests (inline)

def _test_shap_small_flow():
    # sanity: SHAP explainer runs on small subset
    small_X = X_train.head(50)
    expl_local = SHAPExplainer(model, feature_names)
    shap_vals_local = expl_local.explain(small_X)
    assert shap_vals_local.shape[0] == small_X.shape[0]
    assert shap_vals_local.shape[1] == small_X.shape[1]

print('Running quick inline test...')
_test_shap_small_flow()
print('Inline tests passed')


In [None]:
# 11) Save outputs and provide quick run instructions

# Save top 5 drivers to CSV
drivers_fp = OUTPUT_DIR / 'top5_shap_drivers.csv'
pd.DataFrame({
    'feature': [feature_names[i] for i in top5_idx],
    'mean_abs_shap': mean_abs_shap[top5_idx]
}).to_csv(drivers_fp, index=False)
print('Saved top 5 drivers to', drivers_fp)

# List output files
print('\nSaved files:')
for p in sorted(OUTPUT_DIR.glob('*')):
    print('-', p)

# Final notes
display(HTML("<h4>Run the batch explainability script:</h4><pre>python scripts/explain_shap.py --data data/processed/Fraud_Data_processed.csv --output outputs/explainability</pre>"))
display(HTML("<p>Report file: <code>docs/TASK_3_SHAP.md</code></p>"))


In [None]:
# 12) Recommendations and interpretation (markdown)

HTML('''
<h3>Interpretation & Business Recommendations</h3>
<ul>
<li><b>Top drivers:</b> The top features by mean(|SHAP|) indicate the variables that most strongly push predictions toward fraud vs. legit.</li>
<li><b>Recommendation 1:</b> Transactions with high positive SHAP contributions from <i>ip_risk_score</i> or <i>device_is_new</i> should receive additional verification (e.g., 2FA) before approval.</li>
<li><b>Recommendation 2:</b> Flag transactions occurring within a short time since signup (high <i>time_since_signup_h</i> or low hours) combined with high amount for manual review.</li>
<li><b>Recommendation 3:</b> Use SHAP-informed thresholds to add targeted friction (step-up authentication) only for high-risk patterns to minimize false positives and customer friction.</li>
</ul>
''')
