# Benchmark NICE as calibration method

In [1]:
%load_ext autoreload

In [2]:
%autoreload 1

In [3]:
import os
import sys
import time
import importlib
import collections
sys.path.append('..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.special import softmax
from sklearn.isotonic import IsotonicRegression
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Activation, Input

pd.set_option('colheader_justify', 'center')

%aimport utils
%aimport utils.ops
%aimport utils.metrics
%aimport utils.visualization
%aimport utils.data
%aimport flows.nice
%aimport flows.normalizing_flows
%aimport calibrators
from utils.ops import onehot_encode, optim_temperature, detection_log_likelihood_ratios
from utils.metrics import neg_log_likelihood, accuracy, expected_calibration_error
from utils.visualization import plot_pdf_simplex, plot_prob_simplex, reliability_plot, ECE_plot
from utils.data import get_cifar10
from flows.nice import NiceFlow
from calibrators import PAVCalibrator, NiceCalibrator, TempScalingCalibrator, MLRCalibrator, PlanarFlowCalibrator, DummyCalibrator, RadialFlowCalibrator

In [4]:
def highlight_max(s):
    is_max = s == s.max()
    return ['font-weight: bold' if v else '' for v in is_max]

In [5]:
def highlight_min(s):
    is_min = s == s.min()
    return ['font-weight: bold' if v else '' for v in is_min]

## CIFAR-100

In [6]:
models = [
    'wide-resnet-28x10',
    'densenet-121',
    'densenet-169',
    'resnet-101',
    'vgg-19',
    'preactresnet-18',
    'preactresnet-164',
    'resnext-29_8x16',
    'wide-resnet-40x10',
]

models = [
    'wide-resnet-28x10',
    'densenet-121',
    'resnet-101',
    'vgg-19',
    'resnext-29_8x16',
]

In [7]:
def score(calibrator, logits, target):
    probs = calibrator.predict(logits)
    nll = neg_log_likelihood(probs, target)
    ece = expected_calibration_error(probs, target, bins=15)
    acc = accuracy(probs, target)
    
    metrics = {
        'NLL': nll,
        'ECE': ece,
        'Accuracy': acc,
    }
    
    return metrics

In [8]:
def train_calibrator(Calibrator, logits, target):
    cal = Calibrator(logits, target)
    return cal

In [9]:
def train_and_evaluate_calibrators(logits, target, test_logits, test_target, calibrators, **kwargs):
    ## Train NICE on test set to obtain NLLmin
    nice_ref_cal = NiceCalibrator(test_logits, test_target, **kwargs['nice_ref_args'])
    ref_results = score(nice_ref_cal, test_logits, test_target)

    results = collections.OrderedDict()
    for cal, Calibrator in calibrators.items():
        t0 = time.time()
        model = train_calibrator(Calibrator, logits, target)
        t1 = time.time() - t0
        print("Calibrator {} fitted in {:.2f}s".format(cal, t1))
        results[cal] = {'Training time': t1,
                        'Validation': score(model, logits, target),
                        'Test': score(model, test_logits, test_target)}
    
    return results, ref_results

In [10]:
results = collections.OrderedDict()
ref_results = collections.OrderedDict()

calibrators = {
    'Uncalibrated': DummyCalibrator,
    'Temp-Scaling': TempScalingCalibrator,
    'MLR': MLRCalibrator,
}

nice_args = [
    {
        'layers': 2,
        'hidden_size': [2],
        'epochs': 500,
    }, {
        'layers': 2,
        'hidden_size': [5],
        'epochs': 500,
    }, {
        'layers': 4,
        'hidden_size': [100, 100],
        'epochs': 1000,
    },
]

planar_args = [
    {
        'layers': 5,
        'epochs': 1000,
    },{
        'layers': 10,
        'epochs': 1000,
    },{
        'layers': 20,
        'epochs': 1000,
    }
]

radial_args = [
    {
        'layers': 5,
        'epochs': 2000,
    },{
        'layers': 10,
        'epochs': 2000,
    },{
        'layers': 20,
        'epochs': 2000,
    }
]
    

for nice in nice_args:
    name = 'NICE_l{}_hs{}'.format(nice['layers'], nice['hidden_size'])
    calibrators[name] = lambda logits, target: NiceCalibrator(logits, target, **nice)

for planar in planar_args:
    name = 'Planar_l{}'.format(planar['layers'])
    calibrators[name] = lambda logits, target: PlanarFlowCalibrator(logits, target, **planar)
    
for radial in radial_args:
    name = 'Radial_l{}'.format(radial['layers'])
    calibrators[name] = lambda logits, target: RadialFlowCalibrator(logits, target, **radial)
    

nice_ref_args = {
    'layers': 4,
    'hidden_size': [100, 100],
    'epochs': 1000,
}

kwargs = {
    'nice_ref_args': nice_ref_args,
}

for model in models:
    print("Calibrating model: {}".format(model))
    data_path = os.path.join('../data', model+'_cifar100')
    prefix = os.path.join(data_path, 'cifar100_'+model)
    
    logits = np.load(prefix + '_logit_prediction_valid.npy')
    test_logits = np.load(prefix + '_logit_prediction_test.npy')
    
    target = np.load(prefix + '_true_valid.npy')
    test_target = np.load(prefix + '_true_test.npy')

    results[model], ref_results[model] = train_and_evaluate_calibrators(logits, 
                                                                        target, 
                                                                        test_logits, 
                                                                        test_target, 
                                                                        calibrators,
                                                                        **kwargs)
    

Calibrating model: wide-resnet-28x10
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Use tf.cast instead.


  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


Calibrator Uncalibrated fitted in 0.00s
Calibrator Temp-Scaling fitted in 0.25s
Calibrator MLR fitted in 0.93s


  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


Calibrator NICE_l2_hs[2] fitted in 199.89s
Calibrator NICE_l2_hs[5] fitted in 202.73s
Calibrator NICE_l4_hs[100, 100] fitted in 205.49s
Calibrator Planar_l5 fitted in 432.23s
Calibrator Planar_l10 fitted in 442.64s
Calibrator Planar_l20 fitted in 452.01s
Calibrator Radial_l20 fitted in 972.17s
Calibrating model: densenet-121
Calibrator Uncalibrated fitted in 0.00s
Calibrator Temp-Scaling fitted in 0.81s
Calibrator MLR fitted in 1.26s


  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


Calibrator NICE_l2_hs[2] fitted in 245.52s
Calibrator NICE_l2_hs[5] fitted in 247.67s
Calibrator NICE_l4_hs[100, 100] fitted in 253.22s
Calibrator Planar_l5 fitted in 481.31s
Calibrator Planar_l10 fitted in 489.19s
Calibrator Planar_l20 fitted in 497.81s
Calibrator Radial_l20 fitted in 1045.71s
Calibrating model: resnet-101
Calibrator Uncalibrated fitted in 0.00s
Calibrator Temp-Scaling fitted in 0.71s
Calibrator MLR fitted in 1.59s


  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


Calibrator NICE_l2_hs[2] fitted in 293.94s
Calibrator NICE_l2_hs[5] fitted in 297.78s
Calibrator NICE_l4_hs[100, 100] fitted in 301.55s
Calibrator Planar_l5 fitted in 539.81s
Calibrator Planar_l10 fitted in 548.68s
Calibrator Planar_l20 fitted in 568.17s
Calibrator Radial_l20 fitted in 1124.04s
Calibrating model: vgg-19
Calibrator Uncalibrated fitted in 0.00s
Calibrator Temp-Scaling fitted in 0.64s
Calibrator MLR fitted in 1.21s


  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


Calibrator NICE_l2_hs[2] fitted in 346.60s
Calibrator NICE_l2_hs[5] fitted in 349.60s
Calibrator NICE_l4_hs[100, 100] fitted in 353.12s
Calibrator Planar_l5 fitted in 611.26s
Calibrator Planar_l10 fitted in 628.57s
Calibrator Planar_l20 fitted in 633.44s
Calibrator Radial_l20 fitted in 1209.54s
Calibrating model: resnext-29_8x16
Calibrator Uncalibrated fitted in 0.00s
Calibrator Temp-Scaling fitted in 0.76s
Calibrator MLR fitted in 1.48s


  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


Calibrator NICE_l2_hs[2] fitted in 398.74s
Calibrator NICE_l2_hs[5] fitted in 404.49s
Calibrator NICE_l4_hs[100, 100] fitted in 407.80s
Calibrator Planar_l5 fitted in 684.55s
Calibrator Planar_l10 fitted in 695.81s
Calibrator Planar_l20 fitted in 708.64s
Calibrator Radial_l20 fitted in 1292.41s


In [15]:
ece_val_results = {}
ece_test_results = {}

acc_val_results = {}
acc_test_results = {}

nll_val_results = {}
nll_test_results = {}

for model, model_results in results.items():
    ece_val_results[model] = {}
    ece_test_results[model] = {}
    acc_val_results[model] = {}
    acc_test_results[model] = {}
    nll_val_results[model] = {}
    nll_test_results[model] = {}
    for cal, cal_results in model_results.items():
        ece_test_results[model][cal] = cal_results['Test']['ECE']
        ece_val_results[model][cal] = cal_results['Validation']['ECE']
        
        acc_test_results[model][cal] = cal_results['Test']['Accuracy']
        acc_val_results[model][cal] = cal_results['Validation']['Accuracy']
        
        nll_test_results[model][cal] = cal_results['Test']['NLL']
        nll_val_results[model][cal] = cal_results['Validation']['NLL']

**Results on the test set:**

In [30]:
df = pd.concat([pd.DataFrame.from_dict(acc_test_results, orient='columns'),
                pd.DataFrame.from_dict(ece_test_results, orient='columns'),
                pd.DataFrame.from_dict(nll_test_results, orient='columns')],
               axis=1,keys=['ACC','ECE', 'NLL']).swaplevel(0,1,axis=1).sort_index(axis=1)

df = df[models].loc[list(calibrators.keys())]
df.style.set_properties(**{'text-align': 'center'})\
    .set_caption('CIFAR100 Test set')\
    .apply(highlight_max, subset=[(model, 'ACC') for model in df.columns.levels[0]])\
    .apply(highlight_min, subset=[(model, 'ECE') for model in df.columns.levels[0]])\
    .apply(highlight_min, subset=[(model, 'NLL') for model in df.columns.levels[0]])\
    .set_table_styles([dict(selector="th", props=[("text-align", "center")]),
                       dict(selector="caption", props=[("text-align", "center"),
                                                      ("font-size", "200%"),
                                                      ("color", "black")])])

Unnamed: 0_level_0,wide-resnet-28x10,wide-resnet-28x10,wide-resnet-28x10,densenet-121,densenet-121,densenet-121,resnet-101,resnet-101,resnet-101,vgg-19,vgg-19,vgg-19,resnext-29_8x16,resnext-29_8x16,resnext-29_8x16
Unnamed: 0_level_1,ACC,ECE,NLL,ACC,ECE,NLL,ACC,ECE,NLL,ACC,ECE,NLL,ACC,ECE,NLL
Uncalibrated,0.8039,0.0485303,0.817315,0.788,0.0872402,0.893708,0.72,0.114132,1.1341,0.727,0.176313,1.54045,0.7788,0.0967838,0.938889
Temp-Scaling,0.8039,0.0428475,0.813448,0.788,0.0352327,0.83548,0.72,0.0150577,1.00067,0.727,0.0480832,1.19965,0.7788,0.0281907,0.822031
MLR,0.8041,0.0422879,0.801043,0.7859,0.0356487,0.837193,0.7196,0.0213178,1.0025,0.7245,0.0394123,1.20897,0.7754,0.0274381,0.824304
NICE_l2_hs[2],0.7304,0.237901,3.19681,0.7093,0.252102,3.36619,0.599,0.338926,4.39794,0.6507,0.275675,3.35686,0.6611,0.287655,3.6687
NICE_l2_hs[5],0.7359,0.231491,3.06784,0.7102,0.253259,3.39388,0.6442,0.223539,2.1221,0.6459,0.305994,4.20209,0.6506,0.296666,3.73448
"NICE_l4_hs[100, 100]",0.7342,0.233021,3.08131,0.7031,0.259213,3.43177,0.5974,0.343827,4.52525,0.6418,0.317228,4.45487,0.6553,0.293781,3.78128
Planar_l5,0.7277,0.128492,1.39895,0.7088,0.136383,1.50458,0.6263,0.183335,1.7942,0.6669,0.124055,1.72924,0.6796,0.17589,1.69595
Planar_l10,0.7286,0.126464,1.39511,0.7012,0.140918,1.54297,0.6228,0.184793,1.82116,0.6632,0.124014,1.73947,0.6727,0.181377,1.75127
Planar_l20,0.7286,0.128713,1.41723,0.7046,0.142966,1.53171,0.6245,0.188896,1.85539,0.6639,0.127813,1.74236,0.6789,0.17947,1.71025
Radial_l20,0.8026,0.0355248,0.768442,0.7825,0.0359166,0.818158,0.7168,0.0302682,1.01191,0.7256,0.0373307,1.1875,0.7742,0.0285126,0.826861


In [17]:
nll_df = pd.DataFrame.from_dict(nll_test_results, orient='columns')
min_nll_df = pd.DataFrame.from_dict(ref_results, orient='columns').loc['NLL']
nll_cal_df = nll_df.subtract(min_nll_df, axis='columns')

df = pd.concat([nll_df, nll_cal_df], axis=1,keys=['NLL','NLL_cal']).swaplevel(0,1,axis=1).sort_index(axis=1)
df = df[models].loc[list(calibrators.keys())]

df.rename(columns={model: model + ' NLL_min={:.4f}'.format(min_nll_df[model]) 
                   for model in df.columns.levels[0]}, level=0, inplace=True)

df.style.set_properties(**{'text-align': 'center'})\
    .format("{:.4f}")\
    .set_caption('CIFAR100 NLL decomposition')\
    .apply(highlight_min)\
    .set_table_styles([dict(selector="th", props=[("text-align", "center")]),
                       dict(selector="caption", props=[("text-align", "center"),
                                                      ("font-size", "200%"),
                                                      ("color", "black")])])

Unnamed: 0_level_0,wide-resnet-28x10 NLL_min=0.0008,wide-resnet-28x10 NLL_min=0.0008,densenet-121 NLL_min=0.0027,densenet-121 NLL_min=0.0027,resnet-101 NLL_min=0.0008,resnet-101 NLL_min=0.0008,vgg-19 NLL_min=0.0031,vgg-19 NLL_min=0.0031,resnext-29_8x16 NLL_min=0.0036,resnext-29_8x16 NLL_min=0.0036
Unnamed: 0_level_1,NLL,NLL_cal,NLL,NLL_cal,NLL,NLL_cal,NLL,NLL_cal,NLL,NLL_cal
Uncalibrated,0.8173,0.8166,0.8937,0.891,1.1341,1.1333,1.5404,1.5373,0.9389,0.9353
Temp-Scaling,0.8134,0.8127,0.8355,0.8328,1.0007,0.9999,1.1997,1.1965,0.822,0.8184
MLR,0.801,0.8003,0.8372,0.8345,1.0025,1.0017,1.209,1.2059,0.8243,0.8207
NICE_l2_hs[2],3.1968,3.1961,3.3662,3.3635,4.3979,4.3971,3.3569,3.3537,3.6687,3.6651
NICE_l2_hs[5],3.0678,3.0671,3.3939,3.3912,2.1221,2.1213,4.2021,4.199,3.7345,3.7309
"NICE_l4_hs[100, 100]",3.0813,3.0806,3.4318,3.4291,4.5253,4.5245,4.4549,4.4518,3.7813,3.7777
Planar_l5,1.399,1.3982,1.5046,1.5019,1.7942,1.7934,1.7292,1.7261,1.696,1.6923
Planar_l10,1.3951,1.3944,1.543,1.5403,1.8212,1.8204,1.7395,1.7364,1.7513,1.7477
Planar_l20,1.4172,1.4165,1.5317,1.529,1.8554,1.8546,1.7424,1.7392,1.7103,1.7066
Radial_l20,0.7684,0.7677,0.8182,0.8155,1.0119,1.0111,1.1875,1.1844,0.8269,0.8233


**Results on the validation set:**

In [29]:
df = pd.concat([pd.DataFrame.from_dict(acc_val_results, orient='columns'),
                pd.DataFrame.from_dict(ece_val_results, orient='columns'),
                pd.DataFrame.from_dict(nll_val_results, orient='columns')],
               axis=1,keys=['ACC','ECE', 'NLL']).swaplevel(0,1,axis=1).sort_index(axis=1)

df = df[models].loc[list(calibrators.keys())]
df.style.set_properties(**{'text-align': 'center'})\
    .set_caption('CIFAR100 Validation set')\
    .apply(highlight_max, subset=[(model, 'ACC') for model in df.columns.levels[0]])\
    .apply(highlight_min, subset=[(model, 'ECE') for model in df.columns.levels[0]])\
    .apply(highlight_min, subset=[(model, 'NLL') for model in df.columns.levels[0]])\
    .set_table_styles([dict(selector="th", props=[("text-align", "center")]),
                       dict(selector="caption", props=[("text-align", "center"),
                                                      ("font-size", "200%"),
                                                      ("color", "black")])])

Unnamed: 0_level_0,wide-resnet-28x10,wide-resnet-28x10,wide-resnet-28x10,densenet-121,densenet-121,densenet-121,resnet-101,resnet-101,resnet-101,vgg-19,vgg-19,vgg-19,resnext-29_8x16,resnext-29_8x16,resnext-29_8x16
Unnamed: 0_level_1,ACC,ECE,NLL,ACC,ECE,NLL,ACC,ECE,NLL,ACC,ECE,NLL,ACC,ECE,NLL
Uncalibrated,0.7994,0.053124,0.819501,0.7826,0.0887537,0.940488,0.7218,0.110705,1.13483,0.7162,0.189706,1.5726,0.775,0.101819,0.981964
Temp-Scaling,0.7994,0.0484992,0.815698,0.7826,0.0417907,0.873531,0.7218,0.0265887,1.00431,0.7162,0.0541274,1.2135,0.775,0.0292295,0.852484
MLR,0.8064,0.0437405,0.778041,0.784,0.0395841,0.852308,0.725,0.0208451,0.984427,0.7188,0.0483676,1.19587,0.7806,0.028018,0.829799
NICE_l2_hs[2],1.0,0.0,-1.19209e-07,1.0,0.0,-1.19209e-07,0.9998,0.00018168,0.0032235,1.0,0.000147879,0.000147806,0.9998,0.0002,0.0032235
NICE_l2_hs[5],1.0,0.0,-1.19209e-07,1.0,0.0,-1.19209e-07,0.9974,0.023483,0.0336758,1.0,4.17233e-07,2.80572e-07,0.9998,0.0002,0.0032235
"NICE_l4_hs[100, 100]",1.0,0.0,-1.19209e-07,1.0,0.0,-1.19209e-07,0.9998,0.0002,0.0032235,1.0,0.0,-1.19209e-07,0.9998,0.0002,0.0032235
Planar_l5,0.9782,0.0737056,0.141445,0.9806,0.0773443,0.142612,0.9256,0.0864651,0.313639,0.8646,0.0648571,0.525361,0.9892,0.0675931,0.108938
Planar_l10,0.9582,0.0638565,0.1881,0.9738,0.0770086,0.153926,0.9324,0.0902419,0.29512,0.8566,0.0597846,0.548387,0.9906,0.0653567,0.10044
Planar_l20,0.9816,0.0719755,0.134239,0.9764,0.0771389,0.151749,0.941,0.0958206,0.278289,0.8628,0.0593693,0.519137,0.9894,0.0673455,0.106569
Radial_l20,0.8126,0.0274304,0.715562,0.7888,0.0326073,0.796329,0.7254,0.0304592,0.992393,0.72,0.0537175,1.1666,0.782,0.0287883,0.82807
