In [116]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.nn.utils import parameters_to_vector
from torch.optim import Adam
from torch.distributions import MultivariateNormal, Normal
from preds.models import MLPS

from preds.likelihoods import CategoricalLh
from preds.datasets import UCIClassificationDatasets
from preds.laplace import Laplace

Includes some playing around with the Immer et al. dataset and model classes and a script for running the UCI experiemnts (with hyperparameters as in the paper appendix)

Testing the dataset and model classes

In [None]:
width = 50 # as in Immer et al. 
depth = 2 # as in Immer et al. 
prior_prec = np.logspace(-2, 2, num=10)[0] # as in Immer et al. but depends on dataset
lr = 1e-3 # as in Immer et al. 
n_epochs = 10000
n_samples = 1000
train_size = 0.70 # as in Immer et al. 
lh = CateoricalLh()  
uci_dataset = 'glass'
root_dir = '../data/'
device = 'cpu'

Load training data

In [None]:
ds_train = UCIClassificationDatasets(train=True, data_set=uci_dataset, split_train_size=train_size, double=False, root=root_dir)
X_train, y_train = ds_train.data.to(device), ds_train.targets.to(device).unsqueeze(1)
train_loader = [(X_train, y_train)]  

Load validation data

In [None]:
ds_val = UCIClassificationDatasets(train=False,valid=True, data_set=uci_dataset, split_train_size=train_size, double=False, root=root_dir)
X_val, y_val = ds_val.data.to(device), ds_val.targets.to(device).unsqueeze(1)
val_loader = [(X_val, y_val)] 

In [None]:
X_train.shape

In [None]:
model = MLPS(X_train.shape[1], [width]*depth, 1, activation='tanh', flatten=False).to(device)
optim = Adam(model.parameters(), lr=lr)
losses = list()
for i in range(n_epochs):
    f = model(X_train)
    w = parameters_to_vector(model.parameters())
    reg = 0.5 * prior_prec * w @ w
    loss = - lh.log_likelihood(y_train, f) + reg
    loss.backward()
    optim.step()
    losses.append(loss.item())
    model.zero_grad()


In [None]:
lap = Laplace(model, float(prior_prec), lh)


def get_pred_for(x, model_type='glm', cov_type='full'):
    #### INFERENCE (Posterior approximation) ####
    lap.infer(train_loader, cov_type=cov_type, dampen_kron=model_type=='bnn')
    if model_type == 'glm':
        #### GLM PREDICTIVE ####
        mu, var = lap.predictive_samples_glm(x, n_samples=n_samples)
    elif model_type == 'bnn':
        #### BNN PREDICTIVE ####
        samples = lap.predictive_samples_bnn(x, n_samples=n_samples)
        mu = samples.mean(axis=0)
        var = samples.cov(axis=0)
    else:
        raise ValueError('unsupported model_type.')
    mu = mu.detach().cpu().squeeze()
    var = var.detach().cpu().squeeze()
    return mu, var

In [None]:
print(X_train.dtype)
print(X_val.dtype)
print(mu.dtype)

In [None]:
# GLM

mu_glm, var_glm = get_pred_for(X_val, 'glm', 'full')# runs
# mu_glm_kron, var_glm_kron = get_pred_for(X_val, 'glm', 'kron') # doesn't run
#mu, var = get_pred_for(X_val, 'glm', 'diag')# doesn't run

# BNN
mu_bnn, var_bnn = get_pred_for(X_val, 'bnn', 'full')

In [None]:
y_val.shape

In [None]:
lh_glm = Normal(mu_glm, var_glm)
print(-torch.mean(lh_glm.log_prob(y_val.squeeze(-1))))

In [None]:
lh_bnn = Normal(mu_bnn, var_bnn)
print(-torch.mean(lh_bnn.log_prob(y_val.squeeze(-1))))

<h2>Run the classification pipeline for chosen dataset </h2>

The classification.py script tests 10 options for the prior precision, with a fixed train/val/test split. The results reported in the paper are over 10 different train/val/test splits

In [8]:
datasets = ['australian', 'breast_cancer', 'digits', 'glass',
            'ionosphere', 'satellite', 'vehicle', 'waveform']
seeds = [711, 1, 75, 359, 17, 420, 129, 666, 69, 36]
dataset = 'breast_cancer'

In [None]:
for seed in seeds:
    if dataset in ['satellite', 'digits']:
        logmin = -1.0
    else:
        logmin = -2.0
    !python3 ../experiments/classification.py -d {dataset} --root_dir ../ --seed {seed} --n_layers 2 --activation tanh --name tanh_2 --logd_min {logmin}

Writing results to ../experiments/results
Reading data from ../data
100%|██████████████████████████████████████████| 10/10 [42:12<00:00, 253.22s/it]
Writing results to ../experiments/results
Reading data from ../data
100%|██████████████████████████████████████████| 10/10 [53:43<00:00, 322.32s/it]
Writing results to ../experiments/results
Reading data from ../data
100%|██████████████████████████████████████████| 10/10 [44:43<00:00, 268.37s/it]
Writing results to ../experiments/results
Reading data from ../data
 60%|█████████████████████████▊                 | 6/10 [24:45<16:41, 250.35s/it]

The experiment in experiments/uci_classification_commands.py could otherwise be used, but it does not take into account the different hyperparameter settings required for satellite & digits datasets (different range for prior precision)

<h2> Manual reading of the result pickle</h2>
Read the result pickle (for a single experiment with various prior precisions)

Find  which prior precision gives the best validation NLL for the dataset (based on average NLL performance)

In [179]:
def get_val_best(result_list, model_name='svgp_ntk'):
    print(model_name)
    valid_nlls = np.zeros((len(result_list), len(result_list[0]['deltas'])))
    for i, results in enumerate(result_list):
        for j, res in enumerate(results['results']):
            if f'valid_nll_{model_name}' not in res:
                continue
            valid_nlls[i, j] = res[f'valid_nll_{model_name}']
    mean_nlls = np.mean(valid_nlls, axis=0)
    print(mean_nlls)
    min_nll = mean_nlls.min()
    min_index = list(mean_nlls).index(min_nll)
    return min_index

def get_result_list(dataset,folder_name,  name, seeds):   
    result_list = []
    for seed in seeds:
        file_name = f'../experiments/results/uci/{folder_name}/classification_{dataset}_{name}_{seed}.pkl'
        if os.path.isfile(file_name):
            with open(file_name, 'rb') as handle:
                result_list.append(pickle.load(handle))
        else:
            print(f'WARNING: No results for seed {seed}, dataset {dataset}')
    return result_list

Check the test set performance of the selected prior precision

In [185]:
def get_test_mean_std(result_list, min_idx, name='svgp_ntk'):
    test_nlls = np.zeros(len(result_list))
    for i, results in enumerate(result_list):
        test_nlls[i] = results['results'][min_idx][f'test_nll_{name}']
    mean = np.mean(test_nlls)
    std = np.std(test_nlls)
    return (mean, std)

<h2>Automatic table creation </h2>

In [187]:
method_map = {
    'NN MAP':  [(0.31, 0.01), (0.11, 0.02), (0.35, 0.02), (0.95,0.03), (0.420, 0.007), (0.335, 0.004), (0.094, 0.003), (0.230, 0.002)],
'MFVI': [(0.34, 0.01), (0.11, 0.01), (0.41, 0.01), (1.06,0.01), (0.504, 0.006), (0.393, 0.003), (0.219, 0.004), (0.307, 0.002)],
'BNN' : [(0.42, 0.00), (0.19, 0.00), (0.50, 0.00), (1.41,0.00), (0.885, 0.002), (0.516, 0.002), (0.875, 0.002), (0.482, 0.001)],
'GLM' : [(0.32, 0.02), (0.10, 0.01), (0.29, 0.01), (0.86,0.01), (0.428, 0.005), (0.339, 0.004), (0.250, 0.002), (0.241, 0.001)],
'GLM diag' : [(0.33, 0.01), (0.11, 0.01), (0.35, 0.01), (0.99,0.01), (0.618, 0.003), (0.388, 0.003), (0.409, 0.002), (0.327, 0.002)],
'GLM refine' :[(0.32, 0.02), (0.11, 0.01), (0.35, 0.05), (0.98,0.07), (0.402, 0.007), (0.335, 0.004), (0.150, 0.002), (0.227, 0.002)],
'GLM refine d' : [(0.31, 0.01), (0.12, 0.02), (0.32, 0.03), (0.83,0.02), (0.432, 0.005), (0.364, 0.008), (0.149, 0.008), (0.248, 0.002)],
'SVGP (quarter)': [(5, 2), (5, 2), (5, 2), (5,2), (5, 2), (5, 3), (5, 4), (5, 5)],
'SVGP (half)' : [(5, 2), (5, 2), (5, 2), (5,2), (5, 2), (5, 3), (5, 4), (5, 5)],
'GP' :[(5, 2), (5, 2), (5, 2), (5,2), (5, 2), (5, 3), (5, 4), (5, 5)]}


In [161]:
'valid_nll_svgp_ntk' in result_list[5]['results'][4].keys()

True

In [188]:
exper_names = ['sparse_quarter']
folder_name = '08'
pretty_names = {'sparse_half_svgp_ntk': 'SVGP (half)',
               'sparse_quarter_svgp_ntk': 'SVGP (quarter)',
               'sparse_eigth_svgp_ntk': 'SVGP (eigth)',
                'sparse_full_svgp_ntk': 'GP',
               'sparse_half_gp_subset': 'GP (subset half)',
               'sparse_quarter_gp_subset': 'GP (subset quarter)',
               'sparse_eigth_gp_subset': 'GP (subset eigth)',
               'sparse_full_svgp_ntk': 'GP svgp',
               'sparse_full_gp_subset': 'GP GP'}
method_names = ['svgp_ntk', 'gp_subset']
datasets = ['australian', 'breast_cancer', 'ionosphere','glass',
            'vehicle','waveform', 'digits', 'satellite']
data_names = ['australian', 'cancer', 'ionosphere', 'glass', 'vehicle', 'waveform', 'digits', 'satellite']
seeds = [711, 1, 75, 359, 17, 420, 129, 666, 69, 36]
for exp_name in exper_names:
    for method_name in method_names:
        table_list = []
        print(method_name)
        for dataset in datasets:
            print(dataset)
            result_list = get_result_list(dataset, folder_name, exp_name, seeds)
            print('N seeds:')
            print(len(result_list))
            if len(result_list) == 0:
                table_list.append((0, 0))
                continue
            min_idx = get_val_best(result_list, method_name)
            (mean, std) = get_test_mean_std(result_list, min_idx, name=method_name)
            table_list.append((round(mean, 3), round(std, 3)))
        method_map[pretty_names[exp_name+'_'+method_name]] = table_list
        

svgp_ntk
australian
svgp_ntk
[0.65564473 0.63371981 0.60722066 0.57084976 0.53251547 0.48553099
 0.41566438 0.33250019 0.35363318 0.43021106]
breast_cancer
svgp_ntk
[0.25154706 0.21345501 0.18994961 0.16697388 0.14399672 0.11359814
 0.08834994 0.09061514 0.11181515 0.18345583]
ionosphere
svgp_ntk
[0.58947544 0.54589691 0.49798007 0.45041892 0.40060972 0.35066446
 0.30512517 0.29317578 0.35900308 0.67097958]
glass
svgp_ntk
[1.33929729 1.22860153 1.08646501 0.91187803 0.76298074 0.66918979
 0.86367606 1.06289082 1.37234914 1.75561462]
vehicle
waveform
digits
satellite
gp_subset
australian
gp_subset
[0.67541825 0.66377576 0.64843921 0.62889761 0.60527028 0.57636394
 0.52836131 0.35656998 0.36756735 0.44032043]
breast_cancer
gp_subset
[0.42250337 0.3883824  0.35026844 0.32042465 0.30236671 0.25664138
 0.14236547 0.12331066 0.13268846 0.19311334]
ionosphere
gp_subset
[0.64232751 0.62134006 0.59638852 0.56997281 0.54168769 0.51060334
 0.45794102 0.38328496 0.39539573 0.67091556]
glass
gp_sub

In [189]:
method_map

{'NN MAP': [(0.31, 0.01),
  (0.11, 0.02),
  (0.35, 0.02),
  (0.95, 0.03),
  (0.42, 0.007),
  (0.335, 0.004),
  (0.094, 0.003),
  (0.23, 0.002)],
 'MFVI': [(0.34, 0.01),
  (0.11, 0.01),
  (0.41, 0.01),
  (1.06, 0.01),
  (0.504, 0.006),
  (0.393, 0.003),
  (0.219, 0.004),
  (0.307, 0.002)],
 'BNN': [(0.42, 0.0),
  (0.19, 0.0),
  (0.5, 0.0),
  (1.41, 0.0),
  (0.885, 0.002),
  (0.516, 0.002),
  (0.875, 0.002),
  (0.482, 0.001)],
 'GLM': [(0.32, 0.02),
  (0.1, 0.01),
  (0.29, 0.01),
  (0.86, 0.01),
  (0.428, 0.005),
  (0.339, 0.004),
  (0.25, 0.002),
  (0.241, 0.001)],
 'GLM diag': [(0.33, 0.01),
  (0.11, 0.01),
  (0.35, 0.01),
  (0.99, 0.01),
  (0.618, 0.003),
  (0.388, 0.003),
  (0.409, 0.002),
  (0.327, 0.002)],
 'GLM refine': [(0.32, 0.02),
  (0.11, 0.01),
  (0.35, 0.05),
  (0.98, 0.07),
  (0.402, 0.007),
  (0.335, 0.004),
  (0.15, 0.002),
  (0.227, 0.002)],
 'GLM refine d': [(0.31, 0.01),
  (0.12, 0.02),
  (0.32, 0.03),
  (0.83, 0.02),
  (0.432, 0.005),
  (0.364, 0.008),
  (0.149, 0.00

In [110]:
import os


lines = []
lines.append(r'\begin{tabular}{l C{0.6\tblw} C{0.6\tblw} C{0.6\tblw} C{0.6\tblw} C{0.6\tblw} C{0.6\tblw} C{0.6\tblw} C{0.6\tblw} C{0.6\tblw} C{0.6\tblw}}')
lines.append(r'\toprule')
header_str = ''
for method_name in method_map:
    header_str += f'& {method_name} '
header_str += r' \\'
lines.append(header_str)
lines.append(r'\midrule')
for i, data_name in enumerate(data_names):
    line_str = f'\sc {data_name} &'
    for method_name in method_map:
        (mean, var) = method_map[method_name][i]
        mean_var = f'{mean} {var}'
        line_str += r' \val{'
        line_str += str(mean)
        line_str += '}{'
        line_str += str(var)
        line_str += '} &'
    line_str = line_str[:-2]
    line_str += r' \\'
    lines.append(line_str)
lines.append(r'\bottomrule')
lines.append(r'\end{tabular}')
tex_file = 'uci.tex'
os.remove(tex_file)
with open(tex_file, 'a') as file:
    for line in lines:
        file.write(line+'\n')
    
print(lines)

['\\begin{tabular}{l C{0.6\\tblw} C{0.6\\tblw} C{0.6\\tblw} C{0.6\\tblw} C{0.6\\tblw} C{0.6\\tblw} C{0.6\\tblw} C{0.6\\tblw} C{0.6\\tblw} C{0.6\\tblw}}', '\\toprule', '& NN MAP & MFVI & BNN & GLM & GLM diag & GLM refine & GLM refine d & SVGP (quarter) & SVGP (half) & GP  \\\\', '\\midrule', '\\sc australian & \\val{0.31}{0.01} & \\val{0.34}{0.01} & \\val{0.42}{0.0} & \\val{0.32}{0.02} & \\val{0.33}{0.01} & \\val{0.32}{0.02} & \\val{0.31}{0.01} & \\val{0.339}{0.088} & \\val{0.341}{0.093} & \\val{5}{2} \\\\', '\\sc cancer & \\val{0.11}{0.02} & \\val{0.11}{0.01} & \\val{0.19}{0.0} & \\val{0.1}{0.01} & \\val{0.11}{0.01} & \\val{0.11}{0.01} & \\val{0.12}{0.02} & \\val{0.095}{0.047} & \\val{0.095}{0.047} & \\val{5}{2} \\\\', '\\sc ionosphere & \\val{0.35}{0.02} & \\val{0.41}{0.01} & \\val{0.5}{0.0} & \\val{0.29}{0.01} & \\val{0.35}{0.01} & \\val{0.35}{0.05} & \\val{0.32}{0.03} & \\val{0.334}{0.127} & \\val{0.307}{0.077} & \\val{5}{2} \\\\', '\\sc glass & \\val{0.95}{0.03} & \\val{1.06}{0.01}

<h2>Create img result table </h2>

In [None]:
def get_val_best_img(result_list, deltas, model_name='svgp_ntk'):
    print(model_name)
    valid_nlls = np.zeros((len(result_list), len(deltas))
    for i, results in enumerate(result_list):
        for j, delta in enumerate(deltas):
            valid_res = results[delta]['nll_va']
            valid_nlls[i, j] = valid_res
    mean_nlls = np.mean(valid_nlls, axis=0)
    print(mean_nlls)
    min_nll = mean_nlls.min()
    min_index = list(mean_nlls).index(min_nll)
    return min_index

def get_result_list_img(dataset, folder_name, seeds, model='MLP'):   
    result_list = []
    deltas = []
    dir_name = f'../experiments/results/{dataset}/models/{folder_name}'
    for seed in seeds:
        seed_res = dict()
        for file is os.listdir(dir_name):
            ds, m, s, delta = file[:-3].split('_')
            if m != model:
                continue
            if s != seed:
                continue
            state = torch.load(file)
            seed_res[delta] = state['svgp_ntk_3200']
            deltas.append(delta)
        result_list.append(seed_res)
    deltas = list(set(deltas))
    return result_list, deltas
                          
def get_test_mean_std_img(result_list, deltas, min_idx):
    test_nlls = np.zeros(len(result_list))
    min_delta = deltas[min_idx]
    for i, results in enumerate(result_list):
        test_nlls[i] = results[min_delta]['nll_te']
    mean = np.mean(test_nlls)
    std = np.std(test_nlls)
    return (mean, std)

In [None]:
exper_names = ['sparse_3200']
folder_name = '09'
pretty_names = {'sparse_3200_svgp_ntk': 'SVGP (3200)',
               'sparse_3200_gp_subset': 'GP subset (3200)'
              }
method_names = ['svgp_ntk']
datasets = ['FMNIST', 'CIFAR10']
model = 'MLP'
data_names = ['FMNIST', 'CIFAR10']
seeds = [711, 1, 75]
for exp_name in exper_names:
    for method_name in method_names:
        table_list = []
        print(method_name)
        for dataset in datasets:
            print(dataset)
            result_list, deltas = get_result_list_img(dataset, folder_name, seeds, model)
            print('N seeds:')
            print(len(result_list))
            if len(result_list) == 0:
                table_list.append((0, 0))
                continue
            min_idx = get_val_best_img(result_list, deltas)
            (mean, std) = get_test_mean_std_img(result_list, min_idx, name=method_name)
            table_list.append((round(mean, 3), round(std, 3)))
        method_map[pretty_names[exp_name+'_'+method_name]] = table_list

<h2>Run Image experiments </h2>

In [92]:
ds = 'MNIST'
model = 'MLP'
seed = 117

In [96]:
!python3 ../experiments/imgclassification.py -d {ds} -m {model} -s {seed}

Writing results to ../experiments/results/MNIST
Reading data from ../data
Dataset: MNIST
Seed: 117
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz
9920512it [00:00, 12327924.77it/s]                                              
Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz
32768it [00:00, 298645.29it/s]
Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz
1654784it [00:00, 6915425.85it/s]                                               
Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz
81