Setup:

In [1]:
%pip install -Uqq fastai
from fastai.vision.all import *

Note: you may need to restart the kernel to use updated packages.


Import MNIST sample that contains samples of all numbers from 0 to 9:

In [2]:
path = untar_data(URLs.MNIST)
Path.BASE_PATH = path
(path/'testing').ls()

(#10) [Path('testing/9'),Path('testing/0'),Path('testing/7'),Path('testing/6'),Path('testing/1'),Path('testing/8'),Path('testing/4'),Path('testing/3'),Path('testing/2'),Path('testing/5')]

Create dictionary `digits` and `digits_training` of all digits in MNIST sample:

In [3]:
import os
dir = os.listdir(path/'training')
print((path/'testing'/'9').ls().sorted())

[Path('testing/9/1000.png'), Path('testing/9/1005.png'), Path('testing/9/1013.png'), Path('testing/9/104.png'), Path('testing/9/1045.png'), Path('testing/9/1048.png'), Path('testing/9/105.png'), Path('testing/9/1058.png'), Path('testing/9/1063.png'), Path('testing/9/108.png'), Path('testing/9/1081.png'), Path('testing/9/1086.png'), Path('testing/9/1088.png'), Path('testing/9/1090.png'), Path('testing/9/1103.png'), Path('testing/9/1105.png'), Path('testing/9/1107.png'), Path('testing/9/113.png'), Path('testing/9/1130.png'), Path('testing/9/1152.png'), Path('testing/9/1165.png'), Path('testing/9/118.png'), Path('testing/9/1183.png'), Path('testing/9/1192.png'), Path('testing/9/12.png'), Path('testing/9/1217.png'), Path('testing/9/1228.png'), Path('testing/9/1232.png'), Path('testing/9/1247.png'), Path('testing/9/125.png'), Path('testing/9/1255.png'), Path('testing/9/1277.png'), Path('testing/9/1282.png'), Path('testing/9/1304.png'), Path('testing/9/1308.png'), Path('testing/9/1309.png'),

In [4]:
import os

dir = os.listdir(path/'training')
dir_valid = os.listdir(path/'testing')

digits = {}
digits['larger_group'] = []
digits['smaller_group'] = []
for digit in dir:
    if digit in ['9', '4', '7', '1']:
        digits['smaller_group'] += (path/'testing'/digit).ls().sorted()
    else:
        digits['larger_group'] += (path/'testing'/digit).ls().sorted()

    
digits_valid = {}
digits_valid['larger_group'] = []
digits_valid['smaller_group'] = []
for digit in dir_valid:
    if digit in ['9', '4', '7', '1']:
        digits_valid['smaller_group'] += (path/'testing'/digit).ls().sorted()
    else:
        digits_valid['larger_group'] += (path/'testing'/digit).ls().sorted()

**Organize all training data:**

Transform data in `digits` elements into tuples containing stacked tensors and the amount of images:

In [5]:
for name, data in digits.items():
    digit_tensor = [tensor(Image.open(o)) for o in data]
    stacked = torch.stack(digit_tensor).float()/255
    digits[name] = (stacked, len(data))

Create training set `train_x` containing each digit in the set:

In [6]:
train_x = torch.cat([i[0] for i in digits.values()]).view(-1, 28*28)

Create training set `train_y` containing the amount of each digit in the set:

In [7]:
train_y = tensor([0]*digits['smaller_group'][1] + [1]*digits['larger_group'][1]).unsqueeze(1)
train_x.shape,train_y.shape

(torch.Size([10000, 784]), torch.Size([10000, 1]))

Create dataset of `train_x` and `train_y`:

In [8]:
dset = list(zip(train_x,train_y))

Create a `DataLoader` from `Dataset`:

In [9]:
dl = DataLoader(dset, batch_size=1024)
xb,yb = first(dl)
xb.shape,yb.shape

(torch.Size([1024, 784]), torch.Size([1024, 1]))

**Organize all validation data:**

Transform data in `digits_valid` elements into tuples containing stacked tensors and the amount of images:

In [10]:
for digit, data in digits_valid.items():
    digit_tensor = [tensor(Image.open(o)) for o in data]
    stacked_digit = torch.stack(digit_tensor).float()/255
    digits_valid[digit] = (stacked_digit, len(data))

Create training set `valid_x` containing each digit in the set:

In [11]:
valid_x = torch.cat([i[0] for i in digits_valid.values()]).view(-1, 28*28)

Create training set `valid_y` containing the amount of each digit in the set:

In [12]:
valid_y = tensor([0]*digits['smaller_group'][1] + [1]*digits['larger_group'][1]).unsqueeze(1)
valid_x.shape,valid_y.shape

(torch.Size([10000, 784]), torch.Size([10000, 1]))

Create dataset of `valid_x` and `valid_y`:

In [13]:
valid_dset = list(zip(valid_x,valid_y))

Create a `DataLoader` from `Dataset`:

In [14]:
valid_dl = DataLoader(valid_dset, batch_size=1024)

**Create dataloader of both training and validation data:**

In [15]:
dls = DataLoaders(dl, valid_dl)

**Training the model:**

Create model to be trained:

In [16]:
from learning_functions import NEURAL_NET_STRUCTURE, mnist_loss, batch_accuracy

learn = Learner(dls, NEURAL_NET_STRUCTURE, opt_func=SGD, loss_func=mnist_loss, metrics=batch_accuracy)

In [17]:
learn.fit(100, 1)
learn.fit(100, 5)
learn.fit(100, 10)
learn.fit(100, 25)

epoch,train_loss,valid_loss,batch_accuracy,time
0,0.61879,0.494184,0.6544,00:00
1,0.603444,0.448517,0.735,00:00
2,0.530937,0.342956,0.6653,00:00
3,0.449519,0.296812,0.7057,00:00
4,0.397933,0.261272,0.7422,00:00
5,0.36381,0.248636,0.7543,00:00
6,0.337771,0.237631,0.7645,00:00
7,0.316389,0.226149,0.7757,00:00
8,0.298106,0.214615,0.7879,00:00
9,0.282288,0.206564,0.7965,00:00


epoch,train_loss,valid_loss,batch_accuracy,time
0,0.168973,0.248907,0.75,00:00
1,0.178088,0.216212,0.7838,00:00
2,0.17712,0.203939,0.7951,00:00
3,0.17785,0.197747,0.8026,00:00
4,0.178442,0.190122,0.8103,00:00
5,0.179525,0.188861,0.8109,00:00
6,0.178809,0.184405,0.8147,00:00
7,0.177431,0.182161,0.8177,00:00
8,0.176404,0.182198,0.8181,00:00
9,0.175346,0.181329,0.8186,00:00


epoch,train_loss,valid_loss,batch_accuracy,time
0,0.07387,0.244685,0.7557,00:00
1,0.166066,0.238744,0.7611,00:00
2,0.166476,0.169055,0.8314,00:00
3,0.148422,0.170249,0.8294,00:00
4,0.138383,0.136828,0.8629,00:00
5,0.129424,0.127475,0.8733,00:00
6,0.123503,0.126541,0.8737,00:00
7,0.117744,0.122366,0.8789,00:00
8,0.115679,0.148904,0.8513,00:00
9,0.115919,0.144296,0.8555,00:00


epoch,train_loss,valid_loss,batch_accuracy,time
0,0.046407,0.084278,0.9156,00:00
1,0.060211,0.313402,0.6875,00:00
2,0.136506,0.20948,0.7905,00:00
3,0.141274,0.324462,0.675,00:00
4,0.145015,0.189618,0.8105,00:00
5,0.146642,0.252428,0.7474,00:00
6,0.15874,0.152387,0.847,00:00
7,0.153074,0.132548,0.8678,00:00
8,0.145751,0.137779,0.8628,00:00
9,0.139848,0.112739,0.8879,00:00
