# Intro to PyTorch with MNIST

The MNIST dataset is well understood, and easy to train a model on.
The point of this notebook is to understand how to build models with PyTorch.

In [None]:
%config InlineBackend.figure_format = 'svg'
%matplotlib inline

import gzip
import math
import pathlib
import pickle
import requests

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

sns.set()

In [None]:
def load_data():
    """Download and extract the MNIST dataset."""
    url = "http://deeplearning.net/data/mnist/"
    filename = "mnist.pkl.gz"
    path = pathlib.Path("../../data/mnist/")

    path.mkdir(parents=True, exist_ok=True)
    if not (path / filename).exists():
        content = requests.get(url + filename).content
        (path / filename).open("wb").write(content)

    with gzip.open((path / filename).as_posix(), "rb") as f:
        ((x_train, y_train), (x_valid, y_valid), (x_test, y_test)) = pickle.load(
            f, encoding="latin-1"
        )
    # Convert the Numpy arrays to Torch tensors. (Will use same memory)
    x_train, y_train, x_valid, y_valid, x_test, y_test = map(
        torch.tensor, (x_train, y_train, x_valid, y_valid, x_test, y_test)
    )
    return x_train, y_train, x_valid, y_valid, x_test, y_test

In [None]:
x_train, y_train, x_valid, y_valid, x_test, y_test = load_data()

Each image in the dataset is a flattened $28 \times 28$ image.
We must reshape them before we can view them.

In [None]:
plt.imshow(x_train[0].reshape((28, 28)))
plt.show()

## From Scratch

This model is simple enough to implement with tensor operations directly.
The point of doing so is to

1. Gain familiarity with PyTorch
2. Show the benefits of the higher level API when used properly

In [None]:
# Uses Xavier initialization.
weights = torch.randn(784, 10) / math.sqrt(784)
# Functions ending with underscores operate in-place.
weights.requires_grad_()
bias = torch.zeros(10, requires_grad=True)

In [None]:
def log_softmax(x):
    return x - x.exp().sum(-1).log().unsqueeze(-1)

def model(xb):
    return log_softmax(xb @ weights + bias)

In [None]:
batch_size = 64

xb = x_train[0:batch_size]  # one mini-batch from x
preds = model(xb)  # predictions
preds[0], preds.shape
print(preds[0], preds.shape)

In [None]:
def nll(input, target):
    """Negative log-likelihood"""
    return -input[range(target.shape[0]), target].mean()

def accuracy(out, yb):
    preds = torch.argmax(out, dim=1)
    return (preds == yb).float().mean()

loss_func = nll

In [None]:
yb = y_train[0:batch_size]
print(loss_func(preds, yb))
print(accuracy(preds, yb))

In [None]:
lr = 0.5  # learning rate
epochs = 2  # how many epochs to train for
n, c = x_train.shape

for epoch in range(epochs):
    for i in range((n - 1) // batch_size + 1):
        start_i = i * batch_size
        end_i = start_i + batch_size
        xb = x_train[start_i:end_i]
        yb = y_train[start_i:end_i]
        pred = model(xb)
        loss = loss_func(pred, yb)

        loss.backward()
        with torch.no_grad():
            weights -= weights.grad * lr
            bias -= bias.grad * lr
            weights.grad.zero_()
            bias.grad.zero_()

In [None]:
print(loss_func(model(xb), yb), accuracy(model(xb), yb))

## Using PyTorch `nn` Module

All of the little nitty-gritty details we've done above can mostly be abstracted away a la Keras using the `torch.nn` helper classes.

In [None]:
# cross entropy combines log softmax with negative log likelihood.
loss_func = F.cross_entropy


def model(xb):
    return xb @ weights + bias

print(loss_func(model(xb), yb), accuracy(model(xb), yb))

In [None]:
class MnistModel(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Anything tunable should be created in __init__()
        self.linear = nn.Linear(784, 10)
    
    def forward(self, minibatch):
        """Does one forward pass on the given minibatch."""
        return self.linear(minibatch)

In [None]:
model = MnistModel()
print(loss_func(model(xb), yb))

In [None]:
def fit():
    for epoch in range(epochs):
        for i in range((n - 1) // batch_size + 1):
            start_i = i * batch_size
            end_i = start_i + batch_size
            xb = x_train[start_i:end_i]
            yb = y_train[start_i:end_i]
            pred = model(xb)
            loss = loss_func(pred, yb)

            loss.backward()
            with torch.no_grad():
                for p in model.parameters():
                    p -= p.grad * lr
                model.zero_grad()

In [None]:
fit()
print(loss_func(model(xb), yb))

In [None]:
def get_model():
    model = MnistModel()
    return model, optim.SGD(model.parameters(), lr=lr)

model, opt = get_model()
print(loss_func(model(xb), yb))

for epoch in range(epochs):
    for i in range((n - 1) // batch_size + 1):
        start_i = i * batch_size
        end_i = start_i + batch_size
        xb = x_train[start_i:end_i]
        yb = y_train[start_i:end_i]
        pred = model(xb)
        loss = loss_func(pred, yb)

        loss.backward()
        opt.step()
        opt.zero_grad()

print(loss_func(model(xb), yb))

In [None]:
train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

valid_ds = TensorDataset(x_valid, y_valid)
# There's no backprop, so we can use more memory.
valid_dl = DataLoader(valid_ds, batch_size=batch_size * 2)

model, opt = get_model()

for epoch in range(epochs):
    # Inform special layers that they're being trained.
    model.train()
    for xb, yb in train_dl:
        pred = model(xb)
        loss = loss_func(pred, yb)

        loss.backward()
        opt.step()
        opt.zero_grad()

    model.eval()
    with torch.no_grad():
        valid_loss = sum(loss_func(model(xb), yb) for xb, yb in valid_dl)

    print(epoch, valid_loss / len(valid_dl))

In [None]:
def loss_batch(model, loss_func, xb, yb, opt=None):
    loss = loss_func(model(xb), yb)

    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    return loss.item(), len(xb)

In [None]:
def fit(epochs, model, loss_func, opt, train_dl, valid_dl):
    for epoch in range(epochs):
        model.train()
        for xb, yb in train_dl:
            loss_batch(model, loss_func, xb, yb, opt)

        model.eval()
        with torch.no_grad():
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)

        print(epoch, val_loss)


In [None]:
def get_data(train_ds, valid_ds, bs):
    return (
        DataLoader(train_ds, batch_size=bs, shuffle=True),
        DataLoader(valid_ds, batch_size=bs * 2),
    )

In [None]:
train_dl, valid_dl = get_data(train_ds, valid_ds, batch_size)
model, opt = get_model()
fit(epochs, model, loss_func, opt, train_dl, valid_dl)

In [None]:
class MnistCnn(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1)

    def forward(self, xb):
        xb = xb.view(-1, 1, 28, 28)
        xb = F.relu(self.conv1(xb))
        xb = F.relu(self.conv2(xb))
        xb = F.relu(self.conv3(xb))
        xb = F.avg_pool2d(xb, 4)
        return xb.view(-1, xb.size(1))

lr = 0.1
model = MnistCnn()
opt = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

fit(epochs, model, loss_func, opt, train_dl, valid_dl)

In [None]:
epochs = 5
fit(epochs, model, loss_func, opt, train_dl, valid_dl)

In [None]:
from torchviz import make_dot

x = torch.randn(28, 28)
make_dot(model(x), params=dict(model.named_parameters()))