# Demo - Convert to Bayesian Neural Network

In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

import torchbnn as bnn
from torchbnn.utils import nonbayes_to_bayes, bayes_to_nonbayes

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline

## 2. Define Model

In [3]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        
        self.conv_layer = nn.Sequential(
            nn.Conv2d(1,3,3),
            nn.ReLU(),
            nn.MaxPool2d(2,2)
        )
        
        self.fc_layer = nn.Sequential(
            nn.Linear(3*2*2,3*2),
            nn.ReLU(),
            nn.Linear(3*2,2)
        )       
        
    def forward(self,x):
        out = self.conv_layer(x)
        out = out.view(-1,3*2*2)
        out = self.fc_layer(out)

        return out

In [4]:
model = CNN()

In [5]:
model

CNN(
  (conv_layer): Sequential(
    (0): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc_layer): Sequential(
    (0): Linear(in_features=12, out_features=6, bias=True)
    (1): ReLU()
    (2): Linear(in_features=6, out_features=2, bias=True)
  )
)

## 3. Convert Model

### 3.1. Nonbayes to Bayes

In [6]:
# if inplace=False, it will return a whole new network.
# Thus, it needs memory size as twice.
# inplace=True as default.
nonbayes_to_bayes(model, prior_mu=0, prior_sigma=0.1, inplace=False)

CNN(
  (conv_layer): Sequential(
    (0): BayesConv2d(0, 0.1, 1, 3, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc_layer): Sequential(
    (0): BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=12, out_features=6, bias=True)
    (1): ReLU()
    (2): BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=6, out_features=2, bias=True)
  )
)

In [7]:
model

CNN(
  (conv_layer): Sequential(
    (0): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc_layer): Sequential(
    (0): Linear(in_features=12, out_features=6, bias=True)
    (1): ReLU()
    (2): Linear(in_features=6, out_features=2, bias=True)
  )
)

In [8]:
# if inplace=True it will change the input network.
nonbayes_to_bayes(model, prior_mu=0, prior_sigma=0.1, inplace=True)

CNN(
  (conv_layer): Sequential(
    (0): BayesConv2d(0, 0.1, 1, 3, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc_layer): Sequential(
    (0): BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=12, out_features=6, bias=True)
    (1): ReLU()
    (2): BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=6, out_features=2, bias=True)
  )
)

In [9]:
model

CNN(
  (conv_layer): Sequential(
    (0): BayesConv2d(0, 0.1, 1, 3, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc_layer): Sequential(
    (0): BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=12, out_features=6, bias=True)
    (1): ReLU()
    (2): BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=6, out_features=2, bias=True)
  )
)

### 3.2. Bayes to Nonbayes

In [10]:
bayes_to_nonbayes(model)

CNN(
  (conv_layer): Sequential(
    (0): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc_layer): Sequential(
    (0): Linear(in_features=12, out_features=6, bias=True)
    (1): ReLU()
    (2): Linear(in_features=6, out_features=2, bias=True)
  )
)

In [11]:
model

CNN(
  (conv_layer): Sequential(
    (0): Conv2d(1, 3, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc_layer): Sequential(
    (0): Linear(in_features=12, out_features=6, bias=True)
    (1): ReLU()
    (2): Linear(in_features=6, out_features=2, bias=True)
  )
)