In [43]:
import os
import json
from glob import glob
from collections import namedtuple

import numpy as np
import matplotlib.pyplot as plt

from bbb.utils.pytorch_setup import DEVICE
from bbb.config.parameters import Parameters, PriorParameters
from bbb.models.dnn import ClassificationDNN
from bbb.models.bnn import ClassificationBNN
from bbb.data import load_mnist

## Helper Classes

In [22]:
ModelDetails = namedtuple('ModelDetails', 'dir mclass')

## Trained Classification Models

In [31]:
MODEL_DETAILS_DICT = {
    # BNN
    "bnn_1200": ModelDetails("saved_models/BBB_classification/2022-03-15-09.18.07", ClassificationBNN),
    "bnn_800": ModelDetails("saved_models/BBB_classification/2022-03-15-14.25.46", ClassificationBNN),
    "bnn_400": ModelDetails("saved_models/BBB_classification/2022-03-15-14.26.34", ClassificationBNN),
    # DNN - no dropout
    "dnn_1200": ModelDetails("saved_models/DNN_classification/2022-03-15-14.28.25", ClassificationDNN),
    "dnn_1200": ModelDetails("saved_models/DNN_classification/2022-03-15-16.06.09", ClassificationDNN),
    "dnn_1200": ModelDetails("saved_models/DNN_classification/2022-03-15-16.10.34", ClassificationDNN),
    # DNN - dropout
    "dnn_do_1200": ModelDetails("saved_models/DNN_classification/2022-03-15-15.21.46", ClassificationDNN),
    "dnn_do_1200": ModelDetails("saved_models/DNN_classification/2022-03-15-15.58.04", ClassificationDNN),
    "dnn_do_1200": ModelDetails("saved_models/DNN_classification/2022-03-15-16.26.18", ClassificationDNN),
}

## Set Model

In [38]:
MODEL = "bnn_1200"
MODEL_DETAILS = MODEL_DETAILS_DICT[MODEL]

## Load Parameters

In [52]:
with open(os.path.join(MODEL_DETAILS.dir, 'params.txt'), 'r') as f:
    params_dict = json.load(f)

# Need to deserialise the prior_params into a PriorParameters object
if params_dict['prior_params']:
    params_dict['prior_params'] = PriorParameters(**params_dict['prior_params'])

params = Parameters(**params_dict)

## Load Data

In [None]:
X_val = load_mnist(train=False, batch_size=params.batch_size, shuffle=True)

## Load Model

In [53]:
model = MODEL_DETAILS.mclass(params=params, eval_mode=True).to(DEVICE)

2022-03-15 16:55:51,988 - bbb.models.layers - INFO - Weights Prior: Gaussian with mean 0 and variance 1.0
2022-03-15 16:55:51,993 - bbb.models.layers - INFO - Biases Prior: Gaussian with mean 0 and variance 1.0
2022-03-15 16:55:52,048 - bbb.models.layers - INFO - Weights Prior: Gaussian with mean 0 and variance 1.0
2022-03-15 16:55:52,051 - bbb.models.layers - INFO - Biases Prior: Gaussian with mean 0 and variance 1.0
2022-03-15 16:55:52,052 - bbb.models.layers - INFO - Weights Prior: Gaussian with mean 0 and variance 1.0
2022-03-15 16:55:52,052 - bbb.models.layers - INFO - Biases Prior: Gaussian with mean 0 and variance 1.0


## Evaluate

In [54]:
model.evaluate(X_val)

KeyboardInterrupt: 