In [1]:
%config Completer.use_jedi = False

Perform ensemble evaluation from logits files

## Import stuff and create arguments

In [2]:
import os
import numpy as np
import seaborn as sns
import pandas as pd
import torch
import matplotlib.pyplot as plt

In [3]:
import sys
sys.path.append('../')

from models import model_selector
from utils.data_augmentation import data_augmentation_selector
from utils.datasets import dataset_selector
from utils.neural import *
from utils.metrics import compute_accuracy

## Load logits

In [4]:
logits_dir="../logits"
prefix = "test"

In [5]:
logits_paths = []
for subdir, dirs, files in os.walk(logits_dir):
    for file in files:
        file_path = os.path.join(subdir, file)
        if f"{prefix}_logits" in file_path:
            logits_paths.append(file_path)

if not len(logits_paths):
    assert False, f"Could not find any file at subdirectoreis of '{logits_dir}' with prefix '{prefix}'"

## Get logits and labels

In [6]:
logits_list, labels_list = [], []
for lp in logits_paths:
    logits_name = "/".join(lp.split("/")[-2:])
    info = torch.load(lp)
    logits = info["logits"].cpu()
    labels = info["labels"].cpu()

    logits_accuracy = compute_accuracy(labels, logits)
    print(f"{logits_name}: {logits_accuracy}")
    logits_list.append(logits)
    labels_list.append(labels)

# logits_list shape: torch.Size([N, 10000, 10]) (CIFAR10 example)
logits_list = torch.stack(logits_list)

model1/test_logits_model_kuangliu_resnet18_best_accuracy.pt: 0.9482
model2/test_logits_model_kuangliu_resnet18_best_accuracy.pt: 0.9438
model3/test_logits_model_kuangliu_resnet18_best_accuracy.pt: 0.9469
model4/test_logits_model_kuangliu_resnet18_best_accuracy.pt: 0.9448
model5/test_logits_model_kuangliu_resnet18_best_accuracy.pt: 0.9458


#### -- Check if al labels has the same order for all logits --

In [7]:
labels = labels_list[0]
for indx, label_list in enumerate(labels_list[1:]):
    # Si alguno difiere del primero es que no es igual al resto tampoco
    if not torch.all(labels.eq(label_list)):
        assert False, f"Labels list does not match!"

## Compute Calibration Metrics

In [8]:
from utils.calibration import *

In [9]:
softmax = nn.Softmax(dim=2)
probs_list = softmax(logits_list)
probs_list.shape

torch.Size([5, 10000, 10])

In [10]:
print("---- First model Calibration Metrics ----")
ECE, MCE, BRIER, NNL = compute_calibration_metrics(probs_list[0], labels, apply_softmax=False, bins=15)
print(f"ECE: {ECE}")
print(f"MCE: {MCE}")
print(f"BRIER: {BRIER}")
print(f"NNL: {NNL}")

---- First model Calibration Metrics ----
ECE: 0.024884950649738308
MCE: 0.013317647820711135
BRIER: 0.008238566108047962
NNL: 0.19300468266010284
