# mnist-torch-function
https://pytorch.org/tutorials/beginner/nn_tutorial.html#using-torch-nn-functional

In [83]:
# setup data
import pickle
import gzip

from pathlib import Path


DATA_PATH = Path("data")
PATH = DATA_PATH / "mnist"
FILENAME = "mnist.pkl.gz"

with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")

In [84]:
# hparams?
bs = 64
n, c = x_train.shape
epochs = 10
lr = 0.5

In [87]:
# load data
from torch.utils.data import TensorDataset, DataLoader


x_train, y_train, x_valid, y_valid = map(
    torch.tensor, (x_train, y_train, x_valid, y_valid)
)

train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs)

valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs)

  


In [88]:
import math
import torch
import numpy as np
import torch.nn.functional as F

from torch import nn, optim

In [89]:
# model
class Mnist_Logistic(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = nn.Linear(784, 10)

    def forward(self, xb):
        return self.lin(xb)

In [82]:
loss_func = F.cross_entropy

In [68]:
def loss_batch(model, loss_func, xb, yb, opt=None):
    loss = loss_func(model(xb), yb)

    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    return loss.item(), len(xb)

In [70]:
def get_data(train_ds, valid_ds, bs):
    return (
        DataLoader(train_ds, batch_size=bs, shuffle=True),
        DataLoader(valid_ds, batch_size=bs * 2)
    )

In [72]:
def get_model():
    model = Mnist_Logistic()
    return model, optim.SGD(model.parameters(), lr=lr)

In [79]:
def fit(epochs, model, loss_func, opt, train_dl, valid_dl):
    for epoch in range(epochs):
        model.train()
        for xb, yb in train_dl:
            loss_batch(model, loss_func, xb, yb, opt)

        model.eval()
        with torch.no_grad():
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)

        print('{} epoch valid loss: {}'.format(epoch, val_loss))

In [90]:
train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(epochs, model, loss_func, opt, train_dl, valid_dl)

0 epoch valid loss: 0.3459483913421631
1 epoch valid loss: 0.28463590693473817
2 epoch valid loss: 0.31753084144592286
3 epoch valid loss: 0.3550556812763214
4 epoch valid loss: 0.2753871307373047
5 epoch valid loss: 0.2849240284919739
6 epoch valid loss: 0.37974022221565246
7 epoch valid loss: 0.2707487875938416
8 epoch valid loss: 0.28815152859687804
9 epoch valid loss: 0.282098974943161
