In [None]:
import sys
sys.path.append('./stylegan/stylegan2')

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

from pathlib import Path
import pickle
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import CIFAR10
from torchvision import transforms as T
from torch.nn import functional as F



from cifar10_models.vgg import vgg11_bn
from cifar10_models.mobilenetv2 import mobilenet_v2
from cifar10_models.resnet import resnet18, resnet50
from temperature_scaling import ModelWithTemperature, _ECELoss
from general_calibration_error import gce
from dirichletcal.calib.vectorscaling import VectorScaling

for p in [
    Path('/d/alecoz/projects'), # DeepLab
    Path(os.path.expandvars('$WORK')), # Jean Zay
    Path('w:/')]: # local
    if os.path.exists(p):
        path_main = p
# path_results = path_main / 'uncertainty-conditioned-gan/results'
for p in [
    Path('/scratchf/CIFAR'), # DeepLab
    Path(os.path.expandvars('$DSDIR'))]: # Jean Zay
    if os.path.exists(p):
        path_dataset = p
path_models = Path.cwd().parent / 'models' / 'CIFAR10'

In [None]:
def postprocess_synthetic_images(images):
    assert images.dim() == 4, "Expected 4D (B x C x H x W) image tensor, got {}D".format(images.dim())
    images = ((images + 1) / 2).clamp(0, 1) # scale
    return images

def preprocess_images_classifier(images):
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2471, 0.2435, 0.2616)
    images = T.Normalize(mean, std)(images)
    return images

# Load

In [None]:
# DATA
batch_size = 512

idx_to_label = {
    0: 'airplane',
    1: 'car',
    2: 'bird',
    3: 'cat',
    4: 'deer', 
    5: 'dog', 
    6: 'frog', 
    7: 'horse', 
    8: 'ship',
    9: 'truck'}

mean = (0.4914, 0.4822, 0.4465)
std = (0.2471, 0.2435, 0.2616)
transforms = T.Compose(
    [T.ToTensor(),
    T.Normalize(mean, std)])
dataset_train = CIFAR10(root=path_dataset, train=True, transform=transforms)
dataset_val = CIFAR10(root=path_dataset, train=False, transform=transforms)
dataset_calib, dataset_test = torch.utils.data.random_split(dataset_val, [5000, 5000], generator=torch.Generator().manual_seed(123))
dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, pin_memory=True)
dataloader_calib = DataLoader(dataset_calib, batch_size=batch_size, shuffle=True, pin_memory=True)
dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=True, pin_memory=True)
    
# GENERATOR
with open(path_models/'cifar10.pkl', 'rb') as f:
    G = pickle.load(f)['G_ema'].cuda() # torch.nn.Module

# CLASSIFIER
classifier = resnet50(pretrained=True).eval().requires_grad_(False).cuda()

In [None]:
# test generator
z = torch.randn([1, G.z_dim]).cuda() # latent codes
label_gen = 4
c = torch.nn.functional.one_hot(torch.tensor([label_gen]).cuda(), num_classes=G.c_dim) # class labels
img = G(z, c, truncation_psi=0.7) # NCHW, float32, dynamic range [-1, +1]
img = postprocess_synthetic_images(img)

# test classifier
logits = classifier(preprocess_images_classifier(img))
probas = torch.softmax(logits, dim=1)
proba, label_pred = torch.max(probas, 1)

# plot
plt.figure()
plt.imshow(img[0].permute(1,2,0).cpu())
plt.title(f'generated {idx_to_label[label_gen]}, predicted {idx_to_label[label_pred.item()]} {100*proba.item():.2f} %')

In [None]:
# def get_MSP_correct(dataloader, classifier):

#     classifier.eval()
#     msp = torch.zeros((len(dataloader.dataset)))
#     correct = torch.zeros((len(dataloader.dataset)))
#     idx = 0
#     for X, y in dataloader:
#         batch_size = X.shape[0]
#         X, y = X.cuda(), y.cuda()

#         with torch.no_grad():
#             logits = classifier(X)
#             probas, class_pred = torch.max(torch.softmax(logits, axis=1), axis=1)
#         msp[idx:idx+batch_size] = probas
#         correct[idx:idx+batch_size] = class_pred == y.squeeze()
#         idx += batch_size

#     return msp, correct


# msp_train, correct_train = get_MSP_correct(dataloader_train, classifier)
# msp_calib, correct_calib = get_MSP_correct(dataloader_calib, classifier)
# msp_test, correct_test = get_MSP_correct(dataloader_test, classifier)

# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
# ax1.set_xlabel('MSP value')
# ax1.hist(msp_test, alpha=0.5, bins=50, log=True, density=True, label='test')
# ax1.hist(msp_train, alpha=0.5, bins=50, log=True, density=True, label='train')
# ax1.legend()
# ax2.set_xlabel('MSP value')
# ax2.hist(msp_test, alpha=0.5, bins=50, log=True, density=True, label='test')
# ax2.hist(msp_calib, alpha=0.5, bins=50, log=True, density=True, label='calib')
# ax2.legend()

In [None]:
# print(f'train accuracy: {correct_train.sum() / len(correct_train):.3f}')
# print(f'calib accuracy: {correct_calib.sum() / len(correct_calib):.3f}')
# print(f'test accuracy: {correct_test.sum() / len(correct_test):.3f}')

# Calibration

In [None]:
def metrics_from_dataloader(model, dataloader, vector_scale=None):
    # First: collect all the logits and labels for the validation set
    logits_list = []
    labels_list = []
    for input, label in dataloader:
        input = input.cuda()
        with torch.no_grad():
            logits = model(input)
        logits_list.append(logits)
        labels_list.append(label)
    logits = torch.cat(logits_list).cpu()
    if vector_scale is not None:
        probs = torch.from_numpy(vector_scale.predict_proba(logits).copy())
    else:
        probs = torch.softmax(logits, dim=1)
    labels = torch.cat(labels_list).cpu()
    # Second: compute the metrics
    ece = gce(labels, probs, binning_scheme='even', class_conditional=False, max_prob=True, norm='l1', num_bins=15)
    sce = gce(labels, probs, binning_scheme='even', class_conditional=False, max_prob=False, norm='l1', num_bins=15)
    rmsce = gce(labels, probs, binning_scheme='adaptive', class_conditional=False, max_prob=True, norm='l2', datapoints_per_bin=100)
    ace = gce(labels, probs, binning_scheme='adaptive', class_conditional=True, max_prob=False, norm='l1')
    tace = gce(labels, probs, binning_scheme='adaptive', class_conditional=True, max_prob=False, norm='l1', threshold=0.01)
    
    metrics = {'ece': ece, 'sce': sce, 'rmsce': rmsce, 'ace': ace, 'tace': tace}
    return metrics

def ece_from_dataloader(model, dataloader, unnormalize=False):
    # First: collect all the logits and labels for the validation set
    logits_list = []
    labels_list = []
    for input, label in dataloader:
        input = input.cuda()
        with torch.no_grad():
            logits = model(input)
        logits_list.append(logits)
        labels_list.append(label)
    logits = torch.cat(logits_list).cuda()
    labels = torch.cat(labels_list).cuda()
    # Second: compute the ECE
    ece = _ECELoss(unnormalize=unnormalize)(logits, labels)
    ece = ece.item()
    
    return ece

def hist_from_dataloader(model, dataloader, vector_scale=None):
    n_bins = 15
    bin_boundaries = torch.linspace(0, 1, n_bins + 1)
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
        
    logits_list = []
    labels_list = []
    for input, label in dataloader:
        input = input.cuda()
        with torch.no_grad():
            logits = model(input)
        logits_list.append(logits)
        labels_list.append(label)
    logits = torch.cat(logits_list).cuda()
    labels = torch.cat(labels_list).cuda()
    
    if vector_scale is not None:
        softmaxes = torch.from_numpy(vector_scale.predict_proba(logits.cpu())).to(logits.device)
    else:
        softmaxes = F.softmax(logits, dim=1)
    confidences, predictions = torch.max(softmaxes, 1)
    if labels.dim() > 1:
        if labels.shape[1] > 1: # one-hot embedding
            labels = labels.argmax(1)
    accuracies = predictions.eq(labels)

    accuracies_in_bin = []
    avg_confidences_in_bin = []
    data_in_bin = []
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
        prop_in_bin = in_bin.float().mean()
        if prop_in_bin.item() > 0:
            accuracy_in_bin = accuracies[in_bin].float().mean()
            avg_confidence_in_bin = confidences[in_bin].mean()
        else:
            accuracy_in_bin = float('nan')
            avg_confidence_in_bin = float('nan')
        accuracies_in_bin.append(accuracy_in_bin)
        avg_confidences_in_bin.append(avg_confidence_in_bin)
        data_in_bin.append(100*in_bin.cpu().sum()/len(in_bin)) # in %
            
    avg_confidences_in_bin = torch.as_tensor(avg_confidences_in_bin)
    accuracies_in_bin = torch.as_tensor(accuracies_in_bin)
    data_in_bin = torch.as_tensor(data_in_bin)
            
    return avg_confidences_in_bin, accuracies_in_bin, data_in_bin, bin_boundaries

### From real validation data

In [None]:
metrics_calib_before_TS = metrics_from_dataloader(classifier, dataloader_calib)
metrics_test_before_TS = metrics_from_dataloader(classifier, dataloader_test)
confid_calib_before_TS, acc_calib_before_TS, in_bin_calib, bin_boundaries = hist_from_dataloader(classifier, dataloader_calib)
confid_test_before_TS, acc_test_before_TS, in_bin_test, bin_boundaries = hist_from_dataloader(classifier, dataloader_test)

# Performing temperature scaling
model = ModelWithTemperature(classifier).cuda()
model.set_temperature(dataloader_calib)

metrics_calib_after_TS = metrics_from_dataloader(model, dataloader_calib)
metrics_test_after_TS = metrics_from_dataloader(model, dataloader_test)
confid_calib_after_TS, acc_calib_after_TS, _, bin_boundaries = hist_from_dataloader(model, dataloader_calib)
confid_test_after_TS, acc_test_after_TS, _, bin_boundaries = hist_from_dataloader(model, dataloader_test)

In [None]:
fig, axs = plt.subplots(3, 2, figsize=(10, 15))

# histo of data
axs[0, 0].stairs(in_bin_calib, bin_boundaries, fill=True)
axs[0, 0].set_title('on calib data')
axs[0, 1].stairs(in_bin_test, bin_boundaries, fill=True)
axs[0, 1].set_title('on test data')
for ax in [axs[0, 0], axs[0, 1]]:
    ax.set_xlabel('confidence')
    ax.set_ylabel('% of samples')

# reliability diagrams
for ax in axs.flatten()[2:]:
    ax.set_xlabel('confidence')
    ax.set_ylabel('accuracy')

list_names = ['on calib data before TS', 'on test data before TS', 'on calib data after TS', 'on test data after TS']
list_avg_confid = [confid_calib_before_TS, confid_test_before_TS, confid_calib_after_TS, confid_test_after_TS]
list_acc = [acc_calib_before_TS, acc_test_before_TS, acc_calib_after_TS, acc_test_after_TS]
list_metrics = [metrics_calib_before_TS, metrics_test_before_TS, metrics_calib_after_TS, metrics_test_after_TS]
for ax, name, avg_confid, acc, metrics in zip(axs.flatten()[2:], list_names, list_avg_confid, list_acc, list_metrics):
    ax.stairs(avg_confid, bin_boundaries, fill=True, alpha=0.8, label='perfect')
    ax.stairs(acc, bin_boundaries, fill=True, alpha=0.8, label='real')
    ax.legend(loc='lower right')
    if 'after TS' in name: ax.text(0, 0.1, f'Temp: {model.temperature.item():.3f}', fontsize='large')
    ax.text(0, 0.8, f'ECE: {100*metrics["ece"]:.3f}%', fontsize='large')
    ax.text(0, 0.7, f'SCE: {100*metrics["sce"]:.3f}%', fontsize='large')
    ax.text(0, 0.6, f'RMSCE: {100*metrics["rmsce"]:.3f}%', fontsize='large')
    ax.text(0, 0.5, f'ACE: {100*metrics["ace"]:.3f}%', fontsize='large')
    ax.set_title(name)

### From real train data

In [None]:
# metrics_calib_before_TS = metrics_from_dataloader(classifier, dataloader_train)
# metrics_test_before_TS = metrics_from_dataloader(classifier, dataloader_test)
# confid_calib_before_TS, acc_calib_before_TS, in_bin_calib, bin_boundaries = hist_from_dataloader(classifier, dataloader_train)
# confid_test_before_TS, acc_test_before_TS, in_bin_test, bin_boundaries = hist_from_dataloader(classifier, dataloader_test)

# # Performing temperature scaling
# model = ModelWithTemperature(classifier).cuda()
# model.set_temperature(dataloader_train)

# metrics_calib_after_TS = metrics_from_dataloader(model, dataloader_train)
# metrics_test_after_TS = metrics_from_dataloader(model, dataloader_test)
# confid_calib_after_TS, acc_calib_after_TS, _, bin_boundaries = hist_from_dataloader(model, dataloader_train)
# confid_test_after_TS, acc_test_after_TS, _, bin_boundaries = hist_from_dataloader(model, dataloader_test)

In [None]:
# fig, axs = plt.subplots(3, 2, figsize=(10, 15))

# # histo of data
# axs[0, 0].stairs(in_bin_calib, bin_boundaries, fill=True)
# axs[0, 0].set_title('on calib data')
# axs[0, 1].stairs(in_bin_test, bin_boundaries, fill=True)
# axs[0, 1].set_title('on test data')
# for ax in [axs[0, 0], axs[0, 1]]:
#     ax.set_xlabel('confidence')
#     ax.set_ylabel('% of samples')

# # reliability diagrams
# for ax in axs.flatten()[2:]:
#     ax.set_xlabel('confidence')
#     ax.set_ylabel('accuracy')

# list_names = ['on calib data before TS', 'on test data before TS', 'on calib data after TS', 'on test data after TS']
# list_avg_confid = [confid_calib_before_TS, confid_test_before_TS, confid_calib_after_TS, confid_test_after_TS]
# list_acc = [acc_calib_before_TS, acc_test_before_TS, acc_calib_after_TS, acc_test_after_TS]
# list_metrics = [metrics_calib_before_TS, metrics_test_before_TS, metrics_calib_after_TS, metrics_test_after_TS]
# for ax, name, avg_confid, acc, metrics in zip(axs.flatten()[2:], list_names, list_avg_confid, list_acc, list_metrics):
#     ax.stairs(avg_confid, bin_boundaries, fill=True, alpha=0.8, label='perfect')
#     ax.stairs(acc, bin_boundaries, fill=True, alpha=0.8, label='real')
#     ax.legend(loc='lower right')
#     if 'after TS' in name: ax.text(0, 0.1, f'Temp: {model.temperature.item():.3f}', fontsize='large')
#     ax.text(0, 0.8, f'ECE: {100*metrics["ece"]:.3f}%', fontsize='large')
#     ax.text(0, 0.7, f'SCE: {100*metrics["sce"]:.3f}%', fontsize='large')
#     ax.text(0, 0.6, f'RMSCE: {100*metrics["rmsce"]:.3f}%', fontsize='large')
#     ax.text(0, 0.5, f'ACE: {100*metrics["ace"]:.3f}%', fontsize='large')
#     ax.set_title(name)

### From synthetic data

In [None]:
class SyntheticImageDataset(Dataset):
    def __init__(self, generator, max_len):
        self.G = generator
        self.max_len = max_len

    def __len__(self):
        return self.max_len

    def __getitem__(self, idx):
        z = torch.randn([1, G.z_dim]).cuda() # latent codes
        label = torch.randint(G.c_dim, (1,)).cuda()
        c = torch.nn.functional.one_hot(label, num_classes=G.c_dim) # class labels
        img = G(z, c, truncation_psi=1) # NCHW, float32, dynamic range [-1, +1]
        img = postprocess_synthetic_images(img)
        img = preprocess_images_classifier(img).squeeze()
        return img, label.squeeze()

class FilteredSyntheticImageDataset(Dataset):
    def __init__(self, generator, max_len, classifier):
        self.G = generator
        self.max_len = max_len

    def __len__(self):
        return self.max_len

    def __getitem__(self, idx):
        z = torch.randn([1, G.z_dim]).cuda() # latent codes
        label = torch.randint(G.c_dim, (1,)).cuda()
        c = torch.nn.functional.one_hot(label, num_classes=G.c_dim) # class labels
        img = G(z, c, truncation_psi=1) # NCHW, float32, dynamic range [-1, +1]
        img = postprocess_synthetic_images(img)
        img = preprocess_images_classifier(img).squeeze()
        return img, label.squeeze()
    
dataset_synthetic = SyntheticImageDataset(G, 10000)
dataloader_calib_synthetic = DataLoader(dataset_synthetic, batch_size)

In [None]:
# def get_MSP_correct(dataloader, classifier):

#     classifier.eval()
#     msp = torch.zeros((len(dataloader.dataset)))
#     correct = torch.zeros((len(dataloader.dataset)))
#     idx = 0
#     for X, y in dataloader:
#         batch_size = X.shape[0]
#         X, y = X.cuda(), y.cuda()

#         with torch.no_grad():
#             logits = classifier(X)
#             probas, class_pred = torch.max(torch.softmax(logits, axis=1), axis=1)
#         msp[idx:idx+batch_size] = probas
#         correct[idx:idx+batch_size] = class_pred == y.squeeze()
#         idx += batch_size

#     return msp, correct

# msp_synthetic, correct_synthetic = get_MSP_correct(dataloader_calib_synthetic, classifier)

# fig, ax = plt.subplots(1, 1, figsize=(5, 5))
# ax.set_xlabel('MSP value')
# ax.hist(msp_test, alpha=0.5, bins=50, log=True, density=True, label='test')
# ax.hist(msp_synthetic, alpha=0.5, bins=50, log=True, density=True, label='synthetic')
# ax.legend()

# print(f'generator - classifier agreement: {correct_synthetic.sum() / len(correct_synthetic):.3f}')

In [None]:
metrics_calib_before_TS = metrics_from_dataloader(classifier, dataloader_calib_synthetic)
metrics_test_before_TS = metrics_from_dataloader(classifier, dataloader_test)
confid_calib_before_TS, acc_calib_before_TS, in_bin_calib, bin_boundaries = hist_from_dataloader(classifier, dataloader_calib_synthetic)
confid_test_before_TS, acc_test_before_TS, in_bin_test, bin_boundaries = hist_from_dataloader(classifier, dataloader_test)

# Performing temperature scaling
model = ModelWithTemperature(classifier).cuda()
model.set_temperature(dataloader_calib_synthetic)

metrics_calib_after_TS = metrics_from_dataloader(model, dataloader_calib_synthetic)
metrics_test_after_TS = metrics_from_dataloader(model, dataloader_test)
confid_calib_after_TS, acc_calib_after_TS, _, bin_boundaries = hist_from_dataloader(model, dataloader_calib_synthetic)
confid_test_after_TS, acc_test_after_TS, _, bin_boundaries = hist_from_dataloader(model, dataloader_test)

In [None]:
fig, axs = plt.subplots(3, 2, figsize=(10, 15))

# histo of data
axs[0, 0].stairs(in_bin_calib, bin_boundaries, fill=True)
axs[0, 0].set_title('on calib data')
axs[0, 1].stairs(in_bin_test, bin_boundaries, fill=True)
axs[0, 1].set_title('on test data')
for ax in [axs[0, 0], axs[0, 1]]:
    ax.set_xlabel('confidence')
    ax.set_ylabel('% of samples')

# reliability diagrams
for ax in axs.flatten()[2:]:
    ax.set_xlabel('confidence')
    ax.set_ylabel('accuracy')

list_names = ['on calib data before TS', 'on test data before TS', 'on calib data after TS', 'on test data after TS']
list_avg_confid = [confid_calib_before_TS, confid_test_before_TS, confid_calib_after_TS, confid_test_after_TS]
list_acc = [acc_calib_before_TS, acc_test_before_TS, acc_calib_after_TS, acc_test_after_TS]
list_metrics = [metrics_calib_before_TS, metrics_test_before_TS, metrics_calib_after_TS, metrics_test_after_TS]
for ax, name, avg_confid, acc, metrics in zip(axs.flatten()[2:], list_names, list_avg_confid, list_acc, list_metrics):
    ax.stairs(avg_confid, bin_boundaries, fill=True, alpha=0.8, label='perfect')
    ax.stairs(acc, bin_boundaries, fill=True, alpha=0.8, label='real')
    ax.legend(loc='lower right')
    if 'after TS' in name: ax.text(0, 0.1, f'Temp: {model.temperature.item():.3f}', fontsize='large')
    ax.text(0, 0.8, f'ECE: {100*metrics["ece"]:.3f}%', fontsize='large')
    ax.text(0, 0.7, f'SCE: {100*metrics["sce"]:.3f}%', fontsize='large')
    ax.text(0, 0.6, f'RMSCE: {100*metrics["rmsce"]:.3f}%', fontsize='large')
    ax.text(0, 0.5, f'ACE: {100*metrics["ace"]:.3f}%', fontsize='large')
    ax.set_title(name)

In [None]:
# 10000 samples
# ece 2.658
# sce 0.468
# rmsce 4.701
# ace 0.797
# temp 1.113

# 20000 samples
# almost same (temp 1.114)

# 5000 samples
# a bit worse (temp 1.120)

## Vector scaling from real validation data

In [None]:
metrics_calib_before_TS = metrics_from_dataloader(classifier, dataloader_calib)
metrics_test_before_TS = metrics_from_dataloader(classifier, dataloader_test)
confid_calib_before_TS, acc_calib_before_TS, in_bin_calib, bin_boundaries = hist_from_dataloader(classifier, dataloader_calib)
confid_test_before_TS, acc_test_before_TS, in_bin_test, bin_boundaries = hist_from_dataloader(classifier, dataloader_test)

# Fit vector scaling
vs = VectorScaling(logit_input=True, logit_constant=0.0)
logits_list = []
labels_list = []
for input, label in dataloader_calib:
    input = input.cuda()
    with torch.no_grad():
        logits = classifier(input)
    logits_list.append(logits)
    labels_list.append(label)
logits = torch.cat(logits_list).cpu()
labels = torch.cat(labels_list).cpu()
vs.fit(logits.numpy(), labels.numpy())

metrics_calib_after_TS = metrics_from_dataloader(classifier, dataloader_calib, vs)
metrics_test_after_TS = metrics_from_dataloader(classifier, dataloader_test, vs)
confid_calib_after_TS, acc_calib_after_TS, _, bin_boundaries = hist_from_dataloader(classifier, dataloader_calib, vs)
confid_test_after_TS, acc_test_after_TS, _, bin_boundaries = hist_from_dataloader(classifier, dataloader_test, vs)


fig, axs = plt.subplots(3, 2, figsize=(10, 15))

# histo of data
axs[0, 0].stairs(in_bin_calib, bin_boundaries, fill=True)
axs[0, 0].set_title('on calib data')
axs[0, 1].stairs(in_bin_test, bin_boundaries, fill=True)
axs[0, 1].set_title('on test data')
for ax in [axs[0, 0], axs[0, 1]]:
    ax.set_xlabel('confidence')
    ax.set_ylabel('% of samples')

# reliability diagrams
for ax in axs.flatten()[2:]:
    ax.set_xlabel('confidence')
    ax.set_ylabel('accuracy')

list_names = ['on calib data before calib', 'on test data before calib', 'on calib data after calib', 'on test data after calib']
list_avg_confid = [confid_calib_before_TS, confid_test_before_TS, confid_calib_after_TS, confid_test_after_TS]
list_acc = [acc_calib_before_TS, acc_test_before_TS, acc_calib_after_TS, acc_test_after_TS]
list_metrics = [metrics_calib_before_TS, metrics_test_before_TS, metrics_calib_after_TS, metrics_test_after_TS]
for ax, name, avg_confid, acc, metrics in zip(axs.flatten()[2:], list_names, list_avg_confid, list_acc, list_metrics):
    ax.stairs(avg_confid, bin_boundaries, fill=True, alpha=0.8, label='perfect')
    ax.stairs(acc, bin_boundaries, fill=True, alpha=0.8, label='real')
    ax.legend(loc='lower right')
    # if 'after TS' in name: ax.text(0, 0.1, f'Temp: {model.temperature.item():.3f}', fontsize='large')
    ax.text(0, 0.8, f'ECE: {100*metrics["ece"]:.3f}%', fontsize='large')
    ax.text(0, 0.7, f'SCE: {100*metrics["sce"]:.3f}%', fontsize='large')
    ax.text(0, 0.6, f'RMSCE: {100*metrics["rmsce"]:.3f}%', fontsize='large')
    ax.text(0, 0.5, f'ACE: {100*metrics["ace"]:.3f}%', fontsize='large')
    ax.set_title(name)

In [None]:
labels

In [None]:

from sklearn.model_selection import (train_test_split,
                                     StratifiedKFold,
                                     GridSearchCV,
                                     cross_val_score)
from dirichletcal.calib.fulldirichlet import FullDirichletCalibrator

skf = StratifiedKFold(n_splits=3, shuffle=True, random_state=0)

reg = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5]
# Full Dirichlet
calibrator = FullDirichletCalibrator(reg_lambda=reg, reg_mu=None)
# ODIR Dirichlet
#calibrator = FullDirichletCalibrator(reg_lambda=reg, reg_mu=reg)
gscv = GridSearchCV(calibrator, param_grid={'reg_lambda':  reg,
                                            'reg_mu': [None]},
                    cv=skf, scoring='neg_log_loss')
gscv.fit(logits.numpy(), labels.numpy())

print('Grid of parameters cross-validated')
print(gscv.param_grid)
print('Best parameters: {}'.format(gscv.best_params_))