# Evaluation and Prediction Procedures

## Overview

This notebook defines our code for evaluating our neural network models on the validation dataset and making predictions on new data.

It defines an evaluation function which runs the models in evaluation mode on the validation loader, compares predictions to truth labels, and returns overall accuracy. This is utilised to monitor validation performance during training.

Next, a prediction function is defined to run the models on a loader and collect all predictions and corresponding labels. This can be used after training to make predictions on new data.

The torch.no_grad decorator indicates these functions do not require gradient calculation since they are just doing forward passes through the models. Both functions use a loop over batches and tqdm for progress bar monitoring.

##  Importing Required Libraries

In [None]:
import torch
import torch.nn as nn
from tqdm import tqdm

## Evaluation Function

In [None]:
# Decorator to indicate that the function does not require gradient calculations
@torch.no_grad()
def evaluate(model, val_loader, epoch):
    # Set the model in evaluation mode
    model.eval()

    # Initialise counters for correct predictions and total predictions
    num_correct = 0
    total_seen = 0

    # Loop over all batches of data in the validation loader
    for batch, labels in tqdm(val_loader, ascii=True, total=len(val_loader)):
        # Forward pass: compute the model output (logits) for the current batch
        logits = model(batch)
        # Compute predictions by taking the class with highest logit
        predictions = logits.argmax(dim=-1)
        # Update number of correct predictions
        num_correct += (predictions == labels).float().sum()
        # Update total number of predictions
        total_seen += logits.size(0)

    # Display the validation performance after the current epoch
    tqdm.write(
        f"Val Perf after {epoch + 1} epochs "
        f"Acc@1 {(num_correct / total_seen):0.4f}",
    )

    # Return the accuracy for this evaluation
    return num_correct / total_seen

## Prediction Function

In [None]:
# Decorator to indicate that the function does not require gradient calculations
@torch.no_grad()
def predict(model, val_loader, epoch):
    # Set the model in evaluation mode
    model.eval()

    # Initialise lists to store all predictions and corresponding labels
    all_predictions = []
    all_labels = []

    # Loop over all batches of data in the validation loader
    for batch, labels in tqdm(val_loader, ascii=True, total=len(val_loader)):
        # Forward pass: compute the model output (logits) for the current batch
        logits = model(batch)
        # Compute predictions by taking the class with highest logit
        predictions = logits.argmax(dim=-1)
        # Append current batch's predictions and labels to their respective lists
        all_predictions += list(predictions)
        all_labels += list(labels)

    # Return lists of all predictions and corresponding labels
    return all_predictions, all_labels

--------------------------------------------------------------------------------------------------------------------------------

#### Code adapted from:

* https://github.com/pytorch
* https://github.com/RAIVNLab/supsup