# Minibatch Training KMNIST

In [None]:
import pickle, gzip, math, os, time, shutil, torch, matplotlib as mpl, numpy as np
import numpy as np
from torch import tensor
from pathlib import Path
from fastcore.test import test_close

from torch import nn
import torch.nn.functional as F

torch.manual_seed(42)

mpl.rcParams['image.cmap'] = 'gray'
torch.set_printoptions(precision=2, linewidth=125, sci_mode=False)
np.set_printoptions(precision=2, linewidth=125)

path_data = Path('data')

train_images_file = 'kmnist-train-imgs.npz'
train_labels_file = 'kmnist-train-labels.npz'
test_images_file = 'kmnist-test-imgs.npz'
test_labels_file = 'kmnist-test-labels.npz'

x_train = np.load(path_data/train_images_file)['arr_0'].reshape(-1, 784)/255.0
x_train = np.float32(x_train)
y_train = np.load(path_data/train_labels_file)['arr_0']
x_valid = np.load(path_data/test_images_file)['arr_0'].reshape(-1, 784)/255.0
x_valid = np.float32(x_valid)
y_valid = np.load(path_data/test_labels_file)['arr_0']

x_train, y_train, x_valid, y_valid = map(tensor, (x_train, y_train, x_valid, y_valid))

## Initial Setup

### Data

In [None]:
n,m = x_train.shape
c = y_train.max() + 1
nh = 50
n, m, c

(60000, 784, tensor(10, dtype=torch.uint8))

In [None]:
class Model(nn.Module):
    def __init__(self, n_in, nh, n_out):
        super().__init__()
        self.layers = [nn.Linear(n_in, nh), nn.ReLU(), nn.Linear(nh, n_out)]
    
    def __call__(self, x):
        for l in self.layers: x = l(x)
        return x

In [None]:
model = Model(m, nh, 10)
pred = model(x_train)
pred.shape

torch.Size([60000, 10])

### Cross Entropy Loss

First, we need to compute the softmax of our activations. This defined by: 
    
$$ \hbox{softmax(x)}_{i} = \frac{e^{x_{i}}}{e^{x_{0}} + e^{x_{1}} + \cdots + e^{x_{n-1}}}$$

or more concisely:

$$ \hbox{softmax(x)}_{i} =  \frac{e^{x_{i}}}{\sum\limits_{0 \leq j \lt n} e^{x_{j}}}$$

In practice, we will need the log of softmax when we calculate the loss

In [None]:
def log_softmax(x): return (x.exp() / (x.exp().sum(-1, keepdim=True))).log()

In [None]:
log_softmax(pred)

tensor([[-2.35, -2.37, -2.08,  ..., -2.47, -2.45, -2.20],
        [-2.38, -2.34, -2.07,  ..., -2.50, -2.55, -2.17],
        [-2.36, -2.48, -1.95,  ..., -2.47, -2.46, -2.13],
        ...,
        [-2.39, -2.39, -2.04,  ..., -2.42, -2.46, -2.40],
        [-2.45, -2.24, -2.22,  ..., -2.42, -2.46, -2.11],
        [-2.30, -2.22, -2.09,  ..., -2.43, -2.44, -2.29]], grad_fn=<LogBackward0>)

Note that the formula

$$ \log \left ( \frac{a}{b} \right ) = \log(a) - \log(b)$$

gives a simplification when we compute the log softmax

In [None]:
def log_softmax(x): return x - x.exp().sum(-1, keepdim=True).log()

Then, there is a way to compute the log of the sum of exponentials in a more stable way, called the LogSumExp trick. The idea is to use the following formula

$$\log \left ( \sum_{j=1}^{n} e^{x_{j}} \right ) = \log \left ( e^{a} \sum_{j=1}^{n} e^{x_{j}-a} \right ) = a + \log \left ( \sum_{j=1}^{n} e^{x_{j}-a} \right )$$

where a is the max of x

In [None]:
def logsumexp(x):
    m = x.max(-1)[0]
    return m + (x  - m[:, None]).exp().sum(-1).log()

This way, we will avoid an overflow when taking the exponential of a big activation. In PyTorch, this is already implemented for us

In [None]:
def log_softmax(x): return x - x.logsumexp(-1, keepdim=True)

In [None]:
test_close(logsumexp(pred), pred.logsumexp(-1))
sm_pred = log_softmax(pred)
sm_pred

tensor([[-2.35, -2.37, -2.08,  ..., -2.47, -2.45, -2.20],
        [-2.38, -2.34, -2.07,  ..., -2.50, -2.55, -2.17],
        [-2.36, -2.48, -1.95,  ..., -2.47, -2.46, -2.13],
        ...,
        [-2.39, -2.39, -2.04,  ..., -2.42, -2.46, -2.40],
        [-2.45, -2.24, -2.22,  ..., -2.42, -2.46, -2.11],
        [-2.30, -2.22, -2.09,  ..., -2.43, -2.44, -2.29]], grad_fn=<SubBackward0>)

In [None]:
y_train[:3]

tensor([8, 7, 0], dtype=torch.uint8)

In [None]:
sm_pred[0, 8], sm_pred[1, 7], sm_pred[2, 0]

(tensor(-2.45, grad_fn=<SelectBackward0>),
 tensor(-2.50, grad_fn=<SelectBackward0>),
 tensor(-2.36, grad_fn=<SelectBackward0>))

In [None]:
sm_pred.shape

torch.Size([60000, 10])

In [None]:
y_train[:3].tolist()

[8, 7, 0]

In [None]:
sm_pred[[0, 1, 2],y_train[:3].long() ]

tensor([-2.45, -2.50, -2.36], grad_fn=<IndexBackward0>)

In [None]:
def nll(input, target): return -input[range(target.shape[0]), target].mean()

In [None]:
loss = nll(sm_pred, y_train.long())
loss

tensor(2.31, grad_fn=<NegBackward0>)

Then use PyTorch's implementation

In [None]:
test_close(F.nll_loss(F.log_softmax(pred, -1), y_train.long()), loss, 1e-3)

In PyTorch, `F.logsoftmax` and `F.nll_loss` are combined in one optimized function `F.cross_entropy`

In [None]:
test_close(F.cross_entropy(pred, y_train), loss, 1e-3)

### Basic training loop

Basically the training loop repeats over the following steps:
- get the output of the model on batch of inputs
- compare the output to the lables we have and compute loss
- calculate the gradients of the loss with respect to every parameter of the model
- update said parameters with those gradients to make them a little bit better

In [None]:
loss_func = F.cross_entropy

In [None]:
bs = 64

xb = x_train[0:bs]
preds = model(xb)
preds[0], preds.shape

(tensor([-0.07, -0.09,  0.20, -0.14,  0.10, -0.09,  0.07, -0.19, -0.17,  0.08], grad_fn=<SelectBackward0>),
 torch.Size([64, 10]))

In [None]:
yb = y_train[0:bs]
loss_func(preds, yb)

tensor(2.33, grad_fn=<NllLossBackward0>)

In [None]:
torch.argmax(preds, dim=1)

tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 9, 2, 4, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 2, 6, 2, 2, 2, 2, 4, 2, 2, 2,
        2, 9, 2, 2, 9, 1, 2, 2, 1, 2, 6, 3, 9, 1, 2, 2, 9, 2, 1, 2, 2, 3, 4, 2, 2])

In [None]:
def accuracy(out, yb): return (torch.argmax(out, dim=1) == yb).float().mean()

In [None]:
accuracy(preds, yb)

tensor(0.08)

In [None]:
lr = 0.5 # learning rate
epochs = 3 # how many epochs to train for

In [None]:
for epoch in range(epochs):
    for i in range(0, n, bs):
        s = slice(i, min(n, i+bs))
        xb, yb = x_train[s], y_train[s]
        preds = model(xb)
        loss = loss_func(preds, yb)
        loss.backward()
        if i==0: print(loss.item(), accuracy(preds, yb).item())
        with torch.no_grad():
            for l in model.layers:
                if hasattr(l, 'weight'):
                    l.weight -= l.weight.grad * lr
                    l.bias -= l.bias.grad * lr
                    l.weight.grad.zero_()
                    l.bias.grad.zero_()

2.154031991958618 0.28125
0.4318438172340393 0.890625
0.3388446569442749 0.921875
