In [1]:
import torch
import torch.nn as nn
import torch.utils.data as data

import torchvision

import pyro
import pyro.distributions as dist
import pyro.contrib.easybnn as ezbnn

In [2]:
resnet = torchvision.models.resnet18(pretrained=True)

In [3]:
resnet

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [4]:
x = torch.randn(2, 3, 256, 256)

In [5]:
resnet(x)

tensor([[ 0.1847, -1.0893,  0.7534,  ..., -0.2881,  1.3715,  1.6030],
        [-1.2186,  0.2616, -1.7894,  ..., -0.6291,  1.1722,  0.6086]],
       grad_fn=<AddmmBackward>)

In [29]:
bayesian_resnet = ezbnn.BNN(resnet, ezbnn.priors.IIDPrior(dist.Normal(0., 1.), hide_modules=[nn.BatchNorm2d]),
                            ezbnn.observation_models.Categorical(), ezbnn.guides.ParameterwiseDiagonalNormal,
                            init_loc_fn=ezbnn.guides.SitewiseInitializer.from_net(resnet, prefix="net"),
                            init_scale=1e-3)

In [30]:
dataset = data.TensorDataset(torch.randn(10, 3, 256, 256), torch.randint(1000, size=(10,)))
dataloader = data.DataLoader(dataset, batch_size=5)

In [31]:
optim = pyro.optim.Adam({"lr": 1e-3})

In [34]:
svi = bayesian_resnet.fit(dataloader, optim, 50)

In [35]:
bayesian_resnet.evaluate(*dataset.tensors, num_predictions=5, reduction="mean")

(tensor(0.9000), tensor(-3.0290))

In [37]:
predictions = bayesian_resnet.predict(dataset.tensors[0], num_predictions=5)
p = predictions.softmax(-1).mean(0)
top5 = p.topk(5, dim=-1).indices
top5_acc = dataset.tensors[1].unsqueeze(-1).eq(top5).max(dim=1).values.float().mean()
top5_acc

tensor(0.5000)