In [1]:
from model import QuantumCircuit, FullQuantumModel
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split

In [2]:
#Model instantiation

num_qubits = 8
num_layers = 3
model = FullQuantumModel(num_qubits, num_layers)

In [3]:
model.draw(style = 'sketch')

probs:  probs(wires=[0])

state:  state(wires=[])


QuantumFunctionError: All measurements must be returned in the order they are measured.

# Dataset preparation

In [None]:
# Download MNIST and prepare transforms
mnist_train = datasets.MNIST(root='./data', train=True, download=True,
                             transform=transforms.Compose([
                                transforms.Resize((16, 16)),  # Resize to 16x16
                                transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))  # Normalize
                             ]))
# Filter for zeros and ones
data = []
targets = []
for image, label in mnist_train:
    if label in [0, 1]:
        data.append(image.squeeze())
        targets.append(label)

data = torch.stack(data)
targets = torch.tensor(targets)
# Select 1024 zeros and 1024 ones for speed
zeros_indices = (targets == 0)
ones_indices = (targets == 1)

zeros = data[zeros_indices]
ones = data[ones_indices]

#normalize between 0 and 1
zeros_max = torch.max(zeros.reshape(-1, 16*16), dim = 1)
zeros_min = torch.min(zeros.reshape(-1, 16*16), dim = 1)
ones_max = torch.max(ones.reshape(-1, 16*16), dim = 1)
ones_min = torch.min(ones.reshape(-1, 16*16), dim = 1)

def normalize(imgs):
  maxes, _ = torch.max(imgs.reshape(-1, 16*16), dim = 1)
  mins, _ = torch.min(imgs.reshape(-1, 16*16), dim = 1)

  mins = mins.unsqueeze(1).unsqueeze(2)
  maxes = maxes.unsqueeze(1).unsqueeze(2)

  return (imgs-mins)/(maxes-mins)

zeros = normalize(zeros)
ones = normalize(ones)

# assert images have min 0 and max 1 within an error of 1e-5
assert torch.allclose(zeros.min(), torch.tensor(0., dtype = torch.float32), atol=1e-5)
assert torch.allclose(zeros.max(), torch.tensor(1., dtype = torch.float32), atol=1e-5)
assert torch.allclose(ones.min(), torch.tensor(0., dtype = torch.float32), atol=1e-5)
assert torch.allclose(ones.max(), torch.tensor(1., dtype = torch.float32), atol=1e-5)

# concatenate the two datasets
zeros = zeros.flatten(start_dim = 1)
ones = ones.flatten(start_dim = 1)
dataset = torch.cat((zeros, ones), dim = 0)

# add labels
labels = torch.cat((torch.zeros((zeros.shape[0], 1)), torch.ones((ones.shape[0], 1))), dim = 0).squeeze()

# build dataloader
dataset = torch.utils.data.TensorDataset(dataset, labels)

In [None]:
#train/test split 80/20

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

#dataloader
BATCH_SIZE = 32

train_dataloader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = True, drop_last = True)

test_dataloader = DataLoader(test_dataset, batch_size = BATCH_SIZE, shuffle = True, drop_last = True)

In [None]:
len(train_dataset)

In [None]:
len(test_dataset)

# Model training

In [9]:
model.fit(dataloader=train_dataloader, learning_rate=0.01, epochs=2, show_plot=True)

Epoch 1/2:   0%|          | 0/316 [00:00<?, ?it/s]


probs:  probs(wires=[])

state:  state(wires=[])


QuantumFunctionError: All measurements must be returned in the order they are measured.

In [13]:
model.freeze_layers(list(range(model.num_layers)))
model.eval()

accuracy = []
for data, targets in test_dataloader:
    data = data / torch.linalg.norm(data, dim=1).view(-1, 1)
    output = model(data)
    accuracy.append(torch.sum((output > 0.5) == targets).item() / test_dataloader.batch_size)
    
accuracy = sum(accuracy)/len(accuracy)
print(accuracy)

0.9909018987341772


In [12]:
accuracy

[1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.96875,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.96875,
 0.96875,
 0.96875,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.96875,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.96875,
 1.0,
 0.96875,
 1.0,
 0.9375,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.96875,
 0.96875,
 1.0,
 1.0,
 1.0,
 0.96875,
 1.0,
 1.0,
 1.0,
 1.0,
 0.96875,
 0.96875,
 1.0,
 1.0,
 0.96875,
 1.0,
 1.0,
 0.96875,
 0.96875,
 0.96875,
 1.0,
 0.96875,
 0.96875,
 1.0,
 1.0,
 1.0,
 1.0,
 1.0,
 0.96875,
 1.0,
 0.96875,
 1.0,
 1.0,
 0.96875,
 1.0]