In [1]:
import torch
from torch import nn
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork()

In [2]:
# Hyperparameters
learning_rate = 1e-3
batch_size = 64
epochs = 5

Each epoch consists of two main parts:
1. The train loop iterates over the training dataset and try to converge to optimal parameters.
2. The validation loop iterate over the test dataset to check if model performance is improving.

In [3]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

Inside the training loop, optimization happens in three steps:
1. Call optimizer.zero_grad() to reset the gradients of model parameters.
2. Backpropagate the prediction loss with loss.backward(). 
3. optimizer.step() adjusts parameters by the gradients.

In [9]:
def train_loop(dataloader, model, loss_fn, optimizer):
    # Set the model to training mode, which is important for batch normalization and drouput layers
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        pred = model(X)
        loss = loss_fn(pred, y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f'loss: {loss:>7f}, [{current:>5d}/{size:>5d}]')
    


def test_loop(dataloader, model, loss_fn):
    # Set the model to evaluation mode, which is important for batch normalization and drouput layers
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f'Test error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f}\n')

In [10]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
epochs = 10
for t in range(epochs):
    print(f'Epoch {t + 1} \n ------------------------')
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(test_dataloader, model, loss_fn)
print('Done!')

Epoch 1 
 ------------------------
loss: 2.165168, [   64/60000]
loss: 2.162582, [ 6464/60000]
loss: 2.100841, [12864/60000]
loss: 2.116851, [19264/60000]
loss: 2.083656, [25664/60000]
loss: 2.002547, [32064/60000]
loss: 2.036098, [38464/60000]
loss: 1.951793, [44864/60000]
loss: 1.962247, [51264/60000]
loss: 1.885889, [57664/60000]
Test error: 
 Accuracy: 59.5%, Avg loss: 1.893768

Epoch 2 
 ------------------------
loss: 1.922629, [   64/60000]
loss: 1.902165, [ 6464/60000]
loss: 1.786757, [12864/60000]
loss: 1.828131, [19264/60000]
loss: 1.736524, [25664/60000]
loss: 1.663085, [32064/60000]
loss: 1.687982, [38464/60000]
loss: 1.582586, [44864/60000]
loss: 1.619511, [51264/60000]
loss: 1.503922, [57664/60000]
Test error: 
 Accuracy: 59.9%, Avg loss: 1.530989

Epoch 3 
 ------------------------
loss: 1.593470, [   64/60000]
loss: 1.564381, [ 6464/60000]
loss: 1.416939, [12864/60000]
loss: 1.491319, [19264/60000]
loss: 1.382131, [25664/60000]
loss: 1.362900, [32064/60000]
loss: 1.37810