In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from matplotlib import pyplot as plt
from tqdm import trange
from torch.distributions.normal import Normal
from bnn import BayesBaseModule, BayesConv2d, BayesLinear, BayesModel

## Distributions and model

In [2]:
distr = Normal(torch.tensor(0.), torch.tensor(1.))

In [3]:
class NormalModel(nn.Module):
    def __init__(self, 
                 weight_distribution,
                 bias_distribution):
        super().__init__()
        self.conv_0 = BayesConv2d(weight_distribution, bias_distribution, 
                                  in_channels=1, out_channels=64, kernel_size=3)
        self.conv_1 = BayesConv2d(weight_distribution, bias_distribution, 
                                  in_channels=64, out_channels=64, kernel_size=3)
        self.relu = nn.ReLU()
        self.pooling = nn.MaxPool2d(kernel_size=2)
        self.fc = BayesLinear(weight_distribution, bias_distribution, in_features=1600, out_features=10)
        self._weight_distribution = weight_distribution
        self._bias_distribution = bias_distribution
        
    def weight_distribution(self):
        return self._weight_distribution
    
    def bias_distribution(self):
        return self._bias_distribution
    
    def weight(self):
        return self.weight
        
    def log_prior(self):
        log_p = 0
        for m in self.modules():
            if isinstance(m, (BayesLinear, BayesConv2d)):
                log_p += m.log_prior

    def forward(self, x):
        x = self.relu(self.conv_0(x))
        x = self.pooling(x)
        x = self.relu(self.conv_1(x))
        x = self.pooling(x)
        x = x.view(-1, 1600)
        x = self.fc(x)
        return x

In [4]:
mdl = NormalModel(weight_distribution = distr,
                 bias_distribution = distr)

## Dataset

In [5]:
dataset = torchvision.datasets.MNIST('./files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))
train_set, val_set = torch.utils.data.random_split(dataset, [50000, 10000])

In [6]:
dataset[0][0].shape

torch.Size([1, 28, 28])

## Train

In [7]:
trainer = BayesModel(train_dataset=train_set,
                    test_dataset=val_set,
                    batch_size=128,
                    architecture=mdl,
                    lr=1e-3)

In [8]:
trainer.fit(n_epochs = 30)

  0%|          | 0/30 [00:00<?, ?it/s]