# Bayesian convolutional layers
The goal of this notebook is to implement the Bayesian convolutional layers proposed by [Shridkar et al.](https://github.com/kumar-shridhar/PyTorch-BayesianCNN) and [Seligmann et al.](https://github.com/Feuermagier/Beyond_Deep_Ensembles/tree/main). 

Rather than learning the distribution of weights in the filter, the goal is to learn the distribution of activations after appyling one convolutional layer. This approach is described in the paper "Variational Dropout and the Local Reparameterization Trick" by [Kingma et al.](https://arxiv.org/pdf/1506.02557). 

Consider the input feature matrix $\mathbf{A}$ of size $M \times 1000$ and a weight matrix $\mathbf{W}$ of size $1000 \times 1000$. When multiplied together, the resulting matrix $\mathbf{B}$ is:
$$
\mathbf{B} = \mathbf{AW}    
$$
where $\mathbf{B}$ are called the activations. Rather than trying to learn the distribution of the weights in $\mathbf{W}$, we learn the activations $\mathbf{B}$ directly.


# Gaussian prior
The prior that we implemented during the Master's thesis is shown below:

In [105]:
import torch
import numpy as np

class ScaleMixturePrior():
    def __init__(self, pi=0.5, sigma1=torch.exp(torch.tensor(0)), sigma2=torch.tensor(0.3), device='cpu'):
        self.device = device
        self.pi = pi
        self.mu = torch.tensor(0)
        self.sigma1 = sigma1
        self.sigma2 = sigma2

    def prob(self, w, sigma):
    
        return (1 / (sigma * torch.sqrt(torch.tensor(2 * np.pi)))) * torch.exp(-0.5 * torch.pow((w - self.mu), 2) / torch.pow(sigma, 2))

    def log_prob(self, w):
        prob1 = self.prob(w, self.sigma1)
        prob2 = self.prob(w, self.sigma2)

        return torch.log(self.pi * prob1 + ((1 - self.pi) * prob2)).sum() if self.sigma2.item() > 0 else torch.log(prob1).sum()

According to Seligmann et al., the standard today is not to use the scale mixture prior proposed by Blundell et al. in the original paper. We take this into account and implement the Gaussian prior.

In [107]:
# Adapted from: https://github.com/Feuermagier/Beyond_Deep_Ensembles/tree/main

class GaussianPrior:
    def __init__(self, sigma1):
        self.sigma1 = sigma1
        self.dist = torch.distributions.Normal(0, sigma1)

    def log_prob(self, x):
        return self.dist.log_prob(x).sum()
    
class Gaussian():
    def __init__(self, mu, rho, device='cpu'):
        self.device = device
        self.mu = mu
        self.rho = rho
        self.init_distribution()

    @property
    def sigma(self):
        return torch.log1p(torch.exp(self.rho))
    
    def init_distribution(self):
        self.normal = torch.distributions.Normal(self.mu, self.sigma)
    
    def rsample(self):
        self.init_distribution()
        return self.normal.rsample()
    
    def log_prob(self, w):
        return self.normal.log_prob(w).sum()

Now we test that the two priors are equivalent

In [138]:
prior1 = ScaleMixturePrior(pi=1, sigma1=torch.tensor(40.82))
prior2 = GaussianPrior(sigma1=torch.tensor(40.82))

test_tensor = torch.randn(2500)

print("Scale Mixture Prior with \pi=1: ", prior1.log_prob(test_tensor).item())
print("Gaussian Prior: ", prior2.log_prob(test_tensor).item())

# sometimes, there is a minimal difference.

Scale Mixture Prior with \pi=1:  -11571.013671875
Gaussian Prior:  -11571.013671875


In [139]:
import torch.nn as nn
import torch.nn.functional as F

class BayesianConvLayer(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, use_bias=True, device='cpu', sigma1=torch.tensor(1.0)):
        super(BayesianConvLayer, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.use_bias = use_bias

        self.weight_mu = nn.Parameter(torch.Tensor(out_channels, in_channels, kernel_size, kernel_size).normal_(0, 0.1))
        self.weight_rho = nn.Parameter(torch.Tensor(out_channels, in_channels, kernel_size, kernel_size).uniform_(-3, 3))
        
        if use_bias:
            # initialise bias
            self.bias_mu = nn.Parameter(torch.Tensor(out_channels).normal_(0, 0.1))
            self.bias_rho = nn.Parameter(torch.Tensor(out_channels).uniform_(-3, 3))
        else:
            # set bias to "None"
            self.register_parameter('bias_mu', None)
            self.register_parameter('bias_rho', None)

        # initialise variational posteriors
        self.weight_posterior = Gaussian(self.weight_mu, self.weight_rho, device=device)
        self.bias_posterior = Gaussian(self.bias_mu, self.bias_rho, device=device)

        # initialise priors
        self.weight_prior = GaussianPrior(sigma1=sigma1)
        self.bias_prior = GaussianPrior(sigma1=sigma1)

        self.log_prior = 0 
        self.log_variational_posterior = 0

        
    def forward(self, x):
        # Taken from: https://github.com/Feuermagier/Beyond_Deep_Ensembles/blob/b805d6f9de0bd2e6139237827497a2cb387de11c/src/algos/util.py#L185

        activation_mean = F.conv2d(x, self.weight_mu, self.bias_mu if self.use_bias else None, self.stride, self.padding, self.dilation)
        actiation_var = F.conv2d((x**2).clamp(1e-4), (F.softplus(self.weight_rho)**2).clamp(1e-4), (F.softplus(self.bias_rho)**2).clamp(1e-4) if self.use_bias else None, self.stride, self.padding, self.dilation)
        activation_std = torch.sqrt(actiation_var)

        epsilon = torch.empty_like(activation_mean).normal_(0,1)   

        w = self.weight_mu 
        b = self.bias_mu if self.use_bias else None

        output = activation_mean + activation_std * epsilon

        self.log_prior = self.weight_prior.log_prob(w) + self.bias_prior.log_prob(b) if self.use_bias else self.weight_prior.log_prob(w)
        self.log_variational_posterior = self.weight_posterior.log_prob(w) + self.bias_posterior.log_prob(b) if self.use_bias else self.weight_posterior.log_prob(w)

        return output

In [149]:
# Test the BayesianConvLayer
conv_layer = BayesianConvLayer(3, 64, 3, use_bias=True)
noise_image = torch.randn(1, 3, 32, 32)

output = conv_layer(noise_image)
assert output.shape == (1, 64, 30, 30), "Output shape is not correct"


# Test on CIFAR10 dataset
We test the Bayesian convolutional layer on the CIFAR10 dataset. First, we build a normal CNN classifier.

In [150]:
import torch
import torchvision
import torchvision.transforms as transforms

# download CIFAR10 from PyTorch
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:09<00:00, 17790273.34it/s]


Extracting ./data\cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [151]:
# Setup convolutional neural network
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()

In [153]:
import torch.optim as optim

# train on CrossEntropyLoss and use SGD optimiser

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)



In [154]:
# train loop

for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

print('Finished Training')

[1,  2000] loss: 2.197
[1,  4000] loss: 1.833
[1,  6000] loss: 1.685
[1,  8000] loss: 1.591
[1, 10000] loss: 1.535
[1, 12000] loss: 1.491
[2,  2000] loss: 1.440
[2,  4000] loss: 1.398
[2,  6000] loss: 1.382
[2,  8000] loss: 1.344
[2, 10000] loss: 1.332
[2, 12000] loss: 1.313
Finished Training


In [158]:
# test network

correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for data in testloader:
        images, labels = data
        # calculate outputs by running images through the network
        outputs = net(images)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total} %')

Accuracy of the network on the 10000 test images: 54.69 %


Then we build a Bayesian classifier using the Bayesian convolutional layer.