In [None]:
import random
import pickle
import pennylane as qml
import torch
from data_utils import mnist_preparation
from typing import Optional, Dict, List, Any
from torch.utils.data import DataLoader, dataloader
from tqdm import tqdm
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from pennylane import numpy as np

In [None]:
labels = [0,1,2,3]
# Download MNIST and prepare transforms
mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.Compose([
                                transforms.Resize((16, 16)),  # Resize to 16x16
                                transforms.ToTensor(),
                                #transforms.Lambda(lambda img: add_salt_and_pepper_noise(img, salt_prob=0.1, pepper_prob=0.1)),
                                transforms.Normalize((0.1307,), (0.3081,))  # Normalize
                             ]))
#train/vali/test 70/15/15 split, see data_utils.py for further details
train_dataloader, validation_dataloader, test_dataloader = mnist_preparation(dataset=mnist, labels = labels, train_test_ratio=0.7,batch_size=64, vali_test_ratio=0.5)

print("Images in the training set: ", len(train_dataloader.dataset), "\n Images in the validation set: ", len(validation_dataloader.dataset), "\n Images in the test set: ", len(test_dataloader.dataset))

In [None]:
with open("/Users/jackvittori/Desktop/no-noise/norumore-training.pickle", "rb") as file: 
    training_history = pickle.load(file)
loss_history = training_history['loss_history']
mcm_accuracy = training_history['mcm_accuracy']
fm_accuracy = training_history['fm_accuracy']
weights = training_history['model_params']

for param in weights.values():
    param.requires_grad = False

In [None]:
p = 0.0

def early_evaluation_utils(state: torch.Tensor = None): 
    first_pair = [0,1]
    if state is not None:
        # state vector initialization with input
        qml.QubitStateVector(state, wires=range(8))
    for i in range(4):
        for j in range(8):
            qml.RX(weights[f'layer_{i}'][j, 0], wires=j)
            qml.RY(weights[f'layer_{i}'][j, 1], wires=j)
            qml.RZ(weights[f'layer_{i}'][j, 2], wires=j)
        for j in range(8):
            qml.CNOT(wires=[j, (j + 1) % 8])
            qml.DepolarizingChannel(p=p, wires=(j + 1) % 8)
    
    return qml.probs(wires = first_pair)

def fully_evaluation_utils(state: torch.Tensor = None):
    first_pair = [0,1]
    second_pair = [2,3]
    measurements = []
    if state is not None:
        # state vector initialization with input
        qml.QubitStateVector(state, wires=range(8))
    for i in range(4):
        for j in range(8):
            qml.RX(weights[f'layer_{i}'][j, 0], wires=j)
            qml.RY(weights[f'layer_{i}'][j, 1], wires=j)
            qml.RZ(weights[f'layer_{i}'][j, 2], wires=j)
        for j in range(8):
            qml.CNOT(wires=[j, (j + 1) % 8])
            qml.DepolarizingChannel(p=p, wires=(j + 1) % 8)
            
    for w in first_pair: 
        measurements.append(qml.measure(wires=w, reset=False, postselect=None))
    #m_0 = qml.measure(wires = 0, reset=False, postselect=None)
    #m_1 = qml.measure(wires = 1, reset=False, postselect=None)
    #print('ok meas 1')
    for i in range(4, 8):
        for j in range(8):
            qml.RX(weights[f'layer_{i}'][j, 0], wires=j)
            qml.RY(weights[f'layer_{i}'][j, 1], wires=j)
            qml.RZ(weights[f'layer_{i}'][j, 2], wires=j)
        for j in range(8):
            qml.CNOT(wires=[j, (j + 1) % 8])
            qml.DepolarizingChannel(p=p, wires=(j + 1) % 8)
    
    return qml.probs(op = measurements), qml.probs(wires=[2,3])

mixed_device = qml.device("default.mixed", wires=[0,1,2,3,4,5,6,7,8,9])
late_qnode = qml.QNode(fully_evaluation_utils, mixed_device)
early_qnode = qml.QNode(early_evaluation_utils, mixed_device)

In [None]:
evaluation_results = []
early_results = []
count_1 = 0 #counter for early classified images
count_2 = 0 #counter for final classified images
early_correct = 0 #counter for correctly early classified images 
final_correct = 0 #counter for correctly final classified images
executed_layers = 0

for i, (img, target) in tqdm(enumerate(test_dataloader.dataset)):
    if i == 100:
        break 
        
    img = img.type(torch.float64)
    img = img / torch.linalg.norm(img).view(-1, 1)
    
    #mid circuit evaluation
    early_probs = early_qnode(img)
    early_prediction = torch.argmax(early_probs)
    confidence = early_probs[early_prediction].item()
    early_guess = early_prediction == target
    early_results.append(early_guess.item())
    
    if confidence >= 0.31:
        #print('early')
        evaluation_results.append(early_guess.item())
        count_1 += 1
        executed_layers += 4
        if early_guess: 
            early_correct += 1
            
    else: 
        #print('post')
        final_probs = late_qnode(img)
        early_full, final_full = final_probs
        final_predictions = torch.argmax(final_full)
        final_guess = final_predictions == target
        evaluation_results.append(final_guess.item())
        count_2 += 1
        executed_layers += 12
        if final_guess: 
            final_correct += 1
    
total_accuracy = sum([1 for i in evaluation_results if i == True])/len(evaluation_results)
early_total_accuracy = sum([1 for i in early_results if i == True])/len(early_results)
early_exited_accuracy = early_correct/count_1 if count_1 > 0 else 0
final_exited_accuracy = final_correct/count_2 if count_2 > 0 else 0