<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/hybrid_parity_full_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install pennylane torch tqdm

In [None]:
#!/usr/bin/env python3
"""
hybrid_parity_full.py

A self-contained script that trains and evaluates an n-bit parity classifier
using a hybrid quantum-classical model with:

  1. MC-Dropout for uncertainty quantification
  2. Temperature scaling + reliability diagram calibration
  3. Noise injection via default.mixed device
  4. Scaling to arbitrary n-qubit parity via CLI args

Unknown args (like Jupyter’s “-f …json”) are silently ignored.
"""

import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

import pennylane as qml
from pennylane import numpy as np

import matplotlib.pyplot as plt


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--n_qubits",    type=int,   default=4)
    p.add_argument("--n_layers",    type=int,   default=6)
    p.add_argument("--hidden_dim",  type=int,   default=32)
    p.add_argument("--dropout_p",   type=float, default=0.1)
    p.add_argument("--mc_samples",  type=int,   default=20)
    p.add_argument(
        "--noise_model",
        type=str,
        choices=["none", "depolarizing", "amplitude_damping"],
        default="none",
    )
    p.add_argument("--noise_prob",  type=float, default=0.01)
    p.add_argument("--batch_size",  type=int,   default=32)
    p.add_argument("--epochs",      type=int,   default=20)
    p.add_argument("--lr",          type=float, default=0.005)
    p.add_argument("--calibrate",   action="store_true")
    p.add_argument("--output_dir",  type=str,   default="results/")

    # parse_known_args(): ignore any extra flags (e.g., Jupyter’s “-f”)
    args, _ = p.parse_known_args()
    return args


def generate_parity_data(n_samples, n_qubits):
    X = np.random.randint(0, 2, size=(n_samples, n_qubits))
    X = 2 * X - 1  # map {0,1} → {-1,+1}
    y = (np.sum(X == 1, axis=1) % 2).astype(int)
    return (
        torch.tensor(X, dtype=torch.float32),
        torch.tensor(y, dtype=torch.long),
    )


def create_device(n_qubits, noise_model, noise_prob):
    if noise_model == "none":
        return qml.device("default.qubit", wires=n_qubits)
    # default.mixed for noise
    return qml.device(
        "default.mixed",
        wires=n_qubits,
        noise= qml.DepolarizingChannel
               if noise_model == "depolarizing"
               else qml.AmplitudeDamping,
        prob=noise_prob,
    )


def create_qnode(dev, n_qubits, n_layers):
    @qml.qnode(dev, interface="torch", diff_method="backprop")
    def qnode(inputs, weights):
        qml.templates.AngleEmbedding(
            inputs, wires=range(n_qubits), rotation="X"
        )
        # Remove `reps` keyword; the shape of `weights` encodes layers:
        qml.templates.BasicEntanglerLayers(
            weights, wires=range(n_qubits)
        )
        return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]

    return qnode


class HybridParityModel(nn.Module):
    def __init__(self, n_qubits, n_layers, hidden_dim, dropout_p, dev):
        super().__init__()
        qnode = create_qnode(dev, n_qubits, n_layers)
        weight_shapes = {"weights": (n_layers, n_qubits)}
        self.qlayer    = qml.qnn.TorchLayer(qnode, weight_shapes)

        self.res_fc1    = nn.Linear(n_qubits, hidden_dim)
        self.dropout    = nn.Dropout(dropout_p)
        self.res_fc2    = nn.Linear(hidden_dim, hidden_dim)
        self.classifier = nn.Linear(hidden_dim, 2)

    def forward(self, x):
        q_out = self.qlayer(x)
        h     = F.relu(self.res_fc1(q_out))
        h     = self.dropout(h)
        h_res = F.relu(self.res_fc2(h))
        h     = h + h_res            # out-of-place residual
        return self.classifier(h)


def temperature_scaling(logits, labels, device):
    T = torch.ones(1, requires_grad=True, device=device)
    nll = nn.CrossEntropyLoss()
    opt = optim.LBFGS([T], lr=0.1, max_iter=50)

    def closure():
        opt.zero_grad()
        loss = nll(logits / T, labels)
        loss.backward()
        return loss

    opt.step(closure)
    return T.item()


def plot_reliability_diagram(probs, labels, n_bins, save_path):
    confs = probs.max(1).values.detach().cpu().numpy()
    preds = probs.argmax(1).detach().cpu().numpy()
    truths = labels.detach().cpu().numpy()

    bins        = np.linspace(0, 1, n_bins + 1)
    bin_centers = (bins[:-1] + bins[1:]) / 2
    accs, avg_conf = [], []

    for i in range(n_bins):
        mask = (confs >= bins[i]) & (confs < bins[i + 1])
        if mask.sum() > 0:
            accs.append((preds[mask] == truths[mask]).mean())
            avg_conf.append(confs[mask].mean())
        else:
            accs.append(0.0)
            avg_conf.append(0.0)

    plt.figure(figsize=(5, 5))
    plt.plot(bin_centers, avg_conf, 's-', label="Confidence")
    plt.plot(bin_centers, accs, 'o-', label="Accuracy")
    plt.plot([0,1], [0,1], '--', color='gray', label="Ideal")
    plt.xlabel("Confidence")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()


def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    device = create_device(
        args.n_qubits, args.noise_model, args.noise_prob
    )

    model = HybridParityModel(
        n_qubits   = args.n_qubits,
        n_layers   = args.n_layers,
        hidden_dim = args.hidden_dim,
        dropout_p  = args.dropout_p,
        dev        = device,
    )

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    criterion = nn.CrossEntropyLoss()

    # datasets: train/calibrate/test
    X_train, y_train = generate_parity_data(2000, args.n_qubits)
    X_calib, y_calib = generate_parity_data(500,  args.n_qubits)
    X_test,  y_test  = generate_parity_data(500,  args.n_qubits)

    train_loader = DataLoader(
        TensorDataset(X_train, y_train),
        batch_size=args.batch_size, shuffle=True
    )
    test_loader  = DataLoader(
        TensorDataset(X_test,  y_test),
        batch_size=args.batch_size
    )

    # 1) Train
    for epoch in range(1, args.epochs + 1):
        model.train()
        running_loss, correct = 0.0, 0

        for Xb, yb in train_loader:
            optimizer.zero_grad()
            logits = model(Xb)
            loss   = criterion(logits, yb)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * Xb.size(0)
            correct      += (logits.argmax(1) == yb).sum().item()

        train_acc  = correct / len(train_loader.dataset)
        train_loss = running_loss / len(train_loader.dataset)

        # quick test accuracy
        model.eval()
        test_corr = 0
        with torch.no_grad():
            for Xb, yb in test_loader:
                test_corr += (model(Xb).argmax(1) == yb).sum().item()
        test_acc = test_corr / len(test_loader.dataset)

        print(
            f"Epoch {epoch:2d} | Loss: {train_loss:.4f} | "
            f"Train Acc: {train_acc:.3f} | Test Acc: {test_acc:.3f}"
        )

    # 2) MC-Dropout
    if args.mc_samples > 0:
        model.train()  # keep dropout active
        mc_probs = []
        with torch.no_grad():
            for _ in range(args.mc_samples):
                batch_probs = []
                for Xb, _ in test_loader:
                    ps = F.softmax(model(Xb), dim=1)
                    batch_probs.append(ps)
                mc_probs.append(torch.cat(batch_probs, dim=0))
        mc_stack = torch.stack(mc_probs)  # [mc, N, 2]
        mean_p   = mc_stack.mean(0)
        var_p    = mc_stack.var(0).mean(1)

        plt.hist(var_p.cpu(), bins=20)
        plt.xlabel("Predictive Variance")
        plt.ylabel("Count")
        plt.tight_layout()
        plt.savefig(os.path.join(args.output_dir, "mc_variance.png"))
        plt.close()
        print("Saved MC-Dropout variance histogram")

    # 3) Calibration
    if args.calibrate:
        model.eval()
        with torch.no_grad():
            logits_calib = model(X_calib)
        T_opt = temperature_scaling(logits_calib, y_calib, logits_calib.device)
        print(f"Fitted temperature: {T_opt:.3f}")

        with torch.no_grad():
            logits_test = model(X_test) / T_opt
            probs_test  = F.softmax(logits_test, dim=1)
        plot_reliability_diagram(
            probs_test, y_test, n_bins=10,
            save_path=os.path.join(args.output_dir, "reliability.png")
        )
        print("Saved reliability diagram")

    # 4) Final Test Acc
    model.eval()
    final_corr = 0
    with torch.no_grad():
        for Xb, yb in test_loader:
            final_corr += (model(Xb).argmax(1) == yb).sum().item()
    print(f"Final Test Accuracy: {final_corr / len(test_loader.dataset):.3f}")


if __name__ == "__main__":
    main()