In [5]:
#base model and data that we're going to train

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork()

In [6]:
#set hyperparameters first (parameters that arent trained)
learning_rate = 1e-3
batch_size = 64
epochs = 5

we're going to train this model in an optimisation loop. each iteration of the loop is called an epoch

every epoch got two parts:
- train loop - iterate over the train dataset and attempt to converge to optimal parameters
- validation/test loop - iterate over test dataset and see if performance is improving

train loop:

the loss function is very important here. it measures the level of dissimilarity of the predicted to the actual result. 
to calculate loss, we make a prediction and compare it to the true data label value (ground truth)

Common loss functions include nn.MSELoss (Mean Square Error) for regression tasks, and nn.NLLLoss (Negative Log Likelihood) for classification. nn.CrossEntropyLoss combines nn.LogSoftmax and nn.NLLLoss.

In [7]:
# Initialize the loss function
loss_fn = nn.CrossEntropyLoss() # normalize the logits and compute the prediction error

optimisation is adjusting model params to reduce the model error. optimisation algos determine how this is performed. 

for this example, we use stochastic gradient descent. all optimisation logic is encapsulated and abstracted within the optimizer object. 

other optimisers include ADAM or RMSProp, which work better for different models and data. 

We initialize the optimizer by registering the model’s parameters that need to be trained, and passing in the learning rate hyperparameter.

In [8]:
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

Inside the training loop, optimization happens in three steps:
1. Call optimizer.zero_grad() to reset the gradients of model parameters. Gradients by default add up; to prevent double-counting, we explicitly zero them at each iteration.

2. Backpropagate the prediction loss with a call to loss.backward(). PyTorch deposits the gradients of the loss w.r.t. each parameter.

3. Once we have our gradients, we call optimizer.step() to adjust the parameters by the gradients collected in the backward pass.

<h1>full implementation</h1>

In [None]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward() # compute the gradient of the loss with respect to the model's parameters
        optimizer.step() # adjust the model's parameters based on the computed gradients
        optimizer.zero_grad() # reset the gradients to zero

        if batch % 100 == 0:
            loss, current = loss.item(), batch * batch_size + len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    # Set the model to evaluation mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    # Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
    # also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [10]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 2.306755  [   64/60000]
loss: 2.298370  [ 6464/60000]
loss: 2.274577  [12864/60000]
loss: 2.266144  [19264/60000]
loss: 2.257879  [25664/60000]
loss: 2.219342  [32064/60000]
loss: 2.227928  [38464/60000]
loss: 2.198219  [44864/60000]
loss: 2.202961  [51264/60000]
loss: 2.153521  [57664/60000]
Test Error: 
 Accuracy: 40.2%, Avg loss: 2.162516 

Epoch 2
-------------------------------
loss: 2.176166  [   64/60000]
loss: 2.169892  [ 6464/60000]
loss: 2.112028  [12864/60000]
loss: 2.122149  [19264/60000]
loss: 2.080901  [25664/60000]
loss: 2.011679  [32064/60000]
loss: 2.039401  [38464/60000]
loss: 1.970496  [44864/60000]
loss: 1.985421  [51264/60000]
loss: 1.886009  [57664/60000]
Test Error: 
 Accuracy: 57.1%, Avg loss: 1.905084 

Epoch 3
-------------------------------
loss: 1.943356  [   64/60000]
loss: 1.914415  [ 6464/60000]
loss: 1.798731  [12864/60000]
loss: 1.831680  [19264/60000]
loss: 1.723822  [25664/60000]
loss: 1.666452  [32064/600