# Test Time Adaptation on a CNN (ResNet50)

In this project we implement a technique to improve the performance of a pretrained model at test time. This technique is called **Test Time Adaptation**.

Test Time Adaptation adjusts the model parameters during the test phase, leveraging the test data itself to enhance the model's predictions.

## Dataset
The dataset we use is the *Imagenet-V2 matched frequency* dataset. This dataset is composed of 10000 images divided in 1000 classes.

The function `CustomImageFolder` is used to correct the wrong naming of the classes caused by the original ImageFolder function.

Then:
- We import the full dataset from the files.
- We decide the number of samples to select from the dataset.
- We create a subset of the full dataset by sampling randomly from the full dataset.

In [None]:
!wget https://huggingface.co/datasets/vaishaal/ImageNetV2/resolve/main/imagenetv2-matched-frequency.tar.gz
!tar -xf imagenetv2-matched-frequency.tar.gz
!pip install timm

In [None]:
# Path to dataset
dataset_path = 'imagenetv2-matched-frequency-format-val'

In [None]:
import os
import torch
import numpy as np

from torch.utils.data import Dataset
from PIL import Image

# Customization of the class ImageFolder to import the real name of the classes of Imagenet-v2
class CustomImageFolder(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted(os.listdir(root_dir))
        self.class_to_idx = {cls_name: int(cls_name) for cls_name in self.classes}
        self.image_paths = []
        self.labels = []

        for cls_name in self.classes:
            cls_dir = os.path.join(root_dir, cls_name)
            for img_name in os.listdir(cls_dir):
                img_path = os.path.join(cls_dir, img_name)
                self.image_paths.append(img_path)
                self.labels.append(self.class_to_idx[cls_name])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

# Import the full dataset from the files
full_dataset = CustomImageFolder(root_dir=dataset_path)

# Decide the number of sample to use for the test and create a random subset
num_samples = 3000
np.random.seed(0)
indices = np.random.choice(len(full_dataset), num_samples, replace=False)
subset_dataset = torch.utils.data.Subset(full_dataset, indices)

## Model
We load the pretrained *Resnet50d* model from *Timm*.

This model features:
- ReLU activations
- 3-layer stem of 3x3 convolutions with pooling
- 2x2 average pool + 1x1 convolution shortcut downsample
- Trained on ImageNet-1k

In [None]:
# !pip install timm

In [None]:
import timm

# Load the Resnet50d model from timm
model = timm.create_model('resnet50d', pretrained=True)
# Save the initial state of the weights
initial_state = model.state_dict()

At each iteration we reset the network. We create a function to load the initial state of the model.

In [None]:
# Function to reset the state of the model to a saved initial state
def reset_model(model, initial_state):
    model.load_state_dict(initial_state)

## Preprocess the data

We define three sets of transformations:
- `preprocess_step1_test` does the resizing and cropping for the test function.
- `preprocess_step1` does a random resized crop and a random horizontal flip with probability of 0.5 for the adapt function.
- `preprocess_step2` converts the image to a tensor and then perform a normalization.

These transformations are used for correctly feeding the images to the test and adapt functions.

In [None]:
from torchvision import transforms

preprocess_step1_test = transforms.Compose([
    transforms.Resize(256, antialias=True),
    transforms.CenterCrop(224)
])

preprocess_step1 = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip()
])

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
preprocess_step2 = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

## Augmentations

We define a set of augmentations.
These augmentations are taken from the [official MEMO implementation](https://github.com/zhangmarvin/memo).

In [None]:
from PIL import ImageOps

def autocontrast(pil_img, level=None):
    return ImageOps.autocontrast(pil_img)

def equalize(pil_img, level=None):
    return ImageOps.equalize(pil_img)

def rotate(pil_img, level):
    degrees = int_parameter(rand_lvl(level), 30)
    if np.random.uniform() > 0.5:
        degrees = -degrees
    return pil_img.rotate(degrees, resample=Image.BILINEAR, fillcolor=128)

def solarize(pil_img, level):
    level = int_parameter(rand_lvl(level), 256)
    return ImageOps.solarize(pil_img, 256 - level)

def shear_x(pil_img, level):
    level = float_parameter(rand_lvl(level), 0.3)
    if np.random.uniform() > 0.5:
        level = -level
    return pil_img.transform((224, 224), Image.AFFINE, (1, level, 0, 0, 1, 0), resample=Image.BILINEAR, fillcolor=128)

def shear_y(pil_img, level):
    level = float_parameter(rand_lvl(level), 0.3)
    if np.random.uniform() > 0.5:
        level = -level
    return pil_img.transform((224, 224), Image.AFFINE, (1, 0, 0, level, 1, 0), resample=Image.BILINEAR, fillcolor=128)

def translate_x(pil_img, level):
    level = int_parameter(rand_lvl(level), 224 / 3)
    if np.random.random() > 0.5:
        level = -level
    return pil_img.transform((224, 224), Image.AFFINE, (1, 0, level, 0, 1, 0), resample=Image.BILINEAR, fillcolor=128)

def translate_y(pil_img, level):
    level = int_parameter(rand_lvl(level), 224 / 3)
    if np.random.random() > 0.5:
        level = -level
    return pil_img.transform((224, 224), Image.AFFINE, (1, 0, 0, 0, 1, level), resample=Image.BILINEAR, fillcolor=128)

def posterize(pil_img, level):
    level = int_parameter(rand_lvl(level), 4)
    return ImageOps.posterize(pil_img, 4 - level)

def int_parameter(level, maxval):
    """Helper function to scale `val` between 0 and maxval .
    Args:
    level: Level of the operation that will be between [0, `PARAMETER_MAX`].
    maxval: Maximum value that the operation can have. This will be scaled
      to level/PARAMETER_MAX.
    Returns:
    An int that results from scaling `maxval` according to `level`.
    """
    return int(level * maxval / 10)

def float_parameter(level, maxval):
    """Helper function to scale `val` between 0 and maxval .
    Args:
    level: Level of the operation that will be between [0, `PARAMETER_MAX`].
    maxval: Maximum value that the operation can have. This will be scaled
      to level/PARAMETER_MAX.
    Returns:
    A float that results from scaling `maxval` according to `level`.
    """
    return float(level) * maxval / 10.

# Function to compute a random level of strength to pass to the augmentations
def rand_lvl(n):
    return np.random.uniform(low=0.1, high=n)

augmentations = [
    autocontrast,
    equalize,
    lambda x: rotate(x, 1),
    lambda x: solarize(x, 1),
    lambda x: shear_x(x, 1),
    lambda x: shear_y(x, 1),
    lambda x: translate_x(x, 1),
    lambda x: translate_y(x, 1),
    lambda x: posterize(x, 1),
]

We then define the `_single_aug` function to compute a single augmentation.
This function takes as input an image and applies an augmentation chosen randomly from the augmentations set.

In [None]:
# Given a sample, computes a random augmentation from the augmentations set and then returns the augmented sample
def _single_aug(x_orig):
    x_orig = preprocess_step1(x_orig)
    x_aug = x_orig.copy()
    x_aug = np.random.choice(augmentations)(x_aug)
    x_aug = preprocess_step2(x_aug)
    return x_aug

The `compute_n_aug` function is used to adaptively decide how many augmentations to compute per sample. If the initial confidence is low we compute more augmentations.

In [None]:
import math

# If adaptive_augmentation_strength is True, then we return a number of augmentations inversely proportional to the confidence of the model
def compute_n_aug(conf_sample_clean, n_aug_max, adaptive_augmentation_strength):
    if adaptive_augmentation_strength:
        n_aug = math.ceil((n_aug_max * (1-conf_sample_clean))+3)
    else:
        n_aug = n_aug_max
    return n_aug

## Adapt and Test functions

### Adapt function

The `adapt_single` function is used to compute the augmentations from an original sample, stack the multiple images (original and augmented) into a single tensor and then compute the output of the model.
Then the function computes the marginal entropy loss and the gradient, and it updates the model weights with the optimizer to minimize the loss.

In [None]:
# We set the device to cuda if it is available, otherwise we set it to cpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Transformation function to use
tr_transforms = _single_aug

# Function for updating the model by computing the gradient on a set of augmented samples
def adapt_single(model, image, optimizer, criterion, conf_sample_clean, n_aug_max, adaptive_augmentation_strength=False):

    model.eval()
    # Compute the number of augmentations to perform
    n_aug = compute_n_aug(conf_sample_clean, n_aug_max, adaptive_augmentation_strength)

    # Perform the augmentations and put it on a list together with the original sample
    inputs = [preprocess_step2(preprocess_step1_test(image))] + [tr_transforms(image) for _ in range(n_aug)]
    # Stack the augmented and the original samples on a single tensor where the first dimension now is number of augmentations + 1
    inputs = torch.stack(inputs).to(device)
    # Zero the gradient
    optimizer.zero_grad()
    # Compute the output of the model given the inputs
    outputs = model(inputs)
    # Compute the loss
    loss = criterion(outputs)
    # Compute the gradient
    loss.backward()
    # Update the model parameters, SGD step
    optimizer.step()

### Test function
The `test_single` function, given a single sample and the relative ground truth label, checks if the label predicted by the model is the same as the ground truth. Returns 1 if true, 0 if false.

In [None]:
# Function for testing the model on a single sample and checking if the label predicted is the same as the ground truth
def test_single(model, image, label):
    model.eval()
    with torch.no_grad():
        image = preprocess_step1_test(image)
        image = preprocess_step2(image)
        outputs = model(image.unsqueeze(0).to(device))
        _, predicted = torch.max(outputs.data, 1)
        result = 1 if (predicted == label) else 0
    return result

## Loss Function / Criterion

Since we don't have access to the labels at test-time, we need an unsupervised loss function.
We use the marginal entropy loss. This loss is computed on the aggregated probabilities of the augmented data, encouraging the model to make the same (confident) prediction across multiple augmentations.

In [None]:
# This function is the loss function or criterion used
def marginal_entropy(outputs):
    # Normalize the logits by computing the log-softmax
    logits = outputs - outputs.logsumexp(dim=-1, keepdim=True)
    # Aggregate the log-probabilities across the batch and normalize by the batch size
    avg_logits = logits.logsumexp(dim=0) - np.log(logits.shape[0])
    # Compute the marginal entropy by multiplying the log-probabilities with the probabilities
    return -(avg_logits * torch.exp(avg_logits)).sum(dim=-1)

## Main TTA loop
We perform a loop, iterating over all samples, where for each sample we:
- Compute if the model correctly predicts the label of the sample (after doing this over all samples we can get the **accuracy** of the model).
- Compute augmentations and update the model.
- Again compute if the model correctly predicts the label of the sample.

It is important to note that the first and third steps are performed only to evaluate the performance difference caused by the adapatation step. The TTA procedure doesn't need the labels.

By doing this we can check if the process of updating the model to make the same confident predictions gives us an useful increment in accuracy.

In [None]:
import torch.nn as nn
from tqdm import tqdm

softmax = nn.Softmax(dim=1)

def tta(model, subset):

    model.to(device)
    model.eval()

    #  For the optimizer we use Stochastic Gradient Descent
    optimizer = torch.optim.SGD(model.parameters(), lr=0.005)

    tot_outputs_augm = 0.0
    tot_outputs_clean = 0.0
    running_average_sample_number = 0

    # We iterate over all samples, one iteration for sample
    with tqdm(total=num_samples) as pbar:
        for i in range(num_samples):
            # Retrieve the image and respective label from the subset
            image, label = subset [i]
            # We reset the model for each sample
            reset_model(model, initial_state)

            # Test if the model correctly predicts the label of the sample
            output_sample_clean = test_single(model, image, label)
            # Sum all results, after the training we average to get the accuracy
            tot_outputs_clean = tot_outputs_clean + output_sample_clean

            # Compute the confidence of the model on its prediction
            conf_sample_clean = torch.max(softmax(model(preprocess_step2(preprocess_step1_test(image)).unsqueeze(0).to(device))), dim=1).values[0]
            # If the confidence is over 0.7 we skip altogether the augmentation and optimization step
            if(conf_sample_clean<0.7):
                # Adapt loop, computes the loss of output of the augmentations and updates the model
                adapt_single(model, image, optimizer, marginal_entropy,
                    conf_sample_clean, n_aug_max=16, adaptive_augmentation_strength = True)

            # Test again if the model correctly predicts the label of the sample
            output_sample_augm = test_single(model, image, label)
            # Sum all results, after the training we average to get the accuracy
            tot_outputs_augm = tot_outputs_augm + output_sample_augm

            # Some accuracy statistics to be visualized mid-run
            running_average_sample_number +=1
            running_average_original = tot_outputs_clean/running_average_sample_number
            running_average_augmented = tot_outputs_augm/running_average_sample_number
            difference = (running_average_augmented - running_average_original) * 100
            # Update progress bar with difference
            pbar.set_postfix({
            'Difference': f'{difference:.1f} %'
            })
            pbar.update(1)

    print("Original accuracy over all samples:", tot_outputs_clean/num_samples)
    print("Accuracy after adaptation over all samples:", tot_outputs_augm/num_samples)
    print("Difference in accuracy over all samples:", f'{(tot_outputs_augm/num_samples - tot_outputs_clean/num_samples)*100:.1f} %')


We run the TTA.

In [None]:
# Run the tta
tta(model, subset_dataset)

## Results

We performed multiple runs of the code with different numbers of augmentations.

- We decided to set a treshold for when to apply the adaptation step on the sample to confidence <0.7 as we observed that the accuracy remained the same but the computation time was reduced.

- After trying different augmentations with different level of strength, we found a good tradeoff with the ones presented.

- We introduced an adaptive mechanism for selecting the number of augmentations based on the confidence level. This approach produced very similar accuracy results while saving computational time, especially when dealing with a high number of augmentations.

The performances are shown in the table.
- The baseline accuracy for the *Resnet50d* on the *Imagenet-V2 matched frequency* dataset is 69.2%.
- The difference in accuracy is represented as a percentage, where a positive value means that the accuracy increases after the adaptation step.
- We represent the data relative to the number of augmentations performed and annotate if the adaptive augmentation strength was set to *True*.
- We also take note of the time it takes to compute, represented as iterations per second (it/s)(using an Nvidia T4 GPU).

|Number of augmentations|Difference|Time to compute|Adaptive augmentation strength|
|---|---|---|---|
|1|0 %|12 it/s|False|
|2|0.7 %|11.27 it/s|False|
|4|0.8 %|9.96 it/s|False|
|8|0.7 %|8.17  it/s|False|
|16|1.1 %|5.89 it/s|False|
|32|1.6 %|3.89 it/s|False|
|32|1.5 %|4.96 it/s|True|
|64|1.5 %|2.23 it/s|False|
|64|1.6 %|3.27 it/s|True|

We can see that by increasing the number of augmentations we can achieve a better accuracy, at the cost of using more computing time. It is also noticeable the reduction in computing time when using the adaptive augmentation strength, while still having acceptable accuracy.

#### Different techniques we tried that didn't give an increase in performance:

During the realization of this project, we used different techniques and performed multiple experiments, some of which didn't bring an increase in performance. Some of them were based on:

- Different sets of augmentations for different levels of confidence, following the idea that we could apply more disruptive augmentations on low confidence samples.

- Different levels of strength of the augmentations for difference levels of confidence.

- Using the sample and the augmentations to build a sort of batch and compute batch normalization with the relative mean and variance.

- For each augmentation, after passing it in the model but before using it to calculate the marginal entropy, the augmentation confidence was checked and if it was not above a certain treshold it was discarded. So in this case the marginal entropy was computed using only augmentations with confidence above the threshold. In the case there were no augmentations above the threshold, the adapt step was skipped and the original sample prediction was considered.

- During TTA, samples with low confidence were temporarly skipped and grouped together. After iterating over all the samples, the adapt step was performed on this low confidence group either without resetting the model or by performing multiple cycles of the adaptation step for each sample.

- Using the confidence gap (difference between the two highest confidence of predictions) over certain thresholds instead of using the confidence level.

## Conclusions

In this project, we demonstrated the effectiveness of Test Time Adaptation using augmentations in improving the model accuracy and robustness.

With this approach a tradeoff between the computational time and the increment in accuracy can be selected.

TTA, being a technique applied at test time without any additional information coming with the samples, cannot provide drastical improvement in accuracy. It still remains an useful and easy to implement technique.