In [None]:
## Environment

### Main imports

# !pip install --upgrade "numpy==1.20.2"
import os
import numpy as np
import pandas as pd
import torch

### Plots

from matplotlib.cm import get_cmap
import matplotlib.pyplot as plt
from IPython.display import set_matplotlib_formats
from cycler import cycler

plt.rc('axes', axisbelow=True, grid=True)
plt.rc('grid', c='grey', ls=':')
plt.rc('font', family='serif')
plt.rc('axes', prop_cycle=cycler(color='bmrcgyk'))
plt.rc('image', cmap='gist_rainbow')  # gist_rainbow
plt.rc('savefig', bbox='tight', pad_inches=0.1, format='pdf')
# set_matplotlib_formats('png')

### Also

rs = {'random_state': 0}
rng = np.random.default_rng(seed=0)

# Metrics, visualization

In [None]:
def bins_reliability_binary(y_true, y_confs, n_bins=10):
    '''
    Args:
        y_true: np.array (n,) of 0 and 1, real classes
        y_confs: np.array (n,), predicted probabilities of positive class
        n_bins: int, number of bins
    
    Returns:
        bin_confs: np.array (n,), mean confidence for each bin
        bin_accs: np.array (n,), frequency of positives for each bin
        weights: np.array (n,), normalized number of samples for each bin
    '''
    bins = np.linspace(0, 1, n_bins + 1)
    # [0, 0.1), [0.1, 0.2), ..., [0.9, 1.0 + eps)
    bins[-1] = 1 + 1e-10
    # find which bin each sample is assigned to
    bin_inds = np.digitize(y_confs, bins, right=False) - 1
    # count number of samples in each bin
    total = np.bincount(bin_inds, minlength=n_bins)
    # find mean confidence for each bin
    bin_confs = np.bincount(bin_inds, y_confs, minlength=n_bins)
    np.divide(bin_confs, total, out=bin_confs, where=total!=0)
    # find accuracy for each bin
    bin_accs = np.bincount(bin_inds, y_true, minlength=n_bins)
    np.divide(bin_accs, total, out=bin_accs, where=total!=0)
    weights = total / total.sum()
    return bin_confs, bin_accs, weights

def bins_reliability_multiclass(true_classes, confs, n_bins=10):
    '''
    Args:
        true_classes: np.array (n,) of integers in range(0, n_classes)
        confs: np.array (n, n_classes) of predicted probabilities
        n_bins: int, number of bins
    
    Returns:
        bin_confs: np.array (n,), mean confidence for each bin
        bin_accs: np.array (n,), accuracy for each bin
        weights: np.array (n,), normalized number of samples for each bin
    '''
    is_correct = (true_classes == confs.argmax(axis=1))
    prediction_confs = confs.max(axis=1)
    return bins_reliability_binary(is_correct, prediction_confs, n_bins)

In [None]:
from scipy.special import softmax
from sklearn.metrics import log_loss


def ECE(bin_confs, bin_accs, weights):
    '''
    Args (returns from bins_reliability):
        bin_confs: np.array (n,), mean confidence for each bin
        bin_accs: np.array (n,), accuracy for each bin
        weights: np.array (n,)

    Returns:
        ece: expected calibration error
    '''
    diffs = np.abs(bin_confs - bin_accs)
    ece = np.average(diffs, weights=weights)
    return ece

def MCE(bin_confs, bin_accs, weights=None):
    '''
    Args (returns from bins_reliability):
        bin_confs: np.array (n,), mean confidence for each bin
        bin_accs: np.array (n,), accuracy for each bin
        weights: np.array (n,), unused

    Returns:
        mce: maximum calibration error
    '''
    diffs = np.abs(bin_confs - bin_accs)
    mce = diffs.max()
    return mce

def BS(true_classes, confs):
    onehot = np.zeros_like(confs)
    onehot[np.arange(len(confs)), true_classes] = 1
    return ((onehot - confs) ** 2).sum(axis=1).mean()

def all_metrics(true_classes, confs=None, n_bins=15, mul=100,
                return_rel=False):
    '''
    Args:
        true_classes: np.array (n,) of integers in range(0, n_classes)
        confs: np.array (n, n_classes) of predicted probabilities
        n_bins: int, number of bins (for computing binning metrics)
        mul: float, multiplier for metrics with values in range [0, 1],
             default is 100
        return_rel: bool, whether to return (bin_confs, bin_accs, weights) - 
                    returns of bins_reliability_multiclass, default is False
    
    Returns:
        dictionary with metrics:
            'ACC': accuracy times mul,
            'ECE': expected calibration error times mul,
            'MCE': maximum calibration error times mul,
            'NLL': negative log likelihood,
            'BS': Brier score
        rel: tuple (check Args)
    '''
    if confs is None:
        confs = softmax(logits, axis=1)
    metrics = {}
    metrics['ACC'] = (confs.argmax(axis=1) == true_classes).mean() * mul
    rel = bins_reliability_multiclass(true_classes, confs, n_bins)
    metrics['ECE'] = ECE(*rel) * mul
    metrics['MCE'] = MCE(*rel) * mul
    metrics['BS'] = BS(true_classes, confs)
    metrics['NLL'] = log_loss(true_classes, confs)
    if return_rel:
        return metrics, rel
    else:
        return metrics

In [None]:
def _hist_plot_adjust(pad=0.00):
    plt.xlim(-pad, 1 + pad)
    plt.ylim(-pad, 1 + pad)
    plt.xticks(np.linspace(0, 1, 6))
    plt.yticks(np.linspace(0.2, 1.0, 5))
    plt.gca().set_aspect('equal')
    plt.gca().tick_params(length=0)

def _reliability_plot(bin_confs, bin_accs, weights=None,
                      name='reliability plot', acc_label='accuracy',
                      show=False, path=None):
    '''
    Args:
        bin_confs: np.array (n,), mean confidence for each bin
        bin_accs: np.array (n,), accuracy for each bin or class frequencies
        weights: np.array (n,)
        name: str, plot title
        acc_label: str, meaning of bin_accs
        show: bool, if True, plt.show() will be called
        path: str, location to save figure, default is None

    '''
    n_bins = len(bin_confs)
    bins = np.linspace(0, 1, n_bins + 1)
    centers = (bins[:-1] + bins[1:]) / 2
    plt.bar(centers, bin_confs, color=(1, 0, 0, 0.5), edgecolor='black',
            label='confidence', width=1/n_bins)
    plt.bar(centers, bin_accs, color=(0, 0, 1, 0.5), edgecolor='black',
            label=acc_label, width=1/n_bins)
    if weights is not None:
        plt.bar(centers, weights, color=(0, 1.0, 0.5, 0.8), edgecolor='black',
                label='weight', width=0.5/n_bins)
    plt.plot([0, 1], [0, 1], color='silver', linestyle='--')
    plt.xlabel('confidence')
    #  plt.ylabel('accuracy') not only...
    plt.legend()
    plt.title(name)
    _hist_plot_adjust()
    if show:
        plt.show()
    if path is not None:
        plt.savefig(path)

def reliability_plot(confs, true_classes, n_bins=10, **kwargs):
    '''
    Args:
        confs: np.array (n, n_classes) of predicted probabilities
        true_classes: np.array (n,) of integers in range(0, n_classes)
        n_bins: int, number of bins (for computing binning metrics)
        kwargs: keyword arguments passed to _reliability_plot

    '''
    rel = bins_reliability_multiclass(true_classes, confs, n_bins)
    _reliability_plot(*rel, **kwargs) 

# Методы калибровки

## Transform confidences
input: confidence,
output: calibrated confidence

In [None]:
from functools import partial
from sklearn.isotonic import IsotonicRegression

class HistogramBinningBinary:
    def __init__(self, n_bins=10):
        self.n_bins = n_bins
        self.thetas = None
        self.bins = np.linspace(0, 1, n_bins + 1)
        self.bins[-1] += 1e-10

    def fit(self, y_confs, y_true):
        '''
        Args:
            y_confs: np.array estimated probabilities of positive class
            y_true: np.array of 0 and 1
        '''
        _, thetas, weights = bins_reliability_binary(y_true, y_confs,
                                                     n_bins=self.n_bins)
        centers = ((self.bins[:-1] + self.bins[1:]) * 0.5)
        thetas[weights == 0] = centers[weights == 0]
        self.thetas = thetas
    
    def transform(self, y_confs):
        '''
        Args:
            y_confs: uncalibrated estimated probability of positive class
        Returns:
            y_confs_calib: calibrated probabilities
        '''
        y_confs_calib = self.thetas[np.digitize(y_confs, self.bins) - 1]
        return y_confs_calib

class IRBinary(IsotonicRegression):
    '''
    Isotonic regression wrapper for binary calibration.
    '''
    def __init__(self):
        super().__init__(increasing=True, out_of_bounds='clip',
                         y_min=0.0, y_max=1.0)

class CalibratorOvR:
    def __init__(self, base, **kwargs):
        '''
        Args:
            base: class of binary calibrator
            kwargs: keyword arguments to initialize each calibrator
        '''
        self.base = partial(base, **kwargs)
        self.ovr_calibrators = []
    
    def fit(self, confs, true_classes):
        '''
        Args:
            confs: np.array (n, n_classes) of predicted probabilities (uncalibrated)
            true_classes: np.array (n,) of integers in range(0, n_classes)
        '''
        self.ovr_calibrators = []
        for class_ in range(confs.shape[1]):
            calibrator = self.base()
            calibrator.fit(confs[:, class_], (true_classes == class_).astype(int))
            self.ovr_calibrators.append(calibrator)
    
    def transform(self, confs):
        '''
        Args:
            confs: np.array (n, n_classes) of predicted probabilities
        Returns:
            calbrated_confs: np.array (n, n_classes) of calibrated probabilities
        '''
        cal_confs = np.stack([calibrator.transform(confs[:, class_]) for
            class_, calibrator in enumerate(self.ovr_calibrators)], axis=1)
        cal_confs /= cal_confs.sum(axis=1, keepdims=True)
        return cal_confs

class HistogramBinningMulticlass(CalibratorOvR):
    def __init__(self, n_bins=15):
        super().__init__(HistogramBinningBinary, n_bins=n_bins)

class IsotonicRegressionMulticlass(CalibratorOvR):
    def __init__(self):
        super().__init__(IRBinary)

## Transform logits
input: logits,
output: calibrated logits

In [None]:
from torch.nn.functional import cross_entropy

def tt(np_array):
    return torch.from_numpy(np_array)

class LogitScaling:
    def __init__(self, scale_type='temperature', bias_type='none'):
        '''
        scale_type: str:
            'temperature': single-parameter logits' scaling
            'vector': vector scaling (logits multiplied by diagonal matrix)
            'matrix': matrix scaling
        bias_type: str
            'intercept': one bias (number) for all classes
            'vector': individual biases for each class
            'none': no bias term
        '''
        self.scale_type = scale_type
        self.bias_type = bias_type
        self.scale = None
        self.bias = None
    
    def fit(self, logits_val, targets_val, device='cpu', lr=1.0, max_iter=1000):
        logits_cal = tt(logits_val).to(device)
        targets_cal = tt(targets_val).to(device)
        kwargs = {'dtype': logits_cal.dtype, 'requires_grad': True, 'device': device}
        if self.scale_type == 'temperature':
            scale = torch.tensor(1.0, **kwargs)
        elif self.scale_type == 'vector':
            scale = torch.ones(logits_cal.shape[1], **kwargs)
        elif scale_type == 'matrix':
            scale = torch.eye(logits_cal.shape[1], **kwargs)
        params = [scale]
        if self.bias_type == 'intercept':
            bias = torch.tensor(0.0, **kwargs)
            params.append(bias)
        elif self.bias_type == 'vector':
            bias = torch.zeros(logits_cal.shape[1], **kwargs)
            params.append(bias)
        else:
            bias = torch.tensor(0, requires_grad=False, dtype=logits_cal.dtype, device=device)
            
        optimizer = torch.optim.LBFGS(params, lr=lr, max_iter=max_iter)
        def closure():
            optimizer.zero_grad()
            if self.scale_type == 'matrix':
                loss = cross_entropy(logits_cal @ scale + bias, targets_cal)
            else:
                loss = cross_entropy(logits_cal * scale + bias, targets_cal)
            loss.backward()
            return loss
        optimizer.step(closure)
        self.n_iter = optimizer.state[scale]['n_iter']
        self.scale = scale.detach().cpu().numpy()
        self.bias = bias.detach().cpu().numpy()
    
    def transform(self, logits_test):
        if self.scale_type == 'matrix':
            return softmax(logits_test @ self.scale + self.bias, axis=1)
        else:
            return softmax(logits_test * self.scale + self.bias, axis=1)

# Compute results

In [None]:
def get_logits(path):
    logitss = {}
    targets = None
    for fname in os.listdir(path):
        if fname == 'targets.txt':
            with open(os.path.join(path, fname), 'r') as fin:
                targets = np.array([int(target) for target in fin.read().split()])
        elif fname[-3:] == '.pt':
            logitss[fname[:-3]] = torch.load(os.path.join(path, fname)).numpy()
    return logitss, targets    

def upd_metrics(metrics_dict, new_metrics,
                calib_name, model_dataset_tuple):
    for metric_name, value in new_metrics.items():
        metrics_dict[metric_name][calib_name][model_dataset_tuple] = value

In [None]:
from sklearn.model_selection import train_test_split
from collections import defaultdict

logit_path = '../input/calibration/logits/'
dataset_names = ['cifar10_v1',
                 'cifar100',
                 'imagenet']

calibrators_confs = {
    'Hist-binning': HistogramBinningMulticlass(n_bins=15),
    'Isotonic': IsotonicRegressionMulticlass(),
}

calibrators_logits = {
    'T-scaling': LogitScaling(scale_type='temperature', bias_type='none'),
    'V-scaling': LogitScaling(scale_type='vector', bias_type='none'),
    'V-scaling + bias': LogitScaling(scale_type='vector', bias_type='vector'),
}

In [None]:
metrics_val = defaultdict(lambda: defaultdict(dict))
metrics_test = defaultdict(lambda: defaultdict(dict))

for dataset_name in dataset_names:
    print(f'== dataset {dataset_name} ==')
    fpath = os.path.join(logit_path, dataset_name)
    logitss, targets = get_logits(fpath)
    if dataset_name == 'cifar10_v1':
        dataset_name = 'cifar10'
    for model_name, logits in logitss.items():
        print(f'{model_name}, ', end='')
        model_dataset_tuple = (dataset_name, model_name)
        logits_val, logits_test, targets_val, targets_test = train_test_split(
        logits, targets, test_size=0.5, stratify=targets, **rs)
        confs_val = softmax(logits_val, axis=1)
        confs_test = softmax(logits_test, axis=1)
        
        # No calibration
        upd_metrics(metrics_val, all_metrics(targets_val, confs_val),
                    'До калибровки', model_dataset_tuple)
        upd_metrics(metrics_test, all_metrics(targets_test, confs_test),
                    'До калибровки', model_dataset_tuple)
        
        # Transforming confs
        for cal_name, calibrator in calibrators_confs.items():
            calibrator.fit(confs_val, targets_val)
            confs_val_cal = calibrator.transform(confs_val)
            confs_test_cal = calibrator.transform(confs_test)
            upd_metrics(metrics_val, all_metrics(targets_val, confs_val_cal),
                        cal_name, model_dataset_tuple)
            upd_metrics(metrics_test, all_metrics(targets_test, confs_test_cal),
                        cal_name, model_dataset_tuple)
        
        # Transforming logits
        for cal_name, calibrator in calibrators_logits.items():
            calibrator.fit(logits_val, targets_val)
            confs_val_cal = calibrator.transform(logits_val)
            confs_test_cal = calibrator.transform(logits_test)
            upd_metrics(metrics_val, all_metrics(targets_val, confs_val_cal),
                        cal_name, model_dataset_tuple)
            upd_metrics(metrics_test, all_metrics(targets_test, confs_test_cal),
                        cal_name, model_dataset_tuple)
        
    print('DONE')

In [None]:
def format_tex(data, mode='min', format_string='%.2f'):
    if mode == 'min':
        mask = data != data.min()
    elif mode == 'max':
        mask = data != data.max()
    else:
        mask = np.ones_like(data, dtype=bool)
    bolded = data.apply(lambda x : 'BOLDLEFT%sBOLDRIGHT' % format_string % x)
    formatted = data.apply(lambda x : format_string % x)
    return formatted.where(mask, bolded)

In [None]:
format_kwargs = {
    'ACC': {'mode': 'max', 'format_string': '%.3f'},
    'ECE': {'mode': 'min', 'format_string': '%.2f'},
    'MCE': {'mode': 'min', 'format_string': '%.2f'},
    'BS': {'mode': 'min', 'format_string': '%.3f'},
    'NLL': {'mode': 'min', 'format_string': '%.3f'},
}

app = 'Значения метрики приводятся для тестовой выборки до и после калибровки.'
captions = {
    'ACC': r'Accuracy, \% -- доля правильных ответов (больше -- лучше). ' + app,
    'ECE': r'ECE, \% -- Expected Calibration Error, 15 бинов (меньше -- лучше). ' + app,
    'MCE': r'MCE, \% -- Maximum Calibration Error, 15 бинов (меньше -- лучше). ' + app,
    'BS': r'Brier Score (меньше -- лучше). ' + app,
    'NLL': r'Negative Log-Likelihood (меньше -- лучше). ' + app,
}

In [None]:
!mkdir tabs

In [None]:
for metric_name in metrics_test:
    df = pd.DataFrame(metrics_test[metric_name])
    df = df.sort_index().apply(format_tex, **format_kwargs[metric_name], axis=1)
    df = df.reset_index().rename({'level_0': 'Данные', 'level_1': 'Модель'}, axis=1)
    df_latex = df.to_latex(label=f'tab:metrics:{metric_name}', index=False, position='h!', caption=captions[metric_name])
    # bold extreme values
    df_latex = df_latex.replace('BOLDLEFT', r'\textbf{').replace('BOLDRIGHT', r'}')
    # Adjust table to text width
    df_latex = df_latex.replace(r'\begin{tabular}', r'\resizebox{\textwidth}{!}{\begin{tabular}')
    df_latex = df_latex.replace(r'\end{tabular}', r'\end{tabular}}')
    # Center values
    df_latex = df_latex.replace(r'{llllllll}', r'{llcccccc}')
    with open(f'tabs/metrics_{metric_name}.tex', 'w') as fout:
        fout.write(df_latex)

In [None]:
!zip -r tabs.zip tabs

# Визуализации

In [None]:
!mkdir vis

## Модельные

In [None]:
from sklearn.svm import LinearSVC
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

X, y = make_classification(n_samples=20000, n_features=19,
                           n_informative=3, n_redundant=10,
                           random_state=0)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.5, random_state=0)
X_train, X_val, y_train, y_val = train_test_split(
    X_train, y_train, test_size=0.3, random_state=0)
model = LinearSVC(random_state=0, max_iter=10000)
_ = model.fit(X_train, y_train)

y_val_confs = model.decision_function(X_val)
y_min = y_val_confs.min()
y_max = y_val_confs.max()

y_val_confs = (y_val_confs - y_min) / (y_max - y_min)
y_test_confs = model.decision_function(X_test)
y_test_confs = (y_test_confs - y_min) / (y_max - y_min)
y_test_confs = np.clip(y_test_confs, 0, 1)

In [None]:
def binconf2mc(y_confs):
    return np.stack([1 - y_confs, y_confs], axis=1)

bin_confs, bin_accs, weights = bins_reliability_binary(y_test, y_test_confs)

In [None]:
plt.figure(figsize=(11, 5))
plt.subplot(131)
plt.plot(bin_confs, bin_accs, marker='o', label='calibration curve', color='m')
plt.plot([0, 1], [0, 1], color='silver', linestyle='--', label='perfect calibraion')
_hist_plot_adjust(pad=0.0)
plt.xlabel('confidence')
plt.ylabel('positive frequency')
plt.title('Reliability plot (binary)')
plt.legend()
plt.text(0.5, -0.2, '(a)', ha='center')

plt.subplot(132)
_reliability_plot(bin_confs, bin_accs, weights, name='Reliability diagram (binary)', acc_label='positive frequency')
plt.text(0.5, -0.2, '(b)', ha='center')

plt.subplot(133)
reliability_plot(binconf2mc(y_test_confs), y_test, name='Reliability diagram (multiclass)')
plt.text(0.5, -0.2, '(c)', ha='center')

plt.tight_layout()
plt.savefig('vis/rel_intro')

## Нейросети

### <font color=blue>WRAP IT TO THE GENERAL FUNCTION</font>

In [None]:
metrics_val = defaultdict(lambda: defaultdict(dict))
metrics_test = defaultdict(lambda: defaultdict(dict))

plt.figure(figsize=(11, 8))
dataset_name = 'cifar100'  # !
fpath = os.path.join(logit_path, dataset_name)
logitss, targets = get_logits(fpath)
model_name = 'shufflenetv2_x0_5'  # !
logits = logitss[model_name]

model_dataset_tuple = (dataset_name, model_name)
logits_val, logits_test, targets_val, targets_test = train_test_split(
logits, targets, test_size=0.5, stratify=targets, **rs)
confs_val = softmax(logits_val, axis=1)
confs_test = softmax(logits_test, axis=1)

# No calibration
upd_metrics(metrics_val, all_metrics(targets_val, confs_val),
            'До калибровки', model_dataset_tuple)
upd_metrics(metrics_test, all_metrics(targets_test, confs_test),
            'До калибровки', model_dataset_tuple)

i = 1
plt.subplot(2, 3, i)
reliability_plot(confs_test, targets_test, name='До калибровки')

# Transforming confs
for cal_name, calibrator in calibrators_confs.items():
    calibrator.fit(confs_val, targets_val)
    confs_val_cal = calibrator.transform(confs_val)
    confs_test_cal = calibrator.transform(confs_test)
    upd_metrics(metrics_val, all_metrics(targets_val, confs_val_cal),
                cal_name, model_dataset_tuple)
    upd_metrics(metrics_test, all_metrics(targets_test, confs_test_cal),
                cal_name, model_dataset_tuple)
    i += 1
    plt.subplot(2, 3, i)
    reliability_plot(confs_test_cal, targets_test, name=cal_name)

# Transforming logits
for cal_name, calibrator in calibrators_logits.items():
    calibrator.fit(logits_val, targets_val)
    confs_val_cal = calibrator.transform(logits_val)
    confs_test_cal = calibrator.transform(logits_test)
    upd_metrics(metrics_val, all_metrics(targets_val, confs_val_cal),
                cal_name, model_dataset_tuple)
    upd_metrics(metrics_test, all_metrics(targets_test, confs_test_cal),
                cal_name, model_dataset_tuple)
    i += 1
    plt.subplot(2, 3, i)
    reliability_plot(confs_test_cal, targets_test, name=cal_name)
    
plt.tight_layout()
plt.savefig(f'calibration_{dataset_name}_{model_name}')

In [None]:
metrics_val = defaultdict(lambda: defaultdict(dict))
metrics_test = defaultdict(lambda: defaultdict(dict))

plt.figure(figsize=(11, 8))
dataset_name = 'imagenet'  # !
fpath = os.path.join(logit_path, dataset_name)
logitss, targets = get_logits(fpath)
model_name = 'tf_efficientnet_b8'  # !
logits = logitss[model_name]

model_dataset_tuple = (dataset_name, model_name)
logits_val, logits_test, targets_val, targets_test = train_test_split(
logits, targets, test_size=0.5, stratify=targets, **rs)
confs_val = softmax(logits_val, axis=1)
confs_test = softmax(logits_test, axis=1)

# No calibration
upd_metrics(metrics_val, all_metrics(targets_val, confs_val),
            'До калибровки', model_dataset_tuple)
upd_metrics(metrics_test, all_metrics(targets_test, confs_test),
            'До калибровки', model_dataset_tuple)

i = 1
plt.subplot(2, 3, i)
reliability_plot(confs_test, targets_test, name='До калибровки')

# Transforming confs
for cal_name, calibrator in calibrators_confs.items():
    calibrator.fit(confs_val, targets_val)
    confs_val_cal = calibrator.transform(confs_val)
    confs_test_cal = calibrator.transform(confs_test)
    upd_metrics(metrics_val, all_metrics(targets_val, confs_val_cal),
                cal_name, model_dataset_tuple)
    upd_metrics(metrics_test, all_metrics(targets_test, confs_test_cal),
                cal_name, model_dataset_tuple)
    i += 1
    plt.subplot(2, 3, i)
    reliability_plot(confs_test_cal, targets_test, name=cal_name)

# Transforming logits
for cal_name, calibrator in calibrators_logits.items():
    calibrator.fit(logits_val, targets_val)
    confs_val_cal = calibrator.transform(logits_val)
    confs_test_cal = calibrator.transform(logits_test)
    upd_metrics(metrics_val, all_metrics(targets_val, confs_val_cal),
                cal_name, model_dataset_tuple)
    upd_metrics(metrics_test, all_metrics(targets_test, confs_test_cal),
                cal_name, model_dataset_tuple)
    i += 1
    plt.subplot(2, 3, i)
    reliability_plot(confs_test_cal, targets_test, name=cal_name)
    
plt.tight_layout()
plt.savefig(f'calibration_{dataset_name}_{model_name}')

In [None]:
metrics_val = defaultdict(lambda: defaultdict(dict))
metrics_test = defaultdict(lambda: defaultdict(dict))

plt.figure(figsize=(11, 8))
dataset_name = 'cifar10_v1'  # !
fpath = os.path.join(logit_path, dataset_name)
logitss, targets = get_logits(fpath)
model_name = 'googlenet'  # !
logits = logitss[model_name]

model_dataset_tuple = (dataset_name, model_name)
logits_val, logits_test, targets_val, targets_test = train_test_split(
logits, targets, test_size=0.5, stratify=targets, **rs)
confs_val = softmax(logits_val, axis=1)
confs_test = softmax(logits_test, axis=1)

# No calibration
upd_metrics(metrics_val, all_metrics(targets_val, confs_val),
            'До калибровки', model_dataset_tuple)
upd_metrics(metrics_test, all_metrics(targets_test, confs_test),
            'До калибровки', model_dataset_tuple)

i = 1
plt.subplot(2, 3, i)
reliability_plot(confs_test, targets_test, name='До калибровки')

# Transforming confs
for cal_name, calibrator in calibrators_confs.items():
    calibrator.fit(confs_val, targets_val)
    confs_val_cal = calibrator.transform(confs_val)
    confs_test_cal = calibrator.transform(confs_test)
    upd_metrics(metrics_val, all_metrics(targets_val, confs_val_cal),
                cal_name, model_dataset_tuple)
    upd_metrics(metrics_test, all_metrics(targets_test, confs_test_cal),
                cal_name, model_dataset_tuple)
    i += 1
    plt.subplot(2, 3, i)
    reliability_plot(confs_test_cal, targets_test, name=cal_name)

# Transforming logits
for cal_name, calibrator in calibrators_logits.items():
    calibrator.fit(logits_val, targets_val)
    confs_val_cal = calibrator.transform(logits_val)
    confs_test_cal = calibrator.transform(logits_test)
    upd_metrics(metrics_val, all_metrics(targets_val, confs_val_cal),
                cal_name, model_dataset_tuple)
    upd_metrics(metrics_test, all_metrics(targets_test, confs_test_cal),
                cal_name, model_dataset_tuple)
    i += 1
    plt.subplot(2, 3, i)
    reliability_plot(confs_test_cal, targets_test, name=cal_name)
    
plt.tight_layout()
plt.savefig(f'calibration_cifar10_{model_name}')

In [None]:
!zip -r vis.zip vis

# OvR-checks

### No calibration

In [None]:
plt.figure(figsize=(10, 5))
plt.subplot(121)
rel = bins_reliability_binary(y_test, y_test_confs, n_bins=10)
_reliability_plot(*rel, acc_label='class 1 frequency', show=False)
plt.subplot(122)
rel = bins_reliability_multiclass(y_test, binconf2mc(y_test_confs), n_bins=10)
_reliability_plot(*rel)

In [None]:
plt.figure(figsize=(10, 5))
plt.subplot(121)
rel = bins_reliability_binary(y_val, y_val_confs, n_bins=10)
_reliability_plot(*rel, acc_label='class 1 frequency', show=False)
plt.subplot(122)
rel = bins_reliability_multiclass(y_val, binconf2mc(y_val_confs), n_bins=10)
_reliability_plot(*rel)

### Histogram binning (binary)

In [None]:
calibrator = HistogramBinningBinary(n_bins=10)
calibrator.fit(y_val_confs, y_val)
y_val_confs_calib = calibrator.transform(y_val_confs)

plt.figure(figsize=(10, 5))
plt.subplot(121)
rel = bins_reliability_binary(y_val, y_val_confs_calib, n_bins=10)
_reliability_plot(*rel, acc_label='class 1 frequency', show=False)
plt.subplot(122)
rel = bins_reliability_multiclass(y_val, binconf2mc(y_val_confs_calib), n_bins=10)
_reliability_plot(*rel, show=False)

In [None]:
calibrator = HistogramBinningBinary(n_bins=10)
calibrator.fit(y_val_confs, y_val)
y_test_confs_calib = calibrator.transform(y_test_confs)

plt.figure(figsize=(10, 5))
plt.subplot(121)
rel = bins_reliability_binary(y_test, y_test_confs_calib, n_bins=10)
_reliability_plot(*rel, acc_label='class 1 frequency', show=False)
plt.subplot(122)
rel = bins_reliability_multiclass(y_test, binconf2mc(y_test_confs_calib), n_bins=10)
_reliability_plot(*rel)

### Isotonic Regression (binary)

In [None]:
calibrator = IRBinary()
calibrator.fit(y_val_confs, y_val)
y_val_confs_calib = calibrator.transform(y_val_confs)

plt.figure(figsize=(10, 5))
plt.subplot(121)
rel = bins_reliability_binary(y_val, y_val_confs_calib, n_bins=10)
_reliability_plot(*rel, acc_label='class 1 frequency', show=False)
plt.subplot(122)
rel = bins_reliability_multiclass(y_val, binconf2mc(y_val_confs_calib), n_bins=10)
_reliability_plot(*rel, show=False)

In [None]:
calibrator = IRBinary()
calibrator.fit(y_val_confs, y_val)
y_test_confs_calib = calibrator.transform(y_test_confs)

plt.figure(figsize=(10, 5))
plt.subplot(121)
rel = bins_reliability_binary(y_test, y_test_confs_calib, n_bins=10)
_reliability_plot(*rel, acc_label='class 1 frequency', show=False)
plt.subplot(122)
rel = bins_reliability_multiclass(y_test, binconf2mc(y_test_confs_calib), n_bins=10)
_reliability_plot(*rel, show=False)

### Histogram binning (multiclass)

In [None]:
calibrator_mc = HistogramBinningMulticlass(n_bins=10)
calibrator_mc.fit(binconf2mc(y_val_confs), y_val)
y_val_confs_calib_mc = calibrator_mc.transform(binconf2mc(y_val_confs))
y_val_confs_calib = y_val_confs_calib_mc[:, 1]

plt.figure(figsize=(10, 5))
plt.subplot(121)
rel = bins_reliability_binary(y_val, y_val_confs_calib, n_bins=10)
_reliability_plot(*rel, acc_label='class 1 frequency', show=False)
plt.subplot(122)
rel = bins_reliability_multiclass(y_val, y_val_confs_calib_mc, n_bins=10)
_reliability_plot(*rel, show=False)

In [None]:
calibrator_mc = HistogramBinningMulticlass(n_bins=10)
calibrator_mc.fit(binconf2mc(y_val_confs), y_val)
y_test_confs_calib_mc = calibrator_mc.transform(binconf2mc(y_test_confs))
y_test_confs_calib = y_test_confs_calib_mc[:, 1]

plt.figure(figsize=(10, 5))
plt.subplot(121)
rel = bins_reliability_binary(y_test, y_test_confs_calib, n_bins=10)
_reliability_plot(*rel, acc_label='class 1 frequency', show=False)
plt.subplot(122)
rel = bins_reliability_multiclass(y_test, y_test_confs_calib_mc, n_bins=10)
_reliability_plot(*rel, show=False)

### Isotonic regression (multiclass)

In [None]:
calibrator_mc = IsotonicRegressionMulticlass()
calibrator_mc.fit(binconf2mc(y_val_confs), y_val)
y_val_confs_calib_mc = calibrator_mc.transform(binconf2mc(y_val_confs))
y_val_confs_calib = y_val_confs_calib_mc[:, 1]

plt.figure(figsize=(10, 5))
plt.subplot(121)
rel = bins_reliability_binary(y_val, y_val_confs_calib, n_bins=10)
_reliability_plot(*rel, acc_label='class 1 frequency', show=False)
plt.subplot(122)
rel = bins_reliability_multiclass(y_val, y_val_confs_calib_mc, n_bins=10)
_reliability_plot(*rel, show=False)

In [None]:
calibrator_mc = IsotonicRegressionMulticlass()
calibrator_mc.fit(binconf2mc(y_val_confs), y_val)
y_test_confs_calib_mc = calibrator_mc.transform(binconf2mc(y_test_confs))
y_test_confs_calib = y_test_confs_calib_mc[:, 1]

plt.figure(figsize=(10, 5))
plt.subplot(121)
rel = bins_reliability_binary(y_test, y_test_confs_calib, n_bins=10)
_reliability_plot(*rel, acc_label='class 1 frequency', show=False)
plt.subplot(122)
rel = bins_reliability_multiclass(y_test, y_test_confs_calib_mc, n_bins=10)
_reliability_plot(*rel, show=False)