<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/parity_qnn_pipeline_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install pennylane torch tqdm

In [None]:
#!/usr/bin/env python3

import argparse
import numpy as np
import pennylane as qml
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.calibration import calibration_curve
from torch.optim.lr_scheduler import StepLR

# 1. CLI Arg Parsing (ignores Jupyter’s extra flags)
def parse_args():
    parser = argparse.ArgumentParser(
        description="n-Qubit Parity Classification with MC-Dropout & Calibration"
    )
    parser.add_argument("--n_qubits",   type=int,   default=3,   help="Number of qubits")
    parser.add_argument("--epochs",     type=int,   default=15,  help="Training epochs")
    parser.add_argument("--batch_size", type=int,   default=16,  help="Batch size")
    parser.add_argument("--lr",         type=float, default=1e-2,help="Learning rate")
    parser.add_argument("--shots",      type=int,   default=500, help="Measurement shots")
    parser.add_argument("--mc_runs",    type=int,   default=50,  help="MC-Dropout runs")
    args, _ = parser.parse_known_args()
    return args

# 2. Parity dataset generation
def generate_parity_dataset(n_qubits):
    X = np.array([
        list(map(int, np.binary_repr(i, width=n_qubits)))
        for i in range(2**n_qubits)
    ], dtype=np.float32)
    y = X.sum(axis=1) % 2
    return X, y.astype(np.int64)

# 3. Model definitions
class HybridParityQNN(nn.Module):
    def __init__(self, qlayer, n_qubits, dropout_rate=0.1):
        super().__init__()
        # classical pre-processing
        self.pre_fc = nn.Sequential(
            nn.Linear(n_qubits, n_qubits),
            nn.ReLU(),
        )
        self.qlayer  = qlayer
        self.dropout = nn.Dropout(dropout_rate)
        # classical post-processing
        self.post_fc = nn.Sequential(
            nn.Linear(1, 4),
            nn.ReLU(),
            nn.Linear(4, 2),
        )

    def forward(self, x):
        # x: [batch, n_qubits]
        x = self.pre_fc(x)                  # -> [batch, n_qubits]
        # manual batch: run QNode per sample
        q_outs = [self.qlayer(sample) for sample in x]
        q_outs = torch.stack(q_outs)        # -> [batch]
        q_outs = q_outs.unsqueeze(-1)       # -> [batch,1]
        dropped = self.dropout(q_outs)      # MC-Dropout
        logits  = self.post_fc(dropped)     # -> [batch,2]
        return logits

class TemperatureScaler(nn.Module):
    def __init__(self):
        super().__init__()
        self.temperature = nn.Parameter(torch.ones(1))

    def forward(self, logits):
        return logits / self.temperature

def optimize_temperature(logits, labels, criterion):
    ts = TemperatureScaler()
    opt = optim.LBFGS([ts.temperature], lr=0.1, max_iter=50)

    def closure():
        opt.zero_grad()
        loss = criterion(ts(logits), labels)
        loss.backward()
        return loss

    opt.step(closure)
    return ts.temperature.item()

# 4. Main pipeline
def main():
    args = parse_args()

    # Prepare data loaders
    X, y = generate_parity_dataset(args.n_qubits)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )
    train_ds = torch.utils.data.TensorDataset(
        torch.tensor(X_train), torch.tensor(y_train)
    )
    test_ds  = torch.utils.data.TensorDataset(
        torch.tensor(X_test),  torch.tensor(y_test)
    )
    train_loader = torch.utils.data.DataLoader(
        train_ds, batch_size=args.batch_size, shuffle=True
    )
    test_loader = torch.utils.data.DataLoader(
        test_ds, batch_size=args.batch_size, shuffle=False
    )

    # Build quantum device
    dev = qml.device(
        "default.mixed",
        wires=args.n_qubits,
        shots=args.shots
    )

    # Noisy QNode with two-layer ansatz
    @qml.qnode(dev, interface="torch")
    def circuit(inputs, weights1, weights2):
        # encoding + per-wire depolarizing noise
        for i in range(args.n_qubits):
            qml.RY(inputs[i] * np.pi, wires=i)
            qml.DepolarizingChannel(0.02, wires=i)
        # two layers of entangling ansatz
        qml.templates.StronglyEntanglingLayers(weights1, wires=range(args.n_qubits))
        qml.templates.StronglyEntanglingLayers(weights2, wires=range(args.n_qubits))
        return qml.expval(qml.PauliZ(0))

    weight_shapes = {
        "weights1": (1, args.n_qubits, 3),
        "weights2": (1, args.n_qubits, 3),
    }
    qlayer = qml.qnn.TorchLayer(circuit, weight_shapes)

    # Instantiate hybrid model
    model     = HybridParityQNN(qlayer, args.n_qubits, dropout_rate=0.1)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = StepLR(optimizer, step_size=5, gamma=0.5)
    criterion = nn.CrossEntropyLoss()

    # Training & eval funcs
    def train_epoch():
        model.train()
        total_loss, total_correct = 0.0, 0
        for xb, yb in train_loader:
            optimizer.zero_grad()
            logits = model(xb)
            loss   = criterion(logits, yb)
            loss.backward()
            optimizer.step()
            total_loss    += loss.item() * xb.size(0)
            total_correct += (logits.argmax(1) == yb).sum().item()
        scheduler.step()
        return total_loss / len(train_ds), total_correct / len(train_ds)

    def eval_acc(loader):
        model.eval()
        correct = 0
        with torch.no_grad():
            for xb, yb in loader:
                correct += (model(xb).argmax(1) == yb).sum().item()
        return correct / len(loader.dataset)

    # Train loop
    print(f"Training {args.n_qubits}-qubit QNN for {args.epochs} epochs")
    for epoch in range(1, args.epochs + 1):
        loss, tr_acc = train_epoch()
        te_acc = eval_acc(test_loader)
        print(
            f"Epoch {epoch:2d} | Loss: {loss:.4f} | "
            f"Train Acc: {tr_acc:.3f} | Test Acc: {te_acc:.3f}"
        )

    # MC-Dropout uncertainty
    def mc_dropout_preds(X_tensor, mc_runs):
        model.train()
        probs = []
        with torch.no_grad():
            for _ in range(mc_runs):
                logits = model(X_tensor)
                probs.append(torch.softmax(logits, dim=1).cpu().numpy())
        stack = np.stack(probs)            # [mc_runs, batch, 2]
        return stack.mean(axis=0), stack.var(axis=0).mean(axis=-1)

    X_test_tensor = torch.tensor(X_test)
    _, var_est = mc_dropout_preds(X_test_tensor, args.mc_runs)

    plt.hist(var_est, bins=20)
    plt.title("MC-Dropout Predictive Variance")
    plt.xlabel("Variance")
    plt.ylabel("Count")
    plt.savefig("mc_dropout_variance.png")
    print("Saved MC-Dropout variance histogram")

    # Temperature scaling
    model.eval()
    logits_tr, labels_tr = [], []
    with torch.no_grad():
        for xb, yb in train_loader:
            logits_tr.append(model(xb))
            labels_tr.append(yb)
    logits_tr = torch.cat(logits_tr)
    labels_tr = torch.cat(labels_tr)

    opt_temp = optimize_temperature(logits_tr, labels_tr, criterion)
    print(f"Optimal temperature: {opt_temp:.3f}")

    # Apply to test set
    logits_te = []
    with torch.no_grad():
        for xb, _ in test_loader:
            logits_te.append(model(xb))
    logits_te = torch.cat(logits_te)
    probs_te  = torch.softmax(logits_te / opt_temp, dim=1).cpu().numpy()

    # Reliability diagram
    frac_pos, mean_pred = calibration_curve(y_test, probs_te[:, 1], n_bins=10)
    plt.figure()
    plt.plot(mean_pred, frac_pos, "s-", label="Model")
    plt.plot([0, 1], [0, 1], "--", label="Ideal")
    plt.xlabel("Mean Predicted Value")
    plt.ylabel("Fraction of Positives")
    plt.title("Reliability Diagram")
    plt.legend()
    plt.savefig("reliability_diagram.png")
    print("Saved reliability diagram")

    # Final test accuracy
    final_acc = (logits_te.argmax(1).numpy() == y_test).mean()
    print(f"Final Test Accuracy: {final_acc:.3f}")

if __name__ == "__main__":
    main()