In [143]:
!pip install cirq



In [144]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import cirq
from math import pi

In [145]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [146]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=30, shuffle=True)

In [147]:
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=30, shuffle=False)

In [155]:
class QuantamConv2d(nn.Module):
    def __init__(self):
        super(QuantamConv2d, self).__init__()
        self.weight = nn.Parameter(torch.ones(4))*100
        self.bias = nn.Parameter(torch.zeros(4))
        self.simulator = cirq.Simulator()


    def forward(self, x):
        kernel_height, kernel_width = (2,2)
        assert x.size() == ( 30, 1, 28, 28 )
        batch, _, image_height, image_width = x.size()
        result = torch.zeros(30, 1, 14, 14)
        for z in range(30):
            for i in range(0,image_height - kernel_height + 1,2):
                for j in range(0,image_width - kernel_width + 1,2):
                    P = x[z][0][i:i+kernel_height, j:j+kernel_width]
                    P = [P[0][0], P[0][1], P[1][0], P[1][1]]
                    circuit, keys = self.kernel(P)
                    res = self.simulator.run(circuit, repetitions=10)
                    # print(res.histogram(key=keys[3]))
                    try:
                        result[z][0][i//2][j//2] = res.histogram(key=keys[3])[1] * 0.1
                    except:
                        result[z][0][i//2][j//2] = 0
        return result
    
    def backward():
        pass

    def kernel(self, P):
        Q = [cirq.GridQubit(i,0) for i in range(4)]
        W = [cirq.GridQubit(i,1) for i in range(3)]
        keys = ["q0", "q1", "q2", "q3"]

        circuit = cirq.Circuit()
        # for i in range(4):
        #     circuit.append(cirq.H(Q[i]))

        weight = self.weight.tolist()

        for i in range(4):
            circuit.append(cirq.ry(P[i].item()/255 * pi).on(Q[i]))

        for i in range(3):
            circuit.append(cirq.rx(weight[i]/255 * pi).on(W[i]))

        for i in range(3):
            circuit.append(cirq.TOFFOLI(W[i], Q[i], Q[i+1]))

        for i in range(3):
            circuit.append(cirq.ZZ(Q[i], Q[i+1]))

        for i in range(4):
            circuit.append(cirq.measure(Q[i], key=keys[i]))
        return circuit, keys



In [149]:
class CustomLayer(nn.Module):
    def __init__(self):
        super(CustomLayer, self).__init__()
        self.conv1 = QuantamConv2d()

    def forward(self, x):
        # print("before : ", x.shape)
        x = torch.relu(self.conv1(x))
        # print("after : ", x.shape)
        return x

In [150]:
# Define the CNN model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.custom_layer = CustomLayer()
        self.fc1 = nn.Linear(14*14, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=2, stride=2)

    def forward(self, x):
        # x = torch.relu(self.custom_layer(x))
        x = torch.relu(self.conv(x))
        x = x.view(-1, 1*14*14)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [181]:
# Initialize the network and optimizer
net = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [152]:
print(dir(trainloader))

['_DataLoader__initialized', '_DataLoader__multiprocessing_context', '_IterableDataset_len_called', '__annotations__', '__class__', '__class_getitem__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__orig_bases__', '__parameters__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__weakref__', '_auto_collation', '_dataset_kind', '_get_iterator', '_index_sampler', '_is_protocol', '_iterator', 'batch_sampler', 'batch_size', 'check_worker_number_rationality', 'collate_fn', 'dataset', 'drop_last', 'generator', 'multiprocessing_context', 'num_workers', 'persistent_workers', 'pin_memory', 'pin_memory_device', 'prefetch_factor', 'sampler', 'timeout', 'worker_init_fn']


In [182]:
# Training loop
for epoch in range(100):  # Change the number of epochs as needed
    running_loss = 0.0
    print(f"Running epoc: {epoch}")
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        break

print('Finished Training')
print("Running loss: ", running_loss)

Running epoc: 0
Running epoc: 1
Running epoc: 2
Running epoc: 3
Running epoc: 4
Running epoc: 5
Running epoc: 6
Running epoc: 7
Running epoc: 8
Running epoc: 9
Running epoc: 10
Running epoc: 11
Running epoc: 12
Running epoc: 13
Running epoc: 14
Running epoc: 15
Running epoc: 16
Running epoc: 17
Running epoc: 18
Running epoc: 19
Running epoc: 20
Running epoc: 21
Running epoc: 22
Running epoc: 23
Running epoc: 24
Running epoc: 25
Running epoc: 26
Running epoc: 27
Running epoc: 28
Running epoc: 29
Running epoc: 30
Running epoc: 31
Running epoc: 32
Running epoc: 33
Running epoc: 34
Running epoc: 35
Running epoc: 36
Running epoc: 37
Running epoc: 38
Running epoc: 39
Running epoc: 40
Running epoc: 41
Running epoc: 42
Running epoc: 43
Running epoc: 44
Running epoc: 45
Running epoc: 46
Running epoc: 47
Running epoc: 48
Running epoc: 49
Running epoc: 50
Running epoc: 51
Running epoc: 52
Running epoc: 53
Running epoc: 54
Running epoc: 55
Running epoc: 56
Running epoc: 57
Running epoc: 58
Running

In [183]:
# Testing the model
correct = 0
total = 0
count = 0
with torch.no_grad():
    for data in testloader:
        count += 1
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        print(f"total: {total}, correct = {correct}")
        if total > 100 : break
        
print(f'Accuracy of the network on the {total} test images: {100 * correct / total:.2f}%')


total: 30, correct = 6
total: 60, correct = 7
total: 90, correct = 9
total: 120, correct = 10
Accuracy of the network on the 120 test images: 8.33%
