* Dataset
* Build model
* Define loss function and optimizer
* Define trainer (model produces prediction -> compute the loss (label - pred) -> backward pass)
* Define test (on validation) -> training/validation
* Run trainer and test

In [None]:
import torch
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


In [None]:
# get the dataset
train_ds = MNIST(root='data', train=True, download=True, transform=ToTensor())
valid_ds = MNIST(root='data', train=False, download=True, transform=ToTensor())

In [None]:
# Examine the data

image, label = train_ds[0]
print(label)
plt.imshow(image.float().reshape(28, 28), cmap='gray')

In [None]:
# Create a data loader
# Makes it easier to iterate over batches
# GPU accelration -> num_workers = 4, pin_memory (pinned memory -> faster RAM to gpu transfer)
bs = 64
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=bs, shuffle=True)

In [None]:
from torch import nn

# build a model
class MNISTModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = nn.Linear(784, 10)
    
    def forward(self, xb):
        xb = xb.flatten(1, -1) # (bs, 1, 28, 28) -> (bs, 784)
        return self.lin(xb)

model = MNISTModel()
print(model)

In [None]:
from torch import optim

lr = 0.5

# define a loss function & optimizer
loss_func = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr)

# accuracy func for logging
def accuracy_func(pred, yb):
    pred_class = torch.argmax(pred, dim=1)
    return (pred_class == yb).float().mean()


In [None]:
# Define a trainer

def train(dataloader, model, loss_func, optimizer):
    # for batch_idx, (xb, yb) in enumerate(dataloader):
    for batch_idx, (xb, yb)  in enumerate(dataloader):
        # loss 
        pred = model(xb)
        loss = loss_func(pred, yb)

        # backprop
        loss.backward() # autograd
        optimizer.step() # updates the parameters using the optimizer (SGD/Adam)
        optimizer.zero_grad()

        # logging
        if batch_idx % 100 == 0:
            train_loss, train_accuracy = loss.item(), accuracy_func(pred, yb).item() * 100
            print(f"Loss: {train_loss:6f} Accuracy: {train_accuracy:0.1f}%")

In [None]:
# Define test -> Not updating any parameters
def test(dataloader, model, loss_func):
    model.eval() # flag to make sure things like dropout works as expected in testing
    with torch.no_grad():
        for xb, yb in dataloader:
            pred = model(xb)
            loss = loss_func(pred, yb)
            accuracy = accuracy_func(pred, yb)

            # logging
            test_loss, test_accuracy = loss.item(), accuracy.item() * 100
    
    print(f"Test:\n Loss: {test_loss:6f}, Accuracy: {test_accuracy:0.1f}%")

In [None]:
# Run

epochs = 2

for t in range(epochs):
    print(f"\nEpoch: {t}\n--------------------")
    train(train_dl, model, loss_func, optimizer)
    test(valid_dl, model, loss_func)

print("\nFinished!!!!!")