In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import time
import numpy as np
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms

from local_logger import Logger
from blundell import Net, Loss

from matplotlib import pyplot as plt

In [3]:
# use GPU for computation if possible
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
# Load a dataset
def get_mnist(batch_size):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=True, download=True,
        transform=transform), batch_size=batch_size, shuffle=True, num_workers=4)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=False, download=True,
        transform=transform), batch_size=batch_size, shuffle=True, num_workers=4)

    return train_loader, test_loader

In [5]:
model = Net().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50,60,70,80], gamma=0.2)

fmt = {'tr_los': '3.1e', 'te_loss': '3.1e', 'sp_0': '.3f', 'sp_1': '.3f', 'lr': '3.1e', 'kl': '.2f'}
logger = Logger('sparse_vd', fmt=fmt)

train_loader, test_loader = get_mnist(batch_size=100)
loss_function = Loss(model, len(train_loader.dataset)).to(DEVICE)

In [6]:
MNIST_INPUT_SHAPE = 28 * 28
epochs = 40

for epoch in range(1, epochs + 1):
    start = time.time()
    model.train()
    train_loss, train_acc = 0, 0
    logger.add_scalar(epoch, 'lr', scheduler.get_last_lr()[0])
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.to(DEVICE)
        target = target.to(DEVICE)
        data = data.view(-1, MNIST_INPUT_SHAPE)
        optimizer.zero_grad()

        output = model(data)
        loss = loss_function(output, target)
        loss.backward()
        optimizer.step()

        train_loss += loss
        pred = output.data.max(1)[1]
        train_acc += np.sum(pred.cpu().numpy() == target.cpu().data.numpy())

    scheduler.step()

    logger.add_scalar(epoch, 'tr_los', train_loss / len(train_loader.dataset))
    logger.add_scalar(epoch, 'tr_acc', train_acc / len(train_loader.dataset) * 100)


    model.eval()
    test_loss, test_acc = 0, 0
    for batch_idx, (data, target) in enumerate(test_loader):
        data = data.to(DEVICE)
        target = target.to(DEVICE)
        data = data.view(-1, MNIST_INPUT_SHAPE)
        output = model(data)
        test_loss += float(loss_function(output, target)) #, kl_weight))
        pred = output.data.max(1)[1]
        test_acc += np.sum(pred.cpu().numpy() == target.cpu().data.numpy())

    logger.add_scalar(epoch, 'te_loss', test_loss / len(test_loader.dataset))
    logger.add_scalar(epoch, 'te_acc', test_acc / len(test_loader.dataset) * 100)

    for i, c in enumerate(model.children()):
        if hasattr(c, 'kl_reg'):
            logger.add_scalar(epoch, 'kl_reg_%s' % i, -c.kl_reg())

    end = time.time()
    logger.add_scalar(epoch, 'time', end - start)

    logger.iter_info()

  epoch       lr    tr_los    tr_acc    te_loss    te_acc    kl_reg_0    kl_reg_1    kl_reg_2    time
-------  -------  --------  --------  ---------  --------  ----------  ----------  ----------  ------
      1  1.0e-03  -4.0e+02      90.9   -5.3e+02      95.9    537418.8     67082.5      2092.9    12.9
      2  1.0e-03  -5.4e+02      96.5   -5.5e+02      96.9    539422.7     68432.7      2154.2    12.5
      3  1.0e-03  -5.6e+02      97.5   -5.6e+02      97.4    540277.4     68744.6      2218.6    12.5
      4  1.0e-03  -5.7e+02      98.0   -5.7e+02      97.6    540599.0     68796.2      2228.5    12.5
      5  1.0e-03  -5.8e+02      98.3   -5.7e+02      97.5    541003.4     68899.7      2246.5    12.6
      6  1.0e-03  -5.9e+02      98.6   -5.7e+02      97.7    541119.9     68978.2      2263.9    12.1
      7  1.0e-03  -5.9e+02      98.7   -5.6e+02      97.7    541284.3     69139.3      2280.7    12.1
      8  1.0e-03  -5.9e+02      98.9   -5.7e+02      97.8    541377.7     69169.9 

KeyboardInterrupt: 