In [1]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import pyro
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO, ELBO
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.infer.autoguide import AutoDiagonalNormal, AutoGuide

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

from tqdm import trange, tqdm
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_moons, load_wine
from sklearn.model_selection import train_test_split

np.random.seed(42)
pyro.set_rng_seed(42)
torch.set_default_dtype(torch.float64)

In [2]:
def fit(model,
        loader,
        epochs = 2,
        optimizer=None,
        criterion=nn.BCEWithLogitsLoss()):
    
    learning_rate = 0.001
    epochs = trange(epochs)

    if optimizer == None:
        optimizer=torch.optim.Adam(model.parameters(), lr=learning_rate)
    for epoch in epochs:
        
        running_loss = []
        for i, batch in enumerate(loader):
            inputs, labels = batch['X'], batch['y']
            
            
            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss.append(loss.item())

            freq_checks = 5
        loss = sum(running_loss) / len(running_loss)
        string = f"Loss = {loss:.5f}"
        epochs.set_postfix_str(s=string)

    print(f"Final loss = {loss:.5f}")
    
def score(model, dataloader, type_ = 'test'):
    for data in dataloader:
        test_inputs, test_labels = data['X'], data['y']

#         print(data)
        outputs = model(test_inputs)
#         print()
#         print(outputs)
        _, predicted = torch.max(outputs, 1)
#         print(predicted)
#         print(test_labels)
    accuracy = sum(np.array(predicted) == np.array(test_labels))/len(predicted)
    print(f"Accuracy on {type_}: {accuracy*100}%")
    
def predict(data, guide):
    num_samples = 10
    sampled_models = [guide(None) for _ in range(num_samples)]
    predictions = [m(data).data for m in sampled_models]
    mean = torch.mean(torch.stack(predictions), 0)
    return np.argmax(mean.numpy(), axis=1)

def score_bnn(guide, dataloader, type_ = 'test'):
    correct = 0
    total = 0
    for  data in dataloader:
        test_inputs, test_labels = data['X'], data['y']

        predicted = predict(test_inputs, guide)
        total += test_labels.size(0)
        correct += (np.array(predicted) == np.array(test_labels)).sum().item()

    print(f"Accuracy on {type_}: {100 * correct / total:.2f}%")

In [8]:
X_wine, y_wine = load_wine(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X_wine, y_wine, train_size=100, random_state=42)

class WineDataset(torch.utils.data.dataset.Dataset):
    def __init__(self, X, y):
        self.X, self.y = X, y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        return {"X": self.X[index,:], "y": self.y[index]}
    
train_wine = DataLoader(WineDataset(X_train, y_train), shuffle=True)
test_wine = DataLoader(WineDataset(X_test,y_test), shuffle=True)

In [9]:
class WineModel(nn.Module):
    def __init__(self, num_in, n_classes):

        super(WineModel, self).__init__()

        self.fc1 = nn.Linear(num_in, 10)
        self.fc2 = nn.Linear(10, 5)
        self.fc3 = nn.Linear(5, n_classes)

    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        x = self.fc3(x)
        return x

In [10]:
wine_model = WineModel(13, 3)
fit(wine_model, train_wine, 10, criterion=torch.nn.CrossEntropyLoss())
score(wine_model, train_wine, 'train')
score(wine_model, test_wine)

100%|██████████| 10/10 [00:01<00:00,  9.98it/s, Loss = 0.84979]


Final loss = 0.84979
Accuracy on train: 100.0%
Accuracy on test: 100.0%


In [6]:
def model_wine(data):
    fc1w_prior = dist.Normal(loc=torch.zeros_like(wine_model.fc1.weight),
                             scale=torch.ones_like(wine_model.fc1.weight))
    fc1b_prior = dist.Normal(loc=torch.zeros_like(wine_model.fc1.bias),
                             scale=torch.ones_like(wine_model.fc1.bias))

    fc2w_prior = dist.Normal(loc=torch.zeros_like(wine_model.fc2.weight),
                             scale=torch.ones_like(wine_model.fc2.weight))
    fc2b_prior = dist.Normal(loc=torch.zeros_like(wine_model.fc2.bias),
                             scale=torch.ones_like(wine_model.fc2.bias))


    fc3w_prior = dist.Normal(loc=torch.zeros_like(wine_model.fc3.weight),
                             scale=torch.ones_like(wine_model.fc3.weight))
    fc3b_prior = dist.Normal(loc=torch.zeros_like(wine_model.fc3.bias),
                             scale=torch.ones_like(wine_model.fc3.bias))

    priors = {"fc1w": fc1w_prior,
              "fc1b": fc1b_prior,
              "fc2w": fc2w_prior,
              "fc2b": fc2b_prior,
              "fc3w": fc3w_prior,
              "fc3b": fc3b_prior}

    lifted_module = pyro.random_module("module", wine_model, priors)
    lifted_reg_model = lifted_module()

    probs = torch.nn.functional.log_softmax(lifted_reg_model(data["X"]),dim=1)

    pyro.sample("obs", dist.Categorical(logits=probs), obs=data["y"])
    
def guide_wine(data):
    
    # FC1 weights
    fc1w_mu = torch.randn_like(wine_model.fc1.weight)
    fc1w_sigma = torch.randn_like(wine_model.fc1.weight)
    fc1w_mu_param = pyro.param("fc1w_mu", fc1w_mu)
    fc1w_sigma_param = F.softplus(pyro.param("fc1w_sigma", fc1w_sigma))
    fc1w_approx_post = dist.Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param)
    # FC1 bias
    fc1b_mu = torch.randn_like(wine_model.fc1.bias)
    fc1b_sigma = torch.randn_like(wine_model.fc1.bias)
    fc1b_mu_param = pyro.param("fc1b_mu", fc1b_mu)
    fc1b_sigma_param = F.softplus(pyro.param("fc1b_sigma", fc1b_sigma))
    fc1b_approx_post = dist.Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param)
    
    # FC2 weights
    fc2w_mu = torch.randn_like(wine_model.fc2.weight)
    fc2w_sigma = torch.randn_like(wine_model.fc2.weight)
    fc2w_mu_param = pyro.param("fc2w_mu", fc2w_mu)
    fc2w_sigma_param = F.softplus(pyro.param("fc2w_sigma", fc2w_sigma))
    fc2w_approx_post = dist.Normal(loc=fc2w_mu_param, scale=fc2w_sigma_param)
    # FC2 bias
    fc2b_mu = torch.randn_like(wine_model.fc2.bias)
    fc2b_sigma = torch.randn_like(wine_model.fc2.bias)
    fc2b_mu_param = pyro.param("fc2b_mu", fc2b_mu)
    fc2b_sigma_param = F.softplus(pyro.param("fc2b_sigma", fc2b_sigma))
    fc2b_approx_post = dist.Normal(loc=fc2b_mu_param, scale=fc2b_sigma_param)
        
    # FC3 weights
    fc3w_mu = torch.randn_like(wine_model.fc3.weight)
    fc3w_sigma = torch.randn_like(wine_model.fc3.weight)
    fc3w_mu_param = pyro.param("fc3w_mu", fc3w_mu)
    fc3w_sigma_param = F.softplus(pyro.param("fc3w_sigma", fc3w_sigma))
    fc3w_approx_post = dist.Normal(loc=fc3w_mu_param, scale=fc3w_sigma_param).independent(1)
    # FC2 bias
    fc3b_mu = torch.randn_like(wine_model.fc3.bias)
    fc3b_sigma = torch.randn_like(wine_model.fc3.bias)
    fc3b_mu_param = pyro.param("fc3b_mu", fc3b_mu)
    fc3b_sigma_param = F.softplus(pyro.param("fc3b_sigma", fc3b_sigma))
    fc3b_approx_post = dist.Normal(loc=fc3b_mu_param, scale=fc3b_sigma_param)
    
    posterior = {"fc1w": fc1w_approx_post,
                        "fc1b": fc1b_approx_post,
                        "fc2w": fc2w_approx_post,
                        "fc2b": fc2b_approx_post,
                        "fc3w": fc3w_approx_post,
                        "fc3b": fc3b_approx_post,}

    lifted_module = pyro.random_module("module", wine_model, posterior)
    
    return lifted_module()

In [7]:
optim = pyro.optim.Adam({"lr": 0.001})
svi_wine = SVI(model_wine, guide_wine, optim, loss=Trace_ELBO())
loss = 0
epochs = trange(10)
total_epoch_loss_train = 0
for i in epochs:
    loss = 0
    for batch_id, data in enumerate(train_wine):
        # calculate the loss and take a gradient step
        loss += svi_wine.step(data)
    total_epoch_loss_train = loss / len(train_wine.dataset)

    string = f"Loss = {total_epoch_loss_train:.5f}"
    epochs.set_postfix_str(s=string)
    
score_bnn(guide_wine, train_wine, 'train')
score_bnn(guide_wine, test_wine)

100%|██████████| 10/10 [00:05<00:00,  1.81it/s, Loss = 1.08650]


Accuracy on train: 42.00%
Accuracy on test: 37.18%
