<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/parity_qnn_reupload_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install pennylane torch tqdm

In [None]:
#!/usr/bin/env python3

import argparse
import numpy as np
import pennylane as qml

# Try EfficientSU2, otherwise fall back
try:
    from pennylane.templates.layers import EfficientSU2 as AnsatzLayer
    ANSATZ_NAME = "EfficientSU2"
    HAS_EFFICIENT = True
except ImportError:
    from pennylane.templates.layers import StronglyEntanglingLayers as AnsatzLayer
    ANSATZ_NAME = "StronglyEntanglingLayers"
    HAS_EFFICIENT = False

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.calibration import calibration_curve

def parse_args():
    parser = argparse.ArgumentParser(description="Parity QNN (analytic, backprop)")
    parser.add_argument("--n_qubits",   type=int,   default=3)
    parser.add_argument("--epochs",     type=int,   default=15)
    parser.add_argument("--batch_size", type=int,   default=16)
    parser.add_argument("--lr",         type=float, default=1e-2)
    parser.add_argument("--mc_runs",    type=int,   default=50)
    parser.add_argument("--reps",       type=int,   default=2)
    args, _ = parser.parse_known_args()
    return args

def generate_parity_dataset(n_qubits):
    X = np.array([list(map(int, np.binary_repr(i, n_qubits)))
                  for i in range(2**n_qubits)], dtype=np.float32)
    y = X.sum(axis=1) % 2
    return X, y.astype(np.int64)

class HybridQNN(nn.Module):
    def __init__(self, qlayer, n_qubits, hidden_dim=4, dropout=0.1):
        super().__init__()
        self.pre  = nn.Sequential(nn.Linear(n_qubits, n_qubits), nn.ReLU())
        self.q    = qlayer
        self.drop = nn.Dropout(dropout)
        self.post = nn.Sequential(
            nn.Linear(n_qubits, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, 2)
        )

    def forward(self, x):
        x_enc = self.pre(x)
        q_out = self.q(x_enc)
        return self.post(self.drop(q_out))

def main():
    args = parse_args()
    print(f"Using ansatz: {ANSATZ_NAME}")

    # Prepare data
    X, y = generate_parity_dataset(args.n_qubits)
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )
    train_ds = torch.utils.data.TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
    test_ds  = torch.utils.data.TensorDataset(torch.tensor(X_test),  torch.tensor(y_test))
    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=args.batch_size, shuffle=True)
    test_loader  = torch.utils.data.DataLoader(test_ds,  batch_size=args.batch_size)

    # Analytic device: supports backprop
    dev = qml.device("default.qubit", wires=args.n_qubits)

    @qml.qnode(dev, interface="torch", diff_method="backprop")
    def circuit(inputs, weights):
        qml.templates.AngleEmbedding(inputs, wires=range(args.n_qubits), rotation="Y")
        if HAS_EFFICIENT:
            AnsatzLayer(weights, wires=range(args.n_qubits),
                        reps=args.reps, entanglement="full")
        else:
            AnsatzLayer(weights, wires=range(args.n_qubits))
        return [qml.expval(qml.PauliZ(i)) for i in range(args.n_qubits)]

    # Weight shapes depend on ansatz
    if HAS_EFFICIENT:
        weight_shapes = {"weights": (args.reps, args.n_qubits)}
    else:
        weight_shapes = {"weights": (args.reps, args.n_qubits, 3)}

    # Build hybrid model
    qlayer   = qml.qnn.TorchLayer(circuit, weight_shapes)
    model    = HybridQNN(qlayer, args.n_qubits)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    loss_fn   = nn.CrossEntropyLoss()

    # Training loop
    for epoch in range(1, args.epochs + 1):
        model.train()
        total_loss, correct = 0.0, 0
        for xb, yb in train_loader:
            optimizer.zero_grad()
            logits = model(xb)
            loss   = loss_fn(logits, yb)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * xb.size(0)
            correct    += (logits.argmax(1) == yb).sum().item()
        train_acc = correct / len(train_loader.dataset)

        model.eval()
        correct = 0
        with torch.no_grad():
            for xb, yb in test_loader:
                correct += (model(xb).argmax(1) == yb).sum().item()
        test_acc = correct / len(test_loader.dataset)

        print(f"Epoch {epoch:2d} | Loss {total_loss/len(train_loader.dataset):.4f}"
              f" | Train {train_acc:.3f} | Test {test_acc:.3f}")

    # MC-Dropout for uncertainty
    model.train()
    all_probs = []
    with torch.no_grad():
        for _ in range(args.mc_runs):
            batch_p = [torch.softmax(model(xb), dim=1).detach().cpu().numpy()
                       for xb, _ in test_loader]
            all_probs.append(np.vstack(batch_p))
    var_probs = np.stack(all_probs).var(axis=0).mean(axis=1)
    plt.hist(var_probs, bins=20)
    plt.title("MC-Dropout Variance")
    plt.savefig("mc_variance.png")
    print("Saved MC variance histogram")

    # Temperature scaling on training set
    model.eval()
    logits_list, labels_list = [], []
    with torch.no_grad():
        for xb, yb in train_loader:
            logits_list.append(model(xb))
            labels_list.append(yb)
    logits_stack = torch.cat(logits_list)
    labels_stack = torch.cat(labels_list)

    T = torch.ones(1, requires_grad=True)
    def nll(): return loss_fn(logits_stack / T, labels_stack)
    optT = optim.LBFGS([T], lr=0.1, max_iter=50)
    optT.step(lambda: nll())
    print(f"Optimal temperature T = {T.item():.3f}")

    # Reliability diagram on test set
    model.eval()
    logits_test = torch.cat([model(xb) for xb, _ in test_loader])
    probs_test  = torch.softmax(logits_test / T, dim=1).detach().cpu().numpy()[:, 1]
    frac_pos, mean_pred = calibration_curve(y_test, probs_test, n_bins=10)

    plt.figure()
    plt.plot(mean_pred, frac_pos, "s-")
    plt.plot([0, 1], [0, 1], "--")
    plt.xlabel("Mean predicted probability")
    plt.ylabel("Fraction of positives")
    plt.title("Reliability Diagram")
    plt.savefig("reliability_diagram.png")
    print("Saved reliability diagram")

    # Final test accuracy
    final_acc = (logits_test.argmax(1).detach().cpu().numpy() == y_test).mean()
    print(f"Final Test Accuracy: {final_acc:.3f}")

if __name__ == "__main__":
    main()