# Analysis of AdaTempScal on LeNet5 for CIFAR10

In [1]:
%load_ext autoreload

In [2]:
%autoreload 1

In [3]:
import os
import sys
import time
sys.path.extend(['..'])

import numpy as np
import matplotlib.pyplot as plt
from scipy.special import softmax
import torch

%aimport models
from models import AdaptiveTempScaling, TempScaling
%aimport utils
from utils import compare_results

In [4]:
%matplotlib inline

In [5]:
plt.rc('text', usetex=True)
plt.rc('text.latex', preamble=r'\usepackage{amsmath} \usepackage{amssymb}')

## Load data and precomputed logits

In [6]:
CIFAR10_PATH = '../../data/CIFAR10'
MODEL_PATH = '../../trained_models/CIFAR10/lenet5'

In [7]:
X_train = np.load(os.path.join(CIFAR10_PATH, 'train_imas.npy'))
y_train = np.load(os.path.join(CIFAR10_PATH, 'train_labels.npy'))

X_val = np.load(os.path.join(CIFAR10_PATH, 'val_imas.npy'))
y_val = np.load(os.path.join(CIFAR10_PATH, 'val_labels.npy'))

X_test = np.load(os.path.join(CIFAR10_PATH, 'test_imas.npy'))
y_test = np.load(os.path.join(CIFAR10_PATH, 'test_labels.npy'))

In [8]:
Z_train = np.load(os.path.join(MODEL_PATH, 'train_logits.npy'))
Z_val = np.load(os.path.join(MODEL_PATH, 'val_logits.npy'))
Z_test = np.load(os.path.join(MODEL_PATH, 'test_logits.npy'))

### Calibrate models

In [9]:
N, dim = Z_train.shape

### Temp-Scal as baseline:
tempScaler = TempScaling()
tempScaler.fit(Z_val, y_val);

In [10]:
aTempScaler = AdaptiveTempScaling(dim)
aTempScaler.fit(Z_val, y_val, v=True, lr=1e-6, epochs=30000);

Finding optimum Temperature
Adapting Weight vector.864e+00, Temp: 0.964, , at time: 0.30s


On epoch: 29994, loss: 1.863e+00, at time: 134.62s



In [11]:
print('##### Results on train set:')
compare_results(predictions={'Uncal': softmax(Z_train, axis=1),
                             'TempScal': tempScaler.predictive(Z_train),
                             'AdaptiveTempScal': aTempScaler.predictive(Z_train)}, target=y_train);

##### Results on train set:
  Calibrator      Accuracy           ECE   Brier Score           NLL
       Uncal          0.00         11.09     8.794e-01     5.122e+05
    TempScal          0.00         11.12     8.788e-01     5.121e+05
AdaptiveTempScal          0.00         11.11     8.790e-01     5.121e+05


In [12]:
print('##### Results on validation set:')
compare_results(predictions={'Uncal': Z_val,
                             'TempScal': tempScaler.predictive(Z_val),
                             'AdaptiveTempScal': aTempScaler.predictive(Z_val)}, target=y_val, M=20);

##### Results on validation set:
  Calibrator      Accuracy           ECE   Brier Score           NLL
       Uncal          0.00         19.83     7.991e-01     1.116e+05
    TempScal          0.00         11.11     8.791e-01     1.142e+05
AdaptiveTempScal          0.00         11.10     8.792e-01     1.142e+05
