In [1]:
!pip install -Uqq fastbook
import fastbook 
fastbook.setup_book()

[K     |████████████████████████████████| 727kB 5.5MB/s 
[K     |████████████████████████████████| 194kB 17.0MB/s 
[K     |████████████████████████████████| 51kB 5.3MB/s 
[K     |████████████████████████████████| 1.2MB 17.8MB/s 
[K     |████████████████████████████████| 61kB 6.1MB/s 
[K     |████████████████████████████████| 51kB 6.1MB/s 
[?25hMounted at /content/gdrive


In [2]:
from fastai.vision.all import *
from fastbook import *
from time import time

In [3]:
torch.random.manual_seed(42);
torch.set_printoptions(sci_mode=False)

## Load data

In [None]:
path = untar_data(URLs.MNIST)
Path.BASE_PATH = path

In [6]:
digits = DataBlock(blocks=(ImageBlock(cls=PILImageBW), CategoryBlock),
                   get_items=get_image_files,
                   splitter=GrandparentSplitter(train_name='training', valid_name='testing'),
                   get_y=parent_label)

In [27]:
dls = digits.dataloaders(path)
print(dls.valid.one_batch()[0][0].shape)

torch.Size([1, 28, 28])


## Loss function

In [9]:
def softmax(x): 
  return torch.exp(x) / torch.exp(x).sum(dim=1, keepdim=True)

The torch [docs](https://pytorch.org/docs/stable/nn.functional.html#log-softmax) explain that in practise doing softmax() followed by log() is slower and numerically unstable. Note to myself to checkout F.log_softmax() to see what they are doing differently when I'm further into the course.

In [11]:
def neg_log_likelihood(x, targ):
  return (-x[range(len(targ)), targ])

In [12]:
def cross_entropy_loss(acts, targ, reduction="mean"):
  preds = torch.log(softmax(acts))
  return neg_log_likelihood(preds, targ).mean() if reduction=='mean' else neg_log_likelihood(preds, targ)

## Model

In [18]:
def batch_accuracy(preds, yb):
  preds = preds.argmax(dim=1)
  correct = preds == yb
  return correct.float().mean()

In [19]:
class BasicOptimiser:
  def __init__(self, params, lr):
    self.params,self.lr = list(params),lr

  def step(self):
    for p in self.params:
      p.data -= p.grad.data * self.lr

  def zero_grad(self):
    for p in self.params:
      p.grad = None

In [23]:
class BasicLearner:
  def __init__(self, dls: DataLoaders, model, opt_func, loss_function, batch_accuracy):
    self.dls = dls
    self.model = model
    self.opt_func = opt_func(model.parameters(), 0.03)
    self.loss_function = loss_function
    self.batch_accuracy = batch_accuracy

  def validate_epoch(self):
    accs = [self.batch_accuracy(self.model(xb), yb) for xb, yb in self.dls.valid]
    return round(torch.stack(accs).mean().item(), 4)

  def fit(self, epochs):
    for epoch in range(epochs):
      accs = []
      start_time = time()
      for xb, yb in self.dls.train:
        preds = self.model(xb)
        loss = self.loss_function(preds, yb)
        loss.backward()
        self.opt_func.step()
        self.opt_func.zero_grad()
      print(f"Epoch {epoch}, Accuracy: {self.validate_epoch()}, took {time() - start_time:.2f}s")

  def pred(self, xb):
    return self.model(xb).argmax(dim=1)

In [None]:
simple_net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28*28,30),
    nn.ReLU(),
    nn.Linear(30,10)
)
simple_net.to(torch.cuda.current_device())

In [26]:
learner = BasicLearner(dls, simple_net, BasicOptimiser, cross_entropy_loss, batch_accuracy)
learner.fit(10)

Epoch 0, Accuracy: 0.9011, took 72.18s
Epoch 1, Accuracy: 0.9135, took 71.68s
Epoch 2, Accuracy: 0.9241, took 71.44s
Epoch 3, Accuracy: 0.9288, took 71.52s
Epoch 4, Accuracy: 0.9341, took 71.67s


## Experiment: Using median loss

In the lecture someone asked why you wouldn't use median loss, so I decided to try it out.

In [None]:
def mnist_loss_median(preds, target):
  return torch.where(target==1, 1-preds, preds).median()

In [None]:
model = SimpleNet(28*28, 1)
learner_exp1 = BasicLearner(dls, model, BasicOptimiser(model.parameters(), 0.13), mnist_loss_median, batch_accuracy)
learner_exp1.fit(10)

Epoch 0, Accuracy: 0.5068
Epoch 1, Accuracy: 0.5068
Epoch 2, Accuracy: 0.5068
Epoch 3, Accuracy: 0.5068
Epoch 4, Accuracy: 0.5166
Epoch 5, Accuracy: 0.5552
Epoch 6, Accuracy: 0.6074
Epoch 7, Accuracy: 0.6519
Epoch 8, Accuracy: 0.7769
Epoch 9, Accuracy: 0.8369


It seems that even though it converges at some point in this case, it just trains slower - don't have a strong intuition for why that is yet.