# Few-shot Adaptation: A Review of Different Methods

**Damiano Orlandi**


**Tommaso Rondani**


**Raffaele Sinani**


## Introduction

The aim of this work is to explore different methods to perform few-shot adaptation of a pre-trained CLIP model. Few-shot adaptation allows a model to improve its performance on a number of base classes, of which some examples are provided, while retaining the general performance of the pre-trained model on other novel, unseen classes.

The methods we proposed are a confidence-based thresholding methods, an anomaly detection-based approach and an architecture based on Topological Data Analysis.

The outline of this paper is as follows. First, we briefly introduce the methodologies used, as well as specifying some design choices we employed. Then, we present our code implementation, showing and explaining the necessary steps to reproduce our results. Lastly, we present the results obtained and provide some comments, comparing the advantages and disadvantages of the techniques used.

## Methodologies

### Confidence-based Thresholding

The first, simplest method that we employed was confidence-based thresholding.We initially trained a linear classifier on the 51 classes from the base set, using the frozen, pretrained ResNet-50 CLIP backbone as feature extractor, while keeping the CLIP text encoder fixed.  A threshold value was selected using the validation set. At inference time, new entries were classified according to the following criterion: if the softmaxed confidence level of the classifier surpassed the threshold, then that prediction is utilized. Otherwise, we fell back on a similarity-based approach: we computed the cosine similarity between the image features and the text features, mimicking the behavior of the original CLIP model.

If an appropriate number of training epochs and a well-tuned threshold are selected, the worst-case performance of the model should approximate that of the original CLIP. A sufficiently high threshold ensures that the classifier is only used for examples where it is highly confident.

### Anomaly Detection-based methods

The second approach we experimented with involved anomaly detection techniques.  
We explored two different methodologies:

- **A simple Gaussian cluster-based method:**  
  After computing and storing the centroid of the extracted features for each class,  
  we calculated the distance from each new input to every class centroid.  
  A prediction was made based on whether the distance fell below a threshold,  
  which was determined using the standard deviation of the training examples within each class.

- **A hierarchical generative model:**  
  A more complex, hierarchical generative model was applied to the input images  
  to distinguish between base and novel classes.  
  This approach should enable selective use of finetuned methods for base classes  
  and CLIP for novel classes.  
  The model is described in [Sheynin et al., (2021)](https://arxiv.org/abs/2104.14535) and was adapted for training on the Flowers102 dataset.


### TopoVPT with thresholding

The last approach attempted to bridge deep learning and Topological Data Analysis (TDA). By using persistence images ([Adams et al., 2015](https://arxiv.org/abs/1507.06217)) as prompts for the VPT ([Jia et al., 2022](https://arxiv.org/abs/2203.12119)), we expected to obtain an improvement on the accuracy for the images in the base classes. Persistence images, without delving into a heavily mathematical explanation (which is found [here](https://arxiv.org/abs/2312.05840)), may be considered as topological filter maps, which focus on understanding the creation and dissolution of, in this case, 0-dimensional and 1-dimensional cavities and transform the obtained data into an image. Due to the novelty of this technique, only black-and-white images can be processed, since multiple channels would require the usage of multiparameter homology. In order to achieve that, the natural topological structure is the cubical homology, which treats each pixel as a vertex of a discrete lattice. This model was adopted by the work -- not yet published -- done by Casacuberta and Ferr√† Marcus on cardiovascular magnetic resonance imaging.

Two versions of this model have been proposed: one in which only the prompt injections parameters are trained and one in which also the parameters of the third and fourth layer are trained.

The main limitation consists of the fact that methods arisen from TDA are computationally intensive, do not have yet a GPU-based implementation, and require a large amount images to capture the necessary information. Moreover, more complex topological tool, such as multiparameter homology, would be more appropriate to capture the RGB colouring of the images.

## Implementation

### Requirements

In [None]:
%pip install openai_clip   #This is required on AWS as clip is not naturally installed
%pip install gudhi

import torch
import torchvision
import clip
from tqdm import tqdm
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from PIL import Image
import skimage
import torch.nn.init as init
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import itertools
import gudhi
import gudhi.representations
import time

Note: you may need to restart the kernel to use updated packages.


2025-08-22 14:24:45.234747: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-08-22 14:24:45.248785: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1755872685.266941    4743 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1755872685.272575    4743 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-08-22 14:24:45.291995: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

### Dataset Loading
The Flowers102 dataset can directly be downloaded via torchvision

In [None]:
def get_data(data_dir="./data", transform=None):
    """Load Flowers102 train, validation and test sets.
    Args:
        data_dir (str): Directory where the dataset will be stored.
        transform (torch.Compose)
    Returns:
        tuple: A tuple containing the train, validation, and test sets.
    """
    train = torchvision.datasets.Flowers102(root=data_dir, split="train", download=True, transform=transform)
    val = torchvision.datasets.Flowers102(root=data_dir, split="val", download=True, transform=transform)
    test = torchvision.datasets.Flowers102(root=data_dir, split="test", download=True, transform=transform)
    return train, val, test

### Base and Novel categories
To split in base and novel categories we list all dataset classes, and count their number.
Then, we just allocate the first half to base categories and the remaining half to novel ones.

In [None]:
def base_novel_categories(dataset):
    # set returns the unique set of all dataset classes
    all_classes = set(dataset._labels)
    # and let's count them
    num_classes = len(all_classes)

    # here list(range(num_classes)) returns a list from 0 to num_classes - 1
    # then we slice the list in half and generate base and novel category lists
    base_classes = list(range(num_classes))[:num_classes//2]
    novel_classes = list(range(num_classes))[num_classes//2:]
    return base_classes, novel_classes, all_classes

### Inspect Classes
We can now visualize which are the base and novel classes.

In [None]:
_, _, tmp_test = get_data()
base_classes, novel_classes, all_classes = base_novel_categories(tmp_test)
CLASS_NAMES = ["pink primrose", "hard-leaved pocket orchid", "canterbury bells", "sweet pea", "english marigold", "tiger lily", "moon orchid", "bird of paradise", "monkshood", "globe thistle", "snapdragon", "colt's foot", "king protea", "spear thistle", "yellow iris", "globe-flower", "purple coneflower", "peruvian lily", "balloon flower", "giant white arum lily", "fire lily", "pincushion flower", "fritillary", "red ginger", "grape hyacinth", "corn poppy", "prince of wales feathers", "stemless gentian", "artichoke", "sweet william", "carnation", "garden phlox", "love in the mist", "mexican aster", "alpine sea holly", "ruby-lipped cattleya", "cape flower", "great masterwort", "siam tulip", "lenten rose", "barbeton daisy", "daffodil", "sword lily", "poinsettia", "bolero deep blue", "wallflower", "marigold", "buttercup", "oxeye daisy", "common dandelion", "petunia", "wild pansy", "primula", "sunflower", "pelargonium", "bishop of llandaff", "gaura", "geranium", "orange dahlia", "pink-yellow dahlia?", "cautleya spicata", "japanese anemone", "black-eyed susan", "silverbush", "californian poppy", "osteospermum", "spring crocus", "bearded iris", "windflower", "tree poppy", "gazania", "azalea", "water lily", "rose", "thorn apple", "morning glory", "passion flower", "lotus", "toad lily", "anthurium", "frangipani", "clematis", "hibiscus", "columbine", "desert-rose", "tree mallow", "magnolia", "cyclamen", "watercress", "canna lily", "hippeastrum", "bee balm", "ball moss", "foxglove", "bougainvillea", "camellia", "mallow", "mexican petunia", "bromelia", "blanket flower", "trumpet creeper", "blackberry lily"]
print("Base Class Names:", [(i, CLASS_NAMES[i]) for i in base_classes])
print("Novel Class Names:", [(i, CLASS_NAMES[i]) for i in novel_classes])

Base Class Names: [(0, 'pink primrose'), (1, 'hard-leaved pocket orchid'), (2, 'canterbury bells'), (3, 'sweet pea'), (4, 'english marigold'), (5, 'tiger lily'), (6, 'moon orchid'), (7, 'bird of paradise'), (8, 'monkshood'), (9, 'globe thistle'), (10, 'snapdragon'), (11, "colt's foot"), (12, 'king protea'), (13, 'spear thistle'), (14, 'yellow iris'), (15, 'globe-flower'), (16, 'purple coneflower'), (17, 'peruvian lily'), (18, 'balloon flower'), (19, 'giant white arum lily'), (20, 'fire lily'), (21, 'pincushion flower'), (22, 'fritillary'), (23, 'red ginger'), (24, 'grape hyacinth'), (25, 'corn poppy'), (26, 'prince of wales feathers'), (27, 'stemless gentian'), (28, 'artichoke'), (29, 'sweet william'), (30, 'carnation'), (31, 'garden phlox'), (32, 'love in the mist'), (33, 'mexican aster'), (34, 'alpine sea holly'), (35, 'ruby-lipped cattleya'), (36, 'cape flower'), (37, 'great masterwort'), (38, 'siam tulip'), (39, 'lenten rose'), (40, 'barbeton daisy'), (41, 'daffodil'), (42, 'sword 

### Split Dataset
The next step is to actually split the dataset into the base and novel categories we extract from `base_novel_categories`.
To split the data we need the dataset and the list of base classes. If the sample label is not part of the base categories, then it must be part of the novel ones.

In [None]:
def split_data(dataset, base_classes):
    # these two lists will store the sample indexes
    base_categories_samples = []
    novel_categories_samples = []

    # we create a set of base classes to compute the test below in O(1)
    # this is optional and can be removed
    base_set = set(base_classes)

    # here we iterate over sample labels and also get the correspondent sample index
    for sample_id, label in enumerate(dataset._labels):
        if label in base_set:
            base_categories_samples.append(sample_id)
        else:
            novel_categories_samples.append(sample_id)

    # here we create the dataset subsets
    # the torch Subset is just a wrapper around the dataset
    # it simply stores the subset indexes and the original dataset (your_subset.dataset)
    # when asking for sample i in the subset, torch will look for its original position in the dataset and retrieve it
    # https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset
    base_dataset = torch.utils.data.Subset(dataset, base_categories_samples)
    novel_dataset = torch.utils.data.Subset(dataset, novel_categories_samples)
    return base_dataset, novel_dataset

### Extract k shots
As the dataset already provides 10 train and validation shots, we do not need to extract them.


### Load CLIP

Using the `clip` module we can import the default image preprocessing for CLIP.

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# available models = ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']
_, preprocess = clip.load("RN50", device=device)

### Load and Prepare Data
Here we get the three dataset split and pass CLIP pre-defined augmentations. Then, we compute base and novel categories. Finally, se split the three datasets into base and novel categories.
The novel categories are excluded from the training set but are conserved for testing and validation sets. We opted to maintain novel examples for validation because we did not want to test hyperparameters on the test set, as it would artificially inflate the results.

In [None]:
# get the three datasets
train_set, val_set, test_set = get_data(transform=preprocess)

# split classes into base and novel
base_classes, novel_classes, all_classes = base_novel_categories(train_set)

# split the three datasets
train_base, _ = split_data(train_set, base_classes)
val_base, val_novel = split_data(val_set, base_classes)
test_base, test_novel = split_data(test_set, base_classes)

### Creating a function to convert to grayscale

In [None]:
def rgb_to_grayscale(image: torch.Tensor) -> torch.Tensor:
    if image.ndim != 3 or image.shape[0] != 3:
        raise ValueError("Expected image tensor of shape (3, H, W)")
    r, g, b = image[0:1], image[1:2], image[2:3]
    return 0.299 * r + 0.587 * g + 0.114 * b

### Defining a function to create the persistence images

The function utilizes the `gudhi` library to generate the persistence images from the input images. They are then injected as prompts in the TopoVPT model.

In [None]:
def topo_hammer(img: torch.Tensor) -> torch.Tensor:
    """Return a **50√ó50** persistence image (H‚ÇÅ) as float32 CPU tensor."""
    # Tranform the input image to grayscale
    gray = rgb_to_grayscale(img).squeeze(0).cpu().numpy()
    # Create cubical complex
    cubical = gudhi.CubicalComplex(top_dimensional_cells=gray)
    # Compute persistence diagram
    d = cubical.persistence()
    # Keep only the necessary data
    d = np.array([interval for (dim,interval) in d])
    # Remove np.inf
    d = d[~np.isinf(d).any(axis=1)]
    # Instantiate persistence image
    pim = gudhi.representations.PersistenceImage(
        resolution=[50, 50], bandwidth=0.1, weight=lambda pt: pt[1] ** 2
    )
    # Compute persistence image
    vec = pim.fit_transform([d])[0].astype(np.float32)
    # Return a torch.tensor
    return torch.tensor(vec).view(50, 50)# 50 prompts √ó 50 dims

### Creating our CLIP models

We create now the clip model used for the clustering and thresholding methodologies. The two different modes can be exchanged on model initialization using the `crit` parameter.
We start from the baseline ResNet50 clip backbone and add a linear classifier to implement thresholding. To implement clustering we also need to initialize some more parameters that are collected during training to conserve the cluster centroids and distances.

The following code can be used to initialize both Threshold and Clustering models.

In [None]:
class ThresholdCLIP(nn.Module):
    def __init__(self, text_features = 0, exnumber: int = 10, num_classes: int = 102, crit = "thresholding"):
        super().__init__()
        #Initialize the CLIP visual encoder
        self.model, _ = clip.load("RN50")
        self.model = self.model.float()

        self.num_classes = num_classes
        #The original clip text and visualk encoder is maintained
        self.text_encoder = self.model.encode_text
        self.encoder = self.model.visual

        #The linear classifier is implemented. We utilize a size of 102, even though the base classes are 51 so that the tensor shape is still 102, matching the testing scenario
        self.classifier = nn.Linear(1024, num_classes)

        #This conserves the type of model we are going to utilize
        self.criterion = crit

        #This initializes a default threshold, used for both clustering and thresholding
        self.threshold = 0.3

        #Initialize means and distribution tensors.
        if self.criterion == "clustering":
            self.distributions = torch.Tensor(num_classes, exnumber, 1024).to(device)
            self.means = torch.Tensor(num_classes, 1024).to(device)
            self.stddevs = torch.Tensor(num_classes).to(device)
            self.average_distances = torch.zeros(num_classes).to(device)

    def freeze_encoders(self):
        '''
        This function freezes both text and image encoders so that they are not updated during training
        '''
        #freeze the image encoder
        for param in self.encoder.parameters():
            param.requires_grad = False

        #freeze text encoder as well
        for param in self.model.transformer.parameters():
            param.requires_grad = False

    def forward(self, x: torch.Tensor, text_features=None) -> torch.Tensor:
        x = self.encoder(x)

        if self.training:
            # Save encoded feature for access outside
            self.x_distr = x.detach().clone()  # for storing later
            x = self.classifier(x)

            return x
        else:
            if self.criterion == "thresholding":
                tmp = self.classifier(x)
                results = torch.zeros(x.shape[0], self.num_classes, device=x.device)
                for i in range(x.shape[0]):
                    if F.softmax(tmp[i], dim=0).max() > self.threshold:  #This compares the confidence of the prediction with the threshold
                        results[i] = tmp[i]

                    else: #Regular CLIP forward pass
                        x[i] /= x[i].norm(dim=-1, keepdim=True)

                        # Compute cosine similarity with text prompts
                        sims = x[i] @ text_features.T
                        results[i] = sims

            elif self.criterion == "clustering":
                tmp = self.classifier(x)
                results = torch.zeros(x.shape[0], self.num_classes, device=x.device)
                for i in range(x.shape[0]):
                    if self.detect_class_distance(x[i]): #We see if it belongs to any cluster
                        results[i] = tmp[i]

                    else:  #Regular CLIP forward pass
                        x[i] /= x[i].norm(dim=-1, keepdim=True)

                        # Compute cosine similarity with text prompts
                        sims = x[i] @ text_features.T  # [batch_size, num_classes]
                        results[i] = sims
            return results

    def detect_class_distance(self, x: torch.Tensor):
        '''
        Computes the distances between the features tensor and each cluster centroid and compares it to the threshold
        '''
        for i in range(self.num_classes//2):
            dist = torch.nn.functional.pairwise_distance(x.unsqueeze(0), self.means[i,:].unsqueeze(0)).item()
            if dist < self.stddevs[i]*self.threshold:
                return True
        return False

    def encode_text(self, text: torch.Tensor) -> torch.Tensor:
        '''
        Encodes the text and normalizes
        '''
        text_features = self.text_encoder(text)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        return text_features


The following code can be used to initialize our TopoVPT model.

In [None]:
class ResNet50Prompted(nn.Module):
    PROMPTS = 50  # number of prompts (rows)
    PDIM    = 50  # dimension of each prompt (columns)

    def __init__(self):
        super().__init__()
        model, _ = clip.load("RN50")
        model = model.float()
        self.base = model.visual
        self.conv1 = self.base.conv1
        self.bn1 = self.base.bn1
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.conv2 = self.base.conv2
        self.bn2 = self.base.bn2
        self.conv3 = self.base.conv3
        self.bn3 = self.base.bn3
        ##
        self.layer1 = self.base.layer1
        self.layer2 = self.base.layer2
        self.layer3 = self.base.layer3
        self.layer4 = self.base.layer4

        self.avgpool = self.base.avgpool

        self.attnpool = self.base.attnpool

        # lift prompt‚Äëdim (50) ‚Üí feature channels
        self.lift2 = nn.Conv2d(self.PDIM, 512, 1, bias=False)
        self.lift3 = nn.Conv2d(self.PDIM, 1024, 1, bias=False)
        # adapters to restore original channel size after concat
        self.adapter2 = nn.Conv2d(512 + 512, 512, 1, bias=False)
        self.adapter3 = nn.Conv2d(1024 + 1024, 1024, 1, bias=False)

        # Initialize parameters for lift2 and adapter2
        self._initialize_weights()

    def _initialize_weights(self):
        # Initialize weights for lift2
        init.kaiming_normal_(self.lift2.weight, mode='fan_out', nonlinearity='relu', )  # He initialization
        self.lift2.weight.data *= 0.00001
        if self.lift2.bias is not None:
            init.constant_(self.lift2.bias, 0)  # Initialize bias to zero if it exists

        # Initialize weights for adapter2
        conv_layer = self.adapter2  # Get the Conv2d layer from the Sequential
        init.kaiming_normal_(conv_layer.weight, mode='fan_out', nonlinearity='relu')  # He initialization
        conv_layer.weight.data *= 0
        if conv_layer.bias is not None:
            init.constant_(conv_layer.bias, 0)  # Initialize bias to zero if it exists

        #Initialize weights for lift2
        init.kaiming_normal_(self.lift3.weight, mode='fan_out', nonlinearity='relu', )  # He initialization
        self.lift3.weight.data *= 0
        if self.lift3.bias is not None:
            init.constant_(self.lift2.bias, 0)  # Initialize bias to zero if it exists

        # Initialize weights for adapter2
        conv_layer = self.adapter3  # Get the Conv2d layer from the Sequential
        init.kaiming_normal_(conv_layer.weight, mode='fan_out', nonlinearity='relu')  # He initialization
        conv_layer.weight.data *= 0.00001
        if conv_layer.bias is not None:
            init.constant_(conv_layer.bias, 0)

    def _prompt_map(self, img: torch.Tensor, H: int, W: int, lift: nn.Conv2d, C_out: int):
        """Compute lifted prompt map (1√óC_out√óH√óW) from a single image."""
        prompts = topo_hammer(img).to(img.device)  # 50√ó50

        P, Dp = prompts.shape  # 50,50
        # Broadcast each prompt over spatial dims and sum across P
        p = prompts.view(P, Dp, 1, 1).expand(-1, -1, H, W).sum(dim=0, keepdim=True)  # 1√óDp√óH√óW

        mean = p.mean(dim=(2, 3), keepdim=True)
        std = p.std(dim=(2, 3), keepdim=True) + 1e-6  # add epsilon to prevent divide-by-zero
        p = (p - mean) / std

        return lift(p)

    def forward(self, x: torch.Tensor):
        B = x.size(0)

        # First convolution
        conv1_out = self.conv1(x)

        # Batch normalization
        bn1_out = self.bn1(conv1_out)

        # ReLU
        relu_out = self.relu(bn1_out)

        # Second convolution
        conv2_out = self.conv2(relu_out)

        # Batch normalization
        bn2_out = self.bn2(conv2_out)

        # ReLU
        relu_out = self.relu(bn2_out)

        # Third convolution
        conv3_out = self.conv3(relu_out)  # use relu_out from above

        # Batch normalization
        bn3_out = self.bn3(conv3_out)

        # ReLU
        relu_out = self.relu(bn3_out)

        #Average pooling
        h = self.avgpool(relu_out)

        # Layer1
        h = self.layer1(h)

        # Layer2 with prompts
        h = self.layer2(h)

        _, _, H2, W2 = h.shape

        # Prompt mapping
        p2 = torch.cat([self._prompt_map(x[i], H2, W2, self.lift2, 512) for i in range(B)], dim=0)
        p2 = F.normalize(p2, dim=1)

        # Concatenate and adapt
        h = self.adapter2(torch.cat([p2, h], dim=1)) + h #We add residual connections

        # Layer3 with prompts
        h = self.layer3(h)

        _, _, H3, W3 = h.shape
        p3 = torch.cat([self._prompt_map(x[i], H3, W3, self.lift3, 1024) for i in range(B)], dim=0)
        p3 = F.normalize(p3, dim=1)

        h = self.adapter3(torch.cat([p3, h], dim=1)) + h  #We add residual connections

        # Layer 4
        h = self.layer4(h)

        #Attention pooling
        h = self.attnpool(h)

        return h


class CLIPWithTopoPrompts(nn.Module):
    def __init__(self, num_classes: int = 102, retrain = False):
        super().__init__()
        self.orig, _ = clip.load("RN50")
        self.orig = self.orig.float()
        self.num_classes = num_classes
        self.visual = ResNet50Prompted()
        self.retrain = retrain

        self.text_encoder = self.orig.encode_text
        self.threshold = 0.014

    def freeze_encoders(self):
        for p in self.visual.parameters():
            p.requires_grad = False
        for p in itertools.chain(self.visual.lift2.parameters(),
                         self.visual.lift3.parameters(),
                         self.visual.adapter2.parameters(),
                         self.visual.adapter3.parameters(),
                                ):
            p.requires_grad = True
        #Utilize the retrain option when initializing to train the parameters of layer 3 and 4
        if self.retrain:
            for p in itertools.chain(
                self.visual.layer3.parameters(),
                self.visual.layer4.parameters(),
            ):
                p.requires_grad = True
        #freeze text encoder as well
        for param in self.orig.transformer.parameters():
            param.requires_grad = False

    def encode_text(self, text: torch.Tensor) -> torch.Tensor:
        text_f = self.text_encoder(text)
        norm = text_f.norm(dim=-1, keepdim=True)
        epsilon = 1e-8  # Small value to prevent division by zero
        text_features = text_f / (norm + epsilon)
        return text_features

    def forward(self, x: torch.Tensor, text_features=None) -> torch.Tensor:
        img = self.visual(x)
        img_norm = img.norm(dim=-1, keepdim=True)
        image_features_normed = img / img_norm

        if self.training:
            sims = image_features_normed @ text_features.T
            return sims

        else:
            tmp = (image_features_normed @ text_features.T)
            results = torch.zeros(x.shape[0], self.num_classes, device=x.device)
            for i in range(x.shape[0]):
                if F.softmax(tmp[i], dim=0).max() > self.threshold:
                    results[i] = tmp[i]

                else:
                    # forward image through CLIP image encoder
                    image = x[i].unsqueeze(0)
                    img[i] = self.orig.encode_image(image)
                    # and normalize
                    img[i] /= img[i].norm(dim=-1, keepdim=True)

                    # Compute cosine similarity with text prompts
                    sims = img[i] @ text_features.T
                    results[i] = sims
            return results

### Freeze Batch Normalization stats

We define a function to freeze batchnorm stats

In [None]:
def freeze_batchnorm_stats(model: nn.Module):
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d):
            module.eval()
            module.requires_grad_(False)

### Define Training and Evaluation routines

We now define our training and evaluation routines. The evaluation routines are going to run in `eval` mode. The text features can be encoded once at the beginning of the routines as they are shared for all images. The loss we selected is Cross Entropy, the default loss for classification tasks. When training the topologic vpt the freeze parameter should be set to `False`.

In [None]:
def training_step_thresh(net, data_loader, optimizer, cost_function, device="cuda"):
    samples = 0.0
    cumulative_loss = 0.0
    cumulative_accuracy = 0.0
    class_counts = torch.zeros(net.num_classes, dtype=torch.long).to(device)

    # Set the network to training mode
    net.train()
    #Freeze both encoders and batchnorm stats
    net.freeze_encoders()
    freeze_batchnorm_stats(net)


    # Iterate over the training set
    pbar = tqdm(data_loader, desc="Training", position=0, leave=True, total=len(data_loader))
    for batch_idx, (inputs, targets) in enumerate(data_loader):
        # Load data into GPU
        inputs = inputs.float().to(device)
        targets = targets.to(device)

        # Forward pass
        outputs = net(inputs)
        if net.criterion == "clustering":
            for i in range(inputs.size(0)):
                label = targets[i]
                count = class_counts[label]

                if count < net.distributions.shape[1]:
                    net.distributions[label, count] = net.x_distr[i]
                    class_counts[label] += 1

        # Loss computation

        loss = cost_function(outputs, targets)

        # Backward pass
        loss.backward()

        # Parameters update
        optimizer.step()

        # Gradients reset
        optimizer.zero_grad()

        # Fetch prediction and loss value
        samples += inputs.shape[0]
        cumulative_loss += loss.item()
        _, predicted = outputs.max(dim=1)

        # Compute training accuracy
        cumulative_accuracy += predicted.eq(targets).sum().item()

        pbar.set_postfix(train_loss=loss.item(), train_acc=cumulative_accuracy / samples * 100)
        pbar.update(1)

    if net.criterion == "clustering":
        # Compute the per-class means and stddevs of distances
        for i in range(net.num_classes):
            if class_counts[i] > 0:
                # Compute class mean
                net.means[i] = net.distributions[i].mean(dim=0)

                # Compute standard deviation of distances from class mean
                diffs = net.distributions[i] - net.means[i]
                dists = torch.norm(diffs, dim=1)
                net.stddevs[i] = dists.std()
                net.average_distances[i] = dists.mean()

    return cumulative_loss / samples, cumulative_accuracy / samples * 100


def training_step_VPT(net, data_loader, optimizer, cost_function, categories, device="cuda"):
    samples = 0.0
    cumulative_loss = 0.0
    cumulative_accuracy = 0.0
    class_counts = torch.zeros(net.num_classes, dtype=torch.long).to(device)

    # Set the network to training mode
    net.train()
    #Freeze both encoders
    net.freeze_encoders()
    #Batch norm stats shoudldn't be frozen when training the vpt


    #Encoding the text prompts is required for the TopoVPT training
    with torch.no_grad():
        text_inputs = clip.tokenize([f"a photo of a {CLASS_NAMES[c]}, a type of flower." for c in categories]).to(device)
        text_features = net.encode_text(text_inputs).to(device, dtype=torch.float32)
        norm = text_features.norm(dim=-1, keepdim=True)
        text_features = text_features / norm

    # Iterate over the training set
    pbar = tqdm(data_loader, desc="Training", position=0, leave=True, total=len(data_loader))
    for batch_idx, (inputs, targets) in enumerate(data_loader):
        # Load data into GPU
        inputs = inputs.float().to(device)
        targets = targets.to(device)

        # Forward pass
        outputs = net(inputs, text_features)


        # Loss computation

        loss = cost_function(outputs, targets)

        # Backward pass
        loss.backward()

        # Parameters update
        optimizer.step()

        # Gradients reset
        optimizer.zero_grad()

        # Fetch prediction and loss value
        samples += inputs.shape[0]
        cumulative_loss += loss.item()
        _, predicted = outputs.max(dim=1)

        # Compute training accuracy
        cumulative_accuracy += predicted.eq(targets).sum().item()

        pbar.set_postfix(train_loss=loss.item(), train_acc=cumulative_accuracy / samples * 100)
        pbar.update(1)

    return cumulative_loss / samples, cumulative_accuracy / samples * 100


def eval_step(net, data_loader, cost_function, num_classes, categories, device="cuda"):
    net.eval()
    samples = 0.0
    cumulative_loss = 0.0
    cumulative_accuracy = 0.0

    with torch.no_grad():
        text_inputs = clip.tokenize([f"a photo of a {CLASS_NAMES[c]}, a type of flower." for c in categories]).to(device)
        text_features = net.encode_text(text_inputs).to(device, dtype=torch.float32)
        norm = text_features.norm(dim=-1, keepdim=True)
        text_features = text_features / norm

    # Set the network to evaluation mode
    norm_weights = torch.zeros(num_classes, 1024).to(device)
    maxacc = 0

    with torch.no_grad():
        # Iterate over the test set
        for batch_idx, (inputs, targets) in enumerate(data_loader):
            # Load data into GPU
            contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}

            targets = torch.Tensor([contig_cat2idx[t.item()] for t in targets]).long()


            inputs = inputs.to(device)
            targets = targets.to(device)
            # Forward pass
            outputs = net(inputs, text_features)

            # Loss computation
            loss = cost_function(outputs, targets)

            # Fetch prediction and loss value
            samples += inputs.shape[0]
            cumulative_loss += loss.item() # Note: the .item() is needed to extract scalars from tensors
            _, predicted = outputs.max(1)

            # Compute accuracy
            cumulative_accuracy += predicted.eq(targets).sum().item()

        del text_inputs, inputs, targets #To avoid GPU memory overusage

        return cumulative_loss / samples, cumulative_accuracy / samples * 100

### Model initialization

To initialize the threshold model use the option crit = "thresholding", and crit = "clustering" for clustering model.


In case of our TopoVPT retrain can be specified as true or false. If true the 3rd and 4th layer parameters are also going to be retrained. If false only the prompt injection ones.

In [None]:
net = ThresholdCLIP(exnumber=10, num_classes=102, crit = "thresholding").to(device)
#net = CLIPWithTopoPrompts(retrain=True).to(device)
#net = CLIPWithTopoPrompts(retrain=False).to(device)

### Hyperparameter selection

We now define some hyperparameters. For the confidence-based thresholding CLIP, the learning rate is set at 0.03, as it was justified before, it allowed fast convergence. Weight decay was kept to a default 0.0001 to avoid overfitting. For the TopoVPT model, the learning rate has been set to 0.005 and the weight decay to 0.01. Batch size was kept at 128 as the GPU memory supported it. Finally we set the maximum number of epochs to 30 as it allowed us to compare different values using checkpointing.

In [None]:
learning_rate_thresh = 0.03
learning_rate_vpt = 0.005
weight_decay_thresh = 1e-4
weight_decay_vpt = 1e-2
batch_size = 128
epochs = 1 # Change it for a higher value (30 for thresholding, 20 for TopoVPT) to replicate the results

### Optimizer selection

The optimizer we selected for the linear classifier (used in for thresholding) was the Adam optimizer. More advanced optimizers shouldn't be required as the trained network is quite simple.

In [None]:
optimizer_thresh = torch.optim.Adam(net.parameters(), lr=learning_rate_thresh,  weight_decay=weight_decay_thresh)

For the TopoVPT training we utilized the AdamW optimier with a very small learning rate to avoid gradient explosion.

In [None]:
# Uncomment CLIPWithTopoPrompts before
if net.retrain:
    optimizer_VPT = opt = torch.optim.AdamW([
        *net.visual.lift2.parameters(),
        *net.visual.lift3.parameters(),
        *net.visual.adapter2.parameters(),
        *net.visual.adapter3.parameters(),
        *net.visual.layer3.parameters(),
        *net.visual.layer4.parameters(),
        ], lr=learning_rate_vpt, weight_decay=weight_decay_vpt)
else:
    optimizer_VPT = torch.optim.AdamW([
        *net.visual.lift2.parameters(),
        *net.visual.lift3.parameters(),
        *net.visual.adapter2.parameters(),
        *net.visual.adapter3.parameters(),
        ], lr=learning_rate_vpt, weight_decay=weight_decay_vpt)
scheduler = torch.optim.lr_scheduler.ExponentialLR(opt, gamma=0.9)

### Dataloader creation

In [None]:
dataloader = torch.utils.data.DataLoader(train_base, batch_size=batch_size, shuffle=True, num_workers=4)
valnovelloader = torch.utils.data.DataLoader(val_novel, batch_size=batch_size, shuffle=False, num_workers=4)
valdataloader = torch.utils.data.DataLoader(val_base, batch_size=batch_size, shuffle=False, num_workers=4)
test_base_loader = torch.utils.data.DataLoader(test_base, batch_size=batch_size, shuffle=False, num_workers=4)
test_novel_loader = torch.utils.data.DataLoader(test_novel, batch_size=batch_size, shuffle=False, num_workers=4)

### Training

Thresholding/Clustering training

In [None]:
for e in range(epochs):
    training_step_thresh(net=net, data_loader=dataloader, optimizer=optimizer_thresh, cost_function=torch.nn.CrossEntropyLoss(), device=device)
    if (e+1)%5 == 0:
        torch.save(net.state_dict(), f'model_{e+1}.pth')

Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:03<00:00,  1.01it/s, train_acc=30, train_loss=3.43]  


TopoVPT training.

Note: If the training is performed on Colab, it might take a very large amount of time.

In [None]:
# Uncomment CLIPWithTopoPrompts before
for e in range(epochs):
    torch.autograd.set_detect_anomaly(True)
    training_step_VPT(net=net, data_loader=dataloader, optimizer=optimizer_VPT, cost_function=torch.nn.CrossEntropyLoss(), categories=all_classes, device=device)
    scheduler.step()
    if (e+1)%5 == 0:
        torch.save(net.state_dict(), f'model_{e+1}.pth')

### Checkpoint loading (Optional)

In [None]:
 net.load_state_dict(torch.load('model_20.pth'))

  net.load_state_dict(torch.load('model_20.pth'))


<All keys matched successfully>

### Threshold and Model selection

The threshold hyperparameter and the number of training epochs is chosen by evaluating on the validation set.

In [None]:
net.threshold = 0.65
# For CLIPWithTopoPrompts
#net.threshold = 0.013 or 0.0165, depending on the chosen model

In [None]:
eval_step(net, data_loader=valnovelloader,cost_function=torch.nn.CrossEntropyLoss(), device=device, num_classes = 102, categories=all_classes)

(0.1480332804661171, 65.29411764705883)

In [None]:
eval_step(net, data_loader=valdataloader,cost_function=torch.nn.CrossEntropyLoss(), device=device, num_classes = 102, categories=all_classes)

(0.013900058409746955, 82.54901960784314)

In [None]:
_, base_accuracy = eval_step(net, data_loader=test_base_loader,cost_function=torch.nn.CrossEntropyLoss(), device=device, num_classes = 102, categories=all_classes)
_, novel_accuracy = eval_step(net, data_loader=test_novel_loader,cost_function=torch.nn.CrossEntropyLoss(), device=device, num_classes = 102, categories=all_classes)

In [None]:
print(base_accuracy, novel_accuracy)

80.7925596441569 67.21980413492928


### Define harmonic mean function

In [None]:
def harmonic_mean(base_accuracy, novel_accuracy):
    numerator = 2
    denominator = 1 / base_accuracy + 1 / novel_accuracy
    hm = numerator / denominator
    return hm

print(f"üîç Harmonic Mean: {harmonic_mean(base_accuracy, novel_accuracy)}%")

üîç Harmonic Mean: 73.38387005217447%


# Results

### Confidence-based Thresholding

Confidence-based thresholding was the simplest method we employed, but it was able to yield some gains, achieving slightly higher accuracy than the base CLIP model.Our model was trained using the Adam optimizer with learning rate of 0.03 and weight decay of 0.0001. This configuration provided a both speed and stability, allowing smooth convergence to the minimum. In contrast, lower learning rates (0.001 and 0.003) led to extremely slow convergence. Since only a simple linear classifier is actually trained, the optimization process should converge relatively smoothly.

#### Validation set performance across iterations and thresholds

| **# Iterations** | **Threshold** | **Base (%)** | **Novel (%)** |
|------------------|---------------|--------------|---------------|
|**BASE CLIP**| - | 63.92 | 68.24 |
| 5  | 0.5  | 80.98 | 65.68 |
| 5  | 0.65 | 72.35 | 68.04 |
| 10 | 0.5  | 91.57 | 55.88 |
| 10 | 0.55 | 90.00 | 58.03 |
| 10 | 0.6  | 87.06 | 60.78 |
| 10 | 0.65 | 85.88 | 64.12 |
| 10 | 0.7  | 83.14 | 65.49 |
| 15 | 0.55 | 88.63 | 59.61 |
| 15 | 0.6  | 86.66 | 63.53 |
| 15 | 0.65 | 84.71 | 63.92 |
| 20 | 0.65 | 82.55 | 65.29 |
| 20 | 0.7  | 79.02 | 66.47 |
| 30 | 0.6  | 84.90 | 64.31 |
| 30 | 0.65 | 82.16 | 66.27 |

The table summarizes the results on the validation set. A trend can be observed, as additional training generally improves the performance on base classes, though often at the expense of novel-class accuracy. As expected, prediction confidence also increases with more epochs, requiring higher thresholds. Both training error and validation accuracy stabilize after roughly 10 epochs.
\\\\
The best balance was found after 15 training iterations with a threshold of 0.6. This model maintained strong performance on novel classes without a significant drop on base classes. The above mentioned trade-off can be observed at 10 iterations with a 0.65 threshold: improved novel-class accuracy at the cost of base-class performance. The same held for the 20-iteration, 0.65-threshold model, which slightly improved novel-class accuracy but again reduced base-class accuracy.
\\\\
By comparison at 5 iterations the model underfits, producing weaker confidence scores for base-class examples and requiring lower thresholds to show improvements over base CLIP. At 30 iterations, performance was similar but slightly worse on both base and novel classes, showing some signs of overfitting. Despite that, the model still maintains a decent performance, suggesting that further training would just plateau, rather than reducing performance.

Based on the validation results, we selected models trained for 30, 35, and 40 iterations for test evaluation.

#### Test set performance for selected models

#### BASE CLIP
| **Metric**       | **Value (%)** |
|------------------|---------------|
| Base Accuracy    | 69.47         |
| Novel Accuracy   | 73.78         |
| Harmonic Mean    | 71.56         |

#### 10 Iterations, 0.65 Threshold
| **Metric**       | **Value (%)** |
|------------------|---------------|
| Base Accuracy    | 82.37         |
| Novel Accuracy   | 65.70         |
| Harmonic Mean    | 73.09         |

#### 15 Iterations, 0.6 Threshold
| **Metric**       | **Value (%)** |
|------------------|---------------|
| Base Accuracy    | 84.63         |
| Novel Accuracy   | 64.68         |
| Harmonic Mean    | 73.33         |

#### 20 Iterations, 0.65 Threshold
| **Metric**       | **Value (%)** |
|------------------|---------------|
| Base Accuracy    | 80.79         |
| Novel Accuracy   | 67.22         |
| Harmonic Mean    | 73.38         |


The thresholds selected for the three models were  0.65, 0.6 and 0.65 respectively. The 20-iteration model marginally outperformed the 15-iteration one (harmonic means of top-1 accuracy of 73.38\% vs. 73.33\%), though both were nearly identical and clearly improved over the base CLIP score of 71.56\%. The 10-iteration model performed slightly worse (73.03\%), reinforcing that additional training was beneficial. Overall, these results suggest that performance stabilized around 15 iterations, showing little improvement beyond that point. Importantly, the proposed method consistently improves upon CLIP while requiring only modest training time and memory.

The last insight gained is that this methodology allows for a hyperparameter that governs the emphasis placed on base and novel classes. In our experiments, we chose a more balanced approach, aiming to maximize the harmonic mean of the accuracies. However, in a different context where maintaining novel-class accuracy was more important, the threshold could be easily adjusted accordingly. The same holds true if the goal were to maximize base-class accuracy while still maintaining reasonable performance on novel classes.

### Anomaly Detection

The anomaly detection methodologies we employed did not yield any significant result.

The simple Gaussian-clustering approach was not able to distinguish any significant cluster, as the number of training examples was too low and distance to the centroid was not a good measure to utilize, as features might be heavily reweighted by the successive forward pass. Changing the value multiplying the standard deviation values causes massive shifts in the model performance. For example, using a multiplier of 0.6 in a 30-iteration model resulted in none of the base class examples being recognized, causing the model to behave like standard CLIP. In contrast, increasing the multiplier to 0.7 led to all examples being recognized, but at the cost of 0\% accuracy on novel classes.

Overall, this approach failed to match the performance of the thresholding method and significantly degraded the base CLIP model‚Äôs accuracy. It also required extensive hyperparameter tuning just to achieve reasonable results on the validation set. Moreover, the method‚Äôs reliance on storing all distances and class means in memory‚Äînecessary for computing standard deviations‚Äîresulted in substantially higher memory consumption.

On the other hand, the approach based on a hierarchical transformation-discriminating generative model proved daunting under a different point of view. The model was too computationally intensive to train on both our machines and the cloud instances provided by the University. We were only able to train the model for a few iterations at a time and the performance was very lackluster when tested on models that were not trained enough. Additionally, we would be required to train a model for each of the 52 base classes, making it a lengthy and disk-consuming effort. This led us to abandon this methodology without being able to properly apply to the dataset in question.

### TopoVPT with thresholding

TopoVPT proved to be a demanding model to train, due to the fact that the topological part cannot take advantage of GPU computing. It was decided to inject $50 \times 50$ persistence images in the second and third layer of the Modified ResNet50 used by CLIP, so that the network could extract both larger and smaller-scale features from them. As has been mentioned before, since the persistence images are largely computationally intensive, it was necessary to keep them relatively small.

We initialize the parameters of our model with very small values drawn from a Kaiming distribution, also known as He initialization. This approach is particularly effective for layers with ReLU activation functions, as it helps to maintain a balanced variance across layers during the forward and backward passes. By using a Kaiming distribution, we ensure that the initial weights are set in a way that mitigates issues like vanishing or exploding gradients, thereby facilitating more efficient training and improving the overall convergence of the model. Adding skip connections proved helpful in training the model, as it improved the performance and reached convergence faster. A further introduction of a learnable parameter did not achieve any meaningful result.

After training a clear issue emerged, as the addition of the topological features alters significantly the structure of the network, even when the parameters are initialized at very low values. Therefore, the network outputs do not replicate the ones from base CLIP when the Topological prompt parameters are initialized at 0. This led the network to overfit on the training data while losing the ability to classify the novel classes. Additionally, the network varies greatly in performance between different classes, suggesting that the prompt extraction methodology works differently for different kinds of input.

Since TopoVPT does not produce good results on the novel classes, thresholding was employed, in order to benefit from both the zero-shot performance of CLIP and the improvements achieved by the model.

#### Test set performance for selected models

We propose two versions of TopoVPT, one where only the added parameters are trained and the other where the parameters of the third and fourth layer of the Modified ResNet are also trained. Based on validation results, it has been noticed that the best performance is obtained by training for 20 epochs using AdamW with an initial learning rate of 0.005 and a weight decay of 0.01. A decrease in the weight decay does not yield better results. The threshold hyperparameters have been selected based on validation set performance.


#### First proposal, 20 Iterations, 0.013 Threshold
| **Metric**       | **Value (%)** |
|------------------|---------------|
| Base Accuracy    | 63.4          |
| Novel Accuracy   | 68.2          |
| Harmonic Mean    | 65.7          |


#### Second proposal, 20 Iterations, 0.0165 Threshold
| **Metric**       | **Value (%)** |
|------------------|---------------|
| Base Accuracy    | 72.7          |
| Novel Accuracy   | 66            |
| Harmonic Mean    | 69.2          |

TopoVPT with thresholding shows results which are slightly worse than CLIP. This is likely due to the fact that more data is required to properly train the model.

## Conclusions

The only model that produced satisfying results is the one based on confidence thresholding. It is very fast to train and produces a clear improvement over base CLIP, whithout needlessly increasing the performance overhead at test time. With a more thorough parallel implementation of the forward pass the model should be a direct upgrade over base CLIP in few-shot adaptation scenarios. This improvement is further reinforced by the ease of putting the emphasis on either base or novel classes by modifying the threshold parameter.

The other methodologies we employed all had a number of issues. Gaussian clustering didn't show much promise when trained on very little amound of data. The hierarchical transformation-discriminating generative model proved too hard to train in terms of time and hardware resources. Finally, the approach based on Topological Visual Prompt Tuning was not able to maintain novel-class performance during training like we had hoped. And even when combined with thresholding, the method proved not to be adapt to an application with only a small number of training examples.

In conclusion confidence-based thresholding proved to be effective, despite its simplicity, while other, more complex, methods were not suitable for the task at hand.



## Bibliography



1. Henry Adams, Sofya Chepushtanova et al. 2015. 'Persistence Images: A Stable Vector Representation of Persistent Homology'. [link](https://arxiv.org/abs/1507.06217)
2. Rub√©n Ballester, Carles Casacuberta, Sergio Escalera. 2023. 'Topological Data Analysis for Neural Network Analysis: A Comprehensive Survey'. [link](https://arxiv.org/abs/2312.05840)
3. Menglin Jia, Lumin Tang et al. 2022. 'Visual Prompt Tuning'. [link](https://arxiv.org/abs/2203.12119)
4. Shelly Sheynin, Sagie Benaim, Lior Wolf. 2021. 'A Hierarchical Transformation-Discriminating Generative Model for Few Shot Anomaly Detection'. [link](https://arxiv.org/abs/2104.14535)