In [None]:
from xai import load_feature_names, aggregate_shap_per_class
from xai import plot_topk_bar_per_class, plot_spider
from xai import aggregate_signed_shap_per_class, plot_signed_topk
from xai import local_shap_for_edge, sample_edges_by_class
from eval_utils import make_eval_loader

import numpy as np
import json, torch, dgl
import matplotlib.pyplot as plt
from train_edgecls_dbg import _FallbackEdgeGraphSAGE

device = "cuda" if torch.cuda.is_available() else "cpu"
# fill with your training hyperparameters (should match)
# TODO improve: load from a config file
fanouts=(25,15)
batch_size=2048
hidden=128
dropout=0.3

# Load artifacts
g_val = dgl.load_graphs("graphs/val.bin")[0][0]
fs_val = "feature_store/val"

# label map
with open("artifacts/label_map.json","r",encoding="utf-8") as f:
    label2id = json.load(f)
id2label = {v:k for k,v in label2id.items()}
classes  = [id2label[i] for i in range(len(label2id))]
d_cat = len(label2id)

feature_names, d_num, d_cat = load_feature_names("feature_store/train")
edge_in = d_num + d_cat  # sanity check == 601 for NF-UNSW-NB15

# 1) Build val loader (or test loader)
val_loader = make_eval_loader(g_val, "feature_store/val", fanouts, batch_size)
val_loader.store_eids = np.load("feature_store/val/edge_indices.npy").astype(np.int64)

# 2) Load model (node_in=0 if you trained w/o explicit node features)
model = _FallbackEdgeGraphSAGE(in_node=0, edge_in=edge_in, hidden=hidden,
                           num_classes=len(label2id), dropout=dropout).to(device)
model.load_state_dict(torch.load("artifacts/best_edge_sage.pt", map_location=device))
model.eval()

# 3) XAI – global per-class importance
feature_names, d_num, d_cat = load_feature_names("feature_store/val")
class2_meanabs = aggregate_shap_per_class(model, g_val, val_loader, "feature_store/val",
                                          feature_names, n_per_class=100, device=device)
# Plots
id2label = {v:k for k,v in label2id.items()}
plot_topk_bar_per_class(class2_meanabs, feature_names, id2label, k=10, save_dir="artifacts/xai")
plot_spider(class2_meanabs, feature_names, id2label, k_union=8, save_path="artifacts/xai/spider.png")

# 4) XAI – per-class signed effects
class2_signed = aggregate_signed_shap_per_class(model, g_val, val_loader, "feature_store/val",
                                                feature_names, n_per_class=100, device=device)
plot_signed_topk(class2_signed, feature_names, id2label, k=10, save_dir="artifacts/xai")

# 5) XAI – local explanation for a single edge
some_local_edge = list(sample_edges_by_class(g_val, n_per_class=1)[0])[0]  # example from class 0
sv, x_edge, y_true, base_val = local_shap_for_edge(model, g_val, val_loader, some_local_edge,
                                                   "feature_store/val", background_size=100,
                                                   target_class=None, device=device)
# visualize the single-edge SHAP top features
imp = np.abs(sv[0]); top = np.argsort(imp)[::-1][:10]
plt.figure(figsize=(8,5)); plt.barh([feature_names[i] for i in top][::-1], imp[top][::-1])
plt.title("Local explanation — top features"); plt.tight_layout(); plt.show()
