# Mini batch training
In this notebook, we are first going to implement cross entropy, then move on to mini-batch training.

## Setup

In [1]:
import pickle,gzip,math,os,time,shutil,torch,matplotlib as mpl,numpy as np,matplotlib.pyplot as plt
from pathlib import Path
from torch import tensor,nn
import torch.nn.functional as F

In [2]:
from fastcore.test import test_close

torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray'

path_data = Path('data')
path_gz = path_data/'mnist.pkl.gz'
with gzip.open(path_gz, 'rb') as f: ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
x_train, y_train, x_valid, y_valid = map(tensor, [x_train, y_train, x_valid, y_valid])

### Data

In [3]:
n,m = x_train.shape
c = y_train.max()+1
nh = 50

In [5]:
class Model(nn.Module):
    def __init__(self, n_in, nh, n_out):
        super().__init__()
        self.layers = [nn.Linear(n_in,nh), nn.ReLU(), nn.Linear(nh,n_out)]
        
    def __call__(self, x):
        for l in self.layers: x = l(x)
        return x

In [8]:
model = Model(m, nh, c)
pred = model(x_train)
pred.shape

torch.Size([50000, 10])

## Cross entropy loss
We implement the log of cross entropy loss (for easier computation).

In [13]:
torch.sum(torch.exp(pred), dim=1, keepdim=True).shape

torch.Size([50000, 1])

Broadcasting will now be automatically applied to divide `pred` with the sum of the exponent of each `x_j`, where `x_j` is the value for each of the class.

In [50]:
def log_softmax(x):
    return x - x.exp().sum(-1, keepdim=True).log()

In [51]:
log_softmax(pred)

tensor([[-1.99, -2.27, -2.32,  ..., -2.51, -2.17, -2.25],
        [-2.08, -2.35, -2.28,  ..., -2.46, -2.15, -2.24],
        [-2.04, -2.33, -2.35,  ..., -2.44, -2.19, -2.26],
        ...,
        [-2.02, -2.21, -2.29,  ..., -2.46, -2.26, -2.30],
        [-2.09, -2.29, -2.29,  ..., -2.41, -2.24, -2.30],
        [-2.08, -2.22, -2.34,  ..., -2.38, -2.25, -2.31]], grad_fn=<SubBackward0>)

In [52]:
log_softmax(pred).shape

torch.Size([50000, 10])

Next, we want to use the LogSumExp trick. The rationale is that we want the gradients to be as precise as possible (suppose we have a region where we would otherwise 'bounce around'). However, the exponent of a value may be very large, which makes the computed value less precise (because of how a computer can lose precision when the magnitude gets very large).

In [53]:
m = pred.max(-1)[0]
m

tensor([0.28, 0.21, 0.27,  ..., 0.29, 0.22, 0.23], grad_fn=<MaxBackward0>)

In [54]:
m[:, None].shape

torch.Size([50000, 1])

In [55]:
pred.shape

torch.Size([50000, 10])

In [56]:
def logsumexp_softmax(x):
    m = x.max(-1)[0]
    return (m + (x - m[:, None]).exp().sum(-1).log()).unsqueeze(-1)

In [57]:
logsumexp_softmax(pred).shape

torch.Size([50000, 1])

In [58]:
logsumexp_softmax(pred)

tensor([[2.28],
        [2.29],
        [2.30],
        ...,
        [2.31],
        [2.31],
        [2.31]], grad_fn=<UnsqueezeBackward0>)

In [59]:
pred.exp().sum(-1, keepdim=True)

tensor([[ 9.75],
        [ 9.90],
        [10.02],
        ...,
        [10.10],
        [10.08],
        [10.10]], grad_fn=<SumBackward1>)

In [60]:
def log_softmax(x):
    return x - logsumexp_softmax(x)

In [61]:
log_softmax(pred)

tensor([[-1.99, -2.27, -2.32,  ..., -2.51, -2.17, -2.25],
        [-2.08, -2.35, -2.28,  ..., -2.46, -2.15, -2.24],
        [-2.04, -2.33, -2.35,  ..., -2.44, -2.19, -2.26],
        ...,
        [-2.02, -2.21, -2.29,  ..., -2.46, -2.26, -2.30],
        [-2.09, -2.29, -2.29,  ..., -2.41, -2.24, -2.30],
        [-2.08, -2.22, -2.34,  ..., -2.38, -2.25, -2.31]], grad_fn=<SubBackward0>)

Alright, our logsumexp trick is working correctly! Meanwhile, Pytorch actually already implements this for us.

In [64]:
def log_softmax(x):
    return x - x.logsumexp(-1, keepdim=True)

In [66]:
sm_pred = log_softmax(pred)
sm_pred

tensor([[-1.99, -2.27, -2.32,  ..., -2.51, -2.17, -2.25],
        [-2.08, -2.35, -2.28,  ..., -2.46, -2.15, -2.24],
        [-2.04, -2.33, -2.35,  ..., -2.44, -2.19, -2.26],
        ...,
        [-2.02, -2.21, -2.29,  ..., -2.46, -2.26, -2.30],
        [-2.09, -2.29, -2.29,  ..., -2.41, -2.24, -2.30],
        [-2.08, -2.22, -2.34,  ..., -2.38, -2.25, -2.31]], grad_fn=<SubBackward0>)

Now that we have calculated the softmax, we can proceed to calculate the cross entropy loss for some target $x$ and some prediction $p(x)$ by $- \sum x \log p(x) $.

However, given that our y values are 'one hot encoded' (in reality they represent an index of the predicted class), we can simplify the above equation to $ - \log(p_i)$, where $p_i$ is the probability of the actual class.

The above is also known as negative log likelihood loss.

In [69]:
y_train

tensor([5, 0, 4,  ..., 8, 4, 8])

In [74]:
sm_pred[torch.arange(sm_pred.shape[0]), y_train].shape, y_train.shape

(torch.Size([50000]), torch.Size([50000]))

In [76]:
sm_pred[torch.arange(sm_pred.shape[0]), y_train].mean()

tensor(-2.31, grad_fn=<MeanBackward0>)

In [82]:
def nll(sm_pred, tgt):
    return -1. * sm_pred[torch.arange(sm_pred.shape[0]), tgt].mean()

In [83]:
nll(sm_pred, y_train)

tensor(2.31, grad_fn=<MulBackward0>)

In Pytorch, softmax and negative log likelihood loss are combined together in the softmax function.

In [84]:
loss = nn.CrossEntropyLoss()
loss(pred, y_train)

tensor(2.31, grad_fn=<NllLossBackward0>)