In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.nn.utils import parameters_to_vector
from torch.optim import Adam
from torch.distributions import MultivariateNormal, Normal
from preds.models import MLPS

from preds.likelihoods import CategoricalLh
from preds.datasets import UCIClassificationDatasets
from preds.laplace import Laplace

Includes some playing around with the Immer et al. dataset and model classes and a script for running the UCI experiemnts (with hyperparameters as in the paper appendix)

Testing the dataset and model classes

In [None]:
width = 50 # as in Immer et al. 
depth = 2 # as in Immer et al. 
prior_prec = np.logspace(-2, 2, num=10)[0] # as in Immer et al. but depends on dataset
lr = 1e-3 # as in Immer et al. 
n_epochs = 10000
n_samples = 1000
train_size = 0.70 # as in Immer et al. 
lh = CateoricalLh()  
uci_dataset = 'glass'
root_dir = '../data/'
device = 'cpu'

Load training data

In [None]:
ds_train = UCIClassificationDatasets(train=True, data_set=uci_dataset, split_train_size=train_size, double=False, root=root_dir)
X_train, y_train = ds_train.data.to(device), ds_train.targets.to(device).unsqueeze(1)
train_loader = [(X_train, y_train)]  

Load validation data

In [None]:
ds_val = UCIClassificationDatasets(train=False,valid=True, data_set=uci_dataset, split_train_size=train_size, double=False, root=root_dir)
X_val, y_val = ds_val.data.to(device), ds_val.targets.to(device).unsqueeze(1)
val_loader = [(X_val, y_val)] 

In [None]:
X_train.shape

In [None]:
model = MLPS(X_train.shape[1], [width]*depth, 1, activation='tanh', flatten=False).to(device)
optim = Adam(model.parameters(), lr=lr)
losses = list()
for i in range(n_epochs):
    f = model(X_train)
    w = parameters_to_vector(model.parameters())
    reg = 0.5 * prior_prec * w @ w
    loss = - lh.log_likelihood(y_train, f) + reg
    loss.backward()
    optim.step()
    losses.append(loss.item())
    model.zero_grad()


In [None]:
lap = Laplace(model, float(prior_prec), lh)


def get_pred_for(x, model_type='glm', cov_type='full'):
    #### INFERENCE (Posterior approximation) ####
    lap.infer(train_loader, cov_type=cov_type, dampen_kron=model_type=='bnn')
    if model_type == 'glm':
        #### GLM PREDICTIVE ####
        mu, var = lap.predictive_samples_glm(x, n_samples=n_samples)
    elif model_type == 'bnn':
        #### BNN PREDICTIVE ####
        samples = lap.predictive_samples_bnn(x, n_samples=n_samples)
        mu = samples.mean(axis=0)
        var = samples.cov(axis=0)
    else:
        raise ValueError('unsupported model_type.')
    mu = mu.detach().cpu().squeeze()
    var = var.detach().cpu().squeeze()
    return mu, var

In [None]:
print(X_train.dtype)
print(X_val.dtype)
print(mu.dtype)

In [None]:
# GLM

mu_glm, var_glm = get_pred_for(X_val, 'glm', 'full')# runs
# mu_glm_kron, var_glm_kron = get_pred_for(X_val, 'glm', 'kron') # doesn't run
#mu, var = get_pred_for(X_val, 'glm', 'diag')# doesn't run

# BNN
mu_bnn, var_bnn = get_pred_for(X_val, 'bnn', 'full')

In [None]:
y_val.shape

In [None]:
lh_glm = Normal(mu_glm, var_glm)
print(-torch.mean(lh_glm.log_prob(y_val.squeeze(-1))))

In [None]:
lh_bnn = Normal(mu_bnn, var_bnn)
print(-torch.mean(lh_bnn.log_prob(y_val.squeeze(-1))))

<h2>Run the classification pipeline for chosen dataset </h2>

The classification.py script tests 10 options for the prior precision, with a fixed train/val/test split. The results reported in the paper are over 10 different train/val/test splits

In [8]:
datasets = ['australian', 'breast_cancer', 'digits', 'glass',
            'ionosphere', 'satellite', 'vehicle', 'waveform']
seeds = [711, 1, 75, 359, 17, 420, 129, 666, 69, 36]
dataset = 'breast_cancer'

In [None]:
for seed in seeds:
    if dataset in ['satellite', 'digits']:
        logmin = -1.0
    else:
        logmin = -2.0
    !python3 ../experiments/classification.py -d {dataset} --root_dir ../ --seed {seed} --n_layers 2 --activation tanh --name tanh_2 --logd_min {logmin}

Writing results to ../experiments/results
Reading data from ../data
100%|██████████████████████████████████████████| 10/10 [42:12<00:00, 253.22s/it]
Writing results to ../experiments/results
Reading data from ../data
100%|██████████████████████████████████████████| 10/10 [53:43<00:00, 322.32s/it]
Writing results to ../experiments/results
Reading data from ../data
100%|██████████████████████████████████████████| 10/10 [44:43<00:00, 268.37s/it]
Writing results to ../experiments/results
Reading data from ../data
 60%|█████████████████████████▊                 | 6/10 [24:45<16:41, 250.35s/it]

The experiment in experiments/uci_classification_commands.py could otherwise be used, but it does not take into account the different hyperparameter settings required for satellite & digits datasets (different range for prior precision)

Read the result pickle (for a single experiment with various prior precisions)

In [28]:
import pickle
import os
dataset = 'breast_cancer'
name = 'sparse_quarter'
datasets = ['australian', 'breast_cancer', 'digits', 'glass',
            'ionosphere', 'satellite', 'vehicle', 'waveform']
seeds = [711, 1, 75, 359, 17, 420, 129, 666, 69, 36]
result_list = []
for seed in seeds:
    file_name = f'../experiments/results/uci/classification_{dataset}_{name}_{seed}.pkl'
    if os.path.isfile(file_name):
        with open(file_name, 'rb') as handle:
            result_list.append(pickle.load(handle))
    else:
        print(f'WARNING: No results for seed {seed}')



Find  which prior precision gives the best validation NLL for the dataset (based on average NLL performance)

In [32]:
valid_nlls = np.zeros((len(result_list), len(result_list[0]['deltas'])))
print(valid_nlls.shape)
print(result_list[0]['results'][0].keys())
for i, results in enumerate(result_list):
    for j, res in enumerate(results['results']):
        if 'valid_nll_svgp_ntk' not in res:
            print(seeds[i])
            continue
        valid_nlls[i, j] = res['train_nll_svgp_ntk']
        
mean_nlls = np.mean(valid_nlls, axis=0)
print(mean_nlls)

(9, 10)
dict_keys(['losses', 'elbos_bbb', 'train_nll_map', 'train_acc_map', 'train_ece_map', 'test_nll_map', 'test_acc_map', 'test_ece_map', 'valid_nll_map', 'valid_acc_map', 'valid_ece_map', 'train_nll_bbb', 'train_acc_bbb', 'train_ece_bbb', 'test_nll_bbb', 'test_acc_bbb', 'test_ece_bbb', 'valid_nll_bbb', 'valid_acc_bbb', 'valid_ece_bbb', 'train_nll_svgp_ntk', 'train_acc_svgp_ntk', 'train_ece_svgp_ntk', 'test_nll_svgp_ntk', 'test_acc_svgp_ntk', 'test_ece_svgp_ntk', 'valid_nll_svgp_ntk', 'valid_acc_svgp_ntk', 'valid_ece_svgp_ntk', 'train_nll_glm', 'train_acc_glm', 'train_ece_glm', 'test_nll_glm', 'test_acc_glm', 'test_ece_glm', 'valid_nll_glm', 'valid_acc_glm', 'valid_ece_glm', 'train_nll_glmd', 'train_acc_glmd', 'train_ece_glmd', 'test_nll_glmd', 'test_acc_glmd', 'test_ece_glmd', 'valid_nll_glmd', 'valid_acc_glmd', 'valid_ece_glmd', 'train_nll_nn', 'train_acc_nn', 'train_ece_nn', 'test_nll_nn', 'test_acc_nn', 'test_ece_nn', 'valid_nll_nn', 'valid_acc_nn', 'valid_ece_nn', 'train_nll_nn

Check the test set performance of the selected prior precision

In [11]:
prec_idx = 7
test_nlls = np.zeros(len(result_list))
for i, results in enumerate(result_list):
    test_nlls[i] = results['results'][prec_idx]['test_nll_glm']
print(np.mean(test_nlls))
print(test_nlls)
print(np.std(test_nlls))

0.31400694251060485
[0.30834961 0.28497851 0.3024272  0.29026619 0.39062044 0.30780569
 0.30664775 0.29315224 0.34323138 0.31259042]
0.02970791712585548


In [None]:
results = result_list[0]
n_test = results['N_test']
n_train = results['N_train']
deltas = results['deltas']
print(f'Results keys: {results.keys()}')
print(results['K'])
print(f'Train set size: {n_train}')
print(f'Test set size: {n_test}')
for i, res in enumerate(results['results']):
    print(f'Prior precision: {deltas[i]}')
    print(f'Validation set NLL Map: {res["valid_nll_map"]}')
    print(f'Validation set NLL GLM: {res["valid_nll_glm"]}')
    print(f'Validation set NLL BNN: {res["valid_nll_nn"]}')
    
    for key in res:
        if 'valid_nll' in key:
            print(key)

<h2>Run Image experiments </h2>

In [92]:
ds = 'MNIST'
model = 'MLP'
seed = 117

In [96]:
!python3 ../experiments/imgclassification.py -d {ds} -m {model} -s {seed}

Writing results to ../experiments/results/MNIST
Reading data from ../data
Dataset: MNIST
Seed: 117
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz
9920512it [00:00, 12327924.77it/s]                                              
Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz
32768it [00:00, 298645.29it/s]
Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz
1654784it [00:00, 6915425.85it/s]                                               
Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz
81