# MNIST Comparision

Here we are going to compare predicting MNIST with a few different set of models:

* Simple Custom Liner Model
* with nn.Linear
* Simple 3 Layer Custom Nueral Network Model
* Simple 3 Layer Nueral Network Model (based on nn.XXX)
* With CNN

In [2]:
from fastai2.vision.all import *
from utils import *
matplotlib.rc('image', cmap='Blues')

## Loading Images

In [3]:
im_path = untar_data(URLs.MNIST_SAMPLE)

In [4]:
im_path.ls()

(#3) [Path('/storage/data/mnist_sample/labels.csv'),Path('/storage/data/mnist_sample/valid'),Path('/storage/data/mnist_sample/train')]

In [5]:
(im_path/"train").ls()

(#2) [Path('/storage/data/mnist_sample/train/7'),Path('/storage/data/mnist_sample/train/3')]

In [6]:
(im_path/"train/7").ls()

(#6265) [Path('/storage/data/mnist_sample/train/7/32208.png'),Path('/storage/data/mnist_sample/train/7/79.png'),Path('/storage/data/mnist_sample/train/7/54193.png'),Path('/storage/data/mnist_sample/train/7/4545.png'),Path('/storage/data/mnist_sample/train/7/2161.png'),Path('/storage/data/mnist_sample/train/7/11473.png'),Path('/storage/data/mnist_sample/train/7/3914.png'),Path('/storage/data/mnist_sample/train/7/58565.png'),Path('/storage/data/mnist_sample/train/7/8302.png'),Path('/storage/data/mnist_sample/train/7/59871.png')...]

In [7]:
def load_images(im_dir_path):
    return torch.stack([tensor(Image.open(im)).float()/255 for im in im_dir_path.ls()])

In [8]:
train_3s = load_images((im_path/"train/3"));
train_7s = load_images((im_path/"train/7"));
valid_3s = load_images((im_path/"valid/3"));
valid_7s = load_images((im_path/"valid/7"));

In [9]:
train_3s.shape, train_7s.shape, valid_3s.shape, valid_7s.shape

(torch.Size([6131, 28, 28]),
 torch.Size([6265, 28, 28]),
 torch.Size([1010, 28, 28]),
 torch.Size([1028, 28, 28]))

## Creating Dataloaders

Now we need to define these images into a format where fastai understands.

In [43]:
def get_dataloader(a, b, batch_size=225):
    x = torch.cat([a, b]).view(-1, 28*28)
    y = tensor([1] * len(a) + [0]* len(b)).unsqueeze(1)
    dset = list(zip(x, y))
    return DataLoader(dset, batch_size)

In [44]:
train_dl = get_dataloader(train_3s, train_7s)
valid_dl = get_dataloader(valid_3s, valid_7s)

In [45]:
for X, Y in train_dl:
    print(X.shape, Y.shape)

torch.Size([225, 784]) torch.Size([225, 1])
torch.Size([225, 784]) torch.Size([225, 1])
torch.Size([225, 784]) torch.Size([225, 1])
torch.Size([225, 784]) torch.Size([225, 1])
torch.Size([225, 784]) torch.Size([225, 1])
torch.Size([225, 784]) torch.Size([225, 1])
torch.Size([225, 784]) torch.Size([225, 1])
torch.Size([225, 784]) torch.Size([225, 1])
torch.Size([225, 784]) torch.Size([225, 1])
torch.Size([225, 784]) torch.Size([225, 1])
torch.Size([225, 784]) torch.Size([225, 1])
torch.Size([225, 784]) torch.Size([225, 1])
torch.Size([225, 784]) torch.Size([225, 1])
torch.Size([225, 784]) torch.Size([225, 1])
torch.Size([225, 784]) torch.Size([225, 1])
torch.Size([225, 784]) torch.Size([225, 1])
torch.Size([225, 784]) torch.Size([225, 1])
torch.Size([225, 784]) torch.Size([225, 1])
torch.Size([225, 784]) torch.Size([225, 1])
torch.Size([225, 784]) torch.Size([225, 1])
torch.Size([225, 784]) torch.Size([225, 1])
torch.Size([225, 784]) torch.Size([225, 1])
torch.Size([225, 784]) torch.Siz

In [14]:
dls = DataLoaders(train_dl, valid_dl)

## Loss & Accuracy

In [59]:
def mnist_loss(preds, Y):
    return (torch.sigmoid(preds) - Y).abs().float().mean()

In [58]:
def mnist_accuracy(preds, Y):
    return ((torch.sigmoid(preds) > 0.5).float() == Y).float().mean()

## Model 1. Simple Linear Model

In [52]:
class SimpleLinerModel(nn.Module):
    def __init__(self, in_features, out_features=1):
        super(SimpleLinerModel, self).__init__()
        self.W = torch.randn((in_features, out_features)).requires_grad_()
        self.B = torch.randn(1).requires_grad_()
        
    def parameters(self):
        return [self.W, self.B]
    
    def forward(self, X):
        return X@self.W + self.B

In [77]:
learn = Learner(dls, SimpleLinerModel(28*28), opt_func=SGD, loss_func=mnist_loss, metrics=mnist_accuracy)

In [78]:
learn.fit(20, lr=1.)

epoch,train_loss,valid_loss,mnist_accuracy,time
0,0.18557,0.331276,0.665358,00:00
1,0.099681,0.151431,0.853778,00:00
2,0.063297,0.084099,0.92002,00:00
3,0.047414,0.060975,0.942591,00:00
4,0.039413,0.049494,0.953876,00:00
5,0.03445,0.043393,0.959274,00:00
6,0.031132,0.039668,0.963199,00:00
7,0.028854,0.037071,0.965653,00:00
8,0.027196,0.035109,0.965653,00:00
9,0.025901,0.033562,0.965162,00:00


In [84]:
learn.fit(10, lr=1.)

epoch,train_loss,valid_loss,mnist_accuracy,time
0,0.01325,0.016796,0.98577,00:01
1,0.013206,0.016747,0.98577,00:00
2,0.013157,0.016702,0.98528,00:00
3,0.013105,0.01666,0.98528,00:00
4,0.013052,0.016621,0.98528,00:00
5,0.013,0.016584,0.98528,00:00
6,0.012948,0.01655,0.98528,00:00
7,0.012898,0.016518,0.98528,00:00
8,0.012849,0.016488,0.98528,00:00
9,0.0128,0.016459,0.98528,00:01


## Model 2: With nn.Linear

In [68]:
learn = Learner(dls, nn.Linear(28*28, 1), opt_func=SGD, loss_func=mnist_loss, metrics=mnist_accuracy)

In [69]:
learn.fit(20, lr=1.)

epoch,train_loss,valid_loss,mnist_accuracy,time
0,0.652766,0.496072,0.495584,00:00
1,0.214868,0.341936,0.642296,00:01
2,0.080524,0.155133,0.857704,00:00
3,0.039129,0.101883,0.910206,00:00
4,0.025461,0.076285,0.933759,00:00
5,0.020614,0.061219,0.947988,00:00
6,0.018674,0.051477,0.958783,00:00
7,0.017716,0.044971,0.965162,00:00
8,0.01711,0.04049,0.966634,00:01
9,0.016636,0.037244,0.967615,00:01


## Model 3: Nueral Network Based on nn.

In [86]:
model = nn.Sequential(
    nn.Linear(28*28, 30),
    nn.ReLU(),
    nn.Linear(30, 1)
)
learn = Learner(dls, model, opt_func=SGD, loss_func=mnist_loss, metrics=mnist_accuracy)

In [87]:
learn.fit(20, lr=.1)

epoch,train_loss,valid_loss,mnist_accuracy,time
0,0.252845,0.426497,0.505397,00:01
1,0.114675,0.20817,0.827772,00:01
2,0.062914,0.107159,0.919038,00:01
3,0.042746,0.073937,0.9421,00:01
4,0.033941,0.058212,0.955348,00:01
5,0.029519,0.049229,0.964671,00:01
6,0.026925,0.043523,0.967615,00:01
7,0.025164,0.039637,0.969087,00:01
8,0.023833,0.036818,0.97105,00:00
9,0.022752,0.034669,0.971541,00:00


In [92]:
learn.fit(10, lr=.2)

epoch,train_loss,valid_loss,mnist_accuracy,time
0,0.010632,0.018264,0.982826,00:00
1,0.010199,0.01821,0.982826,00:00
2,0.010038,0.018166,0.982826,00:00
3,0.009933,0.018126,0.982826,00:00
4,0.009845,0.018091,0.982826,00:00
5,0.009764,0.01806,0.982826,00:01
6,0.009687,0.018032,0.982826,00:00
7,0.009614,0.018007,0.982336,00:01
8,0.009543,0.017983,0.982336,00:00
9,0.009474,0.017959,0.982336,00:00
