In [1]:
import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torch.utils.data import DataLoader

from bnn_priors.models import DenseNet
from bnn_priors.data import CIFAR10
from bnn_priors import prior
from bnn_priors.models import RegressionModel, LinearPrior, ClassificationModel

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True

## Testing ResNet18 model

In [3]:
# Based on https://github.com/kuangliu/pytorch-cifar/blob/master/models/preact_resnet.py

class PreActBlock(nn.Module):
    '''Pre-activation version of the BasicBlock.'''
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, bn=True):
        super(PreActBlock, self).__init__()
        if bn:
            batchnorm = nn.BatchNorm2d
        else:
            batchnorm = nn.Identity
        self.bn1 = batchnorm(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = batchnorm(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)

        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
            )

    def forward(self, x):
        out = x
        out = self.bn1(out)
        out = F.relu(out)
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.conv2(F.relu(out))
        out += shortcut
        return out


class PreActBottleneck(nn.Module):
    '''Pre-activation version of the original Bottleneck module.'''
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, bn=True):
        super(PreActBottleneck, self).__init__()
        if bn:
            batchnorm = nn.BatchNorm2d
        else:
            batchnorm = nn.Identity
        self.bn1 = batchnorm(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn2 = batchnorm(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn3 = batchnorm(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)

        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
            )

    def forward(self, x):
        out = x
        out = self.bn1(out)
        out = F.relu(out)
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.conv2(F.relu(out))
        out = self.bn3(out)
        out = self.conv3(F.relu(out))
        out += shortcut
        return out


class PreActResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, bn=True):
        super(PreActResNet, self).__init__()
        self.in_planes = 64
        self.bn = bn

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, bn=self.bn))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def PreActResNet18(bn=True):
    return PreActResNet(PreActBlock, [2,2,2,2], bn=bn)

def PreActResNet34(bn=True):
    return PreActResNet(PreActBlock, [3,4,6,3], bn=bn)

def PreActResNet50(bn=True):
    return PreActResNet(PreActBottleneck, [3,4,6,3], bn=bn)

def PreActResNet101(bn=True):
    return PreActResNet(PreActBottleneck, [3,4,23,3], bn=bn)

def PreActResNet152(bn=True):
    return PreActResNet(PreActBottleneck, [3,8,36,3], bn=bn)

In [4]:
data = CIFAR10(device=device)

dataloader_train = DataLoader(data.norm.train, batch_size=32, shuffle=True, drop_last=True)
dataloader_test = DataLoader(data.norm.test, batch_size=32, shuffle=True, drop_last=True)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
num_epochs = 5
lr = 5e-4

net = PreActResNet18(bn=True).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=lr)

for epoch in range(num_epochs):
    net.train()
    with tqdm(desc=f"Epoch {epoch}", total=len(dataloader_train), leave=False) as pbar:
        for batch_x, batch_y in dataloader_train:
            optimizer.zero_grad()
            y_pred = net(batch_x)
            loss = criterion(y_pred, batch_y)
            loss.backward()
            optimizer.step()
            pbar.update()
            pbar.set_postfix({"loss": f"{loss.item():.2f}"})
            
    total_acc = 0.
    num_batches = 0
    net.eval()
    with torch.no_grad():
        for batch_x, batch_y in dataloader_test:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            y_pred = net(batch_x)
            total_acc += y_pred.argmax(axis=1).eq(batch_y).float().mean().item()
            num_batches += 1
    acc = total_acc/num_batches
    print(f"Epoch {epoch}: Test accuracy = {acc*100:.1f} %")

Epoch 1:   0%|          | 7/1562 [00:00<00:40, 37.93it/s, loss=0.92]   

Epoch 0: Test accuracy = 54.6 %


Epoch 2:   0%|          | 6/1562 [00:00<00:41, 37.09it/s, loss=0.70]   

Epoch 1: Test accuracy = 68.9 %


Epoch 3:   0%|          | 6/1562 [00:00<00:42, 36.94it/s, loss=0.60]   

Epoch 2: Test accuracy = 76.6 %


Epoch 4:   0%|          | 6/1562 [00:00<00:48, 31.95it/s, loss=0.60]   

Epoch 3: Test accuracy = 77.8 %


                                                                       

Epoch 4: Test accuracy = 80.5 %


## Building Bayesian CNNs analogous to the DenseNet class

In [6]:
class Conv2d(nn.Conv2d):
    def __init__(self, weight_prior, bias_prior=None, stride=1,
            padding=0, dilation=1, groups=1, padding_mode='zeros'):
        nn.Module.__init__(self)
        
        self.stride = nn.modules.utils._pair(stride)
        self.padding = nn.modules.utils._pair(padding)
        self.dilation = nn.modules.utils._pair(dilation)
        self.groups = groups
        self.padding_mode = padding_mode
        self.transposed = False
        self.output_padding = nn.modules.utils._pair(0)
        
        (self.out_channels, in_channels, ksize_0, ksize_1) = weight_prior.p.shape
        self.in_channels = in_channels * self.groups
        self.kernel_size = (ksize_0, ksize_1)
        self.weight_prior = weight_prior
        self.bias_prior = bias_prior

    @property
    def weight(self):
        return self.weight_prior()

    @property
    def bias(self):
        return (None if self.bias_prior is None else self.bias_prior())

In [7]:
def Conv2dPrior(in_channels, out_channels, kernel_size, stride=1,
            padding=0, dilation=1, groups=1, padding_mode='zeros',
            prior_w=prior.Normal, loc_w=0., std_w=1., prior_b=prior.Normal,
            loc_b=0., std_b=1., scaling_fn=None):
    if scaling_fn is None:
        def scaling_fn(std, dim):
            return std/dim**0.5
    kernel_size = nn.modules.utils._pair(kernel_size)
    bias_prior = prior_b((out_channels,), 0., std_b) if prior_b is not None else None
    return Conv2d(weight_prior=prior_w((out_channels, in_channels//groups, kernel_size[0], kernel_size[1]),
                                       loc_w, scaling_fn(std_w, in_channels)),
                  bias_prior=bias_prior,
                 stride=stride, padding=padding, dilation=dilation,
                  groups=groups, padding_mode=padding_mode)

In [8]:
class PreActBlock(nn.Module):
    '''Pre-activation version of the BasicBlock.'''
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, bn=True,
                 prior_w=prior.Normal, loc_w=0., std_w=2**.5,
                 prior_b=prior.Normal, loc_b=0., std_b=1.,
                scaling_fn=None):
        super(PreActBlock, self).__init__()
        if bn:
            batchnorm = nn.BatchNorm2d
        else:
            batchnorm = nn.Identity
        self.bn1 = batchnorm(in_planes)
        self.conv1 = Conv2dPrior(in_planes, planes, kernel_size=3, stride=stride, padding=1,
                                 prior_w=prior_w, loc_w=loc_w, std_w=std_w,
                                 prior_b=None, scaling_fn=scaling_fn)
        self.bn2 = batchnorm(planes)
        self.conv2 = Conv2dPrior(planes, planes, kernel_size=3, stride=1, padding=1,
                                 prior_w=prior_w, loc_w=loc_w, std_w=std_w,
                                 prior_b=None, scaling_fn=scaling_fn)

        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
            )

    def forward(self, x):
        out = x
        out = self.bn1(out)
        out = F.relu(out)
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.conv2(F.relu(out))
        out += shortcut
        return out


class PreActResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10, bn=True,
                 prior_w=prior.Normal, loc_w=0., std_w=2**.5,
                 prior_b=prior.Normal, loc_b=0., std_b=1.,
                scaling_fn=None):
        super(PreActResNet, self).__init__()
        self.in_planes = 64
        self.bn = bn
        self.prior_w = prior_w
        self.loc_w = loc_w
        self.std_w = std_w
        self.prior_b = prior_b
        self.loc_b = loc_b
        self.std_b = std_b
        self.scaling_fn = scaling_fn

        self.conv1 = Conv2dPrior(3, 64, kernel_size=3, stride=1, padding=1, prior_b=None,
                           prior_w=self.prior_w, loc_w=self.loc_w, std_w=self.std_w,
                           scaling_fn=self.scaling_fn)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = LinearPrior(512*block.expansion, num_classes,
                            prior_w=self.prior_w, loc_w=self.loc_w, std_w=self.std_w,
                            prior_b=self.prior_b, loc_b=self.loc_b, std_b=self.std_b,
                            scaling_fn=self.scaling_fn)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, bn=self.bn,
                                prior_w=self.prior_w, loc_w=self.loc_w, std_w=self.std_w,
                                prior_b=self.prior_b, loc_b=self.loc_b, std_b=self.std_b,
                                scaling_fn=self.scaling_fn))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def PreActResNet18(softmax_temp=1.,
             prior_w=prior.Normal, loc_w=0., std_w=2**.5,
             prior_b=prior.Normal, loc_b=0., std_b=1.,
            scaling_fn=None, bn=True):
    return ClassificationModel(PreActResNet(PreActBlock,
                                        [2,2,2,2], bn=bn,
                                        prior_w=prior_w,
                                       loc_w=loc_w,
                                       std_w=std_w,
                                       prior_b=prior_b,
                                       loc_b=loc_b,
                                       std_b=std_b,
                                       scaling_fn=scaling_fn,), softmax_temp)

def PreActResNet34(softmax_temp=1.,
             prior_w=prior.Normal, loc_w=0., std_w=2**.5,
             prior_b=prior.Normal, loc_b=0., std_b=1.,
            scaling_fn=None, bn=True):
    return ClassificationModel(PreActResNet(PreActBlock,
                                        [3,4,6,3], bn=bn,
                                        prior_w=prior_w,
                                       loc_w=loc_w,
                                       std_w=std_w,
                                       prior_b=prior_b,
                                       loc_b=loc_b,
                                       std_b=std_b,
                                       scaling_fn=scaling_fn,), softmax_temp)

## Test new model class

In [9]:
from bnn_priors.models import PreActResNet18

In [10]:
data = CIFAR10(device=device)

dataloader_train = DataLoader(data.norm.train, batch_size=32, shuffle=True, drop_last=True)
dataloader_test = DataLoader(data.norm.test, batch_size=32, shuffle=True, drop_last=True)

Files already downloaded and verified
Files already downloaded and verified


#### SGD training

In [11]:
num_epochs = 5
lr = 5e-4

net = PreActResNet18(bn=True).to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=lr)

for criterion in ["nll", "potential"]:
    print(f"Training criterion: {criterion}")
    for epoch in range(num_epochs):
        net.train()
        with tqdm(desc=f"Epoch {epoch}", total=len(dataloader_train), leave=False) as pbar:
            for batch_x, batch_y in dataloader_train:
                optimizer.zero_grad()
                if criterion == "nll":
                    loss = -net.log_likelihood_avg(batch_x, batch_y, len(dataloader_train.dataset))
                elif criterion == "potential":
                    loss = net.potential_avg(batch_x, batch_y, len(dataloader_train.dataset))
                loss.backward()
                optimizer.step()
                pbar.update()
                pbar.set_postfix({"loss": f"{loss.item():.2f}"})

        total_acc = 0.
        num_batches = 0
        net.eval()
        with torch.no_grad():
            for batch_x, batch_y in dataloader_test:
                batch_x, batch_y = batch_x.to(device), batch_y.to(device)
                y_pred = net(batch_x).probs.argmax(axis=1)
                total_acc += y_pred.eq(batch_y).float().mean().item()
                num_batches += 1
        acc = total_acc/num_batches
        print(f"Epoch {epoch}: Test accuracy = {acc*100:.1f} %")

Epoch 0:   0%|          | 7/1562 [00:00<00:42, 36.39it/s, loss=12.97]

Training criterion: nll


Epoch 1:   0%|          | 6/1562 [00:00<00:48, 32.15it/s, loss=1.26]   

Epoch 0: Test accuracy = 54.2 %


Epoch 2:   0%|          | 6/1562 [00:00<00:44, 35.20it/s, loss=0.77]   

Epoch 1: Test accuracy = 61.6 %


Epoch 3:   0%|          | 6/1562 [00:00<00:49, 31.13it/s, loss=0.76]   

Epoch 2: Test accuracy = 72.7 %


Epoch 4:   0%|          | 6/1562 [00:00<00:45, 34.16it/s, loss=0.45]   

Epoch 3: Test accuracy = 74.2 %


Epoch 0:   0%|          | 4/1562 [00:00<01:06, 23.46it/s, loss=247.38] 

Epoch 4: Test accuracy = 76.1 %
Training criterion: potential


Epoch 1:   0%|          | 4/1562 [00:00<01:01, 25.21it/s, loss=-57403.28]   

Epoch 0: Test accuracy = 77.7 %


Epoch 2:   0%|          | 5/1562 [00:00<00:54, 28.52it/s, loss=-157267.64]   

Epoch 1: Test accuracy = 74.9 %


Epoch 3:   0%|          | 4/1562 [00:00<01:02, 24.87it/s, loss=-288153.00]   

Epoch 2: Test accuracy = 76.0 %


Epoch 4:   0%|          | 4/1562 [00:00<01:04, 24.22it/s, loss=-449243.53]   

Epoch 3: Test accuracy = 77.4 %


                                                                             

Epoch 4: Test accuracy = 77.3 %


#### SGLD inference

In [12]:
from bnn_priors.inference import SGLDRunner

In [13]:
model = PreActResNet18(bn=True).to(device)

n_samples = 20
skip = 5
cycles = 2
warmup = 1000
burnin = 1000
lr = 5e-4
temperature = 1.0
momentum = 0.9
precond_update = None
    
sample_epochs = n_samples * skip // cycles
epochs_per_cycle = warmup + burnin + sample_epochs

In [14]:
mcmc = SGLDRunner(model=model, dataloader=dataloader_train, epochs_per_cycle=epochs_per_cycle,
                  warmup_epochs=warmup, sample_epochs=sample_epochs, learning_rate=lr,
                  skip=skip, sampling_decay=True, cycles=cycles, temperature=temperature,
                  momentum=momentum, precond_update=precond_update)

In [None]:
mcmc.run(progressbar=True)

Cycle 0, Sampling:   3%|▎         | 57/2050 [1:01:45<35:35:42, 64.30s/it]