In [None]:
import math
from matplotlib import pyplot as plt
import timeit
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

from novelty.models.dists import benford_dist, instrumental_dist
from novelty.models.mlp import MLP
from novelty.models.util import train, train_streaming, train_streaming_unbalanced, test
from novelty.visualization.models import plot_accs

# Vanilla Model Training

In [None]:
use_gpu = False
device = torch.device("mps" if use_gpu else "cpu")

In [None]:
epochs = 10
train_kwargs = {'batch_size': 100}
test_kwargs = {'batch_size': 1000}
if use_gpu:
    cuda_kwargs = {'num_workers': 1,
                   'pin_memory': True,
                   'shuffle': True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # image mean and std 
    ])

In [None]:
dataset1 = datasets.MNIST('../data/raw', train=True, download=True,
                   transform=transform)
dataset2 = datasets.MNIST('../data/raw', train=False,
                   transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

In [None]:
model = MLP().to(device)
optimizer = optim.Adam(model.parameters())

In [None]:
train_losses = []
train_accs = []
train_kept = []
test_losses = []
test_accs = []
for epoch in tqdm(range(1, epochs + 1)):
    train_loss, train_acc, kept = train(model, device, train_loader, optimizer)
    test_loss, test_acc = test(model, device, test_loader)
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    train_kept.append(kept)
    test_losses.append(test_loss)
    test_accs.append(test_acc)

In [None]:
plot_accs(train_accs)

In [None]:
plot_accs(test_accs)

# Streaming Training

In a datastreaming context, we assume that the model can only see each of the 60000 training examples once. Therefore, we'd like to maximize the information gain from each training example. We'll still use 10 epochs, but each epoch will only be 10000 examples. 

In [None]:
examples_per_epoch = 10000
batches_per_epoch = examples_per_epoch//train_kwargs['batch_size']
epochs = len(train_loader.dataset)//examples_per_epoch

In [None]:
model = MLP().to(device)
optimizer = optim.Adam(model.parameters())

In [None]:
train_losses = []
train_accs = []
train_kept = []
test_losses = []
test_accs = []
for epoch in tqdm(range(1, epochs + 1)):
    train_loss, train_acc, kept = train_streaming(model, device, train_loader, optimizer, epoch)
    test_loss, test_acc = test(model, device, test_loader)
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    train_kept.append(kept)
    test_losses.append(test_loss)
    test_accs.append(test_acc)

In [None]:
plot_accs(train_accs)

In [None]:
plot_accs(test_accs)

# Unbalanced Streaming Training

Unlike in the previous streaming context, we'll use rejection sampling to create an unbalanced distribution over the 10 mnist letters as they stream in. The goal here will be to ensure that model sees more 0s than 1s than 2s etc. For an initial probability distribution, I'll use Benford's law, which is based on the frequency of first digits in data: P(d) = log10(1 + 1/d) . Because we have 10 (rather than 9) digits, we'll use P(d) = log11(1+1/d)

In [None]:
benford_probs = benford_dist(torch.arange(0,10))
max_scaler = max(benford_probs)

In [None]:
model = MLP().to(device)
optimizer = optim.Adam(model.parameters())

In [None]:
train_losses = []
train_accs = []
train_kept = []
test_losses = []
test_accs = []
for epoch in tqdm(range(1, epochs + 1)):
    train_loss, train_acc, kept = train_streaming_unbalanced(model, device, train_loader, optimizer, epoch, 
                                                       benford_dist, instrumental_dist, max_scaler, 
                                                       batches_per_epoch)
    test_loss, test_acc = test(model, device, test_loader)
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    train_kept.append(kept)
    test_losses.append(test_loss)
    test_accs.append(test_acc)

In [None]:
plot_accs(train_accs)

In [None]:
plot_accs(test_accs)

In [None]:
plt.plot(torch.stack(train_kept).sum(axis=0))
plt.ylabel('frequency')
plt.xlabel('digit')
plt.xticks(torch.arange(0,10))
plt.ylim(0,7000)
plt.show()

In [None]:
plt.plot(benford_probs)
plt.ylabel('probability')
plt.xlabel('digit')
plt.xticks(torch.arange(0,10))
plt.ylim(0,0.4)
plt.show()