In [5]:


import torch
import torch.nn as nn
import torchquantum as tq
import torchvision
import torchvision.transforms as transforms

class QuantumConvLayer(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size):
        super(QuantumConvLayer, self).__init__()
        # Define the convolutional parameters
        self.conv = nn.Conv2d(input_channels, output_channels, kernel_size, padding=kernel_size//2)
        self.q_layer = tq.QuantumLayer(input_channels * kernel_size * kernel_size, output_channels)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.q_layer(x)
        return x

class QuantumPoolingLayer(nn.Module):
    def __init__(self, pool_size):
        super(QuantumPoolingLayer, self).__init__()
        self.pool_size = pool_size

    def forward(self, x):
        bsz, channels, height, width = x.size()
        x = x.view(bsz, channels, height // self.pool_size, self.pool_size, width // self.pool_size, self.pool_size)
        x = x.max(dim=4)[0].max(dim=3)[0]
        return x

class QuantumReLU(nn.Module):
    def forward(self, x):
        return torch.clamp(x, min=0.0)

class QuantumResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(QuantumResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.q_layer = QuantumLayer(out_channels, out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU()

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.q_layer(out.view(out.size(0), -1))
        out = out.view(x.size(0), -1, out.size(1))
        out = self.conv2(out)
        out += residual
        out = self.relu(out)
        return out

class QuantumNeuralNetwork(nn.Module):
    def __init__(self):
        super(QuantumNeuralNetwork, self).__init__()
        self.q_conv1 = QuantumConvLayer(input_channels=1, output_channels=32, kernel_size=3)
        self.q_pool = QuantumPoolingLayer(pool_size=2)
        self.q_residual = QuantumResidualBlock(in_channels=32, out_channels=64)
        self.fc = nn.Linear(64 * 14 * 14, 10)  # Adjust dimensions according to the output of the last layer

    def forward(self, x):
        x = self.q_conv1(x)
        x = self.q_pool(x)
        x = self.q_residual(x)
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.fc(x)
        return x


# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Define data loaders
batch_size = 64
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
