<a href="https://colab.research.google.com/github/asolnn2a8/Bayesian-Deep-Learning/blob/main/Bayesian_CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install pyro-ppl



In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import pyro
from pyro.distributions import Normal, Categorical
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors
from IPython import display
import os
from PIL import Image
from torch.utils.data.dataset import Dataset

## Load MNIST dataset

In [3]:
# Load dataset
train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('mnist-data/', train=True, download=True,
                       transform=transforms.Compose([transforms.ToTensor(),])),
        batch_size=128, shuffle=True)

test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('mnist-data/', train=False, transform=transforms.Compose([transforms.ToTensor(),])
                       ),
        batch_size=128, shuffle=True)

## Utils

In [4]:
# Some required objects
softplus = torch.nn.Softplus()
log_softmax = nn.LogSoftmax(dim=1)

In [19]:
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    #plt.imshow(npimg,  cmap='gray')
    #fig.show(figsize=(1,1))
    
    fig, ax = plt.subplots(figsize=(1, 1))
    ax.imshow(npimg,  cmap='gray', interpolation='nearest')
    plt.show()

def give_uncertainities(x, num_samples=100):
    sampled_models = [guide(net, None, None) for _ in range(num_samples)]
    yhats = [F.log_softmax(model(x).data, 1).detach().numpy() for model in sampled_models]
    return np.asarray(yhats)


def predict(net, x, num_samples=10):
    sampled_models = [guide(net, None, None) for _ in range(num_samples)]
    yhats = [model(x).data for model in sampled_models]
    mean = torch.mean(torch.stack(yhats), 0)
    return np.argmax(mean.numpy(), axis=1)


# Bayesian CNN for MNIST classification

In [7]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = CNN()

In [6]:
def model(net, x_data, y_data):
    """
    Bayesian model with the NN architecture from "net".
    """    
    # Set Gaussian priors for weights of CNN
    priors = {}

    for name, param in net.state_dict().items():
        priors[name] = Normal(loc=torch.zeros_like(param), scale=torch.ones_like(param))

    # lift module parameters to random variables sampled from the priors
    lifted_module = pyro.random_module("module", net, priors)

    # sample a regressor (which also samples w and b)
    lifted_reg_model = lifted_module()
    lhat = log_softmax(lifted_reg_model(x_data))
    
    # sample from categorigal distribution
    pyro.sample("obs", Categorical(logits=lhat), obs=y_data)


def guide(net, x_data, y_data):
    """
    Variational distribution that approximates the posterior p(w|X, Y) with a Gaussian distribution.
    """
    priors = {}
    for name, param in net.state_dict().items():
        mu = pyro.param(name+'_mu', torch.randn_like(param))
        sigma = softplus(pyro.param(name+'_sigma', torch.randn_like(param)))
        priors[name] = Normal(loc=mu, scale=sigma)
    
    lifted_module = pyro.random_module("module", net, priors)
    
    return lifted_module()

In [10]:
optim = Adam({"lr": 0.01})
svi = SVI(model, guide, optim, loss=Trace_ELBO())

num_iterations = 5
loss = 0

for j in range(num_iterations):
    loss = 0
    for batch_id, data in enumerate(train_loader):
        # calculate the loss and take a gradient step
        loss += svi.step(net, data[0], data[1])
    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = loss / normalizer_train
    
    print("Epoch {}/{} - Loss: {:.2f} ".format(j,
                                               num_iterations,
                                               total_epoch_loss_train))



Epoch 0/5 - Loss: 129.96 
Epoch 1/5 - Loss: 43.86 
Epoch 2/5 - Loss: 29.35 
Epoch 3/5 - Loss: 22.85 
Epoch 4/5 - Loss: 20.79 


### Prediction

In [23]:
num_samples = 100

print('Prediction when network is forced to predict')
correct = 0
total = 0
for j, data in enumerate(test_loader):
    images, labels = data
    predicted = predict(net, images, num_samples)
    total += labels.size(0)
    correct += (predicted == labels.numpy()).sum().item()
print("accuracy: %d %%" % (100 * correct / total))

Prediction when network is forced to predict




accuracy: 10 %


# Bayesian MLP 

## Model

#### NN architecture

In [9]:
class NN(nn.Module):
  """
  A simple MLP.
  """
  def __init__(self, input_size, hidden_size, output_size):
      super(NN, self).__init__()
      self.fc1 = nn.Linear(input_size, hidden_size)
      self.out = nn.Linear(hidden_size, output_size)
      
  def forward(self, x):
      output = self.fc1(x)
      output = F.relu(output)
      output = self.out(output)
      return output

# Create model
net = NN(28*28, 1024, 10)

#### Bayesian settings

## Training

In [12]:
optim = Adam({"lr": 0.01})
svi = SVI(model, guide, optim, loss=Trace_ELBO())

num_iterations = 5
loss = 0

for j in range(num_iterations):
    loss = 0
    for batch_id, data in enumerate(train_loader):
        # calculate the loss and take a gradient step
        loss += svi.step(data[0].view(-1,28*28), data[1])
    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = loss / normalizer_train
    
    print("Epoch {}/{} - Loss: {:.2f} ".format(j,
                                               num_iterations,
                                               total_epoch_loss_train))



Epoch 0/5 - Loss: 2055.53 
Epoch 1/5 - Loss: 354.31 
Epoch 2/5 - Loss: 152.26 
Epoch 3/5 - Loss: 109.38 
Epoch 4/5 - Loss: 95.57 


## Prediction

In [13]:
num_samples = 10

print('Prediction when network is forced to predict')
correct = 0
total = 0
for j, data in enumerate(test_loader):
    images, labels = data
    predicted = predict(images.view(-1,28*28))
    total += labels.size(0)
    correct += (predicted == labels.numpy()).sum().item()
print("accuracy: %d %%" % (100 * correct / total))

Prediction when network is forced to predict




accuracy: 89 %
