In [None]:
import math
from matplotlib import pyplot as plt
import timeit
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

from novelty.models.dists import benford_dist, instrumental_dist
from novelty.models.mlp import MLP
from novelty.models.util import train_streaming_unbalanced, test
from novelty.visualization.models import plot_accs

In [None]:
use_gpu = False
device = torch.device("mps" if use_gpu else "cpu")

In [None]:
epochs = 10
train_kwargs = {'batch_size': 1}
test_kwargs = {'batch_size': 1000}
if use_gpu:
    cuda_kwargs = {'num_workers': 1,
                   'pin_memory': True,
                   'shuffle': True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # image mean and std 
    ])

In [None]:
dataset1 = datasets.MNIST('../data/raw', train=True, download=True,
                   transform=transform)
dataset2 = datasets.MNIST('../data/raw', train=False,
                   transform=transform)
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

In [None]:
benford_probs = benford_dist(torch.arange(0,10))
max_scaler = max(benford_probs)

In [None]:
def prob_func_inverse(prob):
    return 0.001 * (1/prob) 

def prob_func_log(prob):
    return 0.01 * -torch.log(prob) / torch.log(torch.tensor(11))

# Naive Novelty Based Updates

Knowing that the samples are drawn from the Benford Distribution, we can attempt to scale our learning rate by the inverse of the probabilty 

In [None]:
examples_per_epoch = 10000
batches_per_epoch = examples_per_epoch//train_kwargs['batch_size']
epochs = len(train_loader.dataset)//examples_per_epoch

In [None]:
model = MLP().to(device)
optimizer = optim.SGD(model.parameters(), lr = 0.01)

In [None]:
train_losses = []
train_accs = []
train_kept = []
test_losses = []
test_accs = []
for epoch in tqdm(range(1, epochs + 1)):
    train_loss, train_acc, kept = train_streaming_unbalanced(model, device, train_loader, optimizer, epoch, 
                                                       benford_dist, instrumental_dist, max_scaler, 
                                                       batches_per_epoch, True, prob_func_inverse)
    test_loss, test_acc = test(model, device, test_loader)
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    train_kept.append(kept)
    test_losses.append(test_loss)
    test_accs.append(test_acc)

In [None]:
plot_accs(train_accs)