## Training A Digit Classifier That Distinguishes 3s from 7s using the MNIST Data Set 

In [44]:
# Import the data 
import numpy as np
from fastai.vision.all import *
path = untar_data(URLs.MNIST_SAMPLE)

In [45]:
# collecting and formatting the data 
threes = (path/'train'/'3').ls()
sevens = (path/'train'/'7').ls()

sevens_tensors = [tensor(Image.open(o)) for o in sevens]
three_tensors = [tensor(Image.open(o)) for o in threes]

stacked_sevens = torch.stack(sevens_tensors).float()/255
stacked_threes = torch.stack(three_tensors).float()/255


valid_3_tens = torch.stack([tensor(Image.open(o)) for o in (path/'valid'/'3').ls()])
valid_3_tens = valid_3_tens.float()/255

valid_7_tens = torch.stack([tensor(Image.open(o)) for o in (path/'valid'/'7').ls()])
valid_7_tens = valid_7_tens.float()/255

In [46]:
train_x = torch.cat([stacked_threes, stacked_sevens]).view(-1, 28*28)
train_y = tensor([1] * len(threes) + [0] * len(sevens)).unsqueeze(1)

dset = list(zip(train_x, train_y)) 

valid_x = torch.cat([valid_3_tens, valid_7_tens]).view(-1, 28*28)
valid_y = tensor([1] * len(valid_3_tens) + [0]*len(valid_7_tens)).unsqueeze(1)
valid_dset = list(zip(valid_x, valid_y))

In [47]:
def mnist_loss(predictions, targets):
    predictions = predictions.sigmoid()
    return torch.where(targets==1, 1-predictions, predictions).mean()

def batch_accuracy(xb, yb):
    preds = xb.sigmoid()
    correct = (preds>0.5) == yb
    return correct.float().mean()

In [48]:
dl = DataLoader(dset, batch_size=256)
valid_dl = DataLoader(valid_dset, batch_size=256)


In [49]:
dls = DataLoaders(dl, valid_dl)
learn = Learner(dls, nn.Linear(28*28, 1), opt_func=SGD, loss_func=mnist_loss, metrics=batch_accuracy)

In [50]:
learn.fit(10, lr=1.)

epoch,train_loss,valid_loss,batch_accuracy,time
0,0.636678,0.502701,0.495584,00:00
1,0.365168,0.266403,0.752208,00:00
2,0.139909,0.163116,0.85476,00:00
3,0.065297,0.101019,0.913641,00:00
4,0.037651,0.074214,0.936212,00:00
5,0.026861,0.059421,0.950442,00:00
6,0.022291,0.050305,0.959274,00:00
7,0.02006,0.044364,0.966143,00:00
8,0.018772,0.040238,0.967125,00:00
9,0.017913,0.037189,0.969087,00:00


In [53]:
# using a more complex net 
complex_net = nn.Sequential(
    nn.Linear(28*28, 30),
    nn.ReLU(),
    nn.Linear(30,1)
)

learn = Learner(dls, complex_net, opt_func=SGD, 
                loss_func=mnist_loss, metrics=batch_accuracy)

learn.fit(40, 0.1)

epoch,train_loss,valid_loss,batch_accuracy,time
0,0.316639,0.40442,0.50736,00:00
1,0.145767,0.222731,0.810599,00:00
2,0.080953,0.109807,0.919529,00:00
3,0.05352,0.074146,0.943572,00:00
4,0.04082,0.058028,0.958292,00:00
5,0.034279,0.049082,0.964671,00:00
6,0.030461,0.043475,0.966143,00:00
7,0.027947,0.03964,0.967615,00:00
8,0.026125,0.036827,0.970559,00:00
9,0.024713,0.034659,0.972522,00:00


###  By using a more complicated architecture, we achieved a higher accuracy `98.2%` versus `96.9%`

In [1]:
%pip install numpy

Defaulting to user installation because normal site-packages is not writeable
You should consider upgrading via the '/Library/Developer/CommandLineTools/usr/bin/python3 -m pip install --upgrade pip' command.[0m
Note: you may need to restart the kernel to use updated packages.
