In [None]:
#!pip install pennylane

In [None]:
"""
Fixed Quantum MNIST (binarized): QNN, QCNN, HQNN

- Data from: qml.data.load("other", name="binarized-mnist")
- 50'000 train, 10'000 test, each input: bitstring of length 784

Design:
- QNN: 10 qubits, 4 layers, stronger downscale MLP
- QCNN: 8 qubits, 2 conv+pool levels, 4 quantum outputs, larger classical head
- HQNN: 8-qubit 3-layer quantum core + CNN + dropout, own small QNode
"""

import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
import pennylane as qml
import matplotlib.pyplot as plt

# ----------------------------
# Global config
# ----------------------------
n_classes = 10

# QNN config
n_qubits_qnn = 10
n_layers_qnn = 4

# QCNN + HQNN config
n_qubits_small = 8
n_layers_small = 3

batch_size = 128
epochs = 20
lr = 1e-3
device = "cuda" if torch.cuda.is_available() else "cpu"

# choose "qnn", "qcnn", or "hqnn"
MODEL_CHOICE = "hqnn"

In [None]:
# ----------------------------
# 1. Load binarized MNIST via qml.data
# ----------------------------
print("Loading binarized MNIST from PennyLane...")
[ds] = qml.data.load("other", name="binarized-mnist")

x_train = torch.tensor(ds.train["inputs"], dtype=torch.float32)  # (50000, 784)
y_train = torch.tensor(ds.train["labels"], dtype=torch.long)    # (50000,)

x_test = torch.tensor(ds.test["inputs"], dtype=torch.float32)   # (10000, 784)
y_test = torch.tensor(ds.test["labels"], dtype=torch.long)      # (10000,)

train_dataset = TensorDataset(x_train, y_train)
test_dataset  = TensorDataset(x_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False)

print(f"Train shape: {x_train.shape}, Test shape: {x_test.shape}")

Loading binarized MNIST from PennyLane...
Train shape: torch.Size([50000, 784]), Test shape: torch.Size([10000, 784])


In [None]:
# ----------------------------
# 2. Classical helper modules
# ----------------------------
class BetterDownscale(nn.Module):
    """
    Improved downscale network:
    - Reshape 784 -> (1, 28, 28)
    - AvgPool to 7x7
    - MLP: 49 -> hidden -> n_qubits

    Used by QNN and QCNN (with different n_qubits).
    """
    def __init__(self, n_qubits, hidden=64):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d((7, 7))
        self.fc1 = nn.Linear(7 * 7, hidden)
        self.fc2 = nn.Linear(hidden, n_qubits)
        self.act = nn.ReLU()

    def forward(self, x):
        # x: (batch_size, 784)
        x = x.view(x.size(0), 1, 28, 28)   # -> (B,1,28,28)
        x = self.pool(x)                   # -> (B,1,7,7)
        x = x.view(x.size(0), -1)          # -> (B,49)
        x = self.act(self.fc1(x))          # -> (B,hidden)
        x = self.fc2(x)                    # -> (B,n_qubits)
        return x


class SmallCNN(nn.Module):
    """
    Small classical CNN for HQNN feature extraction.
    Also expects flat 784 inputs and reshapes inside.
    """
    def __init__(self, latent_dim=32):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),      # 28x28 -> 14x14
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),      # 14x14 -> 7x7
        )
        self.fc = nn.Linear(32 * 7 * 7, latent_dim)

    def forward(self, x):
        # x: (B, 784)
        x = x.view(x.size(0), 1, 28, 28)  # -> (B,1,28,28)
        x = self.conv(x)                  # -> (B,32,7,7)
        x = x.view(x.size(0), -1)         # -> (B,32*7*7)
        return self.fc(x)                 # -> (B, latent_dim)

In [None]:
# ----------------------------
# 3. Devices
# ----------------------------
# QNN uses its own (larger) device
dev_qnn = qml.device("default.qubit", wires=n_qubits_qnn)

# QCNN + HQNN share a smaller 8-qubit device
dev_small = qml.device("default.qubit", wires=n_qubits_small)

In [None]:
# ----------------------------
# 4. QNN circuit (10 qubits, 4 layers)
# ----------------------------
@qml.qnode(dev_qnn, interface="torch")
def qnn_circuit(inputs, weights):
    """
    Quantum circuit for the standalone QNN:
    - AngleEmbedding of n_qubits_qnn features
    - StronglyEntanglingLayers with trainable parameters
    - returns <Z> on each qubit (length n_qubits_qnn)
    """
    qml.AngleEmbedding(inputs, wires=range(n_qubits_qnn), rotation="Y")
    qml.StronglyEntanglingLayers(weights, wires=range(n_qubits_qnn))
    return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits_qnn)]

qnn_shapes = {"weights": (n_layers_qnn, n_qubits_qnn, 3)}
QNNLayerBig = qml.qnn.TorchLayer(qnn_circuit, qnn_shapes)

In [None]:
# ----------------------------
# 5. HQNN circuit (8 qubits, 3 layers)
# ----------------------------
@qml.qnode(dev_small, interface="torch")
def hqnn_circuit(inputs, weights):
    """
    Quantum circuit used inside the HQNN:
    - 8 qubits, 3 layers (baseline size)
    """
    qml.AngleEmbedding(inputs, wires=range(n_qubits_small), rotation="Y")
    qml.StronglyEntanglingLayers(weights, wires=range(n_qubits_small))
    return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits_small)]

hqnn_shapes = {"weights": (n_layers_small, n_qubits_small, 3)}
QNNLayerSmall = qml.qnn.TorchLayer(hqnn_circuit, hqnn_shapes)

In [None]:
# ----------------------------
# 6. QCNN circuit (8 qubits, 4 outputs)
# ----------------------------
def conv_block(params, wires):
    """
    Quantum "convolution" on two qubits (small device: 8 qubits).
    """
    i, j = wires
    qml.CNOT([i, j])
    qml.RY(params[0], i)
    qml.RY(params[1], j)
    qml.RZ(params[2], j)
    qml.CNOT([i, j])

def pooling_block(params, wires):
    """
    Quantum "pooling" on two qubits: keep i, pool/discard j.
    """
    i, j = wires
    qml.CNOT([i, j])
    qml.RY(params[0], i)
    qml.RZ(params[1], i)

@qml.qnode(dev_small, interface="torch")
def qcnn_circuit(inputs, weights):
    """
    QCNN-style circuit on 8 qubits:
    - AngleEmbedding on all 8 qubits
    - conv+pool on pairs (0,1),(2,3),(4,5),(6,7)
    - conv+pool on pairs (0,2),(4,6)
    - measure 4 qubits: 0, 2, 4, 6 (4 outputs)
    """
    qml.AngleEmbedding(inputs, wires=range(n_qubits_small), rotation="Y")

    w = weights.reshape(-1)
    idx = 0

    conv1 = [(0, 1), (2, 3), (4, 5), (6, 7)]
    for p in conv1:
        conv_block(w[idx:idx+3], p)
        idx += 3

    for p in conv1:
        pooling_block(w[idx:idx+2], p)
        idx += 2

    conv2 = [(0, 2), (4, 6)]
    for p in conv2:
        conv_block(w[idx:idx+3], p)
        idx += 3

    for p in conv2:
        pooling_block(w[idx:idx+2], p)
        idx += 2

    return [
        qml.expval(qml.PauliZ(0)),
        qml.expval(qml.PauliZ(2)),
        qml.expval(qml.PauliZ(4)),
        qml.expval(qml.PauliZ(6)),
    ]

qcnn_shapes = {"weights": (30,)}
QCNNLayer = qml.qnn.TorchLayer(qcnn_circuit, qcnn_shapes)

In [None]:
# ----------------------------
# 7. Models: QNN, QCNN, HQNN
# ----------------------------
class QNNModel(nn.Module):
    """
    QNN:
    bitstring (784) -> BetterDownscale(10) -> QNNLayerBig -> Linear(10->10)
    """
    def __init__(self):
        super().__init__()
        self.pre = BetterDownscale(n_qubits_qnn)
        self.q = QNNLayerBig
        self.fc = nn.Linear(n_qubits_qnn, n_classes)

    def forward(self, x):
        x = self.pre(x)              # (B, n_qubits_qnn)
        x = torch.tanh(x)
        x = self.q(x)                # (B, n_qubits_qnn)
        return self.fc(x)            # (B, 10)


class QCNNModel(nn.Module):
    """
    QCNN:
    bitstring (784) -> BetterDownscale(8) -> QCNNLayer -> MLP head (4->32->10)
    """
    def __init__(self):
        super().__init__()
        self.pre = BetterDownscale(n_qubits_small)
        self.q = QCNNLayer
        self.fc = nn.Sequential(
            nn.Linear(4, 32),
            nn.ReLU(),
            nn.Linear(32, n_classes),
        )

    def forward(self, x):
        x = self.pre(x)              # (B, n_qubits_small)
        x = torch.tanh(x)
        x = self.q(x)                # (B, 4)
        return self.fc(x)            # (B, 10)


class HQNNModel(nn.Module):
    """
    HQNN (hybrid):
    bitstring (784) -> SmallCNN -> Linear(latent->8)
                      -> Dropout -> QNNLayerSmall -> Linear(8->10)
    """
    def __init__(self, latent_dim=32, p_dropout=0.3):
        super().__init__()
        self.cnn = SmallCNN(latent_dim)
        self.to_q = nn.Linear(latent_dim, n_qubits_small)
        self.dropout = nn.Dropout(p_dropout)
        self.q = QNNLayerSmall
        self.fc = nn.Linear(n_qubits_small, n_classes)

    def forward(self, x):
        z = self.cnn(x)              # (B, latent_dim)
        z = self.to_q(z)             # (B, n_qubits_small)
        z = self.dropout(z)
        z = torch.tanh(z)
        z = self.q(z)                # (B, n_qubits_small)
        return self.fc(z)            # (B, 10)

In [None]:
# ----------------------------
# 8. Plotting metrics
# ----------------------------
def plot_metrics(history, model_name):
    epochs_list = history["epoch"]
    train_loss = history["train_loss"]
    test_loss  = history["test_loss"]
    train_acc  = history["train_acc"]
    test_acc   = history["test_acc"]

    fig, axs = plt.subplots(1, 2, figsize=(10, 4))

    axs[0].plot(epochs_list, train_loss, label="Train loss")
    axs[0].plot(epochs_list, test_loss,  label="Test loss")
    axs[0].set_xlabel("Epoch")
    axs[0].set_ylabel("Loss")
    axs[0].set_title("Loss")
    axs[0].legend()

    axs[1].plot(epochs_list, train_acc, label="Train acc")
    axs[1].plot(epochs_list, test_acc,  label="Test acc")
    axs[1].set_xlabel("Epoch")
    axs[1].set_ylabel("Accuracy")
    axs[1].set_title("Accuracy")
    axs[1].legend()

    plt.tight_layout()
    out_name = f"{model_name}_metrics.png"
    plt.savefig(out_name, dpi=300, bbox_inches="tight")
    print(f"Saved metrics plot to {out_name}")
    plt.close(fig)

In [None]:
# ----------------------------
# 9. Confusion matrix utilities
# ----------------------------
def compute_confusion_matrix(y_true, y_pred, num_classes):
    cm = torch.zeros(num_classes, num_classes, dtype=torch.int64)
    for t, p in zip(y_true, y_pred):
        cm[t, p] += 1
    return cm

def plot_confusion_matrix(cm, model_name):
    fig, ax = plt.subplots(figsize=(6, 5))
    im = ax.imshow(cm.numpy(), interpolation="nearest")
    ax.figure.colorbar(im, ax=ax)
    ax.set_title(f"Confusion Matrix ({model_name.upper()})")
    ax.set_xlabel("Predicted label")
    ax.set_ylabel("True label")

    tick_marks = range(n_classes)
    ax.set_xticks(tick_marks)
    ax.set_yticks(tick_marks)
    ax.set_xticklabels(tick_marks)
    ax.set_yticklabels(tick_marks)

    thresh = cm.max().item() / 2.0 if cm.max().item() > 0 else 0.5
    for i in range(n_classes):
        for j in range(n_classes):
            value = cm[i, j].item()
            ax.text(
                j, i, str(value),
                ha="center", va="center",
                color="white" if value > thresh else "black",
                fontsize=8,
            )

    plt.tight_layout()
    out_name = f"{model_name}_confusion.png"
    plt.savefig(out_name, dpi=300, bbox_inches="tight")
    print(f"Saved confusion matrix to {out_name}")
    plt.close(fig)

In [None]:
# ----------------------------
# 10. Circuit figures
# ----------------------------
def save_circuit_figures():
    """
    Save circuit diagrams for:
    - QNN (10-qubit circuit)
    - HQNN quantum subcircuit (8-qubit)
    - QCNN (8-qubit QCNN)
    """
    # QNN
    fq = torch.zeros(n_qubits_qnn)
    wq = torch.zeros(n_layers_qnn, n_qubits_qnn, 3)
    fig1, ax1 = qml.draw_mpl(qnn_circuit)(fq, wq)
    fig1.savefig("qnn_circuit.png", dpi=300, bbox_inches="tight")
    print("Saved QNN circuit figure to qnn_circuit.png")
    plt.close(fig1)

    # HQNN quantum subcircuit
    fh = torch.zeros(n_qubits_small)
    wh = torch.zeros(n_layers_small, n_qubits_small, 3)
    fig2, ax2 = qml.draw_mpl(hqnn_circuit)(fh, wh)
    fig2.savefig("hqnn_circuit.png", dpi=300, bbox_inches="tight")
    print("Saved HQNN circuit figure to hqnn_circuit.png")
    plt.close(fig2)

    # QCNN
    fqc = torch.zeros(n_qubits_small)
    wqc = torch.zeros(30)
    fig3, ax3 = qml.draw_mpl(qcnn_circuit)(fqc, wqc)
    fig3.savefig("qcnn_circuit.png", dpi=300, bbox_inches="tight")
    print("Saved QCNN circuit figure to qcnn_circuit.png")
    plt.close(fig3)

In [None]:
# ----------------------------
# 11. Results table
# ----------------------------
def print_results_table(history, model, model_name):
    final_train_loss = history["train_loss"][-1]
    final_test_loss  = history["test_loss"][-1]
    final_train_acc  = history["train_acc"][-1]
    final_test_acc   = history["test_acc"][-1]
    n_params = sum(p.numel() for p in model.parameters())

    print("\n=== Results Summary ===")
    print(f"Model: {model_name.upper()}")
    print(f"Parameters: {n_params}")
    print("----------------------------")
    print("Metric        | Value")
    print("----------------------------")
    print(f"Train loss    | {final_train_loss:.4f}")
    print(f"Test loss     | {final_test_loss:.4f}")
    print(f"Train acc     | {final_train_acc:.4f}")
    print(f"Test acc      | {final_test_acc:.4f}")
    print("----------------------------\n")

    print("Markdown row for table:")
    print(
        f"| {model_name.upper()} | {n_params} | "
        f"{final_train_acc:.4f} | {final_test_acc:.4f} | "
        f"{final_train_loss:.4f} | {final_test_loss:.4f} |"
    )
    print()

In [None]:
# ----------------------------
# 12. Training / evaluation
# ----------------------------
def train_and_evaluate(model, model_name):
    model = model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()

    history = {
        "epoch": [],
        "train_loss": [],
        "train_acc": [],
        "test_loss": [],
        "test_acc": [],
    }

    last_y_true = None
    last_y_pred = None

    for ep in range(1, epochs + 1):
        # ---- train ----
        model.train()
        total_loss = 0.0
        total_correct = 0
        total_samples = 0

        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad()
            logits = model(xb)
            loss = loss_fn(logits, yb)
            loss.backward()
            opt.step()

            total_loss += loss.item() * xb.size(0)
            preds = logits.argmax(1)
            total_correct += (preds == yb).sum().item()
            total_samples += yb.size(0)

        train_loss = total_loss / total_samples
        train_acc = total_correct / total_samples

        # ---- test ----
        model.eval()
        test_loss_total = 0.0
        test_correct = 0
        test_samples = 0
        all_y_true = []
        all_y_pred = []
        with torch.no_grad():
            for xb, yb in test_loader:
                xb, yb = xb.to(device), yb.to(device)
                logits = model(xb)
                loss = loss_fn(logits, yb)
                test_loss_total += loss.item() * xb.size(0)
                preds = logits.argmax(1)
                test_correct += (preds == yb).sum().item()
                test_samples += yb.size(0)

                all_y_true.append(yb.cpu())
                all_y_pred.append(preds.cpu())

        test_loss = test_loss_total / test_samples
        test_acc = test_correct / test_samples

        history["epoch"].append(ep)
        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        history["test_loss"].append(test_loss)
        history["test_acc"].append(test_acc)

        last_y_true = torch.cat(all_y_true, dim=0)
        last_y_pred = torch.cat(all_y_pred, dim=0)

        print(
            f"Epoch {ep:02d} | "
            f"Train loss: {train_loss:.4f}, acc: {train_acc:.4f} | "
            f"Test loss: {test_loss:.4f}, acc: {test_acc:.4f}"
        )

    plot_metrics(history, model_name)
    return history, last_y_true, last_y_pred

In [None]:
# ----------------------------
# 13. Main
# ----------------------------
if __name__ == "__main__":
    print(f"Using device: {device}")
    print(f"Training model: {MODEL_CHOICE.upper()} on binarized MNIST")

    if MODEL_CHOICE.lower() == "qnn":
        model = QNNModel()
        name = "qnn"
    elif MODEL_CHOICE.lower() == "qcnn":
        model = QCNNModel()
        name = "qcnn"
    elif MODEL_CHOICE.lower() == "hqnn":
        model = HQNNModel()
        name = "hqnn"
    else:
        raise ValueError("MODEL_CHOICE must be 'qnn', 'qcnn', or 'hqnn'.")

    history, y_true, y_pred = train_and_evaluate(model, name)

    cm = compute_confusion_matrix(y_true, y_pred, n_classes)
    plot_confusion_matrix(cm, name)

    print_results_table(history, model, name)

    save_circuit_figures()

Using device: cpu
Training model: HQNN on binarized MNIST
Epoch 01 | Train loss: 2.0593, acc: 0.3962 | Test loss: 1.3881, acc: 0.5970
Epoch 02 | Train loss: 1.4362, acc: 0.5010 | Test loss: 0.8346, acc: 0.7150
Epoch 03 | Train loss: 1.1964, acc: 0.5699 | Test loss: 0.7073, acc: 0.7056
Epoch 04 | Train loss: 1.0865, acc: 0.6230 | Test loss: 0.6058, acc: 0.8527
Epoch 05 | Train loss: 0.9841, acc: 0.6915 | Test loss: 0.5079, acc: 0.8684
Epoch 06 | Train loss: 0.8928, acc: 0.7219 | Test loss: 0.4325, acc: 0.8732
Epoch 07 | Train loss: 0.8058, acc: 0.7516 | Test loss: 0.3655, acc: 0.8775
Epoch 08 | Train loss: 0.7468, acc: 0.7726 | Test loss: 0.3280, acc: 0.8812
Epoch 09 | Train loss: 0.7008, acc: 0.7837 | Test loss: 0.3100, acc: 0.9018
Epoch 10 | Train loss: 0.6627, acc: 0.7988 | Test loss: 0.2978, acc: 0.8927
Epoch 11 | Train loss: 0.6201, acc: 0.8164 | Test loss: 0.2654, acc: 0.9486
Epoch 12 | Train loss: 0.5761, acc: 0.8372 | Test loss: 0.2354, acc: 0.9665
Epoch 13 | Train loss: 0.5429,