In [8]:
import os, torch
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
from torch_geometric.utils import to_networkx
from torch_geometric.explain import Explainer, ModelConfig
from torch_geometric.explain.algorithm import GNNExplainer
import networkx as nx
import matplotlib.pyplot as plt
# Your model defs must match training-time classes
# from your_models import GCNNet, GATNet
# Assuming they were defined as GCNNet(in_dim) and GATNet(in_dim)
# and output logits of shape [num_graphs, num_classes] for graph classification.

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Keep the same label order you used in training:
label_order = ["Benign", "InSitu", "Invasive",  "Normal"]  # <-- change if different
num_classes = len(label_order)
ckpt_dir = "checkpoints"  # where you saved them


In [4]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATv2Conv, global_mean_pool

class GCNNet(nn.Module):
    def __init__(self, in_dim, hidden=64, num_classes=4, dropout=0.3):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hidden)
        self.conv2 = GCNConv(hidden, hidden)
        self.lin = nn.Linear(hidden, num_classes)
        self.dropout = dropout
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)
        return self.lin(x)

class GATNet(nn.Module):
    def __init__(self, in_dim, hidden=32, heads=4, num_classes=4, dropout=0.3):
        super().__init__()
        self.gat1 = GATv2Conv(in_dim, hidden, heads=heads, dropout=dropout)
        self.gat2 = GATv2Conv(hidden*heads, hidden, heads=1, dropout=dropout)
        self.lin = nn.Linear(hidden, num_classes)
        self.dropout = dropout
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.elu(self.gat1(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.elu(self.gat2(x, edge_index))
        x = global_mean_pool(x, batch)
        return self.lin(x)

In [5]:
# 2) Load checkpoints -> models
def load_model(ckpt_path):
    ckpt = torch.load(ckpt_path, map_location='cpu')
    arch, in_dim, state = ckpt['arch'], ckpt['in_dim'], ckpt['state_dict']
    if arch == 'GCN':
        model = GCNNet(in_dim)
    elif arch == 'GAT':
        model = GATNet(in_dim)
    else:
        raise ValueError(f"Unknown arch: {arch}")
    model.load_state_dict(state)
    model.to(device).eval()
    return arch, model

gcn_arch, gcn_model = load_model(os.path.join(ckpt_dir, "GCN_best.pth"))
gat_arch, gat_model = load_model(os.path.join(ckpt_dir, "GAT_best.pth"))
print("Loaded:", gcn_arch, gat_arch)


Loaded: GCN GAT


In [6]:
print(gat_model)
print(gcn_model)

GATNet(
  (gat1): GATv2Conv(768, 32, heads=4)
  (gat2): GATv2Conv(128, 32, heads=1)
  (lin): Linear(in_features=32, out_features=4, bias=True)
)
GCNNet(
  (conv1): GCNConv(768, 64)
  (conv2): GCNConv(64, 64)
  (lin): Linear(in_features=64, out_features=4, bias=True)
)


# Load saved graphs and run predictions

We'll now:
1. Point to the `graphs_delaunay` directory containing saved `.pt` graph objects.
2. Select a small sample of graph files (you can adjust the pattern/limit).
3. Load them as `torch_geometric.data.Data` objects.
4. Batch them and obtain logits & softmax probabilities from both GCN and GAT models.
5. Display top predicted label (with probability) for each graph.

You can later adjust `sample_limit` or specify explicit filenames.

In [10]:
import glob
from typing import List

# Directory with graphs (.pt). Adjust if different.
graphs_dir = "graphs_delaunay"
assert os.path.isdir(graphs_dir), f"Directory not found: {graphs_dir} (adjust path)"

# Collect .pt files (excluding ones that may not be graph Data objects if needed)
all_graph_files = sorted(glob.glob(os.path.join(graphs_dir, '*.pt')))
print(f"Found {len(all_graph_files)} graph files.")

# Limit sample for speed (adjust or set to None for all)
sample_limit = 8  # change as needed
sample_files = all_graph_files[200:200+sample_limit]
print("Using files:")
for f in sample_files:
    print(" -", os.path.basename(f))

loaded_graphs: List[Data] = []
for path in sample_files:
    obj = torch.load(path, map_location='cpu', weights_only=False)  # removed weights_only for broader compatibility
    # Expect either a Data object or a dict with 'data'
    if isinstance(obj, Data):
        data_obj = obj
    elif isinstance(obj, dict) and 'data' in obj and isinstance(obj['data'], Data):
        data_obj = obj['data']
    else:
        raise TypeError(f"Unsupported graph file format: {path}")

    # Ensure required attributes
    if getattr(data_obj, 'x', None) is None or getattr(data_obj, 'edge_index', None) is None:
        raise ValueError(f"Graph {path} missing x or edge_index")

    # Add a dummy y if missing (not needed for inference but some utilities expect it)
    if getattr(data_obj, 'y', None) is None:
        data_obj.y = torch.tensor([-1])

    loaded_graphs.append(data_obj)

print(f"Loaded {len(loaded_graphs)} graphs.")

# Sanity check feature dimension vs models
in_dims = {g.x.size(-1) for g in loaded_graphs}
print("Unique feature dims in sample:", in_dims)

# Create batch
data_batch = Batch.from_data_list(loaded_graphs).to(device)
print(data_batch)

@torch.no_grad()
def predict(model, batch: Batch):
    logits = model(batch)
    probs = logits.softmax(dim=-1)
    return logits, probs

# Run both models
logits_gcn, probs_gcn = predict(gcn_model, data_batch)
logits_gat, probs_gat = predict(gat_model, data_batch)

# Display predictions
print("GCN logits shape:", logits_gcn.shape, "GAT logits shape:", logits_gat.shape)

pred_indices_gcn = probs_gcn.argmax(dim=-1).tolist()
pred_indices_gat = probs_gat.argmax(dim=-1).tolist()
print(pred_indices_gat)
print("Predictions per graph (index: GCN_label(prob) | GAT_label(prob)):")
for i, (pi_gcn, pi_gat) in enumerate(zip(pred_indices_gcn, pred_indices_gat)):
    label_gcn = label_order[pi_gcn] if pi_gcn < len(label_order) else f"idx{pi_gcn}"
    label_gat = label_order[pi_gat] if pi_gat < len(label_order) else f"idx{pi_gat}"
    prob_gcn = probs_gcn[i, pi_gcn].item()
    prob_gat = probs_gat[i, pi_gat].item()
    print(f"Graph {i}: {label_gcn} ({prob_gcn:.3f}) | {label_gat} ({prob_gat:.3f})")

# Keep reference to batch + probabilities for later explanation steps
current_batch = data_batch
current_probs = { 'gcn': probs_gcn, 'gat': probs_gat }

Found 374 graph files.
Using files:
 - iv008.pt
 - iv009.pt
 - iv010.pt
 - iv011.pt
 - iv012.pt
 - iv013.pt
 - iv015.pt
 - iv016.pt
Loaded 8 graphs.
Unique feature dims in sample: {768}
DataBatch(x=[125, 768], edge_index=[2, 580], edge_attr=[580], pos=[125, 2], y=[8], batch=[125], ptr=[9])
GCN logits shape: torch.Size([8, 4]) GAT logits shape: torch.Size([8, 4])
[2, 2, 2, 2, 2, 2, 2, 2]
Predictions per graph (index: GCN_label(prob) | GAT_label(prob)):
Graph 0: Invasive (1.000) | Invasive (1.000)
Graph 1: Invasive (1.000) | Invasive (1.000)
Graph 2: Invasive (1.000) | Invasive (1.000)
Graph 3: Invasive (1.000) | Invasive (1.000)
Graph 4: Invasive (1.000) | Invasive (0.999)
Graph 5: Invasive (1.000) | Invasive (0.999)
Graph 6: Invasive (1.000) | Invasive (1.000)
Graph 7: Invasive (0.999) | Invasive (0.997)


# Define 4 batches of graphs and run GNNExplainer

We'll construct four batches using index ranges over the already sorted `all_graph_files` list:
- b1: indices 0–3
- b2: indices 101–104
- b3: indices 201–204
- b4: indices 301–304

Steps:
1. Ensure `all_graph_files` exists (rebuild if notebook restarted).
2. Load Data objects for each batch (skip if a file index is out of bounds).
3. Pick one representative graph from each batch (default = first index).
4. Predict with both models for those representative graphs.
5. Run GNNExplainer (edge + node masks) for GCN and GAT separately.
6. Visualize side‑by‑side: node coloration by node importance, edge width/color by edge importance.

You can tweak:
- `representative_choice` (choose 'first' or 'mid')
- `explainer_epochs` for fidelity vs speed
- `max_draw_nodes` to subsample large graphs for plotting only (explanations still run on full graph).

In [21]:
from math import floor
from copy import deepcopy

# Rebuild all_graph_files if not present (e.g., fresh kernel)
if 'all_graph_files' not in globals():
    import glob
    graphs_dir = 'graphs_delaunay'
    all_graph_files = sorted(glob.glob(os.path.join(graphs_dir, '*.pt')))
    print(f"Rebuilt list: {len(all_graph_files)} files")

# Batch definitions (inclusive ranges)
batch_ranges = {
    'b1': (0, 3),
    'b2': (101, 104),
    'b3': (201, 204),
    'b4': (301, 304),
}

# Explainer / visualization parameters
explainer_epochs = 40  # lowered for faster iteration; raise later if needed
representative_choice = 'first'  # 'first' or 'mid'
max_draw_nodes = 300  # if graph larger, we will still explain full but may prune for drawing

# Helper: load a single graph Data object from index
def load_graph_by_index(idx: int) -> Data:
    if idx < 0 or idx >= len(all_graph_files):
        raise IndexError(f"Index {idx} out of bounds for {len(all_graph_files)} files")
    path = all_graph_files[idx]
    obj = torch.load(path, map_location='cpu', weights_only=False)  # removed weights_only
    if isinstance(obj, Data):
        data_obj = obj
    elif isinstance(obj, dict) and 'data' in obj and isinstance(obj['data'], Data):
        data_obj = obj['data']
    else:
        raise TypeError(f"Unsupported graph file at {path}")
    if getattr(data_obj, 'y', None) is None:
        data_obj.y = torch.tensor([-1])
    return data_obj

# Prepare batches
def build_batch(idx_range):
    s, e = idx_range
    valid_indices = [i for i in range(s, e+1) if i < len(all_graph_files)]
    datas = [load_graph_by_index(i) for i in valid_indices]
    batch = Batch.from_data_list(datas)
    return valid_indices, datas, batch

batches = {}
for k, rng in batch_ranges.items():
    try:
        idxs, datas, batch = build_batch(rng)
        batches[k] = {'indices': idxs, 'datas': datas, 'batch': batch}
        print(f"{k}: loaded indices {idxs[0]}..{idxs[-1]} (count={len(idxs)})")
    except Exception as ex:
        print(f"{k}: FAILED -> {ex}")

# Representative selection
def pick_representative(datas, choice='first'):
    if not datas: return None
    if choice == 'mid':
        return datas[len(datas)//2]
    return datas[0]

representatives = {k: pick_representative(v['datas'], representative_choice) for k,v in batches.items()}

# Build Explainers (reuse)
explainer_gcn = Explainer(
    model=gcn_model,
    algorithm=GNNExplainer(epochs=explainer_epochs),
    model_config=ModelConfig(task_level='graph', mode='multiclass', return_type='logits')
)
explainer_gat = Explainer(
    model=gat_model,
    algorithm=GNNExplainer(epochs=explainer_epochs),
    model_config=ModelConfig(task_level='graph', mode='multiclass', return_type='logits')
)

# Visualization helpers
import matplotlib.colors as mcolors
import numpy as np

def visualize_explanation(data: Data, explanation, title: str, max_nodes=max_draw_nodes, create_figure=False):
    data_cpu = data.cpu()
    edge_index = data_cpu.edge_index
    G = to_networkx(data_cpu, to_undirected=True)

    # Node importance
    node_mask = None
    for attr_name in ['node_mask', 'node_imp', 'node_scores']:
        if hasattr(explanation, attr_name) and getattr(explanation, attr_name) is not None:
            node_mask = getattr(explanation, attr_name)
            break
    if node_mask is None:
        node_mask = torch.ones(data_cpu.x.size(0))

    # Edge importance
    edge_mask = None
    for attr_name in ['edge_mask', 'edge_imp', 'edge_scores']:
        if hasattr(explanation, attr_name) and getattr(explanation, attr_name) is not None:
            edge_mask = getattr(explanation, attr_name)
            break
    if edge_mask is None:
        edge_mask = torch.ones(edge_index.size(1))

    node_vals = node_mask.detach().cpu().numpy()
    edge_vals = edge_mask.detach().cpu().numpy()

    # Optional pruning for drawing
    if G.number_of_nodes() > max_nodes:
        k = max_nodes
        top_idx = np.argsort(-node_vals)[:k]
        keep = set(top_idx.tolist())
        G = G.subgraph(keep).copy()
        new_edge_vals = []
        for ei, (u, v) in enumerate(edge_index.t().tolist()):
            if u in keep and v in keep and G.has_edge(u, v):
                new_edge_vals.append(edge_vals[ei])
        edge_vals = np.array(new_edge_vals) if new_edge_vals else np.array([])
        node_vals = node_vals[list(keep)]

    cmap = plt.cm.viridis
    norm = mcolors.Normalize(vmin=node_vals.min() if len(node_vals) else 0.0,
                             vmax=node_vals.max() if len(node_vals) else 1.0)
    pos = nx.spring_layout(G, seed=42)

    if create_figure:
        plt.figure(figsize=(4.5,4.5))

    nx.draw_networkx_nodes(G, pos, node_size=120,
                           node_color=[cmap(norm(node_vals[n])) if n < len(node_mask) else (0.5,0.5,0.5,1.0) for n in G.nodes()])

    if len(edge_vals):
        widths = []
        colors = []
        edge_iter = list(G.edges())
        for idx_e, e in enumerate(edge_iter):
            val = edge_vals[idx_e % len(edge_vals)] if len(edge_vals) else 1.0
            widths.append(1.0 + 2.5 * (val / (edge_vals.max() + 1e-9)))
            colors.append(cmap(norm(val)))
        nx.draw_networkx_edges(G, pos, width=widths, edge_color=colors)
    else:
        nx.draw_networkx_edges(G, pos, width=1.0, edge_color='#999999')

    nx.draw_networkx_labels(G, pos, font_size=6)
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    plt.colorbar(sm, shrink=0.6)
    plt.title(title)
    plt.axis('off')

# Run explanations for representatives
rep_explanations = {}
for tag, data_obj in representatives.items():
    if data_obj is None:
        print(f"Skipping {tag}: no data")
        continue
    data_obj = data_obj.to(device)
    with torch.no_grad():
        logits_gcn = gcn_model(Batch.from_data_list([data_obj]).to(device))
        logits_gat = gat_model(Batch.from_data_list([data_obj]).to(device))
        pred_gcn = logits_gcn.softmax(-1).argmax(-1).item()
        pred_gat = logits_gat.softmax(-1).argmax(-1).item()
    print(f"Representative {tag}: GCN -> {label_order[pred_gcn]} | GAT -> {label_order[pred_gat]}")

    # Some PyG versions may not require / accept index for graph task; wrap in try
    try:
        explanation_gcn = explainer_gcn(data_obj, index=0)
    except TypeError:
        explanation_gcn = explainer_gcn(data_obj)
    try:
        explanation_gat = explainer_gat(data_obj, index=0)
    except TypeError:
        explanation_gat = explainer_gat(data_obj)

    rep_explanations[tag] = {'gcn': explanation_gcn, 'gat': explanation_gat, 'data': data_obj.cpu()}

    plt.figure(figsize=(9,4.8))
    plt.subplot(1,2,1)
    visualize_explanation(data_obj.cpu(), explanation_gcn, f"{tag} - GCN")
    plt.subplot(1,2,2)
    visualize_explanation(data_obj.cpu(), explanation_gat, f"{tag} - GAT")
    plt.suptitle(f"Representative {tag} explanations")
    plt.tight_layout()
    plt.show()

print("Done generating explanations for representatives.")

b1: loaded indices 0..3 (count=4)
b2: loaded indices 101..104 (count=4)
b3: loaded indices 201..204 (count=4)
b4: loaded indices 301..304 (count=4)


ValueError: 'multiclass' is not a valid ModelMode

# Official PyG Explainer usage on current batch predictions

Using PyG 2.5.2 Explainer API per docs:
- Define a ModelConfig with `task_level='graph'`, `mode='multiclass'`, `return_type='logits'` (our models output raw logits).
- Instantiate an `Explainer` wrapping each model with `GNNExplainer` algorithm.
- For each graph in a Batch, obtain its predicted class and call `explainer(batch, index=i, target=predicted_class)`.
- Collect node and edge masks (`explanation.node_mask`, `explanation.edge_mask`).

Below we run this on the previously built `data_batch` (from the sample file subset) and store explanations in dictionaries for later analysis/visualization.

In [24]:
# Ensure we have a batch (reuse current_batch if available)
if 'current_batch' not in globals():
    raise RuntimeError("current_batch not found. Re-run the earlier cell that builds data_batch.")

batch_for_explain = current_batch.to(device)
num_graphs = batch_for_explain.num_graphs
print(f"Explaining {num_graphs} graphs in current_batch")

# Build fresh Explainers following docs (in case previous ones differ)
model_config = ModelConfig(task_level='graph', mode='multiclass_classification', return_type='log_probs')
explainer_official_gcn = Explainer(model=gcn_model, algorithm=GNNExplainer(epochs=60), model_config=model_config)
explainer_official_gat = Explainer(model=gat_model, algorithm=GNNExplainer(epochs=60), model_config=model_config)

@torch.no_grad()
def predict_logits(model, batch):
    return model(batch)

logits_gcn_full = predict_logits(gcn_model, batch_for_explain)
logits_gat_full = predict_logits(gat_model, batch_for_explain)
probs_gcn_full = logits_gcn_full.softmax(-1)
probs_gat_full = logits_gat_full.softmax(-1)
preds_gcn_full = probs_gcn_full.argmax(-1).tolist()
preds_gat_full = probs_gat_full.argmax(-1).tolist()

print("GCN predictions:")
for i, cls in enumerate(preds_gcn_full):
    print(f"  graph {i}: {label_order[cls]} ({probs_gcn_full[i, cls].item():.3f})")
print("GAT predictions:")
for i, cls in enumerate(preds_gat_full):
    print(f"  graph {i}: {label_order[cls]} ({probs_gat_full[i, cls].item():.3f})")

# Run explanations; store results
explanations_gcn = []
explanations_gat = []
for i in range(num_graphs):
    target_gcn = preds_gcn_full[i]
    target_gat = preds_gat_full[i]
    # According to docs: explainer(batch, index=i, target=class_idx)
    exp_gcn = explainer_official_gcn(batch_for_explain, index=i, target=target_gcn)
    exp_gat = explainer_official_gat(batch_for_explain, index=i, target=target_gat)
    explanations_gcn.append(exp_gcn)
    explanations_gat.append(exp_gat)
print("Collected explanations for all graphs.")

# Simple visualization of first graph explanation for each model using previously defined helper (if present)
if 'visualize_explanation' in globals():
    from torch_geometric.data import Batch as _Batch
    # Extract single graph Data from batch (PyG Explanation has a subgraph property sometimes; we still use original)
    # We rebuild from slice indices for clarity
    first_graph_mask = (batch_for_explain.batch == 0)
    data0 = Data(x=batch_for_explain.x[first_graph_mask].cpu(),
                 edge_index=batch_for_explain.edge_index.clone().cpu())
    plt.figure(figsize=(9,4.5))
    plt.subplot(1,2,1)
    visualize_explanation(data0, explanations_gcn[0], f"Graph 0 GCN (target={label_order[preds_gcn_full[0]]})", create_figure=False)
    plt.subplot(1,2,2)
    visualize_explanation(data0, explanations_gat[0], f"Graph 0 GAT (target={label_order[preds_gat_full[0]]})", create_figure=False)
    plt.suptitle("First graph explanations (official API)")
    plt.tight_layout()
    plt.show()
else:
    print("visualize_explanation helper not found; skipping quick plot.")

# Keep references for later analysis
official_explanations = { 'gcn': explanations_gcn, 'gat': explanations_gat }
print("Done.")

Explaining 8 graphs in current_batch


TypeError: Explainer.__init__() missing 1 required positional argument: 'explanation_type'