Once a PyTorch model is defined, how do you train it and use it?

In [None]:
import torch
import torch.nn as nn
from torchtyping import TensorType
import matplotlib.pyplot as plt

torch.manual_seed(0)

class DigitRecognition(nn.Module):
    def __init__(self):
        super().__init__()
        # Define the architecture here
        self.first_layer = nn.Linear(784, 512)
        self.relu = nn.ReLU()   # introduces nonlinearity to allow model to learn more complex relationships
        self.dropout = nn.Dropout(0.2)
        self.final_layer = nn.Linear(512, 10)
        self.sigmoid = nn.Sigmoid()     # makes all outputs between 0 and 1 to represent a probability
    
    def forward(self, images: TensorType[float]) -> TensorType[float]:
        torch.manual_seed(0)
        # Return the model's prediction to 4 decimal places
        out = self.sigmoid(self.final_layer(self.dropout(self.relu(self.first_layer(images)))))
        return torch.round(out, decimals=4)

In [None]:
model = DigitRecognition()

# loss function to judge probabilities against the true values for classification
loss_function = nn.CrossEntropyLoss()

# optimizer is object in pytorch that does gradient descent for us
# model.parameters() gives all the weights inside NN to optimize and update over the iterations of training
# Adam has optimized tricks on gradient descent to dynamically change the learning rate, default l.r. of 0.001
optimizer = torch.optim.Adam(model.parameters())

# epoch means the model gets trained on the entire training dataset once
# 5 epochs means 5 passes over the training dataset (too many epochs could lead to overfitting and memorization)
epochs = 5

for epoch in range(epochs):
    for images, labels in train_dataloader:     # train_dataloader is iterator giving us tuples of a batch of images with their corresponding labels (we're doing mini-batch GD)
        images = images.view(images.shape[0], 784)  # reshapes 28x28 input data to be flattened vector, torch.reshape could also be used

        # TRAINING BODY
        model_prediction = model(images)
        optimizer.zero_grad()   # cancel out all derivatives calculated in previous iteration of gradient descent
        loss = loss_function(model_prediction, labels)
        loss.backward()     # calculate every single derivative necessary to perform gradient descent, most computationally intensive
        # gets derivative of error w.r.t weights so we can update those weights based on the learning rate
        optimizer.step()    # this line updates all our weights, like new_w = old_w - derivative * learning_rate

In [None]:
model.eval()    # put model in evaluation mode because we want predictions now, don't worry about calculating derivatives

for images, labels in test_dataloader:
    images = images.view(images.shape[0], 784)

    model_prediction = model(images) # has dimension of batch_size x 10
    values, indices = torch.max(model_prediction, dim = 1)  # take max across the cols

    for i in range(len(images)):
        plt.imshow(images[i].view(28, 28))
        plt.show()
        print(indices[i].item())    # .item() extracts the tensor value into a standard python scalar type (float, int, etc)