In [1]:
import torch 
from fastai import *
from fastai.vision.all import *


In [2]:
# download MNIST dataset
path = untar_data(URLs.MNIST)


In [3]:
def mnist_loss(predictions, targets):
  sm = torch.log_softmax(predictions, dim=1)
  return F.nll_loss(sm, targets.squeeze())


In [4]:
def calc_grad(xb, yb, model):
    preds = model(xb)
    loss = mnist_loss(preds, yb)
    loss.backward()


In [5]:
def train_epoch(model):
    for xb, yb in train_dl:
        calc_grad(xb, yb, model)
        opt.step()
        opt.zero_grad()


In [6]:
def batch_accuracy(xb, yb):
    pred_nums = torch.argmax(xb, axis=1)
    return (pred_nums==yb.T).float().mean()


In [7]:
def validate_epoch(model):
    accs = [batch_accuracy(model(xb), yb) for xb, yb in valid_dl]
    return round(torch.stack(accs).mean().item(), 4)


In [8]:
def train_model(model, epochs):
    for i in range(epochs):
        train_epoch(model)
        print(validate_epoch(model), end=' ')


In [9]:
class BasicOptim:
    def __init__(self,params,lr): self.params,self.lr = list(params),lr

    def step(self, *args, **kwargs):
        for p in self.params: p.data -= p.grad.data * self.lr

    def zero_grad(self, *args, **kwargs):
        for p in self.params: p.grad = None

In [10]:
# take images from downloaded MNIST sample
train_digits = []
valid_digits = []
for i in range(0, 10):
    train_digits.append((path/'training'/f'{i}').ls().sorted())
    valid_digits.append((path/'testing'/f'{i}').ls().sorted())


In [11]:
# images to tensor
train_digits_tensors = []
for digit in train_digits:
    train_digits_tensors.append([tensor(Image.open(o)) for o in digit])

valid_digits_tensors = []
for digit in valid_digits:
    valid_digits_tensors.append([tensor(Image.open(o)) for o in digit])


In [12]:
# stacking tensors
train_stacked_digits = []
for digit in train_digits_tensors:
    train_stacked_digits.append(torch.stack(digit).float()/255)
    
train_stacked_digits[9].shape


torch.Size([5949, 28, 28])

In [13]:
valid_stacked_digits = []
for digit in valid_digits_tensors:
    valid_stacked_digits.append(torch.stack(digit).float()/255)
    
valid_stacked_digits[9].shape


torch.Size([1009, 28, 28])

In [14]:
train_x = torch.cat(train_stacked_digits).view(-1, 28*28)

train_y_arr = [[i]*len(train_digits[i]) for i in range(0, 10)]
train_y_flatted = [item for sublist in train_y_arr for item in sublist]

train_y = tensor(train_y_flatted).unsqueeze(1)

# create train data set
train_dset = list(zip(train_x, train_y))

In [15]:
valid_x = torch.cat(valid_stacked_digits).view(-1, 28*28)

valid_y_arr = [[i]*len(valid_digits[i]) for i in range(0, 10)]
valid_y_flatted = [item for sublist in valid_y_arr for item in sublist]

valid_y = tensor(valid_y_flatted).unsqueeze(1)

# create validation data set
valid_dset = list(zip(valid_x, valid_y))


In [16]:
# create DataLoader's
train_dl = DataLoader(train_dset, batch_size=256)
valid_dl = DataLoader(valid_dset, batch_size=256)

In [17]:
lr = 0.001
epoch = 40
linear_model = nn.Linear(28*28, 10)
opt = BasicOptim(linear_model.parameters(), lr)

# train the model
train_model(linear_model, epoch)


0.3296 0.4893 0.5827 0.6478 0.6869 0.7159 0.7369 0.7554 0.7695 0.7824 0.7921 0.7979 0.8043 0.8073 0.8123 0.8171 0.8211 0.8264 0.8293 0.8322 0.8349 0.8371 0.8395 0.8412 0.8433 0.8449 0.8467 0.8481 0.8492 0.8503 0.8518 0.8531 0.8537 0.8555 0.8571 0.8573 0.8583 0.8589 0.8599 0.8606 

In [18]:
# final accuracy ~ 0.86
validate_epoch(linear_model)


0.8606