# Test-time Adaptation for Image Classification

**Authors**: Davide Cavicchini, Laurence Bonat, Lorenzo Orsingher  
**Date**: 01/06/2024

# Structure

1. [Introduction](#Introduction): Overview of the problem, proposed solutions, and some implementation details on how results are stored
2. [Imports](#Imports): Imports the necessary libraries and modules
3. [Parameters](#Parameters): Defines the tests to be performed and the parameters to be used, each macro technique (MEMO, TPT, Enesemble) will have a dedicated section to explain the parameters used
4. [Data Loading](#Data-Loading): Load the data and prepare it for training
5. [Models](#models): Talks in detail about the strategies used to adapt the model at test time, about the motivations behind them and the implementation details.
6. [Results](#test-time-adaptation): Shows the results of the tests performed
7. [Conclusions](#Conclusions): Conclusions

# Introduction
This report presents a study of various deep learning techniques for image classification, specifically focusing on "Test-time Prompt Tuning" (TPT) and "Test Time Robustness via Adaptation and Augmentation" (MEMO). We introduce several contributions, including MEMO with dropout, a simple ensemble with probability marginalization, and an ensemble of multiple models.

## Reference Techniques
### Baselines
As baselines we decided to use the following models:
- RN50: ResNet50, a 50-layer deep convolutional neural network trained on the ImageNet dataset. Both with V1 and V2 versions.
- CLIP RN50: A model that uses a vision encoder and a language encoder to learn visual concepts from natural language supervision. We use the RN50 version.

### [MEMO](https://arxiv.org/abs/2110.09506)
Propose a simple approach that can be used in any test setting where the model is probabilistic and adaptable: when presented with a test example, perform different data augmentations on the data point, and then adapt (all of) the model parameters by minimizing the entropy of the model’s average, or marginal, output distribution across the augmentations.

### [Test-time Prompt Tuning (TPT)](https://arxiv.org/abs/2209.07511)
For image classification, TPT optimizes the prompt given to CLIP by minimizing the entropy with confidence selection so that the model has consistent predictions across different augmented views of each test sample.

## Our Contributions

### TPT with Alignment Steps
Another explored idea is to align the embeddings of the augmented images before passing them to the classifier.

Since the augmented images should be representing the same thing to us, we first apply a gradient update to the image classifier to make the embeddings closer to each other. As with TPT we apply confidence selection to take into account only the relevant augmentations that hopefully do not destroy the content of the image.

### Simple Ensemble with Probability Marginalization
Both MEMO and TPT are trying to minimize the entropy of the model's average output distribution across the augmentations. This, in turn, should maximize the probability of the correct class, affecting the model's output to the original image.

So the idea here is also quite straightforward: since we are maximizing the probability of the class predicted by the ensemble, why not just return this? No need to backpropagate the network.

### Ensemble with Dropout
We propose to use dropout to stochastically remove features that are not relevant for the classification, using networks such as ResNet50.

The idea is naive: if an image should be classified as something, then most of the features extracted should point toward it, but there might be some noise from other features that influence the result or may even strongly polarize the classification. With dropout we hope to stochastically remove these features, while keeping the correct ones that (hopefully) are the majority.

This approach can also be implemented to be extremely efficient since we only need to pass the last classification head multiple times, while the image has to go through the network only once.

### Ensemble of Multiple Models
Finally, we wanted to see if multiple models trained with different techniques could improve the performance. The idea is to have different models covering the out of distribution data of one another, leading to a more robust ensemble.

## Results Storage
After the execution, each experiment is stored in a separate directory, with the following structure:
```bash
└── results
    └── results_{timestamp}
        ├── TPT
        │   ├── {experiment_name}.txt    - dump of the console output during the experiment
        │   ├── {experiment_name}.json   - dump of the experiment configuration and obtained results
        │   └── ...
        ├── MEMO
        │   ├── {experiment_name}.txt    - dump of the console output during the experiment
        │   ├── {experiment_name}.json   - dump of the experiment configuration and obtained results
        │   └── ...
        └── Ensemble
            ├── {experiment_name}.txt    - dump of the console output during the experiment
            ├── {experiment_name}.json   - dump of the experiment configuration and obtained results
            └── ...
```

# Imports

In [None]:
!pip install opencv-python-headless
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git
!wget https://raw.githubusercontent.com/DavidC001/MEMO-TPT-DL2024/main/dataloaders/wordNetIDs2Classes.csv --no-check-certificate

In [None]:

import os
import json
import time
import sys
import torch
import csv
import boto3
import numpy as np
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.models as torch_models

from torch import nn
from torch import optim
from torch.utils.data import Dataset
from torchvision.transforms.v2 import AugMix
from torchvision.transforms import InterpolationMode
from torchvision.models import resnet50, ResNet50_Weights
from copy import deepcopy
from clip import load, tokenize
from PIL import Image
from io import BytesIO
from pathlib import Path
from datetime import datetime
from tqdm import tqdm
from datetime import timedelta


In [None]:
os.makedirs("results", exist_ok=True)
RESULTS_PATH = f"results/results_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
os.makedirs(RESULTS_PATH, exist_ok=True)

#set the seed
torch.manual_seed(0)
np.random.seed(0)

# Parameters

## MEMO
In this section we define the useful parameters that are used to run the test on MEMO. 
We define the device we use with memo and which tests we want to run with the following variables: `memo_tests`, `drop_tests`, `ensemble_tests`, `baseline_tests`. 
The parameters `augs_no_selection` and `augs_selection` are used to choose how many augmentation we apply if we use the topk selection or not.

There are then three variables that define specific paths for the tests: 
- `DATASET_ROOT` is the root of the datasets, in this notebook it is simply ignored, and we use the AWS bucket
- `DATASET_TO_TEST` is the dataset to test, can be either 'a' for ImageNet-A or 'v2' for ImageNetV2, 'both' for both
- `INITIAL_WEIGHTS` is the weights to use to initialize the resnet model. For this case we have 'default' for defaults ones and 'v1' for v1 weights

The dictionary `memo_base_test` is the default configuration for the MEMO tests, it is used as a base for all the tests and then we override the parameters that need to be changed. An overview of the parameters is given below:
    
- **memo**: Contains the parameters for the EasyMEMO class
  - **device**: Device to use for the test
  - **prior_strength**: Prior strength for the Batch Normalization layer that we modified
  - **lr**: Learning rate for the optimizer
  - **weight_decay**: Weight decay for the optimizer
  - **opt**: Optimizer to use for the neural network
  - **niter**: Number of optimization steps. The default is 1, and it is what is used in every single test. 
  - **top**: Percentage of the top augmentations to consider (confidence selection)
  - **ensemble**: Whether to run the model in augmentation ensemble mode

- **dataset**: Contains the parameters for the dataset
  - **imageNetA**: Whether to use ImageNet-A or not
  - **naug**: Number of augmentations to perform
  - **dataset_root**: Root of the datasets
  - **aug_type**: Type of augmentation to use between `augmix`, `cut` and `identity`
- **weights**: Weights to use to initialize the resnet model. For this case we have 'default' for defaults ones and 'v1' for v1 weights
- **run**: Whether to run the test or not
- **drop**: Dropout to use for the simple ensemble model. If we set it to 0 we have identity layer


In [None]:
memo_device = "cuda" if torch.cuda.is_available() else "cpu"
memo_tests = True
drop_tests = True
ensemble_tests = True
baseline_tests = True
augs_no_selection = 8
augs_selection = 64

# The root of the datasets, in this notebook it is simply ignored, and we use the AWS bucket
DATASET_ROOT = 'datasets'
# Which dataset to test: 'a' for ImageNet-A, 'v2' for ImageNet-V2, 'both' for both
DATASET_TO_TEST = 'both'
# Which Weights I want to use to initialize the resnet model. For this case we have 'default' for defaults ones and 'v1' for v1 weights
INITIAL_WEIGHTS = 'default'

memo_base_test = {
    "memo": {
        "device": memo_device,
        "prior_strength": 1.0,
        "lr": 0.005,
        "weight_decay": 0.0001,
        "opt": 'sgd',
        "niter": 1,
        "top": 1,
        "ensemble": False,
    },
    "dataset": {
        "imageNetA": True,
        "naug": 1,
        "dataset_root": DATASET_ROOT,
        "aug_type": "augmix",
    },
    "weights": INITIAL_WEIGHTS,
}

## TPT

To make the test suite as self-contained as possible, the tests are defined as lists of dictionaries where each dictionary contains all the parameters for a single test, the parameters are the same as the ones accessible via command line in the stand-alone version of the code in the `main.py` file.

***tpt_base_test*** includes the default settings for all the parameters, each new test will inherit from this dictionary and override the parameters that need to be changed, in this way only the parameters that are different from the default settings need to be specified.

- **name**: Name of the test, used to identify the experiment in the results folder and to briefly describe the test
- **dataset**: Dataset to use for the test, can be either 'A' for ImageNet-A or 'V2' for ImageNetV2
- **augs**: Number of augmentations of the orignal image to use for the test, the default value is 64
- **ttt_steps**: Number of test-time tuning steps to be performed on the prompt
- **align_steps**: Number of alignment steps to be performed on image embedding before classification
- **ensemble**: whether to run the model in augmentation ensemble mode. When this parameter is set to true the model will skip the prompt tuning
- **test_stop**: Number of test samples to be used in this run, the default value -1 means that all the dataset will be used
- **confidence**: Confidence selection threshold for the TPT model, the default value is 0.10 
- **base_prompt**: Base prompt to be used for the TPT model, the default value is 'A photo of a [CLS]' where [CLS] is the placeholder for the class to be predicted. The class token can be placed anywhere in the sentence
- **arch**: Architecture of the backbone for CLIP's visual encoder
- **splt_ctx**: Whether to keep the context vector separate between prompt prefix and suffix or not
- **lr**: Learning rate for the prompt tuning
- **device**: Device to use for the test

For consistency, the parameters we picked for our tests are as close as possible to the ones used in the original paper.

In [None]:
VERBOSE = 5

tpt_base_test = {
    "name": "Base",
    "dataset": "A",
    "augs": 64,
    "ttt_steps": 1,
    "align_steps": 0,
    "ensemble": False,
    "test_stop": -1,
    "confidence": 0.10,
    "base_prompt": "A photo of a [CLS]",
    "arch": "RN50",
    "splt_ctx": False,
    "lr": 0.005,
    "device": "cuda:0",
}

# test_stop stops the testing after a certain number of samples,
# to run the entire dataset keep it at -1.
# tests using V2 with image alignment or prompt tuning are too big to fit into the GPU, 
# look at EasyTPT/test.py for the configurations used to get the results in the report.
tpt_tests = [
        {
            "name": "TPT_baseline_A",
            "dataset": "A",
            "augs": 1,
            "ensemble": True,
            "confidence": 1,
        },
        {
            "name": "TPT_sel_A",
            "dataset": "A",
        },
        {
            "name": "TPT_ens_nosel_A",
            "dataset": "A",
            "augs": 8,
            "ensemble": True,
            "confidence": 1,
        },
        {
            "name": "TPT_ens_sel_A",
            "dataset": "A",
            "ensemble": True,
            "confidence": 0.10,
        },
        {
            "name": "TPT_align_A",
            "dataset": "A",
            "align_steps": 1,
        },
        {
            "name": "TPT_baseline_V2",
            "dataset": "V2",
            "augs": 1,
            "ensemble": True,
            "confidence": 1,
        },
        {
            "name": "TPT_ens_nosel_V2",
            "dataset": "V2",
            "augs": 8,
            "ensemble": True,
            "confidence": 1,
        },
        {
            "name": "TPT_ens_sel_V2",
            "dataset": "V2",
            "augs": 64,
            "ensemble": True,
            "confidence": 0.1,
        },
    ]

## Ensemble

Similarly, for the tests to perform with the ensemble, we use a dictionary with the following structure:

```python
{
    "TEST_NAME": {
        "imageNetA" : True, # if the test should be performed on ImageNet-A or on ImageNet-V2
        "naug" : 64, # number of augmentations to perform
        "top" : 0.1, # percentage of the top augmentations to consider (confidence selection)
        "niter" : 1, # number of optimization steps
        "testSingleModels" : True, # if we wanto to also compute the results for the single models
        "simple_ensemble" : True, # if we want to also compute the results for the simple ensemble strategy
        "device" : "cuda", # device to use for the computation
            
        "models_type" : ["memo", "tpt", "..."], # list of models types to use for the ensemble, can be "memo" or "tpt"
        "args" : [ # arguments for each model
            {"device": "cuda", "drop": 0, "ttt_steps": 1, "model": "RN50"}, # arguments for the first model
            {"device": "cuda", "ttt_steps": 1, "align_steps": 0, "arch": "RN50"}, # arguments for the second model
            "..."
            ],
        "temps" : [1.55, 0.7], # temperature rescaling to use for each model
        "names" : ["MEMO", "TPT"], # names to use for each model
    }
}
```

In [None]:
ENSTests = {
        "ImageNet-A RN50 + RNXT": {
            "imageNetA" : True,
            "naug" : 64,
            "top" : 0.1,
            "niter" : 1,
            "testSingleModels" : True,
            "simple_ensemble" : True,
            "device" : "cuda",
            
            "models_type" : ["memo", "memo"],
            "args" : [
                {"device": "cuda", "drop": 0, "ttt_steps": 1, "model": "RN50"},
                {"device": "cuda", "drop": 0, "ttt_steps": 1, "model": "RNXT"}
                ],
            "temps" : [1, 1],
            "names" : ["MEMO RN50", "MEMO RNXT"],
        },

        # THIS TEST IS TOO BIG TO RUN IN THE AWS INSTANCE
        # "ImageNet-V2 TPT RN50 + RNXT": {
        #     "imageNetA" : False,
        #     "naug" : 64,
        #     "top" : 0.2,
        #     "niter" : 1,
        #     "testSingleModels" : True,
        #     "simple_ensemble" : True,
        #     "device" : "cuda",
            
        #     "models_type" : ["memo", "memo"],
        #     "args" : [
        #         {"device": "cuda", "drop": 0, "ttt_steps": 1, "model": "RN50"},
        #         {"device": "cuda", "drop": 0, "ttt_steps": 1, "model": "RNXT"}
        #         ],
        #     "temps" : [1, 1],
        #     "names" : ["MEMO RN50", "MEMO RNXT"],
        # },

        "ImageNet-A TPT + MEMO": {
            "imageNetA" : True,
            "naug" : 64,
            "top" : 0.2,
            "niter" : 1,
            "testSingleModels" : True,
            "simple_ensemble" : True,
            "device" : "cuda",
            
            "models_type" : ["memo", "tpt"],
            "args" : [
                {"device": "cuda", "drop": 0, "ttt_steps": 1, "model": "RN50"},
                {"device": "cuda", "ttt_steps": 1, "align_steps": 0, "arch": "RN50"}
                ],
            "temps" : [1.55, 0.7],
            "names" : ["MEMO", "TPT"],
        },

        # also too big to run in the AWS instance
        # "ImageNet-A RN50 + RNXT + TPT": {
        #     "imageNetA": True,
        #     "naug": 64,
        #     "top": 0.1,
        #     "niter": 1,
        #     "testSingleModels": True,
        #     "simple_ensemble": True,
        #     "device": "cuda",

        #     "models_type": ["memo", "memo", "tpt"],
        #     "args": [
        #         {"device": "cuda", "drop": 0, "ttt_steps": 1, "model": "RN50"},
        #         {"device": "cuda", "drop": 0, "ttt_steps": 1, "model": "RNXT"},
        #         {"device": "cuda", "ttt_steps": 1, "align_steps": 0, "arch": "RN50"}
        #     ],
        #     "temps": [1, 1, 0.7],
        #     "names": ["MEMO-RN50", "MEMO-RNXT", "TPT"],
        #     "dataset_root": DATASET_ROOT,
        # },
    }

# Data Loading

## mappings

In [None]:
imagenet_classes = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", "box turtle", "banded gecko", "green iguana", "Carolina anole", "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", "American alligator", "triceratops", "worm snake", "ring-necked snake", "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", "freight car", "French horn", "frying pan", "fur coat", "garbage truck", "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"]

## ImageNet-A

In [None]:
class ImageNetA(Dataset):
    """
    A custom dataset class for loading images from the ImageNet-A dataset.

    Args:
        root (str): The root directory of the dataset.
        csvMapFile (str, optional): The path to the CSV file containing the mapping of WordNet IDs to class names. Defaults to "dataloaders/wordNetIDs2Classes.csv".
        transform (callable, optional): A function/transform that takes in an image and returns a transformed version. Defaults to None.
    """

    def __init__(
        self, root, csvMapFile="wordNetIDs2Classes.csv", transform=None
    ):
        self.s3_bucket = "deeplearning2024-datasets"
        self.s3_region = "eu-west-1"
        self.s3_client = boto3.client("s3", region_name=self.s3_region, verify=True)

        response = self.s3_client.list_objects_v2(Bucket=self.s3_bucket, Prefix=root)
        objects = response.get("Contents", [])
        #print(objects)
        while response.get("NextContinuationToken"):
            response = self.s3_client.list_objects_v2(
                Bucket=self.s3_bucket,
                Prefix=root,
                ContinuationToken=response["NextContinuationToken"]
            )
            objects.extend(response.get("Contents", []))

        mapping = {}
        csv_file = csv.reader(open(csvMapFile, "r"))
        for id, wordnet, name in csv_file:
            if id == "resnet_label":
                continue
            mapping[int(wordnet)] = {"id": id, "name": name}

        # print(mapping)
        self.classnames = {}

        # Iterate and keep valid files only
        self.instances = []
        for ds_idx, item in enumerate(objects):
            key = item["Key"]
            path = Path(key)

            # Check if file is valid
            if path.suffix.lower() not in (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp"):
                continue

            # Get label
            label = int(path.parent.name[1:])
            name = mapping[label]["name"]
            self.classnames[mapping[label]["id"]] = name
            label = int(mapping[label]["id"])


            # Keep track of valid instances
            self.instances.append((label, name, key))

        self.transform = transform

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx):
        try:
            label, name, key = self.instances[idx]
            # Download image from S3
            # response = self.s3_client.get_object(Bucket=self.s3_bucket, Key=key)
            # img_bytes = response["Body"]._raw_stream.data

            img_bytes = BytesIO()
            self.s3_client.download_fileobj(Bucket=self.s3_bucket, Key=key, Fileobj=img_bytes)
            img_bytes.seek(0)  # Ensure the BytesIO object is at the start
            # Open image with PIL
            img = Image.open(img_bytes).convert("RGB")

            # Apply transformations if any
            if self.transform is not None:
                img = self.transform(img)
        except Exception as e:
            raise RuntimeError(f"Error loading image at index {idx}: {str(e)}")

        return {"img": img, "label": label, "name": name}

## ImageNet-V2

In [None]:
class ImageNetV2(Dataset):
    """
    A custom dataset class for loading images from the ImageNet-V2 dataset.

    Args:
        root (str): The root directory of the dataset.
        csvMapFile (str, optional): The path to the CSV file containing the mapping of WordNet IDs to class names. Defaults to "dataloaders/wordNetIDs2Classes.csv".
        transform (callable, optional): A function/transform that takes in an image and returns a transformed version. Defaults to None.
    """

    def __init__(
        self, root, csvMapFile="wordNetIDs2Classes.csv", transform=None
    ):
        self.s3_bucket = "deeplearning2024-datasets"
        self.s3_region = "eu-west-1"
        self.s3_client = boto3.client("s3", region_name=self.s3_region, verify=True)

        response = self.s3_client.list_objects_v2(Bucket=self.s3_bucket, Prefix=root)
        objects = response.get("Contents", [])
        while response.get("NextContinuationToken"):
            response = self.s3_client.list_objects_v2(
                Bucket=self.s3_bucket,
                Prefix=root,
                ContinuationToken=response["NextContinuationToken"]
            )
            objects.extend(response.get("Contents", []))

        mapping = {}
        csv_file = csv.reader(open(csvMapFile, "r"))
        for id, _, name in csv_file:
            if id == "resnet_label":
                continue
            mapping[id] = name

        self.classnames = {}
        # Iterate and keep valid files only
        self.instances = []
        for ds_idx, item in enumerate(objects):
            key = item["Key"]
            path = Path(key)

            # Check if file is valid
            if path.suffix.lower() not in (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp"):
                continue

            # Get label
            label = path.parent.name
            name = mapping[label]
            self.classnames[label] = name

            label = int(label)

            # Keep track of valid instances
            self.instances.append((label, name, key))

        self.transform = transform

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx):
        try:
            label, name, key = self.instances[idx]
            # Download image from S3
            # response = self.s3_client.get_object(Bucket=self.s3_bucket, Key=key)
            # img_bytes = response["Body"]._raw_stream.data

            img_bytes = BytesIO()
            self.s3_client.download_fileobj(Bucket=self.s3_bucket, Key=key, Fileobj=img_bytes)
            img_bytes.seek(0)  # Ensure the BytesIO object is at the start
            # Open image with PIL
            img = Image.open(img_bytes).convert("RGB")

            # Apply transformations if any
            if self.transform is not None:
                img = self.transform(img)
        except Exception as e:
            raise RuntimeError(f"Error loading image at index {idx}: {str(e)}")

        return {"img": img, "label": label, "name": name}

## EasyAugmenter

Augmentations are a core part of the MEMO and TPT techniques, both of them rely on the idea of applying multiple augmentations to the same image and then adapt the model, in various way, to the average output of the model across the augmentations. 

The structure of EasyAugmenter is similar to the one used in the official TPT repository, with some adaptation to make it more flexible to the needs of the different techniques it's supposed to be used with. At the core it takes two predefined sets of augmentations: **preprocess** and **base_transform**. 

The **preprocess** transformations are applied as the last step in the chain, preprocess transforms are tied to the pretrained model used as backbone, they contain the normalization and conversions of the image needed to make it compatible with the model input and distribution.

The **base_transform** transformations are a set of simple transformations such as crop and resize

Finally it's possible to specify 3 different modalities of augmentations: `augmix`, `cut` and `identity`. `identity`, as the name suggests, doesn't apply any transformation to the image, the result will be a set of identical images, in our case this specific modality is used when MEMO works in the dropout mode. `cut` applies a random crops and resize to the image, while `augmix` applies a set of augmentations to the image, the augmentations are randomly selected from a predefined set of transformations from the AugMix module.

In any case the result of the augmentation will be a list of size **n_views + 1** containing the augmented samples as well as the original (transformed) image in first place.

In [None]:
class EasyAgumenter(object):
    def __init__(self, base_transform, preprocess, augmentation, n_views=63):
        """
        This class provides an easy way to apply custom augmentations to images, the when called 
        it will return a list of augmentations with the original image in first place.

        Args:

        - base_transform (torchvision.transforms.Compose): The base transformation to apply to the images.
        - preprocess (torchvision.transforms.Compose): The preprocessing transformation to apply to the images (will be applied last).
        - augmentation (str): The type of augmentation to apply, can be 'augmix', 'identity' or 'cut'.
        - n_views (int): The number of augmentations to apply to the image.
        
        Returns:
        - (list) A list of images with the augmentations applied.
        """
        self.base_transform = base_transform
        self.preprocess = preprocess
        self.n_views = n_views

        if augmentation == 'augmix':

            self.preaugment = transforms.Compose(
                [
                    AugMix(),
                    transforms.Resize(224, interpolation=InterpolationMode.BICUBIC),
                    transforms.CenterCrop(224),
                ]
            )
        elif augmentation == 'identity':
            self.preaugment = self.base_transform
        elif augmentation == 'cut':
            self.preaugment = transforms.Compose(
                [
                    transforms.RandomResizedCrop(224),
                    transforms.RandomHorizontalFlip(),
                ]
            )
        else:
            raise ValueError('Augmentation type not recognized')
    
    def __call__(self, x):

        if isinstance(x, np.ndarray):
            x = transforms.ToPILImage()(x)

        image = self.preprocess(self.base_transform(x))

        views = [self.preprocess(self.preaugment(x)) for _ in range(self.n_views)]

        return [image] + views

## Functions

In [None]:
def get_dataloaders(root, transform=None, csvMapFile="wordNetIDs2Classes.csv"):
    """
    Returns the dataloader of the dataset.

    Args:
        root (str): The root directory of the dataset.
        transform (callable, optional): A function/transform that takes in an image and returns a transformed version. Defaults to None.
    """
    root_A = "imagenet-a"
    imageNet_A = ImageNetA(root_A, transform=transform, csvMapFile=csvMapFile)
    root_V2 = "imagenetv2-matched-frequency-format-val"
    imageNet_V2 = ImageNetV2(root_V2, transform=transform, csvMapFile=csvMapFile)

    return imageNet_A, imageNet_V2

def get_classes_names(csvMapFile="wordNetIDs2Classes.csv"):
    """
    Returns the class names of the dataset.

    Args:
        csvMapFile (str, optional): The path to the CSV file containing the mapping of WordNet IDs to class names. Defaults to "dataloaders/wordNetIDs2Classes.csv".
    """
    names = [""]*1000
    csv_file = csv.reader(open(csvMapFile, 'r'))
    for id, wordnet, name in csv_file:
        if id == 'resnet_label':
            continue
        names[int(id)] = name
    
    return names

def memo_get_datasets(augmentation, augs=64, dataset_root="datasets"):
    """
    Returns the ImageNetA and ImageNetV2 datasets for the memo model
    Args:
        dataset_root: the root folder of all the datasets
        augmentation (str): What type of augmentation to use in EasyAugmenter. Can be 'augmix', 'identity' or 'cut'
        augs (int): The number of augmentations to compute. Must be greater than 1

    Returns: The ImageNetA and ImageNetV2 datasets for the memo model, with the Augmentations already applied

    """
    assert augs > 0, 'The number of augmentations must be greater than 0'
    memo_transforms = transforms.Compose([transforms.Resize(256),
                                          transforms.CenterCrop(224)])
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    transform = EasyAgumenter(memo_transforms, preprocess, augmentation, augs - 1)
    imageNet_A, imageNet_V2 = get_dataloaders(dataset_root, transform)
    return imageNet_A, imageNet_V2


def tpt_get_transforms(augs=64):

    base_transform = transforms.Compose(
        [
            transforms.Resize(224, interpolation=InterpolationMode.BICUBIC),
            transforms.CenterCrop(224),
        ]
    )

    preprocess = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.48145466, 0.4578275, 0.40821073],
                std=[0.26862954, 0.26130258, 0.27577711],
            ),
        ]
    )

    data_transform = EasyAgumenter(
        base_transform,
        preprocess,
        n_views=augs - 1,
    )

    return data_transform


def tpt_get_datasets(data_root, augmix=False, augs=64, all_classes=True):
    """
    Returns the ImageNetA and ImageNetV2 datasets.

    Parameters:
    - data_root (str): The root directory of the datasets.
    - augmix (bool): Whether to use AugMix or not.
    - augs (int): The number of augmentations to use.
    - all_classes (bool): Whether to use all classes or not.

    Returns:
    - imageNet_A (ImageNetA): The ImageNetA dataset.
    - ima_names (list): The original classnames in ImageNetA.
    - ima_custom_names (list): The retouched  classnames in ImageNetA.
    - ima_id_mapping (list): The mapping between the index of the classname and the ImageNet label

    same for ImageNetV2

    For instance the first element of ima_names corresponds to the label '90'.  After running the
    inference run the predicted output through the ima_id_mapping to recover the correct class label.

    out = tpt(inputs)
    pred = out.argmax().item()
    out_id = ima_id_mapping[pred]

    """
    base_transform = transforms.Compose(
        [
            transforms.Resize(224, interpolation=InterpolationMode.BICUBIC),
            transforms.CenterCrop(224),
        ]
    )

    preprocess = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.48145466, 0.4578275, 0.40821073],
                std=[0.26862954, 0.26130258, 0.27577711],
            ),
        ]
    )

    data_transform = EasyAgumenter(
        base_transform,
        preprocess,
        augmentation=("augmix" if augmix else "cut"),
        n_views=augs - 1,
    )

    imageNet_A = ImageNetA(
        "imagenet-a", transform=data_transform
    )
    imageNet_V2 = ImageNetV2(
        "imagenetv2-matched-frequency-format-val",
        transform=data_transform,
    )

    imv2_label_mapping = list(imageNet_V2.classnames.keys())
    imv2_names = list(imageNet_V2.classnames.values())
    imv2_custom_names = [imagenet_classes[int(i)] for i in imv2_label_mapping]

    ima_label_mapping = list(imageNet_A.classnames.keys())
    ima_names = list(imageNet_A.classnames.values())
    ima_custom_names = [imagenet_classes[int(i)] for i in ima_label_mapping]

    if all_classes:
        ima_names += [name for name in imv2_names if name not in ima_names]
        ima_custom_names += [
            name for name in imv2_custom_names if name not in ima_custom_names
        ]
        ima_label_mapping += [
            map for map in imv2_label_mapping if map not in ima_label_mapping
        ]

    return (
        imageNet_A,
        ima_names,
        ima_custom_names,
        ima_label_mapping,
        imageNet_V2,
        imv2_names,
        imv2_custom_names,
        imv2_label_mapping,
    )


# Models

## EasyModel
This is the parent class for all the models used in the experiments, it's used as a template to define the structure of the models and the methods that are needed to run the tests. The class is composed of the following methods:
- **select_confident_samples**: This method is used to select the topk most confident samples from the output of the model
- **avg_entropy**: This method is used to compute the average entropy of the model output across the augmentations
- **forward**: The forward pass of the model, it takes the input image and returns the output of the model
- **predict**: The predict method is used to get the prediction of the model on the input image
- **reset**: This method is used to reset the model to the initial state

Each model that inherits from EasyModel should override theses methods to adapt them to the specific needs of the model.

In [None]:
class EasyModel(nn.Module):
    def __init__(self):
        super(EasyModel, self).__init__()

    def select_confident_samples(self, logits, top):
        """
        Performs confidence selection, will return the indexes of the
        augmentations with the highest confidence as well as the filtered
        logits

        Parameters:
        - logits (torch.Tensor): the logits of the model [NAUGS, NCLASSES]
        - top (float): the percentage of top augmentations to use

        Returns:
        - logits (torch.Tensor): the filtered logits of the model [NAUGS*top, NCLASSES]
        """
        batch_entropy = -(logits.softmax(1) * logits.log_softmax(1)).sum(1)
        idx = torch.argsort(batch_entropy, descending=False)[
            : int(batch_entropy.size()[0] * top)
        ]
        return logits[idx], idx
    
    def avg_entropy(self, outputs):
        """
        Computes the average entropy of the model outputs

        Parameters:
        - outputs (torch.Tensor): the logits of the model [NAUGS, NCLASSES]
        
        Returns:
        - avg_entropy (torch.Tensor): the average entropy of the model outputs [1]
        """
        logits = outputs - outputs.logsumexp(
            dim=-1, keepdim=True
        )  # logits = outputs.log_softmax(dim=1) [N, 1000]
        avg_logits = logits.logsumexp(dim=0) - np.log(
            logits.shape[0]
        )  # avg_logits = logits.mean(0) [1, 1000]
        min_real = torch.finfo(avg_logits.dtype).min
        avg_logits = torch.clamp(avg_logits, min=min_real)
        return -(avg_logits * torch.exp(avg_logits)).sum(dim=-1)

    def forward(self, x):
        return super(EasyModel, self).forward(x)
        
    def predict(self, x):
        raise NotImplementedError
    
    def reset(self):
        raise NotImplementedError

## MEMO
Here we see the implementation of our `EasyMemo` Class.
`EasyMemo` inherits from `EasyModel` and implements the functions that are specific to the implementation described in the paper "MEMO: Test Time Robustness via
Adaptation and Augmentation".
The implementation we have is much cleaner, as everything is encapsulated in the `EasyMemo` Class.
We implemented a simple and modular way to predict a sample. We wrap the model we choose, in our case a `ResNet50` neural network, and implement the **inference**, **forward** and **predict** steps so that we can easily switch between the different modes of operation.
### MEMO Test Time Robustness via Adaptation and Augmentation
In the MEMO paper, the authors propose a modified version of the batch normalization layer, where the running mean and variance are updated at each iteration with the mean and variance of the current modified batch. 
This is done to adapt the model to the current data distribution. The prior strength is used to control the influence of the current batch on the running statistics. 
The authors show that this method is effective in improving the robustness of the model to adversarial attacks and other perturbations.

If we want to leave the original variance and mean, we just keep the prior strength to 1, and the model will use the original values without modifying them.
#### Implementation details
The implementation is quite simple, we only took the modified batch normalization layer and applied it to the model, then we pass the sample to the predict function, which applies backpropagation on the single test sample and updates the running statistics of the batch normalization layer with the chosen prior strength (e.g. 0.94 for the MEMO tests). 
The values used are those from the paper.

### Confidence selection
The confidence selection is a method used to select the most confident samples from the output of the model.
The method is quite simple, we compute the entropy of the model output across the augmentations and select the top percentage of the samples to use for the final prediction.
In this way, we can select the most confident samples and use them to make the final prediction.

#### Implementation details
In the **select_confident_samples** function we compute the entropy of the model output across the augmentations and select the top percentage of the samples to use for the final prediction.
After we have done that, we can base our predict call on the logits that are the best for our prediction.

### Ensemble
To cut the time of inference we could simply use multiple augmentations of the same image and average the predictions. 
This is the idea behind the ensemble method, where we use a batch of multiple augmentations of the same image and average the predictions to get the final prediction.
We can combine these methods with the confidence selection to obtain better results
#### Augmentation ensemble(cut)
This ensemble is done with the random cut augmentation, and turn out to be quite effective since the prediction is much more focused on different areas of the image.
Many times, samples of the Imagenet-A dataset are images that have much noise in it and only a portion of the image is relevant to the class.
By using different cuts, we focus on different areas of the image that we can test, and further improve the ability of the model to make a confident guess.

##### Implementation details
We use the `EasyAugmenter` class to apply the random cut augmentation to the image, and then we pass the augmented image to the model to get the logits. 
The results are given by a classic softmax layer.

#### Dropout ensemble
If an image should be classified as something, then most of the features extracted should point toward it, but there might be some noise from other features that influence the result or may even strongly polarize the classification.
Using dropout we want to zero stochastically some features, hoping that the majority of the good ones will be kept, while the others will be removed. 

##### Implementation details
With the dropout ensemble we add a dropout layer at the end of the residual blocks of the network, before the average pooling layer. 
Here we see where we added the dropout layer, that was added after the sequential block because it was faster to implement. 
```
(...(...(...
        (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      )
      (dropout): Dropout(p=0, inplace=True)
    )
    (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
    (fc): Linear(in_features=2048, out_features=1000, bias=True)
  )
```
Passing a batch to the model will result in a different output at each iteration, and we can average the results to get the final prediction.

In [None]:
def _modified_bn_forward(self, input):
    est_mean = torch.zeros(self.running_mean.shape, device=self.running_mean.device)
    est_var = torch.ones(self.running_var.shape, device=self.running_var.device)
    nn.functional.batch_norm(input, est_mean, est_var, None, None, True, 1.0, self.eps)
    running_mean = self.prior * self.running_mean + (1 - self.prior) * est_mean
    running_var = self.prior * self.running_var + (1 - self.prior) * est_var
    return nn.functional.batch_norm(input, running_mean, running_var, self.weight, self.bias, False, 0, self.eps)


class EasyMemo(EasyModel):
    """
    A class to wrap a neural network with the MEMO TTA method
    """

    def __init__(self, net, device, classes_mask, prior_strength: float = 1.0, lr=0.005, weight_decay=0.0001, opt='sgd',
                 niter=1, top=0.1, ensemble=False):
        """
        Initializes the EasyMemo model with various arguments
        Args:
            net: The model to wrap with EasyMemo
            device: The device to run the model on(usually 'CPU' or 'CUDA')
            classes_mask: The classes to consider for the model(used for Imagenet-A)
            prior_strength: The strength of the prior to use in the modified BN forward pass
            lr: The Learning rate for the optimizer of the model
            weight_decay: The weight decay for the optimizer of the model
            opt: Which optimizer to use for this model between 'sgd' and 'adamw' for the respective optimizers
            niter: The number of iterations to run the memo pass for
            top: The percentage of the top logits to consider for confidence selection
            ensemble: Whether to use the ensemble method or not
        """
        super(EasyMemo, self).__init__()

        self.ens = ensemble
        self.device = device
        self.prior_strength = prior_strength
        self.net = net.to(device)
        self.optimizer = self.get_optimizer(lr=lr, weight_decay=weight_decay, opt=opt)
        self.lr = lr
        self.weight_decay = weight_decay
        self.opt = opt
        self.confidence_idx = None
        self.memo_modify_bn_pass()
        self.criterion = self.avg_entropy
        self.niter = niter
        self.top = top
        self.initial_state = deepcopy(self.net.state_dict())
        self.classes_mask = classes_mask

    def forward(self, x, top=-1):
        """
        Forward pass where we check which type of input we have and we call the inference on the input image Tensor
        Args:
            top: How many samples to select from the batch
            x: A Tensor of shape (N, C, H, W) or a list of Tensors of shape (N, C, H, W)

        Returns: The logits after the inference pass

        """
        self.top = top if top > 0 else self.top
        # print(f"Shape forward: {x.shape}")
        if isinstance(x, list):
            x = torch.stack(x).to(self.device)
            # print(f"Shape forward: {x.shape}")
            logits = self.inference(x)
            logits, self.confidence_idx = self.select_confident_samples(logits, self.top)
        else:
            if len(x.shape) == 3:
                x = x.unsqueeze(0)
            x = x.to(self.device)
            logits = self.inference(x)

        # print(f"[EasyMemo] input shape: {x.shape}")
        # print(f"[EasyMemo] logits shape: {logits.shape}")
        return logits

    def inference(self, x):
        """
        Return the logits of the image in input x
        Args:
            x: A Tensor of shape (N, C, H, W) of an Image

        Returns: The logits for that Tensor image

        """
        if self.ens:
            self.net.train()
        else:
            self.net.eval()
        outputs = self.net(x)

        out_app = torch.zeros(outputs.shape[0], len(self.classes_mask)).to(self.device)
        for i, out in enumerate(outputs):
            out_app[i] = out[self.classes_mask]
        return out_app

    def predict(self, x, niter=1):
        """
        Predicts the class of the input x, which is an image
        Args:
            niter: The number of iteration on which to run the memo pass
            x: Tensor of shape (N, C, H, W)

        Returns: The predicted classes

        """
        self.niter = niter
        nn.BatchNorm2d.prior = self.prior_strength

        if self.ens:
            self.net.train()
            predicted = self.ensemble(x)
        else:
            self.net.eval()
            for iteration in range(self.niter):
                self.optimizer.zero_grad()
                outputs = self.forward(x)
                loss = self.criterion(outputs)
                loss.backward()
                self.optimizer.step()

            with torch.no_grad():
                outputs = self.net(x[0].unsqueeze(0).to(self.device))
                outs = torch.zeros(outputs.shape[0], len(self.classes_mask)).to(self.device)
                for i, out in enumerate(outputs):
                    outs[i] = out[self.classes_mask]
                predicted = outs.argmax(1).item()

        nn.BatchNorm2d.prior = 1.0
        return predicted

    def reset(self):
        """Resets the model to its initial state"""
        del self.optimizer
        self.optimizer = self.get_optimizer(lr=self.lr, weight_decay=self.weight_decay, opt=self.opt)
        self.confidence_idx = None
        self.net.load_state_dict(deepcopy(self.initial_state))

    def memo_modify_bn_pass(self):
        print('modifying BN forward pass')
        nn.BatchNorm2d.prior = 1.0
        nn.BatchNorm2d.forward = _modified_bn_forward

    def get_optimizer(self, lr=0.005, weight_decay=0.0001, opt='sgd'):
        """
        Initializes the optimizer for the memo model
        Args:
            lr: The learning rate for the optimizer
            weight_decay: The weight decay for the optimizer
            opt: Which optimizer to use

        Returns: The optimizer for the memo model

        """
        if opt == 'sgd':
            optimizer = optim.SGD(self.net.parameters(), lr=lr, weight_decay=weight_decay)
        elif opt == 'adamw':
            optimizer = optim.AdamW(self.net.parameters(), lr=lr, weight_decay=weight_decay)
        else:
            raise ValueError('Invalid optimizer selected')
        return optimizer

    def ensemble(self, x):
        with torch.no_grad():
            outputs = self.forward(x)
            outputs = nn.functional.softmax(outputs, dim=1)
            prediction = outputs.sum(0).argmax().item()

        return prediction

## TPT

This is the core of our implementation of Test-Time Prompt Tuning. 

`EasyTPT` inherits from `EasyModel` and acts as a wrapper around `EasyPromptLearner` as well as implementing the functions that are specific to each different mode of operation. The implementation is based on the paper "Test-time Prompt Tuning for Image Classification" and it's inspired to the original one by the authors, our adaptation is more polished and offers a set of two additional modes of operation and some improvements.

### Alignment Mode
This is the first of the two additional modes of operation, it can be triggered by setting 'align_steps' to a value greater than 0. The idea behind this modality is to pull closer the image embeddings by acting directly on CLIP's image encoder as this could result in lower variance between the augmented images and therefore in a more stable classification. The alignment is performed by maximizing the cosine similarity between the embeddings of the set of augmented images. In a similar fashion to the prompt learner, before running the actual classification step, we perform a forward pass of the augmented images through the image encoder and then we compute the cosine similarity between the embeddings, at this point the error is backpropagated to the image encoder and the weights are updated, we exclude the attention weights from the optimization process. This procedure is repeated for `align_steps` times before the classification step.

This implementation could be seen as contrastive learning, where we try to make the embeddings of the augmented images closer to each other by confronting their self-similarity between the set. The results however are not as good as the ones obtained by the prompt learner, testing proved that the alignment step is not very beneficial and on top of that it's computationally expensive as we need to perform a backward pass on the image encoder. It is even memory hungry as we need to build its computational graph too.

#### Implementation Details

In test time adaptation the model is supposed to be unaware of the target domain so no information about past data should be avaiable. For this reason at each prediction we must reset the model to it's original state, this includes both the model parameters and the optimizer state, both of them are saved at model initialization and reloaded every time the model is reset.

To perform alignment we are forced to use `clip.float()` to convert the model to float precision, this is due to the fact that when fine-tuning CLIP the use of different precision causes the gradients to show NaN values, it's not clear why this happens, but it could be due to numerical incompatibilities and numerical instability. The side effect of this is that the model is now using more memory and is slower to run, this is a trade-off that we have to accept to be able to perform alignment.

Alignment too is subject to confidence selection: just like with prompt tuning we are interested the most confident predictions.

### Ensemble Mode
This is the second of the two additional modes of operation, it can be triggered by setting 'ensemble' to *True*. The idea behind this modality is staggering simple: instead of using the entropy of the predictions to tune the prompt in a way that the predictions are more confident and, hopefully, more accurate, we skip the prompt tuning step altogether and pick the class with the highest marginal probability, in a sense, we are skipping the "middle man" by taking directly the most confident marginalized prediction.

This method can be seen as using an ensemble of models, where instead of running the same image through multiple models and averaging the predictions, we run multiple augmentations of the same image through the same model and average the predictions. This method is extremely efficient compared to prompt tuning both in terms of computational cost and memory usage, as we only need to pass the run some forwards passes and no learning is involved whatsoever.

What's even more surprising is that this method is extremely effective, not only scoring better than the baseline, but outperforming the prompt tuning method too.

#### Implementation Details

The implementation is extremely straightforward, we run the augmented images through the model and then we softmax and average the prediction, the class with the highest probability is then returned as the final prediction.

### Additional Improvements

The two modes described above include some of the main differences between our implementation and the original one, however we have added a set of additional features and improvements to the model. Most of the variations implemented in the original version are present in our implementation too, such as the possibility to use different backbones for CLIP (RN50, ViT...) it's possible to tweak the learning rates, confidence selection, number of augmentations and test-time tuning steps.

#### Split Context

Unlike the original implementation we give the user the possibility to seamlessly specify where the prompt is placed in the context just by using the `[CLS]` token in the *base_prompt* specification. Split context is a technique that allows the user to specify the way the prompt is integrated with the context: wherever the `[CLS]` token is placed in the prompt, at that location the context is split in two parts, the first part is the prefix and the second part is the suffix. Split context allows the user to specify whether during the tuning phase the prompt should be treated as a single vector or as two separate vectors.

#### Prediction

The original implementation was very fragmented and the prediction step as well as the test-time tuning step were not included in the main class. With our project we tried to be as self-contained as possible with modularity in mind. For this reason we have included a prediction function that takes care of the behaviour of the model at test time, according to the specified parameters.

`predict`, `inference` and `forward` have very distinct roles:
- **inference** is the function at the lower level, it includes the forward pass through CLIP image feature extractor and the custom text encoder
- **forward** stand a step above and manages whether to run the confidence selection step and, according to the input parameters, how the inference should be performed
- **predict** is the highest level function and it's the one that should be called by the user, it takes care of the alignment step, the ensemble step and the prompt tuning step, it's the most high-level function and it's all the user needs to perform classification on a sample.




TODO ottimizzazioni di memoria e implementazione di base

In [None]:

class EasyPromptLearner(nn.Module):
    """
    This class is responsible for learning the prompt for the TPT model,
    it takes the classnames and the base prompt and creates the prompt
    for each class. The prompts get tokenized and embedded, the embeddings
    of the base prompt are then used to create the context for each class.
    It's possible to put the context in any part of the prompt
    using the [CLS] token. It's also possible to choose wether to
    split the context into separate learning parameters for the prefix and
    suffix or to keep them together.

    Parameters:
    - device (str): the device to run the model
    - clip (torch.nn.Module): the clip model
    - base_prompt (str): the base prompt to use
    - splt_ctx (bool): split the context or not
    - classnames (list): the classnames to use
    """

    def __init__(
        self,
        device,
        clip,
        base_prompt="a photo of [CLS]",
        splt_ctx=False,
        classnames=None,
    ):
        super().__init__()

        self.device = device
        self.base_prompt = base_prompt
        self.tkn_embedder = clip.token_embedding
        self.tkn_embedder.requires_grad_(False)

        self.split_ctx = splt_ctx

        self.prepare_prompts(classnames)

    def prepare_prompts(self, classnames):
        """
        Prepares the prompts for the TPT model, this method tokenizes,
        embeds and prepares the context for each class and the base prompt.

        Parameters:
        - classnames (list): the classnames to use
        """
        print("[PromptLearner] Preparing prompts")

        self.classnames = classnames

        # get number of classes
        self.cls_num = len(self.classnames)

        # get prompt text prefix and suffix
        txt_prefix = self.base_prompt.split("[CLS]")[0]
        txt_suffix = self.base_prompt.split("[CLS]")[1]

        # tokenize the prefix and suffix
        tkn_prefix = tokenize(txt_prefix)
        tkn_suffix = tokenize(txt_suffix)
        tkn_pad = tokenize("")
        tkn_cls = tokenize(self.classnames)

        # get the index of the last element of the prefix and suffix
        idx = torch.arange(tkn_prefix.shape[1], 0, -1)
        self.indp = torch.argmax((tkn_prefix == 0) * idx, 1, keepdim=True)
        self.inds = torch.argmax((tkn_suffix == 0) * idx, 1, keepdim=True)

        # token length for each class
        self.indc = torch.argmax((tkn_cls == 0) * idx, 1, keepdim=True)

        # get the prefix, suffix, SOT and EOT
        self.tkn_sot = tkn_prefix[:, :1]
        self.tkn_prefix = tkn_prefix[:, 1 : self.indp - 1]
        self.tkn_suffix = tkn_suffix[:, 1 : self.inds - 1]
        self.tkn_eot = tkn_suffix[:, self.inds - 1 : self.inds]
        self.tkn_pad = tkn_pad[:, 2:]

        # load segments to CUDA, be ready to be embedded
        self.tkn_sot = self.tkn_sot.to(self.device)
        self.tkn_prefix = self.tkn_prefix.to(self.device)
        self.tkn_suffix = self.tkn_suffix.to(self.device)
        self.tkn_eot = self.tkn_eot.to(self.device)
        self.tkn_pad = self.tkn_pad.to(self.device)

        self.tkn_cls = tkn_cls.to(self.device)

        # gets the embeddings
        with torch.no_grad():
            self.emb_sot = self.tkn_embedder(self.tkn_sot)
            self.emb_prefix = self.tkn_embedder(self.tkn_prefix)
            self.emb_suffix = self.tkn_embedder(self.tkn_suffix)
            self.emb_eot = self.tkn_embedder(self.tkn_eot)
            self.emb_cls = self.tkn_embedder(self.tkn_cls)
            self.emb_pad = self.tkn_embedder(self.tkn_pad)

        # take out the embeddings of the class tokens (they are different lenghts)
        self.all_cls = []
        for i in range(self.cls_num):
            self.all_cls.append(self.emb_cls[i][1 : self.indc[i] - 1])

        # prepare the prompts, they are needed for text encoding
        self.txt_prompts = [
            self.base_prompt.replace("[CLS]", cls) for cls in self.classnames
        ]
        self.tkn_prompts = tokenize(self.txt_prompts)

        # set the inital context, this will be reused at every new inference
        # this is the context that will be optimized

        if self.split_ctx:
            self.pre_init_state = self.emb_prefix.detach().clone()
            self.suf_init_state = self.emb_suffix.detach().clone()
            self.emb_prefix = nn.Parameter(self.emb_prefix)
            self.emb_suffix = nn.Parameter(self.emb_suffix)
            self.register_parameter("emb_prefix", self.emb_prefix)
            self.register_parameter("emb_suffix", self.emb_suffix)
        else:
            self.ctx = torch.cat((self.emb_prefix, self.emb_suffix), dim=1)
            self.ctx_init_state = self.ctx.detach().clone()
            self.ctx = nn.Parameter(self.ctx)
            self.register_parameter("ctx", self.ctx)

    def build_ctx(self):
        """
        While the context will be optimized, the embedded classnames
        must stay the same, this method builds the context for each class
        at each forward pass, using the optimized context.

        Returns:
        - torch.Tensor: the embedded prompt for each class
        """

        prompts = []
        for i in range(self.cls_num):

            # get the size of the padding (length depends on the classname size)
            pad_size = self.emb_cls.shape[1] - (
                self.emb_prefix.shape[1]
                + self.indc[i].item()
                + self.emb_suffix.shape[1]
            )

            if self.split_ctx:
                prefix = self.emb_prefix
                suffix = self.emb_suffix
            else:
                prefix = self.ctx[:, : self.emb_prefix.shape[1]]
                suffix = self.ctx[:, self.emb_prefix.shape[1] :]

            # concatenates all elements to build the prompt
            prompt = torch.cat(
                (
                    self.emb_sot,
                    prefix,
                    self.all_cls[i].unsqueeze(0),
                    suffix,
                    self.emb_eot,
                    self.emb_pad[:, :pad_size],
                ),
                dim=1,
            )
            prompts.append(prompt)
        prompts = torch.cat(prompts, dim=0)

        return prompts

    def forward(self):
        return self.build_ctx()

    def reset(self):
        """
        This functions resets the context to the initial state, it
        has to be run before each new inference to bring the context
        to the initial state.
        """
        if self.split_ctx:
            self.emb_prefix.data.copy_(self.pre_init_state)  # to be optimized
            self.emb_suffix.data.copy_(self.suf_init_state)  # to be optimized
        else:
            self.ctx.data.copy_(self.ctx_init_state)  # to be optimized


class EasyTPT(EasyModel):
    """
    This class is the main class for the TPT, it contains
    the logic for running the TPT model in all its configurations,
    as well as EasyPromptLearner, which is responsible for the
    prompt learning.

    Modes:
    - Ensemble: in this mode the model won't preform tuning steps on the prompt,
    instead it will run the inference on all the augmentations and take the prediction
    that maximizes probability marginalized over the agumentations.
    - Alignment: when align_steps > 0 the model will also perform align_steps tuning
    steps on the image encoder in an effort to minimize the distance between the
    embeddings of the augmentations.

    Parameters:
    - device (str): the device to run the model
    - base_prompt (str): the base prompt to use
    - arch (str): the architecture to use for CLIP
    - splt_ctx (bool): split the context or not
    - classnames (list): the classnames to use
    - ensemble (bool): run TPT in ensemble mode
    - ttt_steps (int): number of test time tuning steps
    - lr (float): the learning rate
    - align_steps (int): number of alignment steps
    - confidence (float): confidence threshold for the confidence selection
    """

    def __init__(
        self,
        device,
        base_prompt="a photo of a [CLS]",
        arch="RN50",
        splt_ctx=False,
        classnames=None,
        ensemble=False,
        ttt_steps=1,
        lr=0.005,
        align_steps=0,
        confidence=0.10,
    ):
        super(EasyTPT, self).__init__()
        self.device = device

        ###TODO: tobe parametrized
        DOWNLOAD_ROOT = "~/.cache/clip"
        ###

        self.base_prompt = base_prompt
        self.ttt_steps = ttt_steps
        self.selected_idx = None
        self.ensemble = ensemble
        self.align_steps = align_steps
        self.confidence = confidence

        # Load clip
        clip, self.preprocess = load(
            arch, device=device, download_root=DOWNLOAD_ROOT, jit=False
        )

        if align_steps > 0:  # clip tuning must run in float
            clip.float()

        self.clip = clip
        self.dtype = clip.dtype
        self.image_encoder = clip.encode_image
        # self.text_encoder = clip.encode_text

        # freeze the parameters
        for name, param in self.named_parameters():
            param.requires_grad_(False)

        # create the prompt learner
        self.prompt_learner = EasyPromptLearner(
            device, clip, base_prompt, splt_ctx, classnames
        )

        # create optimizer and save the state
        trainable_param = []
        for name, param in self.named_parameters():
            if param.requires_grad:
                print(f"[EasyTPT TPT] Training parameter: {name}")
                trainable_param.append(param)
        self.optimizer = torch.optim.AdamW(trainable_param, lr)
        self.optim_state = deepcopy(self.optimizer.state_dict())

        if align_steps > 0:
            emb_trainable_param = []
            # unfreeze the image encoder
            for name, param in self.clip.visual.named_parameters():
                # if parameter is not attnpoll
                if "attnpool" not in name:
                    param.requires_grad_(True)
                    emb_trainable_param.append(param)
                    print(f"[EasyTPT Emb] Training parameter: {name}")

            self.emb_optimizer = torch.optim.AdamW(emb_trainable_param, 0.0001)
            self.emb_optim_state = deepcopy(self.emb_optimizer.state_dict())
            self.clip_init_state = deepcopy(self.clip.visual.state_dict())

        if self.ensemble:
            print("[EasyTPT] Running TPT in Ensemble mode")

        if self.align_steps > 0:
            print("[EasyTPT] Running TPT with alignment")

        # for name, param in self.named_parameters():
        #     if param.requires_grad:
        #         print(f"[EasyTPT] Training parameter: {name}")

    def forward(self, x, top=-1):
        """
        If x is a list of augmentations, run the confidence selection,
        otherwise just run the inference

        Parameters:
        - x (torch.Tensor or list): the image(s) to run the inference. One
        image for the final prediction or a list of augmentations for the
        tuning steps.
        - top (int): the top percentage of samples to select

        Returns:
        - logits (torch.Tensor): the logits of the inference
        """

        if top == -1:
            top = self.confidence

        self.eval()
        if isinstance(x, list):
            x = torch.stack(x).to(self.device)
            logits = self.inference(x)
            if self.selected_idx is None:
                logits, self.selected_idx = self.select_confident_samples(logits, top)
            else:
                logits = logits[self.selected_idx]
        else:
            if len(x.shape) == 3:
                x = x.unsqueeze(0)
            x = x.to(self.device)

            logits = self.inference(x)

        return logits

    def inference(self, x):
        """
        Basically CLIP's forward method, but with the custom
        encoder to use our embeddings
        """

        with torch.no_grad():
            image_feat = self.image_encoder(x)
            image_feat = image_feat / image_feat.norm(dim=-1, keepdim=True)

        emb_prompts = self.prompt_learner()

        txt_features = self.custom_encoder(emb_prompts, self.prompt_learner.tkn_prompts)
        txt_features = txt_features / txt_features.norm(dim=-1, keepdim=True)

        logit_scale = self.clip.logit_scale.exp()
        logits = logit_scale * image_feat @ txt_features.t()

        return logits

    def predict(self, images, niter=1):

        if self.ensemble:
            with torch.no_grad():
                out = self(images)
                marginal_prob = F.softmax(out, dim=1).mean(0)
                out_id = marginal_prob.argmax().item()
        else:
            if self.align_steps > 0:
                self.align_embeddings(images)

            for _ in range(niter):
                out = self(images)
                loss = self.avg_entropy(out)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            with torch.no_grad():
                out = self(images[0])
                out_id = out.argmax(1).item()
                prediction = self.prompt_learner.classnames[out_id]

        # return out_id, prediction
        return out_id

    def align_emb_loss(self, image_feat):

        norm_feat = torch.nn.functional.normalize(image_feat, p=2, dim=1)

        cos_sim = torch.mm(norm_feat, norm_feat.T)

        # noself_mean = (cos_sim.sum() - torch.trace(cos_sim)) / (
        #     cos_sim.numel() - cos_sim.shape[0]
        # )
        loss = 1 - cos_sim.mean()

        return loss

    def align_embeddings(self, x):
        """
        Aligns the embeddings of the image encoder
        """

        self.forward(x)
        self.clip.visual.train()
        x = torch.stack(x).to(self.device)
        selected_augs = torch.index_select(x, 0, self.selected_idx)
        for _ in range(self.align_steps):
            image_feat = self.clip.visual(selected_augs.type(self.dtype))
            loss = self.align_emb_loss(image_feat)
            self.emb_optimizer.zero_grad()
            loss.backward()
            # print("distance before: ", loss.item())
            self.emb_optimizer.step()
        image_feat = self.clip.visual(selected_augs.type(self.dtype))
        loss = self.align_emb_loss(image_feat)
        # print("distance after: ", loss.item())
        self.clip.visual.eval()

    def custom_encoder(self, prompts, tokenized_prompts):
        """
        Custom clip text encoder, unlike the original clip encoder this one
        takes the prompts embeddings from the prompt learner
        """
        x = prompts + self.clip.positional_embedding
        x = x.permute(1, 0, 2).type(self.dtype)  # NLD -> LND
        x = self.clip.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.clip.ln_final(x).type(self.dtype)
        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = (
            x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)]
            @ self.clip.text_projection
        )

        return x

    def reset(self):
        """
        Resets the optimizer and the prompt learner to their initial state,
        this has to be run before each new test
        """
        self.optimizer.load_state_dict(deepcopy(self.optim_state))
        self.prompt_learner.reset()
        self.selected_idx = None

        if self.align_steps > 0:
            # print("[EasyTPT] Resetting embeddings optimizer")
            self.emb_optimizer.load_state_dict(deepcopy(self.emb_optim_state))
            self.clip.visual.load_state_dict(deepcopy(self.clip_init_state))

    def select_closest_samples(self, x, top):

        with torch.no_grad():
            feat = self.clip.visual(x.type(self.dtype))
            feat = feat / feat.norm(dim=-1, keepdim=True)

            # Compute cosine similarities
            sims = F.cosine_similarity(feat[0].unsqueeze(0), feat[1:], dim=1)
            vals, idxs = torch.topk(sims, int(sims.shape[0] * top))

        return idxs

    def get_optimizer(self):
        """
        Returns the optimizer

        Returns:
        - torch.optim: the optimizer
        """
        return self.optimizer


## Ensemble
This model allows to implement the core idea of the ensemble of multiple models. 
As already mentioned, the idea is to have different models trained with different techniques to cover the out of distribution data of one another, leading to a more robust ensemble.

`EasyEnsemble` aggregates multiple `EasyModel` instances and allows to perform the ensemble of the models. This means that we can choose to use any of the previous models with any of the implemented backbones and configurations. This allows maximum flexibility in the choice of the models to use for the ensemble, allowing it to be tuned to fit the specific needs of the user being them memory constraints, computational power or the need for a more robust model.

### Enseble of Multiple Models
This is the core functionality of this module. It uses a similar idea to MEMO and TPT, which is to reduce the entropy of the marginal distribution obtained by augmenting the image.
The difference is that we are feeding the augmentations to multiple models, each with different configurations, and then we are averaging the predictions of each model. This allows us to have a more robust model that can cover the out of distribution data of one another.

The final output will be a weighted sum of the predictions, where the weights are computed based on the entropy of the distribution of each model. This allows the ensemble to give more importance to the models that are more confident in their predictions.
However, to prevent some models to overpower others, we are also using a technique called "temperature scaling". This technique allows us to rescale the logits of each model by a certain factor, which is the temperature. This allows us to control the confidence of each model, and therefore the weight of each model in the final prediction.

#### Implementation Details
To compute the Ensemble output we follow these steps:
1. Compute the logits for each model and augmentation   
$outs = [model_0(augs), model_1(augs), ..., model_M(augs)]$
2. Take the average logits for each model       
$avg\_outs = [mean(outs_0), mean(outs_1), ..., mean(outs_M)]$
3. Rescale the logits by the temperature    
$rescaled\_outs = [avg\_outs_0 / temp_0, avg\_outs_1 / temp_1, ..., avg\_outs_M / temp_M]$
4. Compute the log softmax of the rescaled logits   
$log\_probs = log(softmax(rescaled\_outs))$
5. Compute the entropy of the distribution of each model as     
$ent = [entropy(log\_probs_0), entropy(log\_probs_1), ..., entropy(log\_probs_M)]$
6. Compute a score for each model as    
$ent\_scale_i= (\sum_j^M ent_j) / ent_i $    
$scale_i= ent\_scale_i / (\sum_j^M ent\_scale_j)$
7. Compute the final prediction as the weighted sum of the distributions of each model  
$final\_pred = sum_i^M scale_i * log\_probs_i$

After obtaining the final prediction, we can compute the entropy of the distribution which we use to backpropagate through all the models, using the optimizer they provide.

Finally, we can compute the final output by repeating the same steps and returning the class with the highest probability.

### Single Models Test
In the test arguments we can specify if we want to compute the results for the single models. This is useful to compare the performance of the ensemble with the performance of the single models.

To compute the results for the single models we simply run the **inference** method for each model.

### Simple Ensemble
Similarly to the previous techniques, we also decided to implement the ensemble of augmentations idea using multiple models. 
This has the same advantages we have seen before, combining them with the advantage of having different models hopefully covering slightly different areas of the feature space.

The implementation uses the same procedure of the ensemble of multiple models. However, we directly return the most probable class in the $final\_pred$ tensor, without the need to backpropagate through the models.


In [None]:
class Ensemble(nn.Module):
    """
    Ensemble class. Implements an ensemble of models with entropy minimization.

    Attributes:
        models (list[EasyModel]): A list of models to be used in the ensemble.
        temps (list): A list of temperature values corresponding to each model.
        test_single_models (bool): Whether to test each individual model in addition to the ensemble.
        simple_ensemble (bool): Whether to perform the entropy minimization step.
        device (str): The device to be used for computation.
    """
    def __init__(self, models:list[EasyModel], temps, device="cuda", test_single_models=False, simple_ensemble=False):
        """
        Initializes an Ensemble object.

        Args:
            models (list[EasyModel]): A list of models to be used in the ensemble.
            temps (list): A list of temperature values corresponding to each model.
            device (str, optional): The device to be used for computation. Defaults to "cuda".
            test_single_models (bool, optional): Whether to test each individual model in addition to the ensemble. Defaults to False.
            simple_ensemble (bool, optional): Whether to perform the entropy minimization step. Defaults to False.
        """
        super(Ensemble, self).__init__()
        self.models = models
        self.temps = temps
        self.test_single_models = test_single_models
        self.device = device
        self.simple_ensemble = simple_ensemble

    def entropy(self, logits):
        """
        Computes the entropy of a set of logits.

        Args:
            logits (torch.Tensor): The logits to compute the entropy of.

        Returns:
            torch.Tensor: The entropy of the logits.
        """
        return -(torch.exp(logits) * logits).sum(dim=-1)

    def marginal_distribution(self, models_logits):
        """
        Computes the marginal distribution of the ensemble.

        Args:
            models_logits (torch.Tensor): The logits of the models in the ensemble.

        Returns:
            torch.Tensor: The marginal distribution of the ensemble.
        """
        # average logits for each model
        avg_models_distributions = torch.Tensor(models_logits.shape[0], models_logits.shape[2]).to(self.device)
        for i, model_logits in enumerate(models_logits):
            avg_outs = torch.logsumexp(model_logits, dim=0) - torch.log(torch.tensor(model_logits.shape[0]))
            min_real = torch.finfo(avg_outs.dtype).min
            avg_outs = torch.clamp(avg_outs, min=min_real)
            avg_outs /= self.temps[i]
            avg_models_distributions[i] = torch.log_softmax(avg_outs, dim=0)

        with torch.no_grad():
            entropies = torch.stack([self.entropy(logits) for logits in avg_models_distributions]).to(self.device)
            sum_entropies = torch.sum(entropies, dim=0)
            scale = torch.stack([sum_entropies/entopy for entopy in entropies]).to(self.device)
            #normalize sum to 1
            scale = scale / torch.sum(scale)

        # print("\t\t[Ensemble] Entropies: ", entropies)
        # print("\t\t[Ensemble] Scales: ", scale)

        marginal_dist = torch.sum(torch.stack([scale[i].item() * avg_models_distributions[i] for i in range(len(avg_models_distributions))]), dim=0)

        return marginal_dist

    def get_models_outs(self, inputs, top=0.1):
        """
        Computes the outputs of the models in the ensemble.

        Args:
            inputs (list): A list of inputs to be fed to the models.
            top (float, optional): The top percentage of the outputs to be used. Defaults to 0.1.

        Returns:
            torch.Tensor: The outputs of the models in the ensemble.
        """
        model_outs = torch.stack([model(inputs[i], top).to(self.device) for i, model in enumerate(self.models)]).to(self.device)
        return model_outs.to(self.device)

    def get_models_predictions(self, inputs):
        """
        Computes the predictions of the single models in the ensemble.

        Args:
            inputs (list): A list of inputs to be fed to the models.

        Returns:
            list: A list of the predictions of the single models.
        """
        models_pred = [model.predict(inputs[i]) for i, model in enumerate(self.models)]
        return models_pred

    def entropy_minimization(self, inputs, niter=1, top=0.1):
        """
        Test time adaptation step. Minimizes the entropy of the ensemble's predictions.

        Args:
            inputs (list): A list of inputs to be fed to the models.
            niter (int, optional): The number of iterations to perform. Defaults to 1.
            top (float, optional): The top percentage of the outputs to be used. Defaults to 0.1.
        """
        for i in range(niter):
            outs = self.get_models_outs(inputs, top)
            avg_logit = self.marginal_distribution(outs)

            loss = self.entropy(avg_logit)
            loss.backward()
            for model in self.models:
                model.optimizer.step()
                model.optimizer.zero_grad()

    def forward(self, inputs, niter=1, top=0.1):
        """
        Forward pass of the ensemble.

        Args:
            inputs (list): A list of inputs to be fed to the models.
            niter (int, optional): The number of iterations to perform. Defaults to 1.
            top (float, optional): The top percentage of the outputs to be used. Defaults to 0.1.

        Returns:
            list, torch.Tensor, torch.Tensor: The predictions of the single models, the prediction of the ensemble without the entropy minimization step, and the prediction of the ensemble.
        """
        # get models outputs
        self.reset()
        models_pred = self.get_models_predictions(inputs)

        self.reset()
        self.entropy_minimization(inputs, niter, top)
            
        with torch.no_grad():
            outs = self.get_models_outs([i[0] for i in inputs], top)
            avg_logit = self.marginal_distribution(outs)
            prediction = torch.argmax(avg_logit, dim=0)

        if self.simple_ensemble:
            self.reset()    
            for model in self.models: model.eval()
            outs = self.get_models_outs(inputs, top)
            avg_logit = self.marginal_distribution(outs)
            prediction_no_back = torch.argmax(avg_logit, dim=0)

        return models_pred, prediction_no_back, prediction

    def reset(self):
        """
        Resets the models in the ensemble.
        """
        for model in self.models:
            model.reset()

# Test Time Adaptation

## MEMO
We start with the baseline obtained with different weights initialization, since from when the MEMO paper was published, new weights were defined in torch.

|          Test          | ImageNet-A | ImageNet-V2 |
|:----------------------:|:----------:|:-----------:|
|  baseline<br>default   |   16.67    |    69.95    |
| baseline<br>weights-v1 |    0.03    |    63.13    |

We see that the new weights give a lot better results in Imagenet-A, due to improved training. The results in Imagenet-V2 are still better than the previous weights.
For all the test we used the RandomCrop Augmentation(cut), since we saw it performed generally better than Augmix augmentations for Imagenet-A

### Results with weights-v1

|              Test              | ImageNet-A | ImageNet-V2 | Delta A | Delta V2 | Iterations(it/s) |   Time (A - V2)   |
|:------------------------------:|:----------:|:-----------:|:-------:|:--------:|:----------------:|:-----------------:|
|            Baseline            |    0.03    |    63.13    |  0.00   |   0.00   |  22.73 - 22.73   |    5:30 - 7:20    |
|              MEMO              |    1.92    |    64.47    |  1.89   |   1.34   |   8.19 - 8.06    |  15:15  - 20:40   |
|     MEMO<br>topk selection     | **7.43 **  |    66.68    |  7.40   |   3.55   |   1.64 - 1.64    | 1:16:00 - 1:42:20 |
|            Dropout             |    0.05    |    63.06    |  0.02   |  -0.07   |  14.28 - 13.16   |   8:45 - 12:40    |
|   Dropout<br>topk selection    |    0.03    |    63.18    |  0.00   |   0.05   |   3.98 - 3.85    |   31:23 - 43:20   |
|          Cut ensemble          |    2.00    |    65.90    |  1.97   |   2.77   |  14.71 - 13.88   |   8:10 - 12:00    |
| Cut ensemble<br>topk selection |    6.37    |  **67.05**  |  6.34   |   3.92   |   4.00 - 3.87    |   31:45 - 43:00   |

Here we see that dropout tests are not very effective since we start from a very low accuracy, so the dropout is not helping to achieve better accuracy. 
The MEMO tests are the best of all on Imagenet-A with accuracy of 7.43%, but the cut ensemble beats MEMO with accuracy of 67.05% on Imagenet-V2. 
The confidence selection helps deliver a better performance to both MEMO and cut ensemble because it chooses the samples in which is most confident between the many augmentations processed.

### Results with default weights

|              Test              | ImageNet-A | ImageNet-V2 | Delta A | Delta V2 | Iterations(it/s) | Time (A - V2) |
|:------------------------------:|:----------:|:-----------:|:-------:|:--------:|:----------------:|:-------------:|
|            Baseline            |   16.67    |    69.55    |  0.00   |   0.00   |  22.73 - 22.22   |  5.30 - 7:30  |
|              MEMO              |   21.53    |    68.75    |  4.86   |   -0.8   |   9.86 - 10.05   | 12:40 - 16:35 |
|     MEMO<br>topk selection     |   27.01    |    68.47    |  10.34  |  -1.08   |   3.52 - 3.45    | 35:30 - 48:22 |
|            Dropout             |   17.29    |    69.40    |  0.62   |  -0.15   |  18.34 - 17.60   |  6:49 - 9:28  |
|   Dropout<br>topk selection    |   18.44    |    69.69    |  1.77   |   0.14   |   4.26 - 3.96    | 29:20 - 42:02 |
|          Cut ensemble          |   22.65    |  **70.29**  |  5.98   |   0.74   |  20.19 - 19.04   |  6:11 - 8:45  |
| Cut ensemble<br>topk selection | **28.08**  |    69.77    |  11.41  |   0.22   |   5.50 - 5.17    | 22:43 - 32:15 |

For the MEMO tests, we see that the confidence selection is much better than the default MEMO, and the dropout ensemble is better than the baseline, but it comes at a cost of being slower.
The confidence selection is very effective in improving the performance of the MEMO and cut ensemble tests, as we can see from the results, and as explained above.

The cut ensemble is the best of all in general, with a 28.08% accuracy on Imagenet-A and 69.77% on Imagenet-V2, even if on Imagenet-V2 the simple cut ensemble is better than every other one.

In general the old weights are not as effective as the newer ones, since the newer ones are more refined.

Note that all the time data for this section is obtained with the Sagemaker session provided to us. 
The times are based on an estimate given after how many iteration were calculated after 500 samples of the dataset.
We calculated directly those for default weights, and put the same results for v1 weights, since the architecture is the same. 


In [None]:
results_path_memo = f"{RESULTS_PATH}/MEMO"
os.makedirs(results_path_memo, exist_ok=True)


def MEMO_testing_step(test, arguments):
    print(f"Starting {test} evaluation...")
    device = arguments["memo"]["device"]
    mapping = arguments["memo"]["mapping"]
    prior_strength = arguments["memo"]["prior_strength"]
    lr = arguments["memo"]["lr"]
    weight_decay = arguments["memo"]["weight_decay"]
    opt = arguments["memo"]["opt"]
    niter = arguments["memo"]["niter"]
    top = arguments["memo"]["top"]
    ensemble = arguments["memo"]["ensemble"]
    dataset_root = arguments["dataset"]["dataset_root"]
    naugs = arguments["dataset"]["naug"]
    aug_type = arguments["dataset"]["aug_type"]
    weights = arguments["weights"]

    weights = ResNet50_Weights.IMAGENET1K_V1 if weights == 'v1' else ResNet50_Weights.DEFAULT
    net = resnet50(weights=weights).to(device)
    if "drop" in arguments.keys():
        net.layer4.add_module('dropout', nn.Dropout(arguments["drop"], inplace=True))

    model = EasyMemo(net, device, mapping, prior_strength=prior_strength, top=top, ensemble=ensemble, lr=lr,
                     weight_decay=weight_decay, opt=opt, niter=niter)
    imageNet_A, imageNet_V2 = memo_get_datasets(aug_type, naugs, dataset_root)
    dataset = imageNet_A if arguments['dataset']['imageNetA'] else imageNet_V2
    
    f = open(f"{results_path_memo}/{test}.txt", "w")
    sys.stdout = f
    
    correct = 0
    cnt = 0
    total_time = 0
    
    index = np.random.permutation(range(len(dataset)))
    # iterate = tqdm(index)
    for i in index:
        data = dataset[i]
        image = data["img"]
        label = int(data["label"])
        
        start = time.time()
        prediction = model.predict(image)
        model.reset()
        end = time.time()
        cnt += 1
        total_time += end - start
        avg_time = total_time / cnt
        
        correct += mapping[prediction] == label
        # iterate.set_description(desc=f"Current accuracy {(correct / cnt) * 100:.2f}")
    
    print(f"Final Accuracy: {(correct / cnt) * 100:.2f} over {cnt} samples")
    memo_acc = (correct / cnt) * 100
    arguments["result"] = round(memo_acc,3)
    arguments["speed"] = round(avg_time, 3)
    json_dump = f"{results_path_memo}/{test}.json"
    with open(json_dump, "w") as f:
        json.dump(arguments, f)

    print("--------------------------------------------------------------")
    f.close()

In [None]:
imageNet_A, imageNet_V2 = memo_get_datasets('augmix', 1, DATASET_ROOT)
mapping_a = [int(x) for x in imageNet_A.classnames.keys()]
mapping_v2 = [int(x) for x in imageNet_V2.classnames.keys()]

del imageNet_A, imageNet_V2

memo_dict_tests = {
    "Baseline ImageNetA": {
        "memo": {
            "mapping": mapping_a,
            "ensemble": True,
            "top": 1,
        },
        "dataset": {
            "imageNetA": True,
        },
        "run": baseline_tests and (DATASET_TO_TEST in ["a", "both"]),
        "weights": "default",
    },
    "Baseline ImageNetV2": {
        "memo": {
            "mapping": mapping_v2,
            "ensemble": True,
            "top": 1,
        },
        "dataset": {
            "imageNetA": False,
        },
        "run": baseline_tests and (DATASET_TO_TEST in ["v2", "both"]),
        "weights": "default",
    },
    "Baseline ImageNetA ResNet50 weights V1": {
        "memo": {
            "mapping": mapping_a,
            "ensemble": True,
            "top": 1,
        },
        "dataset": {
            "imageNetA": True,
        },
        "weights": "v1",
        "run": baseline_tests and (DATASET_TO_TEST in ["a", "both"]),
    },
    "Baseline ImageNetV2 ResNet50 weights V1": {
        "memo": {
            "mapping": mapping_v2,
            "ensemble": True,
            "top": 1,
        },
        "dataset": {
            "imageNetA": False,
        },
        "weights": "v1",
        "run": baseline_tests and (DATASET_TO_TEST in ["v2", "both"]),
    },
    "MEMO ImageNetA, without topk selection": {
        "memo": {
            "top": 1,
            "mapping": mapping_a,
            "prior_strength": 0.94
        },
        "dataset": {
            "imageNetA": True,
            "naug": augs_no_selection,
            "aug_type": "cut",
        },
        "run": memo_tests and (DATASET_TO_TEST in ["a", "both"]),
    },
    "MEMO ImageNetV2, without topk selection": {
        "memo": {
            "top": 1,
            "mapping": mapping_v2,
            "prior_strength": 0.94
        },
        "dataset": {
            "imageNetA": False,
            "naug": augs_no_selection,
            "aug_type": "cut",
        },
        "run": memo_tests and (DATASET_TO_TEST in ["v2", "both"]),
    },
    "MEMO ImageNetA, with topk selection": {
        "memo": {
            "top": 0.1,
            "mapping": mapping_a,
            "prior_strength": 0.94
        },
        "dataset": {
            "imageNetA": True,
            "naug": augs_selection,
            "aug_type": "cut",
        },
        "run": memo_tests and (DATASET_TO_TEST in ["a", "both"]),
    },
    "MEMO ImageNetV2, with topk selection": {
        "memo": {
            "top": 0.1,
            "mapping": mapping_v2,
            "prior_strength": 0.94
        },
        "dataset": {
            "imageNetA": False,
            "naug": augs_selection,
            "aug_type": "cut",
        },
        "run": memo_tests and (DATASET_TO_TEST in ["v2", "both"]),
    },
    "DROP ImageNetA, without topk selection": {
        "memo": {
            "top": 1,
            "mapping": mapping_a,
            "ensemble": True,
        },
        "dataset": {
            "imageNetA": True,
            "naug": augs_no_selection,
            "aug_type": "identity",
        },
        "drop": 0.5,
        "run": drop_tests and (DATASET_TO_TEST in ["a", "both"]),
    },
    "DROP ImageNetV2, without topk selection": {
        "memo": {
            "top": 1,
            "mapping": mapping_v2,
            "ensemble": True,
        },
        "dataset": {
            "imageNetA": False,
            "naug": augs_no_selection,
            "aug_type": "identity",
        },
        "drop": 0.5,
        "run": drop_tests and (DATASET_TO_TEST in ["v2", "both"]),
    },
    "DROP ImageNetA, with topk selection": {
        "memo": {
            "top": 0.1,
            "mapping": mapping_a,
            "ensemble": True,
        },
        "dataset": {
            "imageNetA": True,
            "naug": augs_selection,
            "aug_type": "identity",
        },
        "drop": 0.5,
        "run": drop_tests and (DATASET_TO_TEST in ["a", "both"]),
    },
    "DROP ImageNetV2, with topk selection": {
        "memo": {
            "top": 0.1,
            "mapping": mapping_v2,
            "ensemble": True,
        },
        "dataset": {
            "imageNetA": False,
            "naug": augs_selection,
            "aug_type": "identity",
        },
        "drop": 0.5,
        "run": drop_tests and (DATASET_TO_TEST in ["v2", "both"]),
    },
    "Cut ensemble ImageNetA, without topk selection": {
        "memo": {
            "top": 1,
            "mapping": mapping_a,
            "ensemble": True,
        },
        "dataset": {
            "imageNetA": True,
            "naug": augs_no_selection,
            "aug_type": "cut",
        },
        "drop": 0,
        "run": ensemble_tests and (DATASET_TO_TEST in ["a", "both"]),
    },
    "Cut ensemble ImageNetV2, without topk selection": {
        "memo": {
            "top": 1,
            "mapping": mapping_v2,
            "ensemble": True,
        },
        "dataset": {
            "imageNetA": False,
            "naug": augs_no_selection,
            "aug_type": "cut",
        },
        "drop": 0,
        "run": ensemble_tests and (DATASET_TO_TEST in ["v2", "both"]),
    },
    "Cut ensemble ImageNetA, with topk selection": {
        "memo": {
            "top": 0.1,
            "mapping": mapping_a,
            "ensemble": True,
        },
        "dataset": {
            "imageNetA": True,
            "naug": augs_selection,
            "aug_type": "cut",
        },
        "drop": 0,
        "run": ensemble_tests and (DATASET_TO_TEST in ["a", "both"]),
    },
    "Cut ensemble ImageNetV2, with topk selection": {
        "memo": {
            "top": 0.1,
            "mapping": mapping_v2,
            "ensemble": True,
        },
        "dataset": {
            "imageNetA": False,
            "naug": augs_selection,
            "aug_type": "cut",
        },
        "drop": 0,
        "run": ensemble_tests and (DATASET_TO_TEST in ["v2", "both"]),
    },
}

for t in memo_dict_tests:
    if memo_dict_tests[t]["run"]:
        arg = memo_base_test | memo_dict_tests[t]
        arg['memo'] = memo_base_test['memo'] | memo_dict_tests[t]['memo']
        arg['dataset'] = memo_base_test['dataset'] | memo_dict_tests[t]['dataset']
        MEMO_testing_step(t, arg)


## TPT

For consistency, and due to time and memory constraints, we decided to use only ResNet-50 as the backbone for CLIP. Our baselines include CLIP zero-shot and the TPT results from the original paper.

|      Baselines      | ImageNet-A | ImageNet-V2 |
|:-------------------:|:----------:|:-----------:|
| CLIP <br> zero-shot |   21.91    |    51.2     |
|     TPT (paper)     |   26.67    |    54.70    |

For what concerns our TPT implementation, accuracy falls within half a percentage point of the original paper, this creates a solid baseline for our experiments in order to verify the effectiveness of the additional features we implemented. To make the comparison fair, we used the same number of augmentations and the same prompt as the original paper, same goes for learning rate and any hyperparameter that could be shared between the two models.

|                   Test                   | ImageNet-A | ImageNet-V2 | Delta A  | Delta V2 | Iterations(it/s) |
| :--------------------------------------: | :--------: | :---------: | :------: | :------: | :--------------: |
|           CLIP <br> zero-shot            |    21.9    |    51.2     |   0.00   |   0.00   |   6.41 - 1.31    |
|                TPT (ours)                |    26.4    |  **54.2**   |   4.5    |  **3**   |     1.52 - ~     |
|               TPT ensemble               |    26.4    |    53.1     |   4.5    |   1.9    |   6.33 - 1.31    |
| TPT ensemble + <br> confidence selection |  **32.2**  |    54.0     | **10.3** |   2.8    |   4.31 - 1.20    |
|             TPT + alignment              |    17.1    |      ~      |   -4.8   |    ~     |     0.46 - ~     |

### Experiments

The actual parameters used in our experiments can be found in the **Parameters** section at the beginning. Our TPT baseline was run with a single prompt tuning step and the prompt "A photo of a [CLS]". 

The **ensemble** experiments were run without prompt tuning, the prediction consists in the class with the highest marginal probability over the set of augmentations, when using a single augmentation (the origin image) it's equivalent to a CLIP zero-shot prediction, this last modality was used for the CLIP zero-shot baseline. The first ensamble experiment used 8 augmentations meanwhile the second one used 64 augmentations with 10% confidence selection. 

The **alignment** experiments were run in many different modalities: not only different number of steps and learning rates, we also tried to freeze and unfreeze different parts of the visual encoder for both the RN50 and ViT backbones, the results were always disappointing. The specific alignment experiment presented in the table was run by tuning a RN50 backbone with the attention weights frozen.

Results for **TPT + alignment** on ImageNet-V2 are not available as the model was just too large to fit in memory,


In [None]:

results_path_tpt = f"{RESULTS_PATH}/TPT"
os.makedirs(results_path_tpt, exist_ok=True)
DATASET_ROOT = "datasets"


for idx, settings in enumerate(tpt_tests):

    test = tpt_base_test | settings

    dataset_name = test["dataset"]
    test_name = test["name"]
    device = test["device"]

    BASE_PROMPT = test["base_prompt"]
    ARCH = test["arch"]
    SPLT_CTX = test["splt_ctx"]
    LR = test["lr"]
    AUGS = test["augs"]
    TTT_STEPS = test["ttt_steps"]
    ALIGN_STEPS = test["align_steps"]
    ENSEMBLE = test["ensemble"]
    TEST_STOP = test["test_stop"]
    CONFIDENCE = test["confidence"]

    f = open(f"{results_path_tpt}/{test_name}.txt", "w")
    sys.stdout = f


    print("-" * 30)
    print(f"[TEST] Running test {idx + 1} of {len(tpt_tests)}: {test_name} \n{test}")

    print(f"[TEST] loading datasets with {AUGS} augmentation...")
    datasetRoot = DATASET_ROOT
    (
        imageNetA,
        _,
        imageNetACustomNames,
        imageNetAMap,
        imageNetV2,
        _,
        imageNetV2CustomNames,
        imageNetV2Map,
    ) = tpt_get_datasets(datasetRoot, augs=AUGS, all_classes=False)
    print("[TEST] datasets loaded.")

    if dataset_name == "A":
        print("[TEST] using ImageNet A")
        dataset = imageNetA
        classnames = imageNetACustomNames
        id_mapping = imageNetAMap
        del imageNetV2, imageNetV2CustomNames, imageNetV2Map
    elif dataset_name == "V2":
        print("[TEST] using ImageNet V2")
        dataset = imageNetV2
        classnames = imageNetV2CustomNames
        id_mapping = imageNetV2Map
        del imageNetA, imageNetACustomNames, imageNetAMap

    tpt = EasyTPT(
        device,
        base_prompt=BASE_PROMPT,
        arch=ARCH,
        splt_ctx=SPLT_CTX,
        classnames=classnames,
        ttt_steps=TTT_STEPS,
        lr=LR,
        align_steps=ALIGN_STEPS,
        ensemble=ENSEMBLE,
        confidence=CONFIDENCE,
    )

    cnt = 0
    tpt_correct = 0
    total_time = 0

    idxs = [i for i in range(len(dataset))]

    SEED = 1
    np.random.seed(SEED)
    np.random.shuffle(idxs)

    for idx in idxs:
        data = dataset[idx]
        label = data["label"]
        imgs = data["img"]
        name = data["name"]

        start = time.time()

        cnt += 1
        with torch.no_grad():
            tpt.reset()

        out_id = tpt.predict(imgs)
        tpt_predicted = classnames[out_id]

        end = time.time()

        total_time += end - start
        avg_time = total_time / cnt

        if int(id_mapping[out_id]) == label:
            emoji = ":D"
            tpt_correct += 1
        else:
            emoji = ":("

        tpt_acc = tpt_correct / (cnt)

        if cnt % VERBOSE == 0:
            print(emoji)
            print(f"TPT Accuracy: {round(tpt_acc, 3)}")
            print(f"GT: \t{name}\nTPT: \t{tpt_predicted}")
            print(
                f"after {cnt} samples, average time {round(avg_time, 3)}s ({round(1 / avg_time, 3)}it/s)\n"
            )

        if cnt == TEST_STOP:
            print(f"[TEST] Early stopping at {cnt} samples")
            break

    del tpt

    print(f"[TEST] Final TPT Accuracy: {round(tpt_acc, 3)} over {cnt} samples")

    test["result"] = tpt_acc
    test["speed"] = round(avg_time, 3)
    json_dump = f"{results_path_tpt}/{test_name}.json"
    with open(json_dump, "w") as f:
        json.dump(test, f)

    f.close()

## Ensemble

Here we present the results of the ensemble of multiple models. We used the MEMO and TPT models with different backbones to see if the ensemble could improve the performance of the models.

|               Test               |                     Mode                     |               ImageNet-A               |                    Delta A                    |           ImageNet-V2           |            Delta V2             |
| :------------------------------: | :------------------------------------------: | :------------------------------------: | :-------------------------------------------: | :-----------------------------: | :-----------------------------: |
|      MEMO-RN50 + MEMO-RNXT       | Single Models<br>Ensemble<br>Simple Ensemble |     27.28 : 27.03<br>28.48<br>31.4     |        <br>1.20 : 1.45<br>4.12 : 4.37         | 68.48 : 69.88<br>69.63<br>70.96 | <br>1.15 : -0.25<br>2.48 ; 1.08 |
|       MEMO-RN50 + TPT-RN50       | Single Models<br>Ensemble<br>Simple Ensemble |    27.33 : 27.99<br>30.49<br>34.97     |         <br>3.16 : 2.5<br>7.64 : 6.98         |           -<br>-<br>-           |           <br>-<br>-            |
| MEMO-RN50 + MEMO-RNXT + TPT-RN50 | Single Models<br>Ensemble<br>Simple Ensemble | 27.1 : 27.17 : 28.37<br>27.74<br>**36.23** | <br>0.64 : 0.57 : -0.63<br>9.13 : 9.06 : 7.86 |           -<br>-<br>-           |           <br>-<br>-            |

### Results with MEMO-RN50 + MEMO-RNXT

These two models were trained with a similar loss and dataset, so we expect it to be the weaker of the ensembles.
Nonetheless, the ensemble achieves a top1 accuracy higher than the two models finetuned separetely on ImageNet-A with an accuracy of 28.48% on Imagenet-A, while we suffer a bit on Imagenet-V2 with an accuracy of 69.63% on Imagenet-V2, lower than the resnext model finetuned using MEMO alone.

As with the previous tests, we can see that the simpler strategy of aggregating the predictions of multiple augmentations as output to the model is the most effective strategy, with an accuracy of 31.4% on Imagenet-A and 70.96% on Imagenet-V2. This time we see an improvement across the board!

The results are promising, and we can see that the ensemble of multiple models can improve the performance of the models even if they are trained with similar techniques.

### Results with MEMO-RN50 + TPT-RN50

This test was a lot more interesting since we are combining two models which are trained with two very different objectives.
Thanks to this, we can see how the ensemble is able to improve on both models, probably because they are able to better cover the out of distribution data of one another.

The ensemble of the two models is able to improve the performance of the models, with an accuracy of 30.49% on Imagenet-A. This is a significant improvement over the single models, and it shows that the ensemble can be a very effective strategy to improve the performance of the models. But the simple ensemble strategy is still the most effective, with an accuracy of 34.97% on Imagenet-A, outperforming the ensemble of the two models respectively by 7.64 and 6.98 percentage points.

### Results with MEMO-RN50 + MEMO-RNXT + TPT-RN50
Finally, we tried to combine three different models, each using different backbones or trained with a different training objective, to see if we could push even further the performance of the ensemble.

The results confirmed our expectations, achieving an accuracy of 36.23% on Imagenet-A, the best result of all the tests. This shows that the ensemble of multiple models can be a very effective strategy to improve the performance of the models, even if they are trained with different techniques.

A note is to be made with the results of the ensamble strategy with backpropagation, which was not able to outperform one of the models used in the ensemble, but by tinkering with the temperature values the results could probably be improved, as we did not had the time to test different values.

In [None]:
def TPT(device="cuda", naug=64, arch="RN50", A=True, ttt_steps=1, align_steps=0, top=0.1):
    """
    Return the TPT model initialized with the given parameters

    Args:
        - device: device to use - default: cuda
        - naug: number of augmentations to use - default: 64
        - arch: backbone model to use - default: RN50
        - A: use ImageNet A or ImageNet V2 - default: True
        - ttt_steps: number of iterations for the TTT - default: 1
        - align_steps: number of iterations for the alignment of the image embeddings - default: 0
        - top: top confidence to select the augmented samples - default: 0.1
    """
    # prepare TPT
    if not torch.cuda.is_available():
        print("Using CPU this is no bueno")
    else:
        print("Using GPU, brace yourself!")

    datasetRoot = "datasets"
    imageNetA, _, imageNetACustomNames, imageNetAMap, imageNetV2, _, imageNetV2CustomNames, imageNetV2Map = tpt_get_datasets(datasetRoot, augs=naug, all_classes=False)
    
    if A:
        dataset = imageNetA
        classnames = imageNetACustomNames
        mapping = imageNetAMap
    else:
        dataset = imageNetV2
        classnames = imageNetV2CustomNames
        mapping = imageNetV2Map
    
    tpt = EasyTPT(
        base_prompt="A bad photo of a [CLS]",
        arch=arch,
        classnames=classnames,
        device=device,
        ttt_steps=ttt_steps,
        align_steps=align_steps,
        confidence=top
    )
    
    return tpt, dataset, mapping


def memo(device="cuda", prior_strength=0.94, naug=64, A=True, drop=0, ttt_steps=1, model="RN50", top=0.1):
    """
    Return the MEMO model initialized with the given parameters

    Args:
        - device: device to use - default: cuda
        - prior_strength: strength of the prior for the BN layers - default: 0.94
        - naug: number of augmentations to use - default: 64
        - A: use ImageNet A or ImageNet V2 - default: True
        - drop: dropout to use, by setting it to >0 the model will use the ensemble strategy - default: 0
        - ttt_steps: number of iterations for the TTT - default: 1
        - model: backbone model to use - default: RN50
        - top: top confidence to select the augmented samples - default: 0.1
    """
    load_model = {
        "RN50": torch_models.resnet50,
        "RNXT": torch_models.resnext50_32x4d
    }
    models_weights = {
        "RN50": torch_models.ResNet50_Weights.DEFAULT,
        "RNXT": torch_models.ResNeXt50_32X4D_Weights.DEFAULT
    }
    # prepare MEMO
    imageNet_A, imageNet_V2 = memo_get_datasets(augmentation=('cut' if drop==0 else 'identity'), augs=naug)
    dataset = imageNet_A if A else imageNet_V2

    mapping = list(dataset.classnames.keys())
    for i,id in enumerate(mapping):
        mapping[i] = int(id)
    
    model = load_model[model](weights=models_weights[model])
    model.layer4.add_module('dropout', nn.Dropout(drop))

    memo = EasyMemo(
        model, 
        device=device, 
        classes_mask=mapping, 
        prior_strength=prior_strength,
        niter=ttt_steps,
        ensemble=(drop>0),
        top=top
    )
    
    return memo, dataset, mapping


In [None]:
def test(models, datasets, temps, mapping, names,
         device="cuda", niter=1, top=0.1,
         simple_ensemble=False, testSingleModels=False, verbose=100):
    """
    Test the ensemble model on the datasets

    Args:
        - models: list of models
        - datasets: list of datasets
        - temps: list of temperatures for the models to rescale the logits
        - names: names of the models
        - mapping: mapping of the classes outputted by the models to the original 1000 classes
        - device: device to use
        - niter: number of iterations for the TTT
        - top: top confidence to select the augmented samples
        - simple_ensemble: use the simple ensemble, only marginalizing the ditributions
        - testSingleModels: test the single models inside the ensamble
        - verbose: verbosity level
    
    Returns:
        - results: dictionary with the results of the test
    """
    correct = 0
    correct_no_back = 0
    correctSingle = [0] * len(models)
    cnt = 0

    class_names = get_classes_names()

    # shuffle the data
    indx = np.random.permutation(range(len(datasets[0])))

    model = Ensemble(models, temps=temps,
                     device=device, test_single_models=testSingleModels,
                     simple_ensemble=simple_ensemble)
    print("Ensemble model created starting TTA, samples:", len(indx))
    for i in indx:
        cnt += 1
        data = [datasets[j][i]["img"] for j in range(len(datasets))]

        labels = [datasets[j][i]["label"] for j in range(len(datasets))]
        # check if the labels are the same
        assert all(x == labels[0] for x in labels), "Labels are not the same"
        label = labels[0]
        name = datasets[0][i]["name"]

        if(cnt%verbose==0): print(f"Tested Samples: {cnt} / {len(datasets[0])} - current sample: {name}")

        models_out, pred_no_back, prediction = model(data, niter=niter, top=top)
        models_out = [int(mapping[model_out]) for model_out in models_out]
        prediction = int(mapping[prediction])

        if testSingleModels:
            for i, model_out in enumerate(models_out):
                if label == model_out:
                    correctSingle[i] += 1

                if(cnt%verbose==0): 
                    print(
                    f"\t{names[i]} model accuracy: {correctSingle[i]}/{cnt} - predicted class {model_out}: {class_names[model_out]} - tested: {cnt} / {len(datasets[0])}")

        if simple_ensemble:
            pred_no_back = int(mapping[pred_no_back])
            if label == pred_no_back:
                correct_no_back += 1
            if(cnt%verbose==0): 
                print(
                f"\tSimple Ens accuracy: {correct_no_back}/{cnt} - predicted class {pred_no_back}: {class_names[pred_no_back]} - tested: {cnt} / {len(datasets[0])}")

        if label == prediction:
            correct += 1

        if(cnt%verbose==0): 
            print(
            f"\tEnsemble accuracy: {correct}/{cnt} - predicted class {prediction}: {class_names[prediction]} - tested: {cnt} / {len(datasets[0])}")
    
    results = {
        "Ensemble": correct / len(datasets[0]) * 100,
        "Simple Ens": correct_no_back / len(datasets[0]) * 100,
        "Single Models": [correctSingle[i] / len(datasets[0]) * 100 for i in range(len(models))]
    }
    return results


# expand args
def runTest(models_type, args, temps, names, naug=64, niter=1, top=0.1, device="cuda", simple_ensemble=False,
            testSingleModels=False, imageNetA=True, verbose=100):
    """
    Run the test on Ensemble model with the given arguments

    Args:
        - models_type: list of the models type to use
        - args: list of the arguments for the models
        - temps: list of temperatures for the models to rescale the logits
        - names: names of the models
        - naug: number of augmentations to use
        - niter: number of iterations for the TTT
        - top: top confidence to select the augmented samples
        - device: device to use
        - simple_ensemble: use the simple ensemble, only marginalizing the ditributions
        - testSingleModels: test the single models inside the ensamble
        - imageNetA: use ImageNet A or ImageNet V2
        - verbose: verbosity level

    Returns:
        - results: dictionary with the results of the test
    """

    models = []
    datasets = []
    mapping = None
    load_model = {
        "memo": memo,
        "tpt": TPT
    }
    for i in range(len(models_type)):
        model, data, mapping = load_model[models_type[i]](**args[i], A=imageNetA, top=top, naug=naug)
        models.append(model)
        datasets.append(data)

    result = test(models=models, datasets=datasets, temps=temps, mapping=mapping, names=names,
         device=device, niter=niter, top=top, simple_ensemble=simple_ensemble, testSingleModels=testSingleModels,
         verbose=verbose)

    for model in models:
        del model
    
    return result

In [None]:
results_path_tta = f"{RESULTS_PATH}/Ensemble"
os.makedirs(results_path_tta, exist_ok=True)

for ENStest in ENSTests:

        f = open(f"{results_path_tta}/{ENStest}.txt", "w")
        sys.stdout = f

        print(f"Running test: {ENStest}")
        test_params = ENSTests[ENStest]
        test_params["verbose"] = VERBOSE
        result = runTest(**test_params)

        #remove verbose from the results
        test_params.pop("verbose")
        test_params["result"] = result
        json.dump(f"{results_path_tpt}/{ENStest}.json", test_params)
        
        print("\tFinal Results:")
        for key in result:
            print(f"\t\t{key}: {result[key]}")

        print("\n-------------------\n")
        f.close()

# Conclusions

The entire project was created with flexibility in mind by using a modular structure that allows the user to easily switch between different models and backbones.

In conclusion, for both MEMO and TPT we have been able to match the baselines set by the original papers and, in some cases, improve on them. TPT in ensemble mode surpassed the prompt tuning method by almost 6 points and the CLIP baseline by 10.3. MEMO had similar improvements, with the ensemble method outperforming the MEMO reference by 6.55 points on ImageNet-A and 1.54 on ImageNet-V2. 
 
What took us the most by surprise was the effectiveness of the ensemble of augmentations, which was able to outperform the reference techniques without the need of backpropagate through the network. The best results were obtained by ensebling RN50 with RNXT and CLIP (RN50) via the simple ensemble strategy which achieved an accuracy of 36.23% on ImageNet-A, 14.33 points higher than the CLIP zero-shot baseline.

- [Laurence Bonat](https://github.com/blauer4)
- [Davide Cavicchini](https://github.com/DavidC001)
- [Lorenzo Orsingher](https://github.com/lorenzoorsingher)
- [**Github repo**](https://github.com/DavidC001/MEMO-TPT-DL2024)
