# CNN Extrapolation to 6-Qubit Circuits
 
This notebook evaluates the ability of the CNN model trained on 5 qubit quantum circuits to generalize to unseen 6-qubit circuits. The test includes both:

- **Zero-shot extrapolation**: direct evaluation without additional training.
- **Few-shot fine-tuning**: limited adaptation using a small subset of 6-qubit circuits.

Datasets include Class A (variational) and Class B (QAOA-like) circuits under both noiseless and noisy conditions. Performance is evaluated using KL divergence, classical fidelity, mean squared error (MSE), and Wasserstein.

In [1]:
# Suppress warnings
import warnings
warnings.filterwarnings("ignore")

In [2]:
# Imports
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader
from scipy.stats import wasserstein_distance
from collections import defaultdict

In [3]:
# Reproducibility and device setup
def set_all_seeds(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_all_seeds(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


## Metrics

In [4]:
def kl_divergence_vec(p, q, eps=1e-8):
    """Computes element-wise KL divergence between two probability distributions p and q."""
    p = p + eps
    q = q + eps
    return (p * (p.log() - q.log())).sum(dim=1)


def classical_fidelity_vec(p, q, eps=1e-8):
    """Computes classical fidelity between two distributions as the inner product of square roots."""
    p = p + eps
    q = q + eps
    return (p.sqrt() * q.sqrt()).sum(dim=1)


def wasserstein_vec(p, q):
    """Computes Wasserstein-1 distance for each pair of distributions in a batch, skipping invalid samples."""
    x = np.arange(p.shape[1])
    p_np = p.detach().cpu().numpy()
    q_np = q.detach().cpu().numpy()
    results = []
    for i in range(p_np.shape[0]):
        if np.sum(p_np[i]) > 0 and np.sum(q_np[i]) > 0:
            try:
                d = wasserstein_distance(x, x, p_np[i], q_np[i])
                if np.isfinite(d):
                    results.append(d)
            except Exception:
                continue
    return np.array(results) if len(results) > 0 else np.array([np.nan])


def mse_vec(p, q):
    """Computes mean squared error between two batches of vectors with shape correction."""
    if p.shape != q.shape:
        if p.ndim == 2 and p.shape[0] == 1 and q.ndim == 1:
            q = q.unsqueeze(0)
        elif q.ndim == 2 and q.shape[0] == 1 and p.ndim == 1:
            p = p.unsqueeze(0)
        else:
            raise ValueError(f"Shape mismatch in mse_vec: p {p.shape}, q {q.shape}")
    return ((p - q) ** 2).mean(dim=1)


def normalize_distribution(tensor, dim=1, eps=1e-8):
    """Normalizes a tensor along the specified dimension to form a probability distribution."""
    if tensor.dim() == 1:
        return tensor / (tensor.sum() + eps)
    return tensor / (tensor.sum(dim=dim, keepdim=True) + eps)

## Dataset Loading

In [5]:
def load_single_dataset_6q(noise_type, circuit_class):
    """
    Loads a single 6-qubit dataset (class + noise group).
    """
    fname = f"dataset_6q_{noise_type}_{circuit_class}.pt"
    fpath = os.path.join("../datasets/6-qubit", noise_type, circuit_class, fname)
    if not os.path.exists(fpath):
        raise FileNotFoundError(f"Missing 6q dataset: {fpath}")
    data_list = torch.load(fpath)
    for g in data_list:
        g.circuit_class = circuit_class
        g.noise_regime = noise_type
        g.n_qubits = 6
    return data_list


def load_6q_all_groups():
    """
    Loads all 4 groups (A/B × noisy/noiseless) of 6-qubit circuit datasets.
    Returns a combined list.
    """
    all_data = []
    for cls in ["classA", "classB"]:
        for noise in ["noiseless", "noisy"]:
            group = load_single_dataset_6q(noise, cls)
            all_data.extend(group)
    random.shuffle(all_data)
    return all_data

## Sequence Conversion for CNN Input

In [6]:
def convert_graph_to_sequence(graph_data, max_len=40):
    """
    Converts a graph-based quantum circuit into a fixed-length sequence of gate feature vectors.

    Each gate in the circuit is represented by a feature vector (from graph_data.x).
    """
    x = graph_data.x
    n_gates, feat_dim = x.shape
    pad_len = max(0, max_len - n_gates)
    if pad_len > 0:
        padding = torch.zeros((pad_len, feat_dim), dtype=x.dtype)
        x_padded = torch.cat([x, padding], dim=0)
    else:
        x_padded = x[:max_len]
    return x_padded, graph_data.u, graph_data.y


class SequenceCircuitDataset(torch.utils.data.Dataset):
    """
    A dataset that transforms a list of quantum circuit graphs into sequences suitable for 1D CNN input.

    Each circuit is represented as a padded sequence of gate-level feature vectors,
    along with global features and target probability distributions.
    """
    def __init__(self, data_list, max_len=40):
        self.data_list = data_list
        self.max_len = max_len

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        g = self.data_list[idx]
        sequence, global_feat, target = convert_graph_to_sequence(g, self.max_len)
        return {
            'sequence': sequence,
            'global': global_feat,
            'target': target,
            'circuit_class': getattr(g, 'circuit_class', 'unknown'),
            'noise_regime': getattr(g, 'noise_regime', 'unknown'),
            'n_qubits': g.n_qubits
        }


def collate_sequence(batch):
    """
    Collates a batch of sequence-encoded circuit samples into batched tensors.
    """
    sequences = torch.stack([item['sequence'] for item in batch])
    global_feats = torch.stack([item['global'] for item in batch])
    targets = torch.stack([item['target'] for item in batch])

    return {
        'sequence': sequences,
        'global': global_feats,
        'target': targets
    }

## CNN Model

In [7]:
class SequenceCNN(nn.Module):
    """
    CNN model for predicting output distributions of quantum circuits.
    Applies a 1D convolution over gate-level sequences, combines with global features,
    and passes through a shallow MLP head.
    """
    def __init__(self, input_dim, global_dim, output_dim, hidden_2q=58, n_qubits=2):
        super().__init__()
        scale = (n_qubits - 1) * 1.3
        hidden = int(hidden_2q * scale)

        self.conv = nn.Conv1d(input_dim, hidden, kernel_size=5, padding=2)
        self.global_proj = nn.Linear(global_dim, hidden)
        self.mlp = nn.Sequential(
            nn.Linear(2 * hidden, 4 * hidden),
            nn.ReLU(),
            nn.Linear(4 * hidden, 2 * hidden),
            nn.ReLU(),
            nn.Linear(2 * hidden, output_dim)
        )

    def forward(self, sequence_data, global_features):
        x = sequence_data.permute(0, 2, 1)
        x = self.conv(x).mean(dim=2)
        u_proj = self.global_proj(global_features)
        return self.mlp(torch.cat([x, u_proj], dim=1))

In [8]:
def build_cnn_for_6q(sample, hidden_dim_2q=64):
    """
    Builds CNN model with 6q-compatible output size and scaled hidden size.
    """
    input_dim = sample.x.shape[1]
    global_dim = sample.u.shape[0]
    output_dim = 64
    hidden_dim = int(hidden_dim_2q * (1 + 0.75 * (5 - 2)))

    return SequenceCNN(input_dim=input_dim,
                       output_dim=output_dim,
                       global_dim=global_dim)


def load_cnn_weights(model, path):
    """
    Loads CNN model weights from 5q model, interpolating if needed.
    """
    state_dict = torch.load(path, map_location='cpu')
    model_state = model.state_dict()
    for k in model_state:
        if k in state_dict and model_state[k].shape == state_dict[k].shape:
            model_state[k] = state_dict[k]
    model.load_state_dict(model_state)
    return model

## Zero-Shot

In [9]:
# Load and transform 6-qubit dataset
data_6q = load_6q_all_groups()
print(f"Loaded {len(data_6q)} samples.")
cnn_dataset = SequenceCircuitDataset(data_6q)
cnn_loader = DataLoader(cnn_dataset, batch_size=32, shuffle=False, collate_fn=collate_sequence)

# Build and load CNN model
sample_6q = data_6q[0]
cnn_model = build_cnn_for_6q(sample_6q)
cnn_model = load_cnn_weights(cnn_model, "../models/cnn_models/5q_cnn.pt")
cnn_model.to(device)

Loaded 2000 samples.


SequenceCNN(
  (conv): Conv1d(30, 75, kernel_size=(5,), stride=(1,), padding=(2,))
  (global_proj): Linear(in_features=9, out_features=75, bias=True)
  (mlp): Sequential(
    (0): Linear(in_features=150, out_features=300, bias=True)
    (1): ReLU()
    (2): Linear(in_features=300, out_features=150, bias=True)
    (3): ReLU()
    (4): Linear(in_features=150, out_features=64, bias=True)
  )
)

In [10]:
def evaluate_cnn_zero_shot(model, loader):
    model.eval()
    all_mse, all_kl, all_fi, all_wass = [], [], [], []
    with torch.no_grad():
        for batch in loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            pred = model(batch['sequence'], batch['global'])
            pred_prob = F.softmax(pred, dim=1)
            target_prob = normalize_distribution(batch['target'])

            if torch.isnan(pred_prob).any() or torch.isnan(target_prob).any():
                continue

            all_mse.append(mse_vec(pred_prob, target_prob).cpu().numpy())
            all_kl.append(kl_divergence_vec(target_prob, pred_prob).cpu().numpy())
            all_fi.append(classical_fidelity_vec(target_prob, pred_prob).cpu().numpy())
            all_wass.append(wasserstein_vec(pred_prob, target_prob))

    return {
        "MSE": float(np.mean(np.concatenate(all_mse))) if all_mse else np.nan,
        "KL": float(np.mean(np.concatenate(all_kl))) if all_kl else np.nan,
        "Fidelity": float(np.mean(np.concatenate(all_fi))) if all_fi else np.nan,
        "Wasserstein": float(np.mean(np.concatenate(all_wass))) if all_wass else np.nan,
    }


# Run zero-shot
print("Zero-shot evaluation")
zero_metrics = evaluate_cnn_zero_shot(cnn_model, cnn_loader)
for k, v in zero_metrics.items():
    print(f"{k}: {v:.6f}")

Zero-shot evaluation
MSE: 0.000864
KL: 0.891041
Fidelity: 0.758899
Wasserstein: 7.008146


## Few-Shot

In [11]:
def stratified_few_shot_split(data_list, n_per_group=50):
    """
    Creates a stratified few-shot train/val split across circuit class and noise regime.
    """
    group_buckets = defaultdict(list)
    for g in data_list:
        key = (getattr(g, 'noise_regime', 'unknown'), getattr(g, 'circuit_class', 'unknown'))
        group_buckets[key].append(g)

    train_set, val_set = [], []
    for key, items in group_buckets.items():
        random.shuffle(items)
        train_set.extend(items[:n_per_group])
        val_set.extend(items[n_per_group:])

    return train_set, val_set


few_shot_train_graphs, few_shot_val_graphs = stratified_few_shot_split(data_6q, n_per_group=50)

train_dataset = SequenceCircuitDataset(few_shot_train_graphs)
val_dataset = SequenceCircuitDataset(few_shot_val_graphs)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_sequence)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_sequence)

In [12]:
def kl_loss(pred_prob, target_prob, eps=1e-8):
    """Computes mean KL divergence loss between predicted and target probability distributions."""
    pred_prob = pred_prob + eps
    target_prob = target_prob + eps
    return torch.mean(torch.sum(target_prob * (target_prob.log() - pred_prob.log()), dim=1))


def finetune_cnn(model, train_loader, val_loader, epochs=20, lr=1e-4):
    """Performs few-shot fine-tuning of the CNN using KL divergence loss."""
    model.train()
    optimizer = Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        train_losses = []
        for batch in train_loader:
            seq = batch['sequence'].to(device)
            glob = batch['global'].to(device)
            target = batch['target'].to(device)

            pred = model(seq, glob)
            pred_prob = F.softmax(pred, dim=1)
            target_prob = normalize_distribution(target)

            loss = kl_loss(pred_prob, target_prob)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())

        # Validation
        model.eval()
        val_losses = []
        with torch.no_grad():
            for batch in val_loader:
                seq = batch['sequence'].to(device)
                glob = batch['global'].to(device)
                target = batch['target'].to(device)

                pred = model(seq, glob)
                pred_prob = F.softmax(pred, dim=1)
                target_prob = normalize_distribution(target)

                val_loss = kl_loss(pred_prob, target_prob)
                val_losses.append(val_loss.item())

        print(f"Epoch {epoch+1} | Train KL Loss: {np.mean(train_losses):.6f} | Val KL Loss: {np.mean(val_losses):.6f}")

    return model

In [13]:
model_finetuned = finetune_cnn(cnn_model, train_loader, val_loader)

model_finetuned.eval()
all_preds, all_targets = [], []
with torch.no_grad():
    for batch in val_loader:
        seq = batch['sequence'].to(device)
        glob = batch['global'].to(device)
        target = batch['target'].to(device)

        pred = model_finetuned(seq, glob)
        pred_prob = F.softmax(pred, dim=1)
        target_prob = normalize_distribution(target)

        all_preds.append(pred_prob)
        all_targets.append(target_prob)

preds = torch.cat(all_preds, dim=0)
targets = torch.cat(all_targets, dim=0)

few_shot_metrics = {
    "MSE": float(mse_vec(preds, targets).mean().item()),
    "KL": float(kl_divergence_vec(targets, preds).mean().item()),
    "Fidelity": float(classical_fidelity_vec(targets, preds).mean().item()),
    "Wasserstein": float(wasserstein_vec(preds, targets).mean().item())
}

print("Few-shot Evaluation Metrics:")
for k, v in few_shot_metrics.items():
    print(f"{k}: {v:.6f}")

Epoch 1 | Train KL Loss: 0.876760 | Val KL Loss: 0.881790
Epoch 2 | Train KL Loss: 0.943856 | Val KL Loss: 0.880293
Epoch 3 | Train KL Loss: 0.865485 | Val KL Loss: 0.878945
Epoch 4 | Train KL Loss: 0.927782 | Val KL Loss: 0.877659
Epoch 5 | Train KL Loss: 0.884384 | Val KL Loss: 0.876356
Epoch 6 | Train KL Loss: 0.856771 | Val KL Loss: 0.874962
Epoch 7 | Train KL Loss: 0.870564 | Val KL Loss: 0.873521
Epoch 8 | Train KL Loss: 0.935998 | Val KL Loss: 0.871899
Epoch 9 | Train KL Loss: 0.859464 | Val KL Loss: 0.870308
Epoch 10 | Train KL Loss: 0.879475 | Val KL Loss: 0.868730
Epoch 11 | Train KL Loss: 0.845983 | Val KL Loss: 0.867163
Epoch 12 | Train KL Loss: 0.872125 | Val KL Loss: 0.865835
Epoch 13 | Train KL Loss: 0.879678 | Val KL Loss: 0.864694
Epoch 14 | Train KL Loss: 0.841939 | Val KL Loss: 0.863759
Epoch 15 | Train KL Loss: 0.883417 | Val KL Loss: 0.862919
Epoch 16 | Train KL Loss: 0.863109 | Val KL Loss: 0.862319
Epoch 17 | Train KL Loss: 0.868518 | Val KL Loss: 0.862033
Epoch 