In [None]:
# import math
# import random
# import numpy as np
# import torch
# import torch.nn.functional as F
# from torch import nn
# import networkx as nx
# import pennylane as qml

# # -----------------------------
# # Synthetic dataset
# # -----------------------------
# def make_synthetic_ring(N=12, noise_std=0.1, seed=0):
#     """
#     Creates an N-node ring graph.
#     First half nodes label=0, second half label=1.
#     Features = one-hot label + Gaussian noise.
#     """
#     rng = np.random.default_rng(seed)
#     G = nx.cycle_graph(N)
#     labels = np.zeros(N, dtype=np.int64)
#     labels[N // 2:] = 1

#     X = np.zeros((N, 2), dtype=np.float32)
#     for i in range(N):
#         oh = np.array([1.0, 0.0], dtype=np.float32) if labels[i] == 0 else np.array([0.0, 1.0], dtype=np.float32)
#         X[i] = oh + rng.normal(0.0, noise_std, size=2).astype(np.float32)

#     edges = np.array(list(G.edges()), dtype=np.int64).T
#     edges_rev = edges[::-1]
#     edge_index = np.concatenate([edges, edges_rev], axis=1)
#     return edge_index, X, labels

# # -----------------------------
# # Quantum Message-Passing Node Classifier
# # -----------------------------
# class QMessagePassingNodeClassifier(nn.Module):
#     def __init__(self, n_nodes, in_feats=2, T=2, seed=0, verbose=False):
#         super().__init__()
#         self.verbose = verbose
#         self.n_nodes = n_nodes
#         self.n_qubits_per_node = 1
#         self.total_qubits = n_nodes * self.n_qubits_per_node
#         self.T = T

#         torch.manual_seed(seed)
#         np.random.seed(seed)
#         random.seed(seed)

#         # Classical encoders: feature -> rotation angles
#         self.encoders = nn.ModuleList([
#             nn.Linear(in_feats, 2) for _ in range(T)
#         ])

#         # Shared edge phase parameter per layer
#         self.edge_phases = nn.ParameterList([
#             nn.Parameter(torch.randn(1) * 0.1) for _ in range(T)
#         ])

#         # Per-node readout head
#         self.readout = nn.ModuleList([
#             nn.Linear(1 + in_feats, 2) for _ in range(1)
#         ])

#         # Quantum device
#         self.dev = qml.device("default.qubit", wires=self.total_qubits, shots=None)

#         @qml.qnode(self.dev, interface="torch", diff_method="best")
#         def circuit(x_feat, edge_index, enc_alphas, enc_betas, edge_phi):
#             # Encode features
#             for i in range(self.n_nodes):
#                 qml.RX(enc_alphas[i], wires=i)
#                 qml.RY(enc_betas[i], wires=i)

#             # Entangle neighbors
#             E = edge_index.shape[1]
#             for e in range(E):
#                 u = int(edge_index[0, e].item())
#                 v = int(edge_index[1, e].item())
#                 qml.ControlledPhaseShift(edge_phi, wires=[u, v])

#             # Mixing
#             for i in range(self.n_nodes):
#                 qml.Hadamard(wires=i)

#             # Per-node measurement
#             return [qml.expval(qml.Z(i)) for i in range(self.n_nodes)]

#         self._circuit = circuit

#     def forward(self, edge_index_torch, x_torch):
#         if edge_index_torch.is_cuda:
#             edge_index_torch = edge_index_torch.cpu()
#         if x_torch.is_cuda:
#             x_torch = x_torch.cpu()

#         expvals = None
#         for t in range(self.T):
#             enc = self.encoders[t](x_torch)  # [N,2] float32
#             enc_alphas = enc[:, 0]
#             enc_betas = enc[:, 1]
#             edge_phi = self.edge_phases[t].squeeze()

#             # QNode returns float64 — cast to float32
#             layer_out = self._circuit(x_torch, edge_index_torch, enc_alphas, enc_betas, edge_phi)
#             layer_out = torch.stack(layer_out, dim=0).float()

#             expvals = layer_out

#         expvals = expvals.unsqueeze(1)  # [N,1]
#         readin = torch.cat([expvals, x_torch], dim=1).float()
#         logits = self.readout[0](readin)  # [N,2]
#         return logits

# # -----------------------------
# # Train & evaluate
# # -----------------------------
# def train_and_eval(N=12, T=2, epochs=150, lr=0.05, seed=0, verbose=False):
#     edge_index_np, X_np, y_np = make_synthetic_ring(N=N, seed=seed)
#     edge_index = torch.from_numpy(edge_index_np).long()
#     X = torch.from_numpy(X_np).float()
#     y = torch.from_numpy(y_np).long()

#     model = QMessagePassingNodeClassifier(n_nodes=N, in_feats=X.shape[1], T=T, seed=seed, verbose=verbose)
#     opt = torch.optim.Adam(model.parameters(), lr=lr)

#     all_idx = np.arange(N)
#     rng = np.random.default_rng(seed)
#     rng.shuffle(all_idx)
#     n_train = int(0.7 * N)
#     train_idx = torch.from_numpy(all_idx[:n_train])
#     val_idx = torch.from_numpy(all_idx[n_train:])

#     for ep in range(1, epochs + 1):
#         model.train()
#         logits = model(edge_index, X)
#         loss = F.cross_entropy(logits[train_idx], y[train_idx])

#         opt.zero_grad()
#         loss.backward()
#         opt.step()

#         if verbose and (ep % 20 == 0 or ep == 1):
#             model.eval()
#             with torch.no_grad():
#                 pred = logits.argmax(dim=1)
#                 train_acc = (pred[train_idx] == y[train_idx]).float().mean().item()
#                 val_acc = (pred[val_idx] == y[val_idx]).float().mean().item()
#             print(f"Epoch {ep:03d}  loss={loss.item():.4f}  train_acc={train_acc:.3f}  val_acc={val_acc:.3f}")

#     model.eval()
#     with torch.no_grad():
#         logits = model(edge_index, X)
#         pred = logits.argmax(dim=1)
#         acc = (pred == y).float().mean().item()

#     print("\nSynthetic dataset description:")
#     print(f"- Nodes: {N}")
#     print("- Graph: ring (cycle)")
#     print("- Labels: first half=0, second half=1")
#     print("- Features: one-hot label + noise")
#     print("\nNode predictions:")
#     for i in range(N):
#         print(f"Node {i:02d}  label={y[i].item()}  pred={pred[i].item()}")
#     print(f"\nOverall node accuracy: {acc:.3f}")

# if __name__ == "__main__":
#     _ = train_and_eval(N=12, T=2, epochs=150, lr=0.05, seed=7, verbose=True)


In [None]:
# import math
# import random
# import numpy as np
# import torch
# import torch.nn.functional as F
# from torch import nn
# import networkx as nx
# import pennylane as qml

# # =========================================================
# # Utilities
# # =========================================================

# def set_seeds(seed=0):
#     torch.manual_seed(seed)
#     np.random.seed(seed)
#     random.seed(seed)

# def confusion_matrix(y_true, y_pred, num_classes=2):
#     cm = np.zeros((num_classes, num_classes), dtype=int)
#     for t, p in zip(y_true, y_pred):
#         cm[int(t), int(p)] += 1
#     return cm

# def macro_accuracy(y_true, y_pred, num_classes=2):
#     cm = confusion_matrix(y_true, y_pred, num_classes=num_classes)
#     per_class = []
#     for c in range(num_classes):
#         support = cm[c, :].sum()
#         correct = cm[c, c]
#         acc_c = correct / support if support > 0 else 0.0
#         per_class.append(acc_c)
#     return float(np.mean(per_class))

# def print_confusion(cm):
#     num_classes = cm.shape[0]
#     print("Confusion matrix (rows=true, cols=pred):")
#     for i in range(num_classes):
#         row = " ".join(f"{cm[i, j]:4d}" for j in range(num_classes))
#         print(f"class {i}: {row}")

# # =========================================================
# # Synthetic multi-graph dataset (node classification)
# # =========================================================

# def make_synthetic_graph(n_min=10, n_max=30, p_range=(0.08, 0.2), q_range=(0.01, 0.08),
#                          feat_noise=0.2, weak_signal_scale=0.4, seed=None):
#     """
#     Generate a single synthetic graph:
#     - Two communities with intra/inter probs p,q (p>q).
#     - Labels: community id (0/1).
#     - Features per node: [deg_norm, clustering, weak_signal] + noise
#       where weak_signal = +weak_signal_scale for class 1, -weak_signal_scale for class 0, plus noise.
#     Returns:
#       edge_index [2, E], X [N, 3], y [N]
#     """
#     rng = np.random.default_rng(seed)
#     N = rng.integers(n_min, n_max + 1)
#     n0 = N // 2
#     n1 = N - n0

#     p = rng.uniform(*p_range)
#     q = rng.uniform(*q_range)
#     if p < q:
#         p, q = q, p

#     G = nx.Graph()
#     G.add_nodes_from(range(N))
#     comm = np.zeros(N, dtype=int)
#     comm[n0:] = 1

#     for i in range(N):
#         for j in range(i + 1, N):
#             prob = p if comm[i] == comm[j] else q
#             if rng.random() < prob:
#                 G.add_edge(i, j)

#     edges = np.array(list(G.edges()), dtype=np.int64).T if G.number_of_edges() > 0 else np.zeros((2,0), dtype=np.int64)
#     edges_rev = edges[::-1]
#     edge_index = np.concatenate([edges, edges_rev], axis=1) if edges.shape[1] > 0 else edges

#     y = comm.astype(np.int64)

#     degs = np.array([G.degree(i) for i in range(N)], dtype=np.float32)
#     deg_norm = (degs / max(1, N - 1)).astype(np.float32)
#     clustering = np.array(list(nx.clustering(G).values()), dtype=np.float32)
#     weak_signal = np.where(y == 1, +weak_signal_scale, -weak_signal_scale).astype(np.float32)

#     noise = rng.normal(0.0, feat_noise, size=(N, 3)).astype(np.float32)
#     X = np.stack([deg_norm, clustering, weak_signal], axis=1) + noise

#     return edge_index, X.astype(np.float32), y

# def make_dataset(num_graphs=30, seed=0, n_min=10, n_max=30):
#     set_seeds(seed)
#     ds = []
#     for g in range(num_graphs):
#         edge_index, X, y = make_synthetic_graph(n_min=n_min, n_max=n_max, seed=seed + 1000 + g)
#         ds.append(dict(edge_index=edge_index, X=X, y=y))
#     return ds

# def train_val_test_split_graphs(num_graphs, splits=(0.6, 0.2, 0.2), seed=0):
#     idx = np.arange(num_graphs)
#     rng = np.random.default_rng(seed)
#     rng.shuffle(idx)
#     n_train = int(splits[0] * num_graphs)
#     n_val = int(splits[1] * num_graphs)
#     train_idx = idx[:n_train]
#     val_idx = idx[n_train:n_train + n_val]
#     test_idx = idx[n_train + n_val:]
#     return train_idx, val_idx, test_idx

# # =========================================================
# # Quantum Message-Passing Node Classifier (PennyLane)
# # =========================================================

# class QMessagePassingNodeClassifier(nn.Module):
#     """
#     Quantum message-passing node classifier (1 qubit/node):
#     - T layers: encode (RX/RY) -> edge entanglers (ControlledPhaseShift) -> mixing (H).
#     - Per-node readout: expval(Z) -> linear head with original features to logits.
#     """
#     def __init__(self, n_nodes, in_feats=3, T=2, seed=0, verbose=False, use_gpu_qnode=False):
#         super().__init__()
#         self.verbose = verbose
#         self.n_nodes = n_nodes
#         self.total_qubits = n_nodes
#         self.T = T

#         torch.manual_seed(seed)
#         np.random.seed(seed)
#         random.seed(seed)

#         # Classical parts (will be moved to device by .to(device))
#         self.encoders = nn.ModuleList([nn.Linear(in_feats, 2) for _ in range(T)])
#         self.edge_phases = nn.ParameterList([nn.Parameter(torch.randn(1) * 0.1) for _ in range(T)])
#         self.readout = nn.Linear(1 + in_feats, 2)

#         # Quantum device (statevector)
#         dev_name = "default.qubit"
#         if use_gpu_qnode:
#             dev_name = "lightning.gpu"  # requires pennylane-lightning[gpu]
#         self.dev = qml.device(dev_name, wires=self.total_qubits, shots=None)

#         @qml.qnode(self.dev, interface="torch", diff_method="best")
#         def circuit(edge_index, enc_alphas, enc_betas, edge_phi):
#             # Encode
#             for i in range(self.n_nodes):
#                 qml.RX(enc_alphas[i], wires=i)
#                 qml.RY(enc_betas[i], wires=i)
#             # Edge entanglers
#             E = edge_index.shape[1]
#             for e in range(E):
#                 u = int(edge_index[0, e].item())
#                 v = int(edge_index[1, e].item())
#                 if u != v:
#                     qml.ControlledPhaseShift(edge_phi, wires=[u, v])
#             # Mixing
#             for i in range(self.n_nodes):
#                 qml.Hadamard(wires=i)
#             # Readout
#             return [qml.expval(qml.Z(i)) for i in range(self.n_nodes)]

#         self._circuit = circuit

#     def forward(self, edge_index_torch, x_torch):
#         """
#         edge_index_torch: [2, E] long tensor (can be on CUDA); will be cast to CPU for QNode
#         x_torch: [N, F] float tensor on model device (CPU/CUDA)
#         """
#         # Model/device context
#         model_device = next(self.parameters()).device

#         # 1) Run encoders ON MODEL DEVICE (avoid device mismatch)
#         # Ensure x is on model device and correct dtype for Linear
#         x_model = x_torch.to(model_device).float()

#         # We’ll need both the model-device version (for readout concat)
#         # and CPU versions of encoder outputs for the QNode.
#         last_expvals_cpu = None

#         for t in range(self.T):
#             enc = self.encoders[t](x_model)           # [N,2] on model_device
#             enc_alphas_model = enc[:, 0]              # model_device
#             enc_betas_model  = enc[:, 1]              # model_device
#             edge_phi_model   = self.edge_phases[t].squeeze()  # model_device

#             # 2) Prepare CPU copies for QNode call
#             # PennyLane Torch interface is most robust with CPU tensors
#             edge_index_cpu = edge_index_torch.detach().to("cpu")
#             enc_alphas_cpu = enc_alphas_model.detach().to("cpu")
#             enc_betas_cpu  = enc_betas_model.detach().to("cpu")
#             edge_phi_cpu   = edge_phi_model.detach().to("cpu")

#             # 3) QNode returns list of expvals -> stack and cast to float32
#             layer_out = self._circuit(edge_index_cpu, enc_alphas_cpu, enc_betas_cpu, edge_phi_cpu)
#             last_expvals_cpu = torch.stack(layer_out, dim=0).float()  # [N] on CPU

#         # 4) Move expvals back to model device for readout
#         expvals = last_expvals_cpu.to(model_device).unsqueeze(1)  # [N,1] on model_device

#         # 5) Concatenate with original features (on model device) and classify
#         readin = torch.cat([expvals, x_model], dim=1)   # [N, 1+F]
#         logits = self.readout(readin)                   # [N,2] on model_device
#         return logits


# # =========================================================
# # Training/evaluation over graphs (graph-level split; node-level supervision)
# # =========================================================

# def train_epoch(model, graphs, optimizer, device="cpu"):
#     model.train()
#     total_loss = 0.0
#     total_nodes = 0
#     for g in graphs:
#         edge_index = torch.from_numpy(g["edge_index"]).long().to(device)
#         X = torch.from_numpy(g["X"]).float().to(device)
#         y = torch.from_numpy(g["y"]).long().to(device)

#         logits = model(edge_index, X)
#         loss = F.cross_entropy(logits, y)

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         total_loss += loss.item() * X.shape[0]
#         total_nodes += X.shape[0]
#     return total_loss / max(1, total_nodes)

# @torch.no_grad()
# def evaluate(model, graphs, device="cpu"):
#     model.eval()
#     all_true = []
#     all_pred = []
#     total_loss = 0.0
#     total_nodes = 0

#     for g in graphs:
#         edge_index = torch.from_numpy(g["edge_index"]).long().to(device)
#         X = torch.from_numpy(g["X"]).float().to(device)
#         y = torch.from_numpy(g["y"]).long().to(device)

#         logits = model(edge_index, X)
#         loss = F.cross_entropy(logits, y)

#         pred = logits.argmax(dim=1).cpu().numpy()
#         all_pred.extend(list(pred))
#         all_true.extend(list(y.cpu().numpy()))

#         total_loss += loss.item() * X.shape[0]
#         total_nodes += X.shape[0]

#     avg_loss = total_loss / max(1, total_nodes)
#     all_true = np.array(all_true)
#     all_pred = np.array(all_pred)
#     mac_acc = macro_accuracy(all_true, all_pred, num_classes=2)
#     cm = confusion_matrix(all_true, all_pred, num_classes=2)
#     return avg_loss, mac_acc, cm, all_true, all_pred

# # =========================================================
# # Main experiment
# # =========================================================

# def run_experiment(
#     num_graphs=45,
#     splits=(0.6, 0.2, 0.2),
#     T=2,
#     epochs=50,
#     lr=0.03,
#     seed=42,
#     device="cpu",
#     use_gpu_qnode=False,
#     verbose=True
# ):
#     set_seeds(seed)

#     # To share one quantum circuit across graphs, fix N (qubit count)
#     N_fixed = 20
#     dataset = []
#     for g in range(num_graphs):
#         edge_index, X, y = make_synthetic_graph(n_min=N_fixed, n_max=N_fixed, seed=seed + 2000 + g)
#         dataset.append(dict(edge_index=edge_index, X=X, y=y))
#     train_idx, val_idx, test_idx = train_val_test_split_graphs(num_graphs, splits=splits, seed=seed)
#     train_graphs = [dataset[i] for i in train_idx]
#     val_graphs   = [dataset[i] for i in val_idx]
#     test_graphs  = [dataset[i] for i in test_idx]

#     model = QMessagePassingNodeClassifier(
#         n_nodes=N_fixed, in_feats=3, T=T, seed=seed, verbose=verbose, use_gpu_qnode=use_gpu_qnode
#     ).to(device)
#     opt = torch.optim.Adam(model.parameters(), lr=lr)

#     best_val = (-1.0, None)  # (macro-acc, state_dict)
#     for ep in range(1, epochs + 1):
#         tr_loss = train_epoch(model, train_graphs, opt, device=device)
#         val_loss, val_mac, val_cm, _, _ = evaluate(model, val_graphs, device=device)
#         if verbose:
#             print(f"Epoch {ep:03d} | train_loss={tr_loss:.4f} | val_loss={val_loss:.4f} | val_macro_acc={val_mac:.3f}")

#         if val_mac > best_val[0]:
#             best_val = (val_mac, {k: v.detach().cpu().clone() for k, v in model.state_dict().items()})

#     if best_val[1] is not None:
#         model.load_state_dict(best_val[1])

#     tr_loss, tr_mac, tr_cm, tr_y, tr_pred = evaluate(model, train_graphs, device=device)
#     va_loss, va_mac, va_cm, va_y, va_pred = evaluate(model, val_graphs, device=device)
#     te_loss, te_mac, te_cm, te_y, te_pred = evaluate(model, test_graphs, device=device)

#     print("\nResults:")
#     print(f"- Train: macro-acc={tr_mac:.3f}, loss={tr_loss:.4f}")
#     print_confusion(tr_cm)
#     print(f"- Val:   macro-acc={va_mac:.3f}, loss={va_loss:.4f}")
#     print_confusion(va_cm)
#     print(f"- Test:  macro-acc={te_mac:.3f}, loss={te_loss:.4f}")
#     print_confusion(te_cm)

#     return {
#         "model": model,
#         "splits": (train_idx, val_idx, test_idx),
#         "metrics": {
#             "train": dict(loss=tr_loss, macro_acc=tr_mac, cm=tr_cm, y=tr_y, pred=tr_pred),
#             "val":   dict(loss=va_loss, macro_acc=va_mac, cm=va_cm, y=va_y, pred=va_pred),
#             "test":  dict(loss=te_loss, macro_acc=te_mac, cm=te_cm, y=te_y, pred=te_pred),
#         }
#     }

# if __name__ == "__main__":
#     device = "cuda" if torch.cuda.is_available() else "cpu"
#     # Set use_gpu_qnode=True to run the quantum simulator on GPU (requires pennylane-lightning[gpu])
#     results = run_experiment(
#         num_graphs=45,
#         splits=(0.6, 0.2, 0.2),
#         T=2,
#         epochs=50,
#         lr=0.03,
#         seed=42,
#         device=device,
#         use_gpu_qnode=True,  # set True if you installed a GPU-backed PennyLane device
#         verbose=True
#     )


In [None]:
# import math
# import random
# import numpy as np
# import torch
# import torch.nn.functional as F
# from torch import nn
# import networkx as nx
# import pennylane as qml

# # =========================================================
# # Utilities
# # =========================================================

# def set_seeds(seed=0):
#     torch.manual_seed(seed)
#     np.random.seed(seed)
#     random.seed(seed)

# def confusion_matrix(y_true, y_pred, num_classes=2):
#     cm = np.zeros((num_classes, num_classes), dtype=int)
#     for t, p in zip(y_true, y_pred):
#         cm[int(t), int(p)] += 1
#     return cm

# def macro_accuracy(y_true, y_pred, num_classes=2):
#     cm = confusion_matrix(y_true, y_pred, num_classes=num_classes)
#     per_class = []
#     for c in range(num_classes):
#         support = cm[c, :].sum()
#         correct = cm[c, c]
#         acc_c = correct / support if support > 0 else 0.0
#         per_class.append(acc_c)
#     return float(np.mean(per_class))

# def print_confusion(cm):
#     num_classes = cm.shape[0]
#     print("Confusion matrix (rows=true, cols=pred):")
#     for i in range(num_classes):
#         row = " ".join(f"{cm[i, j]:4d}" for j in range(num_classes))
#         print(f"class {i}: {row}")

# # =========================================================
# # Synthetic multi-graph dataset (node classification)
# # =========================================================

# def make_synthetic_graph(n_min=10, n_max=30, p_range=(0.08, 0.2), q_range=(0.01, 0.08),
#                          feat_noise=0.2, weak_signal_scale=0.4, seed=None):
#     """
#     Generate a single synthetic graph:
#     - Two communities with intra/inter probs p,q (p>q).
#     - Labels: community id (0/1).
#     - Features per node: [deg_norm, clustering, weak_signal] + noise
#     """
#     rng = np.random.default_rng(seed)
#     N = rng.integers(n_min, n_max + 1)
#     n0 = N // 2
#     n1 = N - n0

#     p = rng.uniform(*p_range)
#     q = rng.uniform(*q_range)
#     if p < q:
#         p, q = q, p

#     G = nx.Graph()
#     G.add_nodes_from(range(N))
#     comm = np.zeros(N, dtype=int)
#     comm[n0:] = 1

#     for i in range(N):
#         for j in range(i + 1, N):
#             prob = p if comm[i] == comm[j] else q
#             if rng.random() < prob:
#                 G.add_edge(i, j)

#     edges = np.array(list(G.edges()), dtype=np.int64).T if G.number_of_edges() > 0 else np.zeros((2,0), dtype=np.int64)
#     edges_rev = edges[::-1]
#     edge_index = np.concatenate([edges, edges_rev], axis=1) if edges.shape[1] > 0 else edges

#     y = comm.astype(np.int64)

#     degs = np.array([G.degree(i) for i in range(N)], dtype=np.float32)
#     deg_norm = (degs / max(1, N - 1)).astype(np.float32)
#     clustering = np.array(list(nx.clustering(G).values()), dtype=np.float32)
#     weak_signal = np.where(y == 1, +weak_signal_scale, -weak_signal_scale).astype(np.float32)

#     noise = rng.normal(0.0, feat_noise, size=(N, 3)).astype(np.float32)
#     X = np.stack([deg_norm, clustering, weak_signal], axis=1) + noise

#     return edge_index, X.astype(np.float32), y

# def make_dataset(num_graphs=30, seed=0, n_min=10, n_max=30):
#     set_seeds(seed)
#     ds = []
#     for g in range(num_graphs):
#         edge_index, X, y = make_synthetic_graph(n_min=n_min, n_max=n_max, seed=seed + 1000 + g)
#         ds.append(dict(edge_index=edge_index, X=X, y=y))
#     return ds

# def train_val_test_split_graphs(num_graphs, splits=(0.6, 0.2, 0.2), seed=0):
#     idx = np.arange(num_graphs)
#     rng = np.random.default_rng(seed)
#     rng.shuffle(idx)
#     n_train = int(splits[0] * num_graphs)
#     n_val = int(splits[1] * num_graphs)
#     train_idx = idx[:n_train]
#     val_idx = idx[n_train:n_train + n_val]
#     test_idx = idx[n_train + n_val:]
#     return train_idx, val_idx, test_idx

# # =========================================================
# # EDU-QGC Node Classifier (PennyLane)
# # =========================================================

# class EDUQGCNodeClassifier(nn.Module):
#     """
#     EDU-QGC-style quantum graph node classifier (1 qubit/node).

#     Each of T layers:
#       - L_node: shared feature encoder (W,b) -> shared node unitary U_node on every node
#       - L_edge: shared two-qubit EDU on every edge:
#                 (U_pre ⊗ U_pre) -> CRZ(phi) -> (U_post ⊗ U_post)

#     Readout: expval(Z_i) per node -> (optional feat skip) -> shared linear head -> logits.
#     """
#     def __init__(self, n_nodes, in_feats=3, T=2, seed=0, verbose=False, use_gpu_qnode=False):
#         super().__init__()
#         self.verbose = verbose
#         self.n_nodes = n_nodes
#         self.T = T
#         self.in_feats = in_feats

#         torch.manual_seed(seed)
#         np.random.seed(seed)
#         random.seed(seed)

#         # ---------- Shared feature encoders per layer ----------
#         # angles_i = X_i @ W^T + b  -> [alpha_i, beta_i]
#         self.enc_W = nn.ParameterList([nn.Parameter(torch.randn(2, in_feats) * 0.2) for _ in range(T)])
#         self.enc_b = nn.ParameterList([nn.Parameter(torch.zeros(2)) for _ in range(T)])

#         # ---------- Shared node unitary per layer ----------
#         # U_node = RZ(gamma) RX(delta)  (same for all nodes in the layer)
#         self.node_gamma = nn.ParameterList([nn.Parameter(torch.randn(1) * 0.1) for _ in range(T)])
#         self.node_delta = nn.ParameterList([nn.Parameter(torch.randn(1) * 0.1) for _ in range(T)])

#         # ---------- Shared EDU parameters per layer ----------
#         # EDU(u,v) = (U_pre ⊗ U_pre) -> CRZ(phi) -> (U_post ⊗ U_post)
#         # U_pre  = RY(theta_pre) RZ(psi_pre)
#         # U_post = RZ(psi_post)  RY(theta_post)
#         self.edge_phi   = nn.ParameterList([nn.Parameter(torch.randn(1) * 0.1) for _ in range(T)])
#         self.pre_theta  = nn.ParameterList([nn.Parameter(torch.randn(1) * 0.1) for _ in range(T)])
#         self.pre_psi    = nn.ParameterList([nn.Parameter(torch.randn(1) * 0.1) for _ in range(T)])
#         self.post_theta = nn.ParameterList([nn.Parameter(torch.randn(1) * 0.1) for _ in range(T)])
#         self.post_psi   = nn.ParameterList([nn.Parameter(torch.randn(1) * 0.1) for _ in range(T)])

#         # ---------- Readout ----------
#         self.use_feat_skip = True
#         readin_dim = 1 + (in_feats if self.use_feat_skip else 0)
#         self.readout = nn.Linear(readin_dim, 2)

#         # ---------- Quantum device ----------
#         dev_name = "default.qubit"
#         if use_gpu_qnode:
#             dev_name = "lightning.gpu"  # requires pennylane-lightning[gpu]
#         self.dev = qml.device(dev_name, wires=n_nodes, shots=None)

#         @qml.qnode(self.dev, interface="torch", diff_method="best")
#         def circuit(edge_index, X,
#                     enc_W_list, enc_b_list,
#                     node_gamma_list, node_delta_list,
#                     edge_phi_list, pre_theta_list, pre_psi_list, post_theta_list, post_psi_list):
#             """
#             edge_index: [2, E] (long)
#             X: [N, F] (float)
#             All parameter lists: torch tensors (on CPU) with gradients.
#             """
#             N = X.shape[0]

#             def shared_node_unitary(layer_idx, wire):
#                 qml.RZ(node_gamma_list[layer_idx], wires=wire)
#                 qml.RX(node_delta_list[layer_idx], wires=wire)

#             def U_pre(layer_idx, wire):
#                 qml.RY(pre_theta_list[layer_idx], wires=wire)
#                 qml.RZ(pre_psi_list[layer_idx], wires=wire)

#             def U_post(layer_idx, wire):
#                 qml.RZ(post_psi_list[layer_idx], wires=wire)
#                 qml.RY(post_theta_list[layer_idx], wires=wire)

#             def EDU_edge(layer_idx, u, v):
#                 U_pre(layer_idx, u); U_pre(layer_idx, v)
#                 qml.CRZ(edge_phi_list[layer_idx], wires=[u, v])
#                 U_post(layer_idx, u); U_post(layer_idx, v)

#             # Stack T equivariant layers
#             for t in range(self.T):
#                 # L_node: shared encoder + shared single-qubit unitary
#                 W = enc_W_list[t]      # [2, F]
#                 b = enc_b_list[t]      # [2]
#                 # angles = X @ W^T + b   -> [N,2]
#                 angles = qml.math.dot(X, qml.math.transpose(W)) + b

#                 for i in range(N):
#                     qml.RX(angles[i, 0], wires=i)
#                     qml.RY(angles[i, 1], wires=i)
#                     shared_node_unitary(t, i)

#                 # L_edge: shared EDU on every edge
#                 E = edge_index.shape[1]
#                 for e in range(E):
#                     u = int(edge_index[0, e].item())
#                     v = int(edge_index[1, e].item())
#                     if u != v:
#                         EDU_edge(t, u, v)

#             # Nodewise Z-basis readout
#             return [qml.expval(qml.Z(i)) for i in range(N)]

#         self._circuit = circuit

#     def forward(self, edge_index_torch, x_torch):
#         """
#         edge_index_torch: [2, E] long tensor (can be on CUDA); passed to CPU for QNode
#         x_torch: [N, F] float tensor
#         """
#         model_device = next(self.parameters()).device
#         x_model = x_torch.to(model_device).float()

#         # Move data + params to CPU for the QNode WITHOUT detaching (preserve grads)
#         edge_index_cpu = edge_index_torch.to("cpu")
#         X_cpu = x_model.to("cpu")

#         enc_W_list   = [p.to("cpu") for p in self.enc_W]
#         enc_b_list   = [p.to("cpu") for p in self.enc_b]
#         node_g_list  = [p.to("cpu") for p in self.node_gamma]
#         node_d_list  = [p.to("cpu") for p in self.node_delta]
#         edge_phi_list  = [p.to("cpu") for p in self.edge_phi]
#         pre_th_list    = [p.to("cpu") for p in self.pre_theta]
#         pre_psi_list   = [p.to("cpu") for p in self.pre_psi]
#         post_th_list   = [p.to("cpu") for p in self.post_theta]
#         post_psi_list  = [p.to("cpu") for p in self.post_psi]

#         # Run quantum circuit
#        # Run quantum circuit
#         layer_out = self._circuit(
#             edge_index_cpu, X_cpu,
#             enc_W_list, enc_b_list,
#             node_g_list, node_d_list,
#             edge_phi_list, pre_th_list, pre_psi_list, post_th_list, post_psi_list
#         )

#         # Convert list of expvals → [N,1] tensor
#         expvals = torch.stack(layer_out, dim=0).float().to(model_device)  # [N] or [N,1?]
#         if expvals.dim() == 1:
#             expvals = expvals.unsqueeze(1)        # [N,1]
#         elif expvals.dim() == 2 and expvals.shape[1] == 1:
#             pass  # already [N,1]
#         else:
#             expvals = expvals.squeeze(-1).unsqueeze(1)  # force [N,1]

#         # Readout
#         if self.use_feat_skip:
#             readin = torch.cat([expvals, x_model], dim=1)  # [N, 1+F]
#         else:
#             readin = expvals

#         logits = self.readout(readin)                     # [N,2]
#         return logits

# # =========================================================
# # Training/evaluation over graphs (graph-level split; node-level supervision)
# # =========================================================

# def train_epoch(model, graphs, optimizer, device="cpu"):
#     model.train()
#     total_loss = 0.0
#     total_nodes = 0
#     for g in graphs:
#         edge_index = torch.from_numpy(g["edge_index"]).long().to(device)
#         X = torch.from_numpy(g["X"]).float().to(device)
#         y = torch.from_numpy(g["y"]).long().to(device)

#         logits = model(edge_index, X)
#         loss = F.cross_entropy(logits, y)

#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         total_loss += loss.item() * X.shape[0]
#         total_nodes += X.shape[0]
#     return total_loss / max(1, total_nodes)

# @torch.no_grad()
# def evaluate(model, graphs, device="cpu"):
#     model.eval()
#     all_true = []
#     all_pred = []
#     total_loss = 0.0
#     total_nodes = 0

#     for g in graphs:
#         edge_index = torch.from_numpy(g["edge_index"]).long().to(device)
#         X = torch.from_numpy(g["X"]).float().to(device)
#         y = torch.from_numpy(g["y"]).long().to(device)

#         logits = model(edge_index, X)
#         loss = F.cross_entropy(logits, y)

#         pred = logits.argmax(dim=1).cpu().numpy()
#         all_pred.extend(list(pred))
#         all_true.extend(list(y.cpu().numpy()))

#         total_loss += loss.item() * X.shape[0]
#         total_nodes += X.shape[0]

#     avg_loss = total_loss / max(1, total_nodes)
#     all_true = np.array(all_true)
#     all_pred = np.array(all_pred)
#     mac_acc = macro_accuracy(all_true, all_pred, num_classes=2)
#     cm = confusion_matrix(all_true, all_pred, num_classes=2)
#     return avg_loss, mac_acc, cm, all_true, all_pred

# # =========================================================
# # Main experiment (same task as before)
# # =========================================================

# def run_experiment(
#     num_graphs=45,
#     splits=(0.6, 0.2, 0.2),
#     T=2,
#     epochs=50,
#     lr=0.03,
#     seed=42,
#     device="cpu",
#     use_gpu_qnode=False,
#     verbose=True
# ):
#     set_seeds(seed)

#     # Fix N to share the same circuit size across graphs
#     N_fixed = 20
#     dataset = []
#     for g in range(num_graphs):
#         edge_index, X, y = make_synthetic_graph(n_min=N_fixed, n_max=N_fixed, seed=seed + 2000 + g)
#         dataset.append(dict(edge_index=edge_index, X=X, y=y))
#     train_idx, val_idx, test_idx = train_val_test_split_graphs(num_graphs, splits=splits, seed=seed)
#     train_graphs = [dataset[i] for i in train_idx]
#     val_graphs   = [dataset[i] for i in val_idx]
#     test_graphs  = [dataset[i] for i in test_idx]

#     model = EDUQGCNodeClassifier(
#         n_nodes=N_fixed, in_feats=3, T=T, seed=seed, verbose=verbose, use_gpu_qnode=use_gpu_qnode
#     ).to(device)
#     opt = torch.optim.Adam(model.parameters(), lr=lr)

#     best_val = (-1.0, None)  # (macro-acc, state_dict)
#     for ep in range(1, epochs + 1):
#         tr_loss = train_epoch(model, train_graphs, opt, device=device)
#         val_loss, val_mac, val_cm, _, _ = evaluate(model, val_graphs, device=device)
#         if verbose:
#             print(f"Epoch {ep:03d} | train_loss={tr_loss:.4f} | val_loss={val_loss:.4f} | val_macro_acc={val_mac:.3f}")

#         if val_mac > best_val[0]:
#             best_val = (val_mac, {k: v.detach().cpu().clone() for k, v in model.state_dict().items()})

#     if best_val[1] is not None:
#         model.load_state_dict(best_val[1])

#     tr_loss, tr_mac, tr_cm, tr_y, tr_pred = evaluate(model, train_graphs, device=device)
#     va_loss, va_mac, va_cm, va_y, va_pred = evaluate(model, val_graphs, device=device)
#     te_loss, te_mac, te_cm, te_y, te_pred = evaluate(model, test_graphs, device=device)

#     print("\nResults:")
#     print(f"- Train: macro-acc={tr_mac:.3f}, loss={tr_loss:.4f}")
#     print_confusion(tr_cm)
#     print(f"- Val:   macro-acc={va_mac:.3f}, loss={va_loss:.4f}")
#     print_confusion(va_cm)
#     print(f"- Test:  macro-acc={te_mac:.3f}, loss={te_loss:.4f}")
#     print_confusion(te_cm)

#     return {
#         "model": model,
#         "splits": (train_idx, val_idx, test_idx),
#         "metrics": {
#             "train": dict(loss=tr_loss, macro_acc=tr_mac, cm=tr_cm, y=tr_y, pred=tr_pred),
#             "val":   dict(loss=va_loss, macro_acc=va_mac, cm=va_cm, y=va_y, pred=va_pred),
#             "test":  dict(loss=te_loss, macro_acc=te_mac, cm=te_cm, y=te_y, pred=te_pred),
#         }
#     }

# if __name__ == "__main__":
#     device = "cuda" if torch.cuda.is_available() else "cpu"
#     # Set use_gpu_qnode=True to use a GPU-backed PennyLane device if installed (pennylane-lightning[gpu])
#     results = run_experiment(
#         num_graphs=45,
#         splits=(0.6, 0.2, 0.2),
#         T=2,
#         epochs=50,
#         lr=0.03,
#         seed=42,
#         device=device,
#         use_gpu_qnode=True,
#         verbose=True
#     )


#EDU-QGC with Repro

In [None]:
# edu_qgc_repro_full.py
# -*- coding: utf-8 -*-
"""
EDU-QGC Node Classification with strict reproducibility and dataset saving.

Run:
  python edu_qgc_repro_full.py

Requirements:
  pip install torch pennylane pennylane-lightning[gpu] networkx matplotlib
(If you don't have CUDA or lightning.gpu, the code falls back to default.qubit CPU simulator.)
"""

import os
# MUST set CUBLAS_WORKSPACE_CONFIG before importing torch to enable deterministic cuBLAS
os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")

import sys
import math
import random
import pickle
from pathlib import Path
from datetime import datetime

import numpy as np

# Now import torch after setting env var
import torch
import torch.nn.functional as F
from torch import nn

import networkx as nx
import pennylane as qml
import matplotlib.pyplot as plt

# ---------------------------
# Reproducibility helper
# ---------------------------
def set_global_determinism(seed: int):
    """
    Set seeds and flags to make experiments deterministic where feasible.
    Must set CUBLAS_WORKSPACE_CONFIG BEFORE importing torch (done at top).
    """
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)

    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # Enable deterministic algorithms (will raise if impossible)
    try:
        torch.use_deterministic_algorithms(True)
    except Exception as e:
        # Give clear advice and re-raise
        print("Failed to enable deterministic algorithms:", e)
        print("Make sure CUBLAS_WORKSPACE_CONFIG is set before Python starts.")
        print("If running in Jupyter, restart the kernel after setting the env var.")
        raise

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# ---------------------------
# Utilities
# ---------------------------
def confusion_matrix(y_true, y_pred, num_classes=2):
    cm = np.zeros((num_classes, num_classes), dtype=int)
    for t, p in zip(y_true, y_pred):
        cm[int(t), int(p)] += 1
    return cm

def macro_accuracy(y_true, y_pred, num_classes=2):
    cm = confusion_matrix(y_true, y_pred, num_classes=num_classes)
    per_class = []
    for c in range(num_classes):
        support = cm[c, :].sum()
        correct = cm[c, c]
        acc_c = correct / support if support > 0 else 0.0
        per_class.append(acc_c)
    return float(np.mean(per_class))

def print_confusion(cm):
    num_classes = cm.shape[0]
    print("Confusion matrix (rows=true, cols=pred):")
    for i in range(num_classes):
        row = " ".join(f"{cm[i, j]:4d}" for j in range(num_classes))
        print(f"class {i}: {row}")

def save_dataset(dataset, path: Path):
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "wb") as f:
        pickle.dump(dataset, f)
    print(f"[dataset saved] {path.resolve()} (num_graphs={len(dataset)})")

def load_dataset(path: Path):
    with open(path, "rb") as f:
        ds = pickle.load(f)
    print(f"[dataset loaded] {path.resolve()} (num_graphs={len(ds)})")
    return ds

# ---------------------------
# Dataset generation
# ---------------------------
def make_synthetic_graph(n_min=10, n_max=30, p_range=(0.08, 0.2), q_range=(0.01, 0.08),
                         feat_noise=0.2, weak_signal_scale=0.4, seed=None, shuffle_labels=False):
    rng = np.random.default_rng(seed)
    N = int(rng.integers(n_min, n_max + 1))
    n0 = N // 2

    p = float(rng.uniform(*p_range))
    q = float(rng.uniform(*q_range))
    if p < q:
        p, q = q, p

    G = nx.Graph()
    G.add_nodes_from(range(N))
    comm = np.zeros(N, dtype=int)
    comm[n0:] = 1

    for i in range(N):
        for j in range(i + 1, N):
            prob = p if comm[i] == comm[j] else q
            if rng.random() < prob:
                G.add_edge(i, j)

    edges = np.array(list(G.edges()), dtype=np.int64).T if G.number_of_edges() > 0 else np.zeros((2,0), dtype=np.int64)
    edges_rev = edges[::-1]
    edge_index = np.concatenate([edges, edges_rev], axis=1) if edges.shape[1] > 0 else edges

    y = comm.astype(np.int64)
    if shuffle_labels:
        rng.shuffle(y)

    degs = np.array([G.degree(i) for i in range(N)], dtype=np.float32)
    deg_norm = (degs / max(1, N - 1)).astype(np.float32)
    clustering = np.array(list(nx.clustering(G).values()), dtype=np.float32)
    weak_signal = np.where(y == 1, +weak_signal_scale, -weak_signal_scale).astype(np.float32)

    noise = rng.normal(0.0, feat_noise, size=(N, 3)).astype(np.float32)
    X = np.stack([deg_norm, clustering, weak_signal], axis=1) + noise

    return dict(edge_index=edge_index.astype(np.int64), X=X.astype(np.float32), y=y.astype(np.int64))

def train_val_test_split_graphs(num_graphs, splits=(0.6, 0.2, 0.2), seed=0):
    idx = np.arange(num_graphs)
    rng = np.random.default_rng(seed)
    rng.shuffle(idx)
    n_train = int(splits[0] * num_graphs)
    n_val = int(splits[1] * num_graphs)
    train_idx = idx[:n_train]
    val_idx = idx[n_train:n_train + n_val]
    test_idx = idx[n_train + n_val:]
    return train_idx, val_idx, test_idx

# ---------------------------
# EDU-QGC model
# ---------------------------
class EDUQGCNodeClassifier(nn.Module):
    def __init__(self, n_nodes, in_feats=3, T=2, seed=0, use_gpu_qnode=True, use_feat_skip=True):
        super().__init__()
        self.n_nodes = n_nodes
        self.T = T
        self.use_feat_skip = use_feat_skip

        # Parameter packing
        self.enc_W = nn.Parameter(torch.randn(T, 2, in_feats) * 0.08)  # [T,2,F]
        self.enc_b = nn.Parameter(torch.randn(T, 2) * 0.02)           # [T,2]

        self.edge_phase  = nn.Parameter(torch.randn(T) * 0.08)
        self.pre_theta   = nn.Parameter(torch.randn(T) * 0.08)
        self.pre_psi     = nn.Parameter(torch.randn(T) * 0.08)
        self.post_theta  = nn.Parameter(torch.randn(T) * 0.08)
        self.post_psi    = nn.Parameter(torch.randn(T) * 0.08)

        readin_dim = 1 + in_feats if use_feat_skip else 1
        self.readout = nn.Linear(readin_dim, 2)

        use_cuda = torch.cuda.is_available()
        qdev_name = "lightning.gpu" if (use_gpu_qnode and use_cuda) else "default.qubit"
        self.dev = qml.device(qdev_name, wires=n_nodes, shots=None)

        @qml.qnode(self.dev, interface="torch", diff_method="best")
        def circuit(edge_index, X, enc_W, enc_b,
                    edge_phase, pre_theta, pre_psi, post_theta, post_psi):
            # Loop layers
            for t in range(self.T):
                enc_out = X @ enc_W[t].T + enc_b[t]  # [N,2]
                alphas = enc_out[:, 0]; betas = enc_out[:, 1]
                for i in range(self.n_nodes):
                    qml.RX(alphas[i], wires=i)
                    qml.RY(betas[i], wires=i)

                for i in range(self.n_nodes):
                    qml.RZ(pre_psi[t], wires=i)
                    qml.RX(pre_theta[t], wires=i)

                E = edge_index.shape[1]
                for e in range(E):
                    u = int(edge_index[0, e].item()); v = int(edge_index[1, e].item())
                    if u != v:
                        qml.ControlledPhaseShift(edge_phase[t], wires=[u, v])

                for i in range(self.n_nodes):
                    qml.RZ(post_psi[t], wires=i)
                    qml.RX(post_theta[t], wires=i)

            return [qml.expval(qml.Z(i)) for i in range(self.n_nodes)]

        self._circuit = circuit

    def forward(self, edge_index_torch, x_torch):
        model_device = next(self.parameters()).device
        edge_index = edge_index_torch.to(model_device)
        X = x_torch.to(model_device).float()

        # Call QNode (PennyLane handles device movement if needed)
        layer_out = self._circuit(
            edge_index, X,
            self.enc_W, self.enc_b,
            self.edge_phase,
            self.pre_theta, self.pre_psi,
            self.post_theta, self.post_psi
        )

        expvals = torch.stack(layer_out, dim=0).float().to(model_device)
        if expvals.dim() == 1:
            expvals = expvals.unsqueeze(1)
        elif expvals.dim() == 2 and expvals.shape[1] == 1:
            pass
        else:
            expvals = expvals.squeeze(-1).unsqueeze(1)

        if self.use_feat_skip:
            readin = torch.cat([expvals, X], dim=1)
        else:
            readin = expvals

        logits = self.readout(readin)
        return logits

# ---------------------------
# Train / Eval
# ---------------------------
def train_epoch(model, graphs, optimizer, device="cpu"):
    model.train()
    total_loss = 0.0; total_nodes = 0
    for g in graphs:
        edge_index = torch.from_numpy(g["edge_index"]).long().to(device)
        X = torch.from_numpy(g["X"]).float().to(device)
        y = torch.from_numpy(g["y"]).long().to(device)

        logits = model(edge_index, X)
        loss = F.cross_entropy(logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * X.shape[0]
        total_nodes += X.shape[0]
    return total_loss / max(1, total_nodes)

@torch.no_grad()
def evaluate(model, graphs, device="cpu"):
    model.eval()
    all_true, all_pred = [], []
    total_loss, total_nodes = 0.0, 0
    for g in graphs:
        edge_index = torch.from_numpy(g["edge_index"]).long().to(device)
        X = torch.from_numpy(g["X"]).float().to(device)
        y = torch.from_numpy(g["y"]).long().to(device)

        logits = model(edge_index, X)
        loss = F.cross_entropy(logits, y)

        pred = logits.argmax(dim=1).cpu().numpy()
        all_pred.extend(pred)
        all_true.extend(list(y.cpu().numpy()))

        total_loss += loss.item() * X.shape[0]
        total_nodes += X.shape[0]

    avg_loss = total_loss / max(1, total_nodes)
    all_true = np.array(all_true); all_pred = np.array(all_pred)
    mac_acc = macro_accuracy(all_true, all_pred, num_classes=2)
    cm = confusion_matrix(all_true, all_pred, num_classes=2)
    return avg_loss, mac_acc, cm, all_true, all_pred

# ---------------------------
# Experiment runner (full)
# ---------------------------
def run_experiment(
    num_graphs=45,
    splits=(0.6, 0.2, 0.2),
    T=2,
    epochs=50,
    lr=0.03,
    seed=42,
    device="cuda",
    use_gpu_qnode=True,
    save_dataset_path: str = "./data/eduqgc_dataset_seed{seed}.pkl",
    plot_curves: bool = True
):
    # Reproducibility
    set_global_determinism(seed)
    print("[repro] seed:", seed)
    print("[repro] deterministic algos enabled:", torch.are_deterministic_algorithms_enabled())
    print("[repro] cudnn.deterministic:", torch.backends.cudnn.deterministic)
    print("[repro] cudnn.benchmark:", torch.backends.cudnn.benchmark)
    print("[repro] CUBLAS_WORKSPACE_CONFIG:", os.environ.get("CUBLAS_WORKSPACE_CONFIG"))

    # Data
    N_fixed = 20
    ds_path = Path(save_dataset_path.format(seed=seed))
    if ds_path.exists():
        dataset = load_dataset(ds_path)
    else:
        dataset = []
        for g in range(num_graphs):
            graph = make_synthetic_graph(n_min=N_fixed, n_max=N_fixed, seed=seed + 2000 + g, shuffle_labels=False)
            dataset.append(graph)
        save_dataset(dataset, ds_path)

    train_idx, val_idx, test_idx = train_val_test_split_graphs(num_graphs, splits=splits, seed=seed)
    train_graphs = [dataset[i] for i in train_idx]
    val_graphs = [dataset[i] for i in val_idx]
    test_graphs = [dataset[i] for i in test_idx]

    # Device selection
    device = device if (device == "cpu" or torch.cuda.is_available()) else "cpu"
    device = torch.device(device)
    model = EDUQGCNodeClassifier(n_nodes=N_fixed, in_feats=3, T=T, seed=seed,
                                use_gpu_qnode=use_gpu_qnode, use_feat_skip=True).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    history = {"train_loss": [], "val_loss": [], "val_acc": []}
    best_val = (-1.0, None)
    for ep in range(1, epochs + 1):
        tr_loss = train_epoch(model, train_graphs, opt, device=device)
        val_loss, val_mac, val_cm, _, _ = evaluate(model, val_graphs, device=device)

        print(f"Epoch {ep:03d} | train_loss={tr_loss:.4f} | val_loss={val_loss:.4f} | val_macro_acc={val_mac:.3f}")

        history["train_loss"].append(tr_loss)
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_mac)

        if val_mac > best_val[0]:
            best_val = (val_mac, {k: v.detach().cpu().clone() for k, v in model.state_dict().items()})

    # Load best
    if best_val[1] is not None:
        model.load_state_dict(best_val[1])

    # Final eval
    tr_loss, tr_mac, tr_cm, tr_y, tr_pred = evaluate(model, train_graphs, device=device)
    va_loss, va_mac, va_cm, va_y, va_pred = evaluate(model, val_graphs, device=device)
    te_loss, te_mac, te_cm, te_y, te_pred = evaluate(model, test_graphs, device=device)

    print("\nResults:")
    print(f"- Train: macro-acc={tr_mac:.3f}, loss={tr_loss:.4f}")
    print_confusion(tr_cm)
    print(f"- Val:   macro-acc={va_mac:.3f}, loss={va_loss:.4f}")
    print_confusion(va_cm)
    print(f"- Test:  macro-acc={te_mac:.3f}, loss={te_loss:.4f}")
    print_confusion(te_cm)

    # Save model + plots
    out_dir = Path("outputs")
    out_dir.mkdir(exist_ok=True)
    ts = datetime.now().strftime("%Y%m%d-%H%M%S")
    model_path = out_dir / f"eduqgc_model_seed{seed}_{ts}.pt"
    torch.save(model.state_dict(), model_path)
    print(f"[saved model] {model_path.resolve()}")

    plt_path = out_dir / f"training_curves_seed{seed}_{ts}.png"
    if plot_curves:
        fig, ax = plt.subplots(1, 2, figsize=(12,4))
        ax[0].plot(history["train_loss"], label="train loss")
        ax[0].plot(history["val_loss"], label="val loss")
        ax[0].set_xlabel("epoch"); ax[0].set_ylabel("loss"); ax[0].legend(); ax[0].set_title("Loss")

        ax[1].plot(history["val_acc"], label="val macro-acc")
        ax[1].set_xlabel("epoch"); ax[1].set_ylabel("macro-accuracy"); ax[1].legend(); ax[1].set_title("Validation Macro-Accuracy")

        plt.tight_layout()
        plt.savefig(plt_path, dpi=200)
        print(f"[saved plot] {plt_path.resolve()}")
        plt.show()

    return {
        "model": model,
        "splits": (train_idx, val_idx, test_idx),
        "history": history,
        "metrics": {
            "train": dict(loss=tr_loss, macro_acc=tr_mac, cm=tr_cm, y=tr_y, pred=tr_pred),
            "val":   dict(loss=va_loss, macro_acc=va_mac, cm=va_cm, y=va_y, pred=va_pred),
            "test":  dict(loss=te_loss, macro_acc=te_mac, cm=te_cm, y=te_y, pred=te_pred),
        },
        "dataset_path": str(ds_path.resolve()),
        "model_path": str(model_path.resolve()),
        "plot_path": str(plt_path.resolve()),
    }

# ---------------------------
# Main
# ---------------------------
if __name__ == "__main__":
    SEED = 42
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    USE_GPU_QNODE = True
    NUM_GRAPHS = 45

    results = run_experiment(
        num_graphs=NUM_GRAPHS,
        T=2,
        epochs=50,
        lr=0.03,
        seed=SEED,
        device=DEVICE,
        use_gpu_qnode=USE_GPU_QNODE,
        save_dataset_path="./data/eduqgc_dataset_seed{seed}.pkl",
        plot_curves=True
    )
    print("Done. Results summary keys:", list(results.keys()))


EDU-QGC with Visualization

In [None]:
# make_synthetic_graph(seed=42)

In [None]:
# y = np.zeros(10, dtype=np.int64)
# y[len(y) // 2:] = 1
# y

# rng = np.random.default_rng(42)
# rng.shuffle(y)

# y