In [None]:
import math
import time
import random

import torch
import numpy as np

from model import ConvStackModel
from dataset import MNISTDataset

# sci_mode=True is the default which will print values such as 9.8e-01 instead of 0.98.
torch.set_printoptions(precision=3, sci_mode=False)
# manually specify how floats should be printed. Notably, this override uses more spacing
# which I feel makes array easier to read.
np.set_printoptions(formatter={'float': lambda val: f" {val:.3f} "}, linewidth=100)

%load_ext autoreload
%autoreload 2

#### Setup the dataloader, model, loss function & optimizer

In [None]:
mnist_dataset_train = MNISTDataset(split="train")
batch_size = 64
mnist_data_loader_train = torch.utils.data.DataLoader(dataset=mnist_dataset_train, batch_size=batch_size, shuffle=True)

model = ConvStackModel()

bce_loss_fn = torch.nn.BCELoss(reduction='sum')
adam_optimizer = torch.optim.Adam(params=model.parameters(), lr=3e-4)

#### Pull an example element & example batch from the dataloader

In [None]:
single_image, single_label = mnist_dataset_train[672]

# A DataLoader is a Python Iterable (not to be confused with an Iterator). As the name suggests, an Iterable
# is able to be iterated over. The iteration is supported or done by a seperate object called an Iterator. Generally,
# an Iterable implements the __iter__ method which creates & returns an Iterator object.
mnist_data_loader_train_iterator = iter(mnist_data_loader_train)
batch_of_images, batch_of_labels = next(mnist_data_loader_train_iterator)

#### Training-step using one image

In [None]:
predictions = model(single_image)

# The target, or ideal prediction, would have a zero for each irrelevant digit, and a one 
# for the relevant digit (i.e. the label).
target = torch.zeros_like(predictions)
target[single_label.item()] = 1.0

loss_score = bce_loss_fn(input=predictions, target=target)

print(f"The model predicted: \n{predictions.data.numpy()}. \nThe target is: \n{target.data.numpy()}.")
print(f"Loss score: {loss_score:.3f}.")

loss_score.backward()
adam_optimizer.step()
adam_optimizer.zero_grad()

#### Training-step using a batch of images

In [None]:
def print_num_random_batch_entries(predictions: torch.Tensor, target: torch.Tensor, num_entries_to_print: int = 3):
    # The final batch in an epoch may have a batch-size smaller than the rest.
    actual_batch_size = predictions.shape[0]
    random_indices_to_print = sorted(random.sample(range(0, actual_batch_size), k=num_entries_to_print))
    
    for random_idx in random_indices_to_print:
    
        print(f"Batch-entry: {random_idx}.")
        print(f"predictions: \n{predictions[random_idx].data.numpy()}")
        print(f"target: \n{target[random_idx].data.numpy()}")
        print()
    

In [None]:
predictions = model(batch_of_images)

# Again, the target should have 0's for each irrelevant digit and 1 for the relevant digit, for each entry (or image) in
# the batch.
batch_indices = torch.tensor(range(0, len(batch_of_labels)))
target = torch.zeros_like(predictions)
target[batch_indices, batch_of_labels.to(torch.int64)] = 1.0

loss_score = bce_loss_fn(input=predictions, target=target)

# Find the probability each entry is correct.
# We choose the predicted probability (i.e. the values in predictions) is the probability that each entry is True 
# i.e. the input-image is that given digit. The probability the entry is False can be inferred by taking 1 - p.
# For entries where target is 1, (1-target) is 0, minus the predicted probability then an absolute value, the result 
# is the predicted probability for class=1 (i.e. True).
# For entries where target is 0, (1-target) is 1, 1-minus the predicted probability gives us the predicted probability for 
# class=0 (i.e. False).
probability_is_correct = ((1 - target) - predictions).abs()

# The inferred joint-probability predicted for every correct class in the batch. Change the data-type from float32 to float64
# to reduce the likelihood of underflowing to 0.
joint_probability = probability_is_correct.to(torch.float64).prod()
# We can also recover this joint-probability from the loss-score, since we know the loss = -log(joint_probability).
joint_probability_from_loss = math.exp(-loss_score)
# The joint-probability is the product of all probabilities, so we invert that step to get the per batch-element 
# average probability.
per_batch_element_average_probability = joint_probability_from_loss ** (1 / predictions.numel())

print_num_random_batch_entries(predictions, target)

# Note: The smallest float Python can represent is roughly 1e-325. Beyond that, the value becomes 0. 
# This is roughly equal to e^-745. Accordingly, any loss values above 745 result in an inferred 
# probability that underflows to 0.
# Similarly, the smallest float the torch64 datatype can represent is roughly 1e-308 or about e^-710.
print(f"Joint probability: {joint_probability:.2e}. Joint probability inferred from loss-score: {joint_probability_from_loss:.2e}.")
print(f"The per batch-element average probability is: {per_batch_element_average_probability:.3f}.")
print(f"Loss score: {loss_score:.3f}")

loss_score.backward()
adam_optimizer.step()
adam_optimizer.zero_grad()

#### Train for num_epochs

In [None]:
num_epochs = 20
num_batches_per_epoch = math.ceil(len(mnist_dataset_train) / batch_size)
num_batches_to_print_per_epoch = 2

for epoch_idx in range(num_epochs):
    print(f"\nBeginning epoch: {epoch_idx + 1}.")
    
    for batch_idx, (batch_of_images, batch_of_labels) in enumerate(mnist_data_loader_train):
        
        predictions = model(batch_of_images)
        
        batch_indices = torch.tensor(range(0, len(batch_of_labels)))
        target = torch.zeros_like(predictions)
        target[batch_indices, batch_of_labels.to(torch.int64)] = 1.0
        
        loss_score = bce_loss_fn(input=predictions, target=target)
        joint_probability_from_loss = math.exp(-loss_score)

        # Note: the last batch may contain less than batch-size elements.
        per_batch_element_average_probability = joint_probability_from_loss ** (1 / predictions.numel())

        # Only print the first few batches of results for each epoch.
        if batch_idx < num_batches_to_print_per_epoch:
            print(
                f"Batch: {batch_idx + 1}. "
                f"Joint probability inferred from loss-score: {joint_probability_from_loss:.2e}. "
                f"Per batch-element average probability is: {per_batch_element_average_probability:.3f} "
                f"Loss score: {loss_score:.3f}"
            )
            
            # Add a sleep to be able to watch in real-time. Otherwise, the process is too quick.
            time.sleep(2)
        
        elif batch_idx == num_batches_to_print_per_epoch:
            print(
                f"Accelerating through the remaining {num_batches_per_epoch - num_batches_to_print_per_epoch:,} "
                f"batches in this epoch without printing updates..."
            )
        

        loss_score.backward()
        adam_optimizer.step()
        adam_optimizer.zero_grad()

#### Save the model's learned parameters

In [None]:
model_weights = model.state_dict()
# state_dict() is a Python dictionary that maps a string-name (e.g. conv_stack.0.weight, conv_stack_1.bias, etc.)
# to the relevant parameter-tensor.
print(f"Entries in model.state_dict():")
for idx, (k, v) in enumerate(model_weights.items()):
    if idx >= 5:
        break

    print(f"key: {k}")
    print(f"value.shape: {v.shape}")
print("...\n")

model_weights_filename = "model_weights.pt"
print(f"Saving model parameters to {model_weights_filename}")
torch.save(obj=model.state_dict(), f=model_weights_filename)

#### View some sample outputs of the trained model

In [None]:
predictions = model(batch_of_images)

batch_indices = torch.tensor(range(0, len(batch_of_labels)))
target = torch.zeros_like(predictions)
target[batch_indices, batch_of_labels.to(torch.int64)] = 1.0

print_num_random_batch_entries(predictions, target, 5)