# Classifying MNIST with a simple model and quantum embeddings

Inspired by:  https://www.kaggle.com/code/geekysaint/solving-mnist-using-pytorch

Useful imports

In [1]:
# for the Boson Sampler
import perceval as pcvl
#import perceval.providers.scaleway as scw  # Uncomment to allow running on scaleway

# for the machine learning model
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import time
from boson_sampler import BosonSampler
from utils import MNIST_partial, accuracy, plot_training_metrics
from model128 import *
import csv

## Definition of the Boson Sampler

In [2]:
session = None
# to run a remote session on Scaleway, uncomment the following and fill project_id and token
# session = scw.Session(
#                    platform="sim:sampling:p100",  # or sim:sampling:h100
#                    project_id=""  # Your project id,
#                    token=""  # Your personal API key
#                    )

# start session
if session is not None:
    session.start()
# definition of the BosonSampler
# here, we use 30 photons and 2 modes

bs = BosonSampler(10,2, postselect = 2, session = session)
print(f"Boson sampler defined with number of parameters = {bs.nb_parameters}, and embedding size = {bs.embedding_size}")

#to display it
# pcvl.pdisplay(bs.create_circuit())
# define device to run the model

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f'DEVICE = {device}')

Boson sampler defined with number of parameters = 85, and embedding size = 45
DEVICE = cpu


## Dataset : a subset of MNIST dataset

In [3]:
# dataset from csv file, to use for the challenge
train_dataset = MNIST_partial(split = 'train')
val_dataset = MNIST_partial(split='val')

# definition of the dataloader, to process the data in the model
# here, we need a batch size of 1 to use the boson sampler
batch_size = 10
train_loader = DataLoader(train_dataset, batch_size, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size, shuffle = False)

In [4]:
for images, labels in train_loader:  
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    break

Image batch dimensions: torch.Size([10, 1, 28, 28])
Image label dimensions: torch.Size([10])


In [5]:
def training_step(model, batch, emb = None):
    images, labels = batch
    images, labels = images.to(device), labels.to(device)
    images.requires_grad = True  # Only needed for special cases
    # if self.embedding_size:
    #     out = self(images, emb.to(self.device)) ## Generate predictions
    # else:
    # images = images.reshape(-1,1,28,28)
    # print(images.shape)

    # out = model(images) ## Generate predictions
    # loss = F.cross_entropy(out, labels)
    # acc = accuracy(out, labels, task="multiclass", num_classes=10)
    loss, acc = model((images, labels)) ## Generate predictions
    
    # loss = F.cross_entropy(out, labels) ## Calculate the loss
    # acc = accuracy(out, labels)
    return loss, acc

def validation_step(model, batch, emb =None):
    images, labels = batch
    images, labels = images.to(device), labels.to(device)
    images.requires_grad = True  # Only needed for special cases
    # if self.embedding_size:
    #     out = self(images, emb.to(self.device)) ## Generate predictions
    # # else:
    # out = model(images) ## Generate predictions
    # loss = F.cross_entropy(out, labels)
    # acc = accuracy(out, labels, task="multiclass", num_classes=10)
    loss, acc = model((images,labels)) ## Generate predictions
    return({'val_loss':loss, 'val_acc': acc})

def validation_epoch_end(outputs):
    batch_losses = [x['val_loss'] for x in outputs]
    epoch_loss = torch.stack(batch_losses).mean()
    batch_accs = [x['val_acc'] for x in outputs]
    epoch_acc = torch.stack(batch_accs).mean()
    return({'val_loss': epoch_loss.item(), 'val_acc' : epoch_acc.item()})

def epoch_end(epoch, result):
    print("Epoch [{}], val_loss: {:.4f}, val_acc: {:.4f}".format(epoch, result['val_loss'], result['val_acc']))
    return result['val_loss'], result['val_acc']

# training loop
def fit(epochs, lr, model, train_loader, val_loader, bs: BosonSampler, opt_func = torch.optim.SGD):
    history = []
    optimizer = opt_func([{'params': model.model.fc.parameters(), 'lr': lr},{'params': model.model.conv1.parameters(), 'lr': lr}])
    # creation of empty lists to store the training metrics
    train_loss, train_acc, val_loss, val_acc = [], [], [], []
    for epoch in range(epochs):
        model.train()
        training_losses, training_accs = 0, 0
        ## Training Phase
        for step, batch in enumerate(tqdm(train_loader)):
            # # embedding in the BS
            # if model.embedding_size:
            #     images, labs = batch
            #     images = images.squeeze(0).squeeze(0)
            #     t_s = time.time()
            #     embs = bs.embed(images,1000)
            #     loss,acc = model.training_step(batch,emb = embs.unsqueeze(0))

            # else:
            
            loss,acc = training_step(model, batch)
            # loss.requires_grad = True
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            # Backpropagation with boson sampling noise
            # for param in model.parameters():
            #     param.grad = param.grad + bs(param.grad.shape, scale=0.001)

            training_losses+=int(loss.detach())
            training_accs+=int(acc.detach())
            # if model.embedding_size and step%100==0:
            #     print(f"STEP {step}, Training-acc = {training_accs/(step+1)}, Training-losses = {training_losses/(step+1)}")
        
        ## Validation phase
        outputs = [validation_step(model, batch) for batch in val_loader]
        result = (validation_epoch_end(outputs))
        # result = evaluate(model, val_loader, bs)
        validation_loss, validation_acc = result['val_loss'], result['val_acc']
        epoch_end(epoch, result)
        history.append(result)

        ## summing up all the training and validation metrics
        training_loss = training_losses/len(train_loader)
        training_accs = training_accs/len(train_loader)
        train_loss.append(training_loss)
        train_acc.append(training_accs)
        val_loss.append(validation_loss)
        val_acc.append(validation_acc)

        # plot training curves
        

    with open('tmp_file.txt', 'w') as f:
        csv.writer(f, delimiter=' ').writerows([train_acc,val_acc,train_loss,val_loss])
        # plot_training_metrics(train_acc,val_acc,train_loss,val_loss)
    return(history)


## Check original accuracy

In [6]:
# from pre_trained import *
# trainer = Trainer(
#     fast_dev_run=False,
#     # logger=TensorBoardLogger("cifar10", name='resnet18'),
#     deterministic=True,
#     log_every_n_steps=1,
#     max_epochs=100,
#     precision=32,
# )

# model = CIFAR10Module(minst=True)
# model.load_state_dict(torch.load('state_dicts/resnet18.pt'), strict=False)
# trainer.test(model, val_loader)

# [{'acc/test': 10.833333015441895}]

In [7]:
# model = CIFAR10Module(minst=False)
# data = CIFAR10Data()
# model.model.load_state_dict(torch.load('state_dicts/resnet18.pt'))
# trainer.test(model, data.val_dataloader())

# [{'acc/test': 93.06890869140625}]

## Training loop

### Classical model

In [8]:
# model = CIFAR10Module(minst=True)
# # data = CIFAR10Data()

# for param in model.parameters():
#     param.requires_grad = False  # Freeze all layers

# # Unfreeze the last layer
# for param in model.model.conv1.parameters():
#     param.requires_grad = True
# for param in model.model.fc.parameters():
#     param.requires_grad = True
# pretrained_dict = torch.load("state_dicts/resnet18.pt")  # Path to saved weights
# model_dict = model.state_dict()

# # Remove 'conv1' weights from pretrained_dict to avoid shape mismatch
# pretrained_dict = {k: v for k, v in pretrained_dict.items() if "conv1" or 'fc' not in k}

# # Update the current model dictionary
# model_dict.update(pretrained_dict)

# # Load the modified weights
# model.load_state_dict(model_dict, strict=False) 
# model.to(device)


In [9]:
# experiment = fit(epochs = 20, lr = 0.001, model = model, train_loader = train_loader, val_loader = val_loader, bs=bs)

# 100%|██████████| 600/600 [00:28<00:00, 20.78it/s]
# Epoch [0], val_loss: 1.3817, val_acc: 62.0000
# 100%|██████████| 600/600 [00:22<00:00, 26.77it/s]
# Epoch [1], val_loss: 1.0354, val_acc: 69.3333
# 100%|██████████| 600/600 [00:22<00:00, 26.83it/s]
# Epoch [2], val_loss: 0.8341, val_acc: 77.1667
# 100%|██████████| 600/600 [00:22<00:00, 26.98it/s]
# Epoch [3], val_loss: 0.7472, val_acc: 77.5000
# 100%|██████████| 600/600 [00:21<00:00, 27.38it/s]
# Epoch [4], val_loss: 0.6718, val_acc: 80.1667
# 100%|██████████| 600/600 [00:20<00:00, 29.88it/s]
# Epoch [5], val_loss: 0.6378, val_acc: 81.1667
# 100%|██████████| 600/600 [00:20<00:00, 29.55it/s]
# Epoch [6], val_loss: 0.6083, val_acc: 81.1667
# 100%|██████████| 600/600 [00:20<00:00, 29.37it/s]
# Epoch [7], val_loss: 0.6031, val_acc: 80.3333
# 100%|██████████| 600/600 [00:21<00:00, 27.43it/s]
# Epoch [8], val_loss: 0.5747, val_acc: 81.3333
# 100%|██████████| 600/600 [00:21<00:00, 27.53it/s]
# Epoch [9], val_loss: 0.5491, val_acc: 83.1667
# 100%|██████████| 600/600 [00:21<00:00, 28.47it/s]
# Epoch [10], val_loss: 0.5424, val_acc: 83.1667
# 100%|██████████| 600/600 [00:21<00:00, 28.44it/s]
# Epoch [11], val_loss: 0.5516, val_acc: 83.3333
# 100%|██████████| 600/600 [00:22<00:00, 26.27it/s]
# Epoch [12], val_loss: 0.5268, val_acc: 82.0000
# 100%|██████████| 600/600 [00:22<00:00, 26.40it/s]
# Epoch [13], val_loss: 0.5035, val_acc: 83.5000
# 100%|██████████| 600/600 [00:22<00:00, 26.24it/s]
# Epoch [14], val_loss: 0.5045, val_acc: 83.6667
# 100%|██████████| 600/600 [00:23<00:00, 25.35it/s]
# Epoch [15], val_loss: 0.4765, val_acc: 84.3333
# 100%|██████████| 600/600 [00:21<00:00, 28.32it/s]
# Epoch [16], val_loss: 0.4787, val_acc: 84.3333
# 100%|██████████| 600/600 [00:21<00:00, 27.69it/s]
# Epoch [17], val_loss: 0.4639, val_acc: 84.5000
# 100%|██████████| 600/600 [00:22<00:00, 26.24it/s]
# Epoch [18], val_loss: 0.4560, val_acc: 87.3333
# 100%|██████████| 600/600 [00:25<00:00, 23.85it/s]
# Epoch [19], val_loss: 0.4453, val_acc: 86.6667

### different initial condition

In [10]:
# model = CIFAR10Module()
# # data = CIFAR10Data()

# for param in model.parameters():
#     param.requires_grad = False  # Freeze all layers
# pretrained_dict = torch.load("state_dicts/resnet18.pt")  # Path to saved weights
# model.load_state_dict(pretrained_dict, strict=False) 
# model.model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 
# # Unfreeze the last layer
# for param in model.model.conv1.parameters():
#     param.requires_grad = True
# for param in model.model.fc.parameters():
#     param.requires_grad = True
# model.to(device)
# experiment = fit(epochs = 2, lr = 0.001, model = model, train_loader = train_loader, val_loader = val_loader, bs=bs)

# 100%|██████████| 600/600 [00:25<00:00, 23.76it/s]
# Epoch [0], val_loss: 1.6154, val_acc: 50.3333
# 100%|██████████| 600/600 [00:24<00:00, 24.72it/s]
# Epoch [1], val_loss: 1.2136, val_acc: 65.1667

### without training the last layer

In [11]:
# train the model with the chosen parameters
# experiment = fit(epochs = 20, lr = 0.001, model = model, train_loader = train_loader, val_loader = val_loader, bs=bs)

# 100%|██████████| 600/600 [00:23<00:00, 25.89it/s]
# Epoch [0], val_loss: 2.2501, val_acc: 19.1667
# 100%|██████████| 600/600 [00:22<00:00, 27.08it/s]
# Epoch [1], val_loss: 2.1102, val_acc: 23.1667
# 100%|██████████| 600/600 [00:24<00:00, 24.76it/s]
# Epoch [2], val_loss: 2.0289, val_acc: 30.0000
# 100%|██████████| 600/600 [00:23<00:00, 25.88it/s]
# Epoch [3], val_loss: 1.9465, val_acc: 33.0000
# 100%|██████████| 600/600 [00:21<00:00, 28.09it/s]
# Epoch [4], val_loss: 1.9684, val_acc: 34.6667
# 100%|██████████| 600/600 [00:20<00:00, 28.82it/s]
# Epoch [5], val_loss: 1.9208, val_acc: 34.8333
# 100%|██████████| 600/600 [00:21<00:00, 28.14it/s]
# Epoch [6], val_loss: 1.8927, val_acc: 37.5000
# 100%|██████████| 600/600 [00:22<00:00, 26.74it/s]
# Epoch [7], val_loss: 1.8597, val_acc: 36.8333
# 100%|██████████| 600/600 [00:24<00:00, 24.80it/s]
# Epoch [8], val_loss: 1.8191, val_acc: 39.1667
# 100%|██████████| 600/600 [00:22<00:00, 26.70it/s]
# Epoch [9], val_loss: 1.8407, val_acc: 40.1667
# 100%|██████████| 600/600 [00:22<00:00, 27.04it/s]
# Epoch [10], val_loss: 1.8462, val_acc: 38.5000
# 100%|██████████| 600/600 [00:21<00:00, 27.78it/s]
# Epoch [11], val_loss: 1.8358, val_acc: 35.5000
# 100%|██████████| 600/600 [00:21<00:00, 27.85it/s]
# Epoch [12], val_loss: 1.8387, val_acc: 36.3333
# 100%|██████████| 600/600 [00:20<00:00, 29.34it/s]
# Epoch [13], val_loss: 1.8178, val_acc: 40.0000
# 100%|██████████| 600/600 [00:19<00:00, 31.53it/s]
# Epoch [14], val_loss: 1.7815, val_acc: 40.3333
# 100%|██████████| 600/600 [00:18<00:00, 31.98it/s]
# Epoch [15], val_loss: 1.7803, val_acc: 43.6667
# 100%|██████████| 600/600 [00:19<00:00, 31.00it/s]
# Epoch [16], val_loss: 1.7600, val_acc: 44.0000
# 100%|██████████| 600/600 [00:18<00:00, 33.10it/s]
# Epoch [17], val_loss: 1.7525, val_acc: 43.3333
# 100%|██████████| 600/600 [00:19<00:00, 31.11it/s]
# Epoch [18], val_loss: 1.7547, val_acc: 44.5000
# 100%|██████████| 600/600 [00:18<00:00, 32.83it/s]
# Epoch [19], val_loss: 1.7766, val_acc: 44.0000

### Quantum model

In [12]:
from pre_trained import *
model = CIFAR10Module(minst=True, quantum=bs)
model.model.fc = DressedQuantumNet(bs, bs.nb_parameters, dropout=True)

for param in model.parameters():
    param.requires_grad = False  # Freeze all layers

# Unfreeze the last layer
for param in model.model.fc.parameters():
    param.requires_grad = True

for param in model.model.conv1.parameters():
    param.requires_grad = True

pretrained_dict = torch.load("state_dicts/resnet18.pt")  # Path to saved weights
model_dict = model.state_dict()

# Remove 'conv1' weights from pretrained_dict to avoid shape mismatch
pretrained_dict = {k: v for k, v in pretrained_dict.items() if "conv1" or "fc" not in k}

# Update the current model dictionary
model_dict.update(pretrained_dict)

# Load the modified weights
model.load_state_dict(model_dict, strict=False) 
model.to(device)


CIFAR10Module(
  (criterion): CrossEntropyLoss()
  (model): ResNet(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05

In [13]:
# train the model with the chosen parameters
experiment = fit(epochs = 3, lr = 0.001, model = model, train_loader = train_loader, val_loader = val_loader, bs=bs)

100%|██████████| 600/600 [04:56<00:00,  2.03it/s]


Epoch [0], val_loss: 2.3021, val_acc: 12.1667


100%|██████████| 600/600 [04:49<00:00,  2.08it/s]


Epoch [1], val_loss: 2.3021, val_acc: 12.3333


100%|██████████| 600/600 [04:42<00:00,  2.12it/s]


Epoch [2], val_loss: 2.3023, val_acc: 12.3333


In [14]:
break

SyntaxError: 'break' outside loop (668683560.py, line 1)

### Quantum Model with dropout

In [None]:
model = CIFAR10Module(minst=True, quantum=bs)
model.model.fc = DressedQuantumNet(bs, bs.nb_parameters, dropout=True, pos=False)

for param in model.parameters():
    param.requires_grad = False  # Freeze all layers

# Unfreeze the last layer
for param in model.model.fc.parameters():
    param.requires_grad = True

for param in model.model.conv1.parameters():
    param.requires_grad = True

pretrained_dict = torch.load("state_dicts/resnet18.pt")  # Path to saved weights
model_dict = model.state_dict()

# Remove 'conv1' weights from pretrained_dict to avoid shape mismatch
pretrained_dict = {k: v for k, v in pretrained_dict.items() if "conv1" or "fc" not in k}

# Update the current model dictionary
model_dict.update(pretrained_dict)

# Load the modified weights
model.load_state_dict(model_dict, strict=False) 
model.to(device)


In [None]:
# train the model with the chosen parameters
experiment = fit(epochs = 20, lr = 0.001, model = model, train_loader = train_loader, val_loader = val_loader, bs=bs)

### Quantum Model with last layer

In [None]:
model = CIFAR10Module(minst=True, quantum=bs)
model.fc = DressedQuantumNet(bs, bs.nb_parameters, dropout=False, pos=True)

for param in model.parameters():
    param.requires_grad = False  # Freeze all layers

# Unfreeze the last layer
for param in model.model.fc.parameters():
    param.requires_grad = True

for param in model.model.conv1.parameters():
    param.requires_grad = True

pretrained_dict = torch.load("state_dicts/resnet18.pt")  # Path to saved weights
model_dict = model.state_dict()

# Remove 'conv1' weights from pretrained_dict to avoid shape mismatch
pretrained_dict = {k: v for k, v in pretrained_dict.items() if "conv1" or "fc" not in k}

# Update the current model dictionary
model_dict.update(pretrained_dict)

# Load the modified weights
model.load_state_dict(model_dict, strict=False) 
model.to(device)


In [None]:
# train the model with the chosen parameters
experiment = fit(epochs = 20, lr = 0.001, model = model, train_loader = train_loader, val_loader = val_loader, bs=bs)

In [None]:
break

### Quantum Model with both

In [None]:
model = CIFAR10Module(minst=True, quantum=bs)
model.model.fc = DressedQuantumNet(bs, bs.nb_parameters, dropout=True, pos=True)

for param in model.parameters():
    param.requires_grad = False  # Freeze all layers

# Unfreeze the last layer
for param in model.model.fc.parameters():
    param.requires_grad = True

for param in model.model.conv1.parameters():
    param.requires_grad = True

pretrained_dict = torch.load("state_dicts/resnet18.pt")  # Path to saved weights
model_dict = model.state_dict()

# Remove 'conv1' weights from pretrained_dict to avoid shape mismatch
pretrained_dict = {k: v for k, v in pretrained_dict.items() if "conv1" or "fc" not in k}

# Update the current model dictionary
model_dict.update(pretrained_dict)

# Load the modified weights
model.load_state_dict(model_dict, strict=False) 
model.to(device)


In [None]:
# train the model with the chosen parameters
experiment = fit(epochs = 20, lr = 0.001, model = model, train_loader = train_loader, val_loader = val_loader, bs=bs)

### Vanilla model

In [None]:
# end session if needed
if session is not None:
    session.stop()