In [1]:
from sklearn.model_selection import train_test_split
import torch
from torch import nn
from fastai.vision.all import *

In [2]:
from pathlib import Path
import pickle, gzip, os
from urllib.request import urlretrieve

In [3]:
MNIST_URL='https://github.com/mnielsen/neural-networks-and-deep-learning/blob/master/data/mnist.pkl.gz?raw=true'
path_data = Path('data')
path_data.mkdir(exist_ok=True)
path_gz = path_data/'mnist.pkl.gz'

In [4]:
if not path_gz.exists(): urlretrieve(MNIST_URL, path_gz)

In [5]:
with gzip.open(path_gz, 'rb') as f: ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')

In [6]:
x_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid))

In [7]:
x_train.max()

tensor(0.9961)

In [9]:
torch.round(x_train).max()

tensor(1.)

In [7]:
x_train.shape

torch.Size([50000, 784])

In [8]:
x_train.shape, x_valid.shape

(torch.Size([50000, 784]), torch.Size([10000, 784]))

In [9]:
y_train.shape, y_valid.shape

(torch.Size([50000]), torch.Size([10000]))

In [10]:
x_train = torch.round(x_train)
x_valid = torch.round(x_valid)
y_train = y_train.unsqueeze(1)
y_valid = y_valid.unsqueeze(1)

In [11]:
x_train.max()

tensor(1.)

In [12]:
train_dset = list(zip(x_train, y_train))
valid_dset = list(zip(x_valid, y_valid))

train_dl = DataLoader(train_dset, batch_size=256, shuffle=True)
valid_dl = DataLoader(valid_dset, batch_size=256, shuffle=False)
dls = DataLoaders(train_dl, valid_dl) # fast.ai wrapper that encapsulates train_dl and valid_dl

In [15]:
n_net = nn.Sequential(
    nn.Linear(28*28, 350),
    nn.ReLU(),
    nn.Linear(350, 70),
    nn.ReLU(),
    nn.Linear(70, 10)
)

In [16]:
learn_ce = Learner(dls, n_net, loss_func=CrossEntropyLossFlat(), metrics=accuracy)
learn_ce.fit(15)

epoch,train_loss,valid_loss,accuracy,time
0,0.30939,0.230343,0.9316,00:02
1,0.18986,0.156143,0.9553,00:01
2,0.129075,0.130662,0.9615,00:01
3,0.093555,0.11046,0.9678,00:01
4,0.072031,0.101787,0.9707,00:01
5,0.052493,0.094721,0.973,00:01
6,0.040231,0.097211,0.9723,00:01
7,0.027599,0.099014,0.9736,00:01
8,0.025306,0.095247,0.9731,00:01
9,0.016609,0.095541,0.9745,00:01


In [17]:
y_train[52]

tensor([7])

In [18]:
with torch.no_grad():
    print(torch.argmax(torch.softmax(n_net(x_train[52]), dim=0)))

tensor(7)


In [19]:
with open("mnist_net.pkl", "wb") as f:
    pickle.dump(n_net, f)

In [None]:
with open("mnist_net.pkl", "rb") as f:
    tnet = pickle.load(f)

In [None]:
with torch.no_grad():
    print(torch.argmax(torch.softmax(tnet(x_train[10]), dim=0)))

In [None]:
y_train[101]