In [1]:
import sys
sys.path.append("..")

import time
import numpy as np
import matplotlib.pyplot as plt

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F

import seaborn as sns # conda install seaborn
import pandas as pd # ^^ this will automatically install pandas

import pyro
from pyro.infer.mcmc import MCMC
import pyro.distributions as dist

from kernel.sghmc import SGHMC
from kernel.sgld import SGLD
from kernel.sgd import SGD
from kernel.sgnuts import NUTS as SGNUTS

pyro.set_rng_seed(101)

In [2]:
# Simple dataset wrapper class

class Dataset(torch.utils.data.Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets
        
    def __len__(self):
        return(len(self.data))
    
    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

### Hyperparams

In [3]:
BATCH_SIZE = 500
NUM_EPOCHS = 800
WARMUP_EPOCHS = 50
HIDDEN_SIZE = 100

### Download MNIST and setup datasets / dataloaders

In [4]:
train_dataset = datasets.MNIST('./data', train=True, download=True)

test_dataset = datasets.MNIST('./data', train=False, download=True)

nvalid = 10000

perm = torch.arange(len(train_dataset))
train_idx = perm[nvalid:]
val_idx = perm[:nvalid]
    
mean = 0.1307
std = 0.3081

# scale the datasets
X_train = train_dataset.data[train_idx] / 255.0
Y_train = train_dataset.targets[train_idx]

X_val = train_dataset.data[val_idx] / 255.0
Y_val = train_dataset.targets[val_idx]

X_test = test_dataset.data / 255.0
Y_test = test_dataset.targets

# redefine the datasets
train_dataset = Dataset(X_train, Y_train)
val_dataset = Dataset(X_val, Y_val)
test_dataset = Dataset(X_test, Y_test)

# setup the dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

### Define the Bayesian neural network  model

In [5]:
PyroLinear = pyro.nn.PyroModule[torch.nn.Linear]
    
class BNN(pyro.nn.PyroModule):
    
    def __init__(self, input_size, hidden_size, output_size, prec=1.):
        super().__init__()
        # prec is a kwarg that should only used by SGD to set the regularization strength 
        # recall that a Guassian prior over the weights is equivalent to L2 norm regularization in the non-Bayes setting
        
        # TODO add gamma priors to precision terms
        self.fc1 = PyroLinear(input_size, hidden_size)
        self.fc1.weight = pyro.nn.PyroSample(dist.Normal(0., prec).expand([hidden_size, input_size]).to_event(2))
        self.fc1.bias   = pyro.nn.PyroSample(dist.Normal(0., prec).expand([hidden_size]).to_event(1))
        
        self.fc2 = PyroLinear(hidden_size, output_size)
        self.fc2.weight = pyro.nn.PyroSample(dist.Normal(0., prec).expand([output_size, hidden_size]).to_event(2))
        self.fc2.bias   = pyro.nn.PyroSample(dist.Normal(0., prec).expand([output_size]).to_event(1))
        
        self.relu = torch.nn.ReLU()
        self.log_softmax = torch.nn.LogSoftmax(dim=1)

    def forward(self, x, y=None):
        x = x.view(-1, 28*28)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        x = self.log_softmax(x)# output (log) softmax probabilities of each class
        
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", dist.Categorical(logits=x), obs=y)

### Run SGHMC 

We run SGHMC to sample approximately from the posterior distribution.

In [6]:
LR = 2e-6
MOMENTUM_DECAY = 0.01
RESAMPLE_EVERY_N = 100
NUM_STEPS = 1

pyro.clear_param_store()

bnn = BNN(28*28, HIDDEN_SIZE, 10)

sghmc = SGHMC(bnn,
              subsample_positions=[0, 1],
              batch_size=BATCH_SIZE,
              learning_rate=LR,
              momentum_decay=MOMENTUM_DECAY,
              num_steps=NUM_STEPS,
              resample_every_n=RESAMPLE_EVERY_N)

sghmc_mcmc = MCMC(sghmc, num_samples=len(train_dataset)//BATCH_SIZE, warmup_steps=0)

sghmc_test_errs = []

# full posterior predictive 
full_predictive = torch.FloatTensor(10000, 10)
full_predictive.zero_()

for epoch in range(1, 1+NUM_EPOCHS + WARMUP_EPOCHS):
    sghmc_mcmc.run(X_train, Y_train)
    
    if epoch >= WARMUP_EPOCHS:
        
        sghmc_samples = sghmc_mcmc.get_samples()
        predictive = pyro.infer.Predictive(bnn, posterior_samples=sghmc_samples)
        start = time.time()
        
        with torch.no_grad():
            epoch_predictive = None
            for x, y in val_loader:
                if epoch_predictive is None:
                    epoch_predictive = predictive(x)['obs'].to(torch.int64)
                else:
                    epoch_predictive = torch.cat((epoch_predictive, predictive(x)['obs'].to(torch.int64)), dim=1)
                    
            for sample in epoch_predictive:
                predictive_one_hot = F.one_hot(sample, num_classes=10)
                full_predictive = full_predictive + predictive_one_hot
                
            full_y_hat = torch.argmax(full_predictive, dim=1)
            total = Y_val.shape[0]
            correct = int((full_y_hat == Y_val).sum())
            
        end = time.time()
        
        sghmc_test_errs.append(1.0 - correct/total)

        print("Epoch [{}/{}] test accuracy: {:.4f} time: {:.2f}".format(epoch-WARMUP_EPOCHS, NUM_EPOCHS, correct/total, end - start))

Sample: 100%|█████████████████████████████████████████████████████████████████| 100/100 [00:02, 39.65it/s, lr=2.00e-06]
Sample: 100%|█████████████████████████████████████████████████████████████████| 100/100 [00:02, 47.84it/s, lr=2.00e-06]
Sample: 100%|█████████████████████████████████████████████████████████████████| 100/100 [00:02, 46.55it/s, lr=2.00e-06]
Sample: 100%|█████████████████████████████████████████████████████████████████| 100/100 [00:02, 44.82it/s, lr=2.00e-06]
Sample: 100%|█████████████████████████████████████████████████████████████████| 100/100 [00:02, 47.79it/s, lr=2.00e-06]
Sample: 100%|█████████████████████████████████████████████████████████████████| 100/100 [00:02, 46.23it/s, lr=2.00e-06]
Sample: 100%|█████████████████████████████████████████████████████████████████| 100/100 [00:02, 47.30it/s, lr=2.00e-06]
Sample: 100%|█████████████████████████████████████████████████████████████████| 100/100 [00:02, 46.99it/s, lr=2.00e-06]
Sample: 100%|███████████████████████████

KeyboardInterrupt: 

### Run SGLD

We run SGLD to sample approximately from the posterior distribution.

In [None]:
LR = 4e-5
LR_DECAY = True

if LR_DECAY:
    D = 0.25 # decay by 1/4
    B = (NUM_EPOCHS * D**2) / (1 - D**2)
    A = LR * np.sqrt((NUM_EPOCHS * D**2) / (1 - D**2))
    
NUM_STEPS = 1

pyro.clear_param_store()

bnn = BNN(28*28, HIDDEN_SIZE, 10)

sgld = SGLD(bnn,
            subsample_positions=[0, 1],
            batch_size=BATCH_SIZE,
            learning_rate=LR_START,
            num_steps=NUM_STEPS)

sgld_mcmc = MCMC(sgld, num_samples=len(train_dataset)//BATCH_SIZE, warmup_steps=0)

sgld_test_errs = []

# full posterior predictive 
full_predictive = torch.FloatTensor(10000, 10)
full_predictive.zero_()

for epoch in range(1, 1+NUM_EPOCHS + WARMUP_EPOCHS):   
    sgld_mcmc.run(X_train, Y_train)
    
    if epoch >= WARMUP_EPOCHS:
        if LR_DECAY:
            LR = A / np.sqrt((B + (epoch-1)))
            sgld_mcmc.kernel.learning_rate = LR
        
        start = time.time()
        sgld_samples = sgld_mcmc.get_samples()
        predictive = pyro.infer.Predictive(bnn, posterior_samples=sgld_samples)
        
        with torch.no_grad():
            epoch_predictive = None
            for x, y in val_loader:
                if epoch_predictive is None:
                    epoch_predictive = predictive(x)['obs'].to(torch.int64)
                else:
                    epoch_predictive = torch.cat((epoch_predictive, predictive(x)['obs'].to(torch.int64)), dim=1)
            
            for sample in epoch_predictive:
                predictive_one_hot = F.one_hot(sample, num_classes=10)
                if LR_DECAY:
                    predictive_one_hot = predictive_one_hot * LR
                full_predictive = full_predictive + predictive_one_hot
                
            full_y_hat = torch.argmax(full_predictive, dim=1)
            total = Y_val.shape[0]
            correct = int((full_y_hat == Y_val).sum())
            
        end = time.time()
        
        sgld_test_errs.append(1.0 - correct/total)

        print("Epoch [{}/{}] test accuracy: {:.4f} time: {:.2f}".format(epoch-WARMUP_EPOCHS, NUM_EPOCHS, correct/total, end - start))

### Run SGD

We run SGD to optimise the weights of the BNN and we take a point estimate which is the most recent sample to be our "best" parameters

In [None]:
LR = 6e-5
WEIGHT_DECAY=0.0
WITH_MOMENTUM=False
MOMENTUM_DECAY=1.0
REGULARIZATION_TERM=1.

pyro.clear_param_store()

bnn = BNN(28*28, HIDDEN_SIZE, 10, prec=REGULARIZATION_TERM)

sgd = SGD(bnn,
          subsample_positions=[0, 1],
          batch_size=BATCH_SIZE,
          learning_rate=LR,
          weight_decay=WEIGHT_DECAY,
          with_momentum=WITH_MOMENTUM,
          momentum_decay=MOMENTUM_DECAY)

sgd_mcmc = MCMC(sgd, num_samples=len(train_dataset)//BATCH_SIZE, warmup_steps=0)

sgd_test_errs = []

for epoch in range(1, 1+NUM_EPOCHS+WARMUP_EPOCHS):
    sgd_mcmc.run(X_train, Y_train)
        
    if epoch >= WARMUP_EPOCHS:
        
        sgd_samples = sgd_mcmc.get_samples()
        point_estimate = {site : sgd_samples[site][-1, :].unsqueeze(0) for site in sgd_samples.keys()}
        predictive = pyro.infer.Predictive(bnn, posterior_samples=point_estimate)
        start = time.time()
        
        with torch.no_grad():
            total = 0
            correct = 0
            for x, y in val_loader:
                batch_predictive = predictive(x)['obs']
                batch_y_hat = batch_predictive.mode(0)[0]
                total += y.shape[0]
                correct += int((batch_y_hat == y).sum())
            
        end = time.time()
        
        sgd_test_errs.append(1.0 - correct/total)

        print("Epoch [{}/{}] test accuracy: {:.4f} time: {:.2f}".format(epoch-WARMUP_EPOCHS, NUM_EPOCHS, correct/total, end - start))

### Run SGD with momentum

We run SGD with momentum to optimise the weights of the BNN and we take a point estimate which is the most recent sample to be our "best" parameters

In [None]:
LR = 8e-4
WEIGHT_DECAY=1e-6
WITH_MOMENTUM=True
MOMENTUM_DECAY=0.01
REGULARIZATION_TERM=1.

pyro.clear_param_store()

bnn = BNN(28*28, HIDDEN_SIZE, 10, prec=REGULARIZATION_TERM)

sgdmom = SGD(bnn,
             subsample_positions=[0, 1],
             batch_size=BATCH_SIZE,
             learning_rate=LR,
             weight_decay=WEIGHT_DECAY,
             with_momentum=WITH_MOMENTUM,
             momentum_decay=MOMENTUM_DECAY)

sgdmom_mcmc = MCMC(sgdmom, num_samples=len(train_dataset)//BATCH_SIZE, warmup_steps=0)

sgdmom_test_errs = []

for epoch in range(1, 1+NUM_EPOCHS+WARMUP_EPOCHS):
    sgdmom_mcmc.run(X_train, Y_train)
        
    if epoch >= WARMUP_EPOCHS:
        
        sgdmom_samples = sgdmom_mcmc.get_samples()
        point_estimate = {site : sgdmom_samples[site][-1, :].unsqueeze(0) for site in sgdmom_samples.keys()}
        predictive = pyro.infer.Predictive(bnn, posterior_samples=point_estimate)
        start = time.time()
        
        with torch.no_grad():
            total = 0
            correct = 0
            for x, y in val_loader:
                batch_predictive = predictive(x)['obs']
                batch_y_hat = batch_predictive.mode(0)[0]
                total += y.shape[0]
                correct += int((batch_y_hat == y).sum())
            
        end = time.time()
        
        sgdmom_test_errs.append(1.0 - correct/total)

        print("Epoch [{}/{}] test accuracy: {:.4f} time: {:.2f}".format(epoch-WARMUP_EPOCHS, NUM_EPOCHS, correct/total, end - start))

### Plot the convergence dynamics

In [None]:
sns.set_style("dark")
    
sghmc_test_errs = np.array(sghmc_test_errs)
sgld_test_errs = np.array(sgld_test_errs)
sgd_test_errs = np.array(sgd_test_errs)
sgdmom_test_errs = np.array(sgdmom_test_errs)

err_dict = {'SGHMC' : sghmc_test_errs, 'SGLD' : sgld_test_errs, 'SGD' : sgd_test_errs, 'SGD with momentum' : sgdmom_test_errs}
x = np.arange(1, NUM_EPOCHS+1)
lst = []
for i in range(len(x)):
    for updater in err_dict.keys():
        lst.append([x[i], updater, err_dict[updater][i]])

df = pd.DataFrame(lst, columns=['iterations', 'updater','test error'])
sns.lineplot(data=df.pivot("iterations", "updater", "test error"))
plt.ylabel("test error")
plt.show() #dpi=300

Stochastic Gradient NUTS - experimental doesn't quite work yet

In [None]:
LR = 8e-7
MOMENTUM_DECAY = 0.01
NUM_STEPS = 1

pyro.clear_param_store()

bnn = BNN(28*28, HIDDEN_SIZE, 10)

sgnuts = SGNUTS(bnn, 
                subsample_positions=[0, 1],
                batch_size=BATCH_SIZE,
                learning_rate=LR, 
                momentum_decay=MOMENTUM_DECAY,
                resample_every_n=100, 
                obs_info_noise=False, 
                use_multinomial_sampling=True,
                max_tree_depth=10,
                target_accept=0.8)

# do warm up
sgnuts_mcmc = MCMC(sgnuts, num_samples=0, warmup_steps=5000)
sgnuts_mcmc.run(X_train, Y_train)

sgnuts_mcmc.num_samples = len(train_dataset)//BATCH_SIZE
sgnuts_mcmc.warmup_steps = 0

# full posterior predictive 
full_predictive = torch.FloatTensor(10000, 10)
full_predictive.zero_()

for epoch in range(1, 1+NUM_EPOCHS):
    sgnuts_mcmc.run(X_train, Y_train)
    
    if epoch >= 0:
        
        sgnuts_samples = sgnuts_mcmc.get_samples()
        predictive = pyro.infer.Predictive(bnn, posterior_samples=sgnuts_samples)
        start = time.time()
        
        with torch.no_grad():
            epoch_predictive = None
            for x, y in val_loader:
                if epoch_predictive is None:
                    epoch_predictive = predictive(x)['obs'].to(torch.int64)
                else:
                    epoch_predictive = torch.cat((epoch_predictive, predictive(x)['obs'].to(torch.int64)), dim=1)
                    
            for sample in epoch_predictive:
                predictive_one_hot = F.one_hot(sample, num_classes=10)
                full_predictive = full_predictive + predictive_one_hot
                
            full_y_hat = torch.argmax(full_predictive, dim=1)
            total = Y_val.shape[0]
            correct = int((full_y_hat == Y_val).sum())
            
        end = time.time()

        print("Epoch [{}/{}] test accuracy: {:.4f} time: {:.2f}".format(epoch-WARMUP_EPOCHS, NUM_EPOCHS, correct/total, end - start))