# Mid Circuit Measurement 4 classes Debug

In [2]:
import torch
import pandas as pd
import pennylane as qml
from pennylane import numpy as np
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from data_utils import mnist_preparation 
from evaluationUtils import calculate_mcm_accuracy
from tqdm import tqdm
import matplotlib as plt
from OriginalModel import FullQuantumModel, QuantumCircuit
from mcmModel import MCMQuantumModel, MCMCircuit
from pennylane import Device
from pennylane.measurements import StateMP
from torch.nn import Module, ParameterDict
import matplotlib.pyplot as plt
import warnings
from typing import Optional, Dict, List, Any
from torch.utils.data import DataLoader, dataloader
from time import time
import math
from pennylane.measurements import MidMeasureMP
torch.manual_seed(1234)

In [2]:
labels = [0,1,2,3]
# Download MNIST and prepare transforms
mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.Compose([
                                transforms.Resize((16, 16)),  # Resize to 16x16
                                transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))  # Normalize
                             ]))
#train/vali/test 70/15/15 split, see data_utils.py for further details
train_dataloader, validation_dataloader, test_dataloader = mnist_preparation(dataset=mnist, labels = labels, train_test_ratio=0.7,batch_size=64, vali_test_ratio=0.5)

print("Images in the training set: ", len(train_dataloader.dataset), "\n Images in the validation set: ", len(validation_dataloader.dataset), "\n Images in the test set: ", len(test_dataloader.dataset))

# Baseline Model

In [3]:
baseline = FullQuantumModel(qubits=8, layers=8, num_classes=4)
baseline.trainable_parameters()
baseline.draw(style='sketch')

In [4]:
accuracy_history, loss_history = baseline.fit(dataloader=train_dataloader, learning_rate=0.001, epochs=20, show_plot=True)

# Baseline Evaluation

In [14]:
baseline.freeze_layers([0,1,2,3,4,5,6,7])
baseline.trainable_parameters()

#simplified per image test set evaluation
result = []
for img, label in tqdm(test_dataloader.dataset):
    img = img / torch.linalg.norm(img).view(-1, 1) #image normalization
    probs = baseline.forward(state=img) #extract probabilities
    prediction = torch.argmax(probs, dim=1)
    result.append((prediction, label))
    
def calculate_accuracy(data):
    correct = sum([1 for label, prediction in data if label == prediction])
    return correct, correct / len(data)

test_results = calculate_accuracy(result)

print(test_results[0], "elements have been correctly classified over", len(test_dataloader.dataset), "total images with an accuracy of ", test_results[1])

# MCM Model 4 classes 

In [3]:
mcm_model4 = MCMQuantumModel(qubits = 8, layers = 8, ansatz='ansatz_1')

In [4]:
mcm_model4.draw(style='sketch', path='mcm_model4class')

In [27]:
mcm_accuracy, fm_accuracy, loss_history = mcm_model4.fit(dataloader=train_dataloader, learning_rate=0.001, epochs=50, show_plot=True)

In [28]:
#import pickle
#model4_params = mcm_model4.params
#with open("/Users/jackvittori/Desktop/pesimcm4.pickle", "wb") as file:
#    pickle.dump(model4_params, file)

In [5]:
import pickle
with open("/Users/jackvittori/Desktop/pesimcm4.pickle", "rb") as file:
    model4_params = pickle.load(file) 

In [6]:
model4_params['layer_1']

In [7]:
mcm_model4.set_parameters(model4_params)
mcm_model4.params['layer_1']

## Early Exit with full-evaluation

In [8]:
prediction_results = {"early": [], "final": []}
for img, target in tqdm(test_dataloader.dataset):
    #img normalization
    img = img / torch.linalg.norm(img).view(-1, 1)
    #probs extraction
    mcm_probs, final_probs = mcm_model4.forward(state=img)
    #mcm prediction and confidence
    mcm_predictions = torch.argmax(mcm_probs, dim=1)
    mcm_correct = mcm_predictions == target
    early_confidence = mcm_probs[0,mcm_predictions]
    prediction_results["early"].append((mcm_correct, early_confidence))
    
    #fm prediction
    final_predictions = torch.argmax(final_probs, dim=1)
    final_correct = final_predictions == target
    prediction_results["final"].append((final_correct))

In [9]:
def post_evaluation_threshold(early_results, final_results, threshold):
    results = [] #chosen prediction per image
    count_1 = 0 #counter for early classified images
    count_2 = 0 #counter for final classified images
    mcm_correct = 0 #counter for correctly early classified images 
    final_correct = 0 #counter for correctly final classified images
    
    for i, (early_bool, confidence) in enumerate(early_results):
        if confidence.item() > threshold:
            results.append(early_bool.item()) #use early prediction
            count_1 += 1
            if early_bool: 
                mcm_correct += 1
        else:
            results.append(final_results[i][0].item()) #use final prediction
            count_2 += 1
            if final_results[i][0].item():
                final_correct += 1
            
    return results, mcm_correct, count_1, final_correct, count_2

In [10]:
def explain_results(results: Dict, threshold: List[float]):
    summary_data = {
        'Threshold': [],
        'Total Accuracy': [],
        'Early Classified': [],
        'Early Accuracy': [],
        'Final Classified': [],
        'Final Accuracy': []}
    
    for t in threshold:
        prediction_result, mcm_correct, n_early, final_correct, n_final = post_evaluation_threshold(results['early'], results['final'], t)
        tot_accuracy = sum([1 for i in prediction_result if i == True]) / len(prediction_result)
        
        #avoid division by 0
        early_accuracy = mcm_correct / n_early if n_early > 0 else 0
        final_accuracy = final_correct / n_final if n_final > 0 else 0

        summary_data['Threshold'].append(t)
        summary_data['Total Accuracy'].append(tot_accuracy)
        summary_data['Early Classified'].append(n_early)
        summary_data['Early Accuracy'].append(early_accuracy)
        summary_data['Final Classified'].append(n_final)
        summary_data['Final Accuracy'].append(final_accuracy)
        # print(f" tot accuracy {tot_accuracy}, average mean, {(n_early*early_accuracy + n_final*final_accuracy)/(n_early + n_final)}")
    
    df = pd.DataFrame(summary_data)
    return df

In [15]:
threshold = [round(x * 0.02 + 0.2, 2) for x in range(31)]

In [16]:
explain_results(prediction_results, threshold)

# Early Exit without full execution

In [17]:
def early_evaluation_utils(params: Dict, state: torch.Tensor = None): 
    first_pair = [0,1]
    measurements = []
    if state is not None:
        # state vector initialization with input
        qml.QubitStateVector(state, wires=range(8))
    for i in range(4):
        for j in range(8):
            qml.RX(params[f'layer_{i}'][j, 0], wires=j)
            qml.RY(params[f'layer_{i}'][j, 1], wires=j)
            qml.RZ(params[f'layer_{i}'][j, 2], wires=j)
        for j in range(8):
            qml.CNOT(wires=[j, (j + 1) % 8])
    
    for w in first_pair: 
        measurements.append(qml.measure(wires=w)) #measure first pair of qubits
    return measurements

def fully_evaluation_utils(params: Dict, state: torch.Tensor = None):
    first_pair = [0,1]
    second_pair = [2,3]
    mcasurements = []
    if state is not None:
        # state vector initialization with input
        qml.QubitStateVector(state, wires=range(8))
    for i in range(4):
        for j in range(8):
            qml.RX(params[f'layer_{i}'][j, 0], wires=j)
            qml.RY(params[f'layer_{i}'][j, 1], wires=j)
            qml.RZ(params[f'layer_{i}'][j, 2], wires=j)
        for j in range(8):
            qml.CNOT(wires=[j, (j + 1) % 8])
            
    for w in first_pair: 
        mcasurements.append(qml.measure(wires=w)) #measure first pair of qubits

    for i in range(4, 8):
        for j in range(8):
            qml.RX(params[f'layer_{i}'][j, 0], wires=j)
            qml.RY(params[f'layer_{i}'][j, 1], wires=j)
            qml.RZ(params[f'layer_{i}'][j, 2], wires=j)
        for j in range(8):
            qml.CNOT(wires=[j, (j + 1) % 8])

    for w in second_pair:
        mcasurements.append(qml.measure(wires=w))

    return mcasurements

In [18]:
dev = qml.device("default.qubit", wires=8)
@qml.qnode(dev)  
def early_evaluation_ansatz(params: Dict, state: torch.Tensor = None):
    early_measurement = early_evaluation_utils(params=params, state=state)
    return qml.probs(op=early_measurement)

@qml.qnode(dev)
def fully_evaluation_ansatz(params: Dict, state: torch.Tensor = None):
    measurements = fully_evaluation_utils(params=params, state=state)
    mid_measurement = measurements[:2]
    final_measurement = measurements[2:]
    return qml.probs(op=mid_measurement), qml.probs(op=final_measurement)

In [19]:
parameters4classes = mcm_model4.params

In [1]:
early_evaluate_model, ax1 = qml.draw_mpl(early_evaluation_ansatz)(parameters4classes)

early_evaluate_model.savefig('early_evaluate_model.png')

In [3]:
final_evaluate_model, ax2 = qml.draw_mpl(fully_evaluation_ansatz)(parameters4classes)

final_evaluate_model.savefig('final_evaluate_model.png')

## Evaluation Routine Definition

In [22]:
def evaluation_routine(dataloader: DataLoader, parameters: Dict, threshold: float):
    
    evaluation_results = []
    count_1 = 0 #counter for early classified images
    count_2 = 0 #counter for final classified images
    early_correct = 0 #counter for correctly early classified images 
    final_correct = 0 #counter for correctly final classified images
    for img, target in dataloader.dataset:
        #img normalization
        img = img / torch.linalg.norm(img).view(-1, 1)
        
        #mid circuit evaluation
        early_probs = early_evaluation_ansatz(params=parameters, state=img)
        early_prediction = torch.argmax(early_probs, dim=1)
        confidence = early_probs[0, early_prediction].item()
        if confidence >= threshold:
            early_guess = early_prediction == target
            evaluation_results.append(early_guess.item())
            count_1 += 1
            if early_guess: 
                early_correct += 1
            
        else: 
            final_probs = fully_evaluation_ansatz(params=parameters, state=img)
            early_full, final_full = final_probs
            final_predictions = torch.argmax(final_full, dim=1)
            final_guess = final_predictions == target
            evaluation_results.append(final_guess.item())
            count_2 += 1
            if final_guess: 
                final_correct += 1
    total_accuracy = sum([1 for i in evaluation_results if i == True])/len(evaluation_results)
    early_accuracy = early_correct/count_1 if count_1 > 0 else 0
    final_accuracy = final_correct/count_2 if count_2 > 0 else 0
    
    return total_accuracy, early_accuracy, count_1, final_accuracy, count_2

In [23]:
def explain_evaluation(dataloader: DataLoader, parameters: Dict, threshold: List[float]):
    summary_data = {
    'Threshold': [],
    'Total Accuracy': [],
    'Early Classified': [],
    'Early Accuracy': [],
    'Final Classified': [],
    'Final Accuracy': []}
    
    for t in tqdm(threshold):
        tot_acc, early_acc, early_count, final_acc, final_count = evaluation_routine(dataloader, parameters, t)
        summary_data['Threshold'].append(t)
        summary_data['Total Accuracy'].append(tot_acc)
        summary_data['Early Classified'].append(early_count)
        summary_data['Early Accuracy'].append(early_acc)
        summary_data['Final Classified'].append(final_count)
        summary_data['Final Accuracy'].append(final_acc)
        
    df = pd.DataFrame(summary_data)
    return df

In [25]:
#threshold = [round(x * 0.02 + 0.3, 2) for x in range(31)]
threshold = [0.2, 0.3, 0.35,0.36, 0.38, 0.40, 0.45, 0.5]
explain_evaluation(test_dataloader, parameters4classes, threshold)