In [84]:
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

import pickle, gzip, os
from pathlib import Path
from fastai import datasets

In [5]:
MNIST_URL='http://deeplearning.net/data/mnist/mnist.pkl'

In [27]:
path = Path(os.getcwd() + "/dataset/mnist")
path = datasets.download_data(MNIST_URL, path, ext='.gz')

In [39]:
with gzip.open(path, 'rb') as f:
    ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')

In [46]:
x_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid))

In [72]:
class Dataset():
    def __init__(self, x, y): self.x, self.y = x, y
    def __len__(self): return len(self.x)
    def __getitem__(self, i): return self.x[i], self.y[i] 

In [162]:
bs = 64
train_ds = Dataset(x_train, y_train)
valid_ds = Dataset(x_valid, y_valid)

train_dl = DataLoader(train_ds, bs, True)
valid_dl = DataLoader(valid_ds, bs*2)

In [165]:
model = nn.Sequential(nn.Linear(x_train.shape[-1], 50), nn.ReLU(), nn.Linear(50, 10))
opt = optim.SGD(model.parameters(), .5)

In [166]:
def accuracy(out, yb): return (torch.argmax(out, dim=1)==yb).float().mean()

In [167]:
def fit(epochs, model, loss_func, opt, train_dl, valid_dl):
    for e in range(epochs):
        model.train()
        for xb, yb in train_dl:
            loss_func(model(xb), yb).backward()
            opt.step()
            opt.zero_grad()
        
        model.eval()
        with torch.no_grad():
            tot_loss, tot_acc = 0., 0.
            for xb, yb in valid_dl:
                pred = model(xb)
                tot_loss += loss_func(pred, yb)
                tot_acc  += accuracy (pred, yb)
        nv = len(valid_dl)
        print("epoch: {}  loss: {}  acc: {}".format(e, tot_loss/nv, tot_acc/nv))
    return tot_loss/nv, tot_acc/nv

In [169]:
fit(4, model, F.cross_entropy, opt, train_dl, valid_dl)

epoch: 0  loss: 0.10733925551176071  acc: 0.969936728477478
epoch: 1  loss: 0.09393870085477829  acc: 0.974782407283783
epoch: 2  loss: 0.11095812171697617  acc: 0.9709256291389465
epoch: 3  loss: 0.16840021312236786  acc: 0.9548061490058899


(tensor(0.1684), tensor(0.9548))

In [139]:
torch.save(model.state_dict(), "weight.pth")