# Training Functions

## Overview

This notebook provides the code for training our models. 

Two training functions are defined:

* **train**: Implements training for our classic approach
* **trainV2**: Implements training for the novel multitask learning approach

Both functions:

* Set the model to training mode
* Define a CrossEntropyLoss criterion
* Loop through batches
* Perform forward and backward passes
* Optimise model weights
* Track metrics like running loss and accuracy
* Return the average loss on the dataset after an epoch

In addition, the train function initiates the classic model's get_bn_means method to procure the batch normalisation means of the current task.

The training functions encapsulate our models' fundamental training process, including loss computation, weight updates, and metrics tracking.

##  Importing Required Libraries

In [None]:
import torch
import torch.nn as nn
from tqdm import tqdm

## Training the Multitask Model (Classic)

In [None]:
def train(model, trainloader, optimizer, epoch, task_id):
    # Set the model to training mode
    model.train()
    
    # Define the loss function (cross-entropy loss)
    criterion = nn.CrossEntropyLoss()
    
    # Counter to keep track of the number of correct predictions
    num_correct = 0
    
    # Counter to keep track of the total number of samples seen
    total_seen = 0
    
    # Counter to accumulate the total loss during training
    total_loss = 0

    # Iterate through the training data loader, displaying a progress bar
    for i, (batch, labels) in tqdm(
        enumerate(trainloader), ascii=True, total=len(trainloader)):
        # Forward pass of the model to compute logits
        logits = model(batch)
        
        # Compute loss using the cross-entropy criterion
        loss = criterion(logits, labels)
        
        # Accumulate loss for the batch
        total_loss += loss.item()

        # Zero out the gradients
        optimizer.zero_grad()
        
        # Compute gradients through backpropagation
        loss.backward()
        
        # Perform a step of the optimiser to update the weights
        optimizer.step()

        # Log progress every 20 iterations
        if i % 20 == 0:
            # Get the predicted labels
            predictions = logits.argmax(dim=-1)
            
            # Count correct predictions
            num_correct += (predictions == labels).float().sum()
            
            # Increase the total number of samples seen
            total_seen += logits.size(0)
            
            # Display the current epoch, step, loss, and accuracy
            tqdm.write(
                (
                    f"e{epoch} {i+1}/{len(trainloader)}"
                    f" => Loss {loss.item():0.4f}, "
                    f"Acc@1 {(num_correct / total_seen):0.4f}"
                ),
                end="\r",
            )

    # Get the batch normalisation means for the task
    model.get_bn_means(task_id)
    
    # Calculate average loss for the entire dataset
    average_loss = total_loss / len(trainloader)
    
    # Return the average loss and the batch normalisation means for the task
    return average_loss, model.get_bn_means(task_id)

## Training the Multitask Model (Novel)

In [None]:
def trainV2(model, trainloader, optimizer, epoch, bn_means):
    # Set the model to training mode
    model.train()
    
    # Define the loss function (cross-entropy loss)
    criterion = nn.CrossEntropyLoss()
    
    # Counter to keep track of the number of correct predictions
    num_correct = 0
    
    # Counter to keep track of the total number of samples seen
    total_seen = 0
    
    # Counter to accumulate the total loss during training
    total_loss = 0

    # Iterate through the training data loader, displaying a progress bar
    for i, (batch, labels) in tqdm(
        enumerate(trainloader), ascii=True, total=len(trainloader)):
        # Forward pass of the model to compute logits
        logits = model(batch)

        # Compute loss using the cross-entropy criterion
        loss = criterion(logits, labels)
        
        # Accumulate loss for the batch
        total_loss += loss.item()

        # Zero out gradients
        optimizer.zero_grad()
        
        # Compute gradients through backpropagation
        loss.backward()
        
        # Perform a step of the optimiser to update the weights
        optimizer.step()

        # Log training progress every 20 steps
        if i % 20 == 0:
            # Compute predictions by taking the class with the highest logit
            predictions = logits.argmax(dim=-1)
            
            # Update the number of correct predictions
            num_correct += (predictions == labels).float().sum()
            
            # Update the total number of predictions
            total_seen += logits.size(0)
            
            # Display the current epoch, step, loss, and accuracy
            tqdm.write(
                (
                    f"e{epoch} {i+1}/{len(trainloader)}"
                    f" => Loss {loss.item():0.4f}, "
                    f"Acc@1 {(num_correct / total_seen):0.4f}"
                ),
                end="\r",
            )

    # Calculate average loss for the entire dataset
    average_loss = total_loss / len(trainloader)
    
    # Return the average loss for the task
    return average_loss

--------------------------------------------------------------------------------------------------------------------------------

#### Code adapted from:

* https://github.com/pytorch
* https://github.com/RAIVNLab/supsup