In [1]:
import os
import sys
import time
sys.path.extend(['..'])

import torch
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt

In [2]:
plt.rc('text', usetex=True)
plt.rc('text.latex', preamble=r'\usepackage{amsmath} \usepackage{amssymb}')

In [3]:
res_path = '../../results/pretrained'

ece_file_reg = 'logitsJuan_ECE_reg.csv'
nll_file_reg = 'logitsJuan_NLL_reg.csv'
bri_file_reg = 'logitsJuan_BRI_reg.csv'

In [4]:
ECE = pd.read_csv(os.path.join(res_path, ece_file_reg)).drop(['Unnamed: 0'], axis=1)
NLL = pd.read_csv(os.path.join(res_path, nll_file_reg)).drop(['Unnamed: 0'], axis=1)
Bri = pd.read_csv(os.path.join(res_path, bri_file_reg)).drop(['Unnamed: 0'], axis=1)

In [5]:
datasets = ECE['Dataset'].unique()
print(datasets)

['cifar10' 'cifar100' 'cars' 'birds' 'svhn']


In [6]:
def highlight_min(s):
    '''
    highlight the minimum.
    '''
    is_max = s == s.min()
    return ['font-weight: bold' if v else '' for v in is_max]

In [7]:
def highlight_min_br(s):
    print(s['ECE'])
    
    
    s['ECE'].style.apply(highlight_min, subset=s.select_dtypes(float))
    s['NLL'].style.apply(highlight_min, subset=s.select_dtypes(float))
    s['Brier'].style.apply(highlight_min, subset=s.select_dtypes(float))
      
    return s

### CIFAR 10

In [8]:
curr_ece = ECE.loc[ECE.Dataset=='cifar10'].set_index('Model')
curr_nll = NLL.loc[ECE.Dataset=='cifar10'].set_index('Model')
curr_bri = Bri.loc[ECE.Dataset=='cifar10'].set_index('Model')

In [9]:
curr_ece.drop('Dataset', axis=1).transpose().style.apply(highlight_min)

Model,densenet-121,densenet-169,resnet-101,resnext-29_8x16,vgg-19,wide-resnet-28x10,wide-resnet-40x10
Uncalibrated,2.83513,2.791066,4.450539,2.883827,4.605083,1.929727,3.11525
TS,1.386326,1.650474,1.373541,1.02603,2.514881,0.697266,1.04178
ETS,2.257274,2.30978,1.49105,0.852116,3.304092,1.593208,1.137043
MIR,1.057799,1.04118,1.285645,0.694775,1.221535,1.169693,0.586361
BTS,1.092914,1.266436,1.69876,1.185606,1.376058,1.0719,1.275609
PTS,1.344637,1.464498,1.774396,1.539388,2.186014,1.256031,1.784367
LinearTS,1.264614,1.305006,1.368699,1.202503,2.124044,0.910587,1.151095
HTS,1.345383,1.201751,1.691112,0.915007,1.635808,0.746984,1.116338
HnLinearTS,1.381722,1.033772,1.512913,1.111944,1.740951,0.969388,1.088227


In [10]:
curr_nll.drop('Dataset', axis=1).transpose().style.apply(highlight_min)

Model,densenet-121,densenet-169,resnet-101,resnext-29_8x16,vgg-19,wide-resnet-28x10,wide-resnet-40x10
Uncalibrated,10.267942,9.881465,12.523792,19.334122,9.894833,10.44722,12.135914
TS,0.161812,0.160774,0.216332,0.162096,0.235521,0.136182,0.163392
ETS,0.167319,0.168556,0.219902,0.161691,0.242422,0.142409,0.164236
MIR,0.164836,0.16493,0.220817,0.168679,0.228012,0.147148,0.167439
BTS,0.195247,0.236109,0.256819,0.204134,0.290354,0.210356,0.183968
PTS,0.166108,inf,0.225441,inf,0.239876,0.153402,0.177601
LinearTS,0.160565,0.159271,0.213106,0.162965,0.231306,0.134881,0.163118
HTS,0.16132,0.159185,0.217971,0.162146,0.233448,0.136645,0.16333
HnLinearTS,0.160669,0.158659,0.215214,0.16325,0.230335,0.135725,0.163352


In [11]:
curr_bri.drop('Dataset', axis=1).transpose().style.apply(highlight_min)

Model,densenet-121,densenet-169,resnet-101,resnext-29_8x16,vgg-19,wide-resnet-28x10,wide-resnet-40x10
Uncalibrated,0.076359,0.075375,0.110169,0.082816,0.110094,0.062887,0.081958
TS,0.072888,0.071723,0.101096,0.078307,0.101859,0.060798,0.076777
ETS,0.07244,0.07121,0.101027,0.078233,0.100686,0.060941,0.076449
MIR,0.072515,0.071589,0.102072,0.078626,0.099855,0.061612,0.076991
BTS,0.072231,0.071064,0.101906,0.079086,0.100336,0.061111,0.077036
PTS,0.072851,0.073794,0.101424,0.080835,0.102472,0.062112,0.079564
LinearTS,0.072632,0.071348,0.100353,0.078265,0.101058,0.060599,0.076699
HTS,0.072546,0.071131,0.101072,0.078301,0.100401,0.060694,0.076617
HnLinearTS,0.07244,0.07094,0.100321,0.078262,0.099874,0.060553,0.076596


In [12]:
pd.concat([curr_ece, curr_nll, curr_bri], axis=1, keys=['ECE', 'NLL', 'Brier'], names=['Metric', 'Model'])

Metric,ECE,ECE,ECE,ECE,ECE,ECE,ECE,ECE,ECE,ECE,...,Brier,Brier,Brier,Brier,Brier,Brier,Brier,Brier,Brier,Brier
Model,Dataset,Uncalibrated,TS,ETS,MIR,BTS,PTS,LinearTS,HTS,HnLinearTS,...,Dataset,Uncalibrated,TS,ETS,MIR,BTS,PTS,LinearTS,HTS,HnLinearTS
Model,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
densenet-121,cifar10,2.83513,1.386326,2.257274,1.057799,1.092914,1.344637,1.264614,1.345383,1.381722,...,cifar10,0.076359,0.072888,0.07244,0.072515,0.072231,0.072851,0.072632,0.072546,0.07244
densenet-169,cifar10,2.791066,1.650474,2.30978,1.04118,1.266436,1.464498,1.305006,1.201751,1.033772,...,cifar10,0.075375,0.071723,0.07121,0.071589,0.071064,0.073794,0.071348,0.071131,0.07094
resnet-101,cifar10,4.450539,1.373541,1.49105,1.285645,1.69876,1.774396,1.368699,1.691112,1.512913,...,cifar10,0.110169,0.101096,0.101027,0.102072,0.101906,0.101424,0.100353,0.101072,0.100321
resnext-29_8x16,cifar10,2.883827,1.02603,0.852116,0.694775,1.185606,1.539388,1.202503,0.915007,1.111944,...,cifar10,0.082816,0.078307,0.078233,0.078626,0.079086,0.080835,0.078265,0.078301,0.078262
vgg-19,cifar10,4.605083,2.514881,3.304092,1.221535,1.376058,2.186014,2.124044,1.635808,1.740951,...,cifar10,0.110094,0.101859,0.100686,0.099855,0.100336,0.102472,0.101058,0.100401,0.099874
wide-resnet-28x10,cifar10,1.929727,0.697266,1.593208,1.169693,1.0719,1.256031,0.910587,0.746984,0.969388,...,cifar10,0.062887,0.060798,0.060941,0.061612,0.061111,0.062112,0.060599,0.060694,0.060553
wide-resnet-40x10,cifar10,3.11525,1.04178,1.137043,0.586361,1.275609,1.784367,1.151095,1.116338,1.088227,...,cifar10,0.081958,0.076777,0.076449,0.076991,0.077036,0.079564,0.076699,0.076617,0.076596


### CIFAR 100

In [13]:
curr_ece = ECE.loc[ECE.Dataset=='cifar100'].set_index('Model')
curr_nll = NLL.loc[ECE.Dataset=='cifar100'].set_index('Model')
curr_bri = Bri.loc[ECE.Dataset=='cifar100'].set_index('Model')

pd.concat([curr_ece, curr_nll, curr_bri], axis=1, keys=['ECE', 'NLL', 'Brier'], names=['Metric', 'Model'])

Metric,ECE,ECE,ECE,ECE,ECE,ECE,ECE,ECE,ECE,ECE,...,Brier,Brier,Brier,Brier,Brier,Brier,Brier,Brier,Brier,Brier
Model,Dataset,Uncalibrated,TS,ETS,MIR,BTS,PTS,LinearTS,HTS,HnLinearTS,...,Dataset,Uncalibrated,TS,ETS,MIR,BTS,PTS,LinearTS,HTS,HnLinearTS
Model,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
densenet-121,cifar100,8.76025,3.933633,3.039655,1.673477,2.756739,5.620568,4.14419,3.322253,3.362073,...,cifar100,0.317058,0.304773,0.304354,0.303697,0.305231,0.315193,0.304155,0.304919,0.303477
densenet-169,cifar100,8.93241,3.948776,2.897083,1.144476,3.187848,6.243349,4.03347,3.308287,3.695375,...,cifar100,0.314196,0.30167,0.301037,0.299443,0.301962,0.31306,0.299977,0.301601,0.299484
resnet-101,cifar100,11.446499,2.246794,2.243738,2.795522,2.224122,4.071449,2.718405,2.222823,2.473884,...,cifar100,0.405327,0.381723,0.381624,0.385367,0.38247,0.390749,0.382248,0.381687,0.382036
resnext-29_8x16,cifar100,9.692433,3.13951,2.675645,1.924365,2.061209,5.179378,3.546601,1.949264,2.579883,...,cifar100,0.327484,0.309558,0.309294,0.310697,0.309646,0.321346,0.309996,0.308996,0.309261
vgg-19,cifar100,17.631318,5.133481,5.364392,1.917828,3.89307,4.556667,3.73748,3.58957,3.497068,...,cifar100,0.443283,0.391811,0.391021,0.390587,0.388023,0.39198,0.386529,0.389495,0.387019
wide-resnet-28x10,cifar100,5.187939,4.629841,3.546959,1.84259,3.105428,5.28562,4.439579,3.515935,3.69836,...,cifar100,0.289225,0.288609,0.287947,0.285278,0.287132,0.291066,0.285156,0.287152,0.283941
wide-resnet-40x10,cifar100,14.784534,4.201734,2.739724,1.83233,3.551389,5.767309,4.423875,3.731363,3.93878,...,cifar100,0.370023,0.327633,0.326992,0.328069,0.329314,0.33311,0.324495,0.328604,0.324299


In [14]:
curr_ece.drop('Dataset', axis=1).transpose().style.apply(highlight_min)

Model,densenet-121,densenet-169,resnet-101,resnext-29_8x16,vgg-19,wide-resnet-28x10,wide-resnet-40x10
Uncalibrated,8.76025,8.93241,11.446499,9.692433,17.631318,5.187939,14.784534
TS,3.933633,3.948776,2.246794,3.13951,5.133481,4.629841,4.201734
ETS,3.039655,2.897083,2.243738,2.675645,5.364392,3.546959,2.739724
MIR,1.673477,1.144476,2.795522,1.924365,1.917828,1.84259,1.83233
BTS,2.756739,3.187848,2.224122,2.061209,3.89307,3.105428,3.551389
PTS,5.620568,6.243349,4.071449,5.179378,4.556667,5.28562,5.767309
LinearTS,4.14419,4.03347,2.718405,3.546601,3.73748,4.439579,4.423875
HTS,3.322253,3.308287,2.222823,1.949264,3.58957,3.515935,3.731363
HnLinearTS,3.362073,3.695375,2.473884,2.579883,3.497068,3.69836,3.93878


In [15]:
curr_nll.drop('Dataset', axis=1).transpose().style.apply(highlight_min)

Model,densenet-121,densenet-169,resnet-101,resnext-29_8x16,vgg-19,wide-resnet-28x10,wide-resnet-40x10
Uncalibrated,14.603166,14.712597,23.377213,35.042618,11.451286,17.796827,21.331923
TS,0.835487,0.81557,1.000677,0.822038,1.199659,0.813459,0.905464
ETS,0.859196,0.832331,1.009053,0.839845,1.206897,0.830835,0.926263
MIR,0.84818,0.827204,1.021776,0.844593,1.195741,0.804593,0.94482
BTS,0.827012,0.808159,1.00396,0.818996,1.19256,0.787616,0.903716
PTS,inf,0.90239,1.057343,0.929688,inf,0.83603,0.940038
LinearTS,0.836219,0.816024,1.008558,0.829108,1.178349,0.806298,0.898442
HTS,0.82537,0.806239,1.000181,0.812858,1.194119,0.787081,0.899245
HnLinearTS,0.821233,0.802643,1.006788,0.81733,1.17748,0.778204,0.88498


In [16]:
curr_bri.drop('Dataset', axis=1).transpose().style.apply(highlight_min)

Model,densenet-121,densenet-169,resnet-101,resnext-29_8x16,vgg-19,wide-resnet-28x10,wide-resnet-40x10
Uncalibrated,0.317058,0.314196,0.405327,0.327484,0.443283,0.289225,0.370023
TS,0.304773,0.30167,0.381723,0.309558,0.391811,0.288609,0.327633
ETS,0.304354,0.301037,0.381624,0.309294,0.391021,0.287947,0.326992
MIR,0.303697,0.299443,0.385367,0.310697,0.390587,0.285278,0.328069
BTS,0.305231,0.301962,0.38247,0.309646,0.388023,0.287132,0.329314
PTS,0.315193,0.31306,0.390749,0.321346,0.39198,0.291066,0.33311
LinearTS,0.304155,0.299977,0.382248,0.309996,0.386529,0.285156,0.324495
HTS,0.304919,0.301601,0.381687,0.308996,0.389495,0.287152,0.328604
HnLinearTS,0.303477,0.299484,0.382036,0.309261,0.387019,0.283941,0.324299


### Cars

In [17]:
curr_ece = ECE.loc[ECE.Dataset=='cars'].set_index('Model')
curr_nll = NLL.loc[ECE.Dataset=='cars'].set_index('Model')
curr_bri = Bri.loc[ECE.Dataset=='cars'].set_index('Model')

pd.concat([curr_ece, curr_nll, curr_bri], axis=1, keys=['ECE', 'NLL', 'Brier'], names=['Metric', 'Model'])

Metric,ECE,ECE,ECE,ECE,ECE,ECE,ECE,ECE,ECE,ECE,...,Brier,Brier,Brier,Brier,Brier,Brier,Brier,Brier,Brier,Brier
Model,Dataset,Uncalibrated,TS,ETS,MIR,BTS,PTS,LinearTS,HTS,HnLinearTS,...,Dataset,Uncalibrated,TS,ETS,MIR,BTS,PTS,LinearTS,HTS,HnLinearTS
Model,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
densenet-121,cars,5.865755,2.34257,1.912941,2.051523,2.160813,12.009591,4.276551,1.557261,4.213984,...,cars,0.17345,0.164603,0.165007,0.166828,0.166408,0.230406,0.177202,0.164132,0.176487
densenet-169,cars,5.821672,2.380512,1.951687,1.788371,2.568568,13.447186,4.380754,1.760003,4.631363,...,cars,0.172543,0.163421,0.162912,0.166504,0.167407,0.239733,0.178926,0.162921,0.179122
resnet-18,cars,7.038811,1.874214,2.357757,1.657376,2.808024,14.882453,4.694481,1.762694,4.879349,...,cars,0.207543,0.194467,0.194553,0.198104,0.199191,0.283024,0.208075,0.194221,0.208159
resnet-50,cars,5.190266,2.481778,1.506835,1.194913,2.607814,11.154547,4.557648,1.758101,4.559414,...,cars,0.160202,0.154095,0.153287,0.154946,0.156918,0.210636,0.167519,0.153273,0.166631
resnet-101,cars,5.400418,2.312753,1.78631,1.389031,2.133802,12.088839,4.205114,1.630611,4.307877,...,cars,0.16174,0.15365,0.15304,0.155587,0.157109,0.221673,0.166742,0.152951,0.167044


In [18]:
curr_ece.drop('Dataset', axis=1).transpose().style.apply(highlight_min)

Model,densenet-121,densenet-169,resnet-18,resnet-50,resnet-101
Uncalibrated,5.865755,5.821672,7.038811,5.190266,5.400418
TS,2.34257,2.380512,1.874214,2.481778,2.312753
ETS,1.912941,1.951687,2.357757,1.506835,1.78631
MIR,2.051523,1.788371,1.657376,1.194913,1.389031
BTS,2.160813,2.568568,2.808024,2.607814,2.133802
PTS,12.009591,13.447186,14.882453,11.154547,12.088839
LinearTS,4.276551,4.380754,4.694481,4.557648,4.205114
HTS,1.557261,1.760003,1.762694,1.758101,1.630611
HnLinearTS,4.213984,4.631363,4.879349,4.559414,4.307877


In [19]:
curr_nll.drop('Dataset', axis=1).transpose().style.apply(highlight_min)

Model,densenet-121,densenet-169,resnet-18,resnet-50,resnet-101
Uncalibrated,26.357357,28.58099,38.234886,27.692736,33.542385
TS,0.417259,0.416619,0.492074,0.390955,0.384694
ETS,0.443188,0.434737,0.50993,0.396151,0.395463
MIR,0.425034,0.430797,0.509265,0.403214,0.390006
BTS,inf,0.510133,0.659495,inf,inf
PTS,inf,inf,inf,inf,inf
LinearTS,0.597905,inf,0.727421,inf,inf
HTS,0.412039,0.410754,0.490105,0.384607,0.378864
HnLinearTS,0.575785,inf,0.739974,inf,inf


In [20]:
curr_bri.drop('Dataset', axis=1).transpose().style.apply(highlight_min)

Model,densenet-121,densenet-169,resnet-18,resnet-50,resnet-101
Uncalibrated,0.17345,0.172543,0.207543,0.160202,0.16174
TS,0.164603,0.163421,0.194467,0.154095,0.15365
ETS,0.165007,0.162912,0.194553,0.153287,0.15304
MIR,0.166828,0.166504,0.198104,0.154946,0.155587
BTS,0.166408,0.167407,0.199191,0.156918,0.157109
PTS,0.230406,0.239733,0.283024,0.210636,0.221673
LinearTS,0.177202,0.178926,0.208075,0.167519,0.166742
HTS,0.164132,0.162921,0.194221,0.153273,0.152951
HnLinearTS,0.176487,0.179122,0.208159,0.166631,0.167044


### Birds

In [21]:
curr_ece = ECE.loc[ECE.Dataset=='birds'].set_index('Model')
curr_nll = NLL.loc[ECE.Dataset=='birds'].set_index('Model')
curr_bri = Bri.loc[ECE.Dataset=='birds'].set_index('Model')

pd.concat([curr_ece, curr_nll, curr_bri], axis=1, keys=['ECE', 'NLL', 'Brier'], names=['Metric', 'Model'])

Metric,ECE,ECE,ECE,ECE,ECE,ECE,ECE,ECE,ECE,ECE,...,Brier,Brier,Brier,Brier,Brier,Brier,Brier,Brier,Brier,Brier
Model,Dataset,Uncalibrated,TS,ETS,MIR,BTS,PTS,LinearTS,HTS,HnLinearTS,...,Dataset,Uncalibrated,TS,ETS,MIR,BTS,PTS,LinearTS,HTS,HnLinearTS
Model,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
densenet-121,birds,12.429901,2.940005,3.145051,2.988313,3.561318,21.169717,7.818938,3.191521,8.673096,...,birds,0.352798,0.323056,0.323294,0.327902,0.329046,0.441787,0.356361,0.323032,0.35883
densenet-169,birds,12.649702,3.089718,3.098617,3.098784,4.54902,22.374788,9.200208,2.521765,10.609864,...,birds,0.347038,0.314235,0.314541,0.321847,0.324452,0.447073,0.350503,0.314137,0.356461
resnet-101,birds,12.642531,2.83746,3.252857,2.667781,3.864773,20.754587,8.692656,2.809471,8.442265,...,birds,0.337751,0.305229,0.305641,0.310696,0.311545,0.426614,0.340839,0.305312,0.335553


In [22]:
curr_ece.drop('Dataset', axis=1).transpose().style.apply(highlight_min)

Model,densenet-121,densenet-169,resnet-101
Uncalibrated,12.429901,12.649702,12.642531
TS,2.940005,3.089718,2.83746
ETS,3.145051,3.098617,3.252857
MIR,2.988313,3.098784,2.667781
BTS,3.561318,4.54902,3.864773
PTS,21.169717,22.374788,20.754587
LinearTS,7.818938,9.200208,8.692656
HTS,3.191521,2.521765,2.809471
HnLinearTS,8.673096,10.609864,8.442265


In [23]:
curr_nll.drop('Dataset', axis=1).transpose().style.apply(highlight_min)

Model,densenet-121,densenet-169,resnet-101
Uncalibrated,34.11092,29.486126,41.483795
TS,0.900072,0.867165,0.837691
ETS,0.901822,0.867915,0.839361
MIR,0.920553,0.89365,0.861558
BTS,inf,1.269112,1.108864
PTS,inf,inf,inf
LinearTS,inf,inf,inf
HTS,0.897414,0.863451,0.832567
HnLinearTS,inf,inf,inf


In [24]:
curr_bri.drop('Dataset', axis=1).transpose().style.apply(highlight_min)

Model,densenet-121,densenet-169,resnet-101
Uncalibrated,0.352798,0.347038,0.337751
TS,0.323056,0.314235,0.305229
ETS,0.323294,0.314541,0.305641
MIR,0.327902,0.321847,0.310696
BTS,0.329046,0.324452,0.311545
PTS,0.441787,0.447073,0.426614
LinearTS,0.356361,0.350503,0.340839
HTS,0.323032,0.314137,0.305312
HnLinearTS,0.35883,0.356461,0.335553


### SVHN

In [25]:
curr_ece = ECE.loc[ECE.Dataset=='svhn'].set_index('Model')
curr_nll = NLL.loc[ECE.Dataset=='svhn'].set_index('Model')
curr_bri = Bri.loc[ECE.Dataset=='svhn'].set_index('Model')

pd.concat([curr_ece, curr_nll, curr_bri], axis=1, keys=['ECE', 'NLL', 'Brier'], names=['Metric', 'Model'])

Metric,ECE,ECE,ECE,ECE,ECE,ECE,ECE,ECE,ECE,ECE,...,Brier,Brier,Brier,Brier,Brier,Brier,Brier,Brier,Brier,Brier
Model,Dataset,Uncalibrated,TS,ETS,MIR,BTS,PTS,LinearTS,HTS,HnLinearTS,...,Dataset,Uncalibrated,TS,ETS,MIR,BTS,PTS,LinearTS,HTS,HnLinearTS
Model,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
densenet-121,svhn,2.0588,1.44586,2.573994,1.04385,0.7238,0.875031,1.541808,0.976599,0.987368,...,svhn,0.054979,0.052406,0.052608,0.052778,0.052084,0.053008,0.052463,0.052129,0.052118
densenet-169,svhn,0.491083,1.037108,1.156383,1.096983,0.933284,0.864652,1.020738,0.980455,0.999451,...,svhn,0.051542,0.052033,0.052127,0.052481,0.052088,0.052393,0.052086,0.051822,0.05191
wide-resnet-28x10,svhn,1.553618,1.083095,1.589714,1.14695,1.096086,0.877172,1.049163,1.08699,1.040062,...,svhn,0.054144,0.052815,0.053107,0.053342,0.053086,0.053571,0.05261,0.052786,0.052653
wide-resnet-40x10,svhn,1.330812,1.276223,2.535067,1.201622,0.880731,0.757974,1.298697,1.085646,1.006949,...,svhn,0.049306,0.048192,0.048666,0.04879,0.048211,0.049079,0.048141,0.048096,0.048017


In [26]:
curr_ece.drop('Dataset', axis=1).transpose().style.apply(highlight_min)

Model,densenet-121,densenet-169,wide-resnet-28x10,wide-resnet-40x10
Uncalibrated,2.0588,0.491083,1.553618,1.330812
TS,1.44586,1.037108,1.083095,1.276223
ETS,2.573994,1.156383,1.589714,2.535067
MIR,1.04385,1.096983,1.14695,1.201622
BTS,0.7238,0.933284,1.096086,0.880731
PTS,0.875031,0.864652,0.877172,0.757974
LinearTS,1.541808,1.020738,1.049163,1.298697
HTS,0.976599,0.980455,1.08699,1.085646
HnLinearTS,0.987368,0.999451,1.040062,1.006949


In [27]:
curr_nll.drop('Dataset', axis=1).transpose().style.apply(highlight_min)

Model,densenet-121,densenet-169,wide-resnet-28x10,wide-resnet-40x10
Uncalibrated,10.556016,11.286348,11.52866,10.728325
TS,0.141412,0.12926,0.134269,0.126885
ETS,0.151157,0.129777,0.13757,0.137088
MIR,0.140835,0.129976,0.135055,0.126872
BTS,0.168135,0.151598,0.144884,0.136994
PTS,inf,0.131866,0.136933,inf
LinearTS,0.14158,0.1293,0.132851,0.125792
HTS,0.14017,0.12862,0.13405,0.1271
HnLinearTS,0.139976,0.128676,0.133045,0.125753


In [28]:
curr_bri.drop('Dataset', axis=1).transpose().style.apply(highlight_min)

Model,densenet-121,densenet-169,wide-resnet-28x10,wide-resnet-40x10
Uncalibrated,0.054979,0.051542,0.054144,0.049306
TS,0.052406,0.052033,0.052815,0.048192
ETS,0.052608,0.052127,0.053107,0.048666
MIR,0.052778,0.052481,0.053342,0.04879
BTS,0.052084,0.052088,0.053086,0.048211
PTS,0.053008,0.052393,0.053571,0.049079
LinearTS,0.052463,0.052086,0.05261,0.048141
HTS,0.052129,0.051822,0.052786,0.048096
HnLinearTS,0.052118,0.05191,0.052653,0.048017
