In [44]:
import pickle
import gzip
import torch 
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

In [2]:
PATH = './data/mnist/'
FILENAME = 'mnist.pkl.gz'

In [3]:
def get_files(path,filename):
    with gzip.open((PATH + FILENAME), "rb") as file:
        ((x_train, y_train), (x_val, y_val), _) = pickle.load(file, encoding='latin-1')
    return x_train, y_train, x_val, y_val

In [4]:
def tensor_map(x_train,y_train,x_val,y_val): return map(torch.tensor,(x_train,y_train,x_val,y_val))

In [5]:
def preprocess(x):
    return x.view(-1, 1, 28, 28)

In [6]:
def conv(in_size, out_size, pad=1): 
    return nn.Conv2d(in_size, out_size, kernel_size=3, stride=2, padding=pad)

In [7]:
class ResBlock(nn.Module):
    
    def __init__(self, in_size:int, hidden_size:int, out_size:int, pad:int):
        super().__init__()
        self.conv1 = conv(in_size, hidden_size, pad)
        self.conv2 = conv(hidden_size, out_size, pad)
        self.batchnorm1 = nn.BatchNorm2d(hidden_size)
        self.batchnorm2 = nn.BatchNorm2d(out_size)
    
    def convblock(self, x):
        x = F.relu(self.batchnorm1(self.conv1(x)))
        x = F.relu(self.batchnorm2(self.conv2(x)))
        return x
    
    def forward(self, x): return x + self.convblock(x) # skip connection

In [8]:
class ResNet(nn.Module):
    
    def __init__(self, n_classes=10):
        super().__init__()
        self.res1 = ResBlock(1, 8, 16, 15)
        self.res2 = ResBlock(16, 32, 16, 15)
        self.conv = conv(16, n_classes)
        self.batchnorm = nn.BatchNorm2d(n_classes)
        self.maxpool = nn.AdaptiveMaxPool2d(1)
        
    def forward(self, x):
        x = preprocess(x)
        x = self.res1(x)
        x = self.res2(x) 
        x = self.maxpool(self.batchnorm(self.conv(x))) # remove (1,1) grid to get vector of length 10
        return x.view(x.size(0), -1)

In [25]:
def loss_batch(model, loss_func, xb, yb, opt=None, scheduler=None):
    loss = loss_func(model(xb), yb)
    if opt is not None:
        loss.backward()
        if scheduler is not None:
            scheduler.step()
        opt.step()
        opt.zero_grad()
        
    return loss.item(), len(xb)

In [40]:
def get_model():
    model = ResNet()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    return model, optimizer

In [11]:
def accuracy(out, yb):
    preds = torch.argmax(out, dim=1)
    return (preds == yb).float().mean()

In [None]:
def get_data_batches(train_ds, valid_ds, bs):
    return (
        DataLoader(train_ds, batch_size=bs, shuffle=True),
        DataLoader(valid_ds, batch_size=bs * 2),
    )

In [54]:
def fit(epochs, model, loss_func, opt, train_dl, valid_dl, scheduler=None):
    for epoch in range(epochs):
        model.train()
        for xb, yb in train_dl:
            loss_batch(model, loss_func, xb, yb, opt, scheduler)

        model.eval()
        with torch.no_grad():
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
            )
        
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)

        print(epoch, val_loss)

In [15]:
x_train, y_train, x_val, y_val = get_files(PATH, FILENAME)

In [16]:
x_train, y_train, x_val, y_val = tensor_map(x_train, y_train, x_val, y_val)

In [18]:
train_ds = TensorDataset(x_train, y_train)
val_ds = TensorDataset(x_val, y_val)
train_dl, val_dl = get_data_batches(train_ds, val_ds, bs)

In [52]:
bs=64
lr=0.01
loss_func = F.cross_entropy

In [53]:
model, opt = get_model()

In [55]:
fit(5, model, loss_func, opt, train_dl, val_dl)

0 0.2162591547846794
1 0.13687330031245948
2 0.14416173471733928
3 0.10076363628059626
4 0.07622794383913278
