# Classifying MNIST with a simple model and quantum embeddings

Inspired by:  https://www.kaggle.com/code/geekysaint/solving-mnist-using-pytorch

Useful imports

In [11]:
# for the Boson Sampler
import perceval as pcvl
#import perceval.providers.scaleway as scw  # Uncomment to allow running on scaleway

# for the machine learning model
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import time
from boson_sampler import BosonSampler
from utils import MNIST_partial, accuracy, plot_training_metrics
from model128 import *
import csv

## Definition of the Boson Sampler

In [12]:
session = None
# to run a remote session on Scaleway, uncomment the following and fill project_id and token
# session = scw.Session(
#                    platform="sim:sampling:p100",  # or sim:sampling:h100
#                    project_id=""  # Your project id,
#                    token=""  # Your personal API key
#                    )

# start session
if session is not None:
    session.start()
# definition of the BosonSampler
# here, we use 30 photons and 2 modes

bs = BosonSampler(24,2, postselect = 2, session = session)
print(f"Boson sampler defined with number of parameters = {bs.nb_parameters}, and embedding size = {bs.embedding_size}")

#to display it
# pcvl.pdisplay(bs.create_circuit())
# define device to run the model

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f'DEVICE = {device}')

Boson sampler defined with number of parameters = 540, and embedding size = 276
DEVICE = cpu


## Dataset : a subset of MNIST dataset

In [13]:
# dataset from csv file, to use for the challenge
train_dataset = MNIST_partial(split = 'train')
val_dataset = MNIST_partial(split='val')

# definition of the dataloader, to process the data in the model
# here, we need a batch size of 1 to use the boson sampler
batch_size = 10
train_loader = DataLoader(train_dataset, batch_size, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size, shuffle = False)

In [14]:
for images, labels in train_loader:  
    print('Image batch dimensions:', images.shape)
    print('Image label dimensions:', labels.shape)
    break

Image batch dimensions: torch.Size([10, 1, 28, 28])
Image label dimensions: torch.Size([10])


In [15]:
def training_step(model, batch, emb = None):
    images, labels = batch
    images, labels = images.to(device), labels.to(device)
    images.requires_grad = True  # Only needed for special cases
    # if self.embedding_size:
    #     out = self(images, emb.to(self.device)) ## Generate predictions
    # else:
    # images = images.reshape(-1,1,28,28)
    # print(images.shape)

    # out = model(images) ## Generate predictions
    # loss = F.cross_entropy(out, labels)
    # acc = accuracy(out, labels, task="multiclass", num_classes=10)
    loss, acc = model((images, labels)) ## Generate predictions
    
    # loss = F.cross_entropy(out, labels) ## Calculate the loss
    # acc = accuracy(out, labels)
    return loss, acc

def validation_step(model, batch, emb =None):
    images, labels = batch
    images, labels = images.to(device), labels.to(device)
    images.requires_grad = True  # Only needed for special cases
    # if self.embedding_size:
    #     out = self(images, emb.to(self.device)) ## Generate predictions
    # # else:
    # out = model(images) ## Generate predictions
    # loss = F.cross_entropy(out, labels)
    # acc = accuracy(out, labels, task="multiclass", num_classes=10)
    loss, acc = model((images,labels)) ## Generate predictions
    return({'val_loss':loss, 'val_acc': acc})

def validation_epoch_end(outputs):
    batch_losses = [x['val_loss'] for x in outputs]
    epoch_loss = torch.stack(batch_losses).mean()
    batch_accs = [x['val_acc'] for x in outputs]
    epoch_acc = torch.stack(batch_accs).mean()
    return({'val_loss': epoch_loss.item(), 'val_acc' : epoch_acc.item()})

def epoch_end(epoch, result):
    print("Epoch [{}], val_loss: {:.4f}, val_acc: {:.4f}".format(epoch, result['val_loss'], result['val_acc']))
    return result['val_loss'], result['val_acc']

# training loop
def fit(epochs, lr, model, train_loader, val_loader, bs: BosonSampler, opt_func = torch.optim.SGD):
    history = []
    optimizer = opt_func([{'params': model.model.fc.parameters(), 'lr': lr},{'params': model.model.conv1.parameters(), 'lr': lr}])
    # creation of empty lists to store the training metrics
    train_loss, train_acc, val_loss, val_acc = [], [], [], []
    for epoch in range(epochs):
        training_losses, training_accs = 0, 0
        ## Training Phase
        for step, batch in enumerate(tqdm(train_loader)):
            # # embedding in the BS
            # if model.embedding_size:
            #     images, labs = batch
            #     images = images.squeeze(0).squeeze(0)
            #     t_s = time.time()
            #     embs = bs.embed(images,1000)
            #     loss,acc = model.training_step(batch,emb = embs.unsqueeze(0))

            # else:
            
            loss,acc = training_step(model, batch)
            # loss.requires_grad = True
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            # Backpropagation with boson sampling noise
            # for param in model.parameters():
            #     param.grad = param.grad + bs(param.grad.shape, scale=0.001)

            training_losses+=int(loss.detach())
            training_accs+=int(acc.detach())
            # if model.embedding_size and step%100==0:
            #     print(f"STEP {step}, Training-acc = {training_accs/(step+1)}, Training-losses = {training_losses/(step+1)}")
        
        ## Validation phase
        outputs = [validation_step(model, batch) for batch in val_loader]
        result = (validation_epoch_end(outputs))
        # result = evaluate(model, val_loader, bs)
        validation_loss, validation_acc = result['val_loss'], result['val_acc']
        epoch_end(epoch, result)
        history.append(result)

        ## summing up all the training and validation metrics
        training_loss = training_losses/len(train_loader)
        training_accs = training_accs/len(train_loader)
        train_loss.append(training_loss)
        train_acc.append(training_accs)
        val_loss.append(validation_loss)
        val_acc.append(validation_acc)

        # plot training curves
        

    with open('tmp_file.txt', 'w') as f:
        csv.writer(f, delimiter=' ').writerows([train_acc,val_acc,train_loss,val_loss])
        # plot_training_metrics(train_acc,val_acc,train_loss,val_loss)
    return(history)


In [16]:
from pre_trained import *

### Vanilla model

In [17]:
# model = MnistModel()
# pretrained_dict = torch.load("cifar_net.pth")  # Path to saved weights
# model.load_state_dict(pretrained_dict, strict=False) 

# for param in model.parameters():
#     param.requires_grad = False  # Freeze all layers

# model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) 
# # Unfreeze the last layer
# for param in model.conv1.parameters():
#     param.requires_grad = True
# for param in model.fc2.parameters():
#     param.requires_grad = True

# model.to(device)

# model = CIFAR10Module(minst=True,vanilla=True)
# # data = CIFAR10Data()

# for param in model.parameters():
#     param.requires_grad = False  # Freeze all layers

# # Unfreeze the last layer
# for param in model.model.conv1.parameters():
#     param.requires_grad = True
# for param in model.model.fc.parameters():
#     param.requires_grad = True
# pretrained_dict = torch.load("cifar_vanilla.pth")  # Path to saved weights
# model_dict = model.state_dict()

# # Remove 'conv1' weights from pretrained_dict to avoid shape mismatch
# pretrained_dict = {k: v for k, v in pretrained_dict.items() if "conv1" or 'fc' not in k}

# # Update the current model dictionary
# model_dict.update(pretrained_dict)

# # Load the modified weights
# model.load_state_dict(model_dict, strict=False) 
# model.to(device)


In [18]:
# experiment = fit(epochs = 20, lr = 0.001, model = model, train_loader = train_loader, val_loader = val_loader, bs=bs)

In [19]:
model = CIFAR10Module(minst=True, vanilla=True)
model.model.fc = DressedQuantumNet(bs, bs.nb_parameters, dropout=True)

for param in model.parameters():
    param.requires_grad = False  # Freeze all layers

# Unfreeze the last layer
for param in model.model.fc.parameters():
    param.requires_grad = True

for param in model.model.conv1.parameters():
    param.requires_grad = True

pretrained_dict = torch.load("cifar_vanilla.pth")  # Path to saved weights
model_dict = model.state_dict()

# Remove 'conv1' weights from pretrained_dict to avoid shape mismatch
pretrained_dict = {k: v for k, v in pretrained_dict.items() if "conv1" or "fc" not in k}

# Update the current model dictionary
model_dict.update(pretrained_dict)

# Load the modified weights
model.load_state_dict(model_dict, strict=False) 
model.to(device)


CIFAR10Module(
  (criterion): CrossEntropyLoss()
  (model): MnistModel(
    (conv1): Conv2d(1, 6, kernel_size=(7, 7), stride=(1, 1))
    (pool): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
    (linear): Linear(in_features=294, out_features=512, bias=True)
    (fc): DressedQuantumNet(
      (pre_net): Linear(in_features=512, out_features=540, bias=True)
      (drop): Dropout(p=0.2, inplace=False)
      (linear): Linear(in_features=276, out_features=128, bias=True)
      (post_net): Linear(in_features=128, out_features=10, bias=True)
    )
  )
)

In [20]:
# train the model with the chosen parameters
experiment = fit(epochs = 20, lr = 0.001, model = model, train_loader = train_loader, val_loader = val_loader, bs=bs)

100%|██████████| 600/600 [37:46<00:00,  3.78s/it]


Epoch [0], val_loss: 2.3063, val_acc: 9.0000


100%|██████████| 600/600 [33:20<00:00,  3.33s/it]


Epoch [1], val_loss: 2.3057, val_acc: 9.0000


100%|██████████| 600/600 [33:13<00:00,  3.32s/it]


Epoch [2], val_loss: 2.3052, val_acc: 9.0000


100%|██████████| 600/600 [33:06<00:00,  3.31s/it]


Epoch [3], val_loss: 2.3048, val_acc: 9.0000


 50%|█████     | 303/600 [1:03:26<1:02:10, 12.56s/it]


KeyboardInterrupt: 

In [None]:
# end session if needed
if session is not None:
    session.stop()