In [1]:
from xai import load_feature_names, aggregate_shap_per_class
from xai import plot_topk_bar_per_class, plot_spider
from xai import aggregate_signed_shap_per_class, plot_signed_topk
from xai import local_shap_for_edge, sample_edges_by_class
from eval_utils import make_eval_loader

import numpy as np
import json, torch, dgl
import matplotlib.pyplot as plt
from train_edgecls_dbg import _FallbackEdgeGraphSAGE

device = "cuda" if torch.cuda.is_available() else "cpu"
# fill with your training hyperparameters (should match)
# TODO improve: load from a config file
fanouts=(25,15)
batch_size=2048
hidden=128
dropout=0.3

# Load artifacts
g_val = dgl.load_graphs("graphs/val.bin")[0][0]
fs_val = "feature_store/val"

# label map
with open("artifacts/label_map.json","r",encoding="utf-8") as f:
    label2id = json.load(f)
id2label = {v:k for k,v in label2id.items()}
classes  = [id2label[i] for i in range(len(label2id))]
d_cat = len(label2id)

feature_names, d_num, d_cat = load_feature_names("feature_store/train")
edge_in = d_num + d_cat  # sanity check == 601 for NF-UNSW-NB15

# 1) Build val loader (or test loader)
val_loader = make_eval_loader(g_val, "feature_store/val", fanouts, batch_size)
val_loader.store_eids = np.load("feature_store/val/edge_indices.npy").astype(np.int64)

# 2) Load model (node_in=0 if you trained w/o explicit node features)
model = _FallbackEdgeGraphSAGE(in_node=0, edge_in=edge_in, hidden=hidden,
                           num_classes=len(label2id), dropout=dropout).to(device)
model.load_state_dict(torch.load("artifacts/best_edge_sage.pt", map_location=device))
model.eval()

# 3) XAI – global per-class importance
feature_names, d_num, d_cat = load_feature_names("feature_store/val")
class2_meanabs = aggregate_shap_per_class(model, g_val, val_loader, "feature_store/val",
                                          feature_names, n_per_class=100, device=device)
# Plots
id2label = {v:k for k,v in label2id.items()}
plot_topk_bar_per_class(class2_meanabs, feature_names, id2label, k=10, save_dir="artifacts/xai")
plot_spider(class2_meanabs, feature_names, id2label, k_union=8, save_path="artifacts/xai/spider.png")

# 4) XAI – per-class signed effects
class2_signed = aggregate_signed_shap_per_class(model, g_val, val_loader, "feature_store/val",
                                                feature_names, n_per_class=100, device=device)
plot_signed_topk(class2_signed, feature_names, id2label, k=10, save_dir="artifacts/xai")

# 5) XAI – local explanation for a single edge / show for all classes
some_local_edge = list(sample_edges_by_class(g_val, n_per_class=1)[0])[0]  # example from class 0
sv, x_edge, y_true, base_val = local_shap_for_edge(model, g_val, val_loader, some_local_edge,
                                                   "feature_store/val", background_size=100,
                                                   target_class=None, device=device)
# visualize the single-edge SHAP top features
imp = np.abs(sv[0]); top = np.argsort(imp)[::-1][:10]
plt.figure(figsize=(8,5)); plt.barh([feature_names[i] for i in top][::-1], imp[top][::-1])
plt.title("Local explanation — top features"); plt.tight_layout(); plt.show()




AttributeError: '_FallbackEdgeGraphSAGE' object has no attribute 'encode'

In [None]:
# 5b) XAI — improved local explanations: sample multiple edges per class,
#     aggregate SHAP (mean absolute) and produce summary plots (top-10 features).
import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

fs_val = "feature_store/val"
n_per_class = 10           # how many edges to sample per class for local explanations
background_size = 100      # as before (tune for fidelity vs speed)

# sample edges for each class: returns iterable per class
sampled = list(sample_edges_by_class(g_val, n_per_class=n_per_class))

num_classes = len(classes)
num_features = len(feature_names)

# container: class_id -> list of abs-SHAP arrays (shape (num_features,))
class_abs_shaps = {c: [] for c in range(num_classes)}

for class_id, edges_for_class in enumerate(sampled):
    for eid in edges_for_class:
        try:
            sv, x_edge, y_true_local, base_val = local_shap_for_edge(
                model, g_val, val_loader, int(eid),
                fs_val, background_size=background_size,
                target_class=None, device=device
            )
            # sv expected shape: (num_classes, num_features)
            # take absolute and accumulate per class (we use all output classes)
            # for class-level summary we want how important features are for each predicted class:
            # here we append absolute shap vector for the specific true class (class_id)
            class_abs_shaps[class_id].append(np.abs(sv[class_id]))  # shape (num_features,)
        except Exception as e:
            # skip problematic sample but report
            print(f"warning: failed SHAP for eid={eid} class={class_id}: {e}")

# compute mean absolute SHAP per class
class2_meanabs = {}
for cid in range(num_classes):
    arrs = class_abs_shaps[cid]
    if len(arrs) == 0:
        class2_meanabs[cid] = np.zeros(num_features)
    else:
        class2_meanabs[cid] = np.mean(np.stack(arrs, axis=0), axis=0)  # (num_features,)

# Plot per-class top-10 bar charts in a grid
os.makedirs("artifacts/xai", exist_ok=True)
ncols = 3
nrows = int(np.ceil(num_classes / ncols))
fig, axs = plt.subplots(nrows, ncols, figsize=(5*ncols, 3.5*nrows), constrained_layout=True)
axs = axs.flatten()
for cid in range(num_classes):
    meanabs = class2_meanabs[cid]
    top_idx = np.argsort(meanabs)[::-1][:10]
    vals = meanabs[top_idx][::-1]                  # reverse for horizontal barh
    names = [feature_names[i] for i in top_idx][::-1]
    ax = axs[cid]
    ax.barh(names, vals, color=sns.color_palette("tab10"))
    ax.set_title(f"Class {cid}: {classes[cid]} (top10)")
    ax.set_xlabel("mean|SHAP|")
for j in range(num_classes, len(axs)):
    fig.delaxes(axs[j])
fig.suptitle("Per-class local SHAP — top 10 features (mean |SHAP|)", fontsize=14)
fig.savefig("artifacts/xai/local_shap_top10_per_class.png", dpi=200)
plt.show()

# Build union of top-k features across classes and plot heatmap
topk = 10
union_idxs = set()
for cid in range(num_classes):
    union_idxs.update(np.argsort(class2_meanabs[cid])[::-1][:topk])
union_idxs = sorted(union_idxs, key=lambda i: -np.max([class2_meanabs[c][i] for c in range(num_classes)]))

feat_names_union = [feature_names[i] for i in union_idxs]
heatmat = np.vstack([class2_meanabs[c][union_idxs] for c in range(num_classes)])  # shape (C, len(union))
# normalize columns for better visual comparability (optional)
col_max = heatmat.max(axis=0, keepdims=True)
col_max[col_max == 0] = 1.0
heatmat_norm = heatmat / col_max

plt.figure(figsize=(max(6, 0.6*len(feat_names_union)), max(6, 0.5*num_classes)))
sns.heatmap(heatmat_norm, xticklabels=feat_names_union, yticklabels=[classes[c] for c in range(num_classes)],
            cmap="vlag", cbar_kws={"label": "relative mean|SHAP| (norm per feature)"},
            annot=False)
plt.xticks(rotation=45, ha="right")
plt.title("SHAP summary heatmap — union top features")
plt.xlabel("Feature")
plt.ylabel("Class")
plt.tight_layout()
plt.savefig("artifacts/xai/local_shap_summary_heatmap.png", dpi=200)
plt.show()