In [27]:
import random
import pickle
import pennylane as qml
import torch
from data_utils import mnist_preparation
from typing import Optional, Dict, List, Any
from torch.utils.data import DataLoader, dataloader
from tqdm import tqdm
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from pennylane import numpy as np

In [28]:
labels = [0,1,2,3]
# Download MNIST and prepare transforms
mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.Compose([
                                transforms.Resize((16, 16)),  # Resize to 16x16
                                transforms.ToTensor(),
                                #transforms.Lambda(lambda img: add_salt_and_pepper_noise(img, salt_prob=0.1, pepper_prob=0.1)),
                                transforms.Normalize((0.1307,), (0.3081,))  # Normalize
                             ]))
#train/vali/test 70/15/15 split, see data_utils.py for further details
train_dataloader, validation_dataloader, test_dataloader = mnist_preparation(dataset=mnist, labels = labels, train_test_ratio=0.7,batch_size=64, vali_test_ratio=0.5)

print("Images in the training set: ", len(train_dataloader.dataset), "\n Images in the validation set: ", len(validation_dataloader.dataset), "\n Images in the test set: ", len(test_dataloader.dataset))

Images in the training set:  17327 
 Images in the validation set:  3713 
 Images in the test set:  3714


In [29]:
with open("/Users/jackvittori/Desktop/no-noise/norumore-training.pickle", "rb") as file: 
    training_history = pickle.load(file)
loss_history = training_history['loss_history']
mcm_accuracy = training_history['mcm_accuracy']
fm_accuracy = training_history['fm_accuracy']
weights = training_history['model_params']

  return torch.load(io.BytesIO(b))


In [30]:
for param in weights.values():
    param.requires_grad = False

In [5]:
weights['layer_0']

Parameter containing:
tensor([[ 0.4261,  0.2626,  0.2783],
        [ 1.0484, -0.4439,  0.1328],
        [-0.1889,  0.8609,  1.3403],
        [ 0.0887,  0.4347,  0.2151],
        [ 0.4924, -0.3589,  0.9989],
        [ 0.7229,  0.6374,  1.7330],
        [-0.5014,  0.6333,  0.9119],
        [ 0.6729,  0.6782,  0.8720]])

In [31]:
with open("/Users/jackvittori/Desktop/4layerokkkk/weight4layer.pickle", "rb") as file: 
    weights_small = pickle.load(file)
for param in weights_small.values():
    param.requires_grad = False

  return torch.load(io.BytesIO(b))


In [32]:
weights

ParameterDict(
    (layer_0): Parameter containing: [torch.FloatTensor of size 8x3]
    (layer_1): Parameter containing: [torch.FloatTensor of size 8x3]
    (layer_2): Parameter containing: [torch.FloatTensor of size 8x3]
    (layer_3): Parameter containing: [torch.FloatTensor of size 8x3]
    (layer_4): Parameter containing: [torch.FloatTensor of size 8x3]
    (layer_5): Parameter containing: [torch.FloatTensor of size 8x3]
    (layer_6): Parameter containing: [torch.FloatTensor of size 8x3]
    (layer_7): Parameter containing: [torch.FloatTensor of size 8x3]
)

In [33]:
weights_small

ParameterDict(
    (layer_0): Parameter containing: [torch.FloatTensor of size 8x3]
    (layer_1): Parameter containing: [torch.FloatTensor of size 8x3]
    (layer_2): Parameter containing: [torch.FloatTensor of size 8x3]
    (layer_3): Parameter containing: [torch.FloatTensor of size 8x3]
)

In [54]:
p = 0.005

def early_evaluation_utils(state: torch.Tensor = None): 
    first_pair = [0,1]
    if state is not None:
        # state vector initialization with input
        qml.QubitStateVector(state, wires=range(8))
    for i in range(4):
        for j in range(8):
            qml.RX(weights_small[f'layer_{i}'][j, 0], wires=j)
            qml.RY(weights_small[f'layer_{i}'][j, 1], wires=j)
            qml.RZ(weights_small[f'layer_{i}'][j, 2], wires=j)
        for j in range(8):
            qml.CNOT(wires=[j, (j + 1) % 8])
            qml.DepolarizingChannel(p=p, wires=(j + 1) % 8)
    
    return qml.probs(wires = first_pair)

def fully_evaluation_utils(state: torch.Tensor = None):
    first_pair = [0,1]
    second_pair = [2,3]
    measurements = []
    if state is not None:
        # state vector initialization with input
        qml.QubitStateVector(state, wires=range(8))
    for i in range(4):
        for j in range(8):
            qml.RX(weights[f'layer_{i}'][j, 0], wires=j)
            qml.RY(weights[f'layer_{i}'][j, 1], wires=j)
            qml.RZ(weights[f'layer_{i}'][j, 2], wires=j)
        for j in range(8):
            qml.CNOT(wires=[j, (j + 1) % 8])
            qml.DepolarizingChannel(p=p, wires=(j + 1) % 8)
            
    for w in first_pair: 
        measurements.append(qml.measure(wires=w, reset=False, postselect=None))
    #m_0 = qml.measure(wires = 0, reset=False, postselect=None)
    #m_1 = qml.measure(wires = 1, reset=False, postselect=None)
    #print('ok meas 1')
    for i in range(4, 8):
        for j in range(8):
            qml.RX(weights[f'layer_{i}'][j, 0], wires=j)
            qml.RY(weights[f'layer_{i}'][j, 1], wires=j)
            qml.RZ(weights[f'layer_{i}'][j, 2], wires=j)
        for j in range(8):
            qml.CNOT(wires=[j, (j + 1) % 8])
            qml.DepolarizingChannel(p=p, wires=(j + 1) % 8)
    
    return qml.probs(op = measurements), qml.probs(wires=[2,3])

mixed_device = qml.device("default.mixed", wires=[0,1,2,3,4,5,6,7,8,9], shots = 50)
late_qnode = qml.QNode(fully_evaluation_utils, mixed_device)
early_qnode = qml.QNode(early_evaluation_utils, mixed_device)

In [55]:
def evaluation_routine(dataloader: DataLoader, threshold: float):
    
    evaluation_results = []
    early_results = []
    count_1 = 0 #counter for early classified images
    count_2 = 0 #counter for final classified images
    early_correct = 0 #counter for correctly early classified images 
    final_correct = 0 #counter for correctly final classified images
    executed_layers = 0
    
    for i, (img, target) in tqdm(enumerate(dataloader.dataset)):
        
        if i == 100:
            break
            
        img = img.type(torch.float64)
        img = img / torch.linalg.norm(img).view(-1, 1)
        
        #mid circuit evaluation
        early_probs = early_qnode(img, shots=50)
        early_prediction = torch.argmax(early_probs)
        confidence = early_probs[early_prediction].item()
        early_guess = early_prediction == target
        early_results.append(early_guess.item())
        
        if confidence >= threshold:
            #print('early')
            evaluation_results.append(early_guess.item())
            count_1 += 1
            executed_layers += 4
            if early_guess: 
                early_correct += 1
                
        else: 
            #print('post')
            final_probs = late_qnode(img, shots=50)
            early_full, final_full = final_probs
            final_predictions = torch.argmax(final_full)
            final_guess = final_predictions == target
            evaluation_results.append(final_guess.item())
            count_2 += 1
            executed_layers += 12
            if final_guess: 
                final_correct += 1
        
    total_accuracy = sum([1 for i in evaluation_results if i == True])/len(evaluation_results)
    early_total_accuracy = sum([1 for i in early_results if i == True])/len(early_results)
    early_exited_accuracy = early_correct/count_1 if count_1 > 0 else 0
    final_exited_accuracy = final_correct/count_2 if count_2 > 0 else 0   
    
    return total_accuracy, early_total_accuracy, early_exited_accuracy, count_1, final_exited_accuracy, count_2, executed_layers


In [56]:
import pandas as pd
def explain_evaluation(dataloader: DataLoader, threshold: List[float]):
    summary_data = {
    'Threshold': [],
    'Total Accuracy': [],
    '# early exited images': [],
    'Early exited Accuracy': [],
    'Early total accuracy': [],
    '# final classified images': [],
    'Final classified Accuracy': [],
    "Executed layers": []}
    
    for t in tqdm(threshold):
        total_accuracy, early_total_accuracy, early_exited_accuracy, count_1, final_exited_accuracy, count_2, executed_layers = evaluation_routine(dataloader, t)
        summary_data['Threshold'].append(t)
        summary_data['Total Accuracy'].append(total_accuracy)
        summary_data['# early exited images'].append(count_1)
        summary_data['Early exited Accuracy'].append(early_exited_accuracy)
        summary_data['Early total accuracy'].append(early_total_accuracy)
        summary_data['# final classified images'].append(count_2)
        summary_data['Final classified Accuracy'].append(final_exited_accuracy)
        summary_data['Executed layers'].append(executed_layers)
        
    df = pd.DataFrame(summary_data)
    return summary_data, df

In [57]:
thresholds = [round(x * 0.01 + 0.41, 2) for x in range(5)]

In [58]:
[round(x * 0.01 + 0.41, 2) for x in range(5)]

[0.41, 0.42, 0.43, 0.44, 0.45]

In [59]:
summary, table = explain_evaluation(test_dataloader, thresholds)

  0%|          | 0/5 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
1it [00:13, 13.72s/it][A
2it [00:27, 13.69s/it][A
3it [00:41, 13.68s/it][A
4it [00:45, 10.03s/it][A
5it [00:59, 11.34s/it][A
6it [01:12, 12.15s/it][A
7it [01:17,  9.63s/it][A
8it [01:21,  7.97s/it][A
9it [01:26,  6.86s/it][A
10it [01:39,  8.98s/it][A
11it [01:44,  7.59s/it][A
12it [01:57,  9.44s/it][A
13it [02:11, 10.72s/it][A
14it [02:16,  8.82s/it][A
15it [02:20,  7.50s/it][A
16it [02:34,  9.37s/it][A
17it [02:47, 10.67s/it][A
18it [03:01, 11.57s/it][A
19it [03:15, 12.20s/it][A
20it [03:28, 12.64s/it][A
21it [03:42, 13.05s/it][A
22it [03:56, 13.35s/it][A
23it [04:11, 13.57s/it][A
24it [04:15, 10.86s/it][A
25it [04:20,  8.97s/it][A
26it [04:34, 10.49s/it][A
27it [04:48, 11.55s/it][A
28it [05:02, 12.30s/it][A
29it [05:06,  9.97s/it][A
30it [05:11,  8.34s/it][A
31it [05:15,  7.20s/it][A
32it [05:29,  9.24s/it][A
33it [05:34,  7.83s/it][A
34it [05:48,  9.70s/it][A
35it [05:53,  8.15s/it][A
36it

In [60]:
table

Unnamed: 0,Threshold,Total Accuracy,# early exited images,Early exited Accuracy,Early total accuracy,# final classified images,Final classified Accuracy,Executed layers
0,0.41,0.77,51,0.980392,0.89,49,0.55102,792
1,0.42,0.8,54,0.925926,0.81,46,0.652174,768
2,0.43,0.83,39,0.974359,0.78,61,0.737705,888
3,0.44,0.81,43,0.976744,0.81,57,0.684211,856
4,0.45,0.77,37,1.0,0.82,63,0.634921,904


In [61]:
ev_data = {
    'summary': summary,
    'table': table}
import pickle
with open("/Users/jackvittori/Desktop/depolarizing001-altro.pickle", "wb") as file:
    pickle.dump(ev_data, file)

# Depolarizing 02

In [13]:
p = 0.02

def early_evaluation_utils(state: torch.Tensor = None): 
    first_pair = [0,1]
    if state is not None:
        # state vector initialization with input
        qml.QubitStateVector(state, wires=range(8))
    for i in range(4):
        for j in range(8):
            qml.RX(weights[f'layer_{i}'][j, 0], wires=j)
            qml.RY(weights[f'layer_{i}'][j, 1], wires=j)
            qml.RZ(weights[f'layer_{i}'][j, 2], wires=j)
        for j in range(8):
            qml.CNOT(wires=[j, (j + 1) % 8])
            qml.DepolarizingChannel(p=p, wires=(j + 1) % 8)
    
    return qml.probs(wires = first_pair)

def fully_evaluation_utils(state: torch.Tensor = None):
    first_pair = [0,1]
    second_pair = [2,3]
    measurements = []
    if state is not None:
        # state vector initialization with input
        qml.QubitStateVector(state, wires=range(8))
    for i in range(4):
        for j in range(8):
            qml.RX(weights[f'layer_{i}'][j, 0], wires=j)
            qml.RY(weights[f'layer_{i}'][j, 1], wires=j)
            qml.RZ(weights[f'layer_{i}'][j, 2], wires=j)
        for j in range(8):
            qml.CNOT(wires=[j, (j + 1) % 8])
            qml.DepolarizingChannel(p=p, wires=(j + 1) % 8)
            
    for w in first_pair: 
        measurements.append(qml.measure(wires=w, reset=False, postselect=None))
    #m_0 = qml.measure(wires = 0, reset=False, postselect=None)
    #m_1 = qml.measure(wires = 1, reset=False, postselect=None)
    #print('ok meas 1')
    for i in range(4, 8):
        for j in range(8):
            qml.RX(weights[f'layer_{i}'][j, 0], wires=j)
            qml.RY(weights[f'layer_{i}'][j, 1], wires=j)
            qml.RZ(weights[f'layer_{i}'][j, 2], wires=j)
        for j in range(8):
            qml.CNOT(wires=[j, (j + 1) % 8])
            qml.DepolarizingChannel(p=p, wires=(j + 1) % 8)
    
    return qml.probs(op = measurements), qml.probs(wires=[2,3])

mixed_device = qml.device("default.mixed", wires=[0,1,2,3,4,5,6,7,8,9], shots = 50)
late_qnode = qml.QNode(fully_evaluation_utils, mixed_device)
early_qnode = qml.QNode(early_evaluation_utils, mixed_device)

In [14]:
def evaluation_routine(dataloader: DataLoader, threshold: float):
    
    evaluation_results = []
    early_results = []
    count_1 = 0 #counter for early classified images
    count_2 = 0 #counter for final classified images
    early_correct = 0 #counter for correctly early classified images 
    final_correct = 0 #counter for correctly final classified images
    executed_layers = 0
    
    for i, (img, target) in tqdm(enumerate(dataloader.dataset)):
        
        if i == 100:
            break
            
        img = img.type(torch.float64)
        img = img / torch.linalg.norm(img).view(-1, 1)
        
        #mid circuit evaluation
        early_probs = early_qnode(img, shots=50)
        early_prediction = torch.argmax(early_probs)
        confidence = early_probs[early_prediction].item()
        early_guess = early_prediction == target
        early_results.append(early_guess.item())
        
        if confidence >= threshold:
            #print('early')
            evaluation_results.append(early_guess.item())
            count_1 += 1
            executed_layers += 4
            if early_guess: 
                early_correct += 1
                
        else: 
            #print('post')
            final_probs = late_qnode(img, shots=50)
            early_full, final_full = final_probs
            final_predictions = torch.argmax(final_full)
            final_guess = final_predictions == target
            evaluation_results.append(final_guess.item())
            count_2 += 1
            executed_layers += 12
            if final_guess: 
                final_correct += 1
        
    total_accuracy = sum([1 for i in evaluation_results if i == True])/len(evaluation_results)
    early_total_accuracy = sum([1 for i in early_results if i == True])/len(early_results)
    early_exited_accuracy = early_correct/count_1 if count_1 > 0 else 0
    final_exited_accuracy = final_correct/count_2 if count_2 > 0 else 0   
    
    return total_accuracy, early_total_accuracy, early_exited_accuracy, count_1, final_exited_accuracy, count_2, executed_layers


In [15]:
import pandas as pd
def explain_evaluation(dataloader: DataLoader, threshold: List[float]):
    summary_data = {
    'Threshold': [],
    'Total Accuracy': [],
    '# early exited images': [],
    'Early exited Accuracy': [],
    'Early total accuracy': [],
    '# final classified images': [],
    'Final classified Accuracy': [],
    "Executed layers": []}
    
    for t in tqdm(threshold):
        total_accuracy, early_total_accuracy, early_exited_accuracy, count_1, final_exited_accuracy, count_2, executed_layers = evaluation_routine(dataloader, t)
        summary_data['Threshold'].append(t)
        summary_data['Total Accuracy'].append(total_accuracy)
        summary_data['# early exited images'].append(count_1)
        summary_data['Early exited Accuracy'].append(early_exited_accuracy)
        summary_data['Early total accuracy'].append(early_total_accuracy)
        summary_data['# final classified images'].append(count_2)
        summary_data['Final classified Accuracy'].append(final_exited_accuracy)
        summary_data['Executed layers'].append(executed_layers)
        
    df = pd.DataFrame(summary_data)
    return summary_data, df

In [16]:
thresholds = [round(x * 0.01 + 0.26, 2) for x in range(15)]

In [17]:
summary, table = explain_evaluation(test_dataloader, thresholds)

  0%|          | 0/15 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
1it [00:04,  4.50s/it][A
2it [00:08,  4.48s/it][A
3it [00:13,  4.48s/it][A
4it [00:17,  4.48s/it][A
5it [00:22,  4.48s/it][A
6it [00:26,  4.48s/it][A
7it [00:31,  4.47s/it][A
8it [00:35,  4.48s/it][A
9it [00:40,  4.47s/it][A
10it [00:44,  4.47s/it][A
11it [00:49,  4.47s/it][A
12it [00:53,  4.48s/it][A
13it [00:58,  4.47s/it][A
14it [01:02,  4.47s/it][A
15it [01:07,  4.47s/it][A
16it [01:11,  4.47s/it][A
17it [01:16,  4.48s/it][A
18it [01:20,  4.48s/it][A
19it [01:25,  4.48s/it][A
20it [01:29,  4.48s/it][A
21it [01:33,  4.48s/it][A
22it [01:38,  4.47s/it][A
23it [01:42,  4.47s/it][A
24it [01:47,  4.48s/it][A
25it [01:51,  4.48s/it][A
26it [01:56,  4.48s/it][A
27it [02:00,  4.47s/it][A
28it [02:05,  4.47s/it][A
29it [02:09,  4.47s/it][A
30it [02:14,  4.48s/it][A
31it [02:18,  4.48s/it][A
32it [02:23,  4.48s/it][A
33it [02:27,  4.48s/it][A
34it [02:32,  4.47s/it][A
35it [02:36,  4.47s/it][A
36i

In [18]:
table

Unnamed: 0,Threshold,Total Accuracy,# early exited images,Early exited Accuracy,Early total accuracy,# final classified images,Final classified Accuracy,Executed layers
0,0.26,0.56,100,0.56,0.56,0,0.0,400
1,0.27,0.48,99,0.474747,0.47,1,1.0,408
2,0.28,0.46,100,0.46,0.46,0,0.0,400
3,0.29,0.58,89,0.595506,0.58,11,0.454545,488
4,0.3,0.56,91,0.571429,0.55,9,0.444444,472
5,0.31,0.56,76,0.578947,0.52,24,0.5,592
6,0.32,0.59,72,0.652778,0.59,28,0.428571,624
7,0.33,0.55,56,0.607143,0.46,44,0.477273,752
8,0.34,0.49,53,0.54717,0.47,47,0.425532,776
9,0.35,0.54,34,0.676471,0.49,66,0.469697,928


In [19]:
ev_data = {
    'summary': summary,
    'table': table}
import pickle
with open("/Users/jackvittori/Desktop/depolarizing02.pickle", "wb") as file:
    pickle.dump(ev_data, file)

# DEPOLARIZING 03

In [20]:
p = 0.03

def early_evaluation_utils(state: torch.Tensor = None): 
    first_pair = [0,1]
    if state is not None:
        # state vector initialization with input
        qml.QubitStateVector(state, wires=range(8))
    for i in range(4):
        for j in range(8):
            qml.RX(weights[f'layer_{i}'][j, 0], wires=j)
            qml.RY(weights[f'layer_{i}'][j, 1], wires=j)
            qml.RZ(weights[f'layer_{i}'][j, 2], wires=j)
        for j in range(8):
            qml.CNOT(wires=[j, (j + 1) % 8])
            qml.DepolarizingChannel(p=p, wires=(j + 1) % 8)
    
    return qml.probs(wires = first_pair)

def fully_evaluation_utils(state: torch.Tensor = None):
    first_pair = [0,1]
    second_pair = [2,3]
    measurements = []
    if state is not None:
        # state vector initialization with input
        qml.QubitStateVector(state, wires=range(8))
    for i in range(4):
        for j in range(8):
            qml.RX(weights[f'layer_{i}'][j, 0], wires=j)
            qml.RY(weights[f'layer_{i}'][j, 1], wires=j)
            qml.RZ(weights[f'layer_{i}'][j, 2], wires=j)
        for j in range(8):
            qml.CNOT(wires=[j, (j + 1) % 8])
            qml.DepolarizingChannel(p=p, wires=(j + 1) % 8)
            
    for w in first_pair: 
        measurements.append(qml.measure(wires=w, reset=False, postselect=None))
    #m_0 = qml.measure(wires = 0, reset=False, postselect=None)
    #m_1 = qml.measure(wires = 1, reset=False, postselect=None)
    #print('ok meas 1')
    for i in range(4, 8):
        for j in range(8):
            qml.RX(weights[f'layer_{i}'][j, 0], wires=j)
            qml.RY(weights[f'layer_{i}'][j, 1], wires=j)
            qml.RZ(weights[f'layer_{i}'][j, 2], wires=j)
        for j in range(8):
            qml.CNOT(wires=[j, (j + 1) % 8])
            qml.DepolarizingChannel(p=p, wires=(j + 1) % 8)
    
    return qml.probs(op = measurements), qml.probs(wires=[2,3])

mixed_device = qml.device("default.mixed", wires=[0,1,2,3,4,5,6,7,8,9], shots = 50)
late_qnode = qml.QNode(fully_evaluation_utils, mixed_device)
early_qnode = qml.QNode(early_evaluation_utils, mixed_device)

In [21]:
def evaluation_routine(dataloader: DataLoader, threshold: float):
    
    evaluation_results = []
    early_results = []
    count_1 = 0 #counter for early classified images
    count_2 = 0 #counter for final classified images
    early_correct = 0 #counter for correctly early classified images 
    final_correct = 0 #counter for correctly final classified images
    executed_layers = 0
    
    for i, (img, target) in tqdm(enumerate(dataloader.dataset)):
        
        if i == 100:
            break
            
        img = img.type(torch.float64)
        img = img / torch.linalg.norm(img).view(-1, 1)
        
        #mid circuit evaluation
        early_probs = early_qnode(img, shots=50)
        early_prediction = torch.argmax(early_probs)
        confidence = early_probs[early_prediction].item()
        early_guess = early_prediction == target
        early_results.append(early_guess.item())
        
        if confidence >= threshold:
            #print('early')
            evaluation_results.append(early_guess.item())
            count_1 += 1
            executed_layers += 4
            if early_guess: 
                early_correct += 1
                
        else: 
            #print('post')
            final_probs = late_qnode(img, shots=50)
            early_full, final_full = final_probs
            final_predictions = torch.argmax(final_full)
            final_guess = final_predictions == target
            evaluation_results.append(final_guess.item())
            count_2 += 1
            executed_layers += 12
            if final_guess: 
                final_correct += 1
        
    total_accuracy = sum([1 for i in evaluation_results if i == True])/len(evaluation_results)
    early_total_accuracy = sum([1 for i in early_results if i == True])/len(early_results)
    early_exited_accuracy = early_correct/count_1 if count_1 > 0 else 0
    final_exited_accuracy = final_correct/count_2 if count_2 > 0 else 0   
    
    return total_accuracy, early_total_accuracy, early_exited_accuracy, count_1, final_exited_accuracy, count_2, executed_layers


In [22]:
import pandas as pd
def explain_evaluation(dataloader: DataLoader, threshold: List[float]):
    summary_data = {
    'Threshold': [],
    'Total Accuracy': [],
    '# early exited images': [],
    'Early exited Accuracy': [],
    'Early total accuracy': [],
    '# final classified images': [],
    'Final classified Accuracy': [],
    "Executed layers": []}
    
    for t in tqdm(threshold):
        total_accuracy, early_total_accuracy, early_exited_accuracy, count_1, final_exited_accuracy, count_2, executed_layers = evaluation_routine(dataloader, t)
        summary_data['Threshold'].append(t)
        summary_data['Total Accuracy'].append(total_accuracy)
        summary_data['# early exited images'].append(count_1)
        summary_data['Early exited Accuracy'].append(early_exited_accuracy)
        summary_data['Early total accuracy'].append(early_total_accuracy)
        summary_data['# final classified images'].append(count_2)
        summary_data['Final classified Accuracy'].append(final_exited_accuracy)
        summary_data['Executed layers'].append(executed_layers)
        
    df = pd.DataFrame(summary_data)
    return summary_data, df

In [23]:
thresholds = [round(x * 0.01 + 0.26, 2) for x in range(15)]

In [24]:
summary, table = explain_evaluation(test_dataloader, thresholds)

  0%|          | 0/15 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
1it [00:04,  4.49s/it][A
2it [00:08,  4.48s/it][A
3it [00:13,  4.47s/it][A
4it [00:17,  4.48s/it][A
5it [00:22,  4.48s/it][A
6it [00:26,  4.47s/it][A
7it [00:31,  4.48s/it][A
8it [00:35,  4.48s/it][A
9it [00:40,  4.47s/it][A
10it [00:44,  4.48s/it][A
11it [00:49,  4.47s/it][A
12it [00:53,  4.48s/it][A
13it [00:58,  4.47s/it][A
14it [01:02,  4.47s/it][A
15it [01:07,  4.47s/it][A
16it [01:11,  4.47s/it][A
17it [01:16,  4.47s/it][A
18it [01:20,  4.47s/it][A
19it [01:24,  4.47s/it][A
20it [01:29,  4.47s/it][A
21it [01:33,  4.47s/it][A
22it [01:38,  4.47s/it][A
23it [01:42,  4.47s/it][A
24it [01:47,  4.47s/it][A
25it [01:51,  4.47s/it][A
26it [01:56,  4.47s/it][A
27it [02:00,  4.47s/it][A
28it [02:05,  4.47s/it][A
29it [02:09,  4.47s/it][A
30it [02:14,  4.47s/it][A
31it [02:18,  4.46s/it][A
32it [02:23,  4.47s/it][A
33it [02:27,  4.47s/it][A
34it [02:32,  4.47s/it][A
35it [02:36,  4.47s/it][A
36i

In [25]:
table

Unnamed: 0,Threshold,Total Accuracy,# early exited images,Early exited Accuracy,Early total accuracy,# final classified images,Final classified Accuracy,Executed layers
0,0.26,0.52,100,0.52,0.52,0,0.0,400
1,0.27,0.4,98,0.387755,0.38,2,1.0,416
2,0.28,0.49,99,0.494949,0.49,1,0.0,408
3,0.29,0.45,89,0.483146,0.49,11,0.181818,488
4,0.3,0.57,81,0.580247,0.53,19,0.526316,552
5,0.31,0.47,64,0.5625,0.46,36,0.305556,688
6,0.32,0.51,67,0.597015,0.52,33,0.333333,664
7,0.33,0.42,44,0.545455,0.47,56,0.321429,848
8,0.34,0.53,54,0.703704,0.58,46,0.326087,768
9,0.35,0.45,39,0.666667,0.41,61,0.311475,888


In [26]:
ev_data = {
    'summary': summary,
    'table': table}
import pickle
with open("/Users/jackvittori/Desktop/depolarizing03.pickle", "wb") as file:
    pickle.dump(ev_data, file)