<a href="https://colab.research.google.com/github/agnair00/playdate_with_python/blob/main/mnist_basics_trial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install fastbook;

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
from fastai.vision.all import *
from fastbook import *
import torch

In [3]:
path = untar_data(URLs.MNIST_SAMPLE); path

Path('/root/.fastai/data/mnist_sample')

First step is create our dataloaders. Atthe end of it, we should have our training dataloader and our validation dataloader.

In [4]:
train_3_stack = torch.stack([tensor(Image.open(p)).float() for p in (path/'train/3').ls()])
train_3_stack.shape

torch.Size([6131, 28, 28])

In [5]:
train_7_stack = torch.stack([tensor(Image.open(p)).float() for p in (path/'train/7').ls()])
train_7_stack.shape

torch.Size([6265, 28, 28])

In [6]:
train_x = torch.cat([train_3_stack, train_7_stack]).flatten(1)
train_x.shape

torch.Size([12396, 784])

In [7]:
train_y = tensor([1]*len(train_3_stack) + [0]*len(train_7_stack)).float().unsqueeze(1)
train_y.shape

torch.Size([12396, 1])

In [8]:
valid_3_stack = torch.stack([tensor(Image.open(p)).float() for p in (path/'valid/3').ls()])
valid_3_stack.shape

torch.Size([1010, 28, 28])

In [9]:
valid_7_stack = torch.stack([tensor(Image.open(p)).float() for p in (path/'valid/7').ls()])
valid_7_stack.shape

torch.Size([1028, 28, 28])

In [10]:
valid_x = torch.cat([valid_3_stack, valid_7_stack]).flatten(1)
valid_x.shape

torch.Size([2038, 784])

In [11]:
valid_y = tensor([1]*len(valid_3_stack) + [0]*len(valid_7_stack)).float().unsqueeze(1)
valid_y.shape

torch.Size([2038, 1])

In [12]:
train_dset = list(zip(train_x, train_y))

In [13]:
train_dl = torch.utils.data.DataLoader(train_dset, batch_size=256)
len(train_dl)

49

So our training dataloader has 49 batches

In [14]:
valid_dset = list(zip(valid_x, valid_y))
valid_dl = torch.utils.data.DataLoader(valid_dset, batch_size=256)
len(valid_dl)

8

So now we have our training and our validation dataloaders

In [15]:
def init_params(size):
    return torch.randn(size).requires_grad_()

In [16]:
params = init_params((784, 1)), init_params(1)

In [17]:
def linear1(xb, params):
    weights, bias = params
    return (xb @ weights) + bias

In [18]:
def mnist_loss(predictions, targets):
    predictions = predictions.sigmoid()
    return torch.where(targets == 1.0, 1 - predictions, predictions).mean()

In [19]:
mnist_loss(linear1(train_x[:4], params), train_y[:4])

tensor(0.7500, grad_fn=<MeanBackward0>)

In [20]:
def calc_grad(model, loss_f, params, xb, yb):
    preds = model(xb, params)
    loss = loss_f(preds, yb)
    loss.backward()

In [21]:
for param in params:
    try:
        print(param.grad.shape)
    except:
        print('param has no grad')

param has no grad
param has no grad


In [22]:
calc_grad(linear1, mnist_loss, params, train_x[:4], train_y[:4])

In [23]:
for param in params:
    try:
        print(param.grad.shape)
    except:
        print('param has no grad')

torch.Size([784, 1])
torch.Size([1])


In [24]:
def train_epoch(train_dl, model, loss, params, lr):
    for xb, yb in train_dl:
        calc_grad(model, loss, params, xb, yb)
        for param in params:
            with torch.no_grad():
                param -= lr * param.grad
            param.grad.zero_()

In [25]:
def batch_accuracy(preds, yb):
    preds = preds.sigmoid()
    return ((preds > 0.5) == yb).float().mean()

In [26]:
def validate_epoch(valid_dl, model, params):
    with torch.no_grad():
        accs = [batch_accuracy(model(xb, params), yb) for xb, yb in valid_dl]
        accuracy = torch.stack(accs)
    return round(accuracy.mean().item(), 4)

In [27]:
validate_epoch(valid_dl, linear1, params)

0.5451

In [28]:
lr = 1.0
for _ in range(20):
    train_epoch(train_dl, linear1, mnist_loss, params, lr)
    print(validate_epoch(valid_dl, linear1, params), end=' ')

0.6279 0.9191 0.9405 0.9381 0.9454 0.9443 0.9443 0.9189 0.9256 0.9237 0.9485 0.9552 0.9553 0.9557 0.9557 0.9557 0.9557 0.9557 0.9557 0.9557 