In [9]:
!rm -rdf __pycache__ *.pyc

import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from bnn import BayesianNeuralNetwork, KLDivergence
from torchvision import datasets, transforms

epochs = 5
batch_size = 256
learning_rate = 1e-3
validation_frequency = 100
samples = 10

device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])),
    batch_size=batch_size, shuffle=True)

n_batches = int(np.ceil(len(train_loader.dataset) / batch_size))


In [4]:
model = BayesianNeuralNetwork(784, 10).to(device)
loss_function = torch.nn.CrossEntropyLoss()
kld = KLDivergence(n_batches)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [5]:
def plot_lines(df, columns, colors, ax, alpha=0.25, show_range=False, window_size=1):
    for color, column in zip(colors, columns):
        agg_df = df.groupby('epoch')[column]

        if window_size > 1:
            agg_df = agg_df.mean().rolling(window_size, min_periods=1)

        means = agg_df.mean()
        ax.plot(np.arange(len(means)), means, c=color)

        if show_range:
            mins = agg_df.min()
            maxs = agg_df.max()
            ax.fill_between(x=np.arange(len(means)),
                            y1=mins, y2=maxs, alpha=alpha)

    ax.legend(columns)

In [6]:
def entropy(x):
    return -(x * torch.log(x + 1e-10)).sum(dim=-1)

In [8]:
history = pd.DataFrame()

epochs_logger = tqdm(range(1, epochs + 1), desc='epoch', position=0, leave=True) 
for epoch in epochs_logger:
    steps_logger = tqdm(train_loader, desc='step', total=n_batches, position=0, leave=True) 
    for step, (x, y) in enumerate(steps_logger):
        model.train()

        x, y = x.to(device), y.to(device)
        x = x.view(x.shape[0], -1)
        optimizer.zero_grad()

        preds = []
        likelihood = 0
        divergence = 0
        for _ in range(samples):
            pred = model(x)
            preds.append(pred)
            likelihood += loss_function(pred, y)
            divergence += kld(model)

        loss = (likelihood + divergence) / samples
        loss.backward()
        optimizer.step()
        ent = entropy(torch.stack(preds, dim=-1))

        if step % validation_frequency == 0:
            # todo idk why but the behaviour is weird
            model.eval()
            test_loss = 0
            with torch.no_grad():
                for test_x, test_y in test_loader:
                    test_x, test_y = test_x.to(device), test_y.to(device)
                    test_x = test_x.view(test_x.shape[0], -1)

                    for _ in range(samples):
                        pred_test = model(test_x)
                        test_loss += loss_function(pred_test, test_y)

            test_loss /= (len(test_loader.dataset) / batch_size) * samples
            
        history = history.append({
            'epoch': epoch,
            'step': step,
            'loss': loss.item(),
            'entropy': ent.mean().item(),
            'likelihood': likelihood.item(),
            'divergence': divergence.item(),
            'test_loss': test_loss.item()
        }, ignore_index=True)

        history_for_this_epoch = history.query(f'epoch == {epoch}')
        mean_values = history_for_this_epoch.mean(axis=0)
        mean_loss = mean_values['loss']
        mean_entropy = mean_values['entropy']
        mean_likelihood = mean_values['likelihood']
        mean_divergence = mean_values['divergence']
        mean_test_loss = mean_values['test_loss']

        log_str = f'loss: {mean_loss:.5f}, entropy: {mean_entropy:.5f}, likelihood: {mean_likelihood:.5f},' + \
                  f'divergence: {mean_divergence:.5f}, test_loss: {mean_test_loss:.5f}'
        steps_logger.set_postfix_str(log_str)
    epochs_logger.set_postfix_str(log_str)

fig, axs = plt.subplots(1, 2, figsize=(5, 10))
plot_lines(history, ['loss', 'test_loss'], ['blue', 'orange'], axs[0], show_range=True)
plot_lines(history, ['likelihood', 'divergence'], ['green', 'red'], axs[1], show_range=True)

step: 100%|██████████| 235/235 [00:47<00:00,  5.65it/s, loss: 2.32853, entropy: 0.07266, likelihood: 46.43641,divergence: 0.13420, test_loss: 2.37901]
step: 100%|██████████| 235/235 [00:48<00:00,  5.67it/s, loss: 2.32492, entropy: 0.07238, likelihood: 46.36428,divergence: 0.13405, test_loss: 2.37444]
step:  37%|███▋      | 86/235 [00:17<00:27,  5.47it/s, loss: 2.32262, entropy: 0.07213, likelihood: 46.31852,divergence: 0.13393, test_loss: 2.37260]

KeyboardInterrupt: 

In [None]:
model.to('cpu')
with torch.no_grad():
    for images, _ in test_loader:
        for im in images[:10]:
            plot_preds(im, model)
            plt.show()
        plot_preds(torch.randn(1, 28, 28), model)
        plt.show()
        break


def plot_preds(im, model):
    preds = [model(im.view(784)).detach().numpy()
             for _ in range(50)]
    plt.subplot(1, 2, 1)
    plt.imshow(np.transpose(im, [1, 2, 0])[:, :, 0], )
    plt.subplot(1, 2, 2)
    plt.hist(np.argmax(preds, axis=1), bins=10, density=True)
    plt.ylim(0, 1)
    plt.xlim(0, 9)