In [1]:
import torch
import pyro
import numpy as np
from torch.nn import ReLU, LeakyReLU
from pyro.infer import SVI, Trace_ELBO, Predictive
from tqdm import tqdm

# Simple model with one hidden layer
from src.models.regression import MLERegression
from src.metrics.regression import  RootMeanSquaredError, PCIP, MPIW
# Six UCI datasets: concrete, power, energy, wine, yacht, housing
from src.data.regression import UCI
from src.commons.pyro_training import to_bayesian_model
from src.commons.utils import seed_everything

In [2]:
input_sizes = {
    'concrete': 7,
    'power': 3,
    'energy': 8,
    'wine': 10,
    'yacht': 5,
    'housing': 12
}

relu = ReLU()
leaky = LeakyReLU(negative_slope = 0.5)

In [3]:
# Reference example: http://pyro.ai/examples/bayesian_regression.html

def training(svi, train_loader, epoch, device):
    loss = 0
    for idx, (X, y) in enumerate(train_loader):
        X = X.to(device)
        y = y.to(device)
        step_loss = svi.step(X, y)
        loss += step_loss
        batch_index = (epoch + 1) * len(train_loader) + idx
    return loss

def evaluation(predictive, dataloader, metrics, device):
    for idx, (X, y) in enumerate(dataloader):
        if not np.isfinite(np.array(X)).all():
            print("WARNING: INF IN DATA")
            X[X == float("-INF")] = 0
            X[X == float("INF")] = 0
            
        y = y.to(device)
        out = predictive(X.to(device))["obs"].T
        for metric in metrics.values():
            metric.update(out, y)


In [4]:
def train_loop(
    model,
    guide,
    train_loader,
    valid_loader,
    svi,
    epochs,
    num_samples,
    metrics,
    device,
):
    
    for e in range(epochs):
        loss = training(svi, train_loader, e, device)
#         print(f"Loss: {loss}")
        
        predictive = Predictive(model, guide=guide, num_samples=num_samples, return_sites=("obs",))
        evaluation(predictive, valid_loader, metrics, device)
        
        print("\nEpoch:", e)
        for metric in metrics:
            print(f"{metric}: {metrics[metric].compute().cpu()}")
            metrics[metric].reset()
            
    return model, guide

In [5]:
def bayesian_training(
    seed,
    max_epochs,
    num_samples,
    batch_size,
    task,
    device,
    optimizer,
    criterion,
    sigma_bound,
    param_mean,
    param_std,
    hidden_size,
    activation
):
    seed_everything(seed)
    pyro.clear_param_store()

    net = MLERegression(activation=activation, in_size=input_sizes[task], hidden_size=hidden_size, out_size=1)
# the model is converted using code in src.models.bnn, in particular BNNRegression class defining bayesian model
    model = to_bayesian_model(net, param_mean, param_std, device=device, sigma_bound=sigma_bound)
    svi = SVI(model.model, model.guide, optimizer, loss=criterion)
    
    dataloader_args = {
      "num_workers": 4,
      "pin_memory": True,
      "persistent_workers": True
    }
    
    uci = UCI(task, train_batch_size=batch_size, test_batch_size=batch_size, train_ratio=0.8, validation_ratio=0.2, dataloader_args=dataloader_args)
    train_loader = uci.train_dataloader()
    valid_loader = uci.validation_dataloader()
    
    metrics = {
    "RMSE": RootMeanSquaredError(input_type="samples"),
    "PCIP": PCIP(input_type="none", percentile=50),
    "MPIW": MPIW(input_type="none", percentile=50)
    }

    for metric in metrics.values():
        metric.set_device(device)
        
    return train_loop(model.model, model.guide, train_loader, valid_loader, svi, max_epochs, num_samples, metrics, device)

In [6]:
param_dict = {
    "seed": np.random.randint(10000),
    "max_epochs": 500,
    "num_samples": 100,
    "batch_size": 132,
    "task": 'housing',
#     "task": 'concrete',
    "task": 'power',
    "device": torch.device("cuda:0"),

#     "optimizer": pyro.optim.Adam({"lr": 1e-5 ,"betas": [0.9, 0.999], "eps":1e-4}),
    "optimizer": pyro.optim.Adam({"lr": 1e-2}),
#     "optimizer": pyro.optim.SGD({"lr": 1e-8}),
    "criterion": Trace_ELBO(),

    # The models output is Normal(mean, std) where std has prior Uniform(0, sigma_bound)
    "sigma_bound": 5.0,

    # mean and std of bayesian model priors
    "param_mean": 0.0,
    "param_std": 1.0,

    "hidden_size": 50,
    "activation": relu
}

The data is normalized so RMSE of naive model is ~1. 

PCIP metric is adjusted for percentile so for example for 80% confidence interval value 5 means that 85% of datapoint fall into the interval and value -3 means that 77% of the datapoints fall into the interval. In other words, the closer the PCIP metric is to 0 the better.

Here we can see that the model starts with relatively reasonable PCIP but as the model trains it goes up to 18 indicating that 98% of datapoints fall into 80% confidence interval (despite the fact that RMSE goes down so the model converges). 

By changing param_std, sigma_bound, learning_rate, batch_size or the shape of the net one can affect the course of training so that the initial PCIP can be negative/positive and increase/decrease throughout training. But I didn't manage to set the parameters in such a way that it converges to 0.

In [None]:
bayesian_training(**param_dict)

Train data shape
torch.Size([7654, 3])
torch.Size([7654])
Test data shape
torch.Size([1914, 3])
torch.Size([1914])
NAIVE RMSE: 0.9996809450431293

Epoch: 0
RMSE: 0.4473884701728821
PCIP: 41.895423889160156
MPIW: 1.7906368970870972

Epoch: 1
RMSE: 0.5076503157615662
PCIP: 38.69281005859375
MPIW: 1.585242748260498

Epoch: 2
RMSE: 0.41440704464912415
PCIP: 38.932464599609375
MPIW: 1.500036597251892

Epoch: 3
RMSE: 0.4539216458797455
PCIP: 39.65686798095703
MPIW: 1.5613961219787598

Epoch: 4
RMSE: 0.3456208109855652
PCIP: 40.418296813964844
MPIW: 1.5067243576049805

Epoch: 5
RMSE: 0.3187048137187958
PCIP: 40.34858703613281
MPIW: 1.4362764358520508

Epoch: 6
RMSE: 0.33088061213493347
PCIP: 40.80298614501953
MPIW: 1.4135652780532837

Epoch: 7
RMSE: 0.35690540075302124
PCIP: 40.30229187011719
MPIW: 1.3733047246932983

Epoch: 8
RMSE: 0.4047638475894928
PCIP: 39.970947265625
MPIW: 1.3506795167922974

Epoch: 9
RMSE: 0.3228450119495392
PCIP: 39.8431396484375
MPIW: 1.3177143335342407

Epoch: 10
RM