# Benchmark NICE as calibration method

In [1]:
%load_ext autoreload

In [2]:
%autoreload 1

In [3]:
import os
import sys
import time
import importlib
import collections
sys.path.append('..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.special import softmax
from sklearn.isotonic import IsotonicRegression
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Activation, Input

pd.set_option('colheader_justify', 'center')

%aimport utils
%aimport utils.ops
%aimport utils.metrics
%aimport utils.visualization
%aimport utils.data
%aimport flows.nice
%aimport flows.normalizing_flows
%aimport calibrators
from utils.ops import onehot_encode, optim_temperature, detection_log_likelihood_ratios
from utils.metrics import neg_log_likelihood, accuracy, expected_calibration_error
from utils.visualization import plot_pdf_simplex, plot_prob_simplex, reliability_plot, ECE_plot
from utils.data import get_cifar10
from flows.nice import NiceFlow
from calibrators import PAVCalibrator, NiceCalibrator, TempScalingCalibrator, MLRCalibrator, PlanarFlowCalibrator, DummyCalibrator, RadialFlowCalibrator

In [4]:
def highlight_max(s):
    is_max = s == s.max()
    return ['font-weight: bold' if v else '' for v in is_max]

In [5]:
def highlight_min(s):
    is_min = s == s.min()
    return ['font-weight: bold' if v else '' for v in is_min]

## CIFAR-100

In [6]:
models = [
    'wide-resnet-28x10',
    'densenet-121',
    'densenet-169',
    'resnet-101',
    'vgg-19',
    'preactresnet-18',
    'preactresnet-164',
    'resnext-29_8x16',
    'wide-resnet-40x10',
]

models = [
    'wide-resnet-28x10',
    'densenet-121',
    'resnet-101',
    'vgg-19',
    'resnext-29_8x16',
]

In [7]:
def score(calibrator, logits, target):
    probs = calibrator.predict(logits)
    nll = neg_log_likelihood(probs, target)
    ece = expected_calibration_error(probs, target, bins=15)
    acc = accuracy(probs, target)
    
    metrics = {
        'NLL': nll,
        'ECE': ece,
        'Accuracy': acc,
    }
    
    return metrics

In [8]:
def train_calibrator(Calibrator, logits, target):
    cal = Calibrator(logits, target)
    return cal

In [9]:
def train_and_evaluate_calibrators(logits, target, test_logits, test_target, calibrators, **kwargs):
    ## Train NICE on test set to obtain NLLmin
    nice_ref_cal = NiceCalibrator(test_logits, test_target, **kwargs['nice_ref_args'])
    ref_results = score(nice_ref_cal, test_logits, test_target)

    results = collections.OrderedDict()
    for cal, Calibrator in calibrators.items():
        t0 = time.time()
        model = train_calibrator(Calibrator, logits, target)
        t1 = time.time() - t0
        print("Calibrator {} fitted in {:.2f}s".format(cal, t1))
        results[cal] = {'Training time': t1,
                        'Validation': score(model, logits, target),
                        'Test': score(model, test_logits, test_target)}
    
    return results, ref_results

In [10]:
results = collections.OrderedDict()
ref_results = collections.OrderedDict()

calibrators = {
    'Uncalibrated': DummyCalibrator,
    'Temp-Scaling': TempScalingCalibrator,
    'MLR': MLRCalibrator,
}

nice_args = [
    {
        'layers': 2,
        'hidden_size': [2],
        'epochs': 500,
    }, {
        'layers': 2,
        'hidden_size': [5],
        'epochs': 500,
    }, {
        'layers': 4,
        'hidden_size': [100, 100],
        'epochs': 1000,
    },
]

planar_args = [
    {
        'layers': 5,
        'epochs': 2000,
    },{
        'layers': 10,
        'epochs': 2000,
    },{
        'layers': 20,
        'epochs': 2000,
    }
]

radial_args = [
    {
        'layers': 5,
        'epochs': 2000,
    },{
        'layers': 10,
        'epochs': 2000,
    },{
        'layers': 20,
        'epochs': 2000,
    }
]
    

for nice in nice_args:
    name = 'NICE_l{}_hs{}'.format(nice['layers'], nice['hidden_size'])
    calibrators[name] = lambda logits, target: NiceCalibrator(logits, target, **nice)

for planar in planar_args:
    name = 'Planar_l{}'.format(planar['layers'])
    calibrators[name] = lambda logits, target: PlanarFlowCalibrator(logits, target, **planar)
    
for radial in radial_args:
    name = 'Radial_l{}'.format(radial['layers'])
    calibrators[name] = lambda logits, target: RadialFlowCalibrator(logits, target, **radial)
    

nice_ref_args = {
    'layers': 4,
    'hidden_size': [100, 100],
    'epochs': 1000,
}

kwargs = {
    'nice_ref_args': nice_ref_args,
}

for model in models:
    print("Calibrating model: {}".format(model))
    data_path = os.path.join('../data', model+'_cifar100')
    prefix = os.path.join(data_path, 'cifar100_'+model)
    
    logits = np.load(prefix + '_logit_prediction_valid.npy')
    test_logits = np.load(prefix + '_logit_prediction_test.npy')
    
    target = np.load(prefix + '_true_valid.npy')
    test_target = np.load(prefix + '_true_test.npy')

    results[model], ref_results[model] = train_and_evaluate_calibrators(logits, 
                                                                        target, 
                                                                        test_logits, 
                                                                        test_target, 
                                                                        calibrators,
                                                                        **kwargs)
    

Calibrating model: wide-resnet-28x10
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Use tf.cast instead.


  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


Calibrator Uncalibrated fitted in 0.00s
Calibrator Temp-Scaling fitted in 0.22s
Calibrator MLR fitted in 0.89s


  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


Calibrator NICE_l2_hs[2] fitted in 200.43s
Calibrator NICE_l2_hs[5] fitted in 202.42s
Calibrator NICE_l4_hs[100, 100] fitted in 205.25s
Calibrator Planar_l5 fitted in 857.10s
Calibrator Planar_l10 fitted in 874.91s
Calibrator Planar_l20 fitted in 895.35s
Calibrator Radial_l5 fitted in 972.56s
Calibrator Radial_l10 fitted in 980.99s
Calibrator Radial_l20 fitted in 996.60s
Calibrating model: densenet-121
Calibrator Uncalibrated fitted in 0.00s
Calibrator Temp-Scaling fitted in 0.79s
Calibrator MLR fitted in 1.24s


  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


Calibrator NICE_l2_hs[2] fitted in 258.57s
Calibrator NICE_l2_hs[5] fitted in 263.00s
Calibrator NICE_l4_hs[100, 100] fitted in 265.08s
Calibrator Planar_l5 fitted in 978.71s
Calibrator Planar_l10 fitted in 996.17s
Calibrator Planar_l20 fitted in 1012.81s
Calibrator Radial_l5 fitted in 1072.62s
Calibrator Radial_l10 fitted in 1086.77s
Calibrator Radial_l20 fitted in 1098.64s
Calibrating model: resnet-101
Calibrator Uncalibrated fitted in 0.00s
Calibrator Temp-Scaling fitted in 0.70s
Calibrator MLR fitted in 1.56s


  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


Calibrator NICE_l2_hs[2] fitted in 326.68s
Calibrator NICE_l2_hs[5] fitted in 330.33s
Calibrator NICE_l4_hs[100, 100] fitted in 337.86s
Calibrator Planar_l5 fitted in 1154.37s
Calibrator Planar_l10 fitted in 1172.33s
Calibrator Planar_l20 fitted in 1197.01s
Calibrator Radial_l5 fitted in 1192.57s
Calibrator Radial_l10 fitted in 1199.76s
Calibrator Radial_l20 fitted in 1215.53s
Calibrating model: vgg-19
Calibrator Uncalibrated fitted in 0.00s
Calibrator Temp-Scaling fitted in 0.59s
Calibrator MLR fitted in 1.15s


  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


Calibrator NICE_l2_hs[2] fitted in 397.86s
Calibrator NICE_l2_hs[5] fitted in 401.54s
Calibrator NICE_l4_hs[100, 100] fitted in 404.91s
Calibrator Planar_l5 fitted in 1337.63s
Calibrator Planar_l10 fitted in 1364.52s
Calibrator Planar_l20 fitted in 1376.36s
Calibrator Radial_l5 fitted in 1289.49s
Calibrator Radial_l10 fitted in 1301.51s
Calibrator Radial_l20 fitted in 1318.75s
Calibrating model: resnext-29_8x16
Calibrator Uncalibrated fitted in 0.00s
Calibrator Temp-Scaling fitted in 0.73s
Calibrator MLR fitted in 1.46s


  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


Calibrator NICE_l2_hs[2] fitted in 470.43s
Calibrator NICE_l2_hs[5] fitted in 473.97s
Calibrator NICE_l4_hs[100, 100] fitted in 478.83s
Calibrator Planar_l5 fitted in 1511.57s
Calibrator Planar_l10 fitted in 1528.09s
Calibrator Planar_l20 fitted in 1549.40s
Calibrator Radial_l5 fitted in 1402.44s
Calibrator Radial_l10 fitted in 1417.76s
Calibrator Radial_l20 fitted in 1430.37s


In [11]:
ece_val_results = {}
ece_test_results = {}

acc_val_results = {}
acc_test_results = {}

nll_val_results = {}
nll_test_results = {}

for model, model_results in results.items():
    ece_val_results[model] = {}
    ece_test_results[model] = {}
    acc_val_results[model] = {}
    acc_test_results[model] = {}
    nll_val_results[model] = {}
    nll_test_results[model] = {}
    for cal, cal_results in model_results.items():
        ece_test_results[model][cal] = cal_results['Test']['ECE']
        ece_val_results[model][cal] = cal_results['Validation']['ECE']
        
        acc_test_results[model][cal] = cal_results['Test']['Accuracy']
        acc_val_results[model][cal] = cal_results['Validation']['Accuracy']
        
        nll_test_results[model][cal] = cal_results['Test']['NLL']
        nll_val_results[model][cal] = cal_results['Validation']['NLL']

**Results on the test set:**

In [12]:
df = pd.concat([pd.DataFrame.from_dict(acc_test_results, orient='columns'),
                pd.DataFrame.from_dict(ece_test_results, orient='columns'),
                pd.DataFrame.from_dict(nll_test_results, orient='columns')],
               axis=1,keys=['ACC','ECE', 'NLL']).swaplevel(0,1,axis=1).sort_index(axis=1)

df = df[models].loc[list(calibrators.keys())]
df.style.set_properties(**{'text-align': 'center'})\
    .set_caption('CIFAR100 Test set')\
    .apply(highlight_max, subset=[(model, 'ACC') for model in df.columns.levels[0]])\
    .apply(highlight_min, subset=[(model, 'ECE') for model in df.columns.levels[0]])\
    .apply(highlight_min, subset=[(model, 'NLL') for model in df.columns.levels[0]])\
    .set_table_styles([dict(selector="th", props=[("text-align", "center")]),
                       dict(selector="caption", props=[("text-align", "center"),
                                                      ("font-size", "200%"),
                                                      ("color", "black")])])

Unnamed: 0_level_0,wide-resnet-28x10,wide-resnet-28x10,wide-resnet-28x10,densenet-121,densenet-121,densenet-121,resnet-101,resnet-101,resnet-101,vgg-19,vgg-19,vgg-19,resnext-29_8x16,resnext-29_8x16,resnext-29_8x16
Unnamed: 0_level_1,ACC,ECE,NLL,ACC,ECE,NLL,ACC,ECE,NLL,ACC,ECE,NLL,ACC,ECE,NLL
Uncalibrated,0.8039,0.0485303,0.817315,0.788,0.0872402,0.893708,0.72,0.114132,1.1341,0.727,0.176313,1.54045,0.7788,0.0967838,0.938889
Temp-Scaling,0.8039,0.0428475,0.813448,0.788,0.0352327,0.83548,0.72,0.0150577,1.00067,0.727,0.0480832,1.19965,0.7788,0.0281907,0.822031
MLR,0.8041,0.0422879,0.801043,0.7859,0.0356487,0.837193,0.7196,0.0213178,1.0025,0.7245,0.0394123,1.20897,0.7754,0.0274381,0.824304
NICE_l2_hs[2],0.7318,0.235007,3.1209,0.7264,0.20312,2.19934,0.6231,0.275744,2.96796,0.6429,0.313602,4.43316,0.6556,0.2923,3.72466
NICE_l2_hs[5],0.7344,0.232594,3.12112,0.7017,0.259215,3.39011,0.598,0.343734,4.52631,0.6451,0.310299,4.40277,0.6638,0.284071,3.61833
"NICE_l4_hs[100, 100]",0.736,0.23129,3.12681,0.7118,0.250039,3.3034,0.6285,0.265682,2.72584,0.6469,0.31115,4.40665,0.6597,0.289329,3.64702
Planar_l5,0.7142,0.145612,1.56565,0.7,0.152762,1.62687,0.6072,0.208931,1.99647,0.6599,0.135866,1.82431,0.6688,0.190314,1.84781
Planar_l10,0.7269,0.136137,1.48058,0.6952,0.157786,1.68608,0.61,0.206482,1.97574,0.6638,0.131986,1.79552,0.6692,0.191464,1.83067
Planar_l20,0.7239,0.138433,1.50653,0.6975,0.156408,1.69207,0.615,0.200117,1.96602,0.6644,0.129499,1.79204,0.666,0.194069,1.87198
Radial_l5,0.8004,0.0410493,0.795275,0.7856,0.0343245,0.821066,0.7174,0.029446,1.01191,0.7246,0.0368086,1.1883,0.7741,0.0276305,0.825373


In [13]:
nll_df = pd.DataFrame.from_dict(nll_test_results, orient='columns')
min_nll_df = pd.DataFrame.from_dict(ref_results, orient='columns').loc['NLL']
nll_cal_df = nll_df.subtract(min_nll_df, axis='columns')

df = pd.concat([nll_df, nll_cal_df], axis=1,keys=['NLL','NLL_cal']).swaplevel(0,1,axis=1).sort_index(axis=1)
df = df[models].loc[list(calibrators.keys())]

df.rename(columns={model: model + ' NLL_min={:.4f}'.format(min_nll_df[model]) 
                   for model in df.columns.levels[0]}, level=0, inplace=True)

df.style.set_properties(**{'text-align': 'center'})\
    .format("{:.4f}")\
    .set_caption('CIFAR100 NLL decomposition')\
    .apply(highlight_min)\
    .set_table_styles([dict(selector="th", props=[("text-align", "center")]),
                       dict(selector="caption", props=[("text-align", "center"),
                                                      ("font-size", "200%"),
                                                      ("color", "black")])])

Unnamed: 0_level_0,wide-resnet-28x10 NLL_min=0.0006,wide-resnet-28x10 NLL_min=0.0006,densenet-121 NLL_min=0.0026,densenet-121 NLL_min=0.0026,resnet-101 NLL_min=0.0025,resnet-101 NLL_min=0.0025,vgg-19 NLL_min=0.0033,vgg-19 NLL_min=0.0033,resnext-29_8x16 NLL_min=0.0026,resnext-29_8x16 NLL_min=0.0026
Unnamed: 0_level_1,NLL,NLL_cal,NLL,NLL_cal,NLL,NLL_cal,NLL,NLL_cal,NLL,NLL_cal
Uncalibrated,0.8173,0.8167,0.8937,0.8911,1.1341,1.1316,1.5404,1.5371,0.9389,0.9362
Temp-Scaling,0.8134,0.8128,0.8355,0.8328,1.0007,0.9981,1.1997,1.1963,0.822,0.8194
MLR,0.801,0.8004,0.8372,0.8346,1.0025,1.0,1.209,1.2056,0.8243,0.8217
NICE_l2_hs[2],3.1209,3.1203,2.1993,2.1967,2.968,2.9654,4.4332,4.4298,3.7247,3.722
NICE_l2_hs[5],3.1211,3.1205,3.3901,3.3875,4.5263,4.5238,4.4028,4.3995,3.6183,3.6157
"NICE_l4_hs[100, 100]",3.1268,3.1262,3.3034,3.3008,2.7258,2.7233,4.4066,4.4033,3.647,3.6444
Planar_l5,1.5656,1.565,1.6269,1.6242,1.9965,1.9939,1.8243,1.821,1.8478,1.8452
Planar_l10,1.4806,1.48,1.6861,1.6834,1.9757,1.9732,1.7955,1.7922,1.8307,1.828
Planar_l20,1.5065,1.5059,1.6921,1.6894,1.966,1.9635,1.792,1.7887,1.872,1.8693
Radial_l5,0.7953,0.7947,0.8211,0.8184,1.0119,1.0094,1.1883,1.185,0.8254,0.8227


**Results on the validation set:**

In [14]:
df = pd.concat([pd.DataFrame.from_dict(acc_val_results, orient='columns'),
                pd.DataFrame.from_dict(ece_val_results, orient='columns'),
                pd.DataFrame.from_dict(nll_val_results, orient='columns')],
               axis=1,keys=['ACC','ECE', 'NLL']).swaplevel(0,1,axis=1).sort_index(axis=1)

df = df[models].loc[list(calibrators.keys())]
df.style.set_properties(**{'text-align': 'center'})\
    .set_caption('CIFAR100 Validation set')\
    .apply(highlight_max, subset=[(model, 'ACC') for model in df.columns.levels[0]])\
    .apply(highlight_min, subset=[(model, 'ECE') for model in df.columns.levels[0]])\
    .apply(highlight_min, subset=[(model, 'NLL') for model in df.columns.levels[0]])\
    .set_table_styles([dict(selector="th", props=[("text-align", "center")]),
                       dict(selector="caption", props=[("text-align", "center"),
                                                      ("font-size", "200%"),
                                                      ("color", "black")])])

Unnamed: 0_level_0,wide-resnet-28x10,wide-resnet-28x10,wide-resnet-28x10,densenet-121,densenet-121,densenet-121,resnet-101,resnet-101,resnet-101,vgg-19,vgg-19,vgg-19,resnext-29_8x16,resnext-29_8x16,resnext-29_8x16
Unnamed: 0_level_1,ACC,ECE,NLL,ACC,ECE,NLL,ACC,ECE,NLL,ACC,ECE,NLL,ACC,ECE,NLL
Uncalibrated,0.7994,0.053124,0.819501,0.7826,0.0887537,0.940488,0.7218,0.110705,1.13483,0.7162,0.189706,1.5726,0.775,0.101819,0.981964
Temp-Scaling,0.7994,0.0484992,0.815698,0.7826,0.0417907,0.873531,0.7218,0.0265887,1.00431,0.7162,0.0541274,1.2135,0.775,0.0292295,0.852484
MLR,0.8064,0.0437405,0.778041,0.784,0.0395841,0.852308,0.725,0.0208451,0.984427,0.7188,0.0483676,1.19587,0.7806,0.028018,0.829799
NICE_l2_hs[2],1.0,0.0,-1.19209e-07,0.9998,0.00031564,0.00373948,0.9998,0.000504596,0.00357627,1.0,0.0,-1.19209e-07,1.0,0.0,-1.19209e-07
NICE_l2_hs[5],1.0,0.0,-1.19209e-07,1.0,0.0,-1.19209e-07,1.0,0.0,-1.19209e-07,1.0,0.0,-1.19209e-07,1.0,0.0,-1.19209e-07
"NICE_l4_hs[100, 100]",1.0,0.0,-1.19209e-07,1.0,0.0,-1.19209e-07,1.0,0.00138545,0.0013884,1.0,0.0,-1.19209e-07,1.0,0.0,-1.19209e-07
Planar_l5,0.983,0.064075,0.116461,0.9634,0.0687487,0.175855,0.9528,0.0924162,0.24394,0.8692,0.0632886,0.497001,0.9826,0.0655583,0.124318
Planar_l10,0.9854,0.0665541,0.113894,0.991,0.0691636,0.104787,0.953,0.0965378,0.250392,0.8722,0.0691382,0.508222,0.992,0.0687405,0.10193
Planar_l20,0.9868,0.0655517,0.10743,0.9844,0.0664013,0.112976,0.9462,0.0917488,0.257541,0.8642,0.0615565,0.523635,0.9854,0.0623751,0.110922
Radial_l5,0.814,0.0297793,0.70409,0.7856,0.039598,0.829357,0.7252,0.030811,0.992632,0.7194,0.0540512,1.16677,0.7816,0.0303428,0.826352
