In [1]:
%env CUDA_VISIBLE_DEVICES=1

env: CUDA_VISIBLE_DEVICES=1


In [11]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from matplotlib import pyplot as plt
from tqdm import trange
from torch.distributions.normal import Normal
from bnn import BayesBaseModule, BayesConv2d, BayesLinear, BayesModel, CNN, FCNN
from pathlib import Path

## Distributions

In [3]:
distr = Normal(torch.tensor(0.), torch.tensor(0.05))

## Dataset

In [4]:
dataset = torchvision.datasets.MNIST('./files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))
generator=torch.Generator()
generator.manual_seed(42)
train_set, val_set = torch.utils.data.random_split(dataset, [50000, 10000])

## Train

### Multivariate Normal

In [20]:
temperatures = [1, 5e-1, 1e-1, 5e-2, 1e-2, 5e-3, 1e-3]
errors = np.zeros(len(temperatures))
distr = Normal(torch.tensor(0.), torch.tensor(0.05))
n_epochs = 100

for i, temperature in enumerate(temperatures[1:]):
    mdl = CNN(weight_distribution = distr, bias_distribution = distr)
    trainer = BayesModel(train_dataset=train_set,
                    test_dataset=val_set,
                    batch_size=128,
                    architecture=mdl,
                    lr=1e-3,
                    temperature=temperature)
    trainer.fit(n_epochs = 100, log_dir='./runs/cnn_normal_{temp:.3f}'.format(temp=temperature))
    model_save_path = Path('./models/cnn_normal_{temp:.3f}/'.format(temp=temperature))
    model_save_path.mkdir(exist_ok = True, parents = True)
    torch.save(trainer.architecture.state_dict(), model_save_path / 'model.pth')

100%|█████████████████████████████████████████| 100/100 [17:45<00:00, 10.66s/it]
100%|█████████████████████████████████████████| 100/100 [20:30<00:00, 12.31s/it]
100%|█████████████████████████████████████████| 100/100 [21:57<00:00, 13.18s/it]
100%|█████████████████████████████████████████| 100/100 [22:03<00:00, 13.23s/it]
100%|█████████████████████████████████████████| 100/100 [21:17<00:00, 12.77s/it]
100%|█████████████████████████████████████████| 100/100 [22:17<00:00, 13.38s/it]
