<a href="https://colab.research.google.com/github/RoyEHamlin/PyTorch-Lightning-Practice-01/blob/main/MNIST_nn_GPU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Practice MNIST
#### from https://www.youtube.com/watch?v=OMDn66kM9Qc

In [None]:
import torch
from torch import nn
from torch import optim
from torchvision import datasets, transforms
from torch.utils.data import random_split, DataLoader

In [None]:
# Defining model
model = nn.Sequential(
    nn.Linear(28 * 28, 64),
    nn.ReLU(), 
    nn.Linear(64, 64),
    nn.ReLU(),
    nn.Linear(64, 10)
)

In [None]:
# Define my optimizer
params = model.parameters()
optimizer = optim.SGD(params, lr=1e-2)

In [None]:
# Define loss
loss = nn.CrossEntropyLoss()

### Both Model and Optimizer can 'zero' gradient.
#### https://youtu.be/OMDn66kM9Qc?t=712

In [None]:
# Train, Val split
train_data = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())
train, val = random_split(train_data, [55000, 5000])
train_loader = DataLoader(train, batch_size=32)
val_loader = DataLoader(val, batch_size=32)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



In [None]:
# Training and validation Loop
nb_epochs = 5
for epoch in range(nb_epochs):
    losses = list()
    for batch in train_loader:
        x, y = batch

        # x: b x  1 x 28 x 28
        b = x.size(0)  # b = number of rows
        x = x.view(b, -1)  # x = length of each row = 28^2

        # 1 forward
        l = model(x) # l: logits

        # 2 computer objective function
        J = loss (l, y) # l = logit, y = label

        # 3 cleaning the gradient  # https://youtu.be/OMDn66kM9Qc?t=1235
        model.zero_grad()
        # optimizer.zero_grad()
        # params.grad.zero_()

        # 4 accumulate the partial derivatives of J wrt params
        # https://youtu.be/OMDn66kM9Qc?t=1339
        J.backward()
        # params.grad.sum_(dJ/dparams)

        # 5 step in the opposite direction of the gradient
        optimizer.step()
        # with torch.no_grad(): params = params - eta * params.grad   # long hand # https://youtu.be/OMDn66kM9Qc?t=796 (logic)

        losses.append(J.item()) # otherwise, would run out of mem, https://youtu.be/OMDn66kM9Qc?t=1557

    print(f'Epoch {epoch + 1}, train loss: {torch.tensor(losses).mean():.2f}')

    losses = list()
    for batch in val_loader:
        x, y = batch

        # x: b x  1 x 28 x 28
        b = x.size(0)  # b = number of rows
        x = x.view(b, -1)  # x = length of each row = 28^2

        # 1 forward (no gradient) # https://youtu.be/OMDn66kM9Qc?t=1617
        with torch.no_grad():
            l = model(x) # l: logits

        # 2 computer objective function
        J = loss (l, y) # l = logit, y = label


        losses.append(J.item()) 

    print(f'Epoch {epoch + 1}, validation loss: {torch.tensor(losses).mean():.2f}')
        

Epoch 1, train loss: 1.24
Epoch 1, validation loss: 0.53
Epoch 2, train loss: 0.41
Epoch 2, validation loss: 0.37
Epoch 3, train loss: 0.32
Epoch 3, validation loss: 0.33
Epoch 4, train loss: 0.28
Epoch 4, validation loss: 0.29
Epoch 5, train loss: 0.25
Epoch 5, validation loss: 0.27
