In [11]:
####################################################################
# Scenario0: static하게 앙상블
# 모든 컴비네이션으로 경우의 수를 완탐해보고 마지막 출구보다 성능 좋은것 유무
####################################################################
# Scenario1: Entropy vs Temperature Scaling + Entropy 
# make dynamic ensemble of models which entropy is less than threshold
####################################################################
# Scenario2: MC Dropout -> find confident EE -> static ensemble 
#step1: make MC Dropout model
#step2: find confident EE from experiment
#step3: sum of softmax vector of each good model(under threshold) -> final inference from softmax vector sum
####################################################################
# Scenario3: train new block to choose which exit to inference 
# (JUST SCENARIO, NOT VERIFIED)
####################################################################

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import models,datasets, transforms
from mevit_model import MultiExitViT
from tqdm import tqdm
from itertools import combinations
import numpy as np
from scipy.stats import entropy
from temperature_scaling import TemperatureScaling
####################################################################
IMG_SIZE = 224
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dataset_name=dict();dataset_name['cifar10']=datasets.CIFAR10;dataset_name['cifar100']=datasets.CIFAR100;dataset_name['imagenet']=datasets.ImageNet
dataset_outdim=dict();dataset_outdim['cifar10']=10;dataset_outdim['cifar100']=100;dataset_outdim['imagenet']=1000
##############################################################
################ 0. Hyperparameters ##########################
##############################################################
batch_size = 1024
data_choice='cifar10'
mevit_isload=True
mevit_pretrained_path=f"models/{data_choice}/integrated_ee.pth"

backbone_path=f'models/{data_choice}/vit_{data_choice}_backbone.pth'
start_lr=1e-4
max_iter=200

ee_list=[0,1,2,3,4,5,6,7,8,9]#exit list ex) [0,1,2,3,4,5,6,7,8,9]
exit_loss_weights=[1,1,1,1,1,1,1,1,1,1,1]#exit마다 가중치
exit_num=11
##############################################################
transform = transforms.Compose([
        transforms.Resize(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

train_dataset = dataset_name[data_choice](root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = dataset_name[data_choice](root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Load the pretrained ViT model from the saved file
pretrained_vit = models.vit_b_16(weights=models.ViT_B_16_Weights.DEFAULT)

if data_choice != 'imagenet':
    pretrained_vit.heads.head = nn.Linear(pretrained_vit.heads.head.in_features, dataset_outdim[data_choice])  # Ensure output matches the number of classes

    # Load model weights
    pretrained_vit.load_state_dict(torch.load(backbone_path))
    pretrained_vit = pretrained_vit.to(device)
#from torchinfo import summary
#summary(pretrained_vit,input_size= (64, 3, IMG_SIZE, IMG_SIZE))

model = MultiExitViT(pretrained_vit,num_classes=dataset_outdim[data_choice],ee_list=ee_list,exit_loss_weights=exit_loss_weights).to(device)

# Assume a pretrained model (replace with your own model)
model.load_state_dict(torch.load(mevit_pretrained_path))  # Load your trained weights
file_path = f'cache_result_mevit_{data_choice}.pt'

Files already downloaded and verified
Files already downloaded and verified


In [13]:
TRAIN_AND_SAVE = 0
if (TRAIN_AND_SAVE):
    # Initialize lists to store outputs from each exit and labels
    output_list_list = [[] for _ in range(exit_num)]
    labels_list = []

    # Run inference on the test set and collect the logits from each exit
    model.eval()  # Set model to evaluation mode
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Collecting logits", leave=False):
            images, labels = images.cuda(), labels.cuda()
            output_list = model(images)  # Get the output from all exits
            
            # Store the output from each exit
            for i in range(exit_num):
                output_list_list[i].append(output_list[i])
            
            # Store the labels
            labels_list.append(labels)
    # Concatenate the collected outputs and labels
    output_tensor = torch.tensor(np.array([torch.cat(output_list_list[i]).to('cpu') for i in range(exit_num)])).to(device)
    labels = torch.cat(labels_list).to(device)
    # 리스트를 바이너리 파일로 저장하기
    torch.save(output_tensor, file_path)


In [14]:
# 저장한 파일을 다시 불러오기
output_tensor = torch.load(file_path).to(device)
labels_list = test_dataset.targets
labels=torch.tensor(labels_list).to(device)
# 데이터 확인
print(output_tensor.shape)  # <class 'list'>
print(output_tensor[0].shape)  # torch.Size([10000, 100])

torch.Size([11, 10000, 10])
torch.Size([10000, 10])


In [15]:
combinations_list = []
for i in range(2, exit_num+1):combinations_list.extend(combinations(range(exit_num), i))

In [16]:
# sum before softmax (best acc: 87.75)
bef_softmax = dict()
for choosed_exits in combinations_list:
    new_output_list = output_tensor[choosed_exits,:,:]
    ensemble_logits = new_output_list.sum(axis=0)
    ensemble_probabilities = F.softmax(ensemble_logits, dim=1)
    _, predicted_labels = torch.max(ensemble_probabilities, dim=1)
    correct_predictions = (predicted_labels == labels).sum().item()
    total_predictions = labels.size(0)
    accuracy = correct_predictions / total_predictions * 100
    bef_softmax[choosed_exits]=accuracy

sorted(list(bef_softmax.items()), key=lambda x: x[1], reverse=True)

[((7, 10), 97.68),
 ((9, 10), 97.65),
 ((3, 9, 10), 97.63),
 ((4, 9, 10), 97.63),
 ((0, 9, 10), 97.61999999999999),
 ((6, 8, 10), 97.61),
 ((6, 9, 10), 97.61),
 ((0, 6, 9, 10), 97.61),
 ((1, 6, 8, 10), 97.61),
 ((1, 6, 9, 10), 97.61),
 ((4, 7, 9, 10), 97.61),
 ((0, 7, 8, 9, 10), 97.61),
 ((3, 10), 97.6),
 ((8, 9, 10), 97.6),
 ((1, 8, 9, 10), 97.6),
 ((0, 10), 97.59),
 ((4, 10), 97.59),
 ((8, 10), 97.59),
 ((7, 9, 10), 97.59),
 ((0, 7, 9, 10), 97.59),
 ((3, 7, 9, 10), 97.59),
 ((6, 8, 9, 10), 97.59),
 ((0, 3, 7, 8, 9, 10), 97.59),
 ((6, 10), 97.58),
 ((3, 7, 10), 97.58),
 ((0, 7, 8, 10), 97.58),
 ((1, 4, 9, 10), 97.58),
 ((1, 7, 9, 10), 97.58),
 ((1, 6, 7, 9, 10), 97.58),
 ((1, 6, 8, 9, 10), 97.58),
 ((1, 6, 10), 97.57000000000001),
 ((1, 5, 9, 10), 97.57000000000001),
 ((3, 5, 9, 10), 97.57000000000001),
 ((3, 8, 9, 10), 97.57000000000001),
 ((4, 8, 9, 10), 97.57000000000001),
 ((0, 6, 7, 8, 9, 10), 97.57000000000001),
 ((0, 7, 10), 97.56),
 ((1, 9, 10), 97.56),
 ((4, 7, 10), 97.56),
 

In [17]:
# sum after softmax (best acc: 87.75)
aft_softmax = dict()
for choosed_exits in combinations_list:
    new_output_list = output_tensor[choosed_exits,:,:]
    softmax_vector_list = F.softmax(new_output_list, dim=2)
    ensemble_probabilities = softmax_vector_list.sum(axis=0)
    _, predicted_labels = torch.max(ensemble_probabilities, dim=1)
    correct_predictions = (predicted_labels == labels).sum().item()
    total_predictions = labels.size(0)
    accuracy = correct_predictions / total_predictions * 100
    aft_softmax[choosed_exits]=accuracy

sorted(list(aft_softmax.items()), key=lambda x: x[1], reverse=True)

[((7, 10), 97.69),
 ((9, 10), 97.65),
 ((4, 9, 10), 97.65),
 ((1, 8, 9, 10), 97.65),
 ((8, 10), 97.64),
 ((1, 7, 9, 10), 97.64),
 ((6, 10), 97.63),
 ((3, 7, 8, 9, 10), 97.63),
 ((0, 9, 10), 97.61999999999999),
 ((1, 5, 9, 10), 97.61999999999999),
 ((3, 8, 9, 10), 97.61999999999999),
 ((1, 7, 8, 9, 10), 97.61999999999999),
 ((0, 7, 9, 10), 97.61),
 ((4, 7, 9, 10), 97.61),
 ((2, 10), 97.6),
 ((8, 9, 10), 97.6),
 ((0, 4, 9, 10), 97.6),
 ((0, 3, 8, 9, 10), 97.6),
 ((0, 7, 8, 9, 10), 97.6),
 ((1, 4, 7, 9, 10), 97.6),
 ((3, 10), 97.59),
 ((4, 10), 97.59),
 ((0, 8, 10), 97.59),
 ((1, 9, 10), 97.59),
 ((0, 8, 9, 10), 97.59),
 ((0, 4, 7, 9, 10), 97.59),
 ((1, 5, 8, 9, 10), 97.59),
 ((0, 4, 7, 8, 9, 10), 97.59),
 ((0, 10), 97.58),
 ((5, 9, 10), 97.58),
 ((6, 9, 10), 97.58),
 ((1, 6, 9, 10), 97.58),
 ((6, 8, 9, 10), 97.58),
 ((7, 8, 9, 10), 97.58),
 ((0, 3, 7, 9, 10), 97.58),
 ((1, 6, 7, 9, 10), 97.58),
 ((3, 9, 10), 97.57000000000001),
 ((0, 5, 9, 10), 97.57000000000001),
 ((1, 4, 9, 10), 97.570

In [None]:
# 모든 exit에서 특정 엔트로피 이하인 것들을 모아서 앙상블을 해본다.
# case 2가지 소프트맥스 전에 합칠지 후에 합칠지;;

step_range = 10000
aft_sftmx = F.softmax(output_tensor,dim=2)
entropy_array= torch.tensor(entropy(aft_sftmx.to('cpu'), base=exit_num, axis=2))

min_entropy = entropy_array.min()
median_entropy = torch.median(torch.tensor(entropy_array))

step_size = (median_entropy - min_entropy) / step_range
d_bef_softmax=(0,0);d_aft_softmax=(0,0)
for mul in range(step_range):
    threshold = min_entropy + mul * step_size
    mask = (entropy_array[:, :] <= threshold).to(device)
    
    column_sums = mask.sum(dim=0)  
    zero_columns = (column_sums == 0)

    last_row = torch.zeros_like(mask)
    last_row[-1, :] = 1
    mask = mask | (last_row & zero_columns)

    mask.unsqueeze_(dim=2)
    masked_array=mask*output_tensor

    # sum before softmax (best acc: 87.75)
    ensemble_logits = masked_array.sum(axis=0)
    ensemble_probabilities = F.softmax(ensemble_logits, dim=1)
    _, predicted_labels = torch.max(ensemble_probabilities, dim=1)
    correct_predictions = (predicted_labels == labels).sum().item()
    total_predictions = labels.size(0)
    accuracy = correct_predictions / total_predictions * 100
    d_bef_softmax=max(d_bef_softmax,(threshold,accuracy),key=lambda x:x[1])

    # sum after softmax (best acc: 87.75)
    softmax_vector_list = F.softmax(masked_array, dim=2)
    ensemble_probabilities = softmax_vector_list.sum(axis=0)
    _, predicted_labels = torch.max(ensemble_probabilities, dim=1)
    correct_predictions = (predicted_labels == labels).sum().item()
    total_predictions = labels.size(0)
    accuracy = correct_predictions / total_predictions * 100
    d_aft_softmax=max(d_aft_softmax,(threshold,accuracy),key=lambda x:x[1])
print(f"d_bef_softmax: {d_bef_softmax}\nd_aft_softmax: {d_aft_softmax}")

  median_entropy = torch.median(torch.tensor(entropy_array))


d_bef_softmax: (tensor(0.0010), 97.64)
d_aft_softmax: (tensor(0.0010), 97.64)


In [None]:
t_scalers_path = f"models/{data_choice}/temperature_scaler.pth"
t_scalers = torch.load(t_scalers_path)
t_scalers_values = torch.tensor([t_scalers[i].temperature for i in range(exit_num)]).to(device)
print(t_scalers_values)
t_scalers_values=t_scalers_values.unsqueeze_(dim=1).unsqueeze_(dim=2).expand_as(output_tensor)
ts_output_tensor=output_tensor/t_scalers_values

In [20]:
# 모든 exit에서 특정 엔트로피 이하인 것들을 모아서 앙상블을 해본다.
# case 2가지 소프트맥스 전에 합칠지 후에 합칠지;;

step_range = 10000
aft_sftmx = F.softmax(ts_output_tensor,dim=2)
entropy_array= torch.tensor(entropy(aft_sftmx.to('cpu'), base=exit_num, axis=2))

min_entropy = entropy_array.min()
median_entropy = torch.median(torch.tensor(entropy_array))

step_size = (median_entropy - min_entropy) / step_range
d_bef_softmax=(0,0);d_aft_softmax=(0,0)
for mul in range(step_range):
    threshold = min_entropy + mul * step_size
    mask = (entropy_array[:, :] <= threshold).to(device)
    
    column_sums = mask.sum(dim=0)  
    zero_columns = (column_sums == 0)

    last_row = torch.zeros_like(mask)
    last_row[-1, :] = 1
    mask = mask | (last_row & zero_columns)

    mask.unsqueeze_(dim=2)
    masked_array=mask*output_tensor

    # sum before softmax (best acc: 87.75)
    ensemble_logits = masked_array.sum(axis=0)
    ensemble_probabilities = F.softmax(ensemble_logits, dim=1)
    _, predicted_labels = torch.max(ensemble_probabilities, dim=1)
    correct_predictions = (predicted_labels == labels).sum().item()
    total_predictions = labels.size(0)
    accuracy = correct_predictions / total_predictions * 100
    d_bef_softmax=max(d_bef_softmax,(threshold,accuracy),key=lambda x:x[1])

    # sum after softmax (best acc: 87.75)
    softmax_vector_list = F.softmax(masked_array, dim=2)
    ensemble_probabilities = softmax_vector_list.sum(axis=0)
    _, predicted_labels = torch.max(ensemble_probabilities, dim=1)
    correct_predictions = (predicted_labels == labels).sum().item()
    total_predictions = labels.size(0)
    accuracy = correct_predictions / total_predictions * 100
    d_aft_softmax=max(d_aft_softmax,(threshold,accuracy),key=lambda x:x[1])
print(f"d_bef_softmax: {d_bef_softmax}\nd_aft_softmax: {d_aft_softmax}")

  median_entropy = torch.median(torch.tensor(entropy_array))


d_bef_softmax: (tensor(0.0020), 97.64)
d_aft_softmax: (tensor(0.0020), 97.64)
