In [None]:
#Quantum Neural Network for MNIST Classification (0 vs 1)
# All the imports that would be needed for the code to run!
# Import PennyLane (Quantum library) and its NumPy version
import pennylane as qml
from pennylane import numpy as np
# Import PyTorch for machine learning utilities
import torch 
import torchvision 
from torchvision import transforms 
from torch.utils.data import DataLoader, Subset
#-----------------------------


In [None]:
# ----------------------------
# Data Loading and Preprocessing
# ----------------------------
#Load MNIST and reduce to digits 0 and 1
#Define a transformation: convert images to tensors and normalize pixel values to [-1, 1]
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_data = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
#-----------------------------

In [None]:
#Filter digits 0 and 1 only
def filter_digits(dataset):
    idx = (dataset.targets ==0) | (dataset.targets ==1)
    # Keep samples where label is 0 or 1
    dataset.targets = dataset.targets[idx]
    return dataset
train_data = filter_digits(train_data)
test_data = filter_digits(test_data)
#----------------------------

In [None]:
#Downsample images to 4x4 = 16 pixels (for 4 qubits)
# (Since using 4 qubits, each qubit gets 1 feature)
def downsample(img):
    return torch.nn.functional.interpolate(
        img.unsqueeze(0), 
        size=(4,4).view(-1)
    )

#Creating a custom dataset to prepare quantum-friendly inputs
class QuantumMNISTDataset(torch.utils.data.Dataset):
    def __init__(self,dataset):
        self.data = dataset 
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        img, label = self.data[idx]
        features = downsample(img).numpy() #Downsample the image 
        features = features / np.linalg.norm(features) #Normalize the feature vector
        return features, label 
batch_size = 10 # giving a batch size of 10
#-----------------------------

In [None]:
# ----------------------------
# Quantum Circuit Definition
# ----------------------------
#Quantum Circuit Parameters
# Set the number of qubits
n_qubits = 4 #Setting the number of qubits
# Initialize a quantum device (simulator)
dev = qml.device("default.qubit", wires=n_qubits)
# Encode classical input into quantum states using rotations
def angle_embedding(x):
    for i in range(n_qubits):
        qml.RY(x[i],wires=i) #Apply rotation around Y-axis
#Variational quantum layer (parameterized by weights)
def variational_block(weights):
    for i in range(n_qubits):
        qml.Rot(*weights[i], wires=i)  # Apply rotation with 3 parameters (X, Y, Z)
    for i in range(n_qubits - 1):
        qml.CNOT(wires=[i, i + 1])      # Add entanglement between neighboring qubits
# Define the full quantum node (circuit + measurements)
@qml.qnode(dev, interface="torch") #This is a quantum node
def quantum_net(inputs,weights):
    angle_embedding(inputs) #Embed features into quantum state
    variational_block(weights) #Apply learnable quantum layer
    return qml.expval(qml.PauliZ(0)) # Measure expectation value of PauliZ on first qubit
# ----------------------------

In [None]:
# ----------------------------
# Defining the Full Quantum Neural Network in PyTorch
# ----------------------------
#Define Torch Layer
class QNet(torch.nn.Module): #This is a PyTorch module
    def __init__(self): #Initialize the module with parameters
        super().__init__()# Call the parent constructor
        #Initialize the weights for the quantum circuit
        #The weights are the parameters of the quantum circuit
        weight_shapes = (n_qubits, 3)  # 3 parameters (angles) per qubit
        self.q_params = torch.nn.Parameter(torch.randn(weight_shapes) * 0.01) 
         #self.q_params is a learnable parameter
    def forward(self, x): #Forward pass through the quantum circuit
       #Apply quantum_net to each input separately and stack results
       return torch.stack(
            [quantum_net(x[i], self.q_params) 
            for i in range(x.shape[0])]
        ) 
    #Stack the outputs of the quantum circuit


In [None]:
# ----------------------------
# Model Training
# ----------------------------
#Defining the model, loss function, and optimizer

In [None]:
#Testing