# 07 — Interpretability (Kaggle GPU)

Notebook layout mirrors the earlier GCN/GNN notebooks:
1. Install PyG dependencies.
2. Copy the repo snapshot (`elliptic-gnn-baselines`) into `/kaggle/working`.
3. Copy the Elliptic dataset (`elliptic-fraud-detection`) into `data/Elliptic++ Dataset`.
4. Set deterministic seeds.
5. Run the automated scripts (SHAP + GraphSAGE saliency).
6. Visualize the outputs.


## 1. Install Dependencies

In [None]:
!pip install torch-geometric -q

## 2. Imports & Workspace Setup

In [None]:
import os
import shutil
import sys
from pathlib import Path

REPO_DATASET = Path('/kaggle/input/elliptic-gnn-baselines')
WORKDIR = Path('/kaggle/working/elliptic-gnn-baselines')
if REPO_DATASET.exists():
    if WORKDIR.exists():
        shutil.rmtree(WORKDIR)
    shutil.copytree(REPO_DATASET, WORKDIR)
    nested = WORKDIR / 'elliptic-gnn-baselines'
    if nested.exists():
        os.chdir(nested)
    else:
        os.chdir(WORKDIR)
else:
    print("WARNING: add the 'elliptic-gnn-baselines' dataset in Kaggle's Add Data panel.")

if str(Path.cwd()) not in sys.path:
    sys.path.insert(0, str(Path.cwd()))

DATASET_SRC = Path('/kaggle/input/elliptic-fraud-detection')
DATASET_DST = Path('data') / 'Elliptic++ Dataset'
DATASET_DST.parent.mkdir(parents=True, exist_ok=True)
if DATASET_SRC.exists():
    if DATASET_DST.exists():
        shutil.rmtree(DATASET_DST)
    shutil.copytree(DATASET_SRC, DATASET_DST)
else:
    print("WARNING: add the 'elliptic-fraud-detection' dataset as well.")

print(f"Working directory: {Path.cwd()}")
print(f"Dataset directory: {DATASET_DST}")

## 3. Set Seed

In [None]:
import random
import numpy as np
import torch

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)
print("Seeds set to 42")


## 4. Run Automated Scripts

In [None]:
!python scripts/run_m8_interpretability.py

In [None]:
!python scripts/run_m8_graphsage_local_only.py

In [None]:
!python scripts/run_m8_graphsage_saliency.py

## 5. Load SHAP Results

In [None]:
import pandas as pd
from pathlib import Path
from IPython.display import Image, display

REPORTS_DIR = Path('reports')
PLOTS_DIR = REPORTS_DIR / 'plots'

shap_csv = REPORTS_DIR / 'm8_xgb_shap_importance.csv'
shap_df = pd.read_csv(shap_csv)
shap_df.head(10)


In [None]:
plot_path = PLOTS_DIR / 'm8_xgb_shap_summary.png'
if plot_path.exists():
    display(Image(filename=str(plot_path)))
else:
    print(f"Plot not found: {plot_path}")


## 6. Load GraphSAGE Saliency Outputs

In [None]:
import json
saliency_json = REPORTS_DIR / 'm8_graphsage_saliency.json'
with open(saliency_json, 'r', encoding='utf-8') as fp:
    saliency_data = json.load(fp)
print(f"Loaded {len(saliency_data)} node explanations")

rows = []
for entry in saliency_data:
    for feat in entry['top_features']:
        rows.append({'node_id': entry['node_id'], 'feature': feat['feature'], 'importance': feat['importance']})
agg_df = pd.DataFrame(rows)
agg_df.groupby('feature')['importance'].mean().sort_values(ascending=False).head(15)


In [None]:
for entry in saliency_data:
    plot = Path(entry['plot_path'])
    if plot.exists():
        print(f"Node {entry['node_id']} (label={entry['label']})")
        display(Image(filename=str(plot)))
    else:
        print(f"Missing plot: {plot}")


## 7. Observations

- SHAP confirms the tabular model leans on late-index locals plus select aggregates.
- GraphSAGE saliency highlights AF80–AF93 locals and high-probability neighbors.
- Combined, these views explain the M7 results and motivate the discussion in docs/M8_INTERPRETABILITY.md.


## 8. Package Artifacts for Download

In [None]:
import zipfile
from pathlib import Path

OUTPUT_ZIP = Path('kaggle_results.zip')
with zipfile.ZipFile(OUTPUT_ZIP, 'w', compression=zipfile.ZIP_DEFLATED) as zf:
    files = [
        'reports/m8_xgb_shap_importance.csv',
        'reports/graphsage_local_only_metrics.json',
        'reports/m8_graphsage_saliency.json',
        'checkpoints/graphsage_local_only_best.pt',
    ]
    for f in files:
        if Path(f).exists():
            zf.write(f)
    for plot in Path('reports/plots').glob('m8_xgb_shap_summary.png'):
        zf.write(plot)
    for plot in Path('reports/plots').glob('m8_graphsage_saliency_node*.png'):
        zf.write(plot)

print(f"Bundled artifacts into {OUTPUT_ZIP}")
