In [None]:
import os
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()
sns.set_context("paper")
sns.set_style("ticks")
import tikzplotlib
from pathlib import Path
import sklearn.model_selection
import torch
from torch import nn
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision.models import (
    resnet50, ResNet50_Weights,
    vit_b_16, ViT_B_16_Weights,
)
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import numpy as np


from calibrators import fit_scaling_model
from utils import logits_labels_from_dataloader, LogitsDataset
from netcal.presentation import ReliabilityDiagram

n_bins = 15

seed = 0

torch.manual_seed(seed)
np.random.seed(seed)
rng = np.random.default_rng(seed)

path_results = os.path.dirname(os.getcwd()) + '/results/'


path_dataset = Path('/imagenet')

BATCH_SIZE = 64

In [None]:
model_name = 'ViT-B/16'

models_and_weights_torchvision = {
    'ResNet-50': (resnet50, ResNet50_Weights.DEFAULT),
    'ViT-B/16': (vit_b_16, ViT_B_16_Weights.DEFAULT), 
}

architecture, weights = models_and_weights_torchvision[model_name]
classifier = architecture(weights=weights).eval().cuda()
transforms = weights.transforms()
num_classes = 1000
num_epochs = 200

In [None]:
valid_size = 25000

dataset_val = ImageFolder(path_dataset/'val', transform=transforms)
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, num_workers=4, pin_memory=True, shuffle=True)
# separate into validation (calibration) and test sets
test_indices, valid_indices = sklearn.model_selection.train_test_split(np.arange(len(dataset_val)),
                                                                        train_size=len(dataset_val) - valid_size,
                                                                        stratify=dataset_val.targets,
                                                                        random_state=seed)
valid_loader = DataLoader(dataset_val, pin_memory=True, batch_size=BATCH_SIZE,
                        sampler=SubsetRandomSampler(valid_indices), num_workers=4)
test_loader = DataLoader(dataset_val, pin_memory=True, batch_size=BATCH_SIZE,
                        sampler=SubsetRandomSampler(test_indices), num_workers=4)

# Create calibration data
logits_val, labels_val = logits_labels_from_dataloader(classifier, valid_loader)
logits_test, labels_test = logits_labels_from_dataloader(classifier, test_loader)
dataset_logits_val = LogitsDataset(logits_val, labels_val)
dataloader_logits_val = DataLoader(dataset_logits_val, batch_size=512)

# Reliability diagrams

In [None]:
# figure 1
probs = torch.softmax(logits_test, axis=1).numpy()
ground_truth = labels_test.numpy()

diagram = ReliabilityDiagram(bins=15)
diagram.plot(probs, ground_truth, tikz=True, filename=f"ReliabilityDiagram_{model_name.replace('/', '')}_uncalibrated.tikz", axis_height ='4cm', axis_width='8cm');

# figure 2
if model_name == 'ResNet-50':
    model = fit_scaling_model('temperature', dataloader_logits_val, num_classes, binary_loss=False, regularization=False, num_epochs=num_epochs)
    logits_scaled = model(logits_test.cuda()).detach().cpu()
    probs_scaled_TS = torch.softmax(logits_scaled, axis=1).numpy()

    diagram = ReliabilityDiagram(bins=15)
    diagram.plot(probs_scaled_TS, ground_truth, tikz=True, filename=f"ReliabilityDiagram_{model_name}_TS.tikz", axis_height ='4cm', axis_width='8cm')

elif model_name == 'ViT-B/16':
    model = fit_scaling_model('vector', dataloader_logits_val, num_classes, binary_loss=False, regularization=False, num_epochs=num_epochs)
    logits_scaled = model(logits_test.cuda()).detach().cpu()
    probs_scaled_VS = torch.softmax(logits_scaled, axis=1).numpy()

    diagram = ReliabilityDiagram(bins=15)
    diagram.plot(probs_scaled_VS, ground_truth, tikz=True, filename=f"ReliabilityDiagram_{model_name.replace('/', '')}_VS.tikz", axis_height ='4cm', axis_width='8cm')

# figure 3
if model_name == 'ResNet-50':
    model = fit_scaling_model('temperature', dataloader_logits_val, num_classes, binary_loss=True, regularization=False, num_epochs=num_epochs)
    logits_scaled = model(logits_test.cuda()).detach().cpu()
    probs_scaled_TStva = torch.softmax(logits_scaled, axis=1).numpy()

    diagram = ReliabilityDiagram(bins=15)
    diagram.plot(probs_scaled_TStva, ground_truth, tikz=True, filename=f"ReliabilityDiagram_{model_name}_TStva.tikz", axis_height ='4cm', axis_width='8cm');

elif model_name == 'ViT-B/16':
    model = fit_scaling_model('vector', dataloader_logits_val, num_classes, binary_loss=True, regularization=True, num_epochs=num_epochs)
    logits_scaled = model(logits_test.cuda()).detach().cpu()
    probs_scaled_VSregtva = torch.softmax(logits_scaled, axis=1).numpy()

    diagram = ReliabilityDiagram(bins=15)
    diagram.plot(probs_scaled_VSregtva, ground_truth, tikz=True, filename=f"ReliabilityDiagram_{model_name.replace('/', '')}_VSregtva.tikz", axis_height ='4cm', axis_width='8cm')


# figure 4
probs = torch.softmax(logits_val, axis=1)
certainties, y_pred = probs.max(axis=1)
correct = y_pred == labels_val
model = HB_binary()
model.fit(certainties.numpy(), correct.numpy())
probs = torch.softmax(logits_test, axis=1)
certainties_test, y_pred = probs.max(axis=1)
certainties_scaled = model.predict_proba(certainties_test.cpu().numpy())
correct = (y_pred == labels_test).numpy()

diagram = ReliabilityDiagram(bins=15)
diagram.plot(certainties_scaled, correct, tikz=True, filename=f"ReliabilityDiagram_{model_name.replace('/', '')}_HBtva.tikz", axis_height ='4cm', axis_width='8cm');

# Regularization

In [None]:
from calibrators import Temperature, Vector, Dirichlet

def compute_ECE(model):
    with torch.no_grad():
        logits_scaled = model(logits_val.cuda()).cpu()
    probs_scaled_VS = torch.softmax(logits_scaled, axis=1).numpy()
    ece_train = ece(labels_val, probs_scaled_VS, num_bins=15)
    
    with torch.no_grad():
        logits_scaled = model(logits_test.cuda()).cpu()
    probs_scaled_VS = torch.softmax(logits_scaled, axis=1).numpy()
    ece_test = ece(labels_test, probs_scaled_VS, num_bins=15)
    
    return ece_train, ece_test


def fit_scaling_model_log(method, dataloader_logits_calib, num_classes, binary_loss, regularization, temperature_ref=1, num_epochs=200):
    
    if method == 'temperature':
        calibrator = Temperature().cuda()
    elif method =='vector':
        calibrator = Vector(num_classes, temperature_ref).cuda()
    elif method == 'dirichlet':
        calibrator = Dirichlet(num_classes, temperature_ref).cuda()
    else:
        raise ValueError('Unknown method')

    optimizer = torch.optim.Adam(calibrator.parameters(), lr=0.001)

    ECE_train = np.zeros(num_epochs)
    ECE_test = np.zeros(num_epochs)
    for epoch in range(num_epochs):
        epoch_loss = 0
        for x, y in dataloader_logits_calib:
            optimizer.zero_grad()
            x, y = x.cuda(), y.cuda()
            logits_scaled = calibrator(x)
            if binary_loss:
                probas = torch.softmax(logits_scaled, axis=1)
                confidence, y_pred = torch.max(probas, axis=1)
                correct = (y_pred == y).float()
                loss = nn.functional.binary_cross_entropy(confidence, correct)
            else:
                loss = nn.functional.cross_entropy(logits_scaled, y)
            if method == 'dirichlet':
                loss += calibrator.off_diag_reg * calibrator.model[0].weight.clone().fill_diagonal_(0).square().sum()
                loss += calibrator.bias_reg * calibrator.model[0].bias.square().sum()
                if regularization:
                    loss += calibrator.diag_reg * (torch.diagonal(calibrator.model[0].weight) - 1).square().mean()
            elif method == 'vector' and regularization:
                loss += calibrator.vec_reg * (calibrator.vec - 1).square().mean()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
        ECE_train[epoch], ECE_test[epoch] = compute_ECE(calibrator)
    
    return ECE_test

ECE_test = {}
ECE_test['TS_tva'] = fit_scaling_model_log('temperature', dataloader_logits_val, num_classes, binary_loss=True, regularization=False)
ECE_test['VS'] = fit_scaling_model_log('vector', dataloader_logits_val, num_classes, binary_loss=False, regularization=False)
ECE_test['VS_reg'] = fit_scaling_model_log('vector', dataloader_logits_val, num_classes, binary_loss=False, regularization=True)
ECE_test['VS_tva'] = fit_scaling_model_log('vector', dataloader_logits_val, num_classes, binary_loss=True, regularization=False)
ECE_test['VS_reg_tva'] = fit_scaling_model_log('vector', dataloader_logits_val, num_classes, binary_loss=True, regularization=True)

epoch_start = 20
plt.figure()
plt.plot(np.arange(epoch_start, len(ECE_test['TS_tva'])), 100*ECE_test['TS_tva'][epoch_start:], c='k', ls=':', label=r'TS\textsubscript{TvA}')
plt.plot(np.arange(epoch_start, len(ECE_test['VS'])), 100*ECE_test['VS'][epoch_start:], c='C0', ls='--', label='VS')
plt.plot(np.arange(epoch_start, len(ECE_test['VS_reg'])), 100*ECE_test['VS_reg'][epoch_start:], c='C0', ls='-', label=r'VS\textsubscript{reg}')
plt.plot(np.arange(epoch_start, len(ECE_test['VS_tva'])), 100*ECE_test['VS_tva'][epoch_start:], c='C1', ls='--', label=r'VS\textsubscript{TvA}')
plt.plot(np.arange(epoch_start, len(ECE_test['VS_reg_tva'])), 100*ECE_test['VS_reg_tva'][epoch_start:], c='C1', ls='-', label=r'VS\textsubscript{reg\_TvA}')
plt.legend()
plt.xlabel('epochs')
plt.ylabel('ECE test [%]')
tikzplotlib.save('regularization.tikz')

# Histogram for classwise ECE

In [None]:
probas = torch.softmax(logits_test, 1)


fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(12,3))
axes[0].set_ylabel('number of Samples (log scale)')
for i in range(3):
    random_class_idx = np.random.randint(1000)
    axes[i].hist(probas[:, random_class_idx], log=True)
    axes[i].set_xlabel(f'class probability for class {random_class_idx}')
tikzplotlib.save('classwiseECE_histograms.tikz')