In [1]:
!rm -rdf __pycache__ *.pyc

import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
# from tqdm import tqdm_notebook as tqdm
import matplotlib.pyplot as plt
from bnn import BayesianNeuralNetwork, KLDivergence
from torchvision import datasets, transforms

epochs = 3
batch_size = 256
learning_rate = 1e-3
validation_frequency = 30
samples = 10

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])),
    batch_size=batch_size, shuffle=True)

n_batches = len(train_loader.dataset) / batch_size


In [3]:
model = BayesianNeuralNetwork(784, 10).to(device)
loss_function = torch.nn.CrossEntropyLoss()
kld = KLDivergence(n_batches)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [5]:
history = pd.DataFrame()

epochs_logger = tqdm(range(1, epochs + 1), desc='epoch')
for epoch in epochs_logger:
    # steps_logger = tqdm(train_loader, desc='step', total=n_batches)
    for step, (x, y) in enumerate(train_loader):
        model.train()

        x, y = x.to(device), y.to(device)
        x = x.view(x.shape[0], -1)
        optimizer.zero_grad()

        # std = []
        loss = 0
        for _ in range(samples):
            pred = model(x)
            loss += loss_function(pred, y) + kld(model)
            # std += [pred.detach().cpu().numpy().argmax(axis=1)]

        # std = np.array(std).std(axis=0).mean()

        loss.backward()
        optimizer.step()

        loss = loss.item()

        if step % validation_frequency == 0:
            # stds = []
            model.eval()
            correct = 0
            test_loss = 0
            with torch.no_grad():
                for test_x, test_y in test_loader:
                    test_x, test_y = test_x.to(device), test_y.to(device)
                    test_x = test_x.view(test_x.shape[0], -1)

                    # std = []
                    for _ in range(samples):
                        pred_test = model(test_x)
                        test_loss += loss_function(pred_test, test_y)
                        # std += [pred_test.detach().cpu().numpy().argmax(axis=1)]

                    # stds.append(np.std(std, axis=0).mean())

                    correct += pred_test.argmax(dim=1,
                                                 keepdim=True).eq(test_y).sum().item()

            # stds = np.std(stds)

        history = history.append({
            'epoch': epoch,
            'step': step,
            'loss': loss,
            'test_loss': test_loss
        }, ignore_index=True)

        history_for_this_epoch = history.query(f'epoch == {epoch}')
        mean_values = history_for_this_epoch.mean(axis=0)
        mean_loss = mean_values['loss']
        mean_test_loss = mean_values['test_loss']

        log_str = f'loss: {mean_loss:.5f}, test_loss: {mean_test_loss:.5f}'
        epochs_logger.set_postfix_str(log_str)
        # steps_logger.set_postfix_str(log_str)


43694][A[A

epoch:   0%|          | 0/3 [00:19<?, ?it/s, loss: 17.63067, test_loss: 702.34354][A[A

epoch:   0%|          | 0/3 [00:19<?, ?it/s, loss: 17.62683, test_loss: 702.25178][A[A

epoch:   0%|          | 0/3 [00:19<?, ?it/s, loss: 17.62452, test_loss: 702.16162][A[A

epoch:   0%|          | 0/3 [00:19<?, ?it/s, loss: 17.61952, test_loss: 702.07301][A[A

epoch:   0%|          | 0/3 [00:20<?, ?it/s, loss: 17.61615, test_loss: 701.98591][A[A

epoch:   0%|          | 0/3 [00:20<?, ?it/s, loss: 17.61548, test_loss: 701.90029][A[A

epoch:   0%|          | 0/3 [00:20<?, ?it/s, loss: 17.61456, test_loss: 701.81611][A[A

epoch:   0%|          | 0/3 [00:20<?, ?it/s, loss: 17.61431, test_loss: 701.73333][A[A

epoch:   0%|          | 0/3 [00:22<?, ?it/s, loss: 17.61048, test_loss: 701.61099][A[A

epoch:   0%|          | 0/3 [00:22<?, ?it/s, loss: 17.61130, test_loss: 701.49065][A[A

epoch:   0%|          | 0/3 [00:22<?, ?it/s, loss: 17.60596, test_loss: 701.37227][A

KeyboardInterrupt: 

In [None]:
model.to('cpu')
with torch.no_grad():
    for images, _ in test_loader:
        for im in images[:10]:
            plot_preds(im, model)
            plt.show()
        plot_preds(torch.randn(1, 28, 28), model)
        plt.show()
        break


def plot_preds(im, model):
    preds = [model(im.view(784)).detach().numpy()
             for _ in range(50)]
    plt.subplot(1, 2, 1)
    plt.imshow(np.transpose(im, [1, 2, 0])[:, :, 0], )
    plt.subplot(1, 2, 2)
    plt.hist(np.argmax(preds, axis=1), bins=10, density=True)
    plt.ylim(0, 1)
    plt.xlim(0, 9)