In [4]:
import torch
from fastai.vision.all import *

In [5]:
path = untar_data(URLs.MNIST_SAMPLE)

threes = (path / "train" / "3").ls().sorted()
sevens = (path / "train" / "7").ls().sorted()

# Training Image to tensors
three_tensors = torch.stack(
    [tensor(Image.open(o)) for o in (path / "train" / "3").ls()]
)
seven_tensors = torch.stack(
    [tensor(Image.open(o)) for o in (path / "train" / "7").ls()]
)
valid_3_tensors = torch.stack(
    [tensor(Image.open(o)) for o in (path / "valid" / "3").ls()]
)
valid_7_tensors = torch.stack(
    [tensor(Image.open(o)) for o in (path / "valid" / "7").ls()]
)

# Normalize data
stacked_threes = three_tensors.float() / 255
stacked_sevens = seven_tensors.float() / 255
valid_3_tens = valid_3_tensors.float() / 255
valid_7_tens = valid_7_tensors.float() / 255

In [6]:
BATCH = 256

In [7]:
from nn.loader.dataloader import DataLoaders

In [8]:
train_x = torch.cat([stacked_threes, stacked_sevens]).view(-1, 28 * 28)
train_y = tensor([1] * len(threes) + [0] * len(sevens)).unsqueeze(1)

valid_x = torch.cat([valid_3_tens, valid_7_tens]).view(-1, 28 * 28)
valid_y = tensor([1] * len(valid_3_tens) + [0] * len(valid_7_tens)).unsqueeze(1)

In [9]:
train_x.size(), train_y.size()

(torch.Size([12396, 784]), torch.Size([12396, 1]))

In [10]:
dls = DataLoaders(
    x_train=train_x,
    y_train=train_y,
    batch_size=BATCH,
    x_val=valid_x,
    y_val=valid_y
)

In [11]:
from nn.learner.learner import Learner
from nn.model.linear import MNISTModel

In [12]:
learner = Learner(
    dls=dls,
    model=MNISTModel(28 * 28, lr=0.1),
    )

In [13]:
learner.fit(num_epochs=20)

Epoch 1/20, Validation Accuracy: 0.7502
Epoch 2/20, Validation Accuracy: 0.8081
Epoch 3/20, Validation Accuracy: 0.8430
Epoch 4/20, Validation Accuracy: 0.8675
Epoch 5/20, Validation Accuracy: 0.8852
Epoch 6/20, Validation Accuracy: 0.8999
Epoch 7/20, Validation Accuracy: 0.9102
Epoch 8/20, Validation Accuracy: 0.9156
Epoch 9/20, Validation Accuracy: 0.9210
Epoch 10/20, Validation Accuracy: 0.9235
Epoch 11/20, Validation Accuracy: 0.9264
Epoch 12/20, Validation Accuracy: 0.9289
Epoch 13/20, Validation Accuracy: 0.9333
Epoch 14/20, Validation Accuracy: 0.9352
Epoch 15/20, Validation Accuracy: 0.9362
Epoch 16/20, Validation Accuracy: 0.9392
Epoch 17/20, Validation Accuracy: 0.9406
Epoch 18/20, Validation Accuracy: 0.9421
Epoch 19/20, Validation Accuracy: 0.9436
Epoch 20/20, Validation Accuracy: 0.9441


Testing model against valid dataset

In [32]:
# valid_3_tens = valid_3_tensors.float() / 255
print(valid_3_tens.shape)
preds = []
for o in valid_3_tens:
    flatten_image = valid_3_tens.view(-1, 28*28)
    preds.append(learner.predict(x=flatten_image))
preds

torch.Size([1010, 28, 28])


[tensor([[1],
         [1],
         [1],
         ...,
         [1],
         [1],
         [1]], dtype=torch.int32),
 tensor([[1],
         [1],
         [1],
         ...,
         [1],
         [1],
         [1]], dtype=torch.int32),
 tensor([[1],
         [1],
         [1],
         ...,
         [1],
         [1],
         [1]], dtype=torch.int32),
 tensor([[1],
         [1],
         [1],
         ...,
         [1],
         [1],
         [1]], dtype=torch.int32),
 tensor([[1],
         [1],
         [1],
         ...,
         [1],
         [1],
         [1]], dtype=torch.int32),
 tensor([[1],
         [1],
         [1],
         ...,
         [1],
         [1],
         [1]], dtype=torch.int32),
 tensor([[1],
         [1],
         [1],
         ...,
         [1],
         [1],
         [1]], dtype=torch.int32),
 tensor([[1],
         [1],
         [1],
         ...,
         [1],
         [1],
         [1]], dtype=torch.int32),
 tensor([[1],
         [1],
         [1],
      

So 3 is true

In [33]:
# valid_7_tens = valid_7_tensors.float() / 255
print(valid_7_tens.shape)
preds = []
for o in valid_7_tens:
    flatten_image = valid_7_tens.view(-1, 28*28)
    preds.append(learner.predict(x=flatten_image))
preds

torch.Size([1028, 28, 28])


[tensor([[0],
         [0],
         [0],
         ...,
         [0],
         [0],
         [0]], dtype=torch.int32),
 tensor([[0],
         [0],
         [0],
         ...,
         [0],
         [0],
         [0]], dtype=torch.int32),
 tensor([[0],
         [0],
         [0],
         ...,
         [0],
         [0],
         [0]], dtype=torch.int32),
 tensor([[0],
         [0],
         [0],
         ...,
         [0],
         [0],
         [0]], dtype=torch.int32),
 tensor([[0],
         [0],
         [0],
         ...,
         [0],
         [0],
         [0]], dtype=torch.int32),
 tensor([[0],
         [0],
         [0],
         ...,
         [0],
         [0],
         [0]], dtype=torch.int32),
 tensor([[0],
         [0],
         [0],
         ...,
         [0],
         [0],
         [0]], dtype=torch.int32),
 tensor([[0],
         [0],
         [0],
         ...,
         [0],
         [0],
         [0]], dtype=torch.int32),
 tensor([[0],
         [0],
         [0],
      

Its working! (for 3 vs 7)