In [1]:
from OriginalModel import FullQuantumModel, QuantumCircuit
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from data_utils import mnist_preparation 
import math
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import numpy as np
from tqdm import tqdm
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

In [None]:
#states of 8 qubits
for i in range(256):
    print(format(i, '08b'))

In [None]:
#Model instantiation
num_qubits = 8
num_layers = 6
model = FullQuantumModel(qubits=num_qubits, layers=num_layers, num_classes=4)

In [None]:
model.draw(style = 'sketch')

In [None]:
model.trainable_parameters()

# Dataset preparation

In [None]:
labels = [0,1,2,3]

In [None]:
# Download MNIST and prepare transforms
mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.Compose([
                                transforms.Resize((16, 16)),  # Resize to 16x16
                                transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))  # Normalize
                             ]))

In [None]:
train_dataloader, test_dataloader = mnist_preparation(dataset=mnist, labels = labels, train_test_ratio=0.8,batch_size=64)

# Model training

A tutorial on how to train the model with the different functionalities is hereafter reported:

In [None]:
# 'normal' training of the whole model
model.fit(dataloader=train_dataloader, learning_rate=0.01, epochs=8, show_plot=True)

In [18]:
# training of the model up to the forth layer: take prediction after four layers and train only the layer before that
model.fit(dataloader=train_dataloader, learning_rate=0.0001, epochs=1, num_layers_to_execute=4, show_plot=True)

In [19]:
# training of the model up to the forth layer: take prediction after second layers and train only the layer before that
model.fit(dataloader=train_dataloader, learning_rate=0.0001, epochs=1, num_layers_to_execute=2, show_plot=True)

# Test

In [20]:
#freeze the layer after that the previous cell has already freeze all but the first two
model.freeze_layers([0,1])
model.trainable_parameters()

In [21]:
#set a threshold for confidence value
threshold = 0.3

#save prediction of the test set, there is an early exit on layer 2 and 4
#that will be used when the confidence will be adequately high, i.e. the 
#probabilty associated to one of the possible states is greater than the
#threshold

results = {
    '2_layers': [],
    '4_layers': [],
    '6_layers': []
}

for img, label in tqdm(test_dataloader.dataset): 
    img = img / torch.linalg.norm(img).view(-1, 1)
    
    #forward pass until the first 2 layers
    probs = model.forward(state=img, num_layers_to_execute=2)
    prediction = torch.argmax(probs, dim=1)
    confidence = probs[0, prediction]
    
    if confidence > threshold:
        results['2_layers'].append((prediction, label))
    else: 
        #forward pass until the 4 layer
        probs = model.forward(state=img, num_layers_to_execute=4)
        prediction = torch.argmax(probs, dim=1)
        confidence = probs[0, prediction]
        
        if confidence > threshold:
            results['4_layers'].append((prediction, label))
        
        else:
            probs = model.forward(state=img, num_layers_to_execute=6)
            prediction = torch.argmax(probs, dim=1)
            results['6_layers'].append((prediction,label))

In [22]:
#accuracy computation

def calculate_accuracy(data):
    correct = sum([1 for label, prediction in data if label == prediction])
    return correct, correct / len(data)

correct_2_layers, accuracy_2_layers = calculate_accuracy(results['2_layers'])
correct_4_layers, accuracy_4_layers = calculate_accuracy(results['4_layers'])
correct_6_layers, accuracy_6_layers = calculate_accuracy(results['6_layers'])

print(f"{len(results['2_layers'])} elements in 2 layers with Accuracy : {accuracy_2_layers}")

print(f"{len(results['4_layers'])} elements in 4 layers with Accuracy : {accuracy_4_layers}")

print(f"{len(results['6_layers'])} elements in 2 layers with Accuracy : {accuracy_6_layers}")

print(f"Overall accuracy: {(correct_2_layers+correct_4_layers+correct_6_layers)/len(test_dataloader.dataset)}")

In [25]:
#pie chart to display the distribution of labels that are classified in each exit 
def plot_pie_charts(results):
    fig, axes = plt.subplots(1, 3, figsize=(24, 8)) 

    
    titles = ["Label distribution in 2 layers", "Label distribution in 4 layers", "Label distribution in 6 layers"]
    
    for idx, key in enumerate(['2_layers', '4_layers', '6_layers']):
        data = results[key]
        labels, counts = zip(*Counter([label.item() for _, label in data]).items())
        axes[idx].pie(counts, labels=labels, autopct='%1.1f%%', startangle=140)
        axes[idx].set_title(titles[idx])
        axes[idx].axis('equal') 

    plt.tight_layout()
    plt.show()


plot_pie_charts(results)

In [26]:
#confusion matrix 

results = {
    '2_layers': [(torch.tensor([pred]), torch.tensor([label])) for pred, label in results['2_layers']],
    '4_layers': [(torch.tensor([pred]), torch.tensor([label])) for pred, label in results['4_layers']],
    '6_layers': [(torch.tensor([pred]), torch.tensor([label])) for pred, label in results['6_layers']]
}

def plot_confusion_matrix(results, layer_name, ax):
    all_preds = torch.cat([x[0] for x in results])
    all_labels = torch.cat([x[1] for x in results])
    
    cm = confusion_matrix(all_labels.numpy(), all_preds.numpy())
    disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    
    disp.plot(cmap=plt.cm.Blues, ax=ax)
    ax.set_title(f'Confusion Matrix for {layer_name}')
    ax.set_xlabel('Predicted')
    ax.set_ylabel('True')


fig, axes = plt.subplots(2, 3, figsize=(18, 12))


plot_confusion_matrix(results['2_layers'], '2 Layers', axes[0, 0])
plot_confusion_matrix(results['4_layers'], '4 Layers', axes[0, 1])
plot_confusion_matrix(results['6_layers'], '6 Layers', axes[0, 2])

all_results = results['2_layers'] + results['4_layers'] + results['6_layers']
plot_confusion_matrix(all_results, 'Overall Model', axes[1, 1])

fig.delaxes(axes[1, 0])
fig.delaxes(axes[1, 2])

plt.tight_layout()
plt.show()