In [None]:
from main import load_mnist
import torch
import torchvision

In [None]:
from deterministic.vanilla_net import VanillaNetLinear
from probabilistic.distribution_utils import to_discrete_distribution
from globals import TORCH_DEVICE
import numpy as np
import matplotlib.pyplot as plt

is_cuda_available = torch.cuda.is_available()
device = TORCH_DEVICE
print(device)
vanilla_net = VanillaNetLinear()
# Get all the initial weights and plot their distribution
default_weights = np.array([])
for param in vanilla_net.named_parameters():
    if 'weight' in param[0]:
        w = param[1].view(-1)
        default_weights = np.concatenate((default_weights, w.detach().numpy()), axis=None)
print(max(default_weights))

custom_weights = np.array([])
for param in vanilla_net.named_parameters():
    if 'weight' in param[0]:
        init_vals = torch.normal(mean=0.0, std=0.05, size=tuple(param[1].shape)).to(device)
        param[1].data = torch.nn.parameter.Parameter(init_vals)
        w = param[1].view(-1)
        custom_weights = np.concatenate((custom_weights, w.cpu().detach().numpy()), axis=None)

domain_values_default, range_values_default = to_discrete_distribution(default_weights)
domain_values_custom, range_values_custom = to_discrete_distribution(custom_weights)
fig, axs = plt.subplots(1, 2, figsize=(14, 6))
axs[0].plot(domain_values_default, range_values_default)
#axs[0].hist(default_weights, bins=100)
axs[0].set_title('Default weights')
axs[1].plot(domain_values_custom, range_values_custom)
#axs[1].hist(custom_weights, bins=100)
axs[1].set_title('Custom weights')

In [None]:
from scipy.stats import gaussian_kde

# Now in a similar way plot the weights distribution after training
trained_vanilla_net = VanillaNetLinear()
trained_vanilla_net.load_state_dict(torch.load('vanilla_network.pt'))
trained_vanilla_net.to(device)
trained_vanilla_net.eval()


optimized_weights = np.array([])
for param in trained_vanilla_net.named_parameters():
    if 'weight' in param[0]:
        w = param[1].view(-1)
        optimized_weights = np.concatenate((optimized_weights, w.cpu().detach().numpy()), axis=None)

# first fit a distribution to the data
density = gaussian_kde(optimized_weights)

domain_values_optimized, range_values_optimized = to_discrete_distribution(optimized_weights)
fig, axs = plt.subplots(1, 2, figsize=(14, 6))
axs[0].plot(domain_values_optimized, range_values_optimized)
# axs[0].hist(optimized_weights, bins=500)
cts_domain = np.linspace(np.min(domain_values_optimized), np.max(domain_values_optimized), 200)
axs[1].plot(cts_domain, density(cts_domain), 'k', linewidth=2)

In [None]:
import probabilistic.bnn as bnn
import probabilistic.models as models
import probabilistic.hamiltonian as ham
import probabilistic.pipeline as pipeline
import importlib
importlib.reload(models)
importlib.reload(ham)
importlib.reload(pipeline)
importlib.reload(bnn)
from probabilistic.models import WeightModel
from probabilistic.hamiltonian import Hamiltonian, HyperparamsHMC
from probabilistic.pipeline import hmc
from probabilistic.bnn import VanillaBNN
from dataset_utils import load_mnist


# First create the BNN
vanilla_bnn = VanillaBNN()
vanilla_bnn.to(TORCH_DEVICE)
train_data, test_data = load_mnist()
# should broadcast to the correct shape
w_prior_mean, w_prior_cov = torch.zeros(1), torch.eye(1)
weight_probabilistic_model = WeightModel(prior_mean=w_prior_mean, prior_variance=w_prior_cov, ll_variance=torch.tensor(0.1))
w_prior_func = weight_probabilistic_model.log_gaussian_prior
w_likelihood_func = weight_probabilistic_model.log_gaussian_likelihood

momentum_variances = torch.ones(1)
hamiltonian = Hamiltonian(w_prior_func, w_likelihood_func, momentum_variances, train_data, net = vanilla_bnn)

hps = HyperparamsHMC(num_epochs=300, num_burnin_epochs=20, lf_step=0.01, steps_per_epoch=30)

param_samples = hmc(hamiltonian, hps)
param_samples

In [None]:
import probabilistic.pipeline as pipeline
importlib.reload(pipeline)
from probabilistic.pipeline import test_hmc

torch.cuda.empty_cache()

cpu_device = torch.device('cpu')
param_samples = [w.to(cpu_device) for w in param_samples]
vanilla_bnn = vanilla_bnn.to(cpu_device)

accuracy = test_hmc(vanilla_bnn, param_samples[:100], test_data)
accuracy
