In [1]:
# # -*- coding: utf-8 -*-
# # ===============================================================
# # One-cell notebook: EDUQGC scaling study + GraphLIME explanations
# # ===============================================================

# import os, sys, time, pickle, random, math, itertools, json
# from pathlib import Path
# from datetime import datetime

# import numpy as np
# import torch
# import torch.nn.functional as F
# from torch import nn
# import matplotlib.pyplot as plt
# from sklearn.decomposition import PCA
# from sklearn.linear_model import Ridge

# import pennylane as qml

# # Torch Geometric
# try:
#     from torch_geometric.datasets import Planetoid
#     import torch_geometric.utils as pyg_utils
# except Exception as e:
#     raise RuntimeError("Need torch-geometric installed.")

# # =====================================================
# # Utils for determinism
# # =====================================================
# def set_seed(seed=42):
#     os.environ["PYTHONHASHSEED"] = str(seed)
#     random.seed(seed); np.random.seed(seed)
#     torch.manual_seed(seed)
#     if torch.cuda.is_available():
#         torch.cuda.manual_seed_all(seed)
#     torch.use_deterministic_algorithms(False)

# # =====================================================
# # EDUQGC Model
# # =====================================================
# class EDUQGC(nn.Module):
#     def __init__(self, n_nodes, in_feats, T=2, seed=0, use_gpu_qnode=True, use_feat_skip=True, num_classes=3):
#         super().__init__()
#         self.n_nodes = n_nodes
#         self.T = T
#         self.use_feat_skip = use_feat_skip

#         self.enc_W = nn.Parameter(torch.randn(T, 2, in_feats) * 0.08)
#         self.enc_b = nn.Parameter(torch.randn(T, 2) * 0.02)

#         self.edge_phase  = nn.Parameter(torch.randn(T) * 0.08)
#         self.pre_theta   = nn.Parameter(torch.randn(T) * 0.08)
#         self.pre_psi     = nn.Parameter(torch.randn(T) * 0.08)
#         self.post_theta  = nn.Parameter(torch.randn(T) * 0.08)
#         self.post_psi    = nn.Parameter(torch.randn(T) * 0.08)

#         readin_dim = 1 + in_feats if use_feat_skip else 1
#         self.readout = nn.Linear(readin_dim, num_classes)

#         use_cuda = torch.cuda.is_available()
#         qdev_name = "lightning.gpu" if (use_gpu_qnode and use_cuda) else "default.qubit"
#         self.dev = qml.device(qdev_name, wires=n_nodes, shots=None)

#         @qml.qnode(self.dev, interface="torch", diff_method="best")
#         def circuit(edge_index, X, enc_W, enc_b,
#                     edge_phase, pre_theta, pre_psi, post_theta, post_psi):
#             for t in range(self.T):
#                 enc_out = X @ enc_W[t].T + enc_b[t]
#                 alphas = enc_out[:, 0]; betas = enc_out[:, 1]
#                 for i in range(self.n_nodes):
#                     qml.RX(alphas[i], wires=i)
#                     qml.RY(betas[i], wires=i)

#                 for i in range(self.n_nodes):
#                     qml.RZ(pre_psi[t], wires=i)
#                     qml.RX(pre_theta[t], wires=i)

#                 E = edge_index.shape[1]
#                 for e in range(E):
#                     u = int(edge_index[0, e].item()); v = int(edge_index[1, e].item())
#                     if u != v:
#                         qml.ControlledPhaseShift(edge_phase[t], wires=[u, v])

#                 for i in range(self.n_nodes):
#                     qml.RZ(post_psi[t], wires=i)
#                     qml.RX(post_theta[t], wires=i)

#             return [qml.expval(qml.Z(i)) for i in range(self.n_nodes)]

#         self._circuit = circuit

#     def forward(self, edge_index_torch, x_torch):
#         device = next(self.parameters()).device
#         edge_index = edge_index_torch.to(device)
#         X = x_torch.to(device).float()

#         out = self._circuit(edge_index, X,
#                             self.enc_W, self.enc_b,
#                             self.edge_phase,
#                             self.pre_theta, self.pre_psi,
#                             self.post_theta, self.post_psi)

#         expvals = torch.stack(out, dim=0).float().to(device)
#         if expvals.dim() == 1:
#             expvals = expvals.unsqueeze(1)
#         elif expvals.dim() == 2 and expvals.shape[1] == 1:
#             pass
#         else:
#             expvals = expvals.squeeze(-1).unsqueeze(1)

#         readin = torch.cat([expvals, X], dim=1) if self.use_feat_skip else expvals
#         logits = self.readout(readin)
#         return logits

# # =====================================================
# # Dataset builder
# # =====================================================
# def build_cora_ego_dataset(num_graphs=20, target_nodes=20, pca_dim=10,
#                            chosen_labels=(0,1,2), seed=42, save_path=None):
#     rng = np.random.default_rng(seed)
#     dataset = Planetoid(root="./cora_multi_exp/raw", name="Cora")
#     data = dataset[0]
#     x_all = data.x.numpy().astype(np.float32)
#     y_all = data.y.numpy().astype(np.int64)
#     edge_index_all = data.edge_index.numpy()

#     chosen_labels_set = set(chosen_labels)
#     eligible = np.array([i for i,y in enumerate(y_all) if y in chosen_labels_set])
#     pca = PCA(n_components=pca_dim, random_state=seed)
#     pca.fit(x_all[eligible])

#     def pick_subgraph(center):
#         nodes = pyg_utils.k_hop_subgraph(
#             torch.tensor([center], dtype=torch.long),
#             2,
#             torch.as_tensor(edge_index_all, dtype=torch.long),
#             relabel_nodes=False
#         )[0].numpy()

#         nodes = np.array([n for n in nodes if y_all[n] in chosen_labels_set])
#         unique = set(nodes.tolist())
#         if len(unique) < target_nodes:
#             extra = [n for n in eligible if n not in unique]
#             rng.shuffle(extra)
#             nodes = list(unique) + extra[:target_nodes-len(unique)]
#         nodes = np.array(nodes[:target_nodes])

#         relabel = {old:i for i,old in enumerate(nodes)}
#         mask = [(u in relabel and v in relabel)
#                 for u,v in zip(edge_index_all[0], edge_index_all[1])]
#         sub_edges = edge_index_all[:,mask]
#         if sub_edges.size > 0:
#             remapped = np.array([[relabel[int(u)], relabel[int(v)]]
#                                  for u,v in sub_edges.T]).T
#         else:
#             remapped = np.zeros((2,0),dtype=int)
#         edges = set()
#         for u,v in remapped.T:
#             edges.add((u,v)); edges.add((v,u))
#         if len(edges)==0:
#             edge_index = np.vstack([np.arange(target_nodes), np.arange(target_nodes)])
#         else:
#             u_list,v_list = zip(*edges)
#             edge_index = np.vstack([u_list,v_list])
#         X = pca.transform(x_all[nodes])
#         y = np.array([ {old:i for i,old in enumerate(chosen_labels)}[int(lbl)]
#                        for lbl in y_all[nodes]])
#         return dict(edge_index=edge_index, X=X.astype(np.float32), y=y)

#     rng.shuffle(eligible)
#     centers = eligible[:num_graphs]
#     graphs = [pick_subgraph(c) for c in centers]

#     if save_path:
#         Path(save_path).parent.mkdir(parents=True, exist_ok=True)
#         with open(save_path,"wb") as f:
#             pickle.dump(dict(graphs=graphs,meta=dict(N=target_nodes,F=pca_dim,seed=seed)),f)
#     return graphs

# # =====================================================
# # Train/Eval helpers
# # =====================================================
# def train_epoch(model, graphs, opt, device="cpu"):
#     model.train(); total_loss=0; nodes=0
#     for g in graphs:
#         e = torch.from_numpy(g["edge_index"]).long().to(device)
#         X = torch.from_numpy(g["X"]).float().to(device)
#         y = torch.from_numpy(g["y"]).long().to(device)
#         logits = model(e,X)
#         loss = F.cross_entropy(logits,y)
#         opt.zero_grad(); loss.backward(); opt.step()
#         total_loss += loss.item()*X.shape[0]; nodes+=X.shape[0]
#     return total_loss/max(1,nodes)

# @torch.no_grad()
# def evaluate(model, graphs, device="cpu"):
#     model.eval(); total_loss=0; nodes=0; correct=0
#     for g in graphs:
#         e = torch.from_numpy(g["edge_index"]).long().to(device)
#         X = torch.from_numpy(g["X"]).float().to(device)
#         y = torch.from_numpy(g["y"]).long().to(device)
#         logits = model(e,X)
#         loss = F.cross_entropy(logits,y)
#         pred = logits.argmax(1)
#         correct += (pred==y).sum().item()
#         total_loss += loss.item()*X.shape[0]; nodes+=X.shape[0]
#     return total_loss/max(1,nodes), correct/max(1,nodes)

# # =====================================================
# # GraphLIME (classical explanation model)
# # =====================================================
# class GraphLIME:
#     def __init__(self, num_samples=512, sigma=0.1, hop_k=0, kernel_width=0.75,
#                  alpha_ridge=1e-2, random_state=42):
#         self.num_samples=num_samples; self.sigma=sigma; self.hop_k=hop_k
#         self.kernel_width=kernel_width; self.alpha_ridge=alpha_ridge
#         self.random_state=random_state

#     def _softmax(self, logits): return torch.softmax(logits, dim=-1)

#     def _neighbors_khop(self, edge_index_np, node, k, N):
#         if k<=0: return np.array([node],dtype=int)
#         adj=[[] for _ in range(N)]
#         for u,v in edge_index_np.T: adj[int(u)].append(int(v))
#         hopset={node}; frontier={node}
#         for _ in range(k):
#             nxt=set()
#             for u in frontier: nxt|=set(adj[u])
#             hopset|=nxt; frontier=nxt
#         return np.array(sorted(hopset),dtype=int)

#     def explain(self, model, gdict, node, device):
#         rng=np.random.default_rng(self.random_state)
#         model.eval()
#         edge_index=torch.from_numpy(gdict["edge_index"]).long().to(device)
#         X_np=gdict["X"].astype(np.float32); y_np=gdict["y"]
#         N,F=X_np.shape
#         X_base=torch.from_numpy(X_np).float().to(device)
#         with torch.no_grad():
#             base_logits=model(edge_index,X_base)
#             base_probs=self._softmax(base_logits)[node].cpu().numpy()
#             base_pred=int(np.argmax(base_probs))
#         target_class=base_pred

#         P_nodes=self._neighbors_khop(gdict["edge_index"],node,self.hop_k,N)
#         P_mask=np.zeros(N,dtype=bool); P_mask[P_nodes]=True
#         feat_std=X_np.std(axis=0)+1e-8

#         Z=np.zeros((self.num_samples,F),dtype=np.float32)
#         y=np.zeros((self.num_samples,),dtype=np.float32)
#         w=np.zeros((self.num_samples,),dtype=np.float32)

#         for s in range(self.num_samples):
#             X_pert=X_np.copy()
#             noise_all=rng.normal(0.0,self.sigma,size=(N,F)).astype(np.float32)
#             X_pert[P_mask]=X_np[P_mask]+noise_all[P_mask]*feat_std
#             X_pert_t=torch.from_numpy(X_pert).float().to(device)
#             with torch.no_grad():
#                 logits=model(edge_index,X_pert_t)
#                 prob=self._softmax(logits)[node,target_class].item()
#             Z[s]=X_pert[node]; y[s]=prob
#             d=np.linalg.norm((X_pert-X_np)[P_mask].ravel())
#             w[s]=np.exp(-(d**2)/(2*(self.kernel_width**2)+1e-12))

#         reg=Ridge(alpha=self.alpha_ridge,fit_intercept=True,random_state=self.random_state)
#         reg.fit(Z,y,sample_weight=w); coefs=reg.coef_
#         imp=coefs/(np.linalg.norm(coefs,ord=2)+1e-12)

#         return dict(node=node,target_class=target_class,base_pred=base_pred,
#                     base_probs=base_probs.tolist(),importances=imp.tolist())

#     def save_plot(self, expl, out_png, feature_names=None, topk=20):
#         imps=np.array(expl["importances"]); idx=np.argsort(np.abs(imps))[::-1]
#         if topk: idx=idx[:topk]
#         labels=[f"PC{j}" for j in range(len(imps))] if feature_names is None else feature_names
#         labels_sel=[labels[j] for j in idx]; vals=imps[idx]
#         plt.figure(figsize=(8,max(3,0.35*len(idx))))
#         plt.barh(range(len(idx)),vals); plt.yticks(range(len(idx)),labels_sel)
#         plt.gca().invert_yaxis(); plt.title(f"GraphLIME Node {expl['node']}")
#         plt.xlabel("importance"); plt.tight_layout()
#         Path(out_png).parent.mkdir(parents=True,exist_ok=True)
#         plt.savefig(out_png,dpi=200); plt.close()

# # =====================================================
# # Example: Run scaling exp + GraphLIME explain
# # =====================================================
# set_seed(42)
# device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # Build dataset & train one EDUQGC
# graphs=build_cora_ego_dataset(num_graphs=5,target_nodes=20,pca_dim=10,seed=42)
# train, val, test=graphs[:3],graphs[3:4],graphs[4:]
# model=EDUQGC(n_nodes=20,in_feats=10,num_classes=3).to(device)
# opt=torch.optim.Adam(model.parameters(),lr=0.03)

# for ep in range(5):
#     tr_loss=train_epoch(model,train,opt,device)
#     val_loss,val_acc=evaluate(model,val,device)
#     print(f"Epoch {ep+1} | tr_loss={tr_loss:.3f} | val_acc={val_acc:.3f}")

# # Run GraphLIME explanation
# lime=GraphLIME(num_samples=128,sigma=0.1,hop_k=0)
# expl=lime.explain(model,graphs[0],node=0,device=device)
# print("Explanation importances:",expl["importances"][:5])
# lime.save_plot(expl,"cora_multi_exp/explanations/example_node0.png")
# print("Saved GraphLIME plot at cora_multi_exp/explanations/example_node0.png")


In [2]:
# Single-cell notebook: EDUQGC + scaling + GraphLIME + reproducibility + saving + visualization
# Run this entire cell in one notebook cell.

# ---------------------------
# ENV (must be set before torch import for deterministic cuBLAS)
# ---------------------------
import os
# Recommended config for reproducible cuBLAS behavior when using deterministic algorithms on CUDA:
os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")

# ---------------------------
# Imports (after env above)
# ---------------------------
import sys, time, json, math, random, pickle
from pathlib import Path
from datetime import datetime
import itertools
import warnings

import numpy as np
import matplotlib.pyplot as plt

# Torch & related (import after CUBLAS var set)
import torch
import torch.nn.functional as F
from torch import nn

# PennyLane (optional GPU lightning backend)
import pennylane as qml

# PyG + sklearn (must be installed)
try:
    from torch_geometric.datasets import Planetoid
    import torch_geometric.utils as pyg_utils
except Exception as e:
    raise RuntimeError("This notebook requires torch-geometric. Install per https://pytorch-geometric.readthedocs.io/")

from sklearn.decomposition import PCA
from sklearn.linear_model import Ridge

# ---------------------------
# Reproducibility utility (per PyTorch docs)
# ---------------------------
def set_reproducibility(seed:int = 42, strict: bool = True):
    """
    Set seeds and determinism flags. If strict=True, attempt to enable
    torch.use_deterministic_algorithms(True) and require CUBLAS env var.
    If strict is False, do a best-effort deterministic setup (faster).
    Note: If you change CUBLAS_WORKSPACE_CONFIG you must restart the kernel.
    """
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    if strict:
        try:
            # This will raise an error if an op is non-deterministic on the platform
            torch.use_deterministic_algorithms(True)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
            print("[repro] Strict deterministic mode enabled.")
        except Exception as e:
            warnings.warn(
                "Could not enable strict deterministic algorithms. "
                "If using CUDA ensure CUBLAS_WORKSPACE_CONFIG was set before Python start. "
                f"Exception: {e}"
            )
            # fallback to best-effort
            torch.use_deterministic_algorithms(False)
            torch.backends.cudnn.deterministic = False
            torch.backends.cudnn.benchmark = True
    else:
        torch.use_deterministic_algorithms(False)
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True
        print("[repro] Best-effort reproducibility (non-strict).")

# ---------------------------
# EDUQGC model (unchanged semantics; parameterizable input dims & classes)
# ---------------------------
class EDUQGCNodeClassifier(nn.Module):
    def __init__(self, n_nodes, in_feats, T=2, seed=0, use_gpu_qnode=True, use_feat_skip=True, num_classes=3):
        super().__init__()
        self.n_nodes = n_nodes
        self.T = T
        self.use_feat_skip = use_feat_skip

        # Packed parameters for encoders
        self.enc_W = nn.Parameter(torch.randn(T, 2, in_feats) * 0.08)
        self.enc_b = nn.Parameter(torch.randn(T, 2) * 0.02)

        self.edge_phase  = nn.Parameter(torch.randn(T) * 0.08)
        self.pre_theta   = nn.Parameter(torch.randn(T) * 0.08)
        self.pre_psi     = nn.Parameter(torch.randn(T) * 0.08)
        self.post_theta  = nn.Parameter(torch.randn(T) * 0.08)
        self.post_psi    = nn.Parameter(torch.randn(T) * 0.08)

        readin_dim = 1 + in_feats if use_feat_skip else 1
        self.readout = nn.Linear(readin_dim, num_classes)

        use_cuda = torch.cuda.is_available()
        qdev_name = "lightning.gpu" if (use_gpu_qnode and use_cuda) else "default.qubit"
        # safe fallback: if lightning.gpu not available, will raise - we catch in runner
        try:
            self.dev = qml.device(qdev_name, wires=n_nodes, shots=None)
        except Exception as e:
            # fallback to default.qubit
            warnings.warn(f"Could not create device '{qdev_name}': {e}. Falling back to default.qubit (CPU).")
            self.dev = qml.device("default.qubit", wires=n_nodes, shots=None)

        @qml.qnode(self.dev, interface="torch", diff_method="best")
        def circuit(edge_index, X, enc_W, enc_b,
                    edge_phase, pre_theta, pre_psi, post_theta, post_psi):
            # edge_index: [2, E], X: [N, F]
            for t in range(self.T):
                enc_out = X @ enc_W[t].T + enc_b[t]  # [N, 2]
                alphas = enc_out[:, 0]; betas = enc_out[:, 1]
                # data-encoding rotations
                for i in range(self.n_nodes):
                    qml.RX(alphas[i], wires=i)
                    qml.RY(betas[i], wires=i)
                # pre rotations
                for i in range(self.n_nodes):
                    qml.RZ(pre_psi[t], wires=i)
                    qml.RX(pre_theta[t], wires=i)
                # entanglers
                E = edge_index.shape[1] if edge_index.ndim == 2 else 0
                for e in range(E):
                    u = int(edge_index[0, e].item()); v = int(edge_index[1, e].item())
                    if u != v:
                        qml.ControlledPhaseShift(edge_phase[t], wires=[u, v])
                # post rotations
                for i in range(self.n_nodes):
                    qml.RZ(post_psi[t], wires=i)
                    qml.RX(post_theta[t], wires=i)
            return [qml.expval(qml.Z(i)) for i in range(self.n_nodes)]
        self._circuit = circuit

    def forward(self, edge_index_torch, x_torch):
        device = next(self.parameters()).device
        edge_index = edge_index_torch.to(device)
        X = x_torch.to(device).float()

        out = self._circuit(edge_index, X,
                            self.enc_W, self.enc_b,
                            self.edge_phase,
                            self.pre_theta, self.pre_psi,
                            self.post_theta, self.post_psi)
        expvals = torch.stack(out, dim=0).float().to(device)
        if expvals.dim() == 1:
            expvals = expvals.unsqueeze(1)
        elif expvals.dim() == 2 and expvals.shape[1] == 1:
            pass
        else:
            expvals = expvals.squeeze(-1).unsqueeze(1)

        readin = torch.cat([expvals, X], dim=1) if self.use_feat_skip else expvals
        logits = self.readout(readin)
        return logits

# ---------------------------
# Dataset builder: Cora -> k-hop ego graphs -> PCA -> choose 3 labels
# Returns list of dicts {edge_index: [2,E], X: [N,F], y: [N]}
# ---------------------------
def build_cora_ego_dataset(num_graphs=20, chosen_labels=(0,1,2), n_hops=2,
                           target_nodes=20, pca_dim=10, seed=42, save_path=None):
    rng = np.random.default_rng(seed)
    dataset = Planetoid(root="./cora_data", name="Cora")
    data = dataset[0]
    x_all = data.x.numpy().astype(np.float32)
    y_all = data.y.numpy().astype(np.int64)
    edge_index_all = data.edge_index.numpy()

    chosen_set = set(chosen_labels)
    eligible = np.array([i for i in range(len(y_all)) if y_all[i] in chosen_set], dtype=np.int64)
    if len(eligible) == 0:
        raise RuntimeError("No eligible nodes for chosen_labels.")

    # Fit PCA on eligible nodes features for stable transform
    pca = PCA(n_components=pca_dim, random_state=seed)
    pca.fit(x_all[eligible])

    def pick_subgraph(center_id):
        # get k-hop subgraph nodes (use torch_geometric utility)
        nodes = pyg_utils.k_hop_subgraph(torch.tensor([center_id], dtype=torch.long), n_hops,
                                        torch.as_tensor(edge_index_all, dtype=torch.long), relabel_nodes=False)[0].numpy()
        # keep only nodes with chosen labels
        nodes = np.array([n for n in nodes if y_all[n] in chosen_set], dtype=np.int64)
        unique = list(dict.fromkeys(nodes.tolist()))  # preserve order, unique
        # pad if necessary
        if len(unique) < target_nodes:
            extras = [n for n in eligible if n not in unique]
            rng.shuffle(extras)
            take = extras[: max(0, target_nodes - len(unique))]
            unique = unique + list(take)
        # trim to exact size
        if len(unique) > target_nodes:
            rng.shuffle(unique)
            unique = unique[:target_nodes]
        nodes = np.array(unique, dtype=np.int64)
        assert len(nodes) == target_nodes

        # induced edges
        node_set = set(nodes.tolist())
        mask = np.array([(u in node_set and v in node_set) for u, v in zip(edge_index_all[0], edge_index_all[1])])
        sub_edges = edge_index_all[:, mask]
        relabel = {old: new for new, old in enumerate(nodes)}
        if sub_edges.shape[1] > 0:
            remapped = np.array([[relabel[int(u)], relabel[int(v)]] for u, v in sub_edges.T], dtype=np.int64).T
            # make undirected pairs and deduplicate
            edges = set()
            for u, v in remapped.T:
                edges.add((u, v)); edges.add((v, u))
            u_list, v_list = zip(*sorted(edges))
            edge_index = np.vstack([u_list, v_list]).astype(np.int64)
        else:
            # No edges => self-loops to keep circuit valid
            edge_index = np.vstack([np.arange(target_nodes), np.arange(target_nodes)]).astype(np.int64)

        # PCA compress
        X_sub = pca.transform(x_all[nodes]).astype(np.float32)
        # relabel y to 0..len(chosen_labels)-1
        map_lbl = {old: new for new, old in enumerate(chosen_labels)}
        y_sub = np.array([map_lbl[int(lbl)] for lbl in y_all[nodes]], dtype=np.int64)
        return dict(edge_index=edge_index, X=X_sub, y=y_sub)

    # Choose deterministic centers from eligible
    rng.shuffle(eligible)
    centers = eligible[:num_graphs] if len(eligible) >= num_graphs else np.resize(eligible, num_graphs)
    graphs = [pick_subgraph(int(c)) for c in centers]

    if save_path:
        Path(save_path).parent.mkdir(parents=True, exist_ok=True)
        with open(save_path, "wb") as f:
            pickle.dump(dict(graphs=graphs, meta=dict(dataset="Cora", chosen_labels=tuple(chosen_labels),
                                                       n_hops=n_hops, target_nodes=target_nodes,
                                                       pca_dim=pca_dim, seed=seed)), f)
        print(f"[dataset saved] {Path(save_path).resolve()}  (graphs={len(graphs)})")

    return graphs

# ---------------------------
# GraphLIME: local explanations by perturbing node features
# - For a target node, perturb its features, keep other nodes same.
# - Query model probabilities for the predicted class, weight by exponential kernel on L2,
# - Fit a weighted Ridge linear model: perturbed_features -> model_score
# - Return coefficients as importances (signed).
# ---------------------------
def explain_node_graphlime(model, graph, node_idx, device='cpu', num_samples=500, sigma=0.1, alpha=1.0, random_state=0):
    """
    graph: dict(edge_index, X, y)
    node_idx: int (0..N-1)
    returns: dict(coeffs, intercept, scores_sampled, weights)
    """
    rng = np.random.default_rng(random_state)
    edge_index = torch.from_numpy(graph['edge_index']).long().to(device)
    X_all = graph['X'].astype(np.float32).copy()
    N, F = X_all.shape

    # baseline feature vector for node
    x0 = X_all[node_idx].astype(np.float32)

    # sample perturbations around x0 (normal noise)
    # we sample in PCA space (already compressed)
    noise = rng.normal(scale=sigma, size=(num_samples, F)).astype(np.float32)
    samples = x0[None, :] + noise  # [S, F]

    # build batch inputs where only node_idx features change
    # We'll call model S times; for speed we can batch but QNode may not accept batching.
    # So we do sequential calls; for small S it's OK. If slow, reduce num_samples.
    scores = np.zeros(num_samples, dtype=np.float32)
    for i in range(num_samples):
        Xp = X_all.copy()
        Xp[node_idx] = samples[i]
        # convert to tensor
        Xt = torch.from_numpy(Xp).float().to(device)
        logits = model(edge_index, Xt)  # [N, C]
        probs = torch.softmax(logits[node_idx], dim=0)
        # choose target class as model prediction on original x0
        scores[i] = float(probs.max().detach().cpu().numpy())


    # original score baseline for distances/weights
    # compute distances between samples and x0 in L2
    dists = np.linalg.norm(samples - x0[None, :], axis=1)
    # kernel weights
    weights = np.exp(- (dists ** 2) / (2 * (sigma ** 2) ) )

    # fit weighted linear model: samples -> scores
    clf = Ridge(alpha=alpha)
    clf.fit(samples, scores, sample_weight=weights)
    coeffs = clf.coef_.astype(float)   # shape (F,)
    intercept = float(clf.intercept_)

    return dict(coeffs=coeffs, intercept=intercept, samples=samples, scores=scores, weights=weights, x0=x0)

# ---------------------------
# Train / evaluate helpers
# ---------------------------
def train_epoch(model, graphs, optimizer, device="cpu"):
    model.train()
    total_loss = 0.0; total_nodes = 0
    for g in graphs:
        e = torch.from_numpy(g['edge_index']).long().to(device)
        X = torch.from_numpy(g['X']).float().to(device)
        y = torch.from_numpy(g['y']).long().to(device)
        logits = model(e, X)
        loss = F.cross_entropy(logits, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * X.shape[0]
        total_nodes += X.shape[0]
    return total_loss / max(1, total_nodes)

@torch.no_grad()
def evaluate(model, graphs, device="cpu"):
    model.eval()
    total_nodes = 0; total_loss = 0.0; correct = 0
    for g in graphs:
        e = torch.from_numpy(g['edge_index']).long().to(device)
        X = torch.from_numpy(g['X']).float().to(device)
        y = torch.from_numpy(g['y']).long().to(device)
        logits = model(e, X)
        loss = F.cross_entropy(logits, y)
        pred = logits.argmax(dim=1)
        correct += (pred == y).sum().item()
        total_loss += loss.item() * X.shape[0]
        total_nodes += X.shape[0]
    return total_loss / max(1, total_nodes), correct / max(1, total_nodes)

# ---------------------------
# Scaling experiment runner + GraphLIME explanation aggregation
# ---------------------------
# ---------------------------
# Scaling experiment runner + GraphLIME explanation aggregation
# ---------------------------
def run_full_experiment(
    N_list = [20,25,30,35],
    F_list = [10,15,20,25],
    num_graphs = 20,
    epochs = 10,
    lr = 0.03,
    seed = 42,
    base_dir = "./cora_multi_exp",
    use_gpu_qnode = True,
    graphlime_nodes = 10,   # number of nodes to explain per dataset (for aggregation)
    graphlime_samples = 400
):
    set_reproducibility(seed, strict=False)  # strict can be toggled
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    Path(base_dir).mkdir(parents=True, exist_ok=True)

    combos = list(itertools.product(N_list, F_list))
    summary = []
    timings = []
    explanations_summary = {}

    for (N, F) in combos:
        run_tag = f"N{N}_F{F}_seed{seed}"
        print(f"\n[RUN] {run_tag}  device={device}")

        ds_path = Path(base_dir)/"datasets"/f"cora_N{N}_F{F}_seed{seed}.pkl"
        ds_path.parent.mkdir(parents=True, exist_ok=True)
        if ds_path.exists():
            with open(ds_path, "rb") as f:
                payload = pickle.load(f)
            graphs = payload['graphs']
            print(f"Loaded dataset from {ds_path} (graphs={len(graphs)})")
        else:
            graphs = build_cora_ego_dataset(num_graphs=num_graphs, chosen_labels=(0,1,2),
                                           n_hops=2, target_nodes=N, pca_dim=F, seed=seed, save_path=str(ds_path))
        # Split
        n = len(graphs)
        ntr = int(0.6 * n); nval = int(0.2 * n)
        train_graphs = graphs[:ntr]; val_graphs = graphs[ntr:ntr+nval]; test_graphs = graphs[ntr+nval:]

        # model
        model = EDUQGCNodeClassifier(n_nodes=N, in_feats=F, T=2, seed=seed, use_gpu_qnode=use_gpu_qnode, num_classes=3).to(device)
        opt = torch.optim.Adam(model.parameters(), lr=lr)

        start_time = time.time()
        history = {'train_loss': [], 'val_loss': [], 'val_acc': []}
        best_val_acc = -1.0; best_state = None
        for ep in range(1, epochs+1):
            tr_loss = train_epoch(model, train_graphs, opt, device=device)
            val_loss, val_acc = evaluate(model, val_graphs, device=device)
            history['train_loss'].append(tr_loss); history['val_loss'].append(val_loss); history['val_acc'].append(val_acc)
            print(f"Epoch {ep:02d} | tr_loss={tr_loss:.3f} | val_loss={val_loss:.3f} | val_acc={val_acc:.3f}")
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_state = {k: v.detach().cpu().clone() for k,v in model.state_dict().items()}
        run_time = time.time() - start_time

        # load best
        if best_state is not None:
            model.load_state_dict(best_state)

        # eval final
        tr_loss, tr_acc = evaluate(model, train_graphs, device=device)
        va_loss, va_acc = evaluate(model, val_graphs, device=device)
        te_loss, te_acc = evaluate(model, test_graphs, device=device)

        # save model
        model_dir = Path(base_dir)/"models"
        model_dir.mkdir(exist_ok=True, parents=True)
        model_path = model_dir / f"eduqgc_{run_tag}.pt"
        torch.save(model.state_dict(), model_path)

        # save run info
        run_info = dict(tag=run_tag, N=N, F=F, seed=seed,
                        tr_loss=tr_loss, tr_acc=tr_acc,
                        val_loss=va_loss, val_acc=va_acc,
                        test_loss=te_loss, test_acc=te_acc,
                        runtime_sec=run_time, model_path=str(model_path), dataset_path=str(ds_path))
        summary.append(run_info)
        timings.append(run_info)

        # GraphLIME explanations
        expl_dir = Path(base_dir)/"explanations"/run_tag
        expl_dir.mkdir(parents=True, exist_ok=True)
        explanations_summary[run_tag] = []
        if len(test_graphs) > 0:
            gtest = test_graphs[0]
            nodes_to_explain = list(range(min(graphlime_nodes, gtest['X'].shape[0])))
            for node_idx in nodes_to_explain:
                out = explain_node_graphlime(model, gtest, node_idx, device=str(device),
                                             num_samples=graphlime_samples, sigma=0.2,
                                             alpha=1.0, random_state=seed+node_idx)
                coeffs = out['coeffs']
                explanations_summary[run_tag].append(dict(node=node_idx, coeffs=coeffs.tolist(), intercept=out['intercept']))
                k = min(10, len(coeffs))
                idxs = np.argsort(np.abs(coeffs))[::-1][:k]
                labels = [f"PC{i}" for i in idxs]
                vals = coeffs[idxs]
                plt.figure(figsize=(8,4))
                y_pos = np.arange(len(labels))
                plt.barh(y_pos, vals[::-1])
                plt.yticks(y_pos, labels[::-1])
                plt.xlabel("importance")
                plt.title(f"GraphLIME {run_tag} node {node_idx}")
                plt.tight_layout()
                png_path = expl_dir / f"glime_node{node_idx}.png"
                plt.savefig(png_path, dpi=200)
                plt.close()

        # --- ensure logs dir exists ---
        log_dir = Path(base_dir)/"logs"
        log_dir.mkdir(parents=True, exist_ok=True)

        # flush to disk per run
        with open(log_dir/f"{run_tag}.json", "w") as f:
            json.dump(run_info, f, indent=2)
        with open(Path(base_dir)/"explanations"/f"{run_tag}_explanations.json", "w") as f:
            json.dump(explanations_summary.get(run_tag, []), f, indent=2)

    # Save summary table
    out_summary = Path(base_dir)/"logs"/f"scaling_summary_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
    out_summary.parent.mkdir(parents=True, exist_ok=True)
    with open(out_summary, "w") as f:
        json.dump(summary, f, indent=2)
    print(f"\nAll runs finished. Summary saved at {out_summary}")

    # Plot runtime heatmap and accuracy heatmap aggregated across combos
    combos_sorted = sorted(summary, key=lambda r: (r['N'], r['F']))
    Ns = sorted(set(r['N'] for r in combos_sorted))
    Fs = sorted(set(r['F'] for r in combos_sorted))
    runtime_mat = np.zeros((len(Ns), len(Fs)))
    testacc_mat = np.zeros((len(Ns), len(Fs)))
    for r in combos_sorted:
        i = Ns.index(r['N']); j = Fs.index(r['F'])
        runtime_mat[i,j] = r['runtime_sec']
        testacc_mat[i,j] = r['test_acc']

    fig, axes = plt.subplots(1,2, figsize=(12,4))
    im = axes[0].imshow(runtime_mat, aspect='auto')
    axes[0].set_xticks(range(len(Fs))); axes[0].set_yticks(range(len(Ns)))
    axes[0].set_xticklabels(Fs); axes[0].set_yticklabels(Ns)
    axes[0].set_xlabel("F (PCA dim)"); axes[0].set_ylabel("N (nodes)")
    axes[0].set_title("Runtime (sec) per N,F")
    fig.colorbar(im, ax=axes[0])

    im2 = axes[1].imshow(testacc_mat, aspect='auto', vmin=0, vmax=1)
    axes[1].set_xticks(range(len(Fs))); axes[1].set_yticks(range(len(Ns)))
    axes[1].set_xticklabels(Fs); axes[1].set_yticklabels(Ns)
    axes[1].set_xlabel("F (PCA dim)"); axes[1].set_ylabel("N (nodes)")
    axes[1].set_title("Test accuracy per N,F")
    fig.colorbar(im2, ax=axes[1])
    plt.tight_layout()
    fig_path = Path(base_dir)/"logs"/f"scaling_summary_plot_{datetime.now().strftime('%Y%m%d-%H%M%S')}.png"
    plt.savefig(fig_path, dpi=200)
    plt.show()
    print(f"Saved plot: {fig_path}")

    # Save explanations summary
    expl_sum_path = Path(base_dir)/"explanations"/"explanations_summary.json"
    expl_sum_path.parent.mkdir(parents=True, exist_ok=True)
    with open(expl_sum_path, "w") as f:
        json.dump(explanations_summary, f, indent=2)

    return dict(summary=summary, timings=timings,
                explanations=explanations_summary,
                plot_path=str(fig_path),
                summary_path=str(out_summary))


# ---------------------------
# Example: run everything with small defaults
# ---------------------------
if __name__ == "__main__":
    # standard parameters
    N_list = [20,25]          # reduce for quick demo; change to [20,25,30,35] for full
    F_list = [10,15]          # reduce for quick demo; change to [10,15,20,25] for full
    NUM_GRAPHS = 20
    EPOCHS = 10
    LR = 0.03
    SEED = 42
    BASE_DIR = "./cora_multi_exp"

    results = run_full_experiment(
        N_list=N_list,
        F_list=F_list,
        num_graphs=NUM_GRAPHS,
        epochs=EPOCHS,
        lr=LR,
        seed=SEED,
        base_dir=BASE_DIR,
        use_gpu_qnode=True,
        graphlime_nodes=5,
        graphlime_samples=300
    )
    print("Done. Keys in results:", list(results.keys()))


[repro] Best-effort reproducibility (non-strict).

[RUN] N20_F10_seed42  device=cuda
Loaded dataset from cora_multi_exp/datasets/cora_N20_F10_seed42.pkl (graphs=20)
Epoch 01 | tr_loss=0.953 | val_loss=0.804 | val_acc=0.688
Epoch 02 | tr_loss=0.661 | val_loss=0.634 | val_acc=0.738
Epoch 03 | tr_loss=0.508 | val_loss=0.540 | val_acc=0.812
Epoch 04 | tr_loss=0.423 | val_loss=0.486 | val_acc=0.812
Epoch 05 | tr_loss=0.376 | val_loss=0.449 | val_acc=0.838
Epoch 06 | tr_loss=0.346 | val_loss=0.419 | val_acc=0.850
Epoch 07 | tr_loss=0.327 | val_loss=0.394 | val_acc=0.863
Epoch 08 | tr_loss=0.312 | val_loss=0.375 | val_acc=0.863
Epoch 09 | tr_loss=0.300 | val_loss=0.349 | val_acc=0.875
Epoch 10 | tr_loss=0.288 | val_loss=0.327 | val_acc=0.887

[RUN] N20_F15_seed42  device=cuda
Loaded dataset from cora_multi_exp/datasets/cora_N20_F15_seed42.pkl (graphs=20)
Epoch 01 | tr_loss=0.927 | val_loss=0.785 | val_acc=0.725
Epoch 02 | tr_loss=0.620 | val_loss=0.639 | val_acc=0.762
Epoch 03 | tr_loss=0.479

KeyboardInterrupt: 