In [None]:
import math
import time

import torch

from model import ConvStackModel
from dataset import MNISTDataset

# sci_mode=True is the default which will print values such as 9.8e-01 instead of 0.98.
torch.set_printoptions(precision=3, sci_mode=False)

%load_ext autoreload
%autoreload 2

#### Setup the dataloader, model, loss function & optimizer

In [None]:
mnist_dataset_train = MNISTDataset(split="train")
batch_size = 32
mnist_data_loader_train = torch.utils.data.DataLoader(dataset=mnist_dataset_train, batch_size=batch_size, shuffle=True)

model = ConvStackModel()

bce_loss_fn = torch.nn.BCELoss(reduction='sum')
adam_optimizer = torch.optim.Adam(params=model.parameters(), lr=4e-5)

#### Pull an example element & example batch from the dataloader

In [None]:
# A DataLoader is a Python Iterable (not to be confused with an Iterator). As the name suggests, an Iterable
# is able to be iterated over. The iteration is supported or done by a seperate object called an Iterator. Generally,
# an Iterable implements the __iter__ method which creates & returns an Iterator object.
mnist_data_loader_train_iterator = iter(mnist_data_loader_train)

batch_of_images, batch_of_labels = next(mnist_data_loader_train_iterator)
single_image, single_label = mnist_dataset_train[672]
# Add a batch-dimension with a single entry i.e. of size 1.
single_image = single_image.unsqueeze(dim=0)

#### Training-step using one image

In [None]:
predictions = model(single_image)
probability_of_correct_class = predictions[:, single_label.item()].squeeze()
loss_score = bce_loss_fn(input=probability_of_correct_class, target=torch.tensor(1.0))

print(f"Model predicted a probability of {probability_of_correct_class:.2f} for the correct clas.")
print(f"Loss score: {loss_score:.3f}.")

loss_score.backward()
adam_optimizer.step()
adam_optimizer.zero_grad()

#### Training-step using a batch of images

In [None]:
predictions = model(batch_of_images)

correct_class_indices = batch_of_labels.unsqueeze(dim=1).to(torch.int64)
probability_of_correct_classes = predictions.gather(dim=1, index=correct_class_indices).squeeze()

loss_score = bce_loss_fn(input=probability_of_correct_classes, target=torch.ones(size=(batch_size,)))

# The inferred joint-probability predicted for every correct class in the batch. Change the data-type from float32 to float64
# to reduce the likelihood of underflowing to 0.
joint_probability = probability_of_correct_classes.to(torch.float64).prod()
# We can also recover this joint-probability from the loss-score, since we know the loss is computed as -log(joint_probability).
joint_probability_from_loss = math.exp(-loss_score)
# The joint-probability is the product of all probabilities, so we invert that step to get the per batch-element 
# average probability.
per_batch_element_average_probability = joint_probability_from_loss ** (1/batch_size)

print(f"Joint probability: {joint_probability:.2e}. Joint probability inferred from loss-score: {joint_probability_from_loss:.2e}.")
print(f"The per batch-element average probability is: {per_batch_element_average_probability:.3f}.")
print(f"Loss score: {loss_score:.3f}")

loss_score.backward()
adam_optimizer.step()
adam_optimizer.zero_grad()

#### Train for num_epochs

In [None]:
num_epochs = 3
num_batches_per_epoch = math.ceil(len(mnist_dataset_train) / batch_size)
num_batches_to_print_per_epoch = 10

for epoch_idx in range(num_epochs):
    print(f"\nBeginning epoch: {epoch_idx + 1}.")
    
    for batch_idx, (batch_of_images, batch_of_labels) in enumerate(mnist_data_loader_train):
        
        predictions = model(batch_of_images)
        
        correct_class_indices = batch_of_labels.unsqueeze(dim=1).to(torch.int64)
        probability_of_correct_classes = predictions.gather(dim=1, index=correct_class_indices).squeeze()
        
        loss_score = bce_loss_fn(input=probability_of_correct_classes, target=torch.ones_like(probability_of_correct_classes))
        joint_probability_from_loss = math.exp(-loss_score)

        # Note: the last batch may contain less than batch-size elements.
        per_batch_element_average_probability = joint_probability_from_loss ** (1 / len(batch_of_labels))

        # Only print the first few batches of results for each epoch.
        if batch_idx < num_batches_to_print_per_epoch:
            print(
                f"Batch: {batch_idx + 1}. "
                f"Joint probability inferred from loss-score: {joint_probability_from_loss:.2e}. "
                f"Per batch-element average probability is: {per_batch_element_average_probability:.3f} "
                f"Loss score: {loss_score:.3f}"
            )
            
            # Add a sleep to be able to watch in real-time. Otherwise, the process is too quick.
            time.sleep(2)
        
        elif batch_idx == num_batches_to_print_per_epoch:
            print(
                f"Accelerating through the remaining {num_batches_per_epoch - num_batches_to_print_per_epoch:,} "
                f"batches in this epoch without printing updates..."
            )
        

        loss_score.backward()
        adam_optimizer.step()
        adam_optimizer.zero_grad()