In [23]:
import torch
import pandas as pd
import pennylane as qml
from pennylane import numpy as np
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from data_utils import mnist_preparation, add_salt_and_pepper_noise 
from evaluationUtils import calculate_mcm_accuracy
from tqdm import tqdm
import matplotlib as plt
from mcmadaptablemodel import MCMQuantumModel, MCMCircuit
from pennylane import Device
from pennylane.measurements import StateMP
from torch.nn import Module, ParameterDict
import matplotlib.pyplot as plt
from OriginalModel import FullQuantumModel, QuantumCircuit
import warnings
from typing import Optional, Dict, List, Any
from torch.utils.data import DataLoader, dataloader
from time import time
import math
from pennylane.measurements import MidMeasureMP
torch.manual_seed(1234)

In [24]:
def early_evaluation_utils(params: Dict, state: torch.Tensor = None): 
    first_pair = [0,1]
    measurements = []
    if state is not None:
        # state vector initialization with input
        qml.QubitStateVector(state, wires=range(8))
    for i in range(4):
        for j in range(8):
            qml.RX(params[f'layer_{i}'][j, 0], wires=j)
            qml.RY(params[f'layer_{i}'][j, 1], wires=j)
            qml.RZ(params[f'layer_{i}'][j, 2], wires=j)
        for j in range(8):
            qml.CNOT(wires=[j, (j + 1) % 8])
    
    for w in first_pair: 
        measurements.append(qml.measure(wires=w)) #measure first pair of qubits
    return measurements

def fully_evaluation_utils(params: Dict, state: torch.Tensor = None):
    first_pair = [0,1]
    second_pair = [2,3]
    mcasurements = []
    if state is not None:
        # state vector initialization with input
        qml.QubitStateVector(state, wires=range(8))
    for i in range(4):
        for j in range(8):
            qml.RX(params[f'layer_{i}'][j, 0], wires=j)
            qml.RY(params[f'layer_{i}'][j, 1], wires=j)
            qml.RZ(params[f'layer_{i}'][j, 2], wires=j)
        for j in range(8):
            qml.CNOT(wires=[j, (j + 1) % 8])
            
    for w in first_pair: 
        mcasurements.append(qml.measure(wires=w)) #measure first pair of qubits

    for i in range(4, 8):
        for j in range(8):
            qml.RX(params[f'layer_{i}'][j, 0], wires=j)
            qml.RY(params[f'layer_{i}'][j, 1], wires=j)
            qml.RZ(params[f'layer_{i}'][j, 2], wires=j)
        for j in range(8):
            qml.CNOT(wires=[j, (j + 1) % 8])

    for w in second_pair:
        mcasurements.append(qml.measure(wires=w))

    return mcasurements

dev = qml.device("default.qubit", wires=8)
@qml.qnode(dev)  
def early_evaluation_ansatz(params: Dict, state: torch.Tensor = None):
    early_measurement = early_evaluation_utils(params=params, state=state)
    return qml.probs(op=early_measurement)

@qml.qnode(dev)
def fully_evaluation_ansatz(params: Dict, state: torch.Tensor = None):
    measurements = fully_evaluation_utils(params=params, state=state)
    mid_measurement = measurements[:2]
    final_measurement = measurements[2:]
    return qml.probs(op=mid_measurement), qml.probs(op=final_measurement)

In [25]:
labels = [0,1,2,3]
# Download MNIST and prepare transforms
mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.Compose([
                                transforms.Resize((16, 16)),  # Resize to 16x16
                                transforms.ToTensor(),
                                transforms.Lambda(lambda img: add_salt_and_pepper_noise(img, salt_prob=0.2, pepper_prob=0.2)),
                                transforms.Normalize((0.1307,), (0.3081,))  # Normalize
                             ]))
#train/vali/test 70/15/15 split, see data_utils.py for further details
train_dataloader, validation_dataloader, test_dataloader = mnist_preparation(dataset=mnist, labels = labels, train_test_ratio=0.7,batch_size=64, vali_test_ratio=0.5)

print("Images in the training set: ", len(train_dataloader.dataset), "\n Images in the validation set: ", len(validation_dataloader.dataset), "\n Images in the test set: ", len(test_dataloader.dataset))

In [26]:
import pickle
om_model = FullQuantumModel(qubits=8, layers=8, num_classes=4)
with open('/Users/jackvittori/Desktop/modello-originale-training/weights02.pickle', 'rb') as file:
    parameters = pickle.load(file)
om_model.set_parameters(parameters)

mcm_model = MCMQuantumModel(qubits=8, layers=8, early_exits=[3])

import pickle
with open("/Users/jackvittori/Desktop/highnoise/traininghistory.pickle", "rb") as file: 
    training_history = pickle.load(file)
    
mcm_parameters = training_history['model_params']

mcm_model.set_parameters(mcm_parameters)

In [27]:
early_probs_distribution = []
final_probs_distribution = []
for img, target in tqdm(validation_dataloader.dataset): 
    img = img / torch.linalg.norm(img).view(-1, 1)
    evaluation = fully_evaluation_ansatz(params=mcm_parameters, state=img)
    early_full, final_full = evaluation
    
    early_prediction = torch.argmax(early_full, dim=1)
    early_probs = early_full[0, early_prediction].item()
    
    final_prediction = torch.argmax(final_full, dim=1)
    final_probs = final_full[0, final_prediction].item()
    
    early_probs_distribution.append(early_probs)
    final_probs_distribution.append(final_probs)    
    
probs_distribution = []
for img, target in tqdm(validation_dataloader.dataset): 
    img = img / torch.linalg.norm(img).view(-1, 1)
    probs = om_model.forward(state=img) #extract probabilities
    prediction = torch.argmax(probs, dim=1)
    prediction_probs = probs[0, prediction].item()
    probs_distribution.append(prediction_probs) 

In [28]:
import numpy as np
import matplotlib.pyplot as plt

plt.style.use('ggplot')
fig, ax = plt.subplots(figsize=(10, 5))

# Creazione dell'istogramma per il primo array
ax.hist(early_probs_distribution, bins=60, density=False, alpha=0.4, color='yellow', edgecolor='black', label='Mid circuit probability distribution')

# Creazione dell'istogramma per il secondo array
ax.hist(final_probs_distribution, bins=60, density=False, alpha=0.4, color='red', edgecolor='black', label='Final circuit probability distribution')

# Creazione dell'istogramma per il terzo array con maggiore trasparenza
ax.hist(probs_distribution, bins=60, density=False, alpha=0.2, color='green', edgecolor='black', label='Original model probability distribution')

# Titolo e etichette degli assi
ax.set_title('Prediction confidence distribution noise 0.2', fontsize=16)
ax.set_xlabel('Prediction confidence', fontsize=14)
ax.set_ylabel('Occurrencies', fontsize=14)

# Imposta i tick dell'asse X con intervalli di 0.05 da 0.25 a 0.6
ax.set_xticks(np.arange(0.25, 0.65, 0.05))

# Aggiunta della legenda per distinguere le distribuzioni
ax.legend()

ax.set_xticks(np.arange(0.25, 0.6, 0.05))
ax.set_ylim(0, 230)

# Griglia e layout
ax.grid(True)
plt.tight_layout()

# Salvataggio del grafico
plt.savefig('/Users/jackvittori/Desktop/highnoise/noise02probs.png', dpi=300)
plt.show()