In [1]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import pyro
from pyro.optim import Adam, ClippedAdam
from pyro.infer import SVI, Trace_ELBO, ELBO, TraceMeanField_ELBO
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.infer.autoguide import AutoDiagonalNormal, AutoGuide, AutoNormal
from pyro.contrib.bnn import HiddenLayer
from pyro.nn import PyroSample, PyroModule
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.infer import Predictive

import torch 
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.distributions import constraints

from tqdm import trange, tqdm
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_moons, load_wine
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

import matplotlib.pyplot as plt
np.random.seed(42)
pyro.set_rng_seed(42)
# torch.set_default_dtype(torch.float64)

## Helper functions

In [2]:
def fit(model,
        loader,
        epochs = 2,
        multi=False,
        optimizer=None,
        criterion=nn.BCEWithLogitsLoss(),
        lr=0.001):
    """ Train the Neural network """
    epochs = trange(epochs)
    if optimizer == None:
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    for epoch in epochs:
        
        running_loss = []
        
        for inputs, labels in loader:
            
            inputs = inputs.float()            
            optimizer.zero_grad()            
            outputs = model(inputs)
            
            
            if multi is False:
                labels = labels.type_as(outputs)
                labels = labels.unsqueeze(1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss.append(loss.item())
            
        loss = sum(running_loss) / len(running_loss)
        string = f"Loss = {loss:.5f}"
        epochs.set_postfix_str(s=string)


def score(model, dataloader, type_ = 'test', multi=False):
    """ Score the Neural Network"""
    correct=0
    total = 0
    for test_inputs, test_labels in dataloader:

        test_inputs = test_inputs.float()
        outputs = model(test_inputs)
        if multi:
            _, predicted = torch.max(outputs, 1)
            acc = accuracy_score(np.array(predicted), np.array(test_labels))
        else:
            test_inputs = test_inputs.type_as(outputs)
            predicted = torch.round(outputs).detach().numpy()
            acc = accuracy_score(np.array(predicted), np.array(test_labels))
    print(f"Accuracy on {type_}: {100 * acc:.3f}%") 
    
 ## Bayesian NN functions   
def score_bnn(model, guide, dataloader, num_samples=800, type_='test'):
    """ Score the Bayesian network using the Predictive class"""
    predictive = Predictive(model, guide=guide, num_samples=num_samples,
                    return_sites=("linear.weight", "obs", "_RETURN", "data"))
    for data, labels in dataloader:
        pred = predictive(data)
        pred_vals = pred["_RETURN"].detach().numpy().mean(axis=0).round()
        lables = labels.unsqueeze(1)
        acc = accuracy_score(lables, pred_vals)
    
    print(f"Accuracy on {type_}: {100 * acc}%")
    

def predict(data,labels,  guide):
    num_samples = 10
    sampled_models = [guide(None, None) for _ in range(num_samples)]
    predictions = [m(data).data for m in sampled_models]
    mean = torch.mean(torch.stack(predictions), 0)
    return np.argmax(mean.numpy(), axis=1)


def score_bnn2(guide, dataloader, type_ = 'test'):
    """ Score the Bayesian Network using the guide to predict"""
    for  data, labels in dataloader:
        predicted = predict(data.float(),labels.float(), guide)
        acc = accuracy_score(np.array(predicted),np.array(labels))
#         correct += (np.array(predicted) == np.array(labels)).sum().item()

    print(f"Accuracy on {type_}: {100 * acc:.2f}%")

# Classic Model


In [92]:
X_sanity = np.random.randn(200,2)

class SanityDataset(torch.utils.data.dataset.Dataset):
    def __init__(self, X):

        self.X = X

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        y = np.tanh(self.X[idx, 0] + self.X[idx, 1])
        y =  1. / ( 1. + np.exp(-( y + y)))
        y = (y > 0.5).astype(np.int)
        return (self.X[idx,:]).astype(np.float32), y
    
split = 100
train_sanity = torch.utils.data.DataLoader(SanityDataset(X_sanity[:split]), batch_size=split, shuffle=True)
test_sanity = torch.utils.data.DataLoader(SanityDataset(X_sanity[split:]), batch_size=split,shuffle=True)

In [93]:
class SanityModel(nn.Module):
    def __init__(self, num_in=2):
        super(SanityModel, self).__init__()
        self.fc1 = nn.Linear(num_in, 2)
        self.fc2 = nn.Linear(2, 1)

    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        return x

In [94]:
sanity = SanityModel()
fit(sanity, train_sanity, 100, criterion=nn.BCELoss(), lr=0.01)
score(sanity, test_sanity)

100%|██████████| 100/100 [00:00<00:00, 246.56it/s, Loss = 0.24957]

Accuracy on test: 99.000%





In [95]:
class BayesianSanityCheck(PyroModule):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = PyroModule[nn.Linear](in_features, out_features)
        self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))
        self.linear.bias = PyroSample(dist.Normal(0., 1.).expand([out_features]).to_event(1))

    def forward(self, x, y=None):
        param = torch.sigmoid(self.linear(x))
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", dist.Bernoulli(param), obs=y)
        return param
    


In [96]:
# Trace_ELBO, ELBO, TraceMeanField_ELBO
bnn = BayesianSanityCheck(2,1)
guide = AutoNormal(bnn)
adam = pyro.optim.Adam({"lr": 0.01})
svi = SVI(bnn, guide, adam, loss=TraceMeanField_ELBO())

In [97]:
epochs = trange(20)
bnn.train()
for i in epochs:
    for data, labels in train_sanity:
        bnn.zero_grad()
        labels = labels.float()
        loss = svi.step(data, labels)
        loss = loss / len(labels)


    string = f"Loss = {loss:.5f}"
    epochs.set_postfix_str(s=string)


100%|██████████| 20/20 [00:00<00:00, 119.66it/s, Loss = 68.53079]


In [98]:
score_bnn(bnn, guide, test_sanity)

Accuracy on test: 48.0%


## Make Moons dataset

- uses self created guide, instead of AutoGuide
- uses poutine for prediction, instead of the PredictClass

In [99]:
X_moon, y_moon = make_moons(n_samples=150, noise=.05)

class MoonDataset(torch.utils.data.dataset.Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
    
        return self.X[idx,:], self.y[idx]

moon_split = 100
train_moon = DataLoader(MoonDataset(X_moon[:moon_split], y_moon[:moon_split]),batch_size=moon_split, shuffle=True)
test_moon = DataLoader(MoonDataset(X_moon[moon_split:], y_moon[moon_split:]),batch_size=moon_split, shuffle=True)

In [100]:
class MoonModel(nn.Module):
    def __init__(self, num_in=2):
        super(MoonModel, self).__init__()
        self.fc1 = nn.Linear(num_in, 4)
        self.fc2 = nn.Linear(4, 10)
        
        self.out = nn.Linear(10, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = torch.relu(x)
        
        x = self.fc2(x)
        X = torch.relu(x)
        
        x = self.out(x)
        x = torch.sigmoid(x)
        return x

moon_model = MoonModel(2)

In [101]:
fit(moon_model, train_moon, 100, criterion=nn.BCELoss(), lr=0.001)
score(moon_model, test_moon)

100%|██████████| 100/100 [00:00<00:00, 352.73it/s, Loss = 0.64794]


Accuracy on test: 84.000%


In [102]:
class BNN(nn.Module):
    def __init__(self):
        super(BNN, self).__init__()
        self.n_indims = 2
        self.n_hidden = 10
        self.n_classes = 1
       
    def model(self, data, labels=None):
        
        n_data = len(data)
        if labels is not None:
            labels = labels.float()
            
        # Create the tensors for the normal distributions
        # The Hidden Layer uses Normal distribution
        
        l1_mean = torch.zeros(self.n_indims, self.n_hidden)
        l1_scale = torch.ones(self.n_indims, self.n_hidden) 
        
        l2_mean = torch.zeros(self.n_hidden + 1, self.n_classes)
        l2_scale = torch.ones(self.n_hidden + 1, self.n_classes)
        
        
        with pyro.plate('data', size=n_data):
            # Connect the layers
            h1 = pyro.sample('h1', HiddenLayer(data, l1_mean,l1_scale, 
                                                   non_linearity=F.relu,))
           
            logits = pyro.sample('logits', HiddenLayer(h1, l2_mean, l2_scale,
                                                           non_linearity=F.sigmoid,
                                                           include_hidden_bias=False))
            
            
            return pyro.sample('label', dist.Bernoulli(logits=logits), obs=labels) 
    
    def guide(self, data, labels=None):
        n_data = len(data)
        if labels is not None:
            labels = labels.float()

        l1_mean = pyro.param('l1_mean', 0.1 * torch.randn(self.n_indims, self.n_hidden))
        l1_scale = pyro.param('l1_scale', 0.1 * torch.ones(self.n_indims, self.n_hidden),
                              constraint=constraints.greater_than(0.01))
        
        l2_mean = pyro.param('l2_mean', 0.1 * torch.randn(self.n_hidden + 1, self.n_classes))
        l2_scale = pyro.param('l2_scale', 0.1 * torch.ones(self.n_hidden + 1, self.n_classes),
                              constraint=constraints.greater_than(0.01))
        
        with pyro.plate('data', size=n_data):
            h1 = pyro.sample('h1', HiddenLayer(data,
                                               l1_mean,
                                               l1_scale, 
                                               non_linearity=torch.tanh))
            
            logits = pyro.sample('logits', HiddenLayer(h1,
                                                       l2_mean,
                                                       l2_scale,
                                                       non_linearity=torch.sigmoid,
                                                       include_hidden_bias=False))
    
    
    def infer(self, loader, lr=0.01, momentum=0.9,
                         num_epochs=30):
        optim = ClippedAdam({'lr': lr})
        elbo =  TraceMeanField_ELBO()
        svi = SVI(self.model, self.guide, optim, elbo)
        epochs = trange(num_epochs)
        for i in epochs:
            for data, labels in loader:
                loss = svi.step(data.float(), labels)  / len(labels)                
            string = f"Loss = {loss:.5f}"
            epochs.set_postfix_str(s=string)

    def forward(self, images, n_samples=10):
        res = []
        for i in range(n_samples):
            t = poutine.trace(self.guide).get_trace(images)
            res.append(t.nodes['logits']['value'])
        return torch.stack(res, dim=0) 

    def score(self, dataloader, type_ = 'test'):
        """ Score the BNN using the poutine library"""
        for data, labels in dataloader:            
            predicted = bayesnn.forward(data.float(), 10)
            predicted = predicted.detach().numpy().mean(axis=0).round().squeeze()
            acc = accuracy_score(np.array(predicted), np.array(labels))

        print(f"Accuracy on {type_}: {acc * 100:.2f}%")

In [103]:
pyro.clear_param_store()
bayesnn = BNN()
bayesnn.infer(train_moon, num_epochs=100, lr=0.001)
bayesnn.score(test_moon)

100%|██████████| 100/100 [00:00<00:00, 166.05it/s, Loss = 70.16169]


Accuracy on test: 48.00%


## Multi class classification using Transfer learning

- The guide fuction samples from the weights of the pretrained model

In [104]:
X_wine, y_wine = load_wine(return_X_y=True)
size=100
X_train, X_test, y_train, y_test = train_test_split(X_wine, y_wine, train_size=size, random_state=42)

class WineDataset(torch.utils.data.dataset.Dataset):
    def __init__(self, X, y):
        self.X, self.y = X, y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        return self.X[index,:], self.y[index]
    
train_wine = DataLoader(WineDataset(X_train, y_train),batch_size=100, shuffle=True)
test_wine = DataLoader(WineDataset(X_test,y_test),batch_size=len(X_wine) - 100, shuffle=True)

In [105]:
class WineModel(nn.Module):
    def __init__(self, num_in, n_classes):

        super(WineModel, self).__init__()

        self.fc1 = nn.Linear(num_in, 16)
        self.fc2 = nn.Linear(16, 4)
        self.fc3 = nn.Linear(4, n_classes)

    def forward(self, x):
        x = self.fc1(x)
        x = torch.tanh(x)
        x = self.fc2(x)
        x = torch.tanh(x)
        x = self.fc3(x)
        return x

In [106]:
wine_model = WineModel(13, 3)
fit(wine_model, train_wine, 200, criterion=torch.nn.CrossEntropyLoss(), multi=True, lr=0.01)
score(wine_model, test_wine, multi=True)

100%|██████████| 200/200 [00:00<00:00, 328.05it/s, Loss = 0.70001]


Accuracy on test: 67.949%


In [107]:
def model_wine(data, labels):
    fc1w_prior = dist.Normal(loc=torch.zeros_like(wine_model.fc1.weight),
                             scale=torch.ones_like(wine_model.fc1.weight))
    fc1b_prior = dist.Normal(loc=torch.zeros_like(wine_model.fc1.bias),
                             scale=torch.ones_like(wine_model.fc1.bias))

    fc2w_prior = dist.Normal(loc=torch.zeros_like(wine_model.fc2.weight),
                             scale=torch.ones_like(wine_model.fc2.weight))
    fc2b_prior = dist.Normal(loc=torch.zeros_like(wine_model.fc2.bias),
                             scale=torch.ones_like(wine_model.fc2.bias))


    fc3w_prior = dist.Normal(loc=torch.zeros_like(wine_model.fc3.weight),
                             scale=torch.ones_like(wine_model.fc3.weight))
    fc3b_prior = dist.Normal(loc=torch.zeros_like(wine_model.fc3.bias),
                             scale=torch.ones_like(wine_model.fc3.bias))

    priors = {"fc1w": fc1w_prior,
              "fc1b": fc1b_prior,
              "fc2w": fc2w_prior,
              "fc2b": fc2b_prior,
              "fc3w": fc3w_prior,
              "fc3b": fc3b_prior}

    lifted_module = pyro.random_module("module", wine_model, priors)
    lifted_reg_model = lifted_module()

    probs = torch.nn.functional.log_softmax(lifted_reg_model(data),dim=1)

    pyro.sample("obs", dist.Categorical(logits=probs), obs=labels)
    
def guide_wine(data, labels):
    
    # FC1 weights
    fc1w_mu = torch.randn_like(wine_model.fc1.weight)
    fc1w_sigma = torch.randn_like(wine_model.fc1.weight)
    fc1w_mu_param = pyro.param("fc1w_mu", fc1w_mu)
    fc1w_sigma_param = F.softplus(pyro.param("fc1w_sigma", fc1w_sigma))
    fc1w_approx_post = dist.Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param)
    # FC1 bias
    fc1b_mu = torch.randn_like(wine_model.fc1.bias)
    fc1b_sigma = torch.randn_like(wine_model.fc1.bias)
    fc1b_mu_param = pyro.param("fc1b_mu", fc1b_mu)
    fc1b_sigma_param = F.softplus(pyro.param("fc1b_sigma", fc1b_sigma))
    fc1b_approx_post = dist.Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param)
    
    # FC2 weights
    fc2w_mu = torch.randn_like(wine_model.fc2.weight)
    fc2w_sigma = torch.randn_like(wine_model.fc2.weight)
    fc2w_mu_param = pyro.param("fc2w_mu", fc2w_mu)
    fc2w_sigma_param = F.softplus(pyro.param("fc2w_sigma", fc2w_sigma))
    fc2w_approx_post = dist.Normal(loc=fc2w_mu_param, scale=fc2w_sigma_param)
    # FC2 bias
    fc2b_mu = torch.randn_like(wine_model.fc2.bias)
    fc2b_sigma = torch.randn_like(wine_model.fc2.bias)
    fc2b_mu_param = pyro.param("fc2b_mu", fc2b_mu)
    fc2b_sigma_param = F.softplus(pyro.param("fc2b_sigma", fc2b_sigma))
    fc2b_approx_post = dist.Normal(loc=fc2b_mu_param, scale=fc2b_sigma_param)
        
    # FC3 weights
    fc3w_mu = torch.randn_like(wine_model.fc3.weight)
    fc3w_sigma = torch.randn_like(wine_model.fc3.weight)
    fc3w_mu_param = pyro.param("fc3w_mu", fc3w_mu)
    fc3w_sigma_param = F.softplus(pyro.param("fc3w_sigma", fc3w_sigma))
    fc3w_approx_post = dist.Normal(loc=fc3w_mu_param, scale=fc3w_sigma_param).independent(1)
    # FC2 bias
    fc3b_mu = torch.randn_like(wine_model.fc3.bias)
    fc3b_sigma = torch.randn_like(wine_model.fc3.bias)
    fc3b_mu_param = pyro.param("fc3b_mu", fc3b_mu)
    fc3b_sigma_param = F.softplus(pyro.param("fc3b_sigma", fc3b_sigma))
    fc3b_approx_post = dist.Normal(loc=fc3b_mu_param, scale=fc3b_sigma_param)
    
    posterior = {"fc1w": fc1w_approx_post,
                "fc1b": fc1b_approx_post,
                "fc2w": fc2w_approx_post,
                "fc2b": fc2b_approx_post,
                "fc3w": fc3w_approx_post,
                "fc3b": fc3b_approx_post,}

    lifted_module = pyro.random_module("module", wine_model, posterior)
    
    return lifted_module()

In [108]:
optim = pyro.optim.Adam({"lr": 0.001})
svi_wine = SVI(model_wine, guide_wine, optim, loss=Trace_ELBO())
# train the model
epochs = trange(10)
for i in epochs:
    loss = 0
    for data, labels in train_wine:
        loss = svi_wine.step(data.float(), labels)

    string = f"Loss = {loss:.5f}"
    epochs.set_postfix_str(s=string)
    
score_bnn2(guide_wine, test_wine)

100%|██████████| 10/10 [00:00<00:00, 119.52it/s, Loss = 68.10476]

Accuracy on test: 70.51%





# References

- [Pyro Docs](https://docs.pyro.ai/en/stable/index.html)
- [Pyro Examples](https://pyro.ai/examples/bayesian_regression.html#)
- [Pyro MNIST](https://alsibahi.xyz/snippets/2019/06/15/pyro_mnist_bnn_kl.html)
- [Making Your Neural Network Say “I Don’t Know” — Bayesian NNs using Pyro and PyTorch](https://towardsdatascience.com/making-your-neural-network-say-i-dont-know-bayesian-nns-using-pyro-and-pytorch-b1c24e6ab8cd)
- [Bayesian Neural Networks: 1 Why Bother?](https://towardsdatascience.com/bayesian-neural-networks-1-why-bother-b585375b38ec)
- [Bayesian Neural Networks: 2 Fully Connected in TensorFlow and Pytorch](https://towardsdatascience.com/bayesian-neural-networks-2-fully-connected-in-tensorflow-and-pytorch-7bf65fb4697)
- [Bayesian Neural Networks: 3 Bayesian CNN](https://towardsdatascience.com/bayesian-neural-networks-3-bayesian-cnn-6ecd842eeff3)