In [1]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import pyro
from pyro.optim import ClippedAdam
from pyro.infer import SVI, Trace_ELBO, ELBO,  TraceMeanField_ELBO
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.infer.autoguide import AutoDiagonalNormal, AutoGuide

from pyro.contrib.bnn import HiddenLayer

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torch.distributions import constraints

from tqdm import trange, tqdm
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_moons, load_wine
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

np.random.seed(42)
pyro.set_rng_seed(42)
torch.set_default_dtype(torch.float64)

In [2]:
def fit(model,
        loader,
        epochs = 2,
        optimizer=None,
        criterion=nn.BCEWithLogitsLoss(),
        lr=0.001):
    
    epochs = trange(epochs)
    
    if optimizer == None:
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    for epoch in epochs:
        running_loss = []
        for inputs, labels in loader:

            inputs = inputs.double()          
            optimizer.zero_grad()

            
            outputs = model(inputs)
            labels = labels.type_as(outputs)
            
            labels = labels.unsqueeze(1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss.append(loss.item())

        loss = sum(running_loss) / len(running_loss)
        string = f"Loss = {loss:.5f}"
        epochs.set_postfix_str(s=string)

def score(model, dataloader, type_ = 'test'):
    correct=0
    total = 0
    for test_inputs, test_labels in dataloader:

        outputs = model(test_inputs)
        test_inputs = test_inputs.type_as(outputs)
        predicted = torch.round(outputs).detach().numpy()
        acc = accuracy_score(np.array(predicted), np.array(test_labels))
    print(f"Accuracy on {type_}: {100 * acc}%")
    

In [3]:
X_moon, y_moon = make_moons(n_samples=150, noise=.05)

class MoonDataset(torch.utils.data.dataset.Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
    
        return self.X[idx,:], self.y[idx]

moon_split = 100
train_moon = DataLoader(MoonDataset(X_moon[:moon_split], y_moon[:moon_split]),batch_size=moon_split, shuffle=True)
test_moon = DataLoader(MoonDataset(X_moon[moon_split:], y_moon[moon_split:]),batch_size=moon_split, shuffle=True)

In [6]:
class MoonModel(nn.Module):
    def __init__(self, num_in=2):
        super(MoonModel, self).__init__()
        self.fc1 = nn.Linear(num_in, 2)
        self.fc2 = nn.Linear(2, 2)
        
        self.out = nn.Linear(2, 1)

    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        X = torch.tanh(self.fc2(x))
        
        x = torch.sigmoid(self.out(x))
        return x

moon_model = MoonModel(2)

In [7]:
fit(moon_model, train_moon, 10, criterion=nn.BCELoss(), lr=0.001)
score(moon_model, test_moon)

100%|██████████| 10/10 [00:00<00:00, 155.69it/s, Loss = 0.62667]

Accuracy on train: 65.0%
Accuracy on test: 82.0%





In [11]:
class BNN(nn.Module):
    def __init__(self):
        super(BNN, self).__init__()
        self.n_indims = 2
        self.n_hidden = 10
        self.n_classes = 1
       
    def model(self, data, labels=None):
        
        n_data = len(data)
        if labels is not None:
            labels = labels.float()
            
        l1_mean = torch.zeros(self.n_indims, self.n_hidden)
        l1_scale = torch.ones(self.n_indims, self.n_hidden) 
        
        l2_mean = torch.zeros(self.n_hidden + 1, self.n_classes)
        l2_scale = torch.ones(self.n_hidden + 1, self.n_classes)
        
        
        with pyro.plate('data', size=n_data):
            h1 = pyro.sample('h1', HiddenLayer(data, l1_mean,l1_scale, 
                                                   non_linearity=F.relu,))
           
            logits = pyro.sample('logits', HiddenLayer(h1, l2_mean, l2_scale,
                                                           non_linearity=F.sigmoid,
                                                           include_hidden_bias=False))
            
            
            return pyro.sample('label', dist.Bernoulli(logits=logits), obs=labels) 
    
    def guide(self, data, labels=None):
        n_data = len(data)
        if labels is not None:
            labels = labels.float()

        l1_mean = pyro.param('l1_mean', 0.1 * torch.randn(self.n_indims, self.n_hidden))
        l1_scale = pyro.param('l1_scale', 0.1 * torch.ones(self.n_indims, self.n_hidden),
                              constraint=constraints.greater_than(0.01))
        
        l2_mean = pyro.param('l2_mean', 0.1 * torch.randn(self.n_hidden + 1, self.n_classes))
        l2_scale = pyro.param('l2_scale', 0.1 * torch.ones(self.n_hidden + 1, self.n_classes),
                              constraint=constraints.greater_than(0.01))
        
        with pyro.plate('data', size=n_data):
            h1 = pyro.sample('h1', HiddenLayer(data,
                                               l1_mean,
                                               l1_scale, 
                                               non_linearity=torch.tanh))
            
            logits = pyro.sample('logits', HiddenLayer(h1,
                                                       l2_mean,
                                                       l2_scale,
                                                       non_linearity=torch.sigmoid,
                                                       include_hidden_bias=False))
    
    
    def infer(self, loader, lr=0.01, momentum=0.9,
                         num_epochs=30):
        optim = ClippedAdam({'lr': lr})
        elbo =  TraceMeanField_ELBO()
        svi = SVI(self.model, self.guide, optim, elbo)
        epochs = trange(num_epochs)
        for i in epochs:
            for data, labels in loader:
                
                loss = svi.step(data.double(), labels)  / len(labels)                
            string = f"Loss = {loss:.5f}"
            epochs.set_postfix_str(s=string)

    def forward(self, images, n_samples=10):
        res = []
        for i in range(n_samples):
            t = poutine.trace(self.guide).get_trace(images)
            res.append(t.nodes['logits']['value'])
        return torch.stack(res, dim=0) 

    def score(self, dataloader, type_ = 'test'):
        for data, labels in dataloader:            
            predicted = bayesnn.forward(data, 1)
            predicted = predicted.detach().numpy().mean(axis=0).round().squeeze()
            acc = accuracy_score(np.array(predicted), np.array(labels))

        print(f"Accuracy on {type_}: {acc * 100:.2f}%")

In [12]:
pyro.clear_param_store()
bayesnn = BNN()
bayesnn.infer(train_moon, num_epochs=100, lr=0.001)

100%|██████████| 100/100 [00:00<00:00, 232.02it/s, Loss = 72.02252]


In [13]:
bayesnn.score(train_moon, 'train')
bayesnn.score(test_moon)

Accuracy on train: 47.00%
Accuracy on test: 28.00%
