<a href="https://colab.research.google.com/github/JamesMalkin/BayesianANN_tutorial/blob/main/BayesNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Set up dependencies and upload processed data

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import torch.distributions 



if torch.cuda.is_available():  
    dev = "cuda" 
else:  
    dev = "cpu"
device = torch.device(dev)

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
trainset = torchvision.datasets.MNIST(root = '.data/trainset', train = True, transform=transform, download=True)
testset = torchvision.datasets.MNIST(root = '.data/testset', train = False, transform=transform, download=True)


torch.manual_seed(0)

Bayesian neural network

In [None]:
BATCHSIZE = 20

#trainset = torch.load('./.data/trainset/MNIST/processed/training.pt')
#trainset, valset = torch.utils.data.random_split(trainset, [50000, 10000])

trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCHSIZE,
                                          shuffle=True, num_workers=0)
TRAINING_INSTANCES = len(trainloader)

#testset = torch.load('./.data/testset/MNIST/processed/test.pt')
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCHSIZE,
                                         shuffle=False, num_workers=0)
TEST_INSTANCES = len(testloader)


class NetLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        
        # Weight parameters
        self.weight_mu = nn.Parameter(torch.empty(out_features, in_features, device=device,dtype=torch.double).uniform_(-0.1, 0.1))
        self.weight_phi = nn.Parameter((torch.full((out_features, in_features), torch.log(torch.exp((torch.tensor(1e-4, dtype=torch.double)))-1), device=device).double()))
        
        # Bias parameters
        self.bias_mu = nn.Parameter(torch.empty(out_features, device=device, dtype=torch.double).uniform_(-0.1, 0.1)) #(-0.2, 0.2) 
        self.bias_phi = nn.Parameter(torch.full((1, out_features), torch.log(torch.exp((torch.tensor(1e-4, dtype=torch.double)))-1), device=device).double()) #was 0.01 before january
        
    
    def forward(self, input, sample=False):
        weight_sig = F.softplus(self.weight_phi)
        bias_sig = F.softplus(self.bias_phi)
        weight_var  = torch.pow(weight_sig.detach().clone(),2)
        bias_var  = torch.pow(bias_sig.detach().clone(),2)
                          
        weight_dist = torch.distributions.Normal(self.weight_mu, weight_sig)
        bias_dist = torch.distributions.Normal(self.bias_mu, bias_sig)
        
        weight = weight_dist.rsample()
        bias = bias_dist.rsample()

        if sample:
            Net.ent_loss += Net.ent_cost_func(self.weight_mu, weight_sig, weight).sum()
            Net.ent_loss += Net.ent_cost_func(self.bias_mu, bias_sig, bias).sum()
            Net.prior_loss += Net.prior_cost_func(weight).sum()
            Net.prior_loss += Net.prior_cost_func(bias).sum()
                
        return F.linear(input, weight, bias)
                   
class Net(nn.Module):
    def __init__(self, power=2, scale=1): 
        super().__init__()
        self.firingrate = []
        self.p = power
        self.s = torch.tensor(scale, device=device)
        self.linear1 = NetLayer(28*28, 100)
        self.linear2 = NetLayer(100, 100)
        self.linear3 = NetLayer(100, 10)

    def ent_cost_func(self, sample, sig, mu):
        return (BATCHSIZE/TRAINING_INSTANCES)*torch.sum(torch.distributions.normal.Normal(mu, sig).log_prob(sample))
    
    def prior_cost_func(self, sample):
        return (BATCHSIZE/TRAINING_INSTANCES)*1*torch.sum(sample**2)

        
    def forward(self, x, sample=False, biosample=False, lang=False, noise=False, s=False, batch_idx=False, epoch=False):
        self.prior_loss = 0
        self.ent_loss = 0
        self.like_loss = 0
        
        x = x.view(-1, 784)
        x = self.linear1(x, sample)
        x = F.relu(self.linear2(x, sample, biosample, lang, noise))
        x = F.relu(self.linear3(x, sample))
        x = F.log_softmax(x, dim=1)
        return x
    
    @staticmethod
    def loss(pred_values, true_values):
        criterion = nn.NLLLoss(reduction='mean')
        loss = criterion(pred_values, true_values)*BATCHSIZE*10
        return loss