In [1]:
# See https://github.com/nitarshan/bayes-by-backprop/blob/master/Weight%20Uncertainty%20in%20Neural%20Networks.ipynb
# The network structure is basically copied, while the learning code has been written mostly by me

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
import itertools

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    print("No cuda device available; using the CPU")
    device = torch.device("cpu")

No cuda device available; using the CPU


In [3]:
batch_size = 5

transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize(0,5, 0.5)
])

trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=2)

In [4]:
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

# dataiter = iter(trainloader)
# images, labels = dataiter.next()
# imshow(torchvision.utils.make_grid(images))

In [5]:
# Multidimensional gaussian distribution; for sampling weights and biases of a nn layer

class GaussianParameters:
    def __init__(self, mu: torch.Tensor, rho: torch.Tensor):
        self.mu, self.rho = mu, rho
        self.distribution = torch.distributions.Normal(0, 1)

    @property
    def sigma(self):
        return torch.log1p(torch.exp(self.rho))

    def sample(self):
        epsilon = self.distribution.sample(self.rho.size()).to(device)
        return self.mu + self.sigma * epsilon

    def log_prob(self, value):
        return (-np.log(np.sqrt(2 * np.pi)) - torch.log(self.sigma) - ((value - self.mu) ** 2) / (2 * self.sigma ** 2)).sum()

In [6]:
# Weighted sum of two gaussian distributions
class GaussianMixture:
    def __init__(self, pi, sigma1, sigma2):
        self.pi, self.sigma1, self.sigma2 = pi, sigma1, sigma2
        self.gaussian1 = torch.distributions.Normal(0, sigma1)
        self.gaussian2 = torch.distributions.Normal(0, sigma2)

    def log_prob(self, value):
        p1 = torch.exp(self.gaussian1.log_prob(value))
        p2 = torch.exp(self.gaussian2.log_prob(value))
        return torch.log((self.pi * p1 + (1 - self.pi) * p2)).sum()

In [7]:
# Single layer of a BBB network. In essence this is a normal fc layer, but the weights and biases
# are drawn from a gaussian distribution and the corresponding means and variances are learned

# This class also stores the prior and variational posterior of the weights and biases of the last
# forward pass through the layer

class BayesianLayer(nn.Module):
    def __init__(self, in_features: int, out_features: int, weight_prior: GaussianMixture, bias_prior: GaussianMixture):
        super().__init__()
        self.in_features, self.out_features = in_features, out_features
        self.weight_prior, self.bias_prior = weight_prior, bias_prior
        self.log_prior, self.log_posterior = 0, 0

        # Weights
        self.weight_mu = nn.Parameter(torch.Tensor(out_features, in_features).uniform_(-0.2, 0.2))
        self.weight_rho = nn.Parameter(torch.Tensor(out_features, in_features).uniform_(-5, -4))
        self.weight = GaussianParameters(self.weight_mu, self.weight_rho)

        # Biases
        self.bias_mu = nn.Parameter(torch.Tensor(out_features).uniform_(-0.2, 0.2))
        self.bias_rho = nn.Parameter(torch.Tensor(out_features).uniform_(-5, -4))
        self.bias = GaussianParameters(self.bias_mu, self.bias_rho)

    def forward(self, input):
        if self.training:
            weight = self.weight.sample()
            bias = self.bias.sample()
            # Log probability of the drawn parameters given the prior
            self.log_prior = self.weight_prior.log_prob(weight) + self.bias_prior.log_prob(bias)
            # Log probability of the drawn parameters given the parameter distribution
            self.log_posterior = self.weight.log_prob(weight) + self.bias.log_prob(bias)
        else:
            # If we are not in training mode use deterministic parameters (i.e. the means)
            weight = self.weight.mu
            bias = self.bias.mu
        return F.relu(F.linear(input, weight, bias))


In [8]:
# BBB network with three layers (and currently fixed input/hidden/output dimensions). Stores also
# the prior and variational posterior of the last forward pass

class BayesianNetwork(nn.Module):
    def __init__(self, weight_prior: GaussianMixture, bias_prior: GaussianMixture):
        super().__init__()
        self.first = BayesianLayer(28*28, 400, weight_prior, bias_prior)
        self.second = BayesianLayer(400, 400, weight_prior, bias_prior)
        self.third = BayesianLayer(400, 10, weight_prior, bias_prior)
    
    @property
    def log_prior(self):
        return self.first.log_prior + self.second.log_prior + self.third.log_prior

    @property
    def log_posterior(self):
        return self.first.log_posterior + self.second.log_posterior + self.third.log_posterior

    def forward(self, input: torch.Tensor):
        input = self.first(input)
        input = self.second(input)
        input = self.third(input)
        return F.log_softmax(input, dim=1) # TODO

In [9]:
def train(net: BayesianNetwork, optimizer: torch.optim.Optimizer, loader, batch_count, monte_carlo_samples=1):
    net.train()
    for i, (data, target) in enumerate(loader):
        total_loss = 0
        batch_size = data.shape[0]

        data = torch.flatten(data, start_dim=1).to(device)
        target = target.to(device)
        net.zero_grad()
        for sample, target in zip(data, target):
            input = torch.unsqueeze(sample, 0)
            target = torch.unsqueeze(target, 0)
            output = net(input)
            log_prior = net.log_prior
            log_posterior = net.log_posterior
            loss = (log_posterior - log_prior) / batch_count + F.nll_loss(output, target, reduction="sum")
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
        if i % 50 == 0:
            print(f"Completed minibatch #{i}; loss is {total_loss / batch_size}")


In [10]:
pi = 0.5 # 0.25, 0.5, 0.75
sigma1 = np.exp(-1) # 0, 1, 2
sigma2 = np.exp(-7) # 6, 7, 8
prior = GaussianMixture(pi, sigma1, sigma2)

net = BayesianNetwork(prior, prior).to(device)
#optimizer = torch.optim.SGD(net.parameters(), lr=0.01)
optimizer = torch.optim.Adam(net.parameters())
train(net, optimizer, trainloader, len(trainloader))
#torch.save(net.state_dict(), "models/bbb_mnist_adam.pth")

Completed minibatch #0; loss is -11.086641311645508


KeyboardInterrupt: 

In [None]:
with torch.no_grad():
    net.eval()
    corrects = 0
    samples = 0
    for data, target in testloader:
        data = torch.flatten(data, start_dim=1).to(device)
        target = target.to(device)
        outputs = net(data)
        preds = torch.argmax(outputs, dim=1)
        corrects += ((preds - target) == 0).sum()
        samples += len(data)
    print(f"Test accuracy: {corrects / samples}")

Test accuracy: 0.9185999631881714


In [None]:
net2 = BayesianNetwork(prior, prior).to(device)
net2.load_state_dict(torch.load("models/bbb_mnist_adam.pth"))

with torch.no_grad():
    net2.train(False)
    corrects = 0
    samples = 0
    for data, target in testloader:
        data = torch.flatten(data, start_dim=1).to(device)
        target = target.to(device)
        
        outputs = net2(data)
        preds = torch.argmax(outputs, dim=1)
        corrects += ((preds - target) == 0).sum()
        samples += len(data)
    print(f"Test accuracy: {corrects / samples}")

Test accuracy: 0.9228999614715576
