Check https://github.com/kumar-shridhar/PyTorch-BayesianCNN

In [1]:
from __future__ import print_function

import os
import argparse

import torch
import numpy as np
from torch.optim import Adam
from torch.nn import functional as F

import data
import utils
import metrics
import config_bayesian as cfg
from models.BayesianModels.Bayesian3Conv3FC import BBB3Conv3FC
from models.BayesianModels.BayesianAlexNet import BBBAlexNet
from models.BayesianModels.BayesianLeNet import BBBLeNet

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
def getModel(net_type, inputs, outputs):
    if (net_type == 'lenet'):
        return BBBLeNet(outputs,inputs)
    elif (net_type == 'alexnet'):
        return BBBAlexNet(outputs, inputs)
    elif (net_type == '3conv3fc'):
        return BBB3Conv3FC(outputs,inputs)
    else:
        raise ValueError('Network should be either [LeNet / AlexNet / 3Conv3FC')


def train_model(net, optimizer, criterion, trainloader, num_ens=1):
    net.train()
    training_loss = 0.0
    accs = []
    kl_list = []
    freq = cfg.recording_freq_per_epoch
    freq = len(trainloader)//freq
    for i, (inputs, labels) in enumerate(trainloader, 1):
        cfg.curr_batch_no = i
        if i%freq==0:
            cfg.record_now = True
        else:
            cfg.record_now = False

        optimizer.zero_grad()

        inputs, labels = inputs.to(device), labels.to(device)
        outputs = torch.zeros(inputs.shape[0], net.num_classes, num_ens).to(device)

        kl = 0.0
        for j in range(num_ens):
            net_out, _kl = net(inputs)
            kl += _kl
            outputs[:, :, j] = F.log_softmax(net_out, dim=1)
        
        kl = kl / num_ens
        kl_list.append(kl.item())
        log_outputs = utils.logmeanexp(outputs, dim=2)

        loss = criterion(log_outputs, labels, kl)
        loss.backward()
        optimizer.step()

        accs.append(metrics.acc(log_outputs.data, labels))
        training_loss += loss.cpu().data.numpy()
    return training_loss/len(trainloader), np.mean(accs), np.mean(kl_list)


def validate_model(net, criterion, validloader, num_ens=1):
    """Calculate ensemble accuracy and NLL Loss"""
    net.eval()
    valid_loss = 0.0
    accs = []

    for i, (inputs, labels) in enumerate(validloader):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = torch.zeros(inputs.shape[0], net.num_classes, num_ens).to(device)
        kl = 0.0
        for j in range(num_ens):
            net_out, _kl = net(inputs)
            kl += _kl
            outputs[:, :, j] = F.log_softmax(net_out, dim=1).data

        log_outputs = utils.logmeanexp(outputs, dim=2)
        valid_loss += criterion(log_outputs, labels, kl).item()
        accs.append(metrics.acc(log_outputs, labels))

    return valid_loss/len(validloader), np.mean(accs)


def run(dataset, net_type):

    # Hyper Parameter settings
    train_ens = cfg.train_ens
    valid_ens = cfg.valid_ens
    n_epochs = cfg.n_epochs
    lr_start = cfg.lr_start
    num_workers = cfg.num_workers
    valid_size = cfg.valid_size
    batch_size = cfg.batch_size

    trainset, testset, inputs, outputs = data.getDataset(dataset)
    train_loader, valid_loader, test_loader = data.getDataloader(
        trainset, testset, valid_size, batch_size, num_workers)
    net = getModel(net_type, inputs, outputs).to(device)

    ckpt_dir = f'checkpoints/{dataset}/bayesian'
    ckpt_name = f'checkpoints/{dataset}/bayesian/model_{net_type}.pt'

    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir, exist_ok=True)

    criterion = metrics.ELBO(len(trainset)).to(device)
    optimizer = Adam(net.parameters(), lr=lr_start)
    valid_loss_max = np.Inf
    for epoch in range(n_epochs):  # loop over the dataset multiple times
        cfg.curr_epoch_no = epoch
        utils.adjust_learning_rate(optimizer, metrics.lr_linear(epoch, 0, n_epochs, lr_start))

        train_loss, train_acc, train_kl = train_model(net, optimizer, criterion, train_loader, num_ens=train_ens)
        valid_loss, valid_acc = validate_model(net, criterion, valid_loader, num_ens=valid_ens)

        print('Epoch: {} \tTraining Loss: {:.4f} \tTraining Accuracy: {:.4f} \tValidation Loss: {:.4f} \tValidation Accuracy: {:.4f} \ttrain_kl_div: {:.4f}'.format(
            epoch, train_loss, train_acc, valid_loss, valid_acc, train_kl))

        # save model if validation accuracy has increased
        if valid_loss <= valid_loss_max:
            print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
                valid_loss_max, valid_loss))
            torch.save(net.state_dict(), ckpt_name)
            valid_loss_max = valid_loss

In [4]:
device

device(type='cpu')

In [5]:
net_type = 'lenet'
dataset = 'MNIST'

In [6]:
if cfg.record_mean_var:
        mean_var_dir = f"checkpoints/{dataset}/bayesian/{net_type}/"
        cfg.mean_var_dir = mean_var_dir
        if not os.path.exists(mean_var_dir):
            os.makedirs(mean_var_dir, exist_ok=True)
        for file in os.listdir(mean_var_dir):
            os.remove(mean_var_dir + file)

In [7]:
# Hyper Parameter settings
train_ens = cfg.train_ens
valid_ens = cfg.valid_ens
n_epochs = cfg.n_epochs
lr_start = cfg.lr_start
num_workers = cfg.num_workers
valid_size = cfg.valid_size
batch_size = cfg.batch_size

In [8]:
num_workers

0

In [9]:
batch_size

500

In [10]:
trainset, testset, inputs, outputs = data.getDataset(dataset)

In [11]:
train_loader, valid_loader, test_loader = data.getDataloader(
    trainset, testset, valid_size, batch_size, num_workers)

In [12]:
net = getModel(net_type, inputs, outputs).to(device)

In [13]:
net

BBBLeNet(
  (conv1): BBBConv2d()
  (soft1): Softplus(beta=1, threshold=20)
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): BBBConv2d()
  (soft2): Softplus(beta=1, threshold=20)
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten): FlattenLayer()
  (fc1): BBBLinear()
  (soft3): Softplus(beta=1, threshold=20)
  (fc2): BBBLinear()
  (soft4): Softplus(beta=1, threshold=20)
  (fc3): BBBLinear()
)

In [14]:
ckpt_dir = f'checkpoints/{dataset}/bayesian'
ckpt_name = f'checkpoints/{dataset}/bayesian/model_{net_type}.pt'

In [15]:
criterion = metrics.ELBO(len(trainset)).to(device)

In [16]:
optimizer = Adam(net.parameters(), lr=lr_start)
valid_loss_max = np.Inf

In [17]:
epoch = 1

In [18]:
cfg.curr_epoch_no = epoch

In [19]:
utils.adjust_learning_rate(optimizer, metrics.lr_linear(epoch, 0, n_epochs, lr_start))

In [20]:
net.train()

BBBLeNet(
  (conv1): BBBConv2d()
  (soft1): Softplus(beta=1, threshold=20)
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): BBBConv2d()
  (soft2): Softplus(beta=1, threshold=20)
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten): FlattenLayer()
  (fc1): BBBLinear()
  (soft3): Softplus(beta=1, threshold=20)
  (fc2): BBBLinear()
  (soft4): Softplus(beta=1, threshold=20)
  (fc3): BBBLinear()
)

In [63]:
train_loss, train_acc, train_kl = train_model(net, optimizer,
                                              criterion,
                                              train_loader, 
                                              num_ens=train_ens)

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/Users/karimimohammedbelhal/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/Users/karimimohammedbelhal/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/Users/karimimohammedbelhal/opt/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/Users/karimimohammedbelhal/opt/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py", line 97, in __getitem__
    img = self.transform(img)
  File "/Users/karimimohammedbelhal/opt/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 70, in __call__
    img = t(img)
  File "/Users/karimimohammedbelhal/opt/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 1003, in __call__
    return F.rotate(img, angle, self.resample, self.expand, self.center, self.fill)
  File "/Users/karimimohammedbelhal/opt/anaconda3/lib/python3.7/site-packages/torchvision/transforms/functional.py", line 729, in rotate
    return img.rotate(angle, resample, expand, center, fillcolor=fill)
  File "/Users/karimimohammedbelhal/opt/anaconda3/lib/python3.7/site-packages/PIL/Image.py", line 2005, in rotate
    return self.transform((w, h), AFFINE, matrix, resample, fillcolor=fillcolor)
  File "/Users/karimimohammedbelhal/opt/anaconda3/lib/python3.7/site-packages/PIL/Image.py", line 2299, in transform
    im = new(self.mode, size, fillcolor)
  File "/Users/karimimohammedbelhal/opt/anaconda3/lib/python3.7/site-packages/PIL/Image.py", line 2505, in new
    return im._new(core.fill(mode, size, color))
TypeError: function takes exactly 1 argument (3 given)


In [21]:
training_loss = 0.0
accs = []
kl_list = []
freq = cfg.recording_freq_per_epoch
freq = len(train_loader)//freq

In [22]:
train_loader

<torch.utils.data.dataloader.DataLoader at 0x108905490>

In [23]:
num_workers

0

In [47]:
trainset

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: Compose(
               Resize(size=(32, 32), interpolation=PIL.Image.BILINEAR)
               RandomHorizontalFlip(p=0.5)
               RandomRotation(degrees=(-10, 10), resample=False, expand=False)
               ToTensor()
           )

In [51]:
for i in enumerate(train_loader, 0):
    print(i)

TypeError: function takes exactly 1 argument (3 given)

In [34]:
test = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
        sampler=train_sampler, num_workers=num_workers)

NameError: name 'train_sampler' is not defined

In [33]:
for i in enumerate(train_loader, 0):
    print(i)

TypeError: function takes exactly 1 argument (3 given)

In [29]:
for i,  (inputs, labels) in enumerate(train_loader, 0):
    # get the inputs
    inputs, labels = data
    print(i)

TypeError: function takes exactly 1 argument (3 given)