In [None]:
import torch
import matplotlib.pyplot as plt
from ucimlrepo import fetch_ucirepo
from data.custom_dataset import uci_to_normalised_ttsplit
from models import MeanFieldBNN
from training import train

torch.set_default_dtype(torch.float64)

%load_ext autoreload
%autoreload 2

In [None]:
architecture = [8, 50, 50, 2]
scale_prior=False
likelihood_std=0.1
lr=1e-2
final_lr=3e-3
epochs = 10_000
heavy_fixed_nonzero = 1.0
light_fixed_nonzero = 4.0
train_proportion = 0.8

In [None]:
torch.manual_seed(0)

In [None]:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [None]:
energy = fetch_ucirepo(id=242)
X_uci, y_uci = energy.data.features, energy.data.targets
X_train, y_train, X_test, y_test = uci_to_normalised_ttsplit(X_uci, y_uci, train_proportion=train_proportion)

In [None]:
map_mlp = MeanFieldBNN(
    architecture,
    scale_prior=scale_prior,
    likelihood_std=likelihood_std,
)
map_tracker = train(
    map_mlp,
    X_train,
    y_train,
    epochs=epochs,
    learning_rate=lr,
    final_learning_rate=final_lr,
    # batch_size=256,
    x_test=X_test,
    y_test=y_test,
)


fig, ax = plt.subplots(1, len(map_tracker.items()), figsize=(4*len(map_tracker.items()), 2), sharex=True)

for j, (key, value) in enumerate(map_tracker.items()):
    ax[j].plot(value)
    ax[j].tick_params(axis='x', labelsize=6)
    ax[j].tick_params(axis='y', labelsize=6)
    ax[j].grid()
    ax[j].set_xlabel(key, fontsize=10)
plt.show()

In [None]:
HAFN = MeanFieldBNN(
    architecture,
    scale_prior=scale_prior,
    likelihood_std=likelihood_std,
    asymmetric_weights=True,
    minimal_mask=False,
    c=heavy_fixed_nonzero,
    train_c=True,
)

tracker = train(
    HAFN,
    X_train,
    y_train,
    variational=True,
    epochs=epochs,
    learning_rate=lr,
    final_learning_rate=final_lr,
    num_samples=16,
    x_test=X_test,
    y_test=y_test,
    batch_size=512,
)

fig, ax = plt.subplots(1, len(tracker.items()), figsize=(4*len(tracker.items()), 2), sharex=True)

for j, (key, value) in enumerate(tracker.items()):
    ax[j].plot(value)
    ax[j].tick_params(axis='x', labelsize=6)
    ax[j].tick_params(axis='y', labelsize=6)
    ax[j].grid()
    ax[j].set_xlabel(key, fontsize=10)
 
plt.show()

In [None]:
LAFN = MeanFieldBNN(
    architecture,
    scale_prior=scale_prior,
    likelihood_std=likelihood_std,
    asymmetric_weights=True,
    minimal_mask=True,
    c=light_fixed_nonzero,
    train_c=True,
)

tracker = train(
    LAFN,
    X_train,
    y_train,
    variational=True,
    epochs=epochs,
    learning_rate=lr,
    final_learning_rate=final_lr,
    num_samples=16,
    x_test=X_test,
    y_test=y_test,
    batch_size=512,
)

fig, ax = plt.subplots(1, len(tracker.items()), figsize=(4*len(tracker.items()), 2), sharex=True)

for j, (key, value) in enumerate(tracker.items()):
    ax[j].plot(value)
    ax[j].tick_params(axis='x', labelsize=6)
    ax[j].tick_params(axis='y', labelsize=6)
    ax[j].grid()
    ax[j].set_xlabel(key, fontsize=10)
 
plt.show()

In [None]:
HApruned = MeanFieldBNN(
    architecture,
    scale_prior=scale_prior,
    likelihood_std=likelihood_std,
    asymmetric_weights=True,
    minimal_mask=False,
    c=0.0,
)

tracker = train(
    HApruned,
    X_train,
    y_train,
    variational=True,
    epochs=epochs,
    learning_rate=lr,
    final_learning_rate=final_lr,
    num_samples=16,
    x_test=X_test,
    y_test=y_test,
    batch_size=512,
)

fig, ax = plt.subplots(1, len(tracker.items()), figsize=(4*len(tracker.items()), 2), sharex=True)

for j, (key, value) in enumerate(tracker.items()):
    ax[j].plot(value)
    ax[j].tick_params(axis='x', labelsize=6)
    ax[j].tick_params(axis='y', labelsize=6)
    ax[j].grid()
    ax[j].set_xlabel(key, fontsize=10)
 
plt.show()

In [None]:
LApruned = MeanFieldBNN(
    architecture,
    scale_prior=scale_prior,
    likelihood_std=likelihood_std,
    asymmetric_weights=True,
    minimal_mask=True,
    c=0.0,
)

tracker = train(
    LApruned,
    X_train,
    y_train,
    variational=True,
    epochs=epochs,
    learning_rate=lr,
    final_learning_rate=final_lr,
    num_samples=16,
    x_test=X_test,
    y_test=y_test,
    batch_size=512,
)

fig, ax = plt.subplots(1, len(tracker.items()), figsize=(4*len(tracker.items()), 2), sharex=True)

for j, (key, value) in enumerate(tracker.items()):
    ax[j].plot(value)
    ax[j].tick_params(axis='x', labelsize=6)
    ax[j].tick_params(axis='y', labelsize=6)
    ax[j].grid()
    ax[j].set_xlabel(key, fontsize=10)
 
plt.show()

In [None]:
HAMAP = MeanFieldBNN(
    architecture,
    scale_prior=scale_prior,
    likelihood_std=likelihood_std,
    asymmetric_weights=True,
    minimal_mask=False,
    map_weights=[layer.w.detach() for layer in map_mlp.layers],
)

tracker = train(
    HAMAP,
    X_train,
    y_train,
    variational=True,
    epochs=epochs,
    learning_rate=lr,
    final_learning_rate=final_lr,
    num_samples=16,
    x_test=X_test,
    y_test=y_test,
    batch_size=512,
)

fig, ax = plt.subplots(1, len(tracker.items()), figsize=(4*len(tracker.items()), 2), sharex=True)

for j, (key, value) in enumerate(tracker.items()):
    ax[j].plot(value)
    ax[j].tick_params(axis='x', labelsize=6)
    ax[j].tick_params(axis='y', labelsize=6)
    ax[j].grid()
    ax[j].set_xlabel(key, fontsize=10)
 
plt.show()

In [None]:
LAMAP = MeanFieldBNN(
    architecture,
    scale_prior=scale_prior,
    likelihood_std=likelihood_std,
    asymmetric_weights=True,
    minimal_mask=False,
    map_weights=[layer.w.detach() for layer in map_mlp.layers],
)

tracker = train(
    LAMAP,
    X_train,
    y_train,
    variational=True,
    epochs=epochs,
    learning_rate=lr,
    final_learning_rate=final_lr,
    num_samples=16,
    x_test=X_test,
    y_test=y_test,
    batch_size=512,
)

fig, ax = plt.subplots(1, len(tracker.items()), figsize=(4*len(tracker.items()), 2), sharex=True)

for j, (key, value) in enumerate(tracker.items()):
    ax[j].plot(value)
    ax[j].tick_params(axis='x', labelsize=6)
    ax[j].tick_params(axis='y', labelsize=6)
    ax[j].grid()
    ax[j].set_xlabel(key, fontsize=10)
 
plt.show()

In [None]:
HRFN = MeanFieldBNN(
    architecture,
    scale_prior=scale_prior,
    likelihood_std=likelihood_std,
    asymmetric_weights=True,
    minimal_mask=False,
    c=heavy_fixed_nonzero,
    train_c=True,
    random_mask=True,
)

tracker = train(
    HRFN,
    X_train,
    y_train,
    variational=True,
    epochs=epochs,
    learning_rate=lr,
    final_learning_rate=final_lr,
    num_samples=16,
    x_test=X_test,
    y_test=y_test,
    batch_size=512,
)

fig, ax = plt.subplots(1, len(tracker.items()), figsize=(4*len(tracker.items()), 2), sharex=True)

for j, (key, value) in enumerate(tracker.items()):
    ax[j].plot(value)
    ax[j].tick_params(axis='x', labelsize=6)
    ax[j].tick_params(axis='y', labelsize=6)
    ax[j].grid()
    ax[j].set_xlabel(key, fontsize=10)
 
plt.show()

In [None]:
LRFN = MeanFieldBNN(
    architecture,
    scale_prior=scale_prior,
    likelihood_std=likelihood_std,
    asymmetric_weights=True,
    minimal_mask=True,
    c=light_fixed_nonzero,
    train_c=True,
    random_mask=True,
)

tracker = train(
    LRFN,
    X_train,
    y_train,
    variational=True,
    epochs=epochs,
    learning_rate=lr,
    final_learning_rate=final_lr,
    num_samples=16,
    x_test=X_test,
    y_test=y_test,
    batch_size=512,
)

fig, ax = plt.subplots(1, len(tracker.items()), figsize=(4*len(tracker.items()), 2), sharex=True)

for j, (key, value) in enumerate(tracker.items()):
    ax[j].plot(value)
    ax[j].tick_params(axis='x', labelsize=6)
    ax[j].tick_params(axis='y', labelsize=6)
    ax[j].grid()
    ax[j].set_xlabel(key, fontsize=10)
 
plt.show()

In [None]:
vanilla = MeanFieldBNN(
    architecture,
    scale_prior=scale_prior,
    likelihood_std=likelihood_std,
    asymmetric_weights=False,
)

tracker = train(
    vanilla,
    X_train,
    y_train,
    variational=True,
    epochs=epochs,
    learning_rate=lr,
    final_learning_rate=final_lr,
    num_samples=16,
    x_test=X_test,
    y_test=y_test,
    batch_size=512,
)

fig, ax = plt.subplots(1, len(tracker.items()), figsize=(4*len(tracker.items()), 2), sharex=True)

for j, (key, value) in enumerate(tracker.items()):
    ax[j].plot(value)
    ax[j].tick_params(axis='x', labelsize=6)
    ax[j].tick_params(axis='y', labelsize=6)
    ax[j].grid()
    ax[j].set_xlabel(key, fontsize=10)
 
plt.show()

In [None]:
num_samps = 1000
results = []
models = [
    HAFN,
    LAFN,
    HApruned,
    LApruned,
    HAMAP,
    LAMAP,
    HRFN,
    LRFN,
    vanilla,
]
for model in models:
    results.append(model.evaluate(X_test, y_test, variational=True, num_samples=num_samps))

titles = [
    "HAFN",
    "LAFN",
    "HApruned",
    "LApruned",
    "HAMAP",
    "LAMAP",
    "HRFN",
    "LRFN",
    "vanilla",
]

for i, result in enumerate(results):
    print(titles[i] + ":      " + "rmse=", result[0], "  mlpp=", result[1])