# MNIST

In [1]:
pip install pyro-ppl

Collecting pyro-ppl
  Downloading pyro_ppl-1.8.4-py3-none-any.whl (730 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m730.7/730.7 kB[0m [31m12.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pyro-api>=0.1.1
  Downloading pyro_api-0.1.2-py3-none-any.whl (11 kB)
Installing collected packages: pyro-api, pyro-ppl
Successfully installed pyro-api-0.1.2 pyro-ppl-1.8.4
[0mNote: you may need to restart the kernel to use updated packages.


In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
from torchvision.datasets import MNIST
import torch.nn.functional as F
import torchvision.transforms as transforms

In [3]:
# MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True, num_workers=2)

test_set = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 85991729.31it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 42031816.04it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 22470389.04it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 9283883.42it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



## 1.Gaussian weights

### 1.1Variable Inference

In [4]:
# Model with Gaussian prior 
class BNN(nn.Module):
    def __init__(self, n_in, n_out,n_hide):
        super().__init__()
        self.w = nn.Parameter(torch.randn(n_in, n_out))
        self.fc1 = nn.Linear(n_in, n_hide) 
        self.fc2 = nn.Linear(n_hide, n_out)
        self.dropout = nn.Dropout(0.5) 

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)  

    def kl_divergence(self):
        mean, std = self.w.mean(), self.w.std()
        return ((mean ** 2 + std ** 2 - torch.log(std **2)-1)/2).sum()

In [5]:
model = BNN(784,500,10)  

In [6]:
# Training  
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
for epoch in range(50):
    for x, y in train_loader: 
        x = x.view(-1, 784)
        loss = F.nll_loss(model(x), y)
        loss += model.kl_divergence() * 0.1  
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f'Epoch: {epoch+1}, Train Loss: {loss.item():.3f}')

Epoch: 1, Train Loss: 6.126
Epoch: 2, Train Loss: 5.833
Epoch: 3, Train Loss: 5.738
Epoch: 4, Train Loss: 5.476
Epoch: 5, Train Loss: 5.142
Epoch: 6, Train Loss: 4.885
Epoch: 7, Train Loss: 4.888
Epoch: 8, Train Loss: 4.614
Epoch: 9, Train Loss: 4.245
Epoch: 10, Train Loss: 4.000
Epoch: 11, Train Loss: 4.192
Epoch: 12, Train Loss: 3.790
Epoch: 13, Train Loss: 3.609
Epoch: 14, Train Loss: 3.746
Epoch: 15, Train Loss: 3.743
Epoch: 16, Train Loss: 3.165
Epoch: 17, Train Loss: 2.998
Epoch: 18, Train Loss: 3.128
Epoch: 19, Train Loss: 2.856
Epoch: 20, Train Loss: 2.996
Epoch: 21, Train Loss: 2.711
Epoch: 22, Train Loss: 2.883
Epoch: 23, Train Loss: 2.501
Epoch: 24, Train Loss: 2.767
Epoch: 25, Train Loss: 2.735
Epoch: 26, Train Loss: 2.525
Epoch: 27, Train Loss: 2.530
Epoch: 28, Train Loss: 2.552
Epoch: 29, Train Loss: 2.087
Epoch: 30, Train Loss: 2.504
Epoch: 31, Train Loss: 2.326
Epoch: 32, Train Loss: 2.136
Epoch: 33, Train Loss: 1.969
Epoch: 34, Train Loss: 2.239
Epoch: 35, Train Loss: 

In [7]:
# Test accuracy  
with torch.no_grad():
    correct = 0
    for x, y in test_loader:
        x = x.view(-1, 784)
        y_pred = model(x).argmax(dim=1)
        correct += (y_pred == y).sum().item() 
print(f'Test accuracy: {correct/len(test_set):.3f}') 

Test accuracy: 0.478


### 1.2MCMC

In [8]:
# Get train data
train_x = []  
train_y = []
for x, y in train_loader:
    train_x.append(x.view(-1, 784))
    train_y.append(y)
train_x = torch.cat(train_x)  
train_y = torch.cat(train_y)

In [9]:
# Model with Gaussian prior 
class BNN(nn.Module):
    def __init__(self, n_in, n_hide, n_out):
        super().__init__()
        self.fc1 = nn.Linear(n_in, n_hide)  
        self.fc2 = nn.Linear(n_hide, n_out)
        self.dropout = nn.Dropout(0.5)
        self.w = torch.zeros(n_in, n_out)  # Weight parameter

    def forward(self, x,w_proposal=None):
        if w_proposal is None:
              w = self.w
        else:
              w = w_proposal
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        pred = F.log_softmax(x, dim=1)
        dist = torch.distributions.Categorical(logits=pred)
        return dist

In [10]:
# MCMC Sampling
model = BNN(784, 500, 10) 
w = model.w.data  # Get weight tensor 

In [11]:
for i in range(2000):  # 2000 iterations
    # Proposal distribution
    w_proposal = w + torch.randn(w.size())*0.1 

    # Acceptance ratio
    ap = torch.exp(model(train_x).log_prob(train_y).sum() - model(train_x, w_proposal).log_prob(train_y).sum()) 

    # Metropolis acceptance
    u = torch.rand(1)
    if u < ap:
        w = w_proposal  # Accept proposal 

    model.w.data = w  # Set weight to sampled value

In [12]:
# Test accuracy
with torch.no_grad():
    correct = 0
    for x, y in test_loader:
        x = x.view(-1, 784)
        y_pred = model(x).sample()
        correct += (y_pred == y).sum().item()
    print(f'Test accuracy: {correct/len(test_set):.3f}') 

Test accuracy: 0.094


## 2.Laplace weights

### 2.1Variable Inference

In [13]:
# Model with Laplace prior 
class BNN(nn.Module):
    def __init__(self, n_in, n_out,n_hide):
        super().__init__()
        self.w = nn.Parameter(torch.zeros(n_in, n_out))
        self.fc1 = nn.Linear(n_in, n_hide)  
        self.fc2 = nn.Linear(n_hide, n_out)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)  

    def kl_divergence(self):
        return (self.w.abs().sum() / 2).sum()  # Laplace prior  

In [14]:
model = BNN(784,500,10)  

In [15]:
# Training  
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
for epoch in range(50):
    for x, y in train_loader: 
        x = x.view(-1, 784)
        loss = F.nll_loss(model(x), y)
        loss += model.kl_divergence() * 0.1   # Weight decay as KL term
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f'Epoch: {epoch+1}, Train Loss: {loss.item():.3f}')

Epoch: 1, Train Loss: 5.947
Epoch: 2, Train Loss: 5.825
Epoch: 3, Train Loss: 5.602
Epoch: 4, Train Loss: 5.308
Epoch: 5, Train Loss: 4.997
Epoch: 6, Train Loss: 4.989
Epoch: 7, Train Loss: 4.923
Epoch: 8, Train Loss: 4.276
Epoch: 9, Train Loss: 4.003
Epoch: 10, Train Loss: 4.480
Epoch: 11, Train Loss: 3.782
Epoch: 12, Train Loss: 3.833
Epoch: 13, Train Loss: 3.691
Epoch: 14, Train Loss: 4.087
Epoch: 15, Train Loss: 3.338
Epoch: 16, Train Loss: 3.167
Epoch: 17, Train Loss: 3.277
Epoch: 18, Train Loss: 3.164
Epoch: 19, Train Loss: 3.211
Epoch: 20, Train Loss: 2.908
Epoch: 21, Train Loss: 2.956
Epoch: 22, Train Loss: 2.956
Epoch: 23, Train Loss: 2.594
Epoch: 24, Train Loss: 2.434
Epoch: 25, Train Loss: 2.398
Epoch: 26, Train Loss: 2.483
Epoch: 27, Train Loss: 2.545
Epoch: 28, Train Loss: 2.453
Epoch: 29, Train Loss: 2.456
Epoch: 30, Train Loss: 2.548
Epoch: 31, Train Loss: 2.558
Epoch: 32, Train Loss: 2.463
Epoch: 33, Train Loss: 2.121
Epoch: 34, Train Loss: 2.273
Epoch: 35, Train Loss: 

In [16]:
# Test accuracy  
with torch.no_grad():
    correct = 0
    for x, y in test_loader:
        x = x.view(-1, 784)
        y_pred = model(x).argmax(dim=1)
        correct += (y_pred == y).sum().item() 
print(f'Test accuracy: {correct/len(test_set):.3f}') 

Test accuracy: 0.470


### 2.2MCMC

In [17]:
# Get train data
train_x = []  
train_y = []
for x, y in train_loader:
    train_x.append(x.view(-1, 784))
    train_y.append(y)
train_x = torch.cat(train_x)  
train_y = torch.cat(train_y)

In [18]:
# Model with Laplace prior
class BNN(nn.Module):
    def __init__(self, n_in, n_hide, n_out):
        super().__init__()
        self.fc1 = nn.Linear(n_in, n_hide)  
        self.fc2 = nn.Linear(n_hide, n_out)
        self.dropout = nn.Dropout(0.5)
        self.w = torch.zeros(n_in, n_out)  

    def forward(self, x, w_proposal=None):
        if w_proposal is None:  
              w = self.w
        else:
              w = w_proposal
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        pred = F.log_softmax(x, dim=1)
        dist = torch.distributions.Categorical(logits=pred)
        return dist

    def log_prior(self, w):
        return -w.abs().sum() 

In [19]:
# MCMC Sampling
model = BNN(784, 500, 10) 
w = model.w.data  # Get weight tensor 

In [20]:
for i in range(2000): 
    # Proposal distribution
    w_proposal = w + torch.randn(w.size())*0.1

    # Prior ratio 
    lp = model.log_prior(w_proposal)  
    ap = torch.exp(lp - model.log_prior(w))

    # Metropolis acceptance 
    u = torch.rand(1)
    if u < ap:
        w = w_proposal  

    model.w.data = w 

In [21]:
# Test accuracy
with torch.no_grad():
    correct = 0
    for x, y in test_loader:
        x = x.view(-1, 784)
        y_pred = model(x).sample()
        correct += (y_pred == y).sum().item()
    print(f'Test accuracy: {correct/len(test_set):.3f}') 

Test accuracy: 0.091


## 3.Equalisation weights

### 3.1Variable Inference

In [22]:
# Model with Uniform prior
class BNN(nn.Module):
    def __init__(self, n_in, n_out,n_hide):
        super().__init__()
        self.w = nn.Parameter(torch.zeros(n_in, n_out))
        self.fc1 = nn.Linear(n_in, n_hide)
        self.fc2 = nn.Linear(n_hide, n_out)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)  

    def kl_divergence(self):
        return 0.5*(self.w**2).sum()  # Uniform prior  

In [23]:
model = BNN(784,500,10)  

In [24]:
# Training  
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
for epoch in range(50):
    for x, y in train_loader: 
        x = x.view(-1, 784)
        loss = F.nll_loss(model(x), y)
        loss += model.kl_divergence() * 0.1  
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f'Epoch: {epoch+1}, Train Loss: {loss.item():.3f}')

Epoch: 1, Train Loss: 5.939
Epoch: 2, Train Loss: 5.554
Epoch: 3, Train Loss: 5.293
Epoch: 4, Train Loss: 5.365
Epoch: 5, Train Loss: 5.016
Epoch: 6, Train Loss: 4.291
Epoch: 7, Train Loss: 4.648
Epoch: 8, Train Loss: 4.227
Epoch: 9, Train Loss: 3.973
Epoch: 10, Train Loss: 3.773
Epoch: 11, Train Loss: 4.025
Epoch: 12, Train Loss: 3.557
Epoch: 13, Train Loss: 3.611
Epoch: 14, Train Loss: 3.416
Epoch: 15, Train Loss: 2.967
Epoch: 16, Train Loss: 3.783
Epoch: 17, Train Loss: 3.161
Epoch: 18, Train Loss: 3.268
Epoch: 19, Train Loss: 3.242
Epoch: 20, Train Loss: 2.768
Epoch: 21, Train Loss: 2.356
Epoch: 22, Train Loss: 2.816
Epoch: 23, Train Loss: 2.693
Epoch: 24, Train Loss: 2.332
Epoch: 25, Train Loss: 2.488
Epoch: 26, Train Loss: 2.321
Epoch: 27, Train Loss: 2.283
Epoch: 28, Train Loss: 2.267
Epoch: 29, Train Loss: 2.152
Epoch: 30, Train Loss: 2.359
Epoch: 31, Train Loss: 2.036
Epoch: 32, Train Loss: 2.138
Epoch: 33, Train Loss: 2.218
Epoch: 34, Train Loss: 2.208
Epoch: 35, Train Loss: 

In [25]:
# Test accuracy  
with torch.no_grad():
    correct = 0
    for x, y in test_loader:
        x = x.view(-1, 784)
        y_pred = model(x).argmax(dim=1)
        correct += (y_pred == y).sum().item() 
print(f'Test accuracy: {correct/len(test_set):.3f}') 

Test accuracy: 0.472


### 3.2MCMC

In [26]:
# Get train data
train_x = []  
train_y = []
for x, y in train_loader:
    train_x.append(x.view(-1, 784))
    train_y.append(y)
train_x = torch.cat(train_x)  
train_y = torch.cat(train_y)

In [27]:
# Model with Uniform prior
class BNN(nn.Module):
    def __init__(self, n_in, n_hide, n_out):
        super().__init__()
        self.fc1 = nn.Linear(n_in, n_hide)  
        self.fc2 = nn.Linear(n_hide, n_out)
        self.dropout = nn.Dropout(0.5)
        self.w = torch.zeros(n_in, n_out)  # Weight parameter

    def forward(self, x, w_proposal=None):
        if w_proposal is None:
              w = self.w
        else:
              w = w_proposal
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        pred = F.log_softmax(x, dim=1)
        dist = torch.distributions.Categorical(logits=pred)
        return dist

    def log_prior(self, w):
        return -torch.sum(torch.abs(w))

In [28]:
# MCMC Sampling
model = BNN(784, 500, 10)  
w = model.w.data 

In [29]:
for i in range(2000): 
    # Evenly distributed
    w_proposal = torch.rand(w.size())

    # A priori ratio - logarithmic difference of uniform distribution
    ap = torch.tensor(0.)  

    u = torch.rand(1)
    if u < ap:  
          w = w_proposal

    model.w.data = w  

    logits = model(train_x, w_proposal).logits
    loss = F.cross_entropy(logits, train_y)  

In [30]:
# Test accuracy
with torch.no_grad():
    correct = 0
    for x, y in test_loader:
        x = x.view(-1, 784)
        y_pred = model(x).sample()
        correct += (y_pred == y).sum().item()
    print(f'Test accuracy: {correct/len(test_set):.3f}') 

Test accuracy: 0.101
