<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/parity_qnn_reupload_tuned_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install pennylane torch tqdm

In [None]:
#!/usr/bin/env python3
"""
parity_qnn_reupload_tuned.py

Tuned hybrid parity QNN pipeline with:
- Increased expressivity (6 layers, dual‐axis embedding)
- Larger classical post‐net (batch norm + 64‐unit hidden layer)
- No dropout
- AdamW optimizer + gradient clipping + cosine annealing LR scheduler
- MC‐Dropout, temperature scaling & reliability diagram
- Robust argparse (ignores Colab/IPython flags)
"""

import argparse
import numpy as np
import pennylane as qml
from pennylane.templates.layers import StronglyEntanglingLayers
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.calibration import calibration_curve
from torch.utils.data import TensorDataset, DataLoader

def parse_args():
    parser = argparse.ArgumentParser(
        description="Tuned Parity QNN with data re‐uploading"
    )
    parser.add_argument("--n_qubits",   type=int,   default=3,
                        help="Number of qubits / input bits")
    parser.add_argument("--n_layers",   type=int,   default=6,
                        help="Number of quantum re‐uploading layers")
    parser.add_argument("--hidden_dim", type=int,   default=64,
                        help="Hidden dimension of classical MLP post‐net")
    parser.add_argument("--dropout",    type=float, default=0.0,
                        help="Dropout probability in classical post‐net")
    parser.add_argument("--mc_runs",    type=int,   default=50,
                        help="Number of Monte Carlo dropout runs")
    parser.add_argument("--epochs",     type=int,   default=30,
                        help="Number of training epochs")
    parser.add_argument("--lr",         type=float, default=2e-2,
                        help="Initial learning rate for AdamW")
    args, _ = parser.parse_known_args()  # ignore unknown IPython flags
    return args

def generate_parity_dataset(n_qubits):
    """Generate all 2^n_qubits bitstrings and parity labels."""
    X = np.array(
        [list(map(int, np.binary_repr(i, width=n_qubits)))
         for i in range(2**n_qubits)],
        dtype=np.float32,
    )
    y = X.sum(axis=1) % 2
    return X, y.astype(np.int64)

class HybridParityReupload(nn.Module):
    """Hybrid model: classical pre‐net, quantum layer, and classical post‐net."""
    def __init__(self, qlayer, n_qubits, hidden_dim, dropout_p):
        super().__init__()
        self.pre_net = nn.Sequential(
            nn.Linear(n_qubits, n_qubits),
            nn.ReLU(),
        )
        self.qlayer  = qlayer
        self.dropout = nn.Dropout(dropout_p)
        self.post_net = nn.Sequential(
            nn.BatchNorm1d(n_qubits),
            nn.Linear(n_qubits, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2),
        )

    def forward(self, x):
        x = self.pre_net(x)
        x = self.qlayer(x)           # quantum features
        x = self.dropout(x)
        return self.post_net(x)

def main():
    args = parse_args()

    # 1) Prepare data
    X, y = generate_parity_dataset(args.n_qubits)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )
    train_ds = TensorDataset(
        torch.tensor(X_train, dtype=torch.float32),
        torch.tensor(y_train, dtype=torch.int64),
    )
    test_ds  = TensorDataset(
        torch.tensor(X_test, dtype=torch.float32),
        torch.tensor(y_test, dtype=torch.int64),
    )
    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
    test_loader  = DataLoader(test_ds,  batch_size=32)

    # 2) Quantum device & QNode
    dev = qml.device("default.qubit", wires=args.n_qubits)

    @qml.qnode(dev, interface="torch", diff_method="backprop")
    def circuit(inputs, weights):
        wires = range(args.n_qubits)
        for layer in range(args.n_layers):
            # Dual‐axis embedding for extra nonlinearity
            qml.templates.AngleEmbedding(inputs, wires=wires, rotation="Y")
            qml.templates.AngleEmbedding(inputs, wires=wires, rotation="Z")
            w = weights[layer : layer + 1]  # shape (1, n_qubits, 3)
            StronglyEntanglingLayers(w, wires=wires)
        # final embedding boost
        qml.templates.AngleEmbedding(inputs, wires=wires, rotation="Y")
        return [qml.expval(qml.PauliZ(i)) for i in wires]

    weight_shapes = {"weights": (args.n_layers, args.n_qubits, 3)}
    qlayer = qml.qnn.TorchLayer(circuit, weight_shapes)

    # 3) Model, optimizer, loss, scheduler
    model     = HybridParityReupload(qlayer, args.n_qubits, args.hidden_dim, args.dropout)
    optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-3)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
    criterion = nn.CrossEntropyLoss()

    # 4) Training loop with gradient clipping
    for epoch in range(1, args.epochs + 1):
        model.train()
        total_loss, correct = 0.0, 0
        for xb, yb in train_loader:
            optimizer.zero_grad()
            logits = model(xb)
            loss   = criterion(logits, yb)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item() * xb.size(0)
            correct    += (logits.argmax(1) == yb).sum().item()

        scheduler.step()
        train_acc = correct / len(train_loader.dataset)

        model.eval()
        correct = 0
        with torch.no_grad():
            for xb, yb in test_loader:
                correct += (model(xb).argmax(1) == yb).sum().item()
        test_acc = correct / len(test_loader.dataset)

        print(
            f"Epoch {epoch:2d} | "
            f"Loss {total_loss/len(train_loader.dataset):.4f} | "
            f"Train {train_acc:.3f} | Test {test_acc:.3f}"
        )

    # 5) MC‐Dropout uncertainty (dropout=0 → zero variance here)
    model.train()
    all_probs = []
    with torch.no_grad():
        for _ in range(args.mc_runs):
            batch_ps = []
            for xb, _ in test_loader:
                ps = torch.softmax(model(xb), dim=1).cpu().numpy()
                batch_ps.append(ps)
            all_probs.append(np.vstack(batch_ps))
    var_est = np.stack(all_probs).var(axis=0).mean(axis=1)
    plt.hist(var_est, bins=20)
    plt.title("MC‐Dropout Variance")
    plt.savefig("mc_variance.png")
    print("Saved MC‐Dropout variance histogram")

    # 6) Temperature scaling on training set
    model.eval()
    logits_list, labels_list = [], []
    with torch.no_grad():
        for xb, yb in train_loader:
            logits_list.append(model(xb))
            labels_list.append(yb)
    logits_stack = torch.cat(logits_list)
    labels_stack = torch.cat(labels_list)

    T = torch.ones(1, requires_grad=True)
    def loss_T(): return criterion(logits_stack / T, labels_stack)
    optim.LBFGS([T], lr=0.1, max_iter=50).step(lambda: loss_T())
    T = T.detach()
    print(f"Optimal temperature T = {T.item():.3f}")

    # 7) Reliability diagram on test set
    model.eval()
    logits_test = torch.cat([model(xb) for xb, _ in test_loader])
    probs_test  = (
        torch.softmax(logits_test / T, dim=1)
        .detach()
        .cpu()
        .numpy()[:, 1]
    )
    frac_pos, mean_pred = calibration_curve(y_test, probs_test, n_bins=10)
    plt.figure()
    plt.plot(mean_pred, frac_pos, "s-", label="Model")
    plt.plot([0, 1], [0, 1], "--", label="Ideal")
    plt.xlabel("Mean predicted probability")
    plt.ylabel("Fraction of positives")
    plt.title("Reliability Diagram")
    plt.legend()
    plt.savefig("reliability_diagram.png")
    print("Saved reliability diagram")

    # 8) Final test accuracy
    final_acc = (logits_test.argmax(1).cpu().numpy() == y_test).mean()
    print(f"Final Test Accuracy: {final_acc:.3f}")

if __name__ == "__main__":
    main()