# Analysis of AdaTempScal on ResNet50 for CIFAR10

In [1]:
%load_ext autoreload

In [2]:
%autoreload 1

In [3]:
import os
import sys
import time
sys.path.extend(['..'])

import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from torchvision import transforms

from mixNmatch_cal import temperature_scaling
from scipy_models import LTS, HTS, HistTS, TS, HnLTS, BTS
from models import AdaTS, DNNbasedT
%aimport utils
from utils import compare_results, calib_split, get_CIFAR10_C, NumpyDataset, load_model, predict_logits, compute_metrics, onehot_encode, softmax
%aimport adats_utils
from adats_utils import fitAdaTS, fitCV_AdaTS, fitHistTS
%aimport mixNmatch_cal
from mixNmatch_cal import ets_calibrate, mir_calibrate

In [4]:
%matplotlib inline

In [5]:
plt.rc('text', usetex=True)
plt.rc('text.latex', preamble=r'\usepackage{amsmath} \usepackage{amssymb}')

In [6]:
dev = torch.device('cuda')

## Load data and precomputed logits

In [7]:
CIFAR10_PATH = '../../data/CIFAR10'
CIFAR10C_PATH = '../../data/CIFAR-10-C'
MODEL_PATH = '../../trained_models/CIFAR10/resnet50'

In [8]:
X_train = np.load(os.path.join(CIFAR10_PATH, 'train_imas.npy'))
y_train = np.load(os.path.join(CIFAR10_PATH, 'train_labels.npy'))

X_val = np.load(os.path.join(CIFAR10_PATH, 'val_imas.npy'))
y_val = np.load(os.path.join(CIFAR10_PATH, 'val_labels.npy'))

X_test = np.load(os.path.join(CIFAR10_PATH, 'test_imas.npy'))
y_test = np.load(os.path.join(CIFAR10_PATH, 'test_labels.npy'))

In [9]:
Z_train = np.load(os.path.join(MODEL_PATH, 'train_logits.npy'))
Z_val = np.load(os.path.join(MODEL_PATH, 'val_logits.npy'))
Z_test = np.load(os.path.join(MODEL_PATH, 'test_logits.npy'))

### Calibrate models

In [10]:
N, dim = Z_train.shape

### Temp-Scal as baseline:
tempScaler = TS(dim)
tempScaler.fit(Z_val, y_val, v=True);

Optimization terminated successfully.
         Current function value: 2171.642972
         Iterations: 7
         Function evaluations: 8
         Gradient evaluations: 8


In [11]:
hts = HTS(dim)
hts.fit(Z_val, y_val, v=True)

In [12]:
lts = LTS(dim)
lts.fit(Z_val, y_val, v=True)

In [13]:
hlts = HnLTS(dim)
hlts.fit(Z_val, y_val, v=True)

In [14]:
bts = BTS()
bts.fit(Z_val, y_val)

Fitted bin 0, with T: 2.79Fitted bin 1, with T: 2.76Fitted bin 2, with T: 3.23Fitted bin 3, with T: 3.34Fitted bin 4, with T: 2.93Fitted bin 5, with T: 2.51Fitted bin 6, with T: 2.06Fitted bin 7, with T: 3.21Fitted bin 8, with T: 2.46Fitted bin 9, with T: 2.09Fitted bin 10, with T: 2.87Fitted bin 11, with T: 3.00Fitted bin 12, with T: 2.83Fitted bin 13, with T: 2.73Fitted bin 14, with T: 3.01Fitted bin 15, with T: 3.80Fitted bin 16, with T: 2.58Fitted bin 17, with T: 2.27Fitted bin 18, with T: 2.91Fitted bin 19, with T: 2.65Fitted bin 20, with T: 2.35Fitted bin 21, with T: 3.19Fitted bin 22, with T: 3.75Fitted bin 23, with T: 3.15Fitted bin 24, with T: 2.63Fitted bin 25, with T: 3.76Fitted bin 26, with T: 2.62Fitted bin 27, with T: 2.91Fitted bin 28, with T: 3.11Fitted bin 29, with T: 2.53Fitted bin 30, with T: 2.98Fitted bin 31, with T: 2.82Fitted bin 32, with T: 2.82Fitted bin 33, with T: 3.29Fitted bin 34, with T: 2.49Fitted bin 35, with T: 3.16Fi

In [15]:
PTS = AdaTS(DNNbasedT(dim, hs=[5, 5]))
PTS = fitAdaTS(PTS, Z_val, y_val, epochs=30000, batch_size=1000, lr=1e-4, v=True, weight_decay=1e-2, dev=dev)

On epoch: 174, NLL: 2.974e+03, at time: 5.65s

KeyboardInterrupt: 

In [15]:
hisTS = HistTS()
hisTS.fit(Z_val, y_val)

In [16]:
print('##### Results on train set:')
compare_results(predictions={'Uncal': softmax(Z_train, axis=1),
                             'TempScal': tempScaler.predictive(Z_train),
                             'BTS': bts.predictive(Z_train),
                             'ETS': ets_calibrate(Z_val, onehot_encode(y_val), Z_train, dim),
                             'MIR': mir_calibrate(Z_val, onehot_encode(y_val), Z_train),
                             'HTS': hts.predictive(Z_train),
                             'LTS': lts.predictive(Z_train),
                             'HnLTS': hlts.predictive(Z_train)}, target=y_train, M=50, from_logits=False);

##### Results on train set:
Calibrator      Accuracy           ECE           MCE   Brier Score           NLL
Uncal             99.94%         0.12%        47.21%    1.104e-03     2.701e-03
TempScal          99.94%         5.36%        71.37%    8.604e-03     5.812e-02
BTS               99.94%         4.75%        72.62%    1.003e-02     5.309e-02
ETS               99.94%         7.37%        72.93%    1.287e-02     8.031e-02
MIR               99.94%         4.70%        78.25%    1.078e-02     5.266e-02
HTS               99.94%         4.79%        74.31%    9.582e-03     5.312e-02
LTS               99.94%         5.32%        72.07%    9.458e-03     5.824e-02
HnLTS             99.94%         4.79%        74.61%    1.012e-02     5.339e-02


In [17]:
print('##### Results on val set:')
compare_results(predictions={'Uncal': softmax(Z_val, axis=1),
                             'TempScal': tempScaler.predictive(Z_val),
                             'BTS': bts.predictive(Z_val),
                             'ETS': ets_calibrate(Z_val, onehot_encode(y_val), Z_val, dim),
                             'MIR': mir_calibrate(Z_val, onehot_encode(y_val), Z_val),
                             'HTS': hts.predictive(Z_val),
                             'LTS': lts.predictive(Z_val),
                             'HnLTS': hlts.predictive(Z_val)}, target=y_val, M=50, from_logits=False);

##### Results on val set:
Calibrator      Accuracy           ECE           MCE   Brier Score           NLL
Uncal             86.72%        10.31%        73.80%    2.300e-01     7.737e-01
TempScal          86.72%         3.06%        79.67%    1.967e-01     4.343e-01
BTS               86.72%         2.20%        80.17%    1.938e-01     4.275e-01
ETS               86.72%         3.40%        80.22%    1.958e-01     4.382e-01
MIR               86.72%         1.67%        82.23%    1.935e-01     4.188e-01
HTS               86.72%         2.16%        80.79%    1.953e-01     4.305e-01
LTS               86.72%         2.84%        78.86%    1.947e-01     4.278e-01
HnLTS             86.72%         2.07%        79.96%    1.937e-01     4.248e-01


In [18]:
print('##### Results on test set:')
compare_results(predictions={'Uncal': softmax(Z_test, axis=1),
                             'TempScal': tempScaler.predictive(Z_test),
                             'BTS': bts.predictive(Z_test),
                             'ETS': ets_calibrate(Z_val, onehot_encode(y_val), Z_test, dim),
                             'MIR': mir_calibrate(Z_val, onehot_encode(y_val), Z_test),
                             'HTS': hts.predictive(Z_test),
                             'LTS': lts.predictive(Z_test),
                             'HnLTS': hlts.predictive(Z_test)}, target=y_test, M=50, from_logits=False);

##### Results on test set:
Calibrator      Accuracy           ECE           MCE   Brier Score           NLL
Uncal             86.13%        10.71%        49.03%    2.392e-01     7.897e-01
TempScal          86.13%         2.54%        51.53%    2.037e-01     4.473e-01
BTS               86.13%         1.86%        60.38%    2.036e-01     4.498e-01
ETS               86.13%         2.84%        43.68%    2.029e-01     4.515e-01
MIR               86.13%         1.33%        80.74%    2.023e-01     4.435e-01
HTS               86.13%         1.45%        45.17%    2.024e-01     4.448e-01
LTS               86.13%         2.41%        43.20%    2.019e-01     4.399e-01
HnLTS             86.13%         1.58%        45.22%    2.011e-01     4.386e-01


### Temperature distribution

In [1]:
fig, ax = plt.subplots(3, 3, figsize=(25, 25))

Ts = lhaTempScaler.get_T(Z_train)
ax[0, 0].hist(Ts, bins=50)
ax[0, 0].set_xlabel('T', fontsize=22)
ax[0, 0].set_ylabel('Count', fontsize=22)
ax[0, 0].set_title('Train set', fontsize=26)

Ts = lhaTempScaler.get_T(Z_val)
ax[0, 1].hist(Ts, bins=50)
ax[0, 1].set_xlabel('T', fontsize=22)
ax[0, 1].set_ylabel('Count', fontsize=22)
ax[0, 1].set_title('Validation set', fontsize=26)

Ts = lhaTempScaler.get_T(Z_test)
ax[0, 2].hist(Ts, bins=50)
ax[0, 2].set_xlabel('T', fontsize=22)
ax[0, 2].set_ylabel('Count', fontsize=22)
ax[0, 2].set_title('Test set', fontsize=26)



Ts = dnnaTempScaler.get_T(Z_train)
ax[1, 0].hist(Ts, bins=50)
ax[1, 0].set_xlabel('T', fontsize=22)
ax[1, 0].set_ylabel('Count', fontsize=22)
ax[1, 0].set_title('Train set', fontsize=26)

Ts = dnnaTempScaler.get_T(Z_val)
ax[1, 1].hist(Ts, bins=50)
ax[1, 1].set_xlabel('T', fontsize=22)
ax[1, 1].set_ylabel('Count', fontsize=22)
ax[1, 1].set_title('Validation set', fontsize=26)

Ts = dnnaTempScaler.get_T(Z_test)
ax[1, 2].hist(Ts, bins=50)
ax[1, 2].set_xlabel('T', fontsize=22)
ax[1, 2].set_ylabel('Count', fontsize=22)
ax[1, 2].set_title('Test set', fontsize=26)



Ts = bdnnaTempScaler.get_T(Z_train)
ax[2, 0].hist(Ts, bins=50)
ax[2, 0].set_xlabel('T', fontsize=22)
ax[2, 0].set_ylabel('Count', fontsize=22)
ax[2, 0].set_title('Train set', fontsize=26)

Ts = bdnnaTempScaler.get_T(Z_val)
ax[2, 1].hist(Ts, bins=50)
ax[2, 1].set_xlabel('T', fontsize=22)
ax[2, 1].set_ylabel('Count', fontsize=22)
ax[2, 1].set_title('Validation set', fontsize=26)

Ts = bdnnaTempScaler.get_T(Z_test)
ax[2, 2].hist(Ts, bins=50)
ax[2, 2].set_xlabel('T', fontsize=22)
ax[2, 2].set_ylabel('Count', fontsize=22)
ax[2, 2].set_title('Test set', fontsize=26)

for _ax in ax.flatten():
    _ax.tick_params(axis='both', labelsize=18)
    

plt.show()

NameError: name 'plt' is not defined

### Selected temperature for different confidences in test set

In [None]:
hc, lc, hi, li = calib_split(Z_test, y_test)

ts_aux = TempScaling()
ts_aux.fit(Z_test[hc | hi], y_test[hc | hi]);
T_hc = ts_aux.T.detach().numpy()

ts_aux = TempScaling()
ts_aux.fit(Z_test[lc | li], y_test[lc | li]);
T_lc = ts_aux.T.detach().numpy()


In [None]:
TslH = lhaTempScaler.get_T(Z_test)
TsDNN = dnnaTempScaler.get_T(Z_test)
TsBDNN = bdnnaTempScaler.get_T(Z_test)

fig, ax = plt.subplots(3, 2, sharex=True, sharey=True, figsize=(23, 25))


ax[0, 0].hist(TslH[hc | hi])
ax[0, 0].axvline(T_hc, ls='--', lw=2, c='red', label='Optimum Temperature $T = {:.3f}$'.format(T_hc[0]))
ax[0, 0].set_title('High Confidence samples.\n $\mathbb{{E}}[T] = {:.3f}$'.format(np.mean(TslH[hc | hi])), fontsize=26)
ax[0, 0].legend(fontsize=18)

ax[0, 1].hist(TslH[lc | li])
ax[0, 1].axvline(T_lc, ls='--', lw=2, c='red', label='Optimum Temperature $T = {:.3f}$'.format(T_lc[0]))
ax[0, 1].set_title('Low Confidence samples.\n $\mathbb{{E}}[T] = {:.3f}$'.format(np.mean(TslH[lc | li])), fontsize=26)
ax[0, 1].legend(fontsize=18)



ax[1, 0].hist(TsDNN[hc | hi])
ax[1, 0].axvline(T_hc, ls='--', lw=2, c='red', label='Optimum Temperature $T = {:.3f}$'.format(T_hc[0]))
ax[1, 0].set_title('High Confidence samples.\n $\mathbb{{E}}[T] = {:.3f}$'.format(np.mean(TsDNN[hc | hi])), fontsize=26)
ax[1, 0].legend(fontsize=18)

ax[1, 1].hist(TsDNN[lc | li])
ax[1, 1].axvline(T_lc, ls='--', lw=2, c='red', label='Optimum Temperature $T = {:.3f}$'.format(T_lc[0]))
ax[1, 1].set_title('Low Confidence samples.\n $\mathbb{{E}}[T] = {:.3f}$'.format(np.mean(TsDNN[lc | li])), fontsize=26)
ax[1, 1].legend(fontsize=18)



ax[2, 0].hist(TsBDNN[hc | hi])
ax[2, 0].axvline(T_hc, ls='--', lw=2, c='red', label='Optimum Temperature $T = {:.3f}$'.format(T_hc[0]))
ax[2, 0].set_title('High Confidence samples.\n $\mathbb{{E}}[T] = {:.3f}$'.format(np.mean(TsBDNN[hc | hi])), fontsize=26)
ax[2, 0].legend(fontsize=18)

ax[2, 1].hist(TsBDNN[lc | li])
ax[2, 1].axvline(T_lc, ls='--', lw=2, c='red', label='Optimum Temperature $T = {:.3f}$'.format(T_lc[0]))
ax[2, 1].set_title('Low Confidence samples.\n $\mathbb{{E}}[T] = {:.3f}$'.format(np.mean(TsBDNN[lc | li])), fontsize=26)
ax[2, 1].legend(fontsize=18)

for _ax in ax.flatten():
    _ax.yaxis.set_tick_params(labelleft=True)
    _ax.xaxis.set_tick_params(labelbottom=True)
    _ax.set_xlabel('T', fontsize=22)
    _ax.set_ylabel('Count', fontsize=22)
    _ax.tick_params(axis='both', labelsize=18)
    
plt.tight_layout()
plt.show();

### According to quantiles

In [None]:
test_probs = softmax(Z_test, axis=1)
test_confs = np.max(test_probs, axis=1)

In [None]:
fig, ax = plt.subplots(figsize=(25, 8))

ax.hist(test_confs, bins=200)

ax.set_xlabel('Confidence', fontsize=22)
ax.set_ylabel('Count', fontsize=22)
ax.tick_params(axis='both', labelsize=18)

plt.show();

In [None]:
ix = np.argsort(test_confs)

q1, q2, q3, q4 = ix[:len(test_confs)//4], ix[len(test_confs)//4:len(test_confs)//2], ix[len(test_confs)//2: 3*len(test_confs)//4], ix[3*len(test_confs)//4:]


ts_aux = TempScaling()
ts_aux.fit(Z_test[q1], y_test[q1]);
T_q1 = ts_aux.T.detach().numpy()

ts_aux = TempScaling()
ts_aux.fit(Z_test[q2], y_test[q2]);
T_q2 = ts_aux.T.detach().numpy()

ts_aux = TempScaling()
ts_aux.fit(Z_test[q3], y_test[q3]);
T_q3 = ts_aux.T.detach().numpy()

ts_aux = TempScaling()
ts_aux.fit(Z_test[q4], y_test[q4]);
T_q4 = ts_aux.T.detach().numpy()



TslH = lhaTempScaler.get_T(Z_test)
TsDNN = dnnaTempScaler.get_T(Z_test)
TsBDNN = bdnnaTempScaler.get_T(Z_test)

In [None]:
fig, ax = plt.subplots(3, 4, sharex=True, sharey=True, figsize=(30, 25))


ax[0, 0].hist(TslH[q1])
ax[0, 0].axvline(T_q1, ls='--', lw=2, c='red', label='Optimum Temperature $T = {:.3f}$'.format(T_q1[0]))
ax[0, 0].set_title('Samples in Q1.\n $\mathbb{{E}}[T] = {:.3f}$'.format(np.mean(TslH[q1])), fontsize=26)

ax[0, 1].hist(TslH[q2])
ax[0, 1].axvline(T_q2, ls='--', lw=2, c='red', label='Optimum Temperature $T = {:.3f}$'.format(T_q2[0]))
ax[0, 1].set_title('Samples in Q2.\n $\mathbb{{E}}[T] = {:.3f}$'.format(np.mean(TslH[q2])), fontsize=26)

ax[0, 2].hist(TslH[q3])
ax[0, 2].axvline(T_q3, ls='--', lw=2, c='red', label='Optimum Temperature $T = {:.3f}$'.format(T_q3[0]))
ax[0, 2].set_title('Samples in Q3.\n $\mathbb{{E}}[T] = {:.3f}$'.format(np.mean(TslH[q3])), fontsize=26)

ax[0, 3].hist(TslH[q4])
ax[0, 3].axvline(T_q4, ls='--', lw=2, c='red', label='Optimum Temperature $T = {:.3f}$'.format(T_q4[0]))
ax[0, 3].set_title('Samples in Q4.\n $\mathbb{{E}}[T] = {:.3f}$'.format(np.mean(TslH[q4])), fontsize=26)



ax[1, 0].hist(TsDNN[q1])
ax[1, 0].axvline(T_q1, ls='--', lw=2, c='red', label='Optimum Temperature $T = {:.3f}$'.format(T_q1[0]))
ax[1, 0].set_title('Samples in Q1.\n $\mathbb{{E}}[T] = {:.3f}$'.format(np.mean(TsDNN[q1])), fontsize=26)

ax[1, 1].hist(TsDNN[q2])
ax[1, 1].axvline(T_q2, ls='--', lw=2, c='red', label='Optimum Temperature $T = {:.3f}$'.format(T_q2[0]))
ax[1, 1].set_title('Samples in Q2.\n $\mathbb{{E}}[T] = {:.3f}$'.format(np.mean(TsDNN[q2])), fontsize=26)

ax[1, 2].hist(TsDNN[q3])
ax[1, 2].axvline(T_q3, ls='--', lw=2, c='red', label='Optimum Temperature $T = {:.3f}$'.format(T_q3[0]))
ax[1, 2].set_title('Samples in Q3.\n $\mathbb{{E}}[T] = {:.3f}$'.format(np.mean(TsDNN[q3])), fontsize=26)

ax[1, 3].hist(TsDNN[q4])
ax[1, 3].axvline(T_q4, ls='--', lw=2, c='red', label='Optimum Temperature $T = {:.3f}$'.format(T_q4[0]))
ax[1, 3].set_title('Samples in Q4.\n $\mathbb{{E}}[T] = {:.3f}$'.format(np.mean(TsDNN[q4])), fontsize=26)



ax[2, 0].hist(TsBDNN[q1])
ax[2, 0].axvline(T_q1, ls='--', lw=2, c='red', label='Optimum Temperature $T = {:.3f}$'.format(T_q1[0]))
ax[2, 0].set_title('Samples in Q1.\n $\mathbb{{E}}[T] = {:.3f}$'.format(np.mean(TsBDNN[q1])), fontsize=26)

ax[2, 1].hist(TsBDNN[q2])
ax[2, 1].axvline(T_q2, ls='--', lw=2, c='red', label='Optimum Temperature $T = {:.3f}$'.format(T_q2[0]))
ax[2, 1].set_title('Samples in Q2.\n $\mathbb{{E}}[T] = {:.3f}$'.format(np.mean(TsBDNN[q2])), fontsize=26)

ax[2, 2].hist(TsBDNN[q3])
ax[2, 2].axvline(T_q3, ls='--', lw=2, c='red', label='Optimum Temperature $T = {:.3f}$'.format(T_q3[0]))
ax[2, 2].set_title('Samples in Q3.\n $\mathbb{{E}}[T] = {:.3f}$'.format(np.mean(TsBDNN[q3])), fontsize=26)

ax[2, 3].hist(TsBDNN[q4])
ax[2, 3].axvline(T_q4, ls='--', lw=2, c='red', label='Optimum Temperature $T = {:.3f}$'.format(T_q4[0]))
ax[2, 3].set_title('Samples in Q4.\n $\mathbb{{E}}[T] = {:.3f}$'.format(np.mean(TsBDNN[q4])), fontsize=26)



for _ax in ax.flatten():
    _ax.yaxis.set_tick_params(labelleft=True)
    _ax.xaxis.set_tick_params(labelbottom=True)
    _ax.legend(fontsize=18)
    _ax.set_xlabel('T', fontsize=22)
    _ax.set_ylabel('Count', fontsize=22)
    _ax.tick_params(axis='both', labelsize=18)

plt.show();

## Corruption Robustness: CIFAR10-C

In [None]:
cifar10c = get_CIFAR10_C(CIFAR10C_PATH)

categories = list(cifar10c.keys())
categories.remove('labels')

In [None]:
cifar10_transforms_test=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [None]:
net = load_model('resnet50', 'cifar10', model_path='../../trained_models') 

### Robustness without exposure to corrupted samples

In [None]:
cifar10c_logits = {}

for category in categories:
    cifar10c_logits[category] = {}
    print('Computing predictions for corruption: {}'.format(category))
    for severity in range(1, 6):
        test_data = NumpyDataset(cifar10c[category][severity], cifar10c['labels'][:10000], transform=cifar10_transforms_test)
        test_dataloader = DataLoader(test_data, batch_size=256)
        cifar10c_logits[category][severity] = predict_logits(net, test_dataloader, torch.device('cuda'))

In [None]:
cifar10c_metrics_uncal = {}
cifar10c_metrics_tscal = {}
cifar10c_metrics_lhtscal = {}
cifar10c_metrics_dnntscal = {}
cifar10c_metrics_bdnntscal = {}
cifar10c_metrics_histtscal = {}

for category in categories:
    cifar10c_metrics_uncal[category] = {}
    cifar10c_metrics_tscal[category] = {}
    cifar10c_metrics_lhtscal[category] = {}
    cifar10c_metrics_dnntscal[category] = {}
    cifar10c_metrics_bdnntscal[category] = {}
    cifar10c_metrics_histtscal[category] = {}
    print('Computing metrics for corruption: {}'.format(category))
    for severity in range(1, 6):
        cifar10c_metrics_uncal[category][severity] = compute_metrics(cifar10c_logits[category][severity], cifar10c['labels'][:10000])
        cifar10c_metrics_tscal[category][severity] = compute_metrics(tempScaler.predictive(cifar10c_logits[category][severity]), cifar10c['labels'][:10000], from_logits=False)
        cifar10c_metrics_lhtscal[category][severity] = compute_metrics(lhaTempScaler.predictive(cifar10c_logits[category][severity]), cifar10c['labels'][:10000], from_logits=False)
        cifar10c_metrics_dnntscal[category][severity] = compute_metrics(dnnaTempScaler.predictive(cifar10c_logits[category][severity]), cifar10c['labels'][:10000], from_logits=False)
        cifar10c_metrics_bdnntscal[category][severity] = compute_metrics(bdnnaTempScaler.predictive(cifar10c_logits[category][severity]), cifar10c['labels'][:10000], from_logits=False)
        cifar10c_metrics_histtscal[category][severity] = compute_metrics(hisTS.predictive(cifar10c_logits[category][severity]), cifar10c['labels'][:10000], from_logits=False)

In [None]:
mean_sev_uncal = np.zeros((5, 4))
mean_sev_tscal = np.zeros((5, 4))
mean_sev_lhtscal = np.zeros((5, 4))
mean_sev_dnntscal = np.zeros((5, 4))
mean_sev_bdnntscal = np.zeros((5, 4))
mean_sev_histtscal = np.zeros((5, 4))


for i in range(5):
    for cat in categories:
        mean_sev_uncal[i] += cifar10c_metrics_uncal[cat][i+1]
        mean_sev_tscal[i] += cifar10c_metrics_tscal[cat][i+1]
        mean_sev_lhtscal[i] += cifar10c_metrics_lhtscal[cat][i+1]
        mean_sev_dnntscal[i] += cifar10c_metrics_dnntscal[cat][i+1]
        mean_sev_bdnntscal[i] += cifar10c_metrics_bdnntscal[cat][i+1]
        mean_sev_histtscal[i] += cifar10c_metrics_histtscal[cat][i+1]
        
mean_sev_uncal /= len(categories)
mean_sev_tscal /= len(categories)
mean_sev_lhtscal /= len(categories)
mean_sev_dnntscal /= len(categories)
mean_sev_bdnntscal /= len(categories)
mean_sev_histtscal /= len(categories)

In [None]:
fig, ax = plt.subplots(3, 1, figsize=(28, 33))

# ax[0].plot(mean_sev_uncal[:, 1], ls='--', marker='*', label='Uncalibrated')
ax[0].plot(mean_sev_tscal[:, 1], ls='--', marker='*', label='Temp-Scaling')
ax[0].plot(mean_sev_lhtscal[:, 1], ls='--', marker='*', label='$\log H$ Temp-Scaling')
ax[0].plot(mean_sev_dnntscal[:, 1], ls='--', marker='*', label='DNN Temp-Scaling')
ax[0].plot(mean_sev_bdnntscal[:, 1], ls='--', marker='*', label='Big DNN Temp-Scaling')
ax[0].plot(mean_sev_histtscal[:, 1], ls='--', marker='*', label='H Hist Temp-Scaling')
ax[0].set_ylabel('ECE', fontsize=22)

# ax[1].plot(mnll_sev_uncal, label='Uncalibrated')
ax[1].plot(mean_sev_tscal[:, 3], ls='--', marker='*', label='Temp-Scaling')
ax[1].plot(mean_sev_lhtscal[:, 3], ls='--', marker='*', label='$\log H$ Temp-Scaling')
ax[1].plot(mean_sev_dnntscal[:, 3], ls='--', marker='*', label='DNN Temp-Scaling')
ax[1].plot(mean_sev_bdnntscal[:, 3], ls='--', marker='*', label='Big DNN Temp-Scaling')
ax[1].plot(mean_sev_histtscal[:, 3], ls='--', marker='*', label='H Hist Temp-Scaling')
ax[1].set_ylabel('NLL', fontsize=22)

ax[2].plot(mean_sev_tscal[:, 2], ls='--', marker='*', label='Temp-Scaling')
ax[2].plot(mean_sev_lhtscal[:, 2], ls='--', marker='*', label='$\log H$ Temp-Scaling')
ax[2].plot(mean_sev_dnntscal[:, 2], ls='--', marker='*', label='DNN Temp-Scaling')
ax[2].plot(mean_sev_bdnntscal[:, 2], ls='--', marker='*', label='Big DNN Temp-Scaling')
ax[2].plot(mean_sev_histtscal[:, 2], ls='--', marker='*', label='H Hist Temp-Scaling')
ax[2].set_ylabel('Brier Score', fontsize=22)

for _ax in ax.flatten():
    _ax.set_xticks(np.arange(5))
    _ax.set_xticklabels(np.arange(5) + 1)
    _ax.legend(fontsize=18)
    _ax.tick_params(axis='both', labelsize=18)
    _ax.set_xlabel('Corruption level', fontsize=22)

plt.show()

In [None]:
cifar10c_TslH= {}
cifar10c_TsDNN= {}

for category in categories:
    cifar10c_TslH[category] = {}
    cifar10c_TsDNN[category] = {}
    print('Computing Ts for corruption: {}'.format(category))
    for severity in range(1, 6):
        cifar10c_TslH[category][severity] = lhaTempScaler.get_T(cifar10c_logits[category][severity])
        cifar10c_TsDNN[category][severity] = dnnaTempScaler.get_T(cifar10c_logits[category][severity])

In [None]:
mean_TslH = np.zeros(5)
mean_TsDNN = np.zeros(5)
for i in range(5):
    for cat in categories:
        mean_TslH[i] += np.mean(cifar10c_TslH[cat][i+1])
        mean_TsDNN[i] += np.mean(cifar10c_TsDNN[cat][i+1])
        
mean_TslH /= len(categories)
mean_TsDNN /= len(categories)

In [None]:
print(mean_TslH)
print(mean_TsDNN)

### Exposure to corruptions

In [None]:
rnd_idx = np.random.permutation(10000)
idx_train = rnd_idx[:3000]
idx_test = rnd_idx[3000:]

In [None]:
train_set = np.vstack([cifar10c_logits[cat][i+1][idx_train] for cat in categories for i in range(5)])
y_train = np.hstack([cifar10c['labels'][idx_train] for cat in categories for i in range(5)])

In [None]:
### Temp-Scal as baseline:
tempScaler = TempScaling()
tempScaler.fit(train_set, y_train, v=True);

In [None]:
lhaTempScaler = AdaTS(HlogbasedT(dim))
lhaTempScaler = fitAdaTS(lhaTempScaler, train_set, y_train, epochs=10000, batch_size=1000, lr=1e-4, v=True, dev=dev)

In [None]:
dnnaTempScaler = AdaTS(DNNbasedT(dim))
dnnaTempScaler = fitAdaTS(dnnaTempScaler, train_set, y_train, epochs=10000, batch_size=1000, lr=1e-4, v=True, weight_decay=1e-2, dev=dev)

In [None]:
bdnnaTempScaler = AdaTS(DNNbasedT(dim, hs=[2*dim, 2*dim]))
bdnnaTempScaler = fitAdaTS(bdnnaTempScaler, train_set, y_train, epochs=10000, batch_size=1000, lr=1e-4, v=True, weight_decay=1e-2, dev=dev)

In [None]:
hisTS = HistTS()
hisTS.fit(train_set, y_train)

In [None]:
cifar10c_metrics_uncal = {}
cifar10c_metrics_tscal = {}
cifar10c_metrics_lhtscal = {}
cifar10c_metrics_dnntscal = {}
cifar10c_metrics_bdnntscal = {}
cifar10c_metrics_histtscal = {}

for category in categories:
    cifar10c_metrics_uncal[category] = {}
    cifar10c_metrics_tscal[category] = {}
    cifar10c_metrics_lhtscal[category] = {}
    cifar10c_metrics_dnntscal[category] = {}
    cifar10c_metrics_bdnntscal[category] = {}
    cifar10c_metrics_histtscal[category] = {}
    print('Computing metrics for corruption: {}'.format(category))
    for severity in range(1, 6):
        cifar10c_metrics_uncal[category][severity] = compute_metrics(cifar10c_logits[category][severity][idx_test], cifar10c['labels'][idx_test])
        cifar10c_metrics_tscal[category][severity] = compute_metrics(tempScaler.predictive(cifar10c_logits[category][severity][idx_test]), cifar10c['labels'][idx_test], from_logits=False)
        cifar10c_metrics_lhtscal[category][severity] = compute_metrics(lhaTempScaler.predictive(cifar10c_logits[category][severity][idx_test]), cifar10c['labels'][idx_test], from_logits=False)
        cifar10c_metrics_dnntscal[category][severity] = compute_metrics(dnnaTempScaler.predictive(cifar10c_logits[category][severity][idx_test]), cifar10c['labels'][idx_test], from_logits=False)
        cifar10c_metrics_bdnntscal[category][severity] = compute_metrics(bdnnaTempScaler.predictive(cifar10c_logits[category][severity][idx_test]), cifar10c['labels'][idx_test], from_logits=False)
        cifar10c_metrics_histtscal[category][severity] = compute_metrics(hisTS.predictive(cifar10c_logits[category][severity][idx_test]), cifar10c['labels'][idx_test], from_logits=False)

In [None]:
mean_sev_uncal = np.zeros((5, 4))
mean_sev_tscal = np.zeros((5, 4))
mean_sev_lhtscal = np.zeros((5, 4))
mean_sev_dnntscal = np.zeros((5, 4))
mean_sev_bdnntscal = np.zeros((5, 4))
mean_sev_histtscal = np.zeros((5, 4))


for i in range(5):
    for cat in categories:
        mean_sev_uncal[i] += cifar10c_metrics_uncal[cat][i+1]
        mean_sev_tscal[i] += cifar10c_metrics_tscal[cat][i+1]
        mean_sev_lhtscal[i] += cifar10c_metrics_lhtscal[cat][i+1]
        mean_sev_dnntscal[i] += cifar10c_metrics_dnntscal[cat][i+1]
        mean_sev_bdnntscal[i] += cifar10c_metrics_bdnntscal[cat][i+1]
        mean_sev_histtscal[i] += cifar10c_metrics_histtscal[cat][i+1]
        
mean_sev_uncal /= len(categories)
mean_sev_tscal /= len(categories)
mean_sev_lhtscal /= len(categories)
mean_sev_dnntscal /= len(categories)
mean_sev_bdnntscal /= len(categories)
mean_sev_histtscal /= len(categories)

In [None]:
fig, ax = plt.subplots(3, 1, figsize=(28, 33))

# ax[0].plot(mean_sev_uncal[:, 1], ls='--', marker='*', label='Uncalibrated')
ax[0].plot(mean_sev_tscal[:, 1], ls='--', marker='*', label='Temp-Scaling')
ax[0].plot(mean_sev_lhtscal[:, 1], ls='--', marker='*', label='$\log H$ Temp-Scaling')
ax[0].plot(mean_sev_dnntscal[:, 1], ls='--', marker='*', label='DNN Temp-Scaling')
ax[0].plot(mean_sev_bdnntscal[:, 1], ls='--', marker='*', label='Big DNN Temp-Scaling')
ax[0].plot(mean_sev_histtscal[:, 1], ls='--', marker='*', label='H Hist Temp-Scaling')
ax[0].set_ylabel('ECE', fontsize=22)

# ax[1].plot(mnll_sev_uncal, label='Uncalibrated')
ax[1].plot(mean_sev_tscal[:, 3], ls='--', marker='*', label='Temp-Scaling')
ax[1].plot(mean_sev_lhtscal[:, 3], ls='--', marker='*', label='$\log H$ Temp-Scaling')
ax[1].plot(mean_sev_dnntscal[:, 3], ls='--', marker='*', label='DNN Temp-Scaling')
ax[1].plot(mean_sev_bdnntscal[:, 3], ls='--', marker='*', label='Big DNN Temp-Scaling')
ax[1].plot(mean_sev_histtscal[:, 3], ls='--', marker='*', label='H Hist Temp-Scaling')
ax[1].set_ylabel('NLL', fontsize=22)

ax[2].plot(mean_sev_tscal[:, 2], ls='--', marker='*', label='Temp-Scaling')
ax[2].plot(mean_sev_lhtscal[:, 2], ls='--', marker='*', label='$\log H$ Temp-Scaling')
ax[2].plot(mean_sev_dnntscal[:, 2], ls='--', marker='*', label='DNN Temp-Scaling')
ax[2].plot(mean_sev_bdnntscal[:, 2], ls='--', marker='*', label='Big DNN Temp-Scaling')
ax[2].plot(mean_sev_histtscal[:, 2], ls='--', marker='*', label='H Hist Temp-Scaling')
ax[2].set_ylabel('Brier Score', fontsize=22)

for _ax in ax.flatten():
    _ax.set_xticks(np.arange(5))
    _ax.set_xticklabels(np.arange(5) + 1)
    _ax.legend(fontsize=18)
    _ax.tick_params(axis='both', labelsize=18)
    _ax.set_xlabel('Corruption level', fontsize=22)

plt.show()