# Deep Learning 2024 - Project Assignment
This notebook contains the code for the 2024 **project assignment** of the course **Deep Learning** - MSc in Artificial Intelligence Systems - University of Trento.

The project is developed by [Alessandro Lorenzi](mailto:alessandro.lorenzi-1@studenti.unitn.it) and [Luca Cazzola](mailto:luca.cazzola-1@studenti.unitn.it).


The primary objective of this project is to implement a **Test-Time Adaptation (TTA)** solution, focusing on enhancing the predictive capabilities of pre-trained neural networks on unseen test samples.

Throughout this notebook, we implement the **Test-Time Prompt Tuning (TPT)** method in combination also with **CoOp**, a few-shot prompt tuning strategy.

Additionally, we experiment with different **image augmentation techniques** to assess their impact on the model's performance.

To further improve the effectiveness of our approach, we propose a novel method that utilizes **image captioning to generate more context-aware prompts**. 

The goal of this notebook is to evaluate these methods and demonstrate improvements in model performance without access to test labels or additional data during inference.

## Smart Index

To make this notebook easier to navigate, we've created a smart index with links to the most important sections. Each section includes explanations and comments to guide you through the key aspects of our implementation and experiments. Click on the links below to jump directly to the relevant parts of the notebook:

- [Image Augmentations](#Image-Augmentations-⭐)
- [Prompt-Augmentation](#Prompt-Augmentation-⭐)
- [Results obtained with TPT on ImageNet-A](#Results-obtained-with-TPT-on-ImageNet-A-⭐)
- [Conclusions](#Conclusions-⭐)
- [Augmentations PlayGround](#Augmentations-PlayGround-⭐)

Feel free to explore each section to understand our approach and findings in detail!

---

## Notebook Setup

In [None]:
!pip install -q ftfy regex tqdm scikit-learn scikit-image
!pip install -q git+https://github.com/openai/CLIP.git
!wget -P /lib https://raw.githubusercontent.com/google-research/augmix/master/augmentations.py

!pip install wandb
!pip install -U diffusers
!pip install keybert[spacy]
!pip install -U accelerate
!pip install --upgrade accelerate

# Core modules
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split, Subset, RandomSampler

import torchvision
import torchvision.transforms as transforms
from torchvision.transforms import v2
from PIL import Image, ImageOps, ImageEnhance

# Models
from clip import clip
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
from diffusers import StableDiffusionImageVariationPipeline
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
from keybert import KeyBERT

# Plotting
import wandb
import matplotlib.pyplot as plt
from tqdm import tqdm

# AWS
from pathlib import Path
import boto3
from io import BytesIO

# Utility
import numpy as np
from copy import deepcopy

%matplotlib inline
%config InlineBackend.figure_format = "retina"

In [2]:
# Set globally the GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
# Set clip weights to use
print(clip.available_models())
clip_version = 'RN50'

['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']


## DATASET INSTANCES

In [3]:
'''
Arguments relative to datasets
'''
args_data = {
    "dataset": "imagenet-a",   # Dataset of choice
    "n_aug": 63,               # number of augmentations per image
    "retain_aug_perc": 0.1,    # percentage of augmentations to retain
}

In [4]:
IMGNET_A_CLASSES = {
    "n01498041": "stingray", "n01531178": "goldfinch", "n01534433": "junco", "n01558993": "American robin", "n01580077": "jay", "n01614925": "bald eagle","n01616318": "vulture","n01631663": "newt","n01641577": "American bullfrog","n01669191": "box turtle","n01677366": "green iguana","n01687978": "agama","n01694178": "chameleon","n01698640": "American alligator","n01735189": "garter snake","n01770081": "harvestman","n01770393": "scorpion","n01774750": "tarantula","n01784675": "centipede","n01819313": "sulphur-crested cockatoo","n01820546": "lorikeet","n01833805": "hummingbird","n01843383": "toucan","n01847000": "duck","n01855672": "goose","n01882714": "koala","n01910747": "jellyfish","n01914609": "sea anemone","n01924916": "flatworm","n01944390": "snail","n01985128": "crayfish","n01986214": "hermit crab","n02007558": "flamingo","n02009912": "great egret","n02037110": "oystercatcher","n02051845": "pelican","n02077923": "sea lion","n02085620": "Chihuahua","n02099601": "Golden Retriever","n02106550": "Rottweiler","n02106662": "German Shepherd Dog","n02110958": "pug","n02119022": "red fox","n02123394": "Persian cat","n02127052": "lynx","n02129165": "lion","n02133161": "American black bear","n02137549": "mongoose","n02165456": "ladybug","n02174001": "rhinoceros beetle","n02177972": "weevil","n02190166": "fly","n02206856": "bee","n02219486": "ant","n02226429": "grasshopper","n02231487": "stick insect","n02233338": "cockroach","n02236044": "mantis","n02259212": "leafhopper","n02268443": "dragonfly","n02279972": "monarch butterfly","n02280649": "small white","n02281787": "gossamer-winged butterfly","n02317335": "starfish","n02325366": "cottontail rabbit","n02346627": "porcupine","n02356798": "fox squirrel","n02361337": "marmot","n02410509": "bison","n02445715": "skunk","n02454379": "armadillo","n02486410": "baboon","n02492035": "white-headed capuchin","n02504458": "African bush elephant","n02655020": "pufferfish","n02669723": "academic gown","n02672831": "accordion","n02676566": "acoustic guitar","n02690373": "airliner","n02701002": "ambulance","n02730930": "apron","n02777292": "balance beam","n02782093": "balloon","n02787622": "banjo","n02793495": "barn","n02797295": "wheelbarrow","n02802426": "basketball","n02814860": "lighthouse","n02815834": "beaker","n02837789": "bikini","n02879718": "bow","n02883205": "bow tie","n02895154": "breastplate","n02906734": "broom","n02948072": "candle","n02951358": "canoe","n02980441": "castle","n02992211": "cello","n02999410": "chain","n03014705": "chest","n03026506": "Christmas stocking","n03124043": "cowboy boot","n03125729": "cradle","n03187595": "rotary dial telephone","n03196217": "digital clock","n03223299": "doormat","n03250847": "drumstick","n03255030": "dumbbell","n03291819": "envelope","n03325584": "feather boa","n03355925": "flagpole","n03384352": "forklift","n03388043": "fountain","n03417042": "garbage truck","n03443371": "goblet","n03444034": "go-kart","n03445924": "golf cart","n03452741": "grand piano","n03483316": "hair dryer","n03584829": "clothes iron","n03590841": "jack-o'-lantern","n03594945": "jeep","n03617480": "kimono","n03666591": "lighter","n03670208": "limousine","n03717622": "manhole cover","n03720891": "maraca","n03721384": "marimba","n03724870": "mask","n03775071": "mitten","n03788195": "mosque","n03804744": "nail","n03837869": "obelisk","n03840681": "ocarina","n03854065": "organ","n03888257": "parachute","n03891332": "parking meter","n03935335": "piggy bank","n03982430": "billiard table","n04019541": "hockey puck","n04033901": "quill","n04039381": "racket","n04067472": "reel","n04086273": "revolver","n04099969": "rocking chair","n04118538": "rugby ball","n04131690": "salt shaker","n04133789": "sandal","n04141076": "saxophone","n04146614": "school bus","n04147183": "schooner","n04179913": "sewing machine","n04208210": "shovel","n04235860": "sleeping bag","n04252077": "snowmobile","n04252225": "snowplow","n04254120": "soap dispenser","n04270147": "spatula","n04275548": "spider web","n04310018": "steam locomotive","n04317175": "stethoscope","n04344873": "couch","n04347754": "submarine","n04355338": "sundial","n04366367": "suspension bridge","n04376876": "syringe","n04389033": "tank","n04399382": "teddy bear","n04442312": "toaster","n04456115": "torch","n04482393": "tricycle","n04507155": "umbrella","n04509417": "unicycle","n04532670": "viaduct","n04540053": "volleyball","n04554684": "washing machine","n04562935": "water tower","n04591713": "wine bottle","n04606251": "shipwreck","n07583066": "guacamole","n07695742": "pretzel","n07697313": "cheeseburger","n07697537": "hot dog","n07714990": "broccoli","n07718472": "cucumber","n07720875": "bell pepper","n07734744": "mushroom","n07749582": "lemon","n07753592": "banana","n07760859": "custard apple","n07768694": "pomegranate","n07831146": "carbonara","n09229709": "bubble","n09246464": "cliff","n09472597": "volcano","n09835506": "baseball player","n11879895": "rapeseed","n12057211": "yellow lady's slipper","n12144580": "corn","n12267677": "acorn"
}

# Load CLIP pre-process functions
_, clip_preprocess = clip.load(clip_version, device)
base_transform = v2.Compose(clip_preprocess.transforms[:3]) # resize + center crop + RGB
preprocess = v2.Compose(clip_preprocess.transforms[3:]) # toTensor + CLIP normalization

class S3ImageFolder(Dataset):
    '''
    Dataset which handles single images
    '''
    def __init__(self, root, transform=None):
        self.s3_bucket = "deeplearning2024-datasets"
        self.s3_region = "eu-west-1"
        self.s3_client = boto3.client("s3", region_name=self.s3_region, verify=True)
        self.transform = transform

        # Get list of objects in the bucket
        response = self.s3_client.list_objects_v2(Bucket=self.s3_bucket, Prefix=root)
        objects = response.get("Contents", [])
        while response.get("NextContinuationToken"):
            response = self.s3_client.list_objects_v2(
                Bucket=self.s3_bucket,
                Prefix=root,
                ContinuationToken=response["NextContinuationToken"]
            )
            objects.extend(response.get("Contents", []))

        # Iterate and keep valid files only
        self.instances = []
        for ds_idx, item in enumerate(objects):
            key = item["Key"]
            path = Path(key)

            # Check if file is valid
            if path.suffix.lower() not in (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp"):
                continue

            # Get label
            label = path.parent.name

            # Keep track of valid instances
            self.instances.append((IMGNET_A_CLASSES[label], key))

        # Sort classes in alphabetical order (as in ImageFolder)
        self.classes = sorted(set(label for label, _ in self.instances))
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx):
        try:
            label, key = self.instances[idx]

            # Download image from S3
            # response = self.s3_client.get_object(Bucket=self.s3_bucket, Key=key)
            # img_bytes = response["Body"]._raw_stream.data

            img_bytes = BytesIO()
            response = self.s3_client.download_fileobj(Bucket=self.s3_bucket, Key=key, Fileobj=img_bytes)
            # img_bytes = response["Body"]._raw_stream.data

            # Open image with PIL
            img = Image.open(img_bytes).convert("RGB")

            # Apply transformations if any
            if self.transform is not None:
                img = self.transform(img)
        except Exception as e:
            raise RuntimeError(f"Error loading image at index {idx}: {str(e)}")

        return img, self.class_to_idx[label]


    def get_image(self, idx):
        '''
        Returns (PIL image + label) without applying the transform
        '''
        try:
            label, key = self.instances[idx]
            img_bytes = BytesIO()
            response = self.s3_client.download_fileobj(Bucket=self.s3_bucket, Key=key, Fileobj=img_bytes)
            # Open image with PIL
            img = Image.open(img_bytes).convert("RGB")
        except Exception as e:
            raise RuntimeError(f"Error loading image at index {idx}: {str(e)}")
        return img, label


class Augmenter(Dataset):
    '''
    Dataset which handles augmentations
    '''
    def __init__(self, dataset_images, aug_component):
        self.dataset_images = dataset_images
        # object responsible for contructing augmentations
        self.aug_component = aug_component

    def __len__(self):
        return len(self.dataset_images) * args_data['n_aug']

    def __getitem__(self, idx):
        # Gather PIL image
        label, key = self.dataset_images.instances[idx]
        img_bytes = BytesIO()
        response = self.dataset_images.s3_client.download_fileobj(Bucket=self.dataset_images.s3_bucket, Key=key, Fileobj=img_bytes)
        img = Image.open(img_bytes).convert("RGB")

        # Apply augmentations
        original_image = preprocess(base_transform(img))
        augmentations = [self.aug_component.augment(img) for _ in range(args_data['n_aug'])]
        augmentations =  [original_image] + augmentations

        # stack augmentations together as tensors
        augmentations = torch.stack(augmentations, dim=0)

        return augmentations


class Augmented_dataset(Dataset):
    '''
    Join together S3ImageFolder & Augmenter in a single dataset
    '''
    def __init__(self, images_dataset, augmenter):
        self.images_dataset = images_dataset
        self.augmenter = augmenter
        self.classes = images_dataset.classes
        self.is_augmenter_on = True

    def __getitem__(self, idx):
        image, label = self.images_dataset[idx]
        augmentations = self.augmenter[idx] if self.is_augmenter_on else torch.zeros(image.shape)
        return image, label, augmentations, idx

    def __len__(self):
        return len(self.images_dataset)

    def turnOFFaugs(self):
        self.is_augmenter_on = False

    def turnONaugs(self):
        self.is_augmenter_on = True


def get_data(dataset, perc=1, batch_size=1, test_batch_size=1):
    '''
    Returns a dataloader for the given dataset
    '''
    # sub-sample the dataset
    size = int(perc * len(dataset))
    # Split dataset into train, validation, and test sets
    dataset_subset = Subset(dataset, range(size))
    # Create DataLoader instances
    data_loader = DataLoader(dataset_subset, batch_size=batch_size, shuffle=False, num_workers=4)

    #random_sampler = RandomSampler(dataset, num_samples=size)
    #data_loader = DataLoader(dataset, batch_size=batch_size, sampler=random_sampler, shuffle=False, num_workers=4)

    return data_loader

## Image Augmentations ⭐

We've tested with 4 metohds to augment images :
1) **PreAugment** : applies only random crop to the image
2) **AugMix**    [[1]](#References) [[2]](#References) : the method used in the original TPT implementation, technique which mixes randomly generated augmentations and
uses a Jensen-Shannon loss to enforce consistency
3) **AutoAugment** [[3]](#References) : a reinforcement learning based method which augment an image according to the one maximizing accuracy (trained on ImageNet)
4) **DiffusionAugment** [[4]](#References) : uses a diffusion model to generate augmentations

### Results
Here we present the outcomes of applying different augmentation techniques using TPT with handcrafted prompts. The evaluation metrics include average accuracy and average loss (entropy) across the test dataset, providing insights into how each augmentation method influences the model's performance.

##### N.B.
* In the case of **DiffusionAugment** while testing we've realized <u>it is too much expensive (time wise) to generate images online during evaluation for our hardware</u>.<br>It takes around $\sim$12 secs. for the diffusion model we've selected to perform 25 diffusion steps. Moreover, a single augmentation isn't enough to us and even downsampling the number of augmentations to generate from 64 to 10 would still be expensive (2 min. per image $\times$ 7500 for ImageNet-A = 250 hours of runtime).<br>A work which tests the effectiveness of diffusion models combined with TPT is [**DiffTPT**](https://arxiv.org/abs/2308.06038), in which they avoid the issue of "online generation" by basically generating offline augmentations and store them apart ready to be used during inference. **We consider such solution not really aligned with the goal of TTA as it breaks down the whole principle of improving during inference only**. For this reason we stopped experimenting with this solution and didn't report any results (other than the code) related to it.

<div align=center>

| <!-- --> | <!-- --> |
|-|-|
|<img src="imgs/augmentations_result_1.png" width="650" />|<img src="imgs/augmentations_result_2.png" width="650" />|

</div>

**Average Accuracy** (TPT - Handcrafted Prompts):

<div align=center>

| Augmentation Technique | Avg Accuracy (%) |
| ---------------------- | ---------------- |
| RandomCrop + Flip      | 27.51             |
| **AutoAugment**        | **30.36**         |
| AugMix                 | 28.80             |

</div>

**Average Loss** (TPT - Handcrafted Prompts):

<div align=center>

| Augmentation Technique | Avg Loss |
| ---------------------- | -------- |
| RandomCrop + Flip      | 3.02041  |
| **AutoAugment**        | **1.89376** |
| AugMix                 | 1.91948  |

</div>

From these results, we observe that <u>**AutoAugment** outperforms the other techniques</u> achieving the highest average accuracy of 30.36% and the lowest average loss of 1.89376. This indicates that AutoAugment not only improves the model's predictive accuracy but also enhances its ability to generalize by reducing the loss during inference.

**RandomCrop + Flip** performs the worst in this comparison and this suggests that while basic augmentations can introduce some robustness, more advanced techniques like AutoAugment are necessary for optimizing performance in this scenario. **AugMix** shows better results than RandomCrop + Flip but still falls short compared to AutoAugment.

Overall, the results highlight the importance of choosing the right augmentation strategy when working with TPT and handcrafted prompts. **AutoAugment** clearly demonstrates superior performance, suggesting that its ability to apply diverse and meaningful transformations to the input data leads to better alignment between visual and textual representations, thereby enhancing the model's accuracy and reducing its loss.


## Image AugmentationExamples

<br>

### > PreAugment - AugMix - AutoAugment
<br>

<div align=center><img src="imgs/augmentations.png" width="1500" /></div>

<br><hr><br>

### > DiffusionAugment
<br>

<div align=center><img src="imgs/diff_augment_ex.png" width="800" /></div>

### Code Process Explanation

1. **PreAugment** : <br>
Applies Basic transformation of RandomResizedCrop + RandomHorizontalFlip.<br>This method is **used as preliminary augmentation to both AugMix and AutoAugment** before the application of their specific augmentations

2. **AugMix** : <br>
Follows the original implementation [[2]](#References) as we've found hard to work with the one provided by PyTorch.

3. **AutoAugment** : <br>
We don't directly apply the PyTorch implementation, instead we took inspiration from AugMix in the way each augmentation is returned.<br>A factor $m$ is sampled from a $\beta$-distribution and the resulting tensor is given by : $(m)$ $\cdot$ AutoAugment + $(1-m)$ $\cdot$ PreAugment input.<br>We've found this approach of constraining the output with informations coming from the original image to be much more robust than applying directly the augmentation.

5. **DiffusionAugment** <br>
Follows the implementation provided by HuggingFace [[4]](#References)


In [5]:
IMAGE_SIZE = 224


def get_preaugment():
    '''
    PreAugment : applies only random crop to the image
    '''
    return transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
        ])

class augMethod():
    def __init__(self, selection, severity=1):
        self.selection = selection

        if selection == 'AugMix' :
            self.aug_component = my_augMix(severity=severity, n_aug=args_data['n_aug'])
        elif selection == 'AutoAugment' :
            self.aug_component = my_AutoAugment()
        elif selection == 'DiffusionAugment' :
            self.aug_component = my_diffusionAugmenter()

    def augment(self, img):
        if self.selection == 'PreAugment':
            transform = get_preaugment()
            return preprocess(transform(img))

        return self.aug_component.apply_augmentation(img)


class my_diffusionAugmenter():
    '''
    Personalized instance of diffusion-based augmentation
    uses a diffusion model to generate augmentations
    '''
    def __init__(self):
        self.sd_pipe = StableDiffusionImageVariationPipeline.from_pretrained(
            "lambdalabs/sd-image-variations-diffusers",
            revision="v2.0",
          )
        self.sd_pipe = self.sd_pipe.to(device)

        self.base_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(
                (224, 224),
                interpolation=transforms.InterpolationMode.BICUBIC,
                antialias=False,
                ),
            transforms.Normalize(
              [0.48145466, 0.4578275, 0.40821073],
              [0.26862954, 0.26130258, 0.27577711]),
        ])
        args_data['n_aug'] = 4


    def apply_augmentation(self, img):
        src_img = img.copy()
        x_orig = self.base_transform(src_img).to(device).unsqueeze(0)
        out = self.sd_pipe(x_orig, guidance_scale=3, num_inference_steps=25, output_type='pil')
        out = preprocess(base_transform(out["images"][0]))
        return out


class my_AutoAugment():
    '''
    Personalized instance of AutoAugment component
    reinforcement learning based method which augment an image according to the one maximizing accuracy (trained on ImageNet)
    '''
    def __init__(self, n_aug=args_data['n_aug']):
        self.aug_component = v2.AutoAugment().to(device)
        self.transform = [
                self.aug_component
            ]

    def apply_augmentation(self, img):
        preaugment = get_preaugment()
        x_orig = preaugment(img)
        x_processed = preprocess(x_orig)

        if len(self.transform) == 0:
            return x_processed

        m = np.float32(np.random.beta(1.0, 1.0))
        x_aug = x_orig.copy()
        x_aug = self.transform[0](x_aug)
        mix = m * x_processed + (1 - m) * preprocess(x_aug)

        return mix


class my_augMix():
    '''
    Personalized instance of AugMix component
    '''
    def __init__(self, severity=3, n_aug=args_data['n_aug']):

        self.severity = severity
        self.n_aug = n_aug

        self.transform = [
            self.autocontrast, self.equalize, self.posterize, self.rotate, self.solarize, self.shear_x, self.shear_y,
            self.translate_x, self.translate_y
        ]

    # apply augmentations to a PIL image and retrun them as a list
    def apply_augmentation(self, img):
        preaugment = get_preaugment()
        x_orig = preaugment(img)
        x_processed = preprocess(x_orig)
        if len(self.transform) == 0:
            return x_processed
        w = np.float32(np.random.dirichlet([1.0, 1.0, 1.0]))
        m = np.float32(np.random.beta(1.0, 1.0))

        mix = torch.zeros_like(x_processed)
        for i in range(3):
            x_aug = x_orig.copy()
            for _ in range(np.random.randint(1, 4)):
                x_aug = np.random.choice(self.transform)(x_aug, self.severity)
            mix += w[i] * preprocess(x_aug)
        mix = m * x_processed + (1 - m) * mix

        return mix

    # Utility functions
    def int_parameter(self, level, maxval):
        return int(level * maxval / 10)

    def float_parameter(self, level, maxval):
        return float(level) * maxval / 10.

    def sample_level(self, n):
        return np.random.uniform(low=0.1, high=n)

    # AUGMENTATIONS

    def autocontrast(self, pil_img, _):
        return ImageOps.autocontrast(pil_img)

    def equalize(self, pil_img, _):
        return ImageOps.equalize(pil_img)

    def posterize(self, pil_img, level):
        level = self.int_parameter(self.sample_level(level), 4)
        return ImageOps.posterize(pil_img, 4 - level)

    def rotate(self, pil_img, level):
        degrees = self.int_parameter(self.sample_level(level), 30)
        if np.random.uniform() > 0.5:
            degrees = -degrees
        return pil_img.rotate(degrees, resample=Image.BILINEAR)

    def solarize(self, pil_img, level):
        level = self.int_parameter(self.sample_level(level), 256)
        return ImageOps.solarize(pil_img, 256 - level)

    def shear_x(self, pil_img, level):
        level = self.float_parameter(self.sample_level(level), 0.3)
        if np.random.uniform() > 0.5:
            level = -level
        return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), Image.AFFINE, (1, level, 0, 0, 1, 0), resample=Image.BILINEAR)

    def shear_y(self, pil_img, level):
        level = self.float_parameter(self.sample_level(level), 0.3)
        if np.random.uniform() > 0.5:
            level = -level
        return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), Image.AFFINE, (1, 0, 0, level, 1, 0), resample=Image.BILINEAR)

    def translate_x(self, pil_img, level):
        level = self.int_parameter(self.sample_level(level), IMAGE_SIZE / 3)
        if np.random.random() > 0.5:
            level = -level
        return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), Image.AFFINE, (1, 0, level, 0, 1, 0), resample=Image.BILINEAR)

    def translate_y(self, pil_img, level):
        level = self.int_parameter(self.sample_level(level), IMAGE_SIZE / 3)
        if np.random.random() > 0.5:
            level = -level
        return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), Image.AFFINE, (1, 0, 0, 0, 1, level), resample=Image.BILINEAR)

    # operation that overlaps with ImageNet-C's test set
    def color(self, pil_img, level):
        level = self.float_parameter(self.sample_level(level), 1.8) + 0.1
        return ImageEnhance.Color(pil_img).enhance(level)

    # operation that overlaps with ImageNet-C's test set
    def contrast(self, pil_img, level):
        level = self.float_parameter(self.sample_level(level), 1.8) + 0.1
        return ImageEnhance.Contrast(pil_img).enhance(level)

    # operation that overlaps with ImageNet-C's test set
    def brightness(self, pil_img, level):
        level = self.float_parameter(self.sample_level(level), 1.8) + 0.1
        return ImageEnhance.Brightness(pil_img).enhance(level)

    # operation that overlaps with ImageNet-C's test set
    def sharpness(self, pil_img, level):
        level = self.float_parameter(self.sample_level(level), 1.8) + 0.1
        return ImageEnhance.Sharpness(pil_img).enhance(level)

## Prompt Augmentation ⭐

In this section, we introduce a novel approach for augmenting prompts using an **image captioning system**. 

This method aims to create more context-aware prompts compared to the standard, generic descriptions like "a photo of a [class label]." Our hypothesis is that captions specifically tailored to the content of the image will enhance the alignment between the image and the class labels, leading to improved model performance.

### Method Overview

1. **Image Captioning**: <br>
We use the VisionEncoderDecoderModel (ViT-GPT2) [[5]](#References) to generate descriptive captions from the images. This model integrates a Vision Transformer (ViT) with GPT-2, allowing it to produce detailed captions that capture the visual content of the images.

2. **KeyWords Extraction**: <br>
After generating the caption, we utilize KeyBERT [[6]](#References) to extract the most relevant keywords or phrases from the caption. These keywords represent the key elements or subjects described in the caption.

3. **Personalized Prompts composition**: <br>
We replace the most relevant keyword in the caption with each class label from the dataset to create personalized prompts. This process generates a set of prompts specific to the content of the image and the class labels.

### High Level Schema
<center><img src="imgs/image_captioning_schema.png" width="1500"></center>

### Results and Discussion

We evaluated our proposed prompt augmentation method using an image captioning system and compared it with baseline methods. The performance was assessed on both the variants of zero-shot CLIP: CLIP-RN50 and CLIP-ViT-B/16.

Below are the results of our method compared to the baseline:

**Average Loss and Accuracy for Zero-Shot CLIP (CLIP-RN50):**

| Method                | Avg Loss      | Avg Accuracy (%) |
| --------------------- | ------------- | ---------------- |
| Our Method            | 3.0781        | 19.41            |
| Baseline              | -             | 21.83            |

**Average Loss and Accuracy for Zero-Shot CLIP (CLIP-ViT-B/16):**

| Method                | Avg Loss      | Avg Accuracy (%) |
| --------------------- | ------------- | ---------------- |
| Our Method            | 2.5711        | 42.13            |
| Baseline              | -             | 47.87            |

#### Analysis

For **CLIP-RN50**, our method achieved an average loss of 3.0781 and an average accuracy of 19.41%. In comparison, the baseline method yielded an average accuracy of 21.83%. The results indicate that our approach led to a decrease in accuracy, suggesting that the personalized prompts generated through image captioning did not provide the anticipated benefit over the standard method.

For **CLIP-ViT-B/16**, our method resulted in an average loss of 2.5711 and an average accuracy of 42.13%. The baseline accuracy for CLIP-ViT-B/16 was 47.87%. The accuracy was still below the baseline, indicating that the context-aware prompts did not fully match the performance of the standard approach.

#### Discussion

Despite our hypothesis that contextually specific prompts would improve model performance, the **results suggest otherwise**. The personalized **prompts generated by the image captioning system did not achieve better results than the standard approaches**. The observed decrease in accuracy, highlights that this method may not have been effective in enhancing alignment between the image content and class labels in the tested configurations. Potential weaknesses to our method are discussed [**later**](#Drawbacks).

This outcome provides valuable insights into the challenges of using image captions for prompt augmentation. It underscores the need for further refinement of the method and suggests that additional techniques or adjustments might be necessary to achieve the desired improvements. Future exploration could focus on optimizing the image captioning and keyword extraction processes, or integrating other innovative approaches to better leverage context-aware prompts for improved performance.

In [None]:
# Image captioning chain
model_vitGpt = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning").to(device)
feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer_gpt2 = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
# Load keyword extractor
model_keyBert = KeyBERT()

gen_kwargs = {"max_length": 16, "num_beams": 4}
def get_captions(images):
    '''
    Produce captions using vit-gpt2 chain
    '''
    pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
    pixel_values = pixel_values.to(device)
    output_ids = model_vitGpt.generate(pixel_values, **gen_kwargs)
    preds = tokenizer_gpt2.batch_decode(output_ids, skip_special_tokens=True)
    preds = [pred.strip() for pred in preds]
    return preds


def get_keywords(text) :
    '''
    Extract keywords from a sentence using KeyBERT
    '''
    keywords = model_keyBert.extract_keywords(text)
    words = model_keyBert.extract_keywords(text, keyphrase_ngram_range=(1, 1), stop_words=None)
    words = [keyword[0] for keyword in words]
    return words


def get_personilized_prompt(net, dataset, idx):
    '''
    Produce at-hoc context caption and insert real class name
    '''
    img, _ = dataset.images_dataset.get_image(idx)
    caption = get_captions([img])[0]  # feed list of PIL images to the ViT-GPT2 pipeline and produce caption
    subject = get_keywords(caption)[0]  # get head of sentence
    prompts = [caption.replace(subject, label) for label in dataset.classes] # substitute the head with the class label
    prompts = [(prompt if len(prompt) <= 70 else prompt[:70]) for prompt in prompts]
    tokenized_prompts = clip.tokenize(prompts).to(device)
    texts_z = net.text_encoder(tokenized_prompts)
    return texts_z


## Code Process Explanation

1. **Caption Generation** : <br>
The first step involves generating a descriptive caption for each image using the ViT-GPT2 model. This caption provides a detailed textual representation of the visual content.

2. **Keyword Extraction** : <br>
From the generated caption, we extract the most significant keywords using KeyBERT. KeyBERT is obviously a BERT variant optimized on predicting words relevance in a given sentence.<br>Most relevant words are what we call **KeyWords**. Specifically, we're interested in knowing what is **the most relevant word among all words**

3. **Prompt Personalization** : <br>
The extracted keyword is then substituted with class labels from the dataset to create contextually relevant prompts. These prompts are designed to be more specific and better aligned with the image content, potentially improving the model’s ability to classify the image accurately.


#### Drawbacks

This implementation basically <u>delegates the handcrafted prompt design to an image captioner model</u>. Such design can potentially be harmful since :
* Performances are dependent on an secondary supervising model (ViT-GPT2 chain in this case) which is detatched from the rest.
* If the input image is noisy, the produced caption will also probably be noisy, making the inference even harder than using a more generic prompt like "a photo of a { label }"<br>As we're evaluating performances on an noisy dataset such as ImageNet-A this might be the most relevant aspect
* As stated by OpenAI itself, CLIP is very sensitive to wording

---

### Loss Functions

In [7]:
class EntropyLoss(nn.Module):
    '''
    Unsupervised entropy loss function    
    '''
    def __init__(self):
        super(EntropyLoss, self).__init__()

    def forward(self, x):
        return -torch.sum(x * torch.log(x + 1e-9), dim=-1)

def get_loss_function() :
    return EntropyLoss()

def get_optimizer(model, lr=0.005, wd=0.0005, momentum=0.9):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    return optimizer

### CoOp prompt Learner + CLIP adapted text encoder

In [8]:
args_CoOp = {
    "n_ctx": 4,
    "ctx_init": "a photo of a",
    "class_token_position": "end",
    "csc": False,
}

_tokenizer = _Tokenizer()

class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection

    def forward(self, prompts, tokenized_prompts):
        # prompts : ones we learn
        # tokenized_prompts : from original input : for eot
        x = prompts + self.positional_embedding
        x = x.permute(1, 0, 2)  # [batch_size, n_ctx, transformer.width] -> [n_ctx, batch_size, transformer.width]
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # [n_ctx, batch_size, transformer.width] -> [batch_size, n_ctx, transformer.width]
        x = self.ln_final(x)

        # Take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection

        return x

class PromptLearner(nn.Module):
    def __init__(self, clip_model, classes, n_ctx, ctx_init, class_token_position, csc=False):
        super().__init__()
        n_cls = len(classes)
        ctx_dim = clip_model.ln_final.weight.shape[0]
        clip_imsize = clip_model.visual.input_resolution

        # Use given words to initialize context vectors
        if ctx_init:
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init).to(clip_model.token_embedding.weight.device)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt)
            ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
            prompt_prefix = ctx_init
        else:
            if csc:
                print("Initializing class-specific contexts")
                ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim)
            else:
                print("Initializing a generic context")
                ctx_vectors = torch.empty(n_ctx, ctx_dim)

            torch.nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        print(f"Initial context: '{prompt_prefix}'")
        print(f"Number of context words (tokens): {n_ctx}")

        # These are the `prompts` we want to optimize
        self.ctx = nn.Parameter(ctx_vectors)

        classes = [name.replace("_", " ") for name in classes]
        name_lens = [len(_tokenizer.encode(name)) for name in classes]
        prompts = [prompt_prefix + " " + name + "." for name in classes]

        # print("+++")
        # print("Prompts:")
        # for p in prompts:
        #     print(p)
        # print("+++")

        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(clip_model.token_embedding.weight.device)

        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts)

        # These token vectors will be saved when in save_model(),
        # but they should be ignored in load_model() as we want to use
        # those computed using the current class names
        self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :])  # CLS, EOS

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts
        self.name_lens = name_lens
        self.class_token_position = class_token_position

    def forward(self):
        prefix = self.token_prefix
        suffix = self.token_suffix
        ctx = self.ctx

        # If CoOp, expand the ctx for all classes
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)

        if self.class_token_position == "end":
            prompts = torch.cat(
                [
                    prefix,  # (n_cls, 1, dim)
                    ctx,     # (n_cls, n_ctx, dim)
                    suffix,  # (n_cls, *, dim)
                ],
                dim=1,
            )

        elif self.class_token_position == "middle":
            half_n_ctx = self.n_ctx // 2
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :]
                ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :]
                prompt = torch.cat(
                    [
                        prefix_i,     # (1, 1, dim)
                        ctx_i_half1,  # (1, n_ctx//2, dim)
                        class_i,      # (1, name_len, dim)
                        ctx_i_half2,  # (1, n_ctx//2, dim)
                        suffix_i,     # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        elif self.class_token_position == "front":
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i = ctx[i : i + 1, :, :]
                prompt = torch.cat(
                    [
                        prefix_i,  # (1, 1, dim)
                        class_i,   # (1, name_len, dim)
                        ctx_i,     # (1, n_ctx, dim)
                        suffix_i,  # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        else:
            raise ValueError

        return prompts

# TPT MODEL

In [9]:
class TPT(nn.Module) :
    '''
    Model Class of TPT
    '''
    def __init__(self, clip_model, classes, prompts=None, only_CLIP=True, enable_CoOp=False):
        super().__init__()
        self.enable_CoOp = enable_CoOp # if set True, the model uses TPT + CoOp
        self.only_CLIP = only_CLIP     # if set True, the model uses CLIP only (TPT part is not performed)

        if not enable_CoOp or only_CLIP :
            # Generate text embeddings
            tokenized_prompts = clip.tokenize(prompts).to(device)
            self.texts_z = clip_model.encode_text(tokenized_prompts)
            self.text_encoder = clip_model.encode_text
        else :
            # Setup CoOp prompt learner
            self.prompt_learner = PromptLearner(clip_model, classes, args_CoOp['n_ctx'], args_CoOp['ctx_init'], args_CoOp['class_token_position'], args_CoOp['csc'])
            # References to image + text encoders
            self.text_encoder = TextEncoder(clip_model)

        # References to image + text encoders
        self.image_encoder = clip_model.visual
        self.logit_scale = clip_model.logit_scale


    def forward (self, inputs):
        if self.enable_CoOp :
            # Embed prompts
            prompts = self.prompt_learner()
            tokenized_prompts = self.prompt_learner.tokenized_prompts
            self.texts_z = self.text_encoder(prompts, tokenized_prompts)

        # Embed augmentations
        images_z = self.image_encoder(inputs)

        # L2 norm to tensors
        images_z = images_z / torch.linalg.vector_norm(images_z, keepdim=True, dim=-1)
        texts_z = self.texts_z / torch.linalg.vector_norm(self.texts_z, keepdim=True, dim=-1)

        # Compute logits
        logit_scale = self.logit_scale.exp()
        output = (logit_scale * images_z @ texts_z.T).softmax(dim=-1)

        return output

    def testTimeAdapt(self, clip_output, loss_function):
        '''
        Apply augmentation selection based on entropies and average the histograms
        '''
        # Entropy for each augmentation
        aug_losses = loss_function(clip_output)
        # retain top 'retain_aug_perc'%, ordered by entropy
        output = torch.index_select(clip_output, 0, torch.argsort(aug_losses, descending=False))[:int(np.ceil(args_data['n_aug'] * args_data['retain_aug_perc']))]
        # average
        output = torch.mean(output, dim=0)
        
        return output

## TRAIN / TEST functions

We've separated eval loops to have cleaner code

In [10]:
def eval_loop(net, loss_function, test_loader, dataset, enable_personilized_prompt=False):
    '''
    Perform Zero-Shot clip / TPT
    '''
    samples = 0.0
    cumulative_loss = 0.0
    cumulative_accuracy = 0.0

    desc = "Zero Shot CLIP" if net.only_CLIP else "TPT - handcrafted prompts" 
    # Set the network to evaluation mode
    net.eval()
    with torch.no_grad():
        # Iterate over the test set
        for batch_idx, (inputs, targets, augmentations, idx) in tqdm(enumerate(test_loader), desc=desc, position=0, leave=True, total=len(test_loader)):
            # Load data into GPU
            inputs = inputs.to(device)
            targets = targets.to(device)

            if not net.only_CLIP :
                augmentations = torch.squeeze(augmentations, axis=0)
                augmentations = augmentations.to(device)

            # Forward pass
            net.texts_z = get_personilized_prompt(net, dataset, idx) if enable_personilized_prompt else net.texts_z
            output = net(inputs) if net.only_CLIP else net(augmentations)
            # Apply augmentation selection based on entropies and average the histograms
            output = output if net.only_CLIP else net.testTimeAdapt(output, loss_function)
            # Loss computation
            loss = loss_function(output)

            # Compute matching
            _, predicted = output.max(dim=-1)

            cumulative_accuracy += predicted.eq(targets).sum().item()
            samples += inputs.shape[0]
            cumulative_loss += loss.item()

    return cumulative_loss/samples, cumulative_accuracy/samples*100


def fewShot_eval_loop(net, loss_function, optimizer, train_loader):
    '''
    Perform few shot inference with TPT + CoOp
    '''
    samples = 0.0
    cumulative_loss = 0.0
    cumulative_accuracy = 0.0

    optim_state = deepcopy(optimizer.state_dict())
    prompt_learner_state = deepcopy(net.prompt_learner.state_dict())

    #net.train()
    for batch_idx, (inputs, targets, augmentations, _) in tqdm(enumerate(train_loader), desc="TPT + CoOp", position=0, leave=True, total=len(train_loader)):
        net.train()
        # Load data into GPU
        inputs = inputs.to(device)
        targets = targets.to(device)

        augmentations = torch.squeeze(augmentations, axis=0)  # [1, n_aug, 3, 224, 224] -> [n_aug, 3, 224, 224]
        augmentations = augmentations.to(device)

        # Forward pass
        output = net(augmentations)

        # ================================= TPT computation ================================= #

        # Apply augmentation selection based on entropies and average the histograms
        output = net.testTimeAdapt(output, loss_function)
        # Compute Loss and update weights
        loss = loss_function(output)
        loss.backward()
        optimizer.step()

        # ================================= INFERENCE ================================= #

        # Forward pass again to compute accuracy
        net.eval()
        with torch.no_grad():
            # Forward again + TPT computation
            output = net(augmentations)
            output = net.testTimeAdapt(output, loss_function)

            loss = loss_function(output)
            _, predicted = output.max(dim=-1)
            # Calculate scores
            cumulative_accuracy += predicted.eq(targets).sum().item()
            samples += inputs.shape[0]
            cumulative_loss += loss.item()

        # reset prompt learner & optimizer to its original state
        net.prompt_learner.load_state_dict(prompt_learner_state)
        optimizer.load_state_dict(optim_state)

    return cumulative_loss/samples, cumulative_accuracy/samples*100

## RUNS

In [11]:
params = {
    "lr": 0.005,
    "wd": 0.0005,
    "momentum": 0.9,
    "n_epochs": 1,
}

# Set WanDB logger to track resuls
logger_on = False

if logger_on :
    # start a new wandb run to track this script
    wandb.init(
        # set the wandb project where this run will be logged
        project="TPT-CoOp",
    
        # track hyperparameters and run metadata
        config={
            "epochs": params["n_epochs"],
            "learning_rate": params["lr"],
            "momentum": params["momentum"],
            "weight decay": params["wd"],
            "number of augmentations": args_data["n_aug"],
            "dataset": args_data["dataset"],
        }
    )

In [12]:
params_augs = {
    'img': 'AutoAugment', # PreAugment - AugMix  - AutoAugment - DiffusionAugment
    'prompt' : False,   # (True/False) to enable / disable prompt augmentation
}

In [13]:
def main(flg=(False, False, False)):
    # Flags to enable/disable models when running the function
    flag_zeroShot, flag_TPT_handCrafted, flag_TPT_CoOp = flg

    # Load CLIP pre-trained weights
    clip_model, clip_preprocess = clip.load(clip_version)
    clip_model = clip_model.float()
    clip_model = clip_model.to(device)

    # Setup images dataset
    ds_images = S3ImageFolder(args_data["dataset"], transform=clip_preprocess)
    # Setup aumentated dataset
    aug_component = augMethod(params_augs['img'])
    ds_augmentations = Augmenter(ds_images, aug_component)
    # Merge the 2 as a single dataset
    dataset = Augmented_dataset(ds_images, ds_augmentations)

    # Turn off gradients in CLIP's image & text encoders
    for _, param in clip_model.named_parameters():
        param.requires_grad_(False)

    # Handcrafted prompts to use
    handcrafted_prompts = [f"a photo of a {label}" for label in dataset.classes]

    if flag_zeroShot :
        model_TPT_clipOnly = TPT(clip_model, ds_images.classes, prompts=handcrafted_prompts, only_CLIP=True, enable_CoOp=False).to(device)
    
    if flag_TPT_handCrafted:
        model_TPT_handCrafted = TPT(clip_model, ds_images.classes, prompts=handcrafted_prompts, only_CLIP=False, enable_CoOp=False).to(device)

    if flag_TPT_CoOp:
        model_TPT_CoOp = TPT(clip_model, ds_images.classes, only_CLIP=False, enable_CoOp=True).to(device)
        for name, param in model_TPT_CoOp.named_parameters():
            if "prompt_learner" not in name:
                param.requires_grad_(False)

    # ====================================== Run Settings ====================================== 

    loss_function = get_loss_function()
    # Define dataloader
    perc = 1  # percentage of the dataset to use
    data_loader = get_data(dataset, perc=perc, batch_size=1)

    for ep in range(params["n_epochs"]):
        # Runs Zero-Shot evaluation
        if flag_zeroShot :
            dataset.turnOFFaugs()
            avg_loss, avg_acc = eval_loop(model_TPT_clipOnly, loss_function, data_loader, dataset, enable_personilized_prompt=params_augs['prompt'])
            print(f"   Avg loss : {avg_loss}")
            print(f"   Avg accuracy : {avg_acc}\n")
            if logger_on :
                wandb.log({"Avg loss (zero-shot CLIP)": avg_loss, "Avg accuracy (zero-shot CLIP)": avg_acc})

        if flag_TPT_handCrafted :
            dataset.turnONaugs()
            avg_loss, avg_acc = eval_loop(model_TPT_handCrafted, loss_function, data_loader, dataset, enable_personilized_prompt=params_augs['prompt'])
            print(f"   Avg loss : {avg_loss}")
            print(f"   Avg accuracy : {avg_acc}\n")
            if logger_on :
                wandb.log({"Avg loss (TPT - handcrafted prompts)": avg_loss, "Avg accuracy (TPT - handcrafted prompts)": avg_acc})

        if flag_TPT_CoOp :
            dataset.turnONaugs()
            optimizer = get_optimizer(model_TPT_CoOp, lr=params["lr"])
            avg_loss, avg_acc = fewShot_eval_loop(model_TPT_CoOp, loss_function, optimizer, data_loader)
            print(f"   Avg loss : {avg_loss}")
            print(f"   Avg accuracy : {avg_acc}\n")
            if logger_on :
                wandb.log({"Avg loss (TPT + CoOp)": avg_loss, "Avg accuracy (TPT + CoOp)": avg_acc})

    # end run
    if logger_on :
        wandb.finish()

### Run procedure

In [None]:
"""
runs the evaluation procedure with specified models or methods.

Flags:
- `flag_zeroShot`: If True, the zero-shot model will be used.
- `flag_TPT_handCrafted`: If True, Test-Time Prompt Tuning with handcrafted prompts will be used.
- `flag_TPT_CoOp`: If True, Test-Time Prompt Tuning with CoOp will be used.
"""
with torch.no_grad():
    torch.cuda.empty_cache()

# Flags to enable/disable models when running the function
# { flag_zeroShot, flag_TPT_handCrafted, flag_TPT_CoOp }
flags = (True, False, False)
main(flags)

## Results obtained with TPT on ImageNet-A ⭐
By implementing TPT on ImageNet-A we verified the results of the original paper.

Here are our results using AugMix:

| Method                    | Avg Accuracy | Avg Loss |
| ------------------------- | ------------ | -------- |
| CLIP-RN50 (zero-shot)     | 21.88        | 2.32929  |
| TPT (handcrafted prompts) | 28.8         | 1.91948  |
| TPT + CoOp                | 29.41333     | 1.89968  |


## Conclusions ⭐

The primary aim of this project was to implement a Test-Time Adaptation (TTA) solution by focusing on improving the predictive capabilities of pre-trained neural networks on unseen test samples. Our approach involved implementing Test-Time Prompt Tuning (TPT) with a few-shot prompt tuning strategy called CoOp and experimenting with various image augmentation techniques.

**Key Findings:**

1. **Evaluation of Augmentation Techniques** : <br>
We assessed several image augmentation techniques using TPT with handcrafted prompts. The results demonstrated that **AutoAugment** significantly outperforms other methods, indicating that it provides the most effective transformations, enhancing the model’s performance and generalization ability. AugMix performed better than RandomCrop + Flip but still fell short compared to AutoAugment.

2. **Prompt Augmentation Using Image Captioning** : <br>
We introduced a novel method for generating context-aware prompts using image captioning. Although our hypothesis was that contextually specific captions would improve model alignment and performance, the results did not meet expectations. For both **CLIP-RN50** and **CLIP-ViT-B/16**, the personalized prompts generated by the image captioning system led to a decrease in accuracy compared to standard methods.

3. **Results on ImageNet-A** : <br>
Our implementation of TPT on ImageNet-A validated the results from the original paper. 

**Conclusions** :

Overall, this project highlights the importance of selecting appropriate augmentation techniques and prompt tuning strategies to optimize model performance. While AutoAugment proved to be the most effective augmentation strategy, the image captioning-based prompt augmentation did not deliver the expected improvements. This underscores the need for further exploration and refinement of prompt augmentation methods. Future work could focus on optimizing image captioning techniques, experimenting with other innovative methods, or combining multiple strategies to better align text and visual representations for improved model performance.



---

## References
- [1] [AugMix](https://pytorch.org/vision/main/generated/torchvision.transforms.AugMix.html)
- [2] [AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty](https://arxiv.org/abs/1912.02781)
- [3] [AutoAugment](https://pytorch.org/vision/main/generated/torchvision.transforms.AutoAugment.html)
- [4] [DiffusionAugment](https://huggingface.co/lambdalabs/sd-image-variations-diffusers)
- [5] [ViT-GPT2](https://huggingface.co/nlpconnect/vit-gpt2-image-captioning)
- [6] [KeyBERT](https://github.com/MaartenGr/KeyBERT)

---
---
---
---
---

<br><br>

# Augmentations PlayGround ⭐
In this section it is possible to play around with the augmentation methods we've implemented

### Prompt Augmentation

In [None]:
# Choose in {0 - 7499}
input_idx = 7

# Compute
dataset = S3ImageFolder(args_data["dataset"], transform=None)
img, label = dataset.get_image(input_idx)

og_caption = get_captions([img])[0]
print(f"caption : {og_caption}")
print(f"=>   label: {label}")
subject = get_keywords(og_caption)[0]
print(f"=>   KeyWord: {subject}")
caption = og_caption.replace(subject, label)
print(f"\noutput prompt : {caption}\n")

# Display
plt.imshow(img)
plt.axis('off')
plt.tight_layout()
plt.show()

### Image Augmentation : PreAugment - AugMix - AutoAugment

In [None]:
# Choose in {0 - 7499}
input_idx = 99
# Number of augmentations per type
num_augmentations = 7
# regulates aggressiveness of augmentations (positive integer)
severity = 3

# Compute
ds_images = S3ImageFolder(args_data["dataset"], transform=None)
img, label = ds_images.get_image(input_idx)

aug_1 = augMethod('PreAugment', severity=severity)
aug_2 = augMethod('AugMix', severity=severity)
aug_3 = augMethod('AutoAugment', severity=severity)
augmenter_names = ["PreAugment", "AugMix", "AutoAugment"]

aug1 = aug_1.augment(img)
aug2 = aug_2.augment(img)
aug3 = aug_3.augment(img)

invNorm = transforms.Compose([
    transforms.Normalize(mean = [ 0., 0., 0. ], std = [ 1/0.26862954, 1/0.26130258, 1/0.27577711 ]),
    transforms.Normalize(mean = [ -0.48145466, -0.4578275, -0.40821073 ], std = [ 1., 1., 1. ])
])
tensor_to_pil_img = transforms.ToPILImage()

augmented_images_1 = [tensor_to_pil_img(invNorm(aug_1.augment(img))) for _ in range(num_augmentations)]
augmented_images_2 = [tensor_to_pil_img(invNorm(aug_2.augment(img))) for _ in range(num_augmentations)]
augmented_images_3 = [tensor_to_pil_img(invNorm(aug_3.augment(img))) for _ in range(num_augmentations)]

# Display
images = augmented_images_1 + augmented_images_2 + augmented_images_3
augmentation_titles = ["PreAugment", "AugMix", "AutoAugment"]
fig, axes = plt.subplots(4, num_augmentations, figsize=(20, 15))

for ax in axes[0]:
    ax.axis('off')

mid_col = num_augmentations // 2
axes[0, mid_col].imshow(img.resize((224, 224)))  
axes[0, mid_col].set_title('Original', fontsize=16, pad=20)

for row in range(3):  
    for col in range(num_augmentations):
        idx = row * num_augmentations + col
        axes[row + 1, col].imshow(images[idx])
        axes[row + 1, col].axis('off')
    axes[row + 1, mid_col].set_title(augmentation_titles[row], fontsize=16, pad=20)

plt.tight_layout()
#plt.savefig('imgs/augmentations.png')
plt.show()

### Image Augmentation : DiffusionAugment

In [None]:
# Load this only once
augmenter = augMethod('DiffusionAugment')

In [None]:
# Choose in {0 - 7499}
input_idx = 7001

# Compute
ds_images = S3ImageFolder(args_data["dataset"], transform=None)
img, label = ds_images.get_image(input_idx)

invNorm = transforms.Compose([
    transforms.Normalize(mean = [ 0., 0., 0. ], std = [ 1/0.26862954, 1/0.26130258, 1/0.27577711 ]),
    transforms.Normalize(mean = [ -0.48145466, -0.4578275, -0.40821073 ], std = [ 1., 1., 1. ])
])
tensor_to_pil_img = transforms.ToPILImage()

aug = tensor_to_pil_img(invNorm(augmenter.augment(img)))

# Display
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
ax[0].imshow(img.resize((224, 224)))
ax[0].set_title("Original Image", fontsize=16, pad=20)
ax[0].axis("off")
ax[1].imshow(aug)
ax[1].set_title("Generated Image", fontsize=16, pad=20)
ax[1].axis("off")

#plt.savefig('imgs/diff_augment_ex.png')
plt.show()

<br>

---
---