In [6]:
import os
import json
from glob import glob
from collections import namedtuple

import torch
import numpy as np
import matplotlib.pyplot as plt

from bbb.utils.pytorch_setup import DEVICE
from bbb.utils.plotting import plot_weight_samples
from bbb.config.constants import KL_REWEIGHTING_TYPES, PRIOR_TYPES, VP_VARIANCE_TYPES
from bbb.config.parameters import Parameters, PriorParameters
from bbb.models.dnn import ClassificationDNN
from bbb.models.bnn import ClassificationBNN
from bbb.data import load_mnist
from bbb.models.layers import BFC, BFC_LRT

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [14]:
ModelDetails = namedtuple('ModelDetails', 'dir mclass')

MODEL_DETAILS_DICT = {
    # BNN
    "bnn_1200": ModelDetails("../saved_models/BBB_classification/2022-03-15-09.18.07", ClassificationBNN),
    "bnn_800": ModelDetails("../saved_models/BBB_classification/2022-03-15-14.25.46", ClassificationBNN),
    "bnn_400": ModelDetails("../saved_models/BBB_classification/2022-03-15-14.26.34", ClassificationBNN),
    # DNN - no dropout
    "dnn_1200": ModelDetails("../saved_models/DNN_classification/2022-03-15-14.28.25", ClassificationDNN),
    "dnn_800": ModelDetails("../saved_models/DNN_classification/2022-03-15-16.06.09", ClassificationDNN),
    "dnn_400": ModelDetails("../saved_models/DNN_classification/2022-03-15-16.10.34", ClassificationDNN),
    # DNN - dropout
    "dnn_do_400": ModelDetails("../saved_models/DNN_classification/2022-03-15-15.21.46", ClassificationDNN),
    "dnn_do_800": ModelDetails("../saved_models/DNN_classification/2022-03-15-15.58.04", ClassificationDNN),
    "dnn_do_1200": ModelDetails("../saved_models/DNN_classification/2022-03-15-16.26.18", ClassificationDNN),
}

In [15]:
results_dict = {}

for MODEL in MODEL_DETAILS_DICT.keys():
    print("Loading {}".format(MODEL))
    results_dict[MODEL] = {}

    MODEL_DETAILS = MODEL_DETAILS_DICT[MODEL]

    # Load parameters
    with open(os.path.join(MODEL_DETAILS.dir, 'params.txt'), 'r') as f:
        params_dict = json.load(f)

    # Need to deserialise the prior_params into a PriorParameters object
    if params_dict['prior_params']:
        params_dict['prior_params'] = PriorParameters(**params_dict['prior_params'])

    params = Parameters(**params_dict)

    # Load model
    net = MODEL_DETAILS.mclass(params=params, eval_mode=True) # .to(DEVICE) 
    net.model.load_state_dict(torch.load(os.path.join(MODEL_DETAILS.dir, 'model.pt'), map_location=torch.device('cpu')))

    # Load evaluation metric results across epochs (list)
    eval_metric = np.load(os.path.join(MODEL_DETAILS.dir, 'eval_metric.npy'))

    # Add last epoch value
    results_dict[MODEL]['eval_metric'] = eval_metric[-1]

    weight_tensors = []
    for layer in [l for l in net.model if isinstance(l, BFC)]:

        # Average weight mean in layer, average weight variance in layer
        mu = layer.w_var_post.mu.cpu().detach().numpy()
        sigma = torch.log1p(torch.exp(layer.w_var_post.rho)).cpu().detach().numpy()
        
        print("{}: {} \t {}".format(MODEL, np.mean(mu), np.mean(sigma)))
        
    


2022-03-16 00:27:02,182 - bbb.models.layers - INFO - Weights Prior: Gaussian with mean 0 and variance 1.0
2022-03-16 00:27:02,182 - bbb.models.layers - INFO - Biases Prior: Gaussian with mean 0 and variance 1.0
2022-03-16 00:27:02,193 - bbb.models.layers - INFO - Weights Prior: Gaussian with mean 0 and variance 1.0
2022-03-16 00:27:02,193 - bbb.models.layers - INFO - Biases Prior: Gaussian with mean 0 and variance 1.0
2022-03-16 00:27:02,195 - bbb.models.layers - INFO - Weights Prior: Gaussian with mean 0 and variance 1.0
2022-03-16 00:27:02,195 - bbb.models.layers - INFO - Biases Prior: Gaussian with mean 0 and variance 1.0
2022-03-16 00:27:02,215 - bbb.models.layers - INFO - Weights Prior: Gaussian with mean 0 and variance 1.0
2022-03-16 00:27:02,215 - bbb.models.layers - INFO - Biases Prior: Gaussian with mean 0 and variance 1.0
2022-03-16 00:27:02,220 - bbb.models.layers - INFO - Weights Prior: Gaussian with mean 0 and variance 1.0
2022-03-16 00:27:02,220 - bbb.models.layers - INFO

Loading bnn_1200
bnn_1200: -0.0010805814526975155 	 0.022610129788517952
bnn_1200: 0.0005043221171945333 	 0.02260366827249527
bnn_1200: -0.0022730641067028046 	 0.022348301485180855
Loading bnn_800
bnn_800: -0.00013174352352507412 	 0.01478757243603468
bnn_800: 0.000435502064647153 	 0.014787760563194752
bnn_800: -0.0034469771198928356 	 0.014695613645017147
Loading bnn_400
bnn_400: -0.000349174631992355 	 0.018243536353111267
bnn_400: 0.0024643607903271914 	 0.01825784333050251
bnn_400: -0.01089491881430149 	 0.017943862825632095
Loading dnn_1200
Loading dnn_800
Loading dnn_400
Loading dnn_do_400
Loading dnn_do_800
Loading dnn_do_1200


In [16]:
results_dict

{'bnn_1200': {'eval_metric': 0.9863781929016113},
 'bnn_800': {'eval_metric': 0.9849759936332703},
 'bnn_400': {'eval_metric': 0.9855769276618958},
 'dnn_1200': {'eval_metric': 0.9833734035491943},
 'dnn_800': {'eval_metric': 0.9831730723381042},
 'dnn_400': {'eval_metric': 0.9795673489570618},
 'dnn_do_400': {'eval_metric': 0.9866787195205688},
 'dnn_do_800': {'eval_metric': 0.9858773946762085},
 'dnn_do_1200': {'eval_metric': 0.9847756624221802}}

In [21]:
print("Test error")
for model, results in results_dict.items():
    print("{}: {}".format(model, np.round(100*(1-results['eval_metric']), 3)))

Test error
bnn_1200: 1.362
bnn_800: 1.502
bnn_400: 1.442
dnn_1200: 1.663
dnn_800: 1.683
dnn_400: 2.043
dnn_do_400: 1.332
dnn_do_800: 1.412
dnn_do_1200: 1.522
