In [9]:
import math
import copy
import os
import copy
import time

import torch
from torch.nn import functional as F
from torch.nn import Module, ModuleList, Linear, Tanh, MSELoss, GaussianNLLLoss
from torch.nn.parameter import Parameter
from torch.optim import SGD, Adam, Optimizer
from torch.utils.data import DataLoader, TensorDataset

device='cpu'
dtype=torch.float32

import matplotlib.pyplot as plt
import numpy as np

# Custom libraries
from utils.custom_utils import set_global_random_seed

# Setting random seed for reproducibility
global_seed = 0
set_global_random_seed(global_seed)

Global seed set to 0


# Loading the perovskite dataset

In [22]:
inputs_dir = 'inputs_perov_data'
run_dir = 'bayesian_nestedae'
nn_save_dir = 'ae1'
nn = 1
mode = 'train'
kfolds = 0
X_variable_name = 'all_props' 
y_variable_name = 'bg'

# Run the command to preprocess the data
os.system('rm -rf inputs')
os.system(f'cp -rf {inputs_dir} inputs')
os.system(f'python3 preprocess_data.py --run_dir {run_dir} --nn_save_dir {nn_save_dir} --nn {nn} --mode {mode} --kfolds {kfolds} &')
time.sleep(5)
os.system('rm -rf inputs')

train_dataset_path = f'../runs/{run_dir}/{nn_save_dir}/datasets/train_dataset.pt'
val_dataset_path = f'../runs/{run_dir}/{nn_save_dir}/datasets/val_dataset.pt'

train_dataset = torch.load(train_dataset_path)
val_dataset = torch.load(val_dataset_path)

print(f'Train Dataset shape {train_dataset.shape}')
print(f'Train Dataset variables {train_dataset.variable_names}')
print(f'Train Dataset variable shapes {train_dataset.variable_shapes}')

train_X_data = train_dataset.dataset['all_props']
train_y_data = train_dataset.dataset['bg']
val_X_data = val_dataset.dataset['all_props']
val_y_data = val_dataset.dataset['bg']

 --> Found Run directory already exists.
 --> nn directory already exists.
Train Dataset shape (444, 16)
Train Dataset variables ['all_props', 'bg']
Train Dataset variable shapes {'all_props': (444, 15), 'bg': (444, 1)}


# Neural Network Class

In [None]:
class BayesLinear(Module):
    def __init__(self, in_dim: int, out_dim: int, prior_mu: float = None, prior_sigma: float = None,  bias: bool = True) -> None:
        super(BayesLinear, self).__init__()
        # User provides the prior mean and prior log sigma for describing the prior distribution to sample the weights from
        self.prior_mu = prior_mu
        self.prior_sigma = prior_sigma
        self.prior_log_sigma = math.log(prior_sigma)
        # self.prior_log_sigma = math.log(prior_sigma)
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.bias = bias
        self.weight_mu = Parameter(torch.empty(out_dim, in_dim))    
        self.weight_log_sigma = Parameter(torch.Tensor(out_dim, in_dim))
        if self.bias:
            self.bias_mu = Parameter(torch.Tensor(out_dim))
            self.bias_log_sigma = Parameter(torch.Tensor(out_dim))
        else:
            self.register_parameter('bias_mu', None)
            self.register_parameter('bias_log_sigma', None)
        # Reset the parameters
        self.reset_parameters()

    def reset_parameters(self) -> None:
        # Uniform Initialization
        # From : https://github.com/JavierAntoran/Bayesian-Neural-Networks
        scale1 = (2/math.sqrt(self.weight_mu.size(1)))**0.5
        # From : https://github.com/Harry24k/bayesian-neural-network-pytorch
        scale2 = 1. / math.sqrt(self.weight_mu.size(1))        
        self.weight_mu.data.uniform_(-scale2, scale2)
        self.weight_log_sigma.data.uniform_(-3, -1)

        if self.bias:
            self.bias_mu.data.uniform_(-scale2, scale2)
            self.bias_log_sigma.data.uniform_(-3, -1)

    def sample_parameter(self, mu, log_sigma):
        eps = torch.randn_like(log_sigma)
        # From : https://github.com/JavierAntoran/Bayesian-Neural-Networks
        # std = torch.log(1 + torch.exp(rho))
        std = torch.exp(log_sigma)
        sample = mu + std*eps
        return sample

    def forward(self, input: torch.Tensor, sample: bool = True) -> torch.Tensor:
        if sample :
            weight_sample = self.sample_parameter(self.weight_mu, self.weight_log_sigma)
            if self.bias:
                bias_sample = self.sample_parameter(self.bias_mu, self.bias_log_sigma)
                return F.linear(input, weight_sample, bias_sample)
            else:
                return F.linear(input, weight_sample)
        else:
            if self.bias:
                return F.linear(input, self.weight_mu, self.bias_mu)
            else:
                return F.linear(input, self.weight_mu)
            
    def sample_layer(self, num_samples):
        weight_samples = []
        for i in range(num_samples):
            weight_sample = self.sample_parameter(self.weight_mu, self.weight_log_sigma)
            weight_samples += weight_sample.view(-1).cpu().data.numpy().tolist()
        if self.bias:
            bias_samples = []
            for i in range(num_samples):
                bias_sample = self.sample_parameter(self.bias_mu, self.bias_log_sigma)
                bias_samples += bias_sample.view(-1).cpu().data.numpy().tolist()
            return weight_samples, bias_samples
        else:
            return weight_samples
        
class AppendLogVar(Module):
    def __init__(self, noise=1e-3, dtype=torch.float32, device='cpu', *args, **kwargs):
        super().__init__(*args, **kwargs)
        # The log variance is the learnable parameter in this layer
        # self.log_var is a tensor with the following shape [[0]]
        self.log_var = torch.nn.Parameter(torch.empty((1, 1), dtype=dtype, device=device))
        # self.log_var is initialized with log(noise) with the following shape [[log(noise)]]
        torch.nn.init.constant_(self.log_var, val=math.log(noise))

    def forward(self, x):
        # [[sample1], [sample2], ..] -> [[sample1, log(noise)], [sample2, log(noise)], ..]
        return torch.cat((x, self.log_var * torch.ones_like(x)), dim=1)

class BNN(Module):
    def __init__(self, in_dim, out_dim, 
                 hidden_layers, hidden_activations,
                 out_activation,
                 bias=True,
                 model_type=1,
                 noise_var=0.1,
                 dtype=torch.float32,
                 device='cpu'):
        super(BNN, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.hidden_layers = hidden_layers
        self.hidden_activations = hidden_activations
        self.out_activation = out_activation
        self.layer_stack = torch.nn.Sequential()
        self.model_type = model_type
        self.dtype = dtype
        self.device = device
        self.noise_var = noise_var
        self.bias = bias
        if not (self.model_type == 1 or self.model_type == 2 or self.model_type == 3):
            raise ValueError('Provided model_type {model}. model_type can only be 1, 2 or 3.')
        
        if self.hidden_layers: # Output + Hidden layer model
            for i, hidden_dim, hidden_act in enumerate(zip(self.hidden_layers, self.hidden_activations)):
                if i == 0:
                    if self.model_type == 1:
                        self.layer_stack.append(BayesLinear(self.in_dim, hidden_dim, 0, 0.1))
                    else:
                        self.layer_stack.append(torch.nn.Linear(self.in_dim, hidden_dim, bias=self.bias, dtype=self.dtype, device=self.device))
                    self.layer_stack.append(self.hidden_activation)
                elif i == len(self.hidden_layers) - 1:
                    if self.model_type == 1:
                        self.layer_stack.append(BayesLinear(hidden_dim, self.out_dim, 0, 0.1))
                        if self.out_activation:
                            self.layer_stack.append(self.out_activation)
                    else:
                        self.layer_stack.append(torch.nn.Linear(hidden_dim, self.out_dim, bias=self.bias, dtype=self.dtype, device=self.device))
                        if self.out_activation:
                            self.layer_stack.append(self.out_activation)
                        if self.model_type == 2:
                            self.layer_stack.append(AppendLogVar(noise=self.noise_var))
                else:
                    if self.model_type == 1:
                        self.layer_stack.append(BayesLinear(hidden_dim, hidden_dim, 0, 0.1))
                    else:
                        self.layer_stack.append(torch.nn.Linear(hidden_dim, hidden_dim, bias=self.bias, dtype=self.dtype, device=self.device))
                    self.layer_stack.append(self.hidden_activation)
        else: # Only output layer model
            if self.model_type == 1:
                self.layer_stack.append(BayesLinear(self.in_dim, self.out_dim, 0, 0.1))
                if self.out_activation:
                    self.layer_stack.append(self.out_activation)
            else:
                self.layer_stack.append(torch.nn.Linear(self.in_dim, self.out_dim, bias=self.bias, dtype=self.dtype, device=self.device))
                if self.out_activation:
                    self.layer_stack.append(self.out_activation)
                if self.model_type == 2:
                    self.layer_stack.append(AppendLogVar(noise=self.noise_var))
            

    def forward(self, X):
        return self.layer_stack(X)
    
def kl_loss_fn(model, reduction='mean'):

    def kl_div(mu1, log_sigma1, mu2, log_sigma2):
        """
            Gets the KL divergence between two gaussian distributions
        """
        return (log_sigma2 - log_sigma1).sum() + 0.5*((torch.exp(log_sigma1)**2)/(math.exp(log_sigma2)**2)).sum() + 0.5*(((mu1 - mu2)**2)/(math.exp(log_sigma2)**2)).sum() - 0.5*log_sigma1.numel()
    
    kl_sum = torch.tensor(0.0, dtype=torch.float32)
    n = torch.tensor(0.0, dtype=torch.float32)
    for layer in model.layer_stack:
        if isinstance(layer, BayesLinear):
            kl_sum += kl_div(layer.weight_mu, layer.weight_log_sigma, layer.prior_mu, layer.prior_log_sigma)
            n += len(layer.weight_mu.view(-1))
            if layer.bias:
                kl_sum += kl_div(layer.weight_mu, layer.weight_log_sigma, layer.prior_mu, layer.prior_log_sigma)
                n += len(layer.bias_mu.view(-1))

    if reduction == 'mean':
        return kl_sum/n
    elif reduction == 'sum':
        return kl_sum
    else:
        raise ValueError(reduction + " is not valid")
        
def train_loop(dataloader, encoder, decoder, predictor, loss_fn_list, loss_wts, loss_names, optimizer, scheduler=None, print_every=None):
    """
        num_burn_in_steps : for use with BNN trained with SGHMC
        keep_every : Sa
    """
    # Set model into training mode
    encoder.train()
    predictor.train()
    decoder.train()
    dataset_size = len(dataloader.dataset)
    batch_size = dataloader.batch_size
    num_batches = dataset_size/batch_size
    num_losses = len(loss_fn_list)
    loss_per_batch_list = [0.0]*num_losses
    loss_per_epoch_list = [0.0]*num_losses

    for batch, data in enumerate(dataloader):
        X, y = data

        # Zero out the gradients
        optimizer.zero_grad()

        l = encoder(X)
        y_pred = predictor(l)
        X_pred = decoder(l)

        # This stores all the running loss values
        loss_list = [torch.tensor(0.0, dtype=torch.float32)]*num_losses
        total_loss = torch.tensor(0.0, dtype=torch.float32)

        if predictor.model_type == 2:

        # Calculate prediction losses
        for i, (loss_name, loss_fn) in enumerate(zip(loss_names, loss_fn_list)):
            if loss_name == 'kl':
                loss_list[i] = loss_fn(model, reduction='mean')
            elif loss_name == 'neg_log_joint':
                if model.model_type == 2:
                    gnll, neg_log_var_prior, neg_log_params_prior = loss_fn(pred, y, model.parameters())
                    loss_list[i] = gnll + neg_log_var_prior + neg_log_params_prior
                else:
                    gnll, neg_log_params_prior = loss_fn(pred, y, model.parameters())
                    # print(gnll, neg_log_params_prior)
                    # print(type(gnll), type(neg_log_params_prior))
                    loss_list[i] = gnll + neg_log_params_prior
            loss_per_batch_list[i] += loss_list[i].item()
        
        # print('Model parameters :')
        # for param in model.parameters():
        #     print(param)

        for loss_wt, loss in zip(loss_wts, loss_list):
            total_loss += loss_wt*loss

        # Calculate the gradients for each trainable param
        total_loss.backward()
        # Update the trainable params
        optimizer.step()
        if scheduler is not None:
            scheduler.step()

        if print_every is not None:
            if batch % print_every == 0:
                current = batch*batch_size + len(X)
                for loss_name, loss in zip(loss_names, loss_list):
                    print(f' Train {loss_name} @ {current}/{dataset_size} is {loss.item()}. ')
                print(f' Train Total loss @ {current}/{dataset_size} is {total_loss.item()}. ')

    # Average the stats across all batches to get stat for 1 epoch
    for i, loss_per_batch in enumerate(loss_per_batch_list):
        loss_per_epoch_list[i] = loss_per_batch/num_batches

    return loss_per_epoch_list, [gnll.item()], [neg_log_params_prior.item()]

def test_loop(dataloader, model, loss_fn_list, loss_wts, loss_names):
    # Set model into evaluation mode
    model.eval()
    dataset_size = len(dataloader.dataset)
    batch_size = dataloader.batch_size
    num_batches = dataset_size/batch_size
    num_losses = len(loss_fn_list)
    loss_per_batch_list = [0.0]*num_losses
    loss_per_epoch_list = [0.0]*num_losses

    with torch.no_grad():
        for data in dataloader:
            X, y = data

            pred = model(X)

            loss_list = [torch.tensor(0.0, dtype=torch.float32)]*num_losses
            total_loss = torch.tensor(0.0, dtype=torch.float32)

            for i, (loss_name, loss_fn) in enumerate(zip(loss_names, loss_fn_list)):
                if loss_name == 'kl':
                    loss_list[i] = loss_fn(model, reduction='mean')
                elif loss_name == 'neg_log_joint':
                    if model.model_type == 2:
                        gnll, neg_log_var_prior, neg_log_params_prior = loss_fn(pred, y, model.parameters())
                        loss_list[i] = gnll + neg_log_var_prior + neg_log_params_prior
                    else:
                        gnll, neg_log_params_prior = loss_fn(pred, y, model.parameters())
                        loss_list[i] = gnll + neg_log_params_prior
                else:
                    loss_list[i] = loss_fn(pred, y)
                loss_per_batch_list[i] += loss_list[i].item()
        
            for loss_wt, loss in zip(loss_wts, loss_list):
                total_loss += loss_wt*loss
    
    for i, loss_per_batch in enumerate(loss_per_batch_list):
        loss_per_epoch_list[i] = loss_per_batch/num_batches

    return loss_per_epoch_list, [gnll.item()], [neg_log_params_prior.item()]

if __name__ == "__main__":
    print('Testing Functions ... ')
    bnn = BNN(in_dim=10, out_dim=10, hidden_layers=1, hidden_dim=10, hidden_activation=torch.nn.ReLU(inplace=True))
    train_random_input_tensor = torch.randn((10, 10))
    train_random_output_tensor = torch.randn((10, 10))

    test_random_input_tensor = torch.randn((5, 10))
    test_random_output_tensor = torch.randn((5, 10))

    print('Output from NN :')
    print(bnn(train_random_input_tensor))
    print('\n')
    
    # print(next(iter(dataloader)))
    # print('\n')
    # print(dataloader.dataset)

    # # Prior mu and rho are not defined
    # kl_loss_fn(bnn)
    # for param in bnn.parameters():
    #     print(param)

# Using Bayesian Neural Networks - Hamiltonian Monte Carlo
* Code mainly adopted from here : https://github.com/automl/pybnn/tree/master
* Original paper here : http://arxiv.org/abs/1502.05700


In [16]:
from typing import Iterable

class NegativeLogJoint(torch.nn.Module):
    reduction:str
    def __init__(self, model_type, params_var=10.0, noise_var=0.1, reduction:str='mean', dtype=torch.float32, device='cpu') -> None:
        super(NegativeLogJoint, self).__init__()
        self.model_type = model_type
        self.params_var = params_var
        self.noise_var = noise_var
        self.reduction = reduction
        self.dtype = dtype
        self.device = device
        if not (self.model_type == 2 or self.model_type == 3):
            raise ValueError(f'Provided likelihood_type is {self.model_type}. Should be 2 or 3.')

    def gaussian_negative_log_likelihood(self,
                                         pred: torch.tensor,
                                         target: torch.tensor) -> torch.tensor:
        """
            This compute the log of the liklihood distribution.
            Here we use the gaussian distribution for the likelihood function
        """
        # Noise var is learned from data
        if self.model_type == 2: 
            pred_mean = pred[:,0].view_as(target)
            pred_log_var = pred[:,1].view_as(target)
            log_likelihood = 0.5*pred_log_var + 0.5*math.log(2*math.pi) + ((target - pred_mean)**2)/(2*(torch.exp(pred_log_var) + 1e-16))
        # Fixed noise var
        else:
            pred_mean = pred.view_as(target)
            pred_log_var = torch.log(torch.tensor(self.noise_var))
            log_likelihood = 0.5*pred_log_var + 0.5*torch.log(torch.tensor(2*math.pi)) + ((target - pred_mean)**2)/(2*(torch.exp(pred_log_var) + 1e-16))
            # print(f'MSE = {(target - pred_mean)**2}')

        if self.reduction == 'mean':
            return torch.mean(log_likelihood)
        elif self.reduction == 'sum':
            return torch.sum(log_likelihood)
        else:
            return log_likelihood

    def negative_log_variance_prior_fn(self,
                                log_variance: torch.tensor, 
                                mean: float = 1e-6, 
                                variance: float = 0.01) -> torch.tensor:
        """
            This calculates the log of the prior distribution for the variance parameter in N(f(X,theta_mu), theta_var). 
            Here we use the lognormal distribution for the variance prior.
        """
        mean = torch.tensor(mean, dtype=self.dtype)
        variance = torch.tensor(variance, dtype=self.dtype)
        log_variance_prior = -0.5*torch.log(variance) - torch.exp(log_variance).sum() - (((log_variance - mean)**2)/(2*variance)).sum()
        neg_log_variance_prior = -log_variance_prior
        return neg_log_variance_prior
    
    def negative_log_params_prior_fn(self, params: Iterable[torch.tensor], fn='gaussian') -> torch.tensor:
        """
            This calculates log of the prior distribution for the weight parameters (theta_mu) that calculate the mean (f(X,theta_mu)) in N(f(X,theta_mu), theta_var).
            Here we use the normal distribution for the weight priors. This translates to the L2 norm on the weights.
        """
        if fn == 'gaussian':
            log_params_prior = torch.tensor(0.0, dtype=self.dtype, device=self.device)
            for param in params:
                log_params_prior += -0.5*(1/self.params_var)*torch.sum((param**2)) 
                # Including the constant part
                # log_params_prior += 0.5*torch.log(torch.tensor(1/(2*math.pi*self.params_var)))*param.numel() - 0.5*(1/self.params_var)*torch.sum((param**2))
            return -log_params_prior
        elif fn == 'laplace':
            log_params_prior = torch.tensor(0.0, dtype=self.dtype, device=self.device)
            for param in params:
                log_params_prior += (-1/self.params_var)*torch.sum(torch.abs(param))
            return -log_params_prior
        else:
            raise ValueError(f'Provided fn {fn} is not valid. Should be gaussian or laplace.')

    def forward(self, pred, target, params):
        if self.model_type == 2:
            gnll = self.gaussian_negative_log_likelihood(pred, target)
            neg_log_var_prior = self.negative_log_variance_prior_fn(pred[:,1].view_as(target))
            neg_log_params_prior = self.negative_log_params_prior_fn(params)
            return gnll, neg_log_var_prior, neg_log_params_prior
        else:
            gnll = self.gaussian_negative_log_likelihood(pred, target)
            neg_log_params_prior = self.negative_log_params_prior_fn(params)
            return gnll, neg_log_params_prior


def get_params(model:torch.nn.Module) -> torch.tensor:
    params = []
    for param in model.parameters():
        params.append(param.flatten())
    return torch.concatenate(params)

def get_params_grads(model:torch.nn.Module) -> torch.tensor:
    params_grads = []
    for param in model.parameters():
        params_grads.append(param.grad.flatten())
    return torch.concatenate(params_grads)

def set_params(model:torch.nn.Module, params) -> None:
    # stack_name, layer_num, param_name = name.split('.')
    # new_param = torch.nn.parameter.Parameter(data=params[i:i+num_params_per_layer[j]].view_as(old_param), 
    #                                         requires_grad=True)
    # Method 2 : Sets the params for the layer
    # setattr(model._modules[stack_name][int(layer_num)], param_name, new_param)
    # old_param.copy_(new_param.view_as(old_param)) # Does not work
    pointer=0
    for old_param in model.parameters():
        # Method 1 : From PyTorch
        num_param = old_param.numel()
        old_param.data = params[pointer:pointer+num_param].view_as(old_param).data
        pointer += num_param

def calc_params_grads(model, train_X, train_y, negative_log_prob_fn, step_num, print_every=100) -> None:
    # Forward Propagation
    pred = model(train_X)
    if negative_log_prob_fn.model_type == 2:
        # Compute Loss
        gnll, neg_log_var_prior, neg_log_params_prior = negative_log_prob_fn(pred, train_y, model.parameters())
        negative_log_prob = gnll + neg_log_var_prior + neg_log_params_prior
        if print_every is not None:
            if step_num%print_every == 0:
                print(f'Step {step_num} : gnll is {round(gnll.item(), 4)}, \
                    neg_log_var_prior is {round(neg_log_var_prior.item(), 4)}, \
                    neg_log_params_prior is {round(neg_log_params_prior.item(), 4)}, \
                    total loss is {round(negative_log_prob.item(), 4)}.')
    else:
        # Compute Loss
        gnll, neg_log_params_prior = negative_log_prob_fn(pred, train_y, model.parameters())
        print
        negative_log_prob = gnll + neg_log_params_prior
        if print_every is not None:
            if step_num%print_every == 0:
                print(f'Step {step_num} : gnll is {round(gnll.item(), 4)}, \
                    neg_log_params_prior is {round(neg_log_params_prior.item(), 4)}, \
                    total loss is {round(negative_log_prob.item(), 4)}.')
    # Backpropagate to accumulate gradients in the parameters
    negative_log_prob.backward()

# Ref : https://arxiv.org/pdf/1206.1901 (MCMC Using Hamiltonian Dynamics)
def run_hmc(model, train_X, train_y, potential_energy_fn, epsilon, L):
    """
        model : BNN to do the forward pass.
        current_q : This is the current position of the fictious particle. 
                    These are the weights of the BNN
        U : This is the potential energy function.
            This is the negative log of the probability distribution we want to sample from.
        grad_U : This is the gradient of the potential energy function. 
                 This is the gradient of the negative log of the probability distribution we want to sample from.
        epsilon : This is the step size of the leap frog integrator.
        L : This is the number of leap frog steps to take.
    """
    
    # Current set of weights as a 1D tensor
    q = get_params(model)
    current_q = q

    # Randomly sample momentum variables from the proposal distribution
    # In this case proposal dist is N(0,1).
    p = torch.randn_like(q)
    # Set this as current momentum
    current_p = p

    # Calculate the current Hamiltonian 
    if potential_energy_fn.model_type == 2: # Use this if liklihood is N(f(X,theta_mu), theta_var). In this case liklihood is homosckedastic model
        current_gnll, current_neg_log_var_prior, current_neg_log_wt_prior = potential_energy_fn(model(train_X), train_y, current_q)
        current_potential_energy = current_gnll + current_neg_log_var_prior + current_neg_log_wt_prior
    else: # Use this if liklihood is N(f(X,theta_mu), 1). In this case the variance is not a learnable parameter.
        current_gnll, current_neg_log_wt_prior = potential_energy_fn(model(train_X), train_y, current_q)
        current_potential_energy = current_gnll + current_neg_log_wt_prior
    current_kinetic_energy = 0.5*torch.sum(current_p**2) # Here is I. Can be scalar multiple of I.

    # Foward + Backward Propagation
    calc_params_grads(model, train_X, train_y, potential_energy_fn, 0)
    # Accumulate gradients as a 1D tensor
    params_grads = get_params_grads(model)

    # Make half step for momentum at beginning
    # p = p - epsilon/2*grad_U(q)
    p.add_(params_grads, alpha=-epsilon/2)
    # Zero out param grad
    model.zero_grad()

    # Run 'L' leap frog steps
    for i in range(1, L+1):

        # Make full step for position
        q.add_(p, alpha=epsilon)

        # Set params to model
        set_params(model, q)
        # Foward + Backward Propagation
        calc_params_grads(model, train_X, train_y, potential_energy_fn, i)
        # Accumulate gradients as a 1D tensor
        params_grads = get_params_grads(model)

        # Make full step for momentum except at end
        if i != L:
            # p = p - epsilon*grad_U(q)
            p.add_(params_grads, alpha=-epsilon)
            # Zero out param grad
            model.zero_grad()

    # Make half step for momentum at end
    # p = p - epsilon/2*grad_U(q)
    p.add_(params_grads, alpha=-epsilon/2)
    # Zero out param grad
    model.zero_grad()

    # Negating momentum at end of trajectory to make proposal symmetric
    # We actually dont need this step since we are anyways squaring p.
    p = -p

    # Calculate the Hamiltonian at the end of trajectory
    if potential_energy_fn.model_type == 2:
        proposed_gnll, proposed_neg_log_var_prior, proposed_neg_log_wt_prior = potential_energy_fn(model(train_X), train_y, model.parameters())
        proposed_potential_energy = proposed_gnll + proposed_neg_log_var_prior + proposed_neg_log_wt_prior
    else:
        proposed_gnll, proposed_neg_log_wt_prior = potential_energy_fn(model(train_X), train_y, model.parameters())
        proposed_potential_energy = proposed_gnll + proposed_neg_log_wt_prior
    proposed_kinetic_energy = 0.5*torch.sum(p**2)

    # Acceptance Criteria - Metropolis Hastings update
    # Decide whether to accept sample or not
    runif = torch.rand(1)
    accep_prob = min(1, torch.exp(current_potential_energy - proposed_potential_energy + current_kinetic_energy - proposed_kinetic_energy))
    if runif < accep_prob:
        print(f'runif {runif}, accep prob {accep_prob}')
        return q, 'accepted', accep_prob
    else:
        print(f'runif {runif}, accep prob {accep_prob}')
        # Set the current_q back as model parameters
        set_params(model, current_q)
        return None, 'rejected', accep_prob

# Ref : https://dl.acm.org/doi/10.5555/3157382.3157560 (Bayesian Optimization with Robust Bayesian Neural Networks.)
class SGHMC(Optimizer):
    def __init__(self, params, lr: float,
                 num_burn_in_steps: int,
                 scale_grad : float = 1.0,
                 mdecay: float = 0.05,
                 wd: float = 0.00002,
                 epsilon: float = 1e-16,
                 ) -> None:
        defaults = {'lr':lr, 
                    'scale_grad':scale_grad,
                    'num_burn_in_steps':num_burn_in_steps,
                    'mdecay':mdecay,
                    'wd':wd,
                    'epsilon':epsilon}
        super(SGHMC, self).__init__(params, defaults)
    
    def step(self, closure=None):
        loss = None

        if closure is not None:
            loss = closure()

        # A group of tensors can be optimized separately
        for group in self.param_groups:
            for parameter in group['params']:

                if parameter.grad is None:
                    continue

                state = self.state[parameter]

                if len(state) == 0:
                    state['iteration'] = 0
                    state['momentum'] = torch.randn(parameter.size(), dtype=parameter.dtype, device=parameter.device)

                state['iteration'] += 1

                mdecay, lr, wd = group['mdecay'], group['lr'], group['wd']
                scale_grad = group['scale_grad']

                momentum = state['momentum']
                grad = parameter.grad.data*scale_grad

                sigma = torch.sqrt(torch.from_numpy(np.array(2*lr*mdecay, dtype=type(lr))))
                sample_t = torch.normal(mean=torch.zeros_like(grad), std=torch.ones_like(grad)*sigma)
                
                # This is where the update steps take place (Ref : https://github.com/automl/pybnn/blob/master/pybnn/sampler/sghmc.py)
                parameter.data.add_(lr*mdecay*momentum)
                momentum.add_(-lr*grad - lr*mdecay*momentum + sample_t)

        return loss
    
class SGLD(Optimizer):
    def __init__(self, params, 
                 lr: float=1e-2,
                 scale_grad : float = 1.0) -> None:
        
        defaults = dict(
            lr=lr,
            scale_grad=scale_grad
        )
        super().__init__(params, defaults)

    def step(self, closure=None):

        loss = None

        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for parameter in group['parameters']:

                if parameter.grad is None:
                    continue

                state = self.state[parameter]

                lr, scale_grad = group['lr'], group['scale_grad']

                grad = parameter.grad.data*scale_grad

                if len(state) == 0:
                    state['iteration'] = 0

                sigma = torch.sqrt(torch.from_numpy(np.array(lr, dtype=type(lr))))
                
                # This is where the update steps take place (Ref : https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=56f89ce43d7e386bface3cba63e674fe748703fc)
                parameter.data.add_(0.5*lr*grad + sigma+torch.normal(mean=torch.zeros_like(grad), std=torch.ones_like(grad)))

                state['iteration'] += 1
                state['sigma'] = sigma

        return loss
        
# Test the piece of code
if __name__ == "__main__":
    pass
    # net = NeuralNet(1, 1, 10, 3, device=device, dtype=dtype)
    # print(net)
    # # Sample input
    # # INFO : torch.Tensor() : Always copies data. To avoid copy use torch.as_tensor()
    # test_tensor = torch.as_tensor([[10], [20]], device=device, dtype=dtype)
    # print(net(test_tensor))


## Training the BNN by HMC

In [None]:
###### USER INPUT ######
# -----------------------------------------
# BNN parameters
# -----------------------------------------
bnn_type = 3
perform_hyperparam_search = False
# -----------------------------------------
# Encoder parameters
# -----------------------------------------
enc_hidden_layers = None
enc_hidden_activations = None
enc_out_dim = 2
enc_out_activation = torch.nn.Tanh()
# -----------------------------------------
# Decoder parameters
# -----------------------------------------
dec_hidden_layers = None
dec_hidden_activations = None
dec_out_dim = train_X_data.shape[1]
dec_out_activation = None
# -----------------------------------------
# Predictor parameters
# -----------------------------------------
pred_hidden_layers = None
pred_hidden_activations = None
pred_out_dim = train_y_data.shape[1]
pred_out_activation = torch.nn.Sigmoid()
# -----------------------------------------
# Pretraining BNN Parameters
# -----------------------------------------
selected_init_weights = 'normal'
# Dictionary of defaults to use to ihtialize wts
init_wts_dist_params ={
    'uniform': {'a':-1, 'b':1},
    'normal': {'mean':0.0, 'std':10.0},
    'xavier_normal': {'gain':1.0}
}
perform_pretrain = True
pretrain_lr = 0.1
use_lr_scheduler = False
pretrain_num_epochs = 500
# Full batch gradient descent
pretrain_batch_size = train_dataset.shape[0]
# -----------------------------------------
# -log(p(w,D)) = -log(p(D|w)) - log(p(w))
# -----------------------------------------
# p(w) = N(0, params_var)
# Larger values indicate less certainity in the parameters.
# Small values indicate more certainity in the parameters.
params_var = init_wts_dist_params['normal']['std']**2
# p(D|w) = f(X; w) + N(0, noise_var) = N(f(X; w), noise_var) 
# homoscedastic model (same variance for all data points)
# noise_var = 1.0 => -log(p(D|w)) = 0.5*sum((y - f(X; w))**2)
noise_var = 0.001
# -----------------------------------------
# HMC Parameters
# -----------------------------------------
perform_hmc = False
num_samples = 500
# Number of samples to burn before storing 
num_samples_to_burn = 500
path_length = 0.2
# Step Size
epsilon = 0.01
# number_of_leapfrog_steps = L/epsilon 
adapt_epsilon = False
# -----------------------------------------

# TODO : Run hyperparameter search on model
if perform_hyperparam_search:
    pass

# Intialize the model architecture
encoder = BNN(in_dim=train_X_data.shape[1], out_dim=train_y_data.shape[1], 
                hidden_layers=enc_hidden_layers, hidden_activation=enc_hidden_activations,
                out_activation=enc_out_activation,
                bias=True,
                model_type=bnn_type,
                device=device,
                dtype=dtype)

predictor = BNN(in_dim=enc_out_dim, out_dim=pred_out_dim,
                hidden_layers=pred_hidden_layers, hidden_activation=pred_hidden_activations,
                out_activation=pred_out_activation,
                bias=True,
                model_type=bnn_type,
                device=device,
                dtype=dtype)

decoder = BNN(in_dim=enc_out_dim, out_dim=dec_out_dim,
                hidden_layers=dec_hidden_layers, hidden_activation=dec_hidden_activations,
                out_activation=dec_out_activation,
                bias=True,
                model_type=bnn_type,
                device=device,
                dtype=dtype)

def init_weights_scheme(scheme:str) -> None:
    if scheme == 'uniform':
        def init_weights(m:torch.nn.Module):
            if type(m) == torch.nn.Linear:
                # print(f' Weight value before : {m.weight}')
                torch.nn.init.uniform_(m.weight, a=init_wts_dist_params['uniform']['a'], b=init_wts_dist_params['uniform']['b'], generator=None)
                # print(f' Weight value after : {m.weight}')
        return init_weights
    elif scheme == 'xavier_uniform':
        def init_weights(m:torch.nn.Module):
            if type(m) == torch.nn.Linear:
                # print(f' Weight value before : {m.weight}')
                torch.nn.init.xavier_uniform_(m.weight, gain=10.0)
                # print(f' Weight value after : {m.weight}')
        return init_weights
    elif scheme == 'xavier_normal':
        def init_weights(m:torch.nn.Module):
            if type(m) == torch.nn.Linear:
                # print(f' Weight value before : {m.weight}')
                # Gain is an optional scaling factor
                torch.nn.init.xavier_normal_(m.weight, gain=init_wts_dist_params['xavier_normal']['gain'])
                # print(f' Weight value after : {m.weight}')
        return init_weights
    elif scheme == 'normal':
        def init_weights(m:torch.nn.Module):
            if type(m) == torch.nn.Linear:
                # print(f' Weight value before : {m.weight}')
                torch.nn.init.normal_(m.weight, mean=init_wts_dist_params['normal']['mean'], std=init_wts_dist_params['normal']['std'], generator=None)
                # print(f' Weight value after : {m.weight}')
        return init_weights
    elif scheme == 'constant':
        def init_weights(m:torch.nn.Module):
            if type(m) == torch.nn.Linear:
                # print(f' Weight value before : {m.weight}')
                torch.nn.init.constant_(m.weight, val=10.0)
                # print(f' Weight value after : {m.weight}')
        return init_weights
    else:
        raise TypeError(f'Provided scheme {scheme} is not valid. Use uniform, normal or constant.')
    
def init_bias_scheme(scheme:str) -> None:
    if scheme == 'uniform':
        def init_bias(m:torch.nn.Module):
            if type(m) == torch.nn.Linear:
                # print(f' Bias value before : {m.bias}')
                torch.nn.init.uniform_(m.weight, a=init_wts_dist_params['uniform']['a'], b=init_wts_dist_params['uniform']['b'], generator=None)
                # print(f' Bias value after : {m.bias}')
        return init_bias
    elif scheme == 'normal':
        def init_bias(m:torch.nn.Module):
            if type(m) == torch.nn.Linear:
                # print(f' Bias value before : {m.bias}')
                torch.nn.init.normal_(m.bias, mean=init_wts_dist_params['normal']['mean'], std=init_wts_dist_params['normal']['std'], generator=None)
                # print(f' Bias value after : {m.bias}')
        return init_bias
    elif scheme == 'xavier_normal':
        def init_bias(m:torch.nn.Module):
            if type(m) == torch.nn.Linear:
                # print(f' Bias value before : {m.bias}')
                torch.nn.init.xavier_normal_(m.bias, gain=init_wts_dist_params['xavier_normal']['gain'])
                # print(f' Bias value after : {m.bias}')
    else:
        raise TypeError(f'Provided scheme {scheme} is not valid. Use uniform, normal or constant.')

if selected_init_weights == 'default':
    # Initialize torch.nn.Linear weights from U(-k**0.5, k**0.5) k = 1/in_feats
    pass
else:
    # Initialize model weights
    bnn.apply(init_weights_scheme(selected_init_weights))
    bnn.apply(init_bias_scheme(selected_init_weights))

print('\n')
print('Selected BNN Architecture')
print(bnn)
print('\n')

# Define the losses for the predictor and decoder
pred_negative_log_joint_fn = NegativeLogJoint(model_type=bnn_type, noise_var=noise_var, params_var=params_var, reduction='mean', dtype=dtype, device=device)
dec_negative_log_joint_fn = NegativeLogJoint(model_type=bnn_type, noise_var=noise_var, params_var=params_var, reduction='mean', dtype=dtype, device=device)

# Model pretrain to find good set of weights to start off with
if perform_pretrain:
    # Initialize the optimizer
    # Since computing the negative log prob, we are minimizing.
    adam = Adam(bnn.parameters(), lr=pretrain_lr, maximize=False)

    if use_lr_scheduler:
        scheduler = torch.optim.lr_scheduler.ExponentialLR(adam, gamma=0.9999, verbose=True)
    else:
        scheduler = None

    # Training Loop
    train_dataloader = DataLoader(TensorDataset(train_X_data, train_y_data),
                                batch_size=pretrain_batch_size,
                                shuffle=True)

    test_dataloader = DataLoader(TensorDataset(val_X_data, val_y_data),
                                batch_size=pretrain_batch_size,
                                shuffle=True)

    loss_fn_list = [pred_negative_log_joint_fn, dec_negative_log_joint_fn]
    loss_wts = [1, 1]
    loss_names = ['neg_log_joint', 'neg_log_joint']
    store_train_losses_list = []
    store_test_losses_list = []
    store_train_gnll = []
    store_train_neg_log_params_prior = []
    store_test_gnll = []
    store_test_neg_log_params_prior = []

    for epoch in range(pretrain_num_epochs):
        
        train_loss_per_epoch_list, gnll, neg_log_params_prior = train_loop(train_dataloader, encoder, predictor, decoder, loss_fn_list, loss_wts, loss_names, adam, scheduler, print_every=None)
        store_train_losses_list.append(train_loss_per_epoch_list)
        store_train_gnll.append(gnll)
        store_train_neg_log_params_prior.append(neg_log_params_prior)

        test_loss_per_epoch_list, gnll, neg_log_params_prior = test_loop(test_dataloader, encoder, predictor, decoder, loss_fn_list, loss_wts, loss_names)   
        store_test_losses_list.append(test_loss_per_epoch_list) 
        store_test_gnll.append(gnll)
        store_test_neg_log_params_prior.append(neg_log_params_prior)

        if epoch % 100 == 0:
            print(f'-------------------------- Epoch {epoch} --------------------------')
            print('--------------------- Train Epoch Stats ---------------------')
            for loss_name, train_loss in zip(loss_names, train_loss_per_epoch_list):
                print(f'Train {loss_name} / epoch is ............................. {round(train_loss, 8)}')
            print(f'Train Total Loss / epoch is ..................... {round(sum(train_loss_per_epoch_list), 8)}')

            print(f'--------------------- Test Epoch Stats ----------------------')
            for loss_name, test_loss in zip(loss_names, test_loss_per_epoch_list):
                print(f'Test {loss_name} / epoch is ............................. {round(test_loss, 8)}')
            print(f'Test Total Loss / epoch is ..................... {round(sum(test_loss_per_epoch_list), 8)}')
            if use_lr_scheduler:
                print(f'Learning rate before step : {scheduler.get_last_lr()}')
                print('\n')
            else:
                print('\n')
        
    # Plotting the losses
    epochs = np.arange(0, pretrain_num_epochs, 1)
    store_train_losses_array = np.array(store_train_losses_list)
    store_test_losses_array = np.array(store_test_losses_list)
    store_tot_train_losses = np.sum(store_train_losses_array, axis=1)
    store_tot_test_losses = np.sum(store_test_losses_array, axis=1)

    store_train_gnll_array = np.array(store_train_gnll)
    store_test_gnll_array = np.array(store_test_gnll)
    store_train_neg_log_params_prior_array = np.array(store_train_neg_log_params_prior)
    store_test_neg_log_params_prior_array = np.array(store_test_neg_log_params_prior)

    fig, ax = plt.subplots(1,3,figsize=(12,4))
    ax[0].plot(epochs, store_train_losses_array[:,0], label='Train Neg. log. prob.')
    ax[0].plot(epochs, store_test_losses_array[:,0], label='Test Neg. log. prob.')
    ax[0].set_title('Neg. log. prob. v epochs')
    ax[1].plot(epochs, store_train_gnll_array[:,0], label='Train GNLL')
    ax[1].plot(epochs, store_test_gnll_array[:,0], label='Test GNLL')
    ax[1].set_title('GNLL v epochs')
    ax[2].plot(epochs, store_train_neg_log_params_prior_array[:,0], label='Train Neg. log. params prior')
    ax[2].plot(epochs, store_test_neg_log_params_prior_array[:,0], label='Test Neg. log. params prior')
    ax[2].set_title('Neg. log. params prior v epochs')
    # plt.legend()
    plt.show()

    # Zero grad the model
    bnn.zero_grad()
    # Delete the optimizer
    del adam

    # Saving pytorch model
    torch.save(copy.deepcopy(bnn.state_dict()), 'bnn_pretrained.pth')
    torch.save(bnn.state_dict(), 'bnn_pretrained.pth')

    # Accuracy on test set
    print('-------------------------- Test Set Accuracy --------------------------')
    preds = bnn(test_X_data)
    print(f'Predictions : {preds}')
    print(f'True values : {test_y_data}')

##########################
# HMC Starts from here 
##########################

if perform_hmc:

    # Load the pretrained model
    # bnn.load_state_dict(torch.load('bnn_pretrained.pth'))

    num_accepted = 0
    num_rejected = 0
    accepted_weight_samples = []

    # For adaptive epsilon
    gamma = 0.05
    t = 10.0
    kappa = 0.75
    mu = np.log(10*epsilon)
    log_best_epsilon = 0.0
    closeness = 0.0 

    # Set model into training mode
    bnn.train()
    print(f'---------------------- Starting HMC ----------------------')
    for i in range(1, num_samples+num_samples_to_burn+1):
        # Adapted from https://github.com/yucenli/bnn-bo/blob/main/models/hmc_utils.py
        num_leapfrog_steps_per_sample = min(200, int(np.ceil(path_length/epsilon)))
        print(f'-------------------- Running sample {i} --------------------')
        print(f'Epsilon : {epsilon}, Path length : {path_length}, Num leapfrog steps /sample : {num_leapfrog_steps_per_sample}')
        q, result, accep_prob = run_hmc(bnn, train_X_data, train_y_data, negative_log_joint_fn, epsilon, num_leapfrog_steps_per_sample)

        if i <= num_samples_to_burn:
            # Code sourced from : https://github.com/yucenli/bnn-bo/blob/main/models/hmc_utils.py
            if adapt_epsilon:

                iter = float(i + 1)
                closeness_frac = 1.0/(i + t)
                closeness = (1.0 - closeness_frac)*closeness + closeness_frac*(0.75 - accep_prob)
                # The above equation is from  :  https://arxiv.org/pdf/1206.1901.pdf
                log_epsilon = mu - (math.sqrt(iter)/gamma)*closeness
                epsilon = math.exp(log_epsilon)

                step_frac = math.pow(i, -kappa)
                log_best_epsilon = (step_frac*log_epsilon) + (1 - step_frac)*log_best_epsilon

                if (path_length / epsilon) > 200:
                    path_length = path_length / 2
                    print("new path length", path_length, "epsilon", epsilon)

                if i == num_samples_to_burn:
                    epsilon = math.exp(log_best_epsilon)
                    print(f'Final Epsilon : {epsilon}')
            print(f'New epsilon : {epsilon}')
            print(f'-------------------- Burning sample {i} --------------------\n')
        else:
            # Print Accept or Reject
            if result == 'accepted':
                num_accepted += 1
                # Add to list of proposed weight samples ..
                accepted_weight_samples.append(q)
            else:
                num_rejected += 1
            print(f'Final Epsilon : {epsilon}')
            acceptance_ratio = num_accepted/(num_samples)
            print(f'--- Sample {i} {result}; Accepted {num_accepted}; Accep ratio {acceptance_ratio} ---\n')

In [None]:
q = get_params(bnn)
print(q)
for q, feat in zip(q, features):
    print(f'Feature {feat} : {q}')

In [None]:
plt.plot(train_x, train_y, '.r')
plt.plot(test_x, test_y, '--k', label='True', alpha=0.5)
plt.plot(test_x, 9.9215*test_x, '-k', label='loss = 1.5963',  alpha=1.0)
# plt.plot(test_x, accepted_weight_samples[0].detach().numpy()[0]*test_x, '--b', label='loss = 3.0149', alpha=0.5)
# plt.plot(test_x, accepted_weight_samples[1].detach().numpy()[0]*test_x, '--r', label='loss = 1.5462', alpha=0.5)
# plt.plot(test_x, accepted_weight_samples[2].detach().numpy()[0]*test_x, '--g', label='loss = 1.4276', alpha=0.5)
# plt.plot(test_x, accepted_weight_samples[3].detach().numpy()[0]*test_x, '--y', label='loss = 1.3571', alpha=0.5)
# plt.plot(test_x, accepted_weight_samples[4].detach().numpy()[0]*test_x, '--c', label='loss = 1.2948', alpha=0.5)
plt.xlabel('X')
plt.ylabel('y')
plt.legend()

## Posterior Distribution of weights p(w|D)

In [None]:
import seaborn as sns

# Distribution of accepted weight samples
accepted_weight_samples_list = list(torch.stack(accepted_weight_samples).detach().numpy().squeeze())    
print(accepted_weight_samples_list)
# Plotting the distribution of the weights
fig, ax = plt.subplots(1,1,figsize=(8,4))
sns.kdeplot(accepted_weight_samples_list, ax=ax, label='p(w|D)')
plt.title('Distribution of weights')
# Label the mean of the distribution and its value in legend
mean = round(float(np.mean(accepted_weight_samples_list)), 4)
plt.axvline(mean, color='red', linestyle='--', label=f'Mean : {mean}')

def normal_dist(x):
    return (1/(2*math.pi*100)**0.5)*math.exp(-(x**2)/200)
x = np.linspace(-30, 30, 1000)
y = [normal_dist(i) for i in x]
plt.plot(x, y, label='p(w) = N(0, 100)')
plt.legend()

## Joint Distribution of weights and hyperparameters p(w,alpha|D)

In [None]:
# Joint distribution of weights and data using seaborn

# Plotting the data
# randomly_selected_wts = np.random.choice(accepted_weight_samples_list, test_y.size(0))
test_x_1000 = torch.reshape(torch.linspace(0, 1, 989).float().unsqueeze(1).to(device), (-1, 1))
print(g(test_x_1000).squeeze().shape)
print(len(accepted_weight_samples_list))
sns.jointplot(x=g(test_x_1000).squeeze(), y=accepted_weight_samples_list, kind='kde', ax=ax)


In [None]:
# Create a grid for the density plot
x_edges = np.linspace(-5, 20, 100)
y_edges = np.linspace(8.5, 10, 100)
x_grid, y_grid = np.meshgrid(x_edges, y_edges)

# Calculate the density using a Gaussian kernel density estimate
from scipy.stats import gaussian_kde

# Create a Gaussian kernel density estimate of the data points
kde = gaussian_kde([test_y.squeeze().cpu(), 
                    randomly_selected_wts])
z_density = kde(np.vstack([x_grid.ravel(), y_grid.ravel()]))
z_density = z_density.reshape(x_grid.shape)

# # Plotting the surface
# fig = plt.figure()
# ax = fig.add_subplot(111, projection='3d')
# ax.plot_surface(x_grid, y_grid, z_density, cmap='viridis')
# ax.set_xlabel('X-axis')
# ax.set_ylabel('Y-axis')
# ax.set_zlabel('Density')
# plt.show()

# Plotting the surface
import plotly.graph_objects as go
fig = go.Figure(data=[go.Surface(z=z_density, x=x_grid, y=y_grid, colorscale='Viridis')])
fig.update_layout(scene=dict(
                    xaxis_title='X-axis',
                    yaxis_title='Y-axis',
                    zaxis_title='Density'),
                  title='3D Density Plot')
fig.show()

## Plotting the posterior distribution

In [None]:
# q = get_params(bnn)
# accepted_weight_samples = [q]

# Run through and get f preds for each weight sample
f_draws = []
with torch.no_grad():
    for weights in accepted_weight_samples:
        set_params(bnn, weights)
        f_draws.append(bnn(test_X_data)[:,0].detach().numpy())

f_draws_array = np.array(f_draws)
print(f_draws_array.shape)

f_mean = np.round(np.mean(f_draws_array, axis=0), 4)
f_var = np.round(np.var(f_draws_array, axis=0), 4)
f_std = np.round(np.sqrt(f_var), 4)

print(test_y)
print(f_mean)
print(f_var)
print(f_std)

# # f_covar = f_preds.covariance_matrix
# # f_samples = f_preds.sample(sample_shape=torch.Size((10,)))

# plt.plot(test_x.squeeze().cpu(), test_y.squeeze().cpu(), '--k', label='True')
# plt.plot(train_x.cpu(), train_y.cpu(), '.r')
# # plt.plot(test_x.squeeze().cpu(), y_preds.mean.detach().numpy(), color='green', label='mean', linewidth=2)
# plt.plot(test_x.squeeze().cpu(), f_mean, color='orange', label='mean', linewidth=2)
# plt.gca().fill_between(test_x.squeeze().cpu(), f_mean - 2*f_std, f_mean + 2*f_std, label=r'2$\sigma$', alpha = 0.2, color="orange")

# # Randomly choose 8 index to plot
# idx = np.random.choice(f_draws_array.shape[0], 10)
# for i in idx:
#     plt.plot(test_x.squeeze().cpu(), f_draws_array[i,:], color='C0', alpha=0.1)
# plt.legend()