**Bayes-by-backprop prototype**

In this notebook, we implement a small prototype for Bayes-by-backprop introduced by Blundell et al.

In [49]:
import numpy as np
import os
import random
import torch
import torch.nn as nn
import torchvision
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader, Dataset
import math
import torch.nn.functional as F

def set_seed(seed=1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seed(1871)

In [50]:
# Regression data function
f = lambda x, epsilon: x + 0.3 * np.sin(2*np.pi * (x+epsilon)) + 0.3 * np.sin(4 * np.pi * (x+epsilon)) + epsilon

def generate_data(N, lower, upper, std, f=f):
    # create data
    x = np.linspace(lower, upper, N)

    y = []
    for i in range(N):
        epsilon = np.random.normal(0, std)
        y.append(f(x[i], epsilon))
    return x, y

# Generate train data
N_train = 2000
x, y = generate_data(N_train, lower=-0.25, upper=1, std=0.02)

# Generate validation data
N_val = 500
x_val, y_val = generate_data(N_val, lower=-0.25, upper=1, std=0.02)

# Generate test data
N_test = 500
x_test, y_test = generate_data(N_test, lower=-0.5, upper=1.5, std=0.02)

line = f(x_test, 0)

In [69]:
class ToyDataset(Dataset):
    """Custom toy dataset"""

    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __getitem__(self, idx):

        return self.x[idx], self.y[idx]

    def __len__(self):
        return len(self.x)

def collate_fn(batch):

    x, y = zip(*batch)

    return torch.tensor(x)[:,None].T, torch.tensor(y)

We use a diagonal Gaussian distribution (zero covariance) as the variational posterior. Rather than using just $\mu$ and $\sigma$ as the variational parameters, the standard deviation is parameterised as:
$$
\sigma = \log{(1 + \exp{(\rho)})}
$$
such that $\sigma$ is always non-negative. The variational parameters are then $\mathbf{\theta} = (\mu, \rho)$. 

The code blocks in the following sections are inspired by:

https://github.com/nitarshan/bayes-by-backprop/blob/master/Weight%20Uncertainty%20in%20Neural%20Networks.ipynb
https://colab.research.google.com/drive/1K1I_UNRFwPt9l6RRkp8IYg1504PR9q4L#scrollTo=ASGi2Ecx5G-F


In [70]:
class Gaussian(object):
    def __init__(self, mu, rho):
        super().__init__()
        self.mu = mu
        self.rho = rho
        self.normal = torch.distributions.Normal(0, 1)

    @property # @property so we can call self.sigma directly rather than self.sigma()
    def sigma(self):
        return torch.log1p(torch.exp(self.rho))
    
    def sample(self):  
        # sample noise from normal distribution 
        epsilon = self.normal.sample(self.rho.size())
        return self.mu + self.sigma * epsilon # scale with mu and sigma
    
    def log_prob(self, w):
        # log pdf for Gaussian distribution
        return torch.sum(-torch.log(self.sigma) - 0.5*np.log(2*np.pi) - 0.5 * ((w - self.mu) / self.sigma)**2)

The ``Gaussian`` class is a simple class that allows us to sample from a Gaussian distribution, representing the variational posterior.

The function ``sigma`` computes the standard deviation $\sigma$ for a given $\rho$ value.
The function ``sample`` allows us to sample from the approximate posterior, using the reparametrisation trick with $\mu$ and $\sigma$. 
The function ``log_prob`` computes the log-probability density function for a normal distribution wtih mean $\mu$ and standard devation $\sigma$ (derivation below):

The probability density function for the weights $\mathbf{w}$ given the variational parameters $\mathbf{\theta} = (\mu, \rho)$ of a Gaussian distribution is given as
\begin{align*}
q(\mathbf{w|\mathbf{\theta}}) &= \prod_j \mathcal{N}(w_j | \mu, \sigma) \\
& = \prod_j \frac{1}{\sigma \sqrt{(2\pi)}} \exp{-\frac{1}{2} \left( \frac{w_j - \mu}{\sigma}    \right)^2}
\end{align*}
Then taking the log, we get:
\begin{align*}
\log{q(\mathbf{w}|\mathbf{\theta})} &=  \log{  \left( \prod_j \frac{1}{\sigma \sqrt{(2\pi)}} \exp{-\frac{1}{2} \left( \frac{w_j - \mu}{\sigma}    \right)^2} \right)} \\
           &= \sum_j \log{(1)} - \log{\left(\sigma \sqrt(2 \pi)\right) - \frac12 \left( \frac{w_j - \mu}{\sigma} \right)^2 } \\
           &= \sum_j -\log{(\sigma)} - \frac12 \log{(2\pi)} - \frac12 \left( \frac{w_j - \mu}{\sigma} \right)^2
\end{align*}

The prior that was proposed in the paper by Blundell et al. is a Gaussian mixture prior over the weights $\mathbf{w}$. 
$$
P(\mathbf{w}) = \prod_j \pi \mathcal{N}(w_j|0, \sigma_1^2) + (1-\pi) \mathcal{N}(w_j| 0, \sigma_2^2) 
$$
where $\pi \in [0,1]$ is the mixture weight.

In [117]:
class ScaleMixturePrior(object):
    def __init__(self, pi, sigma1, sigma2):
        super().__init__()

        """
        Implementing the scale mixture prior in equation 7 of the paper.
        From the paper: sigma1 > sigma2 and sigma2 << 1.
        """
        assert sigma1 > sigma2, "Error: sigma1 must be greater than sigma2."
        assert sigma2 < 1, "Error: sigma2 must be less than 1."

        self.pi = pi
        
        self.sigma1 = sigma1
        self.sigma2 = sigma2

        self.gaussian1 = torch.distributions.Normal(0,sigma1)
        self.gaussian2 = torch.distributions.Normal(0,sigma2)

    def log_prob(self, w):
        """
        Implementing the log pdf for the scale mixture prior
        """

        p1 = torch.exp(self.gaussian1.log_prob(w)) # torch.exp of log pdf so we get the pdf
        p2 = torch.exp(self.gaussian2.log_prob(w))
        return torch.log(self.pi * p1 + (1-self.pi) * p2)

    

We implement a Bayesian linear layer for our Bayesian neural network. 

In [170]:
# inspired by
# https://github.com/nitarshan/bayes-by-backprop/blob/master/Weight%20Uncertainty%20in%20Neural%20Networks.ipynb
# https://colab.research.google.com/drive/1K1I_UNRFwPt9l6RRkp8IYg1504PR9q4L#scrollTo=ASGi2Ecx5G-F


class BayesianLinearLayer(nn.Module):
    def __init__(self, input_dim, output_dim, pi=0.5, sigma1=torch.exp(torch.tensor([-0])), sigma2=torch.exp(torch.tensor([-6]))):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.pi = pi
        self.sigma1 = sigma1
        self.sigma2 = sigma2

        # initialise mu and rho parameters so they get updated in backpropagation
        self.weight_mus = nn.Parameter(torch.Tensor(input_dim, output_dim).uniform_(-0.05, 0.05))
        self.weight_rhos = nn.Parameter(torch.Tensor(input_dim, output_dim).uniform_(-2, -1)) 
        self.bias_mus = nn.Parameter(torch.Tensor(output_dim).uniform_(-0.05, 0.05))
        self.bias_rhos = nn.Parameter(torch.Tensor(output_dim).uniform_(-2, -1))

        # create approximate posterior distribution
        self.weight_posterior = Gaussian(self.weight_mus, self.weight_rhos)
        self.bias_posterior = Gaussian(self.bias_mus, self.bias_rhos)

        # scale mixture posterior
        self.weight_prior = ScaleMixturePrior(pi=pi, sigma1=sigma1, sigma2=sigma2)
        self.bias_prior = ScaleMixturePrior(pi=pi, sigma1=sigma1, sigma2=sigma2)

        self.log_prior = 0.0
        self.log_variational_posterior = 0.0

    def forward(self, x, test=False):
        if test:
            # during inference, we simply use the mus of the weights and biases
            w = self.weight_mus
            b = self.bias_mus

            self.log_prior = 0.0
            self.log_variational_posterior = 0.0
        else:
            # sample from approximate posterior distribution
            w = self.weight_posterior.sample()
            b = self.bias_posterior.sample()

            # compute log prior and log variational posterior
            self.log_prior = self.weight_prior.log_prob(w) + self.bias_prior.log_prob(b)
            self.log_variational_posterior = self.weight_posterior.log_prob(w) + self.bias_posterior.log_prob(b)
            # print("log_prior:", self.log_prior)
        
        return torch.mm(x, w) + b # matrix multiply input by weights and add bias


class BayesianNeuralNetwork(nn.Module):

    def __init__(self, hidden_units1, hidden_units2, pi=0.5, sigma1=torch.exp(torch.tensor([-0])), sigma2=torch.exp(torch.tensor([-6]))):
        super().__init__()
        self.model = nn.Sequential(BayesianLinearLayer(1, hidden_units1),
                                   nn.ReLU(),
                                   BayesianLinearLayer(hidden_units1, hidden_units2),
                                   nn.ReLU(),
                                   BayesianLinearLayer(hidden_units2, 2)) # output channel = 2 for mean and variance

        self.log_prior = 0
        self.log_variational_posterior = 0

    def forward(self, x, test=False):

        x = self.model(x)

        return x
    
    def compute_log_prior(self):
        for layer in self.model:
            if isinstance(layer, BayesianLinearLayer):
                self.log_prior += layer.log_prior.sum()

    def compute_log_variational_posterior(self):
        for layer in self.model:
            if isinstance(layer, BayesianLinearLayer):
                # print(layer.log_variational_posterior)
                self.log_variational_posterior += layer.log_variational_posterior.sum()

    def get_sigma(self, rho):
        return torch.log1p(torch.exp(rho))


    def compute_ELBO(self, input, target):
        # formula from Blundell: loss = log_variational_posterior - log_prior - log_likelihood
        #                        loss = log_variational_posterior - log_prior + NLL
        
        # compute log prior and log variational posterio

        output = self.forward(input, test=False)
        mu = output[:,0]
        rho = output[:,1]

        self.compute_log_prior()
        self.compute_log_variational_posterior()
        NLL = self.NLL_loss(target, mu, rho)
        # print(self.log_variational_posterior)
        # print(self.log_prior)
        # print(NLL)

        return torch.sum(self.log_variational_posterior - self.log_prior + NLL)

    def NLL_loss(self, target, mu, rho):
        # negative log likelihood loss
        sigma = self.get_sigma(rho)

        return 1 / (2 * sigma.pow(2)) * (target - mu).pow(2) + torch.log(sigma)
    


In [171]:
# seed workers for reproducibility
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(0)

traindata = ToyDataset(x, y)
trainloader = DataLoader(traindata, batch_size=60, shuffle=True, worker_init_fn=seed_worker, generator=g, collate_fn=collate_fn)

valdata = ToyDataset(x_val, y_val)
valloader = DataLoader(valdata, batch_size=60, shuffle=False, collate_fn=collate_fn)

testdata = ToyDataset(x_test, y_test)
testloader = DataLoader(testdata, batch_size=N_test, shuffle=False, collate_fn=collate_fn)

# create model
BNN_model = BayesianNeuralNetwork(32, 128)
optimizer = torch.optim.Adam(BNN_model.parameters(), lr=3e-4)



In [172]:
# useful functions 🤖

def train(model, optimizer, trainloader, valloader, epochs=500, model_name='BNN', val_every_n_epochs=10):

    losses = []
    val_losses = []

    best_val_loss = np.inf

    for e in tqdm(range(epochs)):
        
        for x_, y_ in trainloader:

            model.train()

            x_,y_ = x_.float().T, y_.float()

            optimizer.zero_grad()

            loss = model.compute_ELBO(x_, y_)
            # print(loss)

            loss.backward()
            optimizer.step()

            losses.append(loss.item())  

        if (e+1) % val_every_n_epochs == 0:
            model.eval()

            val_loss_list = []
            with torch.no_grad():
                for val_x, val_y in valloader:
                    val_x, val_y = val_x.float(), val_y.float()
                
                    val_loss = model.compute_ELBO(val_x, val_y)
                    val_loss_list.append(val_loss.item())

            val_losses.extend(val_loss_list)
            mean_val_loss = np.mean(val_loss_list)
            if mean_val_loss < best_val_loss:
                best_val_loss = mean_val_loss
                torch.save(model, f'{model_name}.pt')
            # print(f"Mean validation loss at epoch {e}: {mean_val_loss}")

    return losses, val_losses


def plot_loss(losses, val_losses):

    fig, ax = plt.subplots(1,2, figsize=(12,6))

    ax[0].plot(losses, label='Train loss')
    ax[0].set_title('Train loss')
    ax[0].set_xlabel('Iterations')
    ax[0].set_ylabel('Loss')

    ax[1].plot(val_losses, label='Validation loss', color='orange')
    ax[1].set_title('Validation loss')
    ax[1].set_xlabel('Iterations')
    ax[1].set_ylabel('Loss')

    plt.show()

In [173]:
losses, val_losses = train(BNN_model, optimizer, trainloader, valloader, epochs=500, model_name='BNN', val_every_n_epochs=10)

plot_loss(losses, val_losses)

  0%|          | 0/500 [00:00<?, ?it/s]

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.