# GNN Extrapolation to 6-Qubit Circuits

This notebook evaluates the ability of the GNN model trained on 5 qubit quantum circuits to generalize to unseen 6-qubit circuits. The test includes both:

- **Zero-shot extrapolation**: direct evaluation without additional training.
- **Few-shot fine-tuning**: limited adaptation using a small subset of 6-qubit circuits.

Datasets include Class A (variational) and Class B (QAOA-like) circuits under both noiseless and noisy conditions. Performance is evaluated using KL divergence, classical fidelity, mean squared error (MSE), and Wasserstein.

In [1]:
# Suppress warnings
import warnings
warnings.filterwarnings("ignore")

In [2]:
# Imports
import os
import random
import numpy as np
import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch.nn import Linear, Sequential, ReLU, Dropout, BatchNorm1d, LayerNorm
from torch_geometric.nn import TransformerConv, global_mean_pool, global_max_pool
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_networkx
import networkx as nx
from scipy.stats import wasserstein_distance
from collections import defaultdict

## Seeding and device

In [3]:
def set_all_seeds(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_all_seeds(42)
device = torch.device("cuda")
print("Using device:", device)

Using device: cuda


## Metrics

In [4]:
def kl_divergence_vec(p, q, eps=1e-8):
    """Computes element-wise KL divergence between two probability distributions p and q."""
    p = p + eps
    q = q + eps
    return (p * (p.log() - q.log())).sum(dim=1)


def classical_fidelity_vec(p, q, eps=1e-8):
    """Computes classical fidelity between two distributions as the inner product of square roots."""
    p = p + eps
    q = q + eps
    return (p.sqrt() * q.sqrt()).sum(dim=1)


def wasserstein_vec(p, q):
    """Computes Wasserstein-1 distance for each pair of distributions in a batch, skipping invalid samples."""
    x = np.arange(p.shape[1])
    p_np = p.detach().cpu().numpy()
    q_np = q.detach().cpu().numpy()
    results = []
    for i in range(p_np.shape[0]):
        if np.sum(p_np[i]) > 0 and np.sum(q_np[i]) > 0:
            try:
                d = wasserstein_distance(x, x, p_np[i], q_np[i])
                if np.isfinite(d):
                    results.append(d)
            except Exception:
                continue
    return np.array(results) if len(results) > 0 else np.array([np.nan])


def mse_vec(p, q):
    """Computes mean squared error between two batches of vectors with shape correction."""
    if p.shape != q.shape:
        if p.ndim == 2 and p.shape[0] == 1 and q.ndim == 1:
            q = q.unsqueeze(0)
        elif q.ndim == 2 and q.shape[0] == 1 and p.ndim == 1:
            p = p.unsqueeze(0)
        else:
            raise ValueError(f"Shape mismatch in mse_vec: p {p.shape}, q {q.shape}")
    return ((p - q) ** 2).mean(dim=1)


def normalize_distribution(tensor, dim=1, eps=1e-8):
    """Normalizes a tensor along the specified dimension to form a probability distribution."""
    if tensor.dim() == 1:
        return tensor / (tensor.sum() + eps)
    return tensor / (tensor.sum(dim=dim, keepdim=True) + eps)

## Topological Position

In [5]:
def add_topological_position_feature(data):
    """
    Appends a normalized topological position feature to each node in a DAG-based circuit graph.

    The feature encodes each node's relative position in a topological sort of the graph.
    """
    G = to_networkx(data, to_undirected=False)
    topo_order = list(nx.topological_sort(G))
    pos = torch.zeros((data.num_nodes, 1), dtype=torch.float32)
    for i, node_id in enumerate(topo_order):
        pos[node_id] = i / (len(topo_order) - 1) if len(topo_order) > 1 else 0.0
    data.x = torch.cat([data.x, pos], dim=1)
    return data

## Dataset Loading

In [6]:
def load_single_dataset_6q(noise_type, circuit_class):
    """
    Loads a single 6-qubit dataset (class + noise group) and adds metadata + position feature.
    """
    fname = f"dataset_6q_{noise_type}_{circuit_class}.pt"
    fpath = os.path.join("../datasets/6-qubit", noise_type, circuit_class, fname)
    if not os.path.exists(fpath):
        raise FileNotFoundError(f"Missing 6q dataset: {fpath}")
    data_list = torch.load(fpath)
    for g in data_list:
        g.circuit_class = circuit_class
        g.noise_regime = noise_type
        g.n_qubits = 6
        g = add_topological_position_feature(g)
    return data_list


def load_6q_all_groups():
    """
    Loads all 4 groups (A/B × noisy/noiseless) of 6-qubit circuit graphs.
    Returns a single combined list.
    """
    all_data = []
    for cls in ["classA", "classB"]:
        for noise in ["noiseless", "noisy"]:
            group = load_single_dataset_6q(noise, cls)
            all_data.extend(group)
    random.shuffle(all_data)
    return all_data

## GNN MODEL

In [7]:
class GraphTransformer(torch.nn.Module):
    """
    Graph neural network model using TransformerConv layers to predict output distributions of quantum circuits.

    Applies multiple TransformerConv blocks with residual connections, followed by global pooling and a deep MLP head.
    """
    def __init__(self, node_in_dim, global_u_dim, edge_attr_dim,
                 hidden_dim, output_dim, n_qubits, heads=3, dropout=0.1):
        super().__init__()
        self.input_dim = node_in_dim + global_u_dim

        self.convs = torch.nn.ModuleList([
            TransformerConv(self.input_dim if i == 0 else hidden_dim * heads, hidden_dim, heads=heads,
                           dropout=dropout, edge_dim=edge_attr_dim)
            for i in range(n_qubits)
        ])
        self.bns = torch.nn.ModuleList([
            BatchNorm1d(hidden_dim * heads)
            for _ in range(n_qubits)
        ])
        self.norm = LayerNorm(hidden_dim * heads)
        self.dropout = Dropout(dropout)
        self.relu = ReLU()

        # Widen and deepen MLP
        self.mlp = Sequential(
            Linear(2 * hidden_dim * heads, 2 * hidden_dim * heads),
            ReLU(),
            Dropout(dropout),
            Linear(2 * hidden_dim * heads, hidden_dim * heads),
            ReLU(),
            Dropout(dropout),
            Linear(hidden_dim * heads, output_dim)
        )

    def forward(self, data):
        x, edge_index, batch, u, edge_attr = data.x, data.edge_index, data.batch, data.u, data.edge_attr

        if u.dim() == 1:
            batch_size = batch.max().item() + 1
            u = u.view(batch_size, -1)
        u_per_node = u[batch]

        x = torch.cat([x, u_per_node], dim=1)
        x_in = x
        for i in range(len(self.convs)):
            x_out = self.dropout(self.relu(self.bns[i](self.convs[i](x_in, edge_index, edge_attr))))
            if i > 0:
                x_out = x_out + x_in
            x_in = x_out
        x_final = self.norm(x_in)

        x_mean = global_mean_pool(x_final, batch)
        x_max = global_max_pool(x_final, batch)
        x_pooled = torch.cat([x_mean, x_max], dim=1)

        return self.mlp(x_pooled)

In [8]:
def build_5q_model_for_6q_output(sample, hidden_dim_2q=48, heads=3, dropout=0.1):
    """Builds a 5-qubit GNN model adapted for 6-qubit output size using scaled hidden dimensions."""
    node_in_dim = sample.x.shape[1]
    global_u_dim = sample.u.shape[0]
    edge_attr_dim = sample.edge_attr.shape[1]
    output_dim = 64
    n_qubits = 5

    hidden_dim = int(hidden_dim_2q * (1 + 0.75 * (n_qubits - 2)))

    return GraphTransformer(
        node_in_dim=node_in_dim,
        global_u_dim=global_u_dim,
        edge_attr_dim=edge_attr_dim,
        hidden_dim=hidden_dim,
        output_dim=output_dim,
        n_qubits=n_qubits,
        heads=heads,
        dropout=dropout
    )

## Load 5 qubits model weights

In [9]:
def load_5q_weights(model, path):
    """Loads weights from a 5-qubit GNN model."""
    pretrained = torch.load(path, map_location='cpu')
    model_state = model.state_dict()
    filtered = {k: v for k, v in pretrained.items() if k in model_state and v.shape == model_state[k].shape}
    model_state.update(filtered)
    model.load_state_dict(model_state)
    return model

## Zero-shot

In [10]:
def evaluate_model_zero_shot(model, data_list):
    """Performs zero-shot evaluation on 6-qubit circuits using the pretrained 5-qubit model."""
    model.eval()
    loader = DataLoader(data_list, batch_size=32, shuffle=False)
    all_mse, all_kl, all_fi, all_wass = [], [], [], []
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            pred = model(batch)
            target = batch.y
            target = target.view(pred.shape)
            pred_prob = F.softmax(pred, dim=1)
            target_prob = normalize_distribution(target)
            if torch.isnan(pred_prob).any() or torch.isnan(target_prob).any():
                continue
            all_mse.append(mse_vec(pred_prob, target_prob).cpu().numpy())
            all_kl.append(kl_divergence_vec(target_prob, pred_prob).cpu().numpy())
            all_fi.append(classical_fidelity_vec(target_prob, pred_prob).cpu().numpy())
            all_wass.append(wasserstein_vec(pred_prob, target_prob))

    return {
        "MSE": float(np.mean(np.concatenate(all_mse))) if all_mse else np.nan,
        "KL": float(np.mean(np.concatenate(all_kl))) if all_kl else np.nan,
        "Fidelity": float(np.mean(np.concatenate(all_fi))) if all_fi else np.nan,
        "Wasserstein": float(np.mean(np.concatenate(all_wass))) if all_wass else np.nan,
    }

In [11]:
# Load 6-qubit dataset
data_6q = load_6q_all_groups()
print(f"Loaded {len(data_6q)} samples.")
sample_6q = data_6q[0]

# Build and load model
model_6q = build_5q_model_for_6q_output(sample_6q)
model_6q = load_5q_weights(model_6q, "../models/gnn_models/5q_gnn.pt")
model_6q.to(device)

# Run zero-shot evaluation
print("Evaluation")
zero_shot_metrics = evaluate_model_zero_shot(model_6q, data_6q)

for k, v in zero_shot_metrics.items():
    print(f"{k}: {v:.6f}")

Loaded 2000 samples.
Evaluation
MSE: 0.000866
KL: 0.896023
Fidelity: 0.757905
Wasserstein: 7.031290


## Few-Shot

In [12]:
def stratified_few_shot_split(data_list, n_per_group=100):
    """Creates a stratified few-shot train/val split across circuit class and noise regime."""
    group_buckets = defaultdict(list)
    for g in data_list:
        key = (g.noise_regime, g.circuit_class)
        group_buckets[key].append(g)
    
    train_set, val_set = [], []
    for key, items in group_buckets.items():
        random.shuffle(items)
        train_set.extend(items[:n_per_group])
        val_set.extend(items[n_per_group:])
    
    return train_set, val_set


few_shot_train, few_shot_val = stratified_few_shot_split(data_6q, n_per_group=50)

In [13]:
def kl_loss(pred_prob, target_prob, eps=1e-8):
    """Computes mean KL divergence loss between predicted and target probability distributions."""
    pred_prob = pred_prob + eps
    target_prob = target_prob + eps
    return torch.mean(torch.sum(target_prob * (target_prob.log() - pred_prob.log()), dim=1))


def finetune_gnn(model, train_data, val_data, epochs=20, lr=1e-4):
    """Performs few-shot fine-tuning of the GNN using KL divergence loss."""
    model.train()
    optimizer = Adam(model.parameters(), lr=lr)
    train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=32)

    for epoch in range(epochs):
        model.train()
        train_losses = []
        for batch in train_loader:
            batch = batch.to(device)
            pred = model(batch)
            pred_prob = F.softmax(pred, dim=1)

            target = torch.stack([g.y for g in batch.to_data_list()]).to(device)
            target_prob = normalize_distribution(target)

            loss = kl_loss(pred_prob, target_prob)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())

        # Val KL loss
        model.eval()
        val_losses = []
        with torch.no_grad():
            for batch in val_loader:
                batch = batch.to(device)
                pred = model(batch)
                pred_prob = F.softmax(pred, dim=1)

                target = torch.stack([g.y for g in batch.to_data_list()]).to(device)
                target_prob = normalize_distribution(target)

                loss = kl_loss(pred_prob, target_prob)
                val_losses.append(loss.item())

        print(f"Epoch {epoch+1} | Train KL Loss: {np.mean(train_losses):.6f} | Val KL Loss: {np.mean(val_losses):.6f}")

    #Eval
    all_preds, all_targets = [], []
    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(device)
            pred_prob = F.softmax(model(batch), dim=1)

            target = torch.stack([g.y for g in batch.to_data_list()]).to(device)
            target_prob = normalize_distribution(target)

            all_preds.append(pred_prob)
            all_targets.append(target_prob)

        preds = torch.cat(all_preds, dim=0)
        targets = torch.cat(all_targets, dim=0)

    return model, preds, targets

In [14]:
model_finetuned, preds, targets = finetune_gnn(model_6q, few_shot_train, few_shot_val)
few_shot_metrics = {
    "MSE": float(mse_vec(preds, targets).mean().item()),
    "KL": float(kl_divergence_vec(targets, preds).mean().item()),
    "Fidelity": float(classical_fidelity_vec(targets, preds).mean().item()),
    "Wasserstein": float(wasserstein_vec(preds, targets).mean().item())
}
print("Metrics:", few_shot_metrics)

Epoch 1 | Train KL Loss: 0.914996 | Val KL Loss: 0.883256
Epoch 2 | Train KL Loss: 0.906543 | Val KL Loss: 0.879700
Epoch 3 | Train KL Loss: 0.890341 | Val KL Loss: 0.877100
Epoch 4 | Train KL Loss: 0.906791 | Val KL Loss: 0.873458
Epoch 5 | Train KL Loss: 0.874387 | Val KL Loss: 0.869623
Epoch 6 | Train KL Loss: 0.867210 | Val KL Loss: 0.865018
Epoch 7 | Train KL Loss: 0.844165 | Val KL Loss: 0.860535
Epoch 8 | Train KL Loss: 0.913162 | Val KL Loss: 0.857326
Epoch 9 | Train KL Loss: 0.863229 | Val KL Loss: 0.855532
Epoch 10 | Train KL Loss: 0.812205 | Val KL Loss: 0.852619
Epoch 11 | Train KL Loss: 0.861643 | Val KL Loss: 0.849727
Epoch 12 | Train KL Loss: 0.867958 | Val KL Loss: 0.847630
Epoch 13 | Train KL Loss: 0.878863 | Val KL Loss: 0.847009
Epoch 14 | Train KL Loss: 0.836234 | Val KL Loss: 0.845868
Epoch 15 | Train KL Loss: 0.843282 | Val KL Loss: 0.845993
Epoch 16 | Train KL Loss: 0.846929 | Val KL Loss: 0.846275
Epoch 17 | Train KL Loss: 0.885862 | Val KL Loss: 0.847277
Epoch 