# mnist torch cnn with nn.Sequential
https://pytorch.org/tutorials/beginner/nn_tutorial.html#nn-sequential

In [1]:
# setup data
import pickle
import gzip

from pathlib import Path


DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"
FILENAME = "mnist.pkl.gz"

with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

In [2]:
# hparams?
bs = 64
n, c = x_train.shape
epochs = 10
lr = 0.1

In [26]:
import torch
from torch.utils.data import TensorDataset, DataLoader


x_train, y_train, x_valid, y_valid = map(
    torch.tensor, (x_train, y_train, x_valid, y_valid)
)

train_ds = TensorDataset(x_train, y_train)
valid_ds = TensorDataset(x_valid, y_valid)

  


In [29]:
def get_data(train_ds, valid_ds, bs):
    return (
        DataLoader(train_ds, batch_size=bs, shuffle=True),
        DataLoader(valid_ds, batch_size=bs * 2)
    )

In [41]:
def preprocess(x, y):
    return x.view(-1, 1, 28, 28), y

class WrappedDataLoader:
    def __init__(self, dl, func):
        self.dl = dl
        self.func = func
        
    def __len__(self):
        return len(self.dl)
    
    def __iter__(self):
        batches = iter(self.dl)
        for b in batches:
            yield (self.func(*b))

In [42]:
train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
train_dl = WrappedDataLoader(train_dl, preprocess)
valid_dl = WrappedDataLoader(valid_dl, preprocess)

In [43]:
import math
import numpy as np
import torch.nn.functional as F

from torch import nn, optim

In [44]:
class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func
    
    def forward(self, x):
        return self.func(x)

In [45]:
model = nn.Sequential(
    nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    nn.Conv2d(16, 10, kernel_size=3, stride=2, padding=1),
    nn.ReLU(),
    nn.AvgPool2d(4),
    Lambda(lambda x: x.view(x.size(0), -1))
)

In [46]:
opt = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

In [47]:
loss_func = F.cross_entropy

In [48]:
def fit(epochs, model, loss_func, opt, train_dl, valid_dl):
    def loss_batch(model, loss_func, xb, yb, opt=None):
        loss = loss_func(model(xb), yb)

        if opt is not None:
            loss.backward()
            opt.step()
            opt.zero_grad()

        return loss.item(), len(xb)
    
    for epoch in range(epochs):
        model.train()
        for xb, yb in train_dl:
            loss_batch(model, loss_func, xb, yb, opt)

        model.eval()
        with torch.no_grad():
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)

        print('{} epoch valid loss: {}'.format(epoch, val_loss))

In [49]:
fit(epochs, model, loss_func, opt, train_dl, valid_dl)

0 epoch valid loss: 0.6199473428726197
1 epoch valid loss: 0.49880292506217955
2 epoch valid loss: 0.47153973088264467
3 epoch valid loss: 0.4790508419036865
4 epoch valid loss: 0.4486215820789337
5 epoch valid loss: 0.4340327945232391
6 epoch valid loss: 0.4076675771713257
7 epoch valid loss: 0.40055596237182617
8 epoch valid loss: 0.44266173310279844
9 epoch valid loss: 0.3866080888748169
