In [1]:
import os
import json
from glob import glob
from collections import namedtuple

import torch
import numpy as np
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt

from bbb.utils.pytorch_setup import DEVICE
from bbb.utils.plotting import plot_weight_samples
from bbb.config.constants import KL_REWEIGHTING_TYPES, PRIOR_TYPES, VP_VARIANCE_TYPES
from bbb.config.parameters import Parameters, PriorParameters
from bbb.models.dnn import ClassificationDNN
from bbb.models.bnn import ClassificationBNN
from bbb.data import load_mnist, load_adverserial_mnist
from bbb.models.layers import BaseBFC


from labellines import labelLines

from evaluate_classifier import evaluate_classifier
from model_details import MODEL_DETAILS_DICT, load_model, ModelDetails

import pickle

In [12]:
# Running 50 inference samples across MoG 600 1200 
net, params = load_model("bnn_mog_600_1200")
net.inference_samples = 50
X_val = load_mnist(train=False, batch_size=params.batch_size, shuffle=True)
evaluate_classifier(net, X_val)

2022-03-29 19:09:49,280 - bbb.models.layers - INFO - Weights Prior: Gaussian mixture with means (0, 0), variances (1.0, 0.0009118819655545162) and weight 0.5
2022-03-29 19:09:49,281 - bbb.models.layers - INFO - Biases Prior: Gaussian mixture with means (0, 0), variances (1.0, 0.0009118819655545162) and weight 0.5
2022-03-29 19:09:49,291 - bbb.models.layers - INFO - Weights Prior: Gaussian mixture with means (0, 0), variances (1.0, 0.0009118819655545162) and weight 0.5
2022-03-29 19:09:49,292 - bbb.models.layers - INFO - Biases Prior: Gaussian mixture with means (0, 0), variances (1.0, 0.0009118819655545162) and weight 0.5
2022-03-29 19:09:49,293 - bbb.models.layers - INFO - Weights Prior: Gaussian mixture with means (0, 0), variances (1.0, 0.0009118819655545162) and weight 0.5
2022-03-29 19:09:49,293 - bbb.models.layers - INFO - Biases Prior: Gaussian mixture with means (0, 0), variances (1.0, 0.0009118819655545162) and weight 0.5
  return torch.from_numpy(parsed.astype(m[2], copy=Fals

{'labels': [1,
  3,
  1,
  6,
  6,
  0,
  6,
  8,
  6,
  7,
  1,
  5,
  0,
  0,
  7,
  8,
  4,
  1,
  6,
  1,
  1,
  8,
  0,
  4,
  7,
  8,
  9,
  7,
  5,
  6,
  4,
  2,
  8,
  5,
  3,
  2,
  8,
  5,
  4,
  6,
  3,
  2,
  4,
  0,
  9,
  9,
  3,
  3,
  8,
  4,
  2,
  9,
  6,
  6,
  2,
  3,
  5,
  2,
  4,
  1,
  3,
  9,
  9,
  2,
  1,
  0,
  6,
  7,
  5,
  3,
  2,
  2,
  3,
  7,
  4,
  8,
  9,
  2,
  6,
  2,
  0,
  8,
  6,
  4,
  3,
  3,
  2,
  3,
  5,
  4,
  2,
  0,
  2,
  8,
  7,
  0,
  6,
  1,
  7,
  1,
  7,
  0,
  4,
  8,
  0,
  0,
  3,
  6,
  5,
  3,
  1,
  8,
  0,
  3,
  4,
  5,
  1,
  6,
  2,
  4,
  6,
  2,
  2,
  4,
  1,
  6,
  0,
  7,
  6,
  0,
  9,
  3,
  6,
  0,
  6,
  8,
  4,
  8,
  9,
  9,
  8,
  7,
  1,
  7,
  1,
  3,
  5,
  9,
  1,
  4,
  8,
  3,
  7,
  7,
  1,
  2,
  7,
  9,
  8,
  4,
  1,
  0,
  6,
  0,
  3,
  4,
  4,
  2,
  2,
  8,
  9,
  6,
  9,
  1,
  7,
  9,
  1,
  4,
  3,
  1,
  9,
  6,
  2,
  8,
  1,
  4,
  8,
  1,
  5,
  8,
  9,
  5,
  6,
  0,
  7,
  6,
  7,
  1,


In [21]:
(1 - ___['eval_score'])

tensor(0.0137)

In [2]:
matplotlib.style.use('default')
plt.rcParams.update({'axes.titlesize': 'large', 'axes.labelsize': 'medium'})
colors_hex = {'blue': '#1F77B4', 'orange': '#FF7F0E', 'green': '#2CA02C'}
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']

In [6]:
mog_directory = os.fsencode("../saved_models/BBB_classification/baseline/mog_prior/")
sgp_directory = os.fsencode("../saved_models/BBB_classification/baseline/single_gaussian_prior/")

eval_dict = {}
dirs = [sgp_directory, mog_directory]
for directory in dirs:
    for root, subdirs, files in os.walk(directory):
        root = os.fsdecode(root)

        if root.split('/')[-1].__contains__(".") and root.split('/')[-2].__contains__("0"): # dst folder (time stamp, hidden units)
            
            model_name = "/".join(root.split('/')[-4:])
            
            with open(os.path.join(root, 'params.txt'), 'r') as f:
                params_dict = json.load(f)

            # Need to deserialise the prior_params into a PriorParameters object
            if params_dict['prior_params']:
                params_dict['prior_params'] = PriorParameters(**params_dict['prior_params'])

            params = Parameters(**params_dict)

            net = ClassificationBNN(params=params, eval_mode=True)
            net.model.load_state_dict(torch.load(os.path.join(root, 'model.pt'), map_location=torch.device('cpu')))

            eval_metric = np.load(os.path.join(root, 'eval_metric.npy'))
            eval_dict[model_name] = round(1-eval_metric[-1], 4)

2022-03-29 18:13:03,144 - bbb.models.layers - INFO - Weights Prior: Gaussian with mean 0 and variance 1.0
2022-03-29 18:13:03,145 - bbb.models.layers - INFO - Biases Prior: Gaussian with mean 0 and variance 1.0
2022-03-29 18:13:03,155 - bbb.models.layers - INFO - Weights Prior: Gaussian with mean 0 and variance 1.0
2022-03-29 18:13:03,156 - bbb.models.layers - INFO - Biases Prior: Gaussian with mean 0 and variance 1.0
2022-03-29 18:13:03,157 - bbb.models.layers - INFO - Weights Prior: Gaussian with mean 0 and variance 1.0
2022-03-29 18:13:03,157 - bbb.models.layers - INFO - Biases Prior: Gaussian with mean 0 and variance 1.0
2022-03-29 18:13:03,166 - bbb.models.layers - INFO - Weights Prior: Gaussian with mean 0 and variance 1.0
2022-03-29 18:13:03,166 - bbb.models.layers - INFO - Biases Prior: Gaussian with mean 0 and variance 1.0
2022-03-29 18:13:03,168 - bbb.models.layers - INFO - Weights Prior: Gaussian with mean 0 and variance 1.0
2022-03-29 18:13:03,168 - bbb.models.layers - INFO

In [7]:
eval_dict

{'single_gaussian_prior/300_epochs/1200/2022-03-15-09.18.07': 0.0136,
 'single_gaussian_prior/300_epochs/400/2022-03-15-14.26.34': 0.0144,
 'single_gaussian_prior/300_epochs/800/2022-03-15-14.25.46': 0.015,
 'single_gaussian_prior/600_epochs/1200/2022-03-27-20.30.22': 0.0115,
 'mog_prior/sigma_1_exp_2_sigma_2_exp_6/1200/2022-03-27-00.06.29': 0.0213,
 'mog_prior/sigma_1_exp_2_sigma_2_exp_6/400/2022-03-27-07.52.03': 0.0211,
 'mog_prior/sigma_1_exp_2_sigma_2_exp_6/800/2022-03-26-23.55.35': 0.021,
 'mog_prior/sigma_1_exp_1_sigma_2_exp_6/1200/2022-03-26-19.41.39': 0.0155,
 'mog_prior/sigma_1_exp_1_sigma_2_exp_6/400/2022-03-26-19.39.24': 0.0157,
 'mog_prior/sigma_1_exp_1_sigma_2_exp_6/800/2022-03-26-20.37.06': 0.0161,
 'mog_prior/sigma_1_1_sigma_2_0.2/1200/2022-03-15-18.38.27': 0.0699,
 'mog_prior/sigma_1_1_sigma_2_0.2/400/2022-03-15-21.48.39': 0.0147,
 'mog_prior/sigma_1_1_sigma_2_0.2/800/2022-03-15-18.43.25': 0.0158,
 'sigma_1_1_sigma_2_exp_7/1200/300_epochs/2022-03-27-07.54.00': 0.0148,
 

In [9]:
for MODEL in MODEL_DETAILS_DICT.keys():
    net, params = load_model(MODEL)
    eval_metric = np.load(os.path.join(MODEL_DETAILS_DICT[MODEL].dir, 'eval_metric.npy'))
    print(MODEL, round(1-eval_metric[-1], 5))

2022-03-29 18:19:43,675 - bbb.models.layers - INFO - Weights Prior: Gaussian with mean 0 and variance 1.0
2022-03-29 18:19:43,675 - bbb.models.layers - INFO - Biases Prior: Gaussian with mean 0 and variance 1.0
2022-03-29 18:19:43,686 - bbb.models.layers - INFO - Weights Prior: Gaussian with mean 0 and variance 1.0
2022-03-29 18:19:43,686 - bbb.models.layers - INFO - Biases Prior: Gaussian with mean 0 and variance 1.0
2022-03-29 18:19:43,688 - bbb.models.layers - INFO - Weights Prior: Gaussian with mean 0 and variance 1.0
2022-03-29 18:19:43,688 - bbb.models.layers - INFO - Biases Prior: Gaussian with mean 0 and variance 1.0
2022-03-29 18:19:43,699 - bbb.models.layers - INFO - Weights Prior: Gaussian with mean 0 and variance 1.0
2022-03-29 18:19:43,699 - bbb.models.layers - INFO - Biases Prior: Gaussian with mean 0 and variance 1.0
2022-03-29 18:19:43,704 - bbb.models.layers - INFO - Weights Prior: Gaussian with mean 0 and variance 1.0
2022-03-29 18:19:43,704 - bbb.models.layers - INFO

bnn_sgp_1200 0.01362
bnn_sgp_800 0.01502
bnn_sgp_400 0.01442
bnn_sgp_600_1200 0.01152
bnn_mog_600_1200 0.01482
bnn_mog_1200 0.01482
bnn_mog_800 0.01562
bnn_mog_400 0.01472
dnn_1200 0.01663
dnn_800 0.01683
dnn_400 0.02043
dnn_do_400 0.01332
dnn_do_800 0.01412
dnn_do_1200 0.01522
