In [1]:
import torch
import matplotlib.pyplot as plt

from data.gp_data import generate_segp_dataset
from data.cubic_data import generate_cubic_dataset
from models import MeanFieldBNN
from training import train
from weight_masks import apply_mask, get_mask

torch.set_default_dtype(torch.float64)

%load_ext autoreload
%autoreload 2

In [None]:
architecture = [1, 50, 50, 1]
scale_prior=False
likelihood_std=0.05
lr=1e-2
final_lr=1e-3
epochs = 10_000
heavy_fixed_nonzero = 5.0
light_fixed_nonzero = 15.0

torch.manual_seed(36)
x, y = generate_segp_dataset(input_lower=-2.5, input_upper=2.5, num_points=30, lengthscale=0.2, gap=(-0.75, 0.75))

In [None]:
print("Obtaining MAP weights now...")

map_mlp = MeanFieldBNN(
    architecture,
    scale_prior=scale_prior,
    likelihood_std=likelihood_std,
)

map_tracker = train(
    map_mlp,
    x,
    y,
    epochs=epochs,
    learning_rate=lr,
    final_learning_rate=final_lr,
)

fig, ax = plt.subplots(1, len(map_tracker.items()), figsize=(4*len(map_tracker.items()), 2), sharex=True)

for j, (key, value) in enumerate(map_tracker.items()):
    ax[j].plot(value)
    ax[j].tick_params(axis='x', labelsize=6)
    ax[j].tick_params(axis='y', labelsize=6)
    ax[j].grid()
    ax[j].set_xlabel(key, fontsize=10)
 
plt.show()

x_test = torch.linspace(-4.0, 4.0, 500).unsqueeze(-1)
map_preds = map_mlp(x_test, variational=False).detach()

plt.plot(x_test.squeeze(), map_preds.squeeze(), color='grey', label='MAP prediction')
plt.scatter(x, y, zorder=1e5, label='Datapoints')
plt.grid()
plt.ylim([-3, 3])
plt.xlim([-4.0, 4.0])
plt.xlabel("Input Variable", fontsize=12)
plt.ylabel("Output Variable", fontsize=12)
plt.title("MAP Network Prediction", fontsize=15)
plt.legend(loc=2)
plt.show()

In [None]:
asymmetric_mlp_nonzeros = MeanFieldBNN(
    architecture,
    scale_prior=scale_prior,
    likelihood_std=likelihood_std,
    asymmetric_weights=True,
    minimal_mask=False,
    c=heavy_fixed_nonzero,
)

print(f"Performing VI now...")
mlp_tracker = train(
    asymmetric_mlp_nonzeros,
    x,
    y,
    variational=True,
    epochs=epochs,
    learning_rate=lr,
    num_samples=16,
    final_learning_rate=final_lr,
)

fig, ax = plt.subplots(1, len(mlp_tracker.items()), figsize=(4*len(mlp_tracker.items()), 2), sharex=True)

for j, (key, value) in enumerate(mlp_tracker.items()):
    ax[j].plot(value)
    ax[j].tick_params(axis='x', labelsize=6)
    ax[j].tick_params(axis='y', labelsize=6)
    ax[j].grid()
    ax[j].set_xlabel(key, fontsize=10)
 
plt.show()

In [None]:
lightly_asymmetric_mlp_nonzeros = MeanFieldBNN(
    architecture,
    scale_prior=scale_prior,
    likelihood_std=likelihood_std,
    asymmetric_weights=True,
    minimal_mask=True,
    c=light_fixed_nonzero,
)

print(f"Performing standard VI now...")
mlp_tracker = train(
    lightly_asymmetric_mlp_nonzeros,
    x,
    y,
    variational=True,
    epochs=epochs,
    learning_rate=lr,
    num_samples=16,
    final_learning_rate=final_lr,
)

fig, ax = plt.subplots(1, len(mlp_tracker.items()), figsize=(4*len(mlp_tracker.items()), 2), sharex=True)

for j, (key, value) in enumerate(mlp_tracker.items()):
    ax[j].plot(value)
    ax[j].tick_params(axis='x', labelsize=6)
    ax[j].tick_params(axis='y', labelsize=6)
    ax[j].grid()
    ax[j].set_xlabel(key, fontsize=10)
 
plt.show()

In [None]:
asymmetric_mlp_map_nonzeros = MeanFieldBNN(
    architecture,
    scale_prior=scale_prior,
    likelihood_std=likelihood_std,
    asymmetric_weights=True,
    minimal_mask=False,
    map_weights=[layer.w.detach() for layer in map_mlp.layers],
)

print(f"Performing standard VI now...")
mlp_tracker = train(
    asymmetric_mlp_map_nonzeros,
    x,
    y,
    variational=True,
    epochs=epochs,
    learning_rate=lr,
    num_samples=16,
    final_learning_rate=final_lr,
)

fig, ax = plt.subplots(1, len(mlp_tracker.items()), figsize=(4*len(mlp_tracker.items()), 2), sharex=True)

for j, (key, value) in enumerate(mlp_tracker.items()):
    ax[j].plot(value)
    ax[j].tick_params(axis='x', labelsize=6)
    ax[j].tick_params(axis='y', labelsize=6)
    ax[j].grid()
    ax[j].set_xlabel(key, fontsize=10)
 
plt.show()

In [None]:
lightly_asymmetric_mlp_map_nonzeros = MeanFieldBNN(
    architecture,
    scale_prior=scale_prior,
    likelihood_std=likelihood_std,
    asymmetric_weights=True,
    minimal_mask=True,
    map_weights=[layer.w.detach() for layer in map_mlp.layers],
)

print(f"Performing standard VI now...")
mlp_tracker = train(
    lightly_asymmetric_mlp_map_nonzeros,
    x,
    y,
    variational=True,
    epochs=epochs,
    learning_rate=lr,
    num_samples=16,
    final_learning_rate=final_lr,
)

fig, ax = plt.subplots(1, len(mlp_tracker.items()), figsize=(4*len(mlp_tracker.items()), 2), sharex=True)

for j, (key, value) in enumerate(mlp_tracker.items()):
    ax[j].plot(value)
    ax[j].tick_params(axis='x', labelsize=6)
    ax[j].tick_params(axis='y', labelsize=6)
    ax[j].grid()
    ax[j].set_xlabel(key, fontsize=10)
 
plt.show()

In [None]:
asymmetric_mlp_zeros = MeanFieldBNN(
    architecture,
    scale_prior=scale_prior,
    likelihood_std=likelihood_std,
    asymmetric_weights=True,
    minimal_mask=False,
    c=0.0,
)

print(f"Performing standard VI now...")
mlp_tracker = train(
    asymmetric_mlp_zeros,
    x,
    y,
    variational=True,
    epochs=epochs,
    learning_rate=lr,
    num_samples=16,
    final_learning_rate=final_lr,
)

fig, ax = plt.subplots(1, len(mlp_tracker.items()), figsize=(4*len(mlp_tracker.items()), 2), sharex=True)

for j, (key, value) in enumerate(mlp_tracker.items()):
    ax[j].plot(value)
    ax[j].tick_params(axis='x', labelsize=6)
    ax[j].tick_params(axis='y', labelsize=6)
    ax[j].grid()
    ax[j].set_xlabel(key, fontsize=10)
 
plt.show()

In [None]:
symmetric_mlp = MeanFieldBNN(
    architecture,
    scale_prior=scale_prior,
    likelihood_std=likelihood_std,
    asymmetric_weights=False,
)

print(f"Performing standard VI now...")
mlp_tracker = train(
    symmetric_mlp,
    x,
    y,
    variational=True,
    epochs=epochs,
    learning_rate=lr,
    num_samples=16,
    final_learning_rate=final_lr,
)

fig, ax = plt.subplots(1, len(mlp_tracker.items()), figsize=(4*len(mlp_tracker.items()), 2), sharex=True)

for j, (key, value) in enumerate(mlp_tracker.items()):
    ax[j].plot(value)
    ax[j].tick_params(axis='x', labelsize=6)
    ax[j].tick_params(axis='y', labelsize=6)
    ax[j].grid()
    ax[j].set_xlabel(key, fontsize=10)
 
plt.show()

In [None]:
display_samples = False

In [None]:
x_test = torch.linspace(-4.0, 4.0, 500).unsqueeze(-1)
num_samps = 300
preds = []
preds.append(asymmetric_mlp_nonzeros(x_test, variational=True, num_samples=num_samps).detach())
preds.append(lightly_asymmetric_mlp_nonzeros(x_test, variational=True, num_samples=num_samps).detach())
preds.append(asymmetric_mlp_map_nonzeros(x_test, variational=True, num_samples=num_samps).detach())
preds.append(lightly_asymmetric_mlp_map_nonzeros(x_test, variational=True, num_samples=num_samps).detach())
preds.append(asymmetric_mlp_zeros(x_test, variational=True, num_samples=num_samps).detach())
preds.append(symmetric_mlp(x_test, variational=True, num_samples=num_samps).detach())

if not display_samples:
    processed_preds = []
    for pred in preds:
        pred_mean = pred.mean(0)
        pred_std = pred.std(0)
        pred_upper = pred_mean + 2 * pred_std
        pred_lower = pred_mean - 2 * pred_std
        processed_preds.append((pred_mean, pred_upper, pred_lower))

fig, axes = plt.subplots(2, 3, sharex=True, figsize=(7.5 * 3, 10))
titles = ["Heavily Asymmetric (Fixed Nonzero) MFVI", "Lightly Asymmetric (Fixed Nonzero) MFVI", "Heavily Asymmetric (MAP) MFVI", "Lightly Asymmetric (MAP) MFVI", "Heavily Asymmetric (Pruned) MFVI", "Standard MFVI"]

for j in range(len(preds)):
    a = int(j > 2)
    b = j % 3
    
    if display_samples:
        for i in range(num_samps-1):
            axes[a][b].plot(x_test.squeeze(), preds[j][i,:,:].squeeze(), color='grey', alpha=0.3, linewidth=0.4)
        axes[a][b].plot(x_test.squeeze(), preds[j][-1,:,:].squeeze(), color='grey', alpha=0.3, linewidth=0.4, label='Posterior predictive sample')
    else:
        axes[a][b].plot(x_test.squeeze(), processed_preds[j][0].squeeze(), color='black', label="Predictive mean")
        axes[a][b].fill_between(x_test.squeeze(), processed_preds[j][1].squeeze(), processed_preds[j][2].squeeze(), color='grey', label="95% Confidence")
    axes[a][b].scatter(x, y, zorder=1e5, label='Datapoints')
    axes[a][b].grid()
    axes[a][b].set_ylim([-3, 3])
    axes[a][b].set_xlim([-4.0, 4.0])
    axes[a][b].set_xlabel("Input Variable", fontsize=12)
    axes[a][b].set_ylabel("Output Variable", fontsize=12)
    axes[a][b].set_title(titles[j], fontsize=15)
    axes[a][b].legend(loc=2)
plt.show()

In [None]:
x_test = torch.linspace(-4.0, 4.0, 500).unsqueeze(-1)
num_samps = 300
# preds = asymmetric_mlp_nonzeros(x_test, variational=True, num_samples=num_samps).detach()
# preds = lightly_asymmetric_mlp_nonzeros(x_test, variational=True, num_samples=num_samps).detach()
# preds = asymmetric_mlp_map_nonzeros(x_test, variational=True, num_samples=num_samps).detach()
# preds = lightly_asymmetric_mlp_map_nonzeros(x_test, variational=True, num_samples=num_samps).detach()
# preds = asymmetric_mlp_zeros(x_test, variational=True, num_samples=num_samps).detach()
preds = symmetric_mlp(x_test, variational=True, num_samples=num_samps).detach()

pred_mean = preds.mean(0)
pred_std = preds.std(0)
pred_upper = pred_mean + 2 * pred_std
pred_lower = pred_mean - 2 * pred_std
processed_preds = (pred_mean, pred_upper, pred_lower)

fig, axes = plt.subplots(1, 1, sharex=True, figsize=(8, 4))
titles = ["Heavily Asymmetric (Fixed Nonzero) MFVI", 
          "Lightly Asymmetric (Fixed Nonzero) MFVI", 
          "Heavily Asymmetric (MAP) MFVI", 
          "Lightly Asymmetric (MAP) MFVI", 
          "Heavily Asymmetric (Pruned) MFVI", 
          "Standard MFVI"]
    
axes.plot(x_test.squeeze(), processed_preds[0].squeeze(), color='black', label="Predictive mean")
axes.fill_between(x_test.squeeze(), processed_preds[1].squeeze(), processed_preds[2].squeeze(), color='grey', label="95% Confidence", alpha=0.6)
axes.scatter(x, y, zorder=1e5, label='Datapoints')
axes.grid()
axes.set_ylim([-2, 2])
axes.set_xlim([-4.0, 4.0])
axes.set_xlabel("Input Variable", fontsize=12)
axes.set_ylabel("Output Variable", fontsize=12)
axes.set_title(titles[0], fontsize=15)
axes.legend(loc=2)
plt.show()