<a href="https://colab.research.google.com/github/JamesMalkin/BayesianANN_tutorial/blob/main/BayesNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Set up dependencies and upload processed data

In [23]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import torch.distributions 

if torch.cuda.is_available():  
    dev = "cuda" 
else:  
    dev = "cpu"
print('device', dev)
device = torch.device(dev)

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
trainset = torchvision.datasets.MNIST(root = '.data/trainset', train = True, transform=transform, download=True)
testset = torchvision.datasets.MNIST(root = '.data/testset', train = False, transform=transform, download=True)


torch.manual_seed(0)

device cuda


<torch._C.Generator at 0x7f3560880f70>

Bayesian neural network

In [24]:
BATCHSIZE = 20

#trainset = torch.load('./.data/trainset/MNIST/processed/training.pt')
#trainset, valset = torch.utils.data.random_split(trainset, [50000, 10000])

trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCHSIZE,
                                          shuffle=True, num_workers=0)
TRAINING_INSTANCES = len(trainloader)*BATCHSIZE


#testset = torch.load('./.data/testset/MNIST/processed/test.pt')
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCHSIZE,
                                         shuffle=False, num_workers=0)
TEST_INSTANCES = len(testloader)*BATCHSIZE

                   
class Net(nn.Module):
    def __init__(self): 
        super().__init__()
        self.firingrate = []
        self.linear1 = NetLayer(28*28, 100)
        self.linear2 = NetLayer(100, 10)
        #self.linear3 = NetLayer(100, 10)

    def ent_cost_func(self, sample, sig, mu):
        return (BATCHSIZE/TRAINING_INSTANCES)*torch.sum(torch.distributions.normal.Normal(mu, sig).log_prob(sample))
    
    def prior_cost_func(self, sample):
        return (BATCHSIZE/TRAINING_INSTANCES)*1*torch.sum(sample**2)
  
    def forward(self, x, sample=False, biosample=False, lang=False, noise=False, s=False, batch_idx=False, epoch=False):
        self.prior_loss = 0
        self.ent_loss = 0
        self.like_loss = 0
        
        x = x.view(-1, 784)
        x = F.relu(self.linear1(x, sample))
        x = self.linear2(x, sample)
        #x = F.relu(self.linear3(x, sample))
        x = F.log_softmax(x, dim=1)
        return x
    
    @staticmethod
    def loss(pred_values, true_values):
        criterion = nn.NLLLoss(reduction='mean')
        loss = criterion(pred_values, true_values)*BATCHSIZE*10
        return loss

class NetLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        
        # Weight parameters
        self.weight_mu = nn.Parameter(torch.empty(out_features, in_features, device=device,dtype=torch.double).uniform_(-0.1, 0.1))
        self.weight_phi = nn.Parameter((torch.full((out_features, in_features), torch.log(torch.exp((torch.tensor(1e-4, dtype=torch.double)))-1), device=device).double()))
        
        # Bias parameters
        self.bias_mu = nn.Parameter(torch.empty(out_features, device=device, dtype=torch.double).uniform_(-0.1, 0.1)) #(-0.2, 0.2) 
        self.bias_phi = nn.Parameter(torch.full((1, out_features), torch.log(torch.exp((torch.tensor(1e-4, dtype=torch.double)))-1), device=device).double()) #was 0.01 before january
        
    
    def forward(self, input, sample=False):
        weight_sig = F.softplus(self.weight_phi)
        bias_sig = F.softplus(self.bias_phi)
        weight_var  = torch.pow(weight_sig.detach().clone(),2)
        bias_var  = torch.pow(bias_sig.detach().clone(),2)
                          
        weight_dist = torch.distributions.Normal(self.weight_mu, weight_sig)
        bias_dist = torch.distributions.Normal(self.bias_mu, bias_sig)
        
        if sample:
          weight = weight_dist.rsample()
          bias = bias_dist.rsample()
        else:
          weight = self.weight_mu
          bias = self.bias_mu

        if sample:
            net.ent_loss += net.ent_cost_func(self.weight_mu, weight_sig, weight).sum()
            net.ent_loss += net.ent_cost_func(self.bias_mu, bias_sig, bias).sum()
        net.prior_loss += net.prior_cost_func(weight).sum()
        net.prior_loss += net.prior_cost_func(bias).sum()
                
        return F.linear(input, weight, bias)

Train and test function

In [27]:
TRAINING_SAMPLING = 1
TEST_SAMPLING = 10

def train(sample=False):
    running_loss = 0
    running_like_loss = 0
    running_ent_loss = 0
    running_prior_loss = 0

    mode_performance_list = [] 
    mode_loss_list = []
    sample_performance_list = [] 
    sample_loss_list = []
    expected_performance_list = []
    expected_loss_list = []

    for epoch in range(60):
        for batch_idx, (data, target) in enumerate(trainloader):
            if batch_idx <= len(trainloader):
                net.train()
                data = data.to(device)
                data = data.type(torch.double)
                target = target.to(device)
                target = target.type(torch.long)
                net.zero_grad()
                like_loss = 0
                ent_loss = 0
                prior_loss = 0

                for j in range(TRAINING_SAMPLING):
                    preds = net(data, sample=sample)
                    like_loss += net.loss(preds, target)
                
                loss = like_loss + net.prior_loss + net.ent_loss
                loss = loss/TRAINING_SAMPLING
                loss.backward()
                            
                running_loss += loss.item()/TRAINING_SAMPLING
                running_like_loss += like_loss.item()/TRAINING_SAMPLING
                running_ent_loss += net.ent_loss.item()/TRAINING_SAMPLING
                running_prior_loss += net.prior_loss.item()/TRAINING_SAMPLING
               
                if net.training:
                    optimiser.step()

                if batch_idx % TRAINING_INSTANCES/BATCHSIZE == (TRAINING_INSTANCES/BATCHSIZE)-1: # Print every so mini-batches (epoch)
                    print('[Epoch-{}, Batch-{} total loss= {}, like_loss={}, ent_loss= {}, prior_loss= {}'.format(epoch + 1, batch_idx + 1, running_loss / (3000*BATCHSIZE), running_like_loss/(3000*BATCHSIZE), running_ent_loss/(3000*BATCHSIZE), running_prior_loss/(3000*BATCHSIZE)))
                    running_loss = 0.0
                    running_like_loss = 0.0
                    running_prior_loss = 0.0
                    running_ent_loss = 0.0
    
                    mode_performance, mode_loss = test(sample=False, exp_accuracy=False)
                    sample_performance, sample_loss = test(sample=True, exp_accuracy=False)
                    expected_performance, expected_loss = test(sample=True, exp_accuracy=True)

                    mode_performance_list.append(mode_performance)
                    mode_loss_list.append(mode_loss)
                    sample_performance_list.append(sample_performance)
                    sample_loss_list.append(sample_loss)
                    expected_performance_list.append(expected_performance)
                    expected_loss_list.append(expected_loss)
    
    return mode_performance_list, mode_loss_list, sample_performance_list, sample_loss_list, expected_performance_list, expected_loss_list

def test(sample=False, classes=10, exp_accuracy=False):
    correct = 0
    loss = 0
    net.eval()
    if exp_accuracy:
        with torch.no_grad():
            for data in testloader:
                preds = torch.zeros(BATCHSIZE, 10).to(device)
                class_preds = torch.zeros(BATCHSIZE, 1)
                for n in range(TEST_SAMPLING):
                    images, labels = data
                    images = images.to(device)
                    images = images.type(torch.double) 
                    labels = labels.to(device)
                    labels = labels.type(torch.long)
                    preds += net(images, sample=sample)
                    
                preds /= TEST_SAMPLING
                loss += net.loss(preds, labels)
                class_preds = preds.max(1, keepdim=True)[1]
                correct += class_preds.eq(labels.view(-1, 1)).sum().item()
        accuracy = 100 * correct / (TEST_INSTANCES)
        loss /= TEST_INSTANCES
        print('Expected Accuracy', np.round(accuracy,3))
        print('Expected Loss', np.round(loss.item(),3))
    else:
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                images = images.to(device) 
                labels = labels.to(device)
                images = images.type(torch.double) 
                labels = labels.type(torch.long)
                preds = net(images, sample=sample)
                loss += net.loss(preds, labels)
                class_preds = preds.max(1, keepdim=True)[1]
                correct += class_preds.eq(labels.view(-1, 1)).sum().item()
            
        loss = loss / TEST_INSTANCES
        accuracy = 100 * correct / TEST_INSTANCES
        if sample:
          print('Sample Accuracy', np.round(accuracy,3))
          print('Sample Loss', np.round(loss.item(),3))
        else:
          print('Mode Accuracy', np.round(accuracy,3))
          print('Mode Loss', np.round(loss.item(),3))
    return accuracy, loss


Run simulation


In [None]:
net = Net()
optimiser = optim.Adam(net.parameters(), lr=0.001)
mode_performance_list, mode_loss_list, sample_performance_list, sample_loss_list, expected_performance_list, expected_loss_list = train(sample=True)

np.save('mode_performance_list', np.array(mode_performance_list))
np.save('mode_loss_list', np.array(mode_loss_list))
np.save('sample_performance_list', np.array(sample_performance_list))
np.save('mode_loss_list', np.array(mode_loss_list))
np.save('expected_performance_list', np.array(expected_performance_list))
np.save('expected_loss_list', np.array(expected_loss_list))

Test the best estimate for the network, a noisy network and the average vote of a noisy network. These correspond to the mode of the posterior, single samples from the posterior, and the average of the predictive posterior, respectively.

In [22]:
mode_performance, mode_loss = test(sample=False, exp_accuracy=False)
sample_performance, sample_loss = test(sample=True, exp_accuracy=False)
expected_performance, expected_loss = test(sample=True, exp_accuracy=True)

Sample Accuracy 1869.4
Sample Loss 42.309
Sample Accuracy 1869.2
Sample Loss 42.315
Expected Accuracy 1869.4
Expected Loss 42.312
