## Environment

In [1]:
!cp -R ../input/calibration/calib calib

In [2]:
### Main imports

# !pip install --upgrade "numpy==1.20.2"
import os
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F

### Plots

from matplotlib.cm import get_cmap
import matplotlib.pyplot as plt
from IPython.display import set_matplotlib_formats
from cycler import cycler

plt.rc('axes', axisbelow=True, grid=True)
plt.rc('grid', c='grey', ls=':')
plt.rc('font', family='serif')
plt.rc('axes', prop_cycle=cycler(color='bmrcgyk'))
plt.rc('image', cmap='gist_rainbow')
plt.rc('savefig', bbox='tight', pad_inches=0.1, format='pdf')
# set_matplotlib_formats('png')

# Calibration utils
from calib.eval import *
from calib.calibrators import *
from calib.utils import *

# Measuer
from sklearn.model_selection import train_test_split
from collections import defaultdict
from scipy.special import softmax
from scipy.special import expit as sigmoid

### Also

rs = {'random_state': 0}
rng = np.random.default_rng(seed=0)

# Utils

In [3]:
def get_logits(path):
    logitss = {}
    targets = None
    for fname in os.listdir(path):
        if fname == 'targets.txt':
            with open(os.path.join(path, fname), 'r') as fin:
                targets = np.array([int(target) for target in fin.read().split()])
        elif fname[-3:] == '.pt':
            logitss[fname[:-3]] = torch.load(os.path.join(path, fname)).numpy()
        else:
            raise ValueError('Invalid name in logits path!')
    return logitss, targets    

def upd_metrics(metrics_dict, new_metrics,
                calib_name, model_dataset_tuple):
    for metric_name, value in new_metrics.items():
        metrics_dict[metric_name][calib_name][model_dataset_tuple] = value

In [4]:
logit_path = '../input/calibration/logits/'

dataset_names = [
    'cifar10_v1',
    'cifar100',
    'imagenet',
    'focal_CIFAR10',
    'focal_CIFAR100',
    'focal_TinyImageNet',
]

calibrators_confs = {
    'Hist-binning': HistogramBinningMulticlass(n_bins=20),
    'Isotonic': IsotonicRegressionMulticlass(),
}

calibrators_logits = {
    'T-scaling': LogitScaling(scale_type='temperature', bias_type='none'),
    'V-scaling': LogitScaling(scale_type='vector', bias_type='none'),
    'V-scaling-b': LogitScaling(scale_type='vector', bias_type='vector'),
    'M-scaling-b': LogitScaling(scale_type='matrix', bias_type='vector'),
}

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

## Compute all metrics

In [6]:
%%time

metrics_val = defaultdict(lambda: defaultdict(dict))
metrics_test = defaultdict(lambda: defaultdict(dict))
rels_val = defaultdict(dict)
rels_test = defaultdict(dict)

for dataset_name in dataset_names:
    print(f'== dataset {dataset_name} ==')
    fpath = os.path.join(logit_path, dataset_name)
    logitss, targets = get_logits(fpath)
    for model_name, logits in logitss.items():
        print(f'{model_name}, ', end='')
        model_dataset_tuple = (dataset_name, model_name)
        logits_val, logits_test, targets_val, targets_test = train_test_split(
            logits, targets, test_size=0.5, **rs)
        confs_val = softmax(logits_val, axis=1)
        confs_test = softmax(logits_test, axis=1)

        # No calibration
        upd_metrics(metrics_val, all_metrics(targets_val, confs_val),
                    'До калибровки', model_dataset_tuple)
        upd_metrics(metrics_test, all_metrics(targets_test, confs_test),
                    'До калибровки', model_dataset_tuple)
        rels_val[model_dataset_tuple]['До калибровки'] = bins_reliability_multiclass(targets_val, confs_val, 10)
        rels_test[model_dataset_tuple]['До калибровки'] = bins_reliability_multiclass(targets_test, confs_test, 10)

        # Transforming confs
        for cal_name, calibrator in calibrators_confs.items():
            calibrator.fit(confs_val, targets_val)
            confs_val_cal = calibrator.transform(confs_val)
            confs_test_cal = calibrator.transform(confs_test)
            upd_metrics(metrics_val, all_metrics(targets_val, confs_val_cal),
                        cal_name, model_dataset_tuple)
            upd_metrics(metrics_test, all_metrics(targets_test, confs_test_cal),
                        cal_name, model_dataset_tuple)
            rels_val[model_dataset_tuple][cal_name] = bins_reliability_multiclass(targets_val, confs_val_cal, 10)
            rels_test[model_dataset_tuple][cal_name] = bins_reliability_multiclass(targets_test, confs_test_cal, 10)

        # Transforming logits
        for cal_name, calibrator in calibrators_logits.items():
            calibrator.fit(logits_val, targets_val, device=device)
            confs_val_cal = calibrator.transform(logits_val)
            confs_test_cal = calibrator.transform(logits_test)
            upd_metrics(metrics_val, all_metrics(targets_val, confs_val_cal),
                        cal_name, model_dataset_tuple)
            upd_metrics(metrics_test, all_metrics(targets_test, confs_test_cal),
                        cal_name, model_dataset_tuple)
            rels_val[model_dataset_tuple][cal_name] = bins_reliability_multiclass(targets_val, confs_val_cal, 10)
            rels_test[model_dataset_tuple][cal_name] = bins_reliability_multiclass(targets_test, confs_test_cal, 10)
    print('DONE')

== dataset cifar10_v1 ==
densenet121, googlenet, mobilenet_v2, densenet169, resnet34, vgg13_bn, vgg16_bn, resnet50, vgg19_bn, inception_v3, vgg11_bn, resnet18, densenet161, DONE
== dataset cifar100 ==
shufflenetv2_x1_0, shufflenetv2_x2_0, resnet56, resnet32, vgg13_bn, mobilenetv2_x1_0, mobilenetv2_x0_5, vgg16_bn, resnet44, vgg19_bn, vgg11_bn, shufflenetv2_x0_5, resnet20, shufflenetv2_x1_5, mobilenetv2_x1_4, DONE
== dataset imagenet ==
tf_efficientnet_b8, mobilenetv2_120d, repvgg_b3, vgg19_bn, DONE
== dataset focal_CIFAR10 ==
wide_resnet_focal_loss_gamma_3.0, resnet110_cross_entropy, densenet121_focal_loss_gamma_2.0, densenet121_cross_entropy_smoothed_smoothing_0.05, densenet121_cross_entropy, resnet50_focal_loss_gamma_1.0, resnet110_focal_loss_gamma_1.0, resnet50_cross_entropy, resnet50_focal_loss_gamma_3.0, densenet121_focal_loss_gamma_3.0, wide_resnet_focal_loss_gamma_1.0, wide_resnet_cross_entropy_smoothed_smoothing_0.05, resnet110_cross_entropy_smoothed_smoothing_0.05, densenet121_

## Save checkpoint

In [7]:
import pickle

with open('metrics_test.pickle', 'wb') as fout:
    pickle.dump(dict(metrics_test), fout)
    
with open('metrics_val.pickle', 'wb') as fout:
    pickle.dump(dict(metrics_val), fout)

with open('rels_test.pickle', 'wb') as fout:
    pickle.dump(dict(rels_test), fout)
    
with open('rels_val.pickle', 'wb') as fout:
    pickle.dump(dict(rels_val), fout)

In [8]:
!zip exp_results.zip *.pickle
!rm *.pickle

  adding: metrics_test.pickle (deflated 54%)
  adding: metrics_val.pickle (deflated 54%)
  adding: rels_test.pickle (deflated 51%)
  adding: rels_val.pickle (deflated 52%)
