In [1]:
%config Completer.use_jedi = False

Perform ensemble evaluation from logits files

### Import stuff and create arguments

In [2]:
import os
import numpy as np
import seaborn as sns
import pandas as pd
import torch
import matplotlib.pyplot as plt

In [3]:
import sys
sys.path.append('../')

from models import model_selector
from utils.data_augmentation import data_augmentation_selector
from utils.datasets import dataset_selector
from utils.neural import *
from utils.metrics import compute_accuracy
from utils.calibration import *

# Load logits

#### Validation

In [4]:
logits_dir="../logits"
prefix = "val"

In [5]:
val_logits_paths = []
for subdir, dirs, files in os.walk(logits_dir):
    for file in files:
        file_path = os.path.join(subdir, file)
        if f"{prefix}_logits" in file_path:
            val_logits_paths.append(file_path)

if not len(val_logits_paths):
    assert False, f"Could not find any file at subdirectoreis of '{logits_dir}' with prefix '{prefix}'"

#### Test

In [6]:
logits_dir="../logits"
prefix = "test"

In [35]:
test_logits_paths = []
for subdir, dirs, files in os.walk(logits_dir):
    for file in files:
        file_path = os.path.join(subdir, file)
        if f"{prefix}_logits" in file_path:
            test_logits_paths.append(file_path)

if not len(test_logits_paths):
    assert False, f"Could not find any file at subdirectories of '{logits_dir}' with prefix '{prefix}'"

### Check order

In [8]:
val_logits_paths

['../logits/model1/val_logits_model_kuangliu_resnet18_best_accuracy.pt',
 '../logits/model2/val_logits_model_kuangliu_resnet18_best_accuracy.pt',
 '../logits/model3/val_logits_model_kuangliu_resnet18_best_accuracy.pt',
 '../logits/model4/val_logits_model_kuangliu_resnet18_best_accuracy.pt',
 '../logits/model5/val_logits_model_kuangliu_resnet18_best_accuracy.pt']

In [9]:
test_logits_paths

['../logits/model1/test_logits_model_kuangliu_resnet18_best_accuracy.pt',
 '../logits/model2/test_logits_model_kuangliu_resnet18_best_accuracy.pt',
 '../logits/model3/test_logits_model_kuangliu_resnet18_best_accuracy.pt',
 '../logits/model4/test_logits_model_kuangliu_resnet18_best_accuracy.pt',
 '../logits/model5/test_logits_model_kuangliu_resnet18_best_accuracy.pt']

# Get logits and labels

### Validation

In [10]:
val_logits_list, val_labels_list = [], []
for lp in val_logits_paths:
    logits_name = "/".join(lp.split("/")[-2:])
    info = torch.load(lp)
    logits = info["logits"].cpu()
    labels = info["labels"].cpu()

    logits_accuracy = compute_accuracy(labels, logits)
    print(f"{logits_name}: {logits_accuracy}")
    val_logits_list.append(logits)
    val_labels_list.append(labels)

# logits_list shape: torch.Size([N, 10000, 10]) (CIFAR10 example)
val_logits_list = torch.stack(val_logits_list)

model1/val_logits_model_kuangliu_resnet18_best_accuracy.pt: 0.9478
model2/val_logits_model_kuangliu_resnet18_best_accuracy.pt: 0.9488
model3/val_logits_model_kuangliu_resnet18_best_accuracy.pt: 0.95
model4/val_logits_model_kuangliu_resnet18_best_accuracy.pt: 0.9472
model5/val_logits_model_kuangliu_resnet18_best_accuracy.pt: 0.949


#### -- Check if al labels has the same order for all logits --

In [11]:
val_labels = val_labels_list[0]
for indx, label_list in enumerate(val_labels_list[1:]):
    # Si alguno difiere del primero es que no es igual al resto tampoco
    if not torch.all(labels.eq(label_list)):
        assert False, f"Labels list does not match!"

### Test

In [12]:
test_logits_list, test_labels_list = [], []
for lp in test_logits_paths:
    logits_name = "/".join(lp.split("/")[-2:])
    info = torch.load(lp)
    logits = info["logits"].cpu()
    labels = info["labels"].cpu()

    logits_accuracy = compute_accuracy(labels, logits)
    print(f"{logits_name}: {logits_accuracy}")
    test_logits_list.append(logits)
    test_labels_list.append(labels)

# logits_list shape: torch.Size([N, 10000, 10]) (CIFAR10 example)
test_logits_list = torch.stack(test_logits_list)

model1/test_logits_model_kuangliu_resnet18_best_accuracy.pt: 0.9482
model2/test_logits_model_kuangliu_resnet18_best_accuracy.pt: 0.9438
model3/test_logits_model_kuangliu_resnet18_best_accuracy.pt: 0.9469
model4/test_logits_model_kuangliu_resnet18_best_accuracy.pt: 0.9448
model5/test_logits_model_kuangliu_resnet18_best_accuracy.pt: 0.9458


#### -- Check if al labels has the same order for all logits --

In [13]:
test_labels = test_labels_list[0]
for indx, label_list in enumerate(test_labels_list[1:]):
    # Si alguno difiere del primero es que no es igual al resto tampoco
    if not torch.all(labels.eq(label_list)):
        assert False, f"Labels list does not match!"

# Temperature Scaling
https://github.com/gpleiss/temperature_scaling/blob/master/temperature_scaling.py

In [14]:
class TempScaling(nn.Module):
    """
    A thin decorator, which wraps a model with temperature scaling
    model (nn.Module):
        A classification neural network
        NB: Output of the neural network should be the classification logits, NOT the softmax (or log softmax)!
    """
    def __init__(self):
        super(TempScaling, self).__init__()
        self.temperature = nn.Parameter(torch.ones(1))

    def forward(self, logits):
        return self.temperature_scale(logits)

    def temperature_scale(self, logits):
        """
        Perform temperature scaling on logits
        """
        # Expand temperature to match the size of logits
        temperature = self.temperature.unsqueeze(1).expand(logits.size(0), logits.size(1))
        return logits * temperature

In [15]:
def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

### Train temperature parameter over first validation model

In [16]:
val_first_model_logits = val_logits_list[0]
print(f"All models logits: {val_logits_list.shape}")
print(f"First model logits: {val_first_model_logits.shape}")
print(f"All labels are shared by validation sets: {val_labels.shape}")

All models logits: torch.Size([5, 5000, 10])
First model logits: torch.Size([5000, 10])
All labels are shared by validation sets: torch.Size([5000])


In [19]:
lr = 0.1
T1 = TempScaling()
optimizer = torch.optim.SGD(T1.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss().cuda()
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=np.arange(0,100,20), gamma=0.1)

In [21]:
n_epochs = 100
batch_size = 128

header = "| {:{align}{widthL}} | {:{align}{widthA}} | {:{align}{widthA}} | {:{align}{widthLL}} | {:{align}{widthA}} | {:{align}{widthM}} | {:{align}{widthM}} | {:{align}{widthM}} | {:{align}{widthM}} |".format(
    "Epoch", "LR", "Loss", "Temp Param", "Accuracy", "ECE", "MCE", "BRIER", "NNL", align='^', widthL=8, widthLL=10, widthA=8, widthM=6,
)

print("".join(["_"] * len(header)))
print(header)
print("".join(["_"]*len(header)))

for epoch in range(n_epochs):
    
    T1.train()
    train_loss, correct, total = [], 0, 0
    c_ece, c_mce, c_brier, c_nnl = [], [], [], []
    
    for c_logits, c_labels in zip(chunks(val_first_model_logits, batch_size), chunks(val_labels, batch_size)):

        # Train
        optimizer.zero_grad()
        new_logits = T1(c_logits)
        loss = criterion(new_logits, c_labels)
        loss.backward()
        optimizer.step()

        # Metrics
        train_loss.append(loss.item())
        _, predicted = new_logits.max(1)
        total += len(c_labels)
        correct += predicted.eq(c_labels).sum().item()
        
        softmax = nn.Softmax(dim=1)
        new_probs_list = softmax(new_logits)
        ece, mce, brier, nnl = compute_calibration_metrics(new_probs_list, c_labels, apply_softmax=False, bins=15)
        c_ece.append(ece); c_mce.append(mce); c_brier.append(brier.item()); c_nnl.append(nnl.item()); 
    
    c_train_loss = np.array(train_loss).mean()
    c_accuracy = correct/total
    c_ece = np.array(c_ece).mean()
    c_mce = np.array(c_mce).mean()
    c_brier = np.array(c_brier).mean()
    c_nnl = np.array(c_nnl).mean()
    current_lr = get_current_lr(optimizer)
    
    line = "| {:{align}{widthL}} | {:{align}{widthA}.6f} | {:{align}{widthA}.4f} | {:{align}{widthLL}.4f} | {:{align}{widthA}.4f} | {:{align}{widthM}.4f} | {:{align}{widthM}.4f} | {:{align}{widthM}.4f} | {:{align}{widthM}.4f} |".format(
           epoch+1, current_lr, c_train_loss, T1.temperature.item(), c_accuracy, c_ece, c_mce, c_brier, c_nnl,
           align='^', widthL=8, widthA=8, widthM=6, widthLL=10
    )
    print(line)
    
    scheduler.step()

______________________________________________________________________________________________
|  Epoch   |    LR    |   Loss   | Temp Param | Accuracy |  ECE   |  MCE   | BRIER  |  NNL   |
______________________________________________________________________________________________
|    1     | 0.010000 |  0.1931  |   0.7663   |  0.9478  | 4.5543 | 1.3636 | 0.0087 | 0.1931 |
|    2     | 0.010000 |  0.1928  |   0.7627   |  0.9478  | 4.4926 | 1.3238 | 0.0087 | 0.1928 |
|    3     | 0.010000 |  0.1928  |   0.7634   |  0.9478  | 4.4716 | 1.3390 | 0.0087 | 0.1928 |
|    4     | 0.010000 |  0.1928  |   0.7633   |  0.9478  | 4.4647 | 1.3241 | 0.0087 | 0.1928 |
|    5     | 0.010000 |  0.1928  |   0.7633   |  0.9478  | 4.4647 | 1.3240 | 0.0087 | 0.1928 |
|    6     | 0.010000 |  0.1928  |   0.7633   |  0.9478  | 4.4647 | 1.3240 | 0.0087 | 0.1928 |
|    7     | 0.010000 |  0.1928  |   0.7633   |  0.9478  | 4.4647 | 1.3240 | 0.0087 | 0.1928 |
|    8     | 0.010000 |  0.1928  |   0.7633   |  0

|    86    | 0.000001 |  0.1922  |   0.7418   |  0.9478  | 4.2894 | 1.3129 | 0.0087 | 0.1922 |
|    87    | 0.000001 |  0.1922  |   0.7418   |  0.9478  | 4.2894 | 1.3129 | 0.0087 | 0.1922 |
|    88    | 0.000001 |  0.1922  |   0.7418   |  0.9478  | 4.2894 | 1.3129 | 0.0087 | 0.1922 |
|    89    | 0.000001 |  0.1922  |   0.7418   |  0.9478  | 4.2894 | 1.3129 | 0.0087 | 0.1922 |
|    90    | 0.000001 |  0.1922  |   0.7418   |  0.9478  | 4.2894 | 1.3129 | 0.0087 | 0.1922 |
|    91    | 0.000001 |  0.1922  |   0.7418   |  0.9478  | 4.2894 | 1.3129 | 0.0087 | 0.1922 |
|    92    | 0.000001 |  0.1922  |   0.7418   |  0.9478  | 4.2894 | 1.3129 | 0.0087 | 0.1922 |
|    93    | 0.000001 |  0.1922  |   0.7418   |  0.9478  | 4.2894 | 1.3129 | 0.0087 | 0.1922 |
|    94    | 0.000001 |  0.1922  |   0.7418   |  0.9478  | 4.2894 | 1.3129 | 0.0087 | 0.1922 |
|    95    | 0.000001 |  0.1922  |   0.7418   |  0.9478  | 4.2894 | 1.3129 | 0.0087 | 0.1922 |
|    96    | 0.000001 |  0.1922  |   0.7418   |  0

In [27]:
T1.eval()
softmax = nn.Softmax(dim=1)

TempScaling()

In [32]:
new_val_first_logits = T1(val_first_model_logits)
val_new_first_probs = softmax(new_val_first_logits)
val_first_probs = softmax(val_first_model_logits)

print("---- First model Calibration Metrics over Validation ----")
ECE, MCE, BRIER, NNL = compute_calibration_metrics(val_first_probs, val_labels, apply_softmax=False, bins=15)
print(f"ECE: {ECE}"); print(f"MCE: {MCE}"); print(f"BRIER: {BRIER}"); print(f"NNL: {NNL}")

print("\n---- First model Calibration Metrics over Validation (Temp Scal) ----")
ECE, MCE, BRIER, NNL = compute_calibration_metrics(val_new_first_probs, val_labels, apply_softmax=False, bins=15)
print(f"ECE: {ECE}"); print(f"MCE: {MCE}"); print(f"BRIER: {BRIER}"); print(f"NNL: {NNL}")

---- First model Calibration Metrics over Validation ----
ECE: 2.4941030383110045
MCE: 1.2195143759250642
BRIER: 0.008098339661955833
NNL: 0.19149670004844666

---- First model Calibration Metrics over Validation (Temp Scal) ----
ECE: 0.942747556567192
MCE: 0.1852864122390747
BRIER: 0.007771474774926901
NNL: 0.17561443150043488


In [34]:
test_first_model_logits = test_logits_list[0]
new_test_first_logits = T1(test_first_model_logits)
test_new_first_probs = softmax(new_test_first_logits)
test_first_probs = softmax(test_first_model_logits)

print("---- First model Calibration Metrics over Test ----")
ECE, MCE, BRIER, NNL = compute_calibration_metrics(test_first_probs, test_labels, apply_softmax=False, bins=15)
print(f"ECE: {ECE}"); print(f"MCE: {MCE}"); print(f"BRIER: {BRIER}"); print(f"NNL: {NNL}")

print("\n---- First model Calibration Metrics over Test (Temp Scal) ----")
ECE, MCE, BRIER, NNL = compute_calibration_metrics(test_new_first_probs, test_labels, apply_softmax=False, bins=15)
print(f"ECE: {ECE}"); print(f"MCE: {MCE}"); print(f"BRIER: {BRIER}"); print(f"NNL: {NNL}")

---- First model Calibration Metrics over Test ----
ECE: 2.4884950649738307
MCE: 1.3317647820711134
BRIER: 0.008238566108047962
NNL: 0.19300468266010284

---- First model Calibration Metrics over Test (Temp Scal) ----
ECE: 0.8139485321938993
MCE: 0.1761002254486084
BRIER: 0.007907265797257423
NNL: 0.17737245559692383
