## Training A Digit Classifier That Distinguishes 3s from 7s using the MNIST Data Set 

In [1]:
# Import the data 
import numpy as np
from fastai.vision.all import *
path = untar_data(URLs.MNIST_SAMPLE)



In [2]:
# collecting and formatting the data 
threes = (path/'train'/'3').ls()
sevens = (path/'train'/'7').ls()

sevens_tensors = [tensor(Image.open(o)) for o in sevens]
three_tensors = [tensor(Image.open(o)) for o in threes]

stacked_sevens = torch.stack(sevens_tensors).float()/255
stacked_threes = torch.stack(three_tensors).float()/255


valid_3_tens = torch.stack([tensor(Image.open(o)) for o in (path/'valid'/'3').ls()])
valid_3_tens = valid_3_tens.float()/255

valid_7_tens = torch.stack([tensor(Image.open(o)) for o in (path/'valid'/'7').ls()])
valid_7_tens = valid_7_tens.float()/255

In [3]:
train_x = torch.cat([stacked_threes, stacked_sevens]).view(-1, 28*28)
train_y = tensor([1] * len(threes) + [0] * len(sevens)).unsqueeze(1)

dset = list(zip(train_x, train_y)) 

valid_x = torch.cat([valid_3_tens, valid_7_tens]).view(-1, 28*28)
valid_y = tensor([1] * len(valid_3_tens) + [0]*len(valid_7_tens)).unsqueeze(1)
valid_dset = list(zip(valid_x, valid_y))

In [4]:
def mnist_loss(predictions, targets):
    predictions = predictions.sigmoid()
    return torch.where(targets==1, 1-predictions, predictions).mean()

def batch_accuracy(xb, yb):
    preds = xb.sigmoid()
    correct = (preds>0.5) == yb
    return correct.float().mean()

In [5]:
dl = DataLoader(dset, batch_size=256)
valid_dl = DataLoader(valid_dset, batch_size=256)


In [6]:
dls = DataLoaders(dl, valid_dl)
learn = Learner(dls, nn.Linear(28*28, 1), opt_func=SGD, loss_func=mnist_loss, metrics=batch_accuracy)

In [7]:
learn.fit(10, lr=1.)

epoch,train_loss,valid_loss,batch_accuracy,time
0,0.636673,0.503002,0.495584,00:00
1,0.400988,0.248487,0.773307,00:00
2,0.151842,0.169814,0.84789,00:00
3,0.069864,0.103047,0.914622,00:00
4,0.039466,0.075096,0.935231,00:00
5,0.027632,0.05996,0.950932,00:00
6,0.02267,0.05068,0.960255,00:00
7,0.020289,0.044648,0.965162,00:00
8,0.018939,0.040474,0.967125,00:00
9,0.018051,0.0374,0.969087,00:00
