In [None]:
#!pip install pennylane

In [None]:
"""
Quantum MNIST (binarized): QNN, QCNN, HQNN

- Data from: qml.data.load("other", name="binarized-mnist")
- 50'000 train, 10'000 test, each input: bitstring of length 784

Features:
- Performance logging (loss + accuracy)
- Metric plots per model: <model_name>_metrics.png
- Circuit figures: qnn_circuit.png, qcnn_circuit.png
"""


import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.data import TensorDataset
import pennylane as qml
import matplotlib.pyplot as plt

# ----------------------------
# Config
# ----------------------------
n_qubits = 8
n_classes = 10
n_layers = 3
batch_size = 128
epochs = 20
lr = 1e-3
device = "cuda" if torch.cuda.is_available() else "cpu"

# choose "qnn", "qcnn", or "hqnn"
MODEL_CHOICE = "hqnn"

In [None]:
# ----------------------------
# 1. Load binarized MNIST via qml.data
# ----------------------------
print("Loading binarized MNIST from PennyLane...")
[ds] = qml.data.load("other", name="binarized-mnist")

# ds.train["inputs"]: (50000, 784) array of 0/1
# ds.train["labels"]: (50000,) array of digits 0–9
x_train = torch.tensor(ds.train["inputs"], dtype=torch.float32)
y_train = torch.tensor(ds.train["labels"], dtype=torch.long)

x_test = torch.tensor(ds.test["inputs"], dtype=torch.float32)
y_test = torch.tensor(ds.test["labels"], dtype=torch.long)

# Create PyTorch datasets/loaders
train_dataset = TensorDataset(x_train, y_train)
test_dataset  = TensorDataset(x_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False)

print(f"Train shape: {x_train.shape}, Test shape: {x_test.shape}")

Loading binarized MNIST from PennyLane...
Train shape: torch.Size([50000, 784]), Test shape: torch.Size([10000, 784])


In [None]:
# ----------------------------
# 2. Classical helper modules
# ----------------------------
class SimpleDownscale(nn.Module):
    """
    Downscale 28x28 -> 4x4 via avg pooling, then 16 -> n_qubits via linear layer.

    Here the model receives inputs as flat vectors of length 784.
    Inside forward() we reshape to (B,1,28,28) and then pool.
    Used by QNN and QCNN.
    """
    def __init__(self, n_qubits):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d((4, 4))
        self.fc = nn.Linear(16, n_qubits)

    def forward(self, x):
        # x: (batch_size, 784) bitstrings
        x = x.view(x.size(0), 1, 28, 28)   # -> (B,1,28,28)
        x = self.pool(x)                   # -> (B,1,4,4)
        x = x.view(x.size(0), -1)          # -> (B,16)
        x = self.fc(x)                     # -> (B,n_qubits)
        return x


class SmallCNN(nn.Module):
    """
    Small classical CNN for HQNN feature extraction.
    Also expects flat 784 inputs and reshapes inside.
    """
    def __init__(self, latent_dim=32):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),      # 28x28 -> 14x14
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),      # 14x14 -> 7x7
        )
        self.fc = nn.Linear(32 * 7 * 7, latent_dim)

    def forward(self, x):
        # x: (B, 784)
        x = x.view(x.size(0), 1, 28, 28)  # -> (B,1,28,28)
        x = self.conv(x)                  # -> (B,32,7,7)
        x = x.view(x.size(0), -1)         # -> (B,32*7*7)
        return self.fc(x)                 # -> (B, latent_dim)

In [None]:
# ----------------------------
# 3. PennyLane device
# ----------------------------
dev = qml.device("default.qubit", wires=n_qubits)

In [None]:
# ----------------------------
# 4. QNN circuit (AngleEmbedding + variational ansatz)
# ----------------------------
@qml.qnode(dev, interface="torch")
def qnn_circuit(inputs, weights):
    """
    Quantum circuit for QNN and HQNN:
    - AngleEmbedding of n_qubits features
    - StronglyEntanglingLayers with trainable parameters
    - returns <Z> on each qubit (length n_qubits)
    """
    qml.AngleEmbedding(inputs, wires=range(n_qubits), rotation="Y")
    qml.StronglyEntanglingLayers(weights, wires=range(n_qubits))
    return [qml.expval(qml.PauliZ(i)) for i in range(n_qubits)]

qnn_shapes = {"weights": (n_layers, n_qubits, 3)}
QNNLayer = qml.qnn.TorchLayer(qnn_circuit, qnn_shapes)

In [None]:
# ----------------------------
# 5. QCNN circuit blocks & circuit
# ----------------------------
def conv_block(params, wires):
    """
    Quantum "convolution" on two qubits.
    """
    i, j = wires
    qml.CNOT([i, j])
    qml.RY(params[0], i)
    qml.RY(params[1], j)
    qml.RZ(params[2], j)
    qml.CNOT([i, j])

def pooling_block(params, wires):
    """
    Quantum "pooling" on two qubits: keep i, pool/discard j.
    """
    i, j = wires
    qml.CNOT([i, j])
    qml.RY(params[0], i)
    qml.RZ(params[1], i)

@qml.qnode(dev, interface="torch")
def qcnn_circuit(inputs, weights):
    """
    QCNN-style circuit:
    - AngleEmbedding on all qubits
    - conv+pool on pairs (0,1),(2,3),(4,5),(6,7)
    - conv+pool on pairs (0,2),(4,6)
    - measure qubits 0 and 4
    """
    qml.AngleEmbedding(inputs, wires=range(n_qubits), rotation="Y")

    w = weights.reshape(-1)
    idx = 0

    conv1 = [(0, 1), (2, 3), (4, 5), (6, 7)]
    for p in conv1:
        conv_block(w[idx:idx+3], p)
        idx += 3

    for p in conv1:
        pooling_block(w[idx:idx+2], p)
        idx += 2

    conv2 = [(0, 2), (4, 6)]
    for p in conv2:
        conv_block(w[idx:idx+3], p)
        idx += 3

    for p in conv2:
        pooling_block(w[idx:idx+2], p)
        idx += 2

    return [qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliZ(4))]

# conv1: 4×3 = 12, pool1: 4×2 = 8, conv2: 2×3 = 6, pool2: 2×2 = 4 => 30 params
qcnn_shapes = {"weights": (30,)}
QCNNLayer = qml.qnn.TorchLayer(qcnn_circuit, qcnn_shapes)

In [None]:
# ----------------------------
# 6. Models: QNN, QCNN, HQNN
# ----------------------------
class QNNModel(nn.Module):
    """
    QNN:
    bitstring (784) -> SimpleDownscale -> QNNLayer -> Linear -> 10 logits
    """
    def __init__(self):
        super().__init__()
        self.pre = SimpleDownscale(n_qubits)
        self.q = QNNLayer
        self.fc = nn.Linear(n_qubits, n_classes)

    def forward(self, x):
        x = self.pre(x)        # -> (B,n_qubits)
        x = torch.tanh(x)
        x = self.q(x)          # -> (B,n_qubits) (<Z>)
        return self.fc(x)      # -> (B,10)


class QCNNModel(nn.Module):
    """
    QCNN:
    bitstring (784) -> SimpleDownscale -> QCNNLayer -> Linear(2->10)
    """
    def __init__(self):
        super().__init__()
        self.pre = SimpleDownscale(n_qubits)
        self.q = QCNNLayer
        self.fc = nn.Linear(2, n_classes)

    def forward(self, x):
        x = self.pre(x)        # -> (B,n_qubits)
        x = torch.tanh(x)
        x = self.q(x)          # -> (B,2)
        return self.fc(x)      # -> (B,10)


class HQNNModel(nn.Module):
    """
    HQNN (hybrid):
    bitstring (784) -> SmallCNN -> Linear(latent->n_qubits)
                      -> QNNLayer -> Linear -> 10 logits
    """
    def __init__(self, latent_dim=32):
        super().__init__()
        self.cnn = SmallCNN(latent_dim)
        self.to_q = nn.Linear(latent_dim, n_qubits)
        self.q = QNNLayer
        self.fc = nn.Linear(n_qubits, n_classes)

    def forward(self, x):
        z = self.cnn(x)        # -> (B, latent_dim)
        z = self.to_q(z)       # -> (B, n_qubits)
        z = torch.tanh(z)
        z = self.q(z)          # -> (B, n_qubits)
        return self.fc(z)      # -> (B,10)

In [None]:
# ----------------------------
# 7. Plotting metrics
# ----------------------------
def plot_metrics(history, model_name):
    """
    Plot train/test loss and accuracy over epochs.
    Saves as <model_name>_metrics.png.
    """
    epochs_list = history["epoch"]
    train_loss = history["train_loss"]
    test_loss  = history["test_loss"]
    train_acc  = history["train_acc"]
    test_acc   = history["test_acc"]

    fig, axs = plt.subplots(1, 2, figsize=(10, 4))

    axs[0].plot(epochs_list, train_loss, label="Train loss")
    axs[0].plot(epochs_list, test_loss,  label="Test loss")
    axs[0].set_xlabel("Epoch")
    axs[0].set_ylabel("Loss")
    axs[0].set_title("Loss")
    axs[0].legend()

    axs[1].plot(epochs_list, train_acc, label="Train acc")
    axs[1].plot(epochs_list, test_acc,  label="Test acc")
    axs[1].set_xlabel("Epoch")
    axs[1].set_ylabel("Accuracy")
    axs[1].set_title("Accuracy")
    axs[1].legend()

    plt.tight_layout()
    out_name = f"{model_name}_metrics.png"
    plt.savefig(out_name, dpi=300, bbox_inches="tight")
    print(f"Saved metrics plot to {out_name}")
    plt.close(fig)

In [None]:
# ----------------------------
# 8. Circuit figures (gate diagrams)
# ----------------------------
def save_circuit_figures():
    """
    Draw and save circuit diagrams for QNN and QCNN using qml.draw_mpl.
    Files:
    - qnn_circuit.png
    - qcnn_circuit.png
    """
    # QNN
    features_qnn = torch.zeros(n_qubits)
    weights_qnn  = torch.zeros(n_layers, n_qubits, 3)
    fig1, ax1 = qml.draw_mpl(qnn_circuit)(features_qnn, weights_qnn)
    fig1.savefig("qnn_circuit.png", dpi=300, bbox_inches="tight")
    print("Saved QNN circuit figure to qnn_circuit.png")
    plt.close(fig1)

    # QCNN
    features_qcnn = torch.zeros(n_qubits)
    weights_qcnn  = torch.zeros(30)
    fig2, ax2 = qml.draw_mpl(qcnn_circuit)(features_qcnn, weights_qcnn)
    fig2.savefig("qcnn_circuit.png", dpi=300, bbox_inches="tight")
    print("Saved QCNN circuit figure to qcnn_circuit.png")
    plt.close(fig2)

In [None]:
# ----------------------------
# 9. Training / evaluation with logging
# ----------------------------
def train_and_evaluate(model, model_name):
    model = model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()

    history = {
        "epoch": [],
        "train_loss": [],
        "train_acc": [],
        "test_loss": [],
        "test_acc": [],
    }

    for ep in range(1, epochs + 1):
        # ---- train ----
        model.train()
        total_loss = 0.0
        total_correct = 0
        total_samples = 0

        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad()
            logits = model(xb)
            loss = loss_fn(logits, yb)
            loss.backward()
            opt.step()

            total_loss += loss.item() * xb.size(0)
            preds = logits.argmax(1)
            total_correct += (preds == yb).sum().item()
            total_samples += yb.size(0)

        train_loss = total_loss / total_samples
        train_acc = total_correct / total_samples

        # ---- test ----
        model.eval()
        test_loss_total = 0.0
        test_correct = 0
        test_samples = 0
        with torch.no_grad():
            for xb, yb in test_loader:
                xb, yb = xb.to(device), yb.to(device)
                logits = model(xb)
                loss = loss_fn(logits, yb)
                test_loss_total += loss.item() * xb.size(0)
                preds = logits.argmax(1)
                test_correct += (preds == yb).sum().item()
                test_samples += yb.size(0)

        test_loss = test_loss_total / test_samples
        test_acc = test_correct / test_samples

        history["epoch"].append(ep)
        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        history["test_loss"].append(test_loss)
        history["test_acc"].append(test_acc)

        print(
            f"Epoch {ep:02d} | "
            f"Train loss: {train_loss:.4f}, acc: {train_acc:.4f} | "
            f"Test loss: {test_loss:.4f}, acc: {test_acc:.4f}"
        )

    plot_metrics(history, model_name)
    return history

In [None]:
# ----------------------------
# 10. Main
# ----------------------------
if __name__ == "__main__":
    print(f"Using device: {device}")
    print(f"Training model: {MODEL_CHOICE.upper()} on binarized MNIST")

    if MODEL_CHOICE.lower() == "qnn":
        model = QNNModel()
        name = "qnn"
    elif MODEL_CHOICE.lower() == "qcnn":
        model = QCNNModel()
        name = "qcnn"
    elif MODEL_CHOICE.lower() == "hqnn":
        model = HQNNModel()
        name = "hqnn"
    else:
        raise ValueError("MODEL_CHOICE must be 'qnn', 'qcnn', or 'hqnn'.")

    history = train_and_evaluate(model, name)
    save_circuit_figures()

Using device: cpu
Training model: HQNN on binarized MNIST
Epoch 01 | Train loss: 1.9465, acc: 0.3739 | Test loss: 1.4079, acc: 0.7233
Epoch 02 | Train loss: 0.9967, acc: 0.7629 | Test loss: 0.7571, acc: 0.7775
Epoch 03 | Train loss: 0.6679, acc: 0.7857 | Test loss: 0.5980, acc: 0.7857
Epoch 04 | Train loss: 0.5487, acc: 0.8075 | Test loss: 0.5106, acc: 0.8377
Epoch 05 | Train loss: 0.4395, acc: 0.8739 | Test loss: 0.4247, acc: 0.8695
Epoch 06 | Train loss: 0.3575, acc: 0.8839 | Test loss: 0.3522, acc: 0.8789
Epoch 07 | Train loss: 0.3112, acc: 0.8908 | Test loss: 0.3206, acc: 0.8805
Epoch 08 | Train loss: 0.2662, acc: 0.9354 | Test loss: 0.2648, acc: 0.9661
Epoch 09 | Train loss: 0.2009, acc: 0.9748 | Test loss: 0.2043, acc: 0.9691
Epoch 10 | Train loss: 0.1481, acc: 0.9783 | Test loss: 0.1820, acc: 0.9664
Epoch 11 | Train loss: 0.1227, acc: 0.9794 | Test loss: 0.1518, acc: 0.9732
Epoch 12 | Train loss: 0.1031, acc: 0.9825 | Test loss: 0.1518, acc: 0.9690
Epoch 13 | Train loss: 0.0901,