# mnist-torch-cnn
https://pytorch.org/tutorials/beginner/nn_tutorial.html#switch-to-cnn

In [6]:
# setup data
import pickle
import gzip

from pathlib import Path


DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"
FILENAME = "mnist.pkl.gz"

with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

In [11]:
# hparams?
bs = 64
n, c = x_train.shape
epochs = 10
lr = 0.1

In [28]:
import torch
from torch.utils.data import TensorDataset, DataLoader


x_train, y_train, x_valid, y_valid = map(
    torch.tensor, (x_train, y_train, x_valid, y_valid)
)

train_ds = TensorDataset(x_train, y_train)
valid_ds = TensorDataset(x_valid, y_valid)

  


In [30]:
def get_data(train_ds, valid_ds, bs):
    return (
        DataLoader(train_ds, batch_size=bs, shuffle=True),
        DataLoader(valid_ds, batch_size=bs * 2)
    )

In [31]:
train_dl, valid_dl = get_data(train_ds, valid_ds, bs)

In [9]:
import math
import numpy as np
import torch.nn.functional as F

from torch import nn, optim

In [19]:
class Mnist_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1)
        
    def forward(self, xb):
        xb = xb.view(-1, 1, 28, 28)
        xb = F.relu(self.conv1(xb))
        xb = F.relu(self.conv2(xb))
        xb = F.relu(self.conv3(xb))
        xb = F.avg_pool2d(xb, 4)
        return xb.view(-1, xb.size(1))

In [20]:
loss_func = F.cross_entropy

In [25]:
def fit(epochs, model, loss_func, opt, train_dl, valid_dl):
    def loss_batch(model, loss_func, xb, yb, opt=None):
        loss = loss_func(model(xb), yb)

        if opt is not None:
            loss.backward()
            opt.step()
            opt.zero_grad()

        return loss.item(), len(xb)
    
    for epoch in range(epochs):
        model.train()
        for xb, yb in train_dl:
            loss_batch(model, loss_func, xb, yb, opt)

        model.eval()
        with torch.no_grad():
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)

        print('{} epoch valid loss: {}'.format(epoch, val_loss))

In [26]:
model = Mnist_CNN()
opt = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

In [None]:
fit(epochs, model, loss_func, opt, train_dl, valid_dl)

0 epoch valid loss: 0.37535191621780395
1 epoch valid loss: 0.24335638041496277
2 epoch valid loss: 0.19641312398910524
3 epoch valid loss: 0.18391424870491027
4 epoch valid loss: 0.17105478954315184
5 epoch valid loss: 0.1625874858856201
6 epoch valid loss: 0.1585569878101349
7 epoch valid loss: 0.15394278473854064
8 epoch valid loss: 0.154742910861969
