# Analysis of calibration in NIST LRE17

In [1]:
%load_ext autoreload

In [2]:
%autoreload 1

In [30]:
import os
import sys
import time
import h5py
sys.path.extend(['..'])

import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.nn.functional import softmax as torch_softmax

from mixNmatch_cal import temperature_scaling
from scipy_models import LTS, HTS, HistTS, TS, HnLTS, BTS
from models import AdaTS, PTS
from models import LTS as LTS_torch
from models import HTS as HTS_torch
from models import HnLTS as HnLTS_torch
%aimport utils
from utils import compare_results, calib_split, get_CIFAR10_C, NumpyDataset, load_model, predict_logits, compute_metrics, onehot_encode, softmax, torch_entropy, torch
%aimport adats_utils
from adats_utils import fitAdaTS, fitCV_AdaTS, fitHistTS
%aimport mixNmatch_cal
from mixNmatch_cal import ets_calibrate, mir_calibrate

In [34]:
from psrcal.calibration import calibrate, AffineCalLogLoss
from psrcal.optim.vecmodule import lbfgs

### Data

In [5]:
f = h5py.File("../../data/LRE17/exp_001_w75_embeddings_lre17_post_eval.h5", 'r')

In [6]:
text_file = open("../../data/LRE17/lre17_primary_full_dev.lst_and_key", "r")
dev_keys = text_file.read().split("\n")[:-1]
text_file.close()

text_file = open("../../data/LRE17/lre17_primary_full_evl.lst_and_key", "r")
eval_keys = text_file.read().split("\n")[:-1]
text_file.close()

In [7]:
dev_target = {}
for file in dev_keys:
    k, v = file.split(" ")
    dev_target[k] = v
    
eval_target = {}
for file in eval_keys:
    k, v = file.split(" ")
    eval_target[k] = v

In [8]:
lang2ix = {l.decode("utf-8"): i for i, l in enumerate(f['classes'][:])}

In [9]:
X_val = f['cal_datamat'][:]
X_test = f['eval_datamat'][:]

Y_val = np.zeros(f['cal_datamat'].shape[0], dtype=np.int64)
for i, name in enumerate(f['cal_segnames'][:]):
    Y_val[i] = lang2ix[dev_target[name.decode("utf-8")]]
    
Y_test = np.zeros(f['eval_datamat'].shape[0], dtype=np.int64)
for i, name in enumerate(f['eval_segnames'][:]):
    Y_test[i] = lang2ix[eval_target[name.decode("utf-8")]]

### Calibration

In [10]:
dim = 14

### Temp-Scal as baseline:
tempScaler = TS(dim)
tempScaler.fit(X_val, Y_val, v=True);

Optimization terminated successfully.
         Current function value: 4306.638409
         Iterations: 14
         Function evaluations: 15
         Gradient evaluations: 15


In [45]:
tempScaler.t

array([13.95839057])

In [36]:
focal = AffineCalLogLoss(torch.as_tensor(X_val), torch.as_tensor(Y_val))
paramvec, value, curve = lbfgs(focal, 100, quiet=True)

In [43]:
paramvec

array([ 0.09604245, -1.71607284, -2.3634115 ,  1.40027348, -1.48369855,
       -2.02408611, -0.96301439,  2.05101832,  2.88106684,  4.4172876 ,
       -1.11706341,  0.82904941, -1.11411781, -1.70555521,  0.90832418])

In [23]:
pts = AdaTS(PTS(dim)).double()
pts = fitAdaTS(pts, X_val, Y_val, epochs=30000, batch_size=100, lr=1e-3, v=True)

On epoch: 1624, loss: 4.069e+03, at time: 169.46s
Finish training, convergence reached. Loss: 4069.10 



In [24]:
hts_t = AdaTS(HTS_torch(dim))
hts_t = fitAdaTS(hts_t, X_val, Y_val, epochs=30000, batch_size=100, lr=1e-3, v=True)

On epoch: 1924, loss: 4.291e+03, at time: 135.26s
Finish training, convergence reached. Loss: 4290.76 



In [25]:
hisTS = HistTS()
hisTS.fit(X_val, Y_val)

In [39]:
print('##### Results on val set:')
compare_results(predictions={'Uncal': softmax(X_val, axis=1),
                             'TempScal': tempScaler.predictive(X_val),
                             'HTS_torch': hts_t.predictive(torch.as_tensor(X_val)),
                             'PTS': pts.predictive(torch.as_tensor(X_val)),
                             'histTS': hisTS.predictive(torch.as_tensor(X_val)),
                             'focal': torch_softmax(focal.calibrate(torch.as_tensor(X_val)), dim=1)}, target=Y_val, M=50, from_logits=False);

##### Results on val set:
Calibrator      Accuracy           ECE           MCE   Brier Score           NLL
Uncal              77.39%        20.66%        59.47%    4.238e-01     3.014e+00
TempScal           77.39%         5.51%        18.98%    3.101e-01     7.072e-01
HTS_torch          77.39%         4.67%        17.87%    3.092e-01     7.046e-01
PTS                77.39%         2.66%        22.60%    3.007e-01     6.682e-01
histTS             77.39%        20.66%        59.47%    4.238e-01     3.014e+00
focal              83.55%         4.38%        22.09%    2.379e-01     5.029e-01


In [41]:
print('##### Results on test set:')
compare_results(predictions={'Uncal': softmax(X_test, axis=1),
                             'TempScal': tempScaler.predictive(X_test),
                             'HTS_torch': hts_t.predictive(torch.as_tensor(X_test)),
                             'PTS': pts.predictive(torch.as_tensor(X_test)),
                             'histTS': hisTS.predictive(torch.as_tensor(X_test)),
                             'focal': torch_softmax(focal.calibrate(torch.as_tensor(X_test)), dim=1)}, target=Y_test, M=50, from_logits=False);

##### Results on test set:
Calibrator      Accuracy           ECE           MCE   Brier Score           NLL
Uncal              70.28%        27.47%        59.83%    5.640e-01     4.223e+00
TempScal           70.28%         4.29%        15.44%    3.904e-01     8.863e-01
HTS_torch          70.28%         3.74%        15.63%    3.903e-01     8.823e-01
PTS                70.28%         4.48%        11.68%    3.918e-01     9.064e-01
histTS             70.28%        27.47%        59.83%    5.640e-01     4.223e+00
focal              75.74%         3.14%        17.17%    3.303e-01     7.022e-01
