<p align="center">
  <strong><font size="6">Deep Learning 2025 Project</font></strong><br><br>
  Podavini Luca 257844<br>
  Richichi Andrea 257850<br>
  Sorrentino Francesco 256151
</p>

# Introduction
The aim of this project is to implement a **PEFT** (Parameter Efficient Fine Tuning) technique for CLIP[1] and to find a strategy to improve the baseline performance.
The chosen task used to evaluate the model is a base-to-novel classification task on the Flowers102 dataset.
The chosen technique to improve upon is CLIP-LoRA[3] by the use of a generative pipeline to generate a dataset to support the limited training set. 
The reimplemented technique is called DISEF[4], of which we evaluated results, trying to improve performance and reduce computational cost involved in the generative process.

First however to have a comparison with the last, more complex, technique we implemented CoCoOp[2] prompt tuning technique and vanilla CLIP-LoRA[3].
So the work is divided into:
1. Evaluation of Zero-Shot CLIP
2. Evaluation of CoCoOp
3. Evaluation of CLIP-LoRA
4. Evaluation of DISEF (an improvement on CLIP-LoRA)
5. Evaluation of our improved DISEF

## Import modules
This project requires different modules to work properly some of which are not pre-installed in colab. We will use the following modules:
- torch: torch library.
- torchvision: library containing data.
- clip: containing the pretrained model.
- tqdm: progress bar.
- matplotlib: plotting.
- os: for saving plots and weights.
- collections: used for a cleaner CoCoOp implementation.
- bitsandbites, transformers and diffusers: used for the generative pipeline.
- PIL: used for loading images.
- shutil: used for moving files.
- sklearn: used to compute zero-shot f1 scores.

In [None]:
# Install not pre-installed modules
%pip install openai_clip
%pip install bitsandbytes
# Colab comes preinstalled with other models, 
# if run on a vm transformers and diffusers may require installation

In [None]:
# Importing modules

# Modules for the whole project
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import clip
from tqdm import tqdm
import matplotlib.pyplot as plt
import os

# Modules for CoCoOp
from collections import OrderedDict
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

# Modules for LoRA
import math
import torch.nn.functional as F

# Modules for DISEF
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import random
import shutil
from transformers import AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig
from diffusers import StableDiffusionImg2ImgPipeline, DPMSolverMultistepScheduler


# Modules for our DISEF
from sklearn.metrics import precision_recall_fscore_support

## Define Constants

Here we define constants and parameters used throughout the project. These control the behavior of various components and help manage memory by allowing execution in stages.

In this section, each parameter is explained and formatted for clarity:

- `DEVICE`: A simple variable used to track which device is being used (e.g., `"cuda"` or `"cpu"`).
- `CLIP_BACKBONE`: Throughout the whole project, the chosen model is `"ViT-B/16"` for consistency.
- `CLASS_NAMES`: A global constant containing the class names of the Flowers102 dataset.
- `PROMPT_FORMAT`: The default format of the prompt used for evaluation and training.
- `LLAVA_MODEL_ID`: The chosen model used for captioning in the generative pipeline.
- `LLAVA_USER_PROMPT`: Fixed prompt used when querying the LLaVA model for captioning.
- `LLAVA_LABEL_PROMPT_FORMAT`: Prompt format used to add a label in the captioning prompt.
- `LLAVA_FULL_PROMPT_FORMAT`: Complete prompt required for the LLaVA model input.
- `SD_MODEL`: Stable Diffusion model used in the generative pipeline.
- `DIFFUSION_SAMPLER`: Diffusion sampler used by the diffusion model.
- `CFG_STRENGTH`: Determines how much the diffusion model is constrained by the prompt[4].
- `DIFFUSION_STEPS`: Number of diffusion steps[4].
- `NOISING_STEPS`: Number of noising steps[4].
- `SD_PROMPT_FORMAT`: Prompt format used with the Stable Diffusion model[4].
- `DISEF_GEN_PATH`: Path to save generated images.
- `DISEF_DEL_PATH`: Path to move rejected images.
- `OUR_DISEF_GEN_PATH`: Path to save generated images in the `our_disef` part of the project.
- `OUR_DISEF_DEL_PATH`: Path to move rejected images in the `our_disef` part of the project.
- `GENERATION_K_SHOTS`: Number of samples to generate. We used 32 shots instead of 64 due to fewer available starting samples compared to the reference work[4].
- `RUN_ZERO_SHOT`: Boolean flag indicating whether to run the zero-shot evaluation.
- `RUN_COCOOP`: Boolean flag indicating whether to run the CoCoOp evaluation.
- `RUN_CLIP-LoRA`: Boolean flag indicating whether to run the CLIP_LoRA evaluation.
- `RUN_DISEF`: Boolean flag indicating whether to run the DISEF evaluation.
- `RUN_OUR_DISEF`: Boolean flag indicating whether to run the modified DISEF evaluation.
- `TRAIN_COCOOP`: Boolean flag to train CoCoOp or just load the weights.
- `TRAIN_LoRA`: Boolean flag to train CLIP_LoRA or just load the weights.
- `TRAIN_DISEF`: Boolean flag to train DISEF model or just load weights.
- `GEN_DISEF`: Boolean flag to generate images in DISEF section or just load them.
- `TRAIN_DISEF`: Boolean flag to train our DISEF model or just load weights.
- `GEN_DISEF`: Boolean flag to generate images in our DISEF section or just load them.

In [None]:
# Constants
# Device where code is run
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

# CLIP Backbone
CLIP_BACKBONE = "ViT-B/16"

# Classnames in the dataset, hardcoded for use later
CLASS_NAMES = ["pink primrose", "hard-leaved pocket orchid", "canterbury bells", "sweet pea", "english marigold", "tiger lily", "moon orchid", "bird of paradise", "monkshood", "globe thistle", "snapdragon", "colt's foot", "king protea", "spear thistle", "yellow iris", "globe-flower", "purple coneflower", "peruvian lily", "balloon flower", "giant white arum lily", "fire lily", "pincushion flower", "fritillary", "red ginger", "grape hyacinth", "corn poppy", "prince of wales feathers", "stemless gentian", "artichoke", "sweet william", "carnation", "garden phlox", "love in the mist", "mexican aster", "alpine sea holly", "ruby-lipped cattleya", "cape flower", "great masterwort", "siam tulip", "lenten rose", "barbeton daisy", "daffodil", "sword lily", "poinsettia", "bolero deep blue", "wallflower", "marigold", "buttercup", "oxeye daisy", "common dandelion", "petunia", "wild pansy", "primula", "sunflower", "pelargonium", "bishop of llandaff", "gaura", "geranium", "orange dahlia", "pink-yellow dahlia?", "cautleya spicata", "japanese anemone", "black-eyed susan", "silverbush", "californian poppy", "osteospermum", "spring crocus", "bearded iris", "windflower", "tree poppy", "gazania", "azalea", "water lily", "rose", "thorn apple", "morning glory", "passion flower", "lotus", "toad lily", "anthurium", "frangipani", "clematis", "hibiscus", "columbine", "desert-rose", "tree mallow", "magnolia", "cyclamen", "watercress", "canna lily", "hippeastrum", "bee balm", "ball moss", "foxglove", "bougainvillea", "camellia", "mallow", "mexican petunia", "bromelia", "blanket flower", "trumpet creeper", "blackberry lily"]
# Prompt default format
PROMPT_FORMAT = "a photo of a {}, a type of flower."

# Llava constants
LLAVA_MODEL_ID = "llava-hf/llava-1.5-7b-hf"
LLAVA_USER_PROMPT = "Describe this image focusing on the object"
LLAVA_LABEL_PROMPT_FORMAT = ", knowing it's a {}, a type of flower."
LLAVA_FULL_PROMPT_FORMAT = "A chat between a curious human and an artificial intelligence assistant. The assistant gives concise, factual descriptions of images.###Human: <image>\n{}###Assistant:"

# Stable diffusion constants
SD_MODEL = "stabilityai/stable-diffusion-2-1-base"
DIFFUSION_SAMPLER="dpmsolver++" # From DISEF paper
CFG_STRENGHT=8 # From DISEF paper
DIFFUSION_STEPS=20 # From DISEF paper
NOISING_STEPS=15 # From DISEF paper
SD_PROMPT_FORMAT = "A realistic photo of a {}, a type of flower."

# Location for generated images
DISEF_GEN_PATH = os.path.join("imgs","disef","data") # imgs/disef/data
DISEF_DEL_PATH = os.path.join("imgs","disef","discarded") # imgs/disef/discarded
OUR_DISEF_GEN_PATH = os.path.join("imgs","our_disef","data") # imgs/our_disef/data
OUR_DISEF_DEL_PATH = os.path.join("imgs","our_disef","discarded") # imgs/our_disef/discarded

# Number of shots to generate
GENERATION_K_SHOTS = 32

# Parameters to decide what evaluations must be run
# Better to run an evaluation per session to avoid out-of-memory errors
RUN_ZERO_SHOT = True
RUN_COCOOP = True
RUN_LORA = True
RUN_DISEF = True
RUN_OUR_DISEF = True # To run our_disef LoRA weights are needed inside bin folder

# Parameters to decide if training is done and if generation is done or we load files
TRAIN_COCOOP = True
TRAIN_LORA = True
TRAIN_DISEF = True
GEN_DISEF = True
TRAIN_OUR_DISEF = True
GEN_OUR_DISEF = True




## Load Datasets
Collecting and preprocessing data from torchvision.

The classes are split into base and novel ones by putting half the classes into base and the other half into novel. This only simulates a real application.

In [None]:
def get_data(data_dir="./data", transform=None):
    """Load Flowers102 train, validation and test sets.
    Args:
        data_dir (str): Directory where the dataset will be stored.
        transform (torch.Compose)
    Returns:
        tuple: A tuple containing the train, validation, and test sets.
    """
    train = torchvision.datasets.Flowers102(root=data_dir, split="train", download=True, transform=transform)
    val = torchvision.datasets.Flowers102(root=data_dir, split="val", download=True, transform=transform)
    test = torchvision.datasets.Flowers102(root=data_dir, split="test", download=True, transform=transform)
    return train, val, test

def base_novel_categories(dataset):
  """Split dataset classes into base and novel classes.
  Args:
    dataset (list): Dataset to split into base and novel classes.
  Returns:
    tuple: A tuple containing base and novel classes
  """
  # Set returns the unique set of all dataset classes
  all_classes = set(dataset._labels)
  num_classes = len(all_classes)

  # Generate base and novel category lists
  base_classes = list(range(num_classes))[:num_classes//2]
  novel_classes = list(range(num_classes))[num_classes//2:]
  return base_classes, novel_classes


def split_data(dataset, base_classes):
  """Split sample given base classes.
  Args:
    dataset (list): list of samples.
    base_classes (list): list of base classes.
  Returns:
    tuple: Tuple containing base and novel datasets.
  """
  # List to store sample idx
  base_categories_samples = []
  novel_categories_samples = []

  # Set of base classes to compute the test below in O(1)
  base_set = set(base_classes)

  # Iterate and get sample idx
  for sample_id, label in enumerate(dataset._labels):
    if label in base_set:
      base_categories_samples.append(sample_id)
    else:
      novel_categories_samples.append(sample_id)

  # Create the dataset subsets using Subset
  base_dataset = torch.utils.data.Subset(dataset, base_categories_samples)
  novel_dataset = torch.utils.data.Subset(dataset, novel_categories_samples)
  return base_dataset, novel_dataset


All of this previously defined function will now be used to create the full data loading pipeline.
<div style="text-align: center;">
  <img src="https://raw.githubusercontent.com/Frasor2002/DL_Project/main/report_diagrams/get_dataset.png" width="50%">
</div>


In [None]:
def get_dataset(do_preprocess=True):
  """Load Flowers102 datasets
  Args:
    preprocess (bool): enable or not preprocessing using CLIP.
  Returns:
    tuple: A tuple containing the train, validation, test sets, base classes and novel classes.
  """
  clip_preprocess=None
  if do_preprocess:
    # Load the CLIP preprocess
    _, clip_preprocess = clip.load(CLIP_BACKBONE, device=DEVICE)

  # Get the three datasets using the CLIP preprocess
  train_set, val_set, test_set = get_data(transform=clip_preprocess)

  # Split classes into base and novel
  base_classes, novel_classes = base_novel_categories(train_set)

  # Split the three datasets
  train_base, _ = split_data(train_set, base_classes)
  val_base, _ = split_data(val_set, base_classes)
  test_base, test_novel = split_data(test_set, base_classes)

  return train_base, val_base, test_base, test_novel, base_classes, novel_classes


Now data can be loaded and preprocessed to get our datasets.

Before evaluating the model, let's study the number of samples and class distributions. This is especially relevant later one when evaluating the quality of our generated data.

<div style="text-align: center;">
  <img src="https://raw.githubusercontent.com/Frasor2002/DL_Project/main/report_diagrams/sample_counts_per_dataset.png" width="80%">
</div>

We have 510 samples for both training set and validation set.
Test base contains 2473 samples while test novel contains 3676 samples.

<div style="text-align: center;">
  <img src="https://raw.githubusercontent.com/Frasor2002/DL_Project/main/report_diagrams/class_distribution_Train_Base.png"  width="80%">
	<img src="https://raw.githubusercontent.com/Frasor2002/DL_Project/main/report_diagrams/class_distribution_Validation_Base.png" width="80%">
</div>

*train_base* and *val_base* contain 10 shots for every base class (51 classes * 10 shots).

<div style="text-align: center;">
  <img src="https://raw.githubusercontent.com/Frasor2002/DL_Project/main/report_diagrams/class_distribution_Test_Base.png" width="80%" >
	<img src="https://raw.githubusercontent.com/Frasor2002/DL_Project/main/report_diagrams/class_distribution_Test_Novel.png" width="80%">
</div>

*test_base* contains a non-uniform distribution of samples through classes with the most amount of samples (> 200) on petunia and wallflower classes.

*test_novel* contains a non-uniform distribution of samples through classes with the most amount of samples (> 200) on passion flower.

# Evaluation of Zero-Shot CLIP
Zero-Shot CLIP must be evaluated before applying any technique in order to have a baseline performance to improve.

## Test functions

First a test function is define to be re-used later in the project.
CLIP is evaluated by first computing text features and then getting the similarity with the visual features of the images in the batch, obtaining the predictions and the accuracy across batches.

Harmonic mean will be also used to compute a value that shows if the technique employed is improving both base and novel accuracy. This is due to the fact that simply finetuning the model would increase the base accuracy, greatly reducing the generalization ability of pre-trained weights.

In [None]:
def harmonic_mean(base_accuracy, novel_accuracy):
  """Compute harmonic mean
  Args:
    base_accuracy (float): accuracy score on base classes.
    novel_accuracy (float): accuracy score on novel classes.
  Returns:
    float: harmonic mean.
  """
  numerator = 2
  denominator = 1 / base_accuracy + 1 / novel_accuracy
  hm = numerator / denominator
  return hm

def clip_test(model, loader, categories, device, label=""):
  """Test function for CLIP model.
  Args:
    model (torch.nn): clip pretrained model to use.
    loader (DataLoader): dataloader for evaluation.
    categories (list): either base or novel idxs.
    device (str): device where to put data.
    label (str): label for evaluation loop.
  Returns:
    float: accuracy score.
  """
  # Set model in eval mode
  model.eval()

  # Dictionary for remapping labels label -> into contiguous set
  contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}

  # Apply the standard CLIP template used for oxford flowers to all categories and immediately tokenize each sentence
  text_inputs = clip.tokenize([PROMPT_FORMAT.format(CLASS_NAMES[c]) for c in categories]).to(device)

  with torch.no_grad():
    # Encode text features for all classes
    text_features = model.encode_text(text_inputs)
    # Normalize them (standard pratice with CLIP)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    # Variable to store number of correct predictions
    correct_predictions = 0
    # Iterate through batches
    for image, target in tqdm(loader, desc=label):
      # Map categories to contiguous to get correct predictions
      target = torch.Tensor([contig_cat2idx[t.item()] for t in target]).long()

      image, target = image.to(device), target.to(device)

      # Encode image features
      image_features = model.encode_image(image)
      # Normalize image features
      image_features /= image_features.norm(dim=-1, keepdim=True)

      # Cosine similarity between image and text features and keep the argmax for every row (every image)
      predicted_class = (image_features @ text_features.T).argmax(dim=-1)
      # Check which are correct, and sum them (False == 0, True == 1)
      correct_predictions += (predicted_class == target).sum().item()
  # Compute the accuracy
  accuracy = correct_predictions / len(loader.dataset)
  return accuracy


After having defined a function for CLIP we can run an evaluation function for zero-shot model. This pattern will be used also later one, providing a single function that executes indipendenty all the steps to evaluate the technique.

In [None]:
def zero_shot_eval():
  """Function to run zero_shot evaluation."""
  # Load the model
  clip_model, _ = clip.load(CLIP_BACKBONE, device=DEVICE)


  # Get the datasets
  _, _, test_base, test_novel, base_classes, novel_classes = get_dataset(do_preprocess=True)

  # Batch size
  TEST_BATCH = 128

  # Get loaders
  test_base_loader = torch.utils.data.DataLoader(test_base, batch_size=TEST_BATCH, shuffle=False, num_workers=2)
  test_novel_loader = torch.utils.data.DataLoader(test_novel, batch_size=TEST_BATCH, shuffle=False, num_workers=2)

  # Evaluate base and novel
  base_accuracy = clip_test(clip_model, test_base_loader, base_classes, DEVICE, label="🧠 Zero-shot evaluation on Base")
  novel_accuracy = clip_test(clip_model, test_novel_loader, novel_classes, DEVICE, label="🧠 Zero-shot evaluation on Novel")

  # Show results
  print() # For separating from progress bars
  print("="*20,"Evaluating Zero-Shot","="*20)
  print(f"🔍 Base classes accuracy: {base_accuracy*100:.2f}%")
  print(f"🔍 Novel classes accuracy: {novel_accuracy*100:.2f}%")
  print(f"🔍 Harmonic Mean: {harmonic_mean(base_accuracy, novel_accuracy)*100:.2f}%")

The evaluation is then run depending on the parameter set at the start of the notebook and we save the results

In [None]:
if RUN_ZERO_SHOT:
  zero_shot_eval()

The baseline performance obtained with zero-shot evaluation is:

<div align="center">

| Model      | Base | Novel | Harmonic Mean |
|------------|----------|-----------|---------------|
| Zero-Shot  | 71.33%   | 78.24%    | 74.62%        |

<div>

The desired improvement upon this performance is to raise the base accuracy without lowering the novel performance.

# Evaluation of CoCoOp
The first technique applied to CLIP is Conditional Context Optimization or CoCoOp.
This technique introduces the work into the field of prompt tuning by learning textual prompts that improve the zero-shot performance.
This technique was created to improve on CoOp by making the learned adapt to each individual image using a meta net to compute bias.

Adaptation to each image solves the pitfalls of the previous technique showing that a single static prompt is not good enough to keep a good novel performance.

Mathematically this approach can be seen as follows:
- Image encoder extracts features from the image: $x = g(I)$ where x is the features and I the image.
- Meta net is fed the features and generates a conditional token pi $\pi = h(x)$.
- Then pi is combined with a set of learnable context vectors [V] to generate the final instance specific prompt for class $c_k$. We can see the prompt P as a function of the image features as follows: 

$P = [V(\pi)]_1[V(\pi)]_2...[V(\pi)]_m[CLASS]_k$

- Then the similarity is computed between embeddings of dynamically generated prompt and image:

$p(y=k|I) = \frac{exp(\tau \cdot cos(x, f(P(x)_k)))}{\sum{K}{i=1}exp(\tau \cdot cos(x, f(P(x)_i)))}$
Where K is the total number of classes.

The following diagram synthetizes the CoCoOp architecture:

<div style="text-align: center;">
  <img src="https://raw.githubusercontent.com/Frasor2002/DL_Project/main/report_diagrams/CoCoOp_architecture.png" width="80%">
</div>

The following steps will show the CoCoOp implementation steps and results.

The implementation involves the creations of several wrapper functions around CLIP modules.
First the `TextEncoder` class is defined to encode dynamic prompts with tokenized prompts given, wrapping around CLIP's text encoder.

This is used to handle dynamically created prompts and get the output embeddings.

In [None]:
class TextEncoder(nn.Module):
  """Encode dynamic prompts with tokenized given prompts."""
  def __init__(self, clip_model):
    """Init the module.
    Args:
      clip_model (torch.nn): clip pretrained model to use.
    """
    super().__init__()
    # Save into the module clip modules
    self.text_encoder = clip_model.transformer
    self.positional_embedding = clip_model.positional_embedding
    self.layer_norm = clip_model.ln_final
    self.proj = clip_model.text_projection

  def forward(self, embedded_tokens, tokens_ids):
    """Forward pass.
    Args:
      embedded_tokens (Tensor): Input token embeddings.
      tokens_ids (Tensor): Tokenizer prompt IDs.
    """
    # Apply positional embeddings
    output = embedded_tokens + self.positional_embedding

    # Rearrange dimension for transformer input
    output = output.permute(1, 0, 2)  # [batch_size, n_ctx, transformer.width] -> [n_ctx, batch_size, transformer.width]
    output = self.text_encoder(output)
    # Go back to original dimensions
    output = output.permute(1, 0, 2)  # [n_ctx, batch_size, transformer.width] -> [batch_size, n_ctx, transformer.width]
    # Apply layer norm
    output = self.layer_norm(output)

    # Select features corresponding to eot tokens
    eot_ids = tokens_ids.argmax(dim=-1)
    output = output[torch.arange(output.shape[0]), eot_ids] @ self.proj

    return output

The main CoCoOp component is the `PromptLearner` class then handles the prompt creation given an inital context or a random context lenght.
This module implements both the learnable context and the meta net present previously. The forward method works by taking in input image features and returning dynamic prompts.

In [None]:

class PromptLearner(nn.Module):
  """Module that learns ctx vectors to adapt to visual features for each class."""
  def __init__(self, n_ctx, ctx_init, classnames, clip_model):
    """Initialize PromptLearner module.
    Args:
      n_ctx (int): Number of ctx tokens to learn.
      ctx_init (string): Initialization string for context.
      classnames (list): class names list.
      clip_model (nn.Module): pretrained CLIP for token embedding.
    """
    super().__init__()
    # Save number of classes and number of ctx tokens
    self.n_cls = len(classnames)
    self.n_ctx = n_ctx

    # Get dimensions from CLIP
    # Context embedding dimension
    ctx_dim = clip_model.ln_final.weight.shape[0]
    # Visual encoder output dimension
    vis_dim = clip_model.visual.output_dim

    # Get CLIP device
    device = clip_model.token_embedding.weight.device

    # Use given words to initialize context vectors
    ctx_init = ctx_init.replace("_", " ")
    self.n_ctx = len(ctx_init.split(" "))
    prompt = clip.tokenize(ctx_init).to(device)

    with torch.no_grad():
        embedding = clip_model.token_embedding(prompt) # Convert token to embedding

    # Get only context tokens
    ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
    prompt_prefix = ctx_init

    # Register learnable context parameters
    self.ctx = nn.Parameter(ctx_vectors)

    # Meta-net to apply context based on image features
    self.meta_net = nn.Sequential(OrderedDict([
        ("linear1", nn.Linear(vis_dim, vis_dim // 16)),
        ("relu", nn.ReLU(inplace=True)),
        ("linear2", nn.Linear(vis_dim // 16, ctx_dim))
    ]))

    # Preprocess classnames and add them to the prompts
    classnames = [name.replace("_", " ") for name in classnames]
    # Instantiate tokenizer to get lenght of classnames
    _tokenizer = _Tokenizer()
    self.name_lens = [len(_tokenizer.encode(name)) for name in classnames]
    prompts = [prompt_prefix + " " + name + "." for name in classnames]

    # Tokenize prompts and get embeddings
    tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(device)  # (n_cls, n_tkn)
    with torch.no_grad():
        embedding = clip_model.token_embedding(tokenized_prompts)

    # Register unchanged parts of the embeddings
    self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS
    self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :])  # CLS, EOS

    # Save tokenized_prompts
    self.tokenized_prompts = tokenized_prompts  # torch.Tensor

  def construct_prompts(self, ctx, prefix, suffix, label=None):
    """Construct final prompts by concatenating prefix, ctx and suffix tokens.
    Args:
      ctx (Tensor): context tokens.
      prefix (Tensor): prefix tokens.
      suffix (Tensor): suffix tokens.
      label (Tensor): indeces to select specific class prompts.
    """
    if label is not None:
      # Class specific suffix and prefix if labels are given
      prefix = prefix[label]
      suffix = suffix[label]

    prompts = torch.cat(
      [prefix,
      ctx,
      suffix],
      dim=1
    )

    return prompts

  def forward(self, im_features):
    """Forward pass.
    Args:
      im_features (Tensor): image features.
    Returns:
      Tensor: final prompts.
    """
    # Get fixed prefix and suffix token embeddings
    prefix = self.token_prefix
    suffix = self.token_suffix

    # Using the input image features we generate bias to apply to ctx
    bias = self.meta_net(im_features)
    bias = bias.unsqueeze(1)
    ctx = self.ctx.unsqueeze(0)
    ctx_shifted = ctx + bias # Broadcast bias across context

    # Build prompts for each sample in batch using adapted context
    prompts = []
    for ctx_shifted_i in ctx_shifted:
      # Expands for all classes
      ctx_i = ctx_shifted_i.unsqueeze(0).expand(self.n_cls, -1, -1)
      pts_i = self.construct_prompts(ctx_i, prefix, suffix)
      prompts.append(pts_i)

    # Stack prompts for all batch samples
    prompts = torch.stack(prompts)

    return prompts

The `CoCoOpCLIP` class is a wrapper around clip to handle a clean forward method that takes prompts from the `PromptLearner` class and produces directly the logits in order to more easily train the model.

In [None]:
class CoCoOpCLIP(nn.Module):
  """A CLIP wrapper that adds prompt tuning and computes logits using dynamic text features."""
  def __init__(self, n_ctx, ctx_init, classnames, clip_model):
    """Initiate module with prompt learner and encoders.
    Args:
      n_ctx (int): number of ctx tokens for prompt learner.
      ctx_init (string): context init string for prompt learner.
      classnames (list): list of classnames for prompt learner.
      clip_model (nn.Module): pretrained CLIP to wrap around.
    """
    super().__init__()
    # Prompt learner generates dynamic prompts based on image features
    self.prompt_learner = PromptLearner(n_ctx, ctx_init, classnames, clip_model)
    # Save tokenized prompts from prompt learner
    self.tokenized_prompts = self.prompt_learner.tokenized_prompts
    # Get visual encoder from CLIP
    self.image_encoder = clip_model.visual
    # Get wrapper around clip text encoder
    self.text_encoder = TextEncoder(clip_model)
    # Learnable scaling factor for logits from CLIP
    self.logit_scale = clip_model.logit_scale

  def forward(self, imgs):
    """Forward pass to compute logits between image and text prompts.
    Args:
      imgs: batch of input images.
    Returns:
      Tensor: similarity logits between image and text features.
    """
    # Get tokenized prompts and logit scale
    tokenized_prompts = self.tokenized_prompts
    logit_scale = self.logit_scale.exp()

    # Encode and normalize image features
    image_features = self.image_encoder(imgs)
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)

    # Generate instance-conditioned prompts using image features
    prompts = self.prompt_learner(image_features)

    logits = []
    for pts_i, imf_i in zip(prompts, image_features):
      # Compute similarity for each image in the batch
      text_features = self.text_encoder(pts_i, tokenized_prompts)
      text_features = text_features / text_features.norm(dim=-1, keepdim=True)
      l_i = logit_scale * imf_i @ text_features.t()
      logits.append(l_i)

    # Stack logits for all batch samples
    logits = torch.stack(logits)

    return logits

A new evaluation loop function has to be created in order to handle the wrapper module. Since we are doing few-shot, in testing phase we need to pass also the base or novel class idxs and remap labels to contiguous values.

In [None]:
def cocoop_test(cocoop_model, loader, categories, device="cuda:0", label=""):
  """Test CoCoOp model.
  Args:
    cocoop_model (nn.Module): CoCoOp model to test.
    loader (DataLoader): dataloader for evaluation.
    categories (list): either base or novel idxs.
    device (str): device where to put data.
  """
  # Variables to compute performance score
  samples = 0
  cumulative_accuracy = 0

  # Set the network to evaluation mode
  cocoop_model.eval()

  # Remap labels into a contiguous set starting from zero
  contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}

  with torch.no_grad():
    for images, targets in tqdm(loader, desc=label):
      # Map categories to the [0, 50], otherwise we will have wrong predictions
      targets = torch.Tensor([contig_cat2idx[t.item()] for t in targets]).long()
      images, targets = images.to(device), targets.to(device)

      outputs = cocoop_model(images)

      # Compute performance scores
      batch_size = images.shape[0]
      samples += batch_size

      _, predicted = outputs.max(1)

      # Compute accuracy
      cumulative_accuracy += predicted.eq(targets).sum().item()

    accuracy = cumulative_accuracy / samples

    return accuracy

For a more clean code, we can also create a training function that trains the model for a given number of epochs. Here we also implement early stopping with patience. And also handles the loss computation for training.

In [None]:
def cocoop_train(cocoop_model, train_loader, val_loader, optimizer, loss_fun, scheduler, base_classes, num_epochs=5, patience=3, device="cuda:0"):
  """Train CoCoOp.
  Args:
    cocoop_model (nn.Module): CoCoOp model to train.
    train_loader (DataLoader): dataloader for training.
    val_loader (DataLoader): dataloader for validation.
    optimizer (torch.optim): optimizer to use.
    loss_fun (torch.nn): cost function to use.
    scheduler (torch.optim.lr_scheduler): scheduler to use.
    base_classes (list): list of base classes for evauation.
    num_epochs (int): number of epochs to train.
    patience (int): patience for early stopping.
    device (str): device to use.
  """
  def train_step(cocoop_model, loader, optimizer, loss_fun, device="cuda:0"):
    """Training step for CoCoOp.
    Args:
      cocoop_model (nn.Module): CoCoOp model to train.
      loader (data): loader for training set.
      optimizer (torch.optim): optimizer to use.
      loss_fun (torch.nn): cost function to use.
      device (str): device to use.
    """
    # Variables to store values to compute loss and accuracy
    samples = 0
    cumulative_loss = 0
    cumulative_accuracy = 0

    # Set the model to training mode
    cocoop_model.train()

    # Iterate over the training set
    for images, targets in loader:
      images = images.to(device)
      targets = targets.to(device)

      outputs = cocoop_model(images)
      loss = loss_fun(outputs, targets)
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()

      batch_size = images.shape[0]
      samples += batch_size
      cumulative_loss += loss.item() * batch_size

      # Get predictions
      _, predicted = outputs.max(dim=1)
      cumulative_accuracy += predicted.eq(targets).sum().item()

    loss = cumulative_loss / samples
    accuracy = cumulative_accuracy / samples

    return loss, accuracy

  # Initialize vars before the train loop
  best_val_acc = float('-inf')
  best_model_state = None
  curr_patience = patience
  pb = tqdm(range(num_epochs))

  for epoch in pb:
    # Training step
    train_loss, train_acc = train_step(cocoop_model, train_loader, optimizer, loss_fun, device)

    # Evaluation step
    val_acc = cocoop_test(cocoop_model, val_loader, base_classes, device)

    # Show val accuracy at every iteration
    pb.set_description(f"Val Acc: {val_acc*100:.2f}%")

    # If we get better metric save the model
    if val_acc > best_val_acc:
      best_val_acc = val_acc
      best_model_state = cocoop_model.prompt_learner.state_dict()
      curr_patience = patience
    else:
      if curr_patience < 1:
        break
      else:
        curr_patience = curr_patience - 1
    # Step the scheduler if provided
    if scheduler is not None:
      scheduler.step()

  # Load the best model before returning
  if best_model_state is not None:
    cocoop_model.prompt_learner.load_state_dict(best_model_state)
  return cocoop_model


Some utility functions to save and load weights have also been created in order to store the best possible result.
An important detail is that only the weights of the prompt learner module are stored in order to save space (8 Mb compared to ~500 Mb of CLIP weights).

In [None]:
def save_model(model, model_name):
  """Save the model weights.
  Args:
    model (nn.Module): model to save.
    model_name (str): name of the model to save.
  """
  # Save directory location (if does not exist create it)
  save_dir = os.path.join("bin")
  os.makedirs(save_dir, exist_ok=True)

  file_path = os.path.join(save_dir, f"{model_name}.pt")
  torch.save(model.state_dict(), file_path)
  print(f"{model_name} weights saved to {file_path}.")

def load_model(model, model_name, device):
  """Load the model weights.
  Args:
    model (nn.Module): model where to load the weights.
    model_name (str): name of the model to load.
    device (str): device to use.
  """
  file_path = os.path.join("bin", f"{model_name}.pt")

  if not os.path.isfile(file_path):
    raise FileNotFoundError(f"No weights file found at {file_path}")

  state_dict = torch.load(file_path, map_location=torch.device(device))
  model.load_state_dict(state_dict)
  print(f"Model weights loaded from {file_path}")

After all of this we can create an evaluate CoCoOp function that loads the different models, trains on base training set and gives us accuracy after training.

The hyperparameters have been chosen in accordance with the reference[1] and also the initialization of the prompt learner is done in the same way.
- `N_CTX`: number of context tokens to learn.
- `CTX_INIT`: context to initialize the prompt_learner. We use a default context 'a photo of a'.
- `TRAIN_BATCH`, `VAL_BATCH`, `TEST_BATCH`: batch sizes of different datasets.
- `LR`: through multiple runs we tuned the learning rate to reach best performance in the lowest number of epochs.
- `NUM_EPOCHS`: number of epochs to train. 10 is enough to reach a good performance.
- `PATIENCE`: patience for early stopping that allows to avoid overfitting on training set.

An interesting detail that is relevant in the training is that CoCoOp can be trained using only 1 dims training batch sizes since to create prompts we need to iterate inside the batch in the forward method.
This severely slows down training compared to other techniques.


In [None]:
def cocoop_eval(do_train=True):
  """Function to run CoCoOp evaluation.
  Args:
    do_train (bool): train the model or load weights.
  """
  # Load the model
  clip_model, _ = clip.load(CLIP_BACKBONE, device=DEVICE)
  clip_model = clip_model.float() # To avoid weight type errors

  # Get the datasets
  train_base, val_base, test_base, test_novel, base_classes, novel_classes = get_dataset(do_preprocess=True)

  # Hyperparams
  N_CTX = 4
  CTX_INIT = "a photo of a"
  TRAIN_BATCH = 1 # CoCoOp requires a batch of 1 in training to run
  VAL_BATCH = 64
  TEST_BATCH = 128
  LR = 2e-3
  NUM_EPOCHS = 10 # From CoCoOp paper
  PATIENCE = 3

  # Get base and novel classnames
  base_classnames = [CLASS_NAMES[c] for c in base_classes]
  novel_classnames = [CLASS_NAMES[c] for c in novel_classes]

  # One wrapper for base classification and the other for novel by giving different classnames
  base_model = CoCoOpCLIP(N_CTX, CTX_INIT, base_classnames, clip_model).to(DEVICE)
  novel_model = CoCoOpCLIP(N_CTX, CTX_INIT, novel_classnames, clip_model).to(DEVICE)

  # Freeze all CLIP params other than prompt learner
  for name, param in base_model.named_parameters():
      if "prompt_learner" not in name:
          param.requires_grad_(False)
  for name, param in novel_model.named_parameters():
      if "prompt_learner" not in name:
          param.requires_grad_(False)

  # Sanity check
  enabled = set()
  for name, param in base_model.named_parameters():
      if param.requires_grad:
          enabled.add(name)
  print(f"Parameters to be updated: {enabled}")

  # Get loaders
  train_loader = torch.utils.data.DataLoader(train_base, batch_size=TRAIN_BATCH, shuffle=True, num_workers=2)
  val_loader = torch.utils.data.DataLoader(val_base, batch_size=VAL_BATCH, shuffle=False, num_workers=2)
  test_base_loader = torch.utils.data.DataLoader(test_base, batch_size=TEST_BATCH, shuffle=False, num_workers=2)
  test_novel_loader = torch.utils.data.DataLoader(test_novel, batch_size=TEST_BATCH, shuffle=False, num_workers=2)

  # Get cost function
  loss_fun = torch.nn.CrossEntropyLoss()

  # Train
  if(do_train):
    optimizer = optim.AdamW([p for p in base_model.parameters() if p.requires_grad], lr=LR)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

    # Train
    train_params = {
        "cocoop_model": base_model,
        "train_loader": train_loader,
        "val_loader": val_loader,
        "optimizer": optimizer,
        "loss_fun": loss_fun,
        "scheduler": scheduler,
        "base_classes": base_classes,
        "num_epochs": NUM_EPOCHS,
        "patience": PATIENCE,
        "device": DEVICE
    }
    base_model = cocoop_train(**train_params)
    # Save weights
    save_model(base_model.prompt_learner, model_name="CoCoOp")
  else:
    load_model(base_model.prompt_learner, model_name="CoCoOp", device=DEVICE)

  # Load weights on novel model
  novel_model.prompt_learner.ctx.data = base_model.prompt_learner.ctx.data.clone()
  novel_model.prompt_learner.meta_net.load_state_dict(base_model.prompt_learner.meta_net.state_dict())

  # Evaluate base and novel
  base_accuracy = cocoop_test(base_model, test_base_loader, base_classes, DEVICE, label="🧠 CoCoOp evaluation on Base")
  novel_accuracy = cocoop_test(novel_model, test_novel_loader, novel_classes, DEVICE, label="🧠 CoCoOp evaluation on Novel")

  # Show results
  print() # For separating from progress bars
  print("="*20,"Evaluating CoCoOp","="*20)
  print(f"🔍 Base classes accuracy: {base_accuracy*100:.2f}%")
  print(f"🔍 Novel classes accuracy: {novel_accuracy*100:.2f}%")
  print(f"🔍 Harmonic Mean: {harmonic_mean(base_accuracy, novel_accuracy)*100:.2f}%")

Let's now see the results by executing the following evaluation.

In [None]:
if RUN_COCOOP:
  cocoop_eval(do_train=TRAIN_COCOOP)

The new CoCoOp results can now be compared with zero-shot:
<div align="center">

| Model      | Base | Novel | Harmonic Mean |
|------------|----------|-----------|---------------|
| Zero-Shot  | 71.33%   | 78.24%    | 74.62%        |
| CoCoOp     | 95,19%   | 71.11%    | 81.41%        |

</div>
We can notice clearly how CoCoOp outperforms zero-shot performance increasing greatly base accuracy and enchancing the harmonic mean as we initially planned.
After having implement this prompt tuning technique let's see a more complex technique that increases even more the performance, providing also better training times.

# Evaluation of CLIP-LoRA
After introducing a prompt tuning technique we implement Low-Rank Adaption or LoRA.
This PEFT technique was introduced in the context of fine-tuning large pretrained models.

The groundbreaking idea behind this technique is that the weigth updates during fine tuning exhibit a "low intrinsic rank", this means that a weight change to a large matrix can be approximated by smaller low rank ones. Mathematically lora proposes to freeze the pretrained weights of the model $W_0 \in \real^{(d\times k)}$ and approximate weight updates $\Delta W = BA$ with matrix $B$ and $A$ being $A \in \real^{d \times r}$ and $B \in \real^{r \times k}$ where $r << min(d,k)$ is the rank.
During training only parameters in A and B are trained making it so the output of the adapted layer is:

$o = W_0x + (BA)x$ with x as the input of the adapted layer.

<div style="text-align: center;">
  <img src="https://raw.githubusercontent.com/Frasor2002/DL_Project/main/report_diagrams/LoRA_architecture.png" width="70%">
</div>
The following code implements this adaptation of the linear layers (query, key and value) inside of the multi-head attention modules.


The scaling parameter is used to compute the final output of the lora adapter ($BAx$), while the dropout offers the possibility for an additional dropout layer to be inserted before the LoRA output.

In [None]:
class LoRAGroupedLinear(nn.Linear):
  """Augmented linear layer using a group of low-rank updates."""
  def __init__(self, in_features, out_features, bias, lora_r, lora_alpha, lora_dropout, enable_lora=[False], merge_weights=True):
    """Initialize the module.
    Args:
      in_features (int): number of input features.
      out_features (int): number of output features.
      lora_r (int): rank of low rank decomposition.
      lora_alpha (int): scaling factor for LoRA adjustment.
      lora_dropout (float): dropout prob for dropout in LoRA branch.
      enable_lora (list): list of which output groups should use LoRA.
      merge_weights (bool): if true merge LoRA and model weights for inference.
    """
    # Initialize parent linear layer
    super().__init__(in_features, out_features, bias=bias)

    # Save LoRA input params
    self.lora_r = lora_r
    self.lora_alpha = lora_alpha
    self.lora_dropout = nn.Dropout(lora_dropout) if lora_dropout > 0 else nn.Identity()
    self.merge_weights = merge_weights
    self.enable_lora = enable_lora

    # When initialization occurs weights are unmerged
    self.merged=False

    # Declare as many LoRA params as are the outputs
    if lora_r > 0 and any(enable_lora):
      self.n_lora_out = sum(enable_lora)
      self.n_out = len(enable_lora)
      self.lora_A = nn.Parameter(self.weight.new_zeros((lora_r * self.n_lora_out, in_features)))
      self.lora_B = nn.Parameter(self.weight.new_zeros((out_features // self.n_out * self.n_lora_out, lora_r)))

      self.scaling = self.lora_alpha / self.lora_r

      # Freeze pretrained
      self.weight.requires_grad = False

      # Mask to track where LoRA is applied
      self.lora_ind = self.weight.new_zeros((out_features), dtype=torch.bool).view(self.n_out, -1)
      self.lora_ind[enable_lora, :] = True
      self.lora_ind = self.lora_ind.view(-1)

    # Initialize LoRA params
    self.reset_parameters()

  def reset_parameters(self):
    """Method to reinit linear and LoRA params."""
    super().reset_parameters()
    if hasattr(self, 'lora_A'):
      # lora_A is initialized in a default way as nn.Linear
      nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
      # lora_B is initialized to zero
      nn.init.zeros_(self.lora_B)

  def zero_pad(self, input_seq):
    """Zero pads the input sequence to match output feature dim using LoRA mask.
    Args:
      input_seq (Tensor): input sequence to pad.
    Returns:
      Tensor: padded input sequence.
    """
    # First get zero tensor
    padded = input_seq.new_zeros((len(self.lora_ind), *input_seq.shape[1:]))
    # Substitute input sequence using LoRA mask
    padded[self.lora_ind] = input_seq
    return padded

  def merge_AB(self):
    """Compute effective weight delta from lora_A and lora_B and apply zero padding."""
    weight_delta = F.conv1d(
        self.lora_A.unsqueeze(0),
        self.lora_B.unsqueeze(-1),
        groups=self.n_lora_out
    ).squeeze(0)
    return self.zero_pad(weight_delta)

  def train(self, mode=True):
    """Enable training mode. If merging is enabled, weights are unmerged when in training mode
    and merged back when switching to eval mode."""
    super().train(mode)
    if mode: # If training mode
      if self.merge_weights and self.merged:
        # If merging is active and weights are merged, unmerge
        if self.lora_r > 0 and any(self.enable_lora):
          self.weight.data -= self.merge_AB() * self.scaling
        self.merged = False
    else: # Going in eval mode
      if self.merge_weights and not self.merged:
        # If merging is active and weights are not merged, merge
        if self.lora_r > 0 and any(self.enable_lora):
          self.weight.data += self.merge_AB() * self.scaling
        self.merged = True

  def forward(self, input_seq):
    """Forward pass.
    Args:
      input_seq (Tensor): input sequence.
    Returns:
      Tensor: output sequence.
    """
    output = F.linear(input_seq, self.weight, bias=self.bias)
    if self.merged: # If weights are merged -> eval mode
      return output
    else: # weights not merged -> train mode
      if self.lora_r > 0:
        lora_out = self.lora_dropout(input_seq) @ self.merge_AB().T
        output += lora_out * self.scaling
      return output

The MultiHeadAttention module inside CLIP nees to be sligthly modified to account for the adapted linear projections inside while keeping intact the implementation of the vanilla module to avoid strange behaviours.

The main additions are the lora parameters and some way to set inside the layer in which linear layers we enable lora (value, query or key).
The reference[4] provides insight that adapting query and value layers seems to yield the best results and so we keep that choice.

Some parameters that are never used are kept because they are required when injecting this custom modules inside CLIP to avoid errors.

In [None]:
class LoRAMultiHeadAttention(nn.Module):
  """Multi-Head Attention module augmented with LoRA."""
  def __init__(self, lora_r, lora_alpha, lora_dropout, embed_dim, num_heads, dropout, bias=True, q_lora=True, k_lora=False, v_lora=True):
    """"Initialize the module.
    Args:
      lora_r (int): rank of LoRA update matrices.
      lora_alpha (int): scaling factor for LoRA.
      lora_dropout (float): dropout probability for LoRA.
      embed_dim (int): dimension of input embeddings.
      num_heads (int): number of attention heads.
      dropout (float): dropout probability after attention.
      bias (bool): if True add a learnable bias to projection layers.
      q_lora (bool): enable LoRA for query projection.
      k_lora (bool): enable LoRA for key projection.
      v_lora (bool): enable LoRA for value projection.
    """
    super().__init__()

    # Save embed dim and set dim of key and value vectors equal to it
    self.embed_dim = embed_dim
    self.kdim = embed_dim
    self.vdim = embed_dim

    # Save multihead attention params
    self.num_heads = num_heads
    self.dropout = dropout
    # Dimension of an head in embed_dim / num_heads
    self.head_dim = embed_dim // num_heads

    # Merged linear layer with LoRA
    qkv_params = {
        "in_features": embed_dim,
        "out_features": 3 * embed_dim, # since we get three outputs for query, key and values
        "bias": bias, # enable bias for linear layer
        "lora_r": lora_r,
        "lora_alpha": lora_alpha,
        "lora_dropout": lora_dropout,
        "enable_lora": [q_lora, k_lora, v_lora]
    }
    self.qkv = LoRAGroupedLinear(**qkv_params)

    self.scaled_dot_product_attention = F.scaled_dot_product_attention
    self.proj = nn.Linear(embed_dim, embed_dim, bias=bias)

  def set_parameters(self, mod):
    """Initialize from existing Pytorch modules to load pretrained weights.
    Args:
      mod (nn.Module): Pytorch module with pretrained weights.
    """
    # Copy weights
    self.qkv.weight.data = mod.in_proj_weight.data
    self.qkv.bias.data = mod.in_proj_bias.data
    self.proj.weight.data = mod.out_proj.weight.data
    self.proj.bias.data = mod.out_proj.bias.data



  def forward(self, query, key, value, key_padding_mask=None, attn_mask=None, is_causal=False, need_weights=True):
    """Forward pass to apply Multi-Head Attention.
    Args:
      query (Tensor): query tensor.
      key (Tensor): key tensor.
      value (Tensor): value tensor.
      key_padding_mask (Tensor): mask to ignore padding.
      attn_mask (Tensor): optinal attention mask.
      is_causal (bool): If True, applies a causal mask.
      need_weights (bool): parameter kept for compatibility.
    """
    # Create key padding mask of proper shape and size
    key_padding_mask = F._canonical_mask(
      mask=key_padding_mask,
      mask_name="key_padding_mask",
      other_type=F._none_or_dtype(attn_mask),
      other_name="attn_mask",
      target_type=query.dtype,
    )

    # Get dims
    tgt_len, bsz, embed_dim = query.shape
    src_len, _, _ = key.shape
    E = query.size(-1)

    # Project Q, K, V using grouped LoRA linear layer
    qkv = self.qkv(query)
    # Split qkv from merged Tensor
    qkv = qkv.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
    q, k, v = qkv[0], qkv[1], qkv[2]

    # Canonicalize attention mask
    attn_mask = F._canonical_mask(
      mask=attn_mask,
      mask_name="attn_mask",
      other_type=F._none_or_dtype(key_padding_mask),
      other_name="key_padding_mask",
      target_type=q.dtype,
      check_other=False,
    )

    # If attn_mask param is provided, validate and reshape it
    if attn_mask is not None:
      # Mask dim must be 3
      if attn_mask.dim() == 2:
        attn_mask = attn_mask.unsqueeze(0)

    if attn_mask is not None:
      if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
        attn_mask = attn_mask.unsqueeze(0)
      else:
        attn_mask = attn_mask.view(bsz, self.num_heads, -1, src_len)

    # Use dropout only if in training mode
    dropout_p = self.dropout if self.training else 0.0

    # Prepare q, k, v for attention by splitting heads
    q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
    k = k.view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
    v = v.view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
    src_len = k.size(1)

    # Reshape for attention
    q = q.view(bsz, self.num_heads, tgt_len, self.head_dim)
    k = k.view(bsz, self.num_heads, src_len, self.head_dim)
    v = v.view(bsz, self.num_heads, src_len, self.head_dim)

    # Apply scaled dot product attention
    attn_output = self.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)

    # Recombine attention heads
    attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)

    # Final output projection
    attn_output = self.proj(attn_output)
    attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))

    return attn_output, None

Once the modules have been properly defined, they have to be injected inside the vanilla CLIP model. The following code takes either clip visual or text encoder transformer as input, iterates on all resblocks and substitutes their MultiHeadAttention modules with our modified lora modules.

The pretrained parameters have to be properly handled and copied in the lora layer in order not to lose them.

In [None]:
def lora_replace_multihead_attention(transformer, lora_r, lora_alpha, lora_dropout):
  """Replace multihead attention in clip model with lora enhanced multihead attention.
  All blocks are replaced in this implementation.
  Args:
    transformer (nn.Module): transformer model where to replace layers.
    lora_r (int): rank of LoRA update matrices.
    lora_alpha (int): scaling factor for LoRA.
    lora_dropout (float): dropout probability for LoRA.
  """
  for resblock in transformer.resblocks:
    # Params taken from resblock
    # References applies LoRA only to q and v
    lora_multihead_params = {
        "lora_r": lora_r,
        "lora_alpha": lora_alpha,
        "lora_dropout": lora_dropout,
        "embed_dim": resblock.attn.embed_dim,
        "num_heads": resblock.attn.num_heads,
        "dropout": resblock.attn.dropout,
        "bias": True,
        "q_lora": True,
        "k_lora": False,
        "v_lora": True
    }
    # Initialize, load weights and substitute module
    lora_attn = LoRAMultiHeadAttention(**lora_multihead_params)
    lora_attn.set_parameters(resblock.attn)
    resblock.attn = lora_attn

  return transformer


The reference work[4] proposes an ablation study that shows the best hyperparameter for LoRA that we reproposed here and also that adapting both visual and textual encoders yields the best results.

Another essential step is to freeze pretrained parameters as we do here, making sure we are only training LoRA params.

In [None]:
def get_clip_lora():
  """Load CLIP substituting LoRA modules into text and visual encoders."""
  # LoRA hyperparams taken from reference work [4]
  LORA_VISUAL_R = 64
  LORA_VISUAL_ALPHA = 32
  LORA_VISUAL_DROPOUT = 0.1
  LORA_TEXT_R = 16
  LORA_TEXT_ALPHA = 32
  LORA_TEXT_DROPOUT = 0.1

  clip_model, _ = clip.load(CLIP_BACKBONE, device="cpu")

  # Apply LoRA to both text and visual encoders
  clip_model.visual.transformer = lora_replace_multihead_attention(clip_model.visual.transformer,
    lora_r=LORA_VISUAL_R, lora_alpha=LORA_VISUAL_ALPHA, lora_dropout=LORA_VISUAL_DROPOUT
  )
  clip_model.transformer = lora_replace_multihead_attention(clip_model.transformer,
    lora_r=LORA_TEXT_R, lora_alpha=LORA_TEXT_ALPHA, lora_dropout=LORA_TEXT_DROPOUT
  )
  clip_model = clip_model.to(DEVICE)

  # Freeze CLIP params
  for name, param in clip_model.named_parameters():
      param.requires_grad = 'lora' in name

  # Sanity check, only LoRA params to train
  trainable = {name for name, param in clip_model.named_parameters() if param.requires_grad}
  print(f"Trainable params: {trainable}")

  return clip_model


Utility functions to save and load lora weights are defined here. Remembering to save and load LoRA-only params saves up a lot of space as we did with CoCoOp (10 Mb aganist 500 Mb).

In [None]:
def save_lora(model, filename="LoRA.pt"):
  """Given a CLIP-LoRA model, save only LoRA weights.
  Args:
    model (nn.Module): model to save.
    filename (str): filename to save weights.
  """
  save_dir = os.path.join("bin")
  os.makedirs(save_dir, exist_ok=True)

  lora_state_dict = {
    k: v for k, v in model.state_dict().items()
    if model.get_parameter(k).requires_grad
  }
  file_path = os.path.join(save_dir, filename)
  torch.save(lora_state_dict, file_path)
  print(f"LoRA weights saved to {file_path}.")

def load_lora(model, filename="LoRA.pt", device="cuda:0"):
  """Load into model LoRA only weights.
  Args:
    model (nn.Module): model to load weights.
    filename (str): filename to load weights.
    device (str): device to use.
  """
  file_path = os.path.join("bin", filename)

  if not os.path.isfile(file_path):
    raise FileNotFoundError(f"No LoRA weights file found at {file_path}")

  lora_state_dict = torch.load(file_path, map_location=torch.device(device))

  # Only load the matching keys
  model.load_state_dict(lora_state_dict, strict=False)
  print(f"LoRA weights loaded from {file_path}")

A lora training function is defined to trained the lora model. The training step is done by computing similarity logits between images in the batch and textual features of tokenized defualt prompts with base classes, then crossentropy loss is applied to maximize the similarity between the image feature and the textual features of its class.

We also implement early stopping logic and the possibility for the use of a scheduler.

In [None]:
def lora_train(lora_model, train_loader, val_loader, optimizer, loss_fun, scheduler, base_classes, num_epochs=5, patience=3, device="cuda:0"):
  """"Train CLIP-LoRA model.
  Args:
    lora_model (nn.Module): CLIP-LoRA model to train.
    train_loader (DataLoader): training data loader.
    val_loader (DataLoader): validation data loader.
    optimizer (torch.optim): optimizer to use.
    loss_fun (nn.Module): loss function to use.
    scheduler (torch.optim.lr_scheduler): learning rate scheduler.
    base_classes (list): list of base classes.
    num_epochs (int): number of epochs to train.
    patience (int): patience for early stopping.
    device (str): device to use.
  """
  def train_step(lora_model, loader, optimizer, loss_fun, base_classes, device="cuda"):
    """Train CLIP-LoRA model for one epoch.
    Args:
      lora_model (nn.Module): CLIP-LoRA model to train.
      loader (DataLoader): train data loader.
      optimizer (torch.optim): optimizer to use.
      loss_fun (nn.Module): loss function to use.
      base_classes (list): list of base classes.
      device (str): device to use.
    """
    # Variables to compute performance scores
    samples = 0
    cumulative_loss = 0
    cumulative_accuracy = 0

    # Put model in training mode
    lora_model.train()

    # Iterate over training set
    for images, labels in loader:
      # Put images and labels on correct device
      images, labels = images.to(device), labels.to(device)

      # Encode images and normalize image features
      image_features = lora_model.encode_image(images)
      image_features = image_features / image_features.norm(dim=-1, keepdim=True)
      # Encode text and normalize text features
      text_prompts = [PROMPT_FORMAT.format(CLASS_NAMES[cls]) for cls in base_classes]
      text_inputs = clip.tokenize(text_prompts).to(device)
      text_features = lora_model.encode_text(text_inputs)
      text_features = text_features / text_features.norm(dim=-1, keepdim=True)
      # Compute similarity logits
      logit_scale = 100 # This value is used with LoRA
      logits = logit_scale * image_features @ text_features.T
      # Targets -> image i should match with text prompt j (contrastive learning idea)
      #targets = torch.arange(len(images)).to(device)

      loss = loss_fun(logits, labels)
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()

      # Compute performance scores
      batch_size = images.shape[0]
      samples += batch_size
      cumulative_loss += loss.item() * batch_size
      _, predicted = logits.max(dim=1)
      cumulative_accuracy += predicted.eq(labels).sum().item()

    loss = cumulative_loss / samples
    accuracy = cumulative_accuracy / samples
    return loss, accuracy

  # Initalize training vars before train loop
  best_val_acc = float('-inf')
  best_model_state = None
  curr_patience = patience
  pb = tqdm(range(num_epochs))

  for epoch in pb:
    # Training step
    train_loss, train_acc = train_step(lora_model, train_loader, optimizer, loss_fun, base_classes, device)

    # Evaluation step
    val_acc = clip_test(lora_model, val_loader, base_classes, device)

    # Show val accuracy
    pb.set_description(f"Val Acc: {val_acc*100:.2f}%")

    # If we get better metric save the model
    if val_acc > best_val_acc:
      best_val_acc = val_acc
      best_model_state = lora_model.state_dict()
      curr_patience = patience
    else:
      if curr_patience < 1:
        break
      else:
        curr_patience = curr_patience - 1

    # If scheduler is given, scheduler step
    if scheduler is not None:
      scheduler.step()

  # Load best model before returning
  if best_model_state is not None:
    lora_model.load_state_dict(best_model_state)
  return lora_model


LoRA is then evaluated using the following hyperparams:
- `TRAIN_BATCH`, `VAL_BATCH`, `TEST_BATCH`: batch sizes of different datasets. 32 train batch is taken from reference [4].
- `LR`: learning rate was taken from reference [3] in order to reach best performance.
- `NUM_EPOCHS`: number of epochs to train. 15 is enough to reach a good performance.
- `PATIENCE`: patience for early stopping that allows to avoid overfitting on training set.

Note that a cosine scheduler is also added since it was used in the reference [4].

In [None]:
def lora_eval(do_train=True):
  """Evaluate CLIP-LoRA.
  Args:
    do_train (bool): train model or load weights.
  """
  # Get CLIP-LoRA
  clip_model = get_clip_lora()

  # Get datasets
  train_base, val_base, test_base, test_novel, base_classes, novel_classes = get_dataset(do_preprocess=True)

  # Hyperparam
  TRAIN_BATCH = 32 # From CLIP-LoRA paper
  VAL_BATCH = 64
  TEST_BATCH = 128
  LR = 2e-4 # From CLIP-LoRA paper
  NUM_EPOCHS = 15
  PATIENCE = 3

  # Get loaders
  train_loader = torch.utils.data.DataLoader(train_base, batch_size=TRAIN_BATCH, shuffle=True, num_workers=2)
  val_loader = torch.utils.data.DataLoader(val_base, batch_size=VAL_BATCH, shuffle=False, num_workers=2)
  test_base_loader = torch.utils.data.DataLoader(test_base, batch_size=TEST_BATCH, shuffle=False, num_workers=2)
  test_novel_loader = torch.utils.data.DataLoader(test_novel, batch_size=TEST_BATCH, shuffle=False, num_workers=2)

  # Get cost function
  loss_fun = torch.nn.CrossEntropyLoss()

  # Train
  if do_train:
    optimizer = optim.AdamW([p for p in clip_model.parameters() if p.requires_grad], lr=LR)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

    # Train
    train_params = {
        "lora_model": clip_model,
        "train_loader": train_loader,
        "val_loader": val_loader,
        "optimizer": optimizer,
        "loss_fun": loss_fun,
        "scheduler": scheduler,
        "base_classes": base_classes,
        "num_epochs": NUM_EPOCHS,
        "patience": PATIENCE,
        "device": DEVICE
    }
    clip_model = lora_train(**train_params)
    # Save weights
    save_lora(clip_model, "LoRA.pt")
  else:
    load_lora(clip_model, "LoRA.pt", device=DEVICE)

  # Evaluate base and novel
  base_accuracy = clip_test(clip_model, test_base_loader, base_classes, DEVICE, label="🧠 LoRA evaluation on Base")
  novel_accuracy = clip_test(clip_model, test_novel_loader, novel_classes, DEVICE, label="🧠 LoRA evaluation on Novel")

  # Show results
  print() # For separating from progress bars
  print("="*20,"Evaluating LoRA","="*20)
  print(f"🔍 Base classes accuracy: {base_accuracy*100:.2f}%")
  print(f"🔍 Novel classes accuracy: {novel_accuracy*100:.2f}%")
  print(f"🔍 Harmonic Mean: {harmonic_mean(base_accuracy, novel_accuracy)*100:.2f}%")

LoRA evaluation is run and the registered are compared to the previous ones.

In [None]:
if RUN_LORA:
  lora_eval(do_train=TRAIN_LORA)

<div align="center">

| Model      | Base | Novel | Harmonic Mean |
|------------|----------|-----------|---------------|
| Zero-Shot  | 71.33%   | 78.24%    | 74.62%        |
| CoCoOp     | 95,19%   | 71.11%    | 81.41%        |
| CLIP-LoRA     | 96.81 %   | 74.13 %    | 83.96 %        |

</div>
We can notice how LoRA significantly improves upon the CoCoOp performance while also making the training faster thanks to the possibilitoy of a bigger number of training epochs. LoRA also improves novel accuracy compared to CoCoOp.

After having evaluated the efficacy of LoRA our next goal is to try to further improve on it by implementing and evaluating the DISEF technique that adds synthetic samples to further improve LoRA performance on base classes.

# Evaluation of DISEF
Diversified in-domain synthesis with efficient fine-tuning for few-shot classification (DISEF) is a technique that involves the creation of a generative samples to further enhance LoRA. This motivated our implementation of the pipeline to evaluate its behaviour and study improvements or detriments to the result with addition of the synthetic dataset.

From now on we will split our work between generative and training sections. Discussing implementation choices in reference to the original work.

## Synthetic sample generation

The following diagram shows the progress inside the generative process:

<div style="text-align: center;">
  <img src="https://raw.githubusercontent.com/Frasor2002/DL_Project/main/report_diagrams/generative_pipeline.png" width="80%">
</div>

The process begins by picking two random samples from the same class: one is used to generate a caption, then both the caption and the second image are fed as input to the diffusion model obtaining a new synthetic sample. 

After generation is done we apply CLIP filtering to discard any uninformative sample. 

This process is repeated iteratively until either we get all desired samples or a stopping criterion is reached (i.e. 100 samples of that class have already been discarded).

### Prepare samples
The following code simple preprocesses the dataset that will be used to generate samples in order to get filepaths (we need images without preprocessing in the generative pipeline). A function to randomly pick two random samples from the same class is also implemented.

In [None]:
class DatasetForGeneration(Dataset):
  """Wrap dataset to save filepaths for easily loading images in generation."""
  def __init__(self, subset):
    """Initialize dataset given a subset.
    Args:
      subset: subset of dataset without filepaths.
    """
    self.subset = subset
    self.paths = self.extract_paths()

  def __len__(self):
    """Get number of samples."""
    return len(self.subset)

  def __getitem__(self, idx):
    """Get img, label and filepath"""
    image, label = self.subset[idx]
    path = self.paths[idx]
    return image, label, path

  # Extract an image path given a subset i.e. train_base
  def extract_paths(self):
    """Given a subset of a dataset get filepaths."""
    dataset = self.subset.dataset
    indices = self.subset.indices
    return [dataset._image_files[i] for i in indices]

In [None]:
def get_random_samples(dataset, chosen_class):
  """Given a chosen class, get two distinct random samples.
  Args:
    dataset (list): list of samples with image and labels.
    chosen_class (int): label of chosen class to pick samples from.
  Returns:
    tuple: a tuple of two random samples.
  """
  # Filter samples matching to the original class
  class_samples = [sample for sample in dataset if sample[1] == chosen_class]

  # Check that at least two samples are contained for that class
  if len(class_samples) < 2:
    raise ValueError(f"Not enough samples found for class {chosen_class}.")

  # Randomly pick two samples
  random_samples = tuple(random.sample(class_samples, 2))
  return random_samples

### Captioning model

In this section we load the captioning model LLaVA from Hugging Face Hub using the `transformers` library. The choice to use a model integrated into the `transformers` library was made both to maintain the libraries' compatibility with the rest of the project and for its ease of use. Since there was no need to make changes within the model, we preferred a simpler and more reliable architecture at the expense of customization. We also configure quantization for faster processing.


The captioning process, on the other hand, occurs in the following steps: the image is loaded, the complete prompt is composed by adding information about the image class, and finally, using `llava_processor`, the inputs are prepared to be correctly read by LLaVA. Specifically, the text is tokenized, the images are normalized, and the image feature extraction takes place. After generation, it is important to decode the output to remove tokenization and obtain human-readable text.

In [None]:
def load_llava(device="cuda:0"):
  """"Load off-the-shelf captioning model.
  Args:
    device (str): device to be used.
  Returns:
    tuple: llava_processor, llava_model
  """

  print(f"Loading LLaVA from Hugging Face Hub: {LLAVA_MODEL_ID}")

  # Load processor (tokenizer and image processor)
  llava_processor = AutoProcessor.from_pretrained(LLAVA_MODEL_ID)

  # Configure 4-bit quantization (GPU only)
  bnb_config = BitsAndBytesConfig(
      load_in_4bit=True,
      bnb_4bit_quant_type="nf4",
      bnb_4bit_compute_dtype=torch.float16,
      bnb_4bit_use_double_quant=True,
  )

  # Load the llava model with correct configuration
  llava_model = LlavaForConditionalGeneration.from_pretrained(
      LLAVA_MODEL_ID,
      quantization_config=bnb_config if device == "cuda:0" else None,
      torch_dtype=torch.float16 if device == "cuda:0" else torch.float32,
      device_map="auto" if device == "cuda:0" else None,
      low_cpu_mem_usage=True
  )

  print(f"\nLLaVA model '{LLAVA_MODEL_ID}' loaded.")

  return llava_processor, llava_model


In [None]:
def caption_image(llava_processor, llava_model, image_path, label=None):
  """Given a llava model and a image, caption the image and return the caption.
  Args:
    llava_processor (AutoProcessor): a tokenizer and image processor.
    llava_model (nn.Module): llava model pretrained for captioning.
    image_path (str): path of the image to caption.
    label (int): if a label is passed it is added to the prompt for better captions.
  Returns:
    str: caption of the image.
  """

  # Read the image
  image = Image.open(image_path).convert("RGB")

  # Build the prompt
  # User prompt
  user_prompt = LLAVA_USER_PROMPT
  # If we give a label, embed the label inside the prompt
  if label:
    user_prompt += LLAVA_LABEL_PROMPT_FORMAT.format(label)
  else:
    user_prompt += "."

  # Llava models expect a specific chat template
  full_prompt = (LLAVA_FULL_PROMPT_FORMAT.format(user_prompt))

  # Process image and prompt and put on correct device
  llava_inputs = llava_processor(text=full_prompt, images=image, return_tensors="pt")
  llava_inputs = {k: v.to(llava_model.device) for k, v in llava_inputs.items()}

  # Caption generation
  with torch.inference_mode(): # Avoid gradient computations
    outputs = llava_model.generate(**llava_inputs, do_sample=True, temperature=0.7, max_new_tokens=60, use_cache=True)

  # Lenght of the prompt to skip it in the output
  prompt_len = llava_inputs['input_ids'].shape[1]
  # Decode output to get caption
  caption = llava_processor.tokenizer.decode(outputs[0, prompt_len:], skip_special_tokens=True).strip()
  return caption

### Stable Diffusion for generation

In this section we load the stable diffusion model that will be used in the generation pipeline. The parameters of the model have been set according to the reference[4]. For this purpose, we defined the two functions: the first is responsible for initializing and configuring a pre-trained Stable Diffusion model pipeline, the second is responsible for the actual image generation. 

Stable Diffusion is loaded using the `StableDiffusionImg2ImgPipeline` provided by the diffusers library. We chose to use the full pipeline to have a simple and reliable tool since we were not interested in making internal changes to the stable diffusion logic. 

The generation step consists of loading the input image, combining the base prompt and the caption to form the full prompt, and calculating the strength to pass to the model. The function returns the generated image and the full prompt used in generation.

In [None]:
def load_stable_diffusion():
  """Load stable diffusion pipeline.
  Returns:
    StableDiffusionPipeline: stable diffusion pipeline for generation.
  """
  print("Loading Stable Diffusion Pipeline.")

  sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
      SD_MODEL,
      torch_dtype=torch.float16,
      local_files_only = False # True if an offline model is used
  ).to(DEVICE)

  # Configure SD scheduler
  sd_scheduler = DPMSolverMultistepScheduler.from_config(
      sd_pipe.scheduler.config,
      algorithm_type=DIFFUSION_SAMPLER
  )
  sd_pipe.scheduler = sd_scheduler

  print("Loaded Stable Diffusion Pipeline.")
  return sd_pipe

In [None]:
def run_stable_diffusion(sd_pipe, image_path, caption, base_sd_prompt):
  """Given a stable diffusion pipeline, an image path and prompts, generate an image.
  Args:
    sd_pipe (StableDiffusionPipeline): stable diffusion pipeline.
    image_path (str): path of input image.
    caption (str): caption to use for generation.
    base_sd_prompt (str): base stable diffusion prompt for generation.
  Returns:
    tuple: generated image and final prompt used for logging.
  """
  # Load and preprocess the input image
  input_image = Image.open(image_path).convert("RGB")

  # Compose the final text prompt
  final_sd_prompt = f"{base_sd_prompt} {caption}"

  # strenght computed from paper
  computed_strength = NOISING_STEPS / DIFFUSION_STEPS

  # Generation of the image
  gen_image = sd_pipe(
		prompt = final_sd_prompt,
		image=input_image,
		num_inference_steps=DIFFUSION_STEPS,
		guidance_scale=CFG_STRENGHT,
		strength=computed_strength
  ).images[0]

  return gen_image, final_sd_prompt

### CLIP Filter

CLIP filtering is one of the most important parts of the pipeline as it is responsible for filtering the generated images, keeping only the truly informative samples. The filtering is done by keeping only samples where a zero-shot CLIP model is able to predict the correct class. This is implemented similarly to our previous `clip_test` function, by computing similarity logits and predictions of the given model.

The following functions allows also to decide if the discarded samples have to be deleted or are simply moved to another directory.

In [None]:
def clip_filter(clip_model, clip_preprocess, generated_imgs, categories, class_filtering_counter, device="cuda:0", discarded_dir=None):
  """CLIP filtering given a CLIP model on generated imgs.
  Args:
    clip_model (nn.Module): CLIP for filtering images.
    clip_preprocess: preprocess pipeline for CLIP input.
    generated_imgs (dict): dictionary containing filepaths, imgs and labels of generated images.
    categories (list): list of classes of the generated imgs.
    class_filtering_counter (dict): mapping label -> number of filtered images per class.
    device (str): device to use.
    discarded_dir (str): discarded dir where to put discarded imgs.
  Returns:
    tuple: number of discarded imgs, class_filtering_counter.
  """
  # Dataset class for dataloader
  class SyntheticDataset(Dataset):
    def __init__(self, imgs, labels, filepaths):
      """
      Args:
          images (list of PIL.Image): List of images.
          labels (list of int): Class labels for each image.
          filepaths (list of str): Corresponding image file paths.
      """
      self.images = imgs
      self.labels = labels
      self.filepaths = filepaths

    def __len__(self):
      return len(self.images)

    def __getitem__(self, idx):
      img = self.images[idx]
      label = self.labels[idx]
      path = self.filepaths[idx]
      return img, label, path
  # If we save discarded create discarded dir
  if discarded_dir is not None:
    os.makedirs(discarded_dir, exist_ok=True)

  # Preprocess imgs
  generated_imgs["imgs"] = [clip_preprocess(img) for img in generated_imgs["imgs"]]

  # Create dataset and dataloader
  syn_dataset = SyntheticDataset(**generated_imgs)
  syn_loader = DataLoader(syn_dataset, batch_size=32, shuffle=False, num_workers=2)

  # CLIP Filtering
  clip_model.eval()

  # Dictionary for remapping labels label -> into contiguous set
  contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}
  idx2cat = {v: k for k, v in contig_cat2idx.items()}  # Reverse map for moving imgs

  # Apply the standard CLIP template used for oxford flowers to all categories and immediately tokenize each sentence
  text_inputs = clip.tokenize([PROMPT_FORMAT.format(CLASS_NAMES[c]) for c in categories]).to(device)

  # Collect filtered image count
  discarded_count = 0

  with torch.no_grad():
    # Encode text features for all classes
    text_features = clip_model.encode_text(text_inputs)
    # Normalize them (standard pratice with CLIP)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    for imgs, labels, paths in tqdm(syn_loader, desc="CLIP Filtering"):
      # Map categories to contiguous to get correct predictions
      labels = torch.Tensor([contig_cat2idx[l.item()] for l in labels]).long()
      imgs, labels = imgs.to(device), labels.to(device)

      # Encode image features
      image_features = clip_model.encode_image(imgs)
      # Normalize image features
      image_features /= image_features.norm(dim=-1, keepdim=True)

      # Cosine similarity between image and text features and keep the argmax for every row (every image)
      predicted_class = (image_features @ text_features.T).argmax(dim=-1)
      # Check which are correct
      correct_predictions = (predicted_class == labels)

      # Iterate on predictions for filtering
      for i, correct in enumerate(correct_predictions):

        # If more than 100 instances filtered out for that class we skip filtering
        if class_filtering_counter[idx2cat[labels[i].item()]] >= 100:
          continue

        if not correct: # If filtered out
          discarded_count += 1
          class_filtering_counter[idx2cat[labels[i].item()]] += 1

          if not discarded_dir:
            os.remove(paths[i])
          else:
            true_class_name = CLASS_NAMES[idx2cat[labels[i].item()]]
            pred_class_name = CLASS_NAMES[idx2cat[predicted_class[i].item()]]
            # Move into true class name folder
            tgt_dir = os.path.join(discarded_dir, true_class_name, f"predicted_{pred_class_name}")
            os.makedirs(tgt_dir, exist_ok=True)

            tgt_im = os.path.join(tgt_dir, os.path.basename(paths[i]))
            # Overwrite if an image is already there
            if os.path.exists(tgt_im):
              os.remove(tgt_im)

            shutil.move(paths[i], tgt_dir)
  return discarded_count, class_filtering_counter


### Generation Pipeline

Here we join diffusion and captioning to generate multiple samples iterating over the classes.

In [None]:
def generate_sample(dataset, llava, sd_pipe, chosen_class):
  """Generate one class sample.
  Args:
    dataset (list): list of samples to use for generation pipeline.
    llava (tuple): llava model and processor for captioning.
    sd_pipe (StableDiffusionPipeline): stable diffusion pipeline for generation.
    chosen_class (int): label of the class of which a sample is generated.
  Returns:

  """
  img1, img2 = get_random_samples(dataset, chosen_class)
  # Get filepaths
  _, label1, filepath1 = img1
  _, label2, filepath2 = img2

  assert chosen_class == label1 and chosen_class == label2, "Wrong classes used for generation."

  # Create sd prompt
  sd_prompt = SD_PROMPT_FORMAT.format(CLASS_NAMES[label1])

  # Get caption
  llava_processor, llava_model = llava
  caption = caption_image(llava_processor, llava_model, str(filepath2), label2)
  # Generate sample
  generated, prompt_used = run_stable_diffusion(sd_pipe, str(filepath1), caption, sd_prompt)

  return generated, prompt_used

In [None]:
def generate_samples_for_classes(dataset, llava, sd_pipe, class_samples_to_gen, generation_dir, force_regeneration=False):
    """Given a list of samples to generate for each class, generate them.
    Args:
      sd_pipe: stable diffusion pipeline.
      llava: llava model and processor.
      dataset (list): dataset for generation.
      class_samples_to_gen (dict): class samples to gen for each class with mapping label -> number of shots.
      generation_dir (str): directory where to put generated data.
      force_regeneration (bool): delete all old synthetic data and generate new one.
    """
    # Create dir where to put synthetic data
    os.makedirs(generation_dir, exist_ok=True)

    # Output the generated data for filtering
    generated_data = {
      "imgs": [],
      "labels": [],
      "filepaths": []
    }

    # Progress bar for classes
    with tqdm(total=len(class_samples_to_gen), position=0, desc="Class Generation Progress") as class_bar:
      for class_label, num_samples in class_samples_to_gen.items():
        class_dir = os.path.join(generation_dir, CLASS_NAMES[class_label])

        # If regeneration is forced, delete old data
        if force_regeneration and os.path.exists(class_dir):
          print(f"Clearing directory: {class_dir}")
          shutil.rmtree(class_dir)

        # Create class_dir
        os.makedirs(class_dir, exist_ok=True)

        missing_ids = []
        for i in range(1, num_samples + 1):
          expected_filename = f"Image_{i}.png"
          expected_filepath = os.path.join(class_dir, expected_filename)
          if force_regeneration or not os.path.exists(expected_filepath):
            missing_ids.append(i)

        if not missing_ids:
          class_bar.update(1)
          continue

        with tqdm(total=len(missing_ids), position=1, desc=f"Generating class {CLASS_NAMES[class_label]}", leave=False) as sample_bar:
          for sample_id in missing_ids:
            image, prompt_used = generate_sample(dataset, llava, sd_pipe, class_label)
            # print(f"{prompt_used}")

            # Filename to save
            base_filename = os.path.join(class_dir, f"Image_{sample_id}")
            img_path = f"{base_filename}.png"
            image.save(img_path)

            # Save generated data
            generated_data["imgs"].append(image)
            generated_data["labels"].append(class_label)
            generated_data["filepaths"].append(img_path)

            sample_bar.update(1)

        class_bar.update(1)

    return generated_data


### Generation loop

After every individual component has been defined we create the generation loop setting as stopping condition that all desired samples have been created and filtering stops for a class if we already filtered out 100 samples (condition to reach convergence).
The number of 100 samples is a hyperparameter and could be increased given more gpu resources since with 100 samples at least 13 hours where necessary to end the generation.

This could be due to both the difficulty for the generator to create images similar to the dataset or to how zero-shot clip shows still a low accuracy for base classes. This pitfall will be later adressed in our disef section of the report.

In [None]:
def count_generated_samples(save_dir):
  """Count currently generated images in save_dir.
  Args:
    save_dir (str): save directory for generated images.
  Returns:
    int: number of generated images.
  """
  count_generated = 0
  for root, dirs, files in os.walk(save_dir):
    count_generated += sum(1 for f in files)

  return count_generated

In [None]:
def generation_pipeline(sample_per_class, dataset, categories, clip_model, clip_preprocess, device, save_dir, del_dir=None, force_regeneration=False):
  """Wrapper function for the whole generation pipeline.
  Args:
    sample_per_class (dict): map label -> number of sample to generate.
    dataset (Dataset): dataset to use for generation.
    categories (list): list of classes in the dataset.
    clip_model (nn.Module): CLIP model for filtering.
    clip_preprocess: CLIP preprocess for filtering.
    device (str): device to use.
    save_dir (str): directory where to save generated images.
    del_dir (str): directory where to put filtered out images.
    force_regeneration (bool): if images already generated, regenerate them.
  """

  # Prepare dataset
  gen_dataset = DatasetForGeneration(dataset)

  # Load llava
  llava = load_llava(device)
  # Load sd_pipe
  sd_pipe = load_stable_diffusion()

  # Counter for checking how many times we filtered an instance of a class
  class_filtering_counter = {}
  for c in categories:
    class_filtering_counter[c] = 0

  # Total samples to generate
  total_samples_to_gen = sum(sample_per_class.values())

  print("="*20,"Start image generation","="*20)

  # Generate until we are satisfied with the samples
  # If we filtered more than 100 times a class we accept it without filtering
  while count_generated_samples(save_dir) < total_samples_to_gen:
    generated_data = generate_samples_for_classes(gen_dataset, llava, sd_pipe, sample_per_class, save_dir, force_regeneration)
    discarded, class_filtering_counter = clip_filter(clip_model, clip_preprocess, generated_data, categories, class_filtering_counter, device, del_dir)
    print(f"CLIP filtered out {discarded} shots.")
    for cls, count in class_filtering_counter.items():
      print(f"{CLASS_NAMES[cls]} deleted {count}.", end=" ")

## Training and evaluation

The synthetic dataset must be loaded from current `imgs` directory, the following utility functions loads the dataset, preprocessing the images.

In [None]:
def load_syn_dataset(data_dir, do_preprocess=False):
  """Load generated samples for training.
  Args:
    do_preprocess (bool): if True preprocess with CLIP preprocess.
  """
  # Collect samples in pairs (img, label)
  syn_dataset = []

  # Get clip preprocess
  if do_preprocess:
    _, preprocess = clip.load(CLIP_BACKBONE, device="cpu")

  for label, class_name in enumerate(CLASS_NAMES):
    current_dir = os.path.join(data_dir, class_name)
    # If no folder for a class_name (i. e. skip novel)
    if not os.path.isdir(current_dir):
      continue
    # Iterate on img in class folder
    for img_name in os.listdir(current_dir):
      img_path = os.path.join(current_dir, img_name)

      img = Image.open(img_path).convert("RGB")
      if do_preprocess:
        img = preprocess(img)
      syn_dataset.append((img, label))
  return syn_dataset




In this cells a custom dataset that contains both real and synthetic samples and a custom collate_fn function in order to separate them in a batch during training to compute the loss function.

In [None]:
class DisefDataset(torch.utils.data.Dataset):
	"""Dataset to set if real or synthetic samples are contained."""
	def __init__(self, dataset, is_synthetic=False):
		"""Initialize dataset.
		Args:
			dataset (list): list of samples and labels.
			is_synthetic (bool): images are generated or not. 
		"""
		self.dataset = dataset
		self.is_synthetic = is_synthetic

	def __len__(self):
		"""Get number of samples."""
		return len(self.dataset)

	def __getitem__(self, idx):
		"""Given an index returns a sample.
		Args:
			idx (int): index.
		"""
		image, label = self.dataset[idx]
		return image, label, self.is_synthetic
    
def disef_collate_fn(batch):
	"""Custom collate function to create a DataLoader"""
	images_real, labels_real = [], []
	images_syn, labels_syn = [], []

	for img, lbl, is_syn in batch:
		if is_syn:
			images_syn.append(img)
			labels_syn.append(lbl)
		else:
			images_real.append(img)
			labels_real.append(lbl)

	return (
		torch.stack(images_real) if images_real else None,
		torch.tensor(labels_real) if labels_real else None,
		torch.stack(images_syn) if images_syn else None,
		torch.tensor(labels_syn) if labels_syn else None,
	)


The disef training introduces the main difficulty of handling both real and synthetic data for this task the following loss is proposed[4]:

$L = \lambda L_{real} + (1 - \lambda) L_{syn}$ where $\lambda = 0.8$[4]. 

This approach sums the two losses from real and synthetic data, giving more weight to the real samples.

The rest of the implementation follows the loRA training loop with early stopping and LR scheduler.

In [None]:
def disef_train(lora_model, train_loader, val_loader, optimizer, loss_fun, scheduler, base_classes, lambda_weight, num_epochs=5, patience=3, device="cuda:0"):
  """"Training a CLIP-LoRA model using synthetic samples.
  Args:
    lora_model (nn.Module): CLIP-LoRA model to train.
    train_loader (DataLoader): training dataloader.
    val_loader (DataLoader): validation dataloader.
    optimizer (torch.optim): optimizer to use.
    loss_fun (nn.Module): loss function to use.
    scheduler (torch.optim.lr_scheduler): learning rate scheduler.
    base_classes (list): list of base classes.
    lambda_weight (float): weight for loss function.
    num_epochs (int): number of epochs to train.
    patience (int): patience for early stopping.
    device (str): device to use.
  """
  def train_step(lora_model, loader, optimizer, loss_fun, base_classes, lambda_weight=0.8, device="cuda"):
    """Train CLIP-LoRA model for one epoch.
    Args:
      lora_model (nn.Module): CLIP-LoRA model to train.
      loader (DataLoader): real samples training dataloader.
      optimizer (torch.optim): optimizer to use.
      loss_fun (nn.Module): loss function to use.
      base_classes (list): list of base classes.
      lambda_weight (float): weight to compute the loss.
      device (str): device to use.
    """
    # Variables to compute performance scores
    samples = 0
    cumulative_loss = 0
    cumulative_accuracy = 0

    # Set training mode
    lora_model.train()


    # Iterate over training sets
    for images_real, labels_real, images_syn, labels_syn in loader:
      # Init batch_size and loss
      batch_size = 0
      loss = 0

      # Text features for all classes
      text_prompts = [PROMPT_FORMAT.format(CLASS_NAMES[cls]) for cls in base_classes]
      text_tokens = clip.tokenize(text_prompts).to(device)
      ft = lora_model.encode_text(text_tokens)
      ft = ft / ft.norm(dim=-1, keepdim=True)
      
      logit_scale = 100 # Used with CLIP-LoRA

      # Real images
      if images_real is not None:
        images_real, labels_real = images_real.to(device), labels_real.to(device)
        # Real visual features
        fv_real = lora_model.encode_image(images_real)
        fv_real = fv_real / fv_real.norm(dim=-1, keepdim=True)
        # Compute weighted loss
        logits_real = logit_scale * fv_real @ ft.T
        loss_real = loss_fun(logits_real, labels_real)
        loss += lambda_weight * loss_real

        batch_size += images_real.size(0)
        pred_real = logits_real.argmax(dim=1)
        cumulative_accuracy += (pred_real == labels_real).sum().item()

      # Synthetic images
      if images_syn is not None:
        images_syn, labels_syn = images_syn.to(device), labels_syn.to(device)
        # Synthetic visual features
        fv_syn = lora_model.encode_image(images_syn)
        fv_syn = fv_syn / fv_syn.norm(dim=-1, keepdim=True)
        # Compute weighted loss
        logits_syn = logit_scale * fv_syn @ ft.T
        loss_syn = loss_fun(logits_syn, labels_syn)
        loss += (1 - lambda_weight) * loss_syn

        batch_size += images_syn.size(0)
        pred_syn = logits_syn.argmax(dim=1)
        cumulative_accuracy += (pred_syn == labels_syn).sum().item()
      
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()

      cumulative_loss += loss.item() * batch_size
      samples += batch_size

    loss = cumulative_loss / samples
    accuracy = cumulative_accuracy / samples
    return loss, accuracy

  # Initialization before training loop
  best_val_acc = float('-inf')
  best_model_state = None
  curr_patience = patience
  pb = tqdm(range(num_epochs))

  # Logging for plots
  log = {
      "epoch": [],
      "train_loss": [],
      "train_acc": [],
      "val_acc": []
  }

  for epoch in pb:
    # Training step
    train_loss, train_acc = train_step(lora_model, train_loader, optimizer, loss_fun, base_classes, lambda_weight, device)

    # Evaluation step
    val_acc = clip_test(lora_model, val_loader, base_classes, device)

    # Logging
    log["epoch"].append(epoch)
    log["train_loss"].append(train_loss)
    log["train_acc"].append(train_acc)
    log["val_acc"].append(val_acc)

    # Show val accuracy
    pb.set_description(f"Val Acc: {val_acc*100:.2f}%")

    # If we get better metric save the model
    if val_acc > best_val_acc:
      best_val_acc = val_acc
      best_model_state = lora_model.state_dict()
      curr_patience = patience
    else:
      if curr_patience < 1:
        break
      else:
        curr_patience = curr_patience - 1

    # If scheduler is given, scheduler step
    if scheduler is not None:
      scheduler.step()

  # Load best model before returning
  if best_model_state is not None:
    lora_model.load_state_dict(best_model_state)
  return lora_model, log

To analyze the training loop, a plotting function is written.

In [None]:
def plot_log(log, filename):
  """"Plot training log.
  Args:
    log (dict): training log.
    filename (str): name of the file to save.
  """
  plt.figure(figsize=(15,5))

  # Subplot for Cross Entropy
  plt.subplot(1, 2, 1)
  plt.plot(log["epoch"], log["train_loss"], label="Training Loss", color="#6c757d")
  plt.xlabel('Epochs')
  plt.ylabel('Cross Entropy Loss')
  plt.title('Training Loss')
  plt.legend()
  plt.grid(True)

  # Subplot for accuracy
  plt.subplot(1, 2, 2)
  plt.plot(log["epoch"], log["train_acc"], label="Train Acc", color="#6c757d")
  plt.plot(log["epoch"], log["val_acc"], label="Valid Acc", color="#e9c46a")
  plt.xlabel('Epochs')
  plt.ylabel('Accuracy')
  plt.title('Training and Validation Accuracy')
  plt.legend()
  plt.grid(True)
  plt.tight_layout()

  # Save in current directory
  plt.savefig(filename)
  print("Training plot saved in current path.")

The implemented technique can now be evaluated in order to study how the synthetic data changes performance.
We set number of synthetic shots per class to the same number of 32 shots since the reference[4] proposed 64 shots with 16 starting shots per class.

The hyperparameters set are:
- `TRAIN_BATCH`, `VAL_BATCH`, `TEST_BATCH`: batch sizes are kept the same as in the LoRA evaluation.
- `LR`: the used learning rate is the same as proposed in the reference [4].
- `NUM_EPOCHS`: 50 epochs are used to train the model as in reference [4].
- `PATIENCE`: the patience for early stopping is kept as in previous evaluations.
- `LAMBDA_WEIGTH`: 0.8 as in reference [4].
- `WEIGHT_DECAY`: 1e-2 as in reference [4].

In [None]:
def disef_eval(do_gen=True, do_train=True):
  """"Evaluate DISEF technique.
  Args:
    do_gen (bool): generate synthetic samples or not.
    do_train (bool): train model or load weights.
  """
  # Get datasets
  train_base, val_base, test_base, test_novel, base_classes, novel_classes = get_dataset(do_preprocess=True)

  if do_gen:
    # Generate synthetic dataset

    # Zero-shot CLIP for filtering and preprocess
    clip_model, clip_preprocess = clip.load(CLIP_BACKBONE, device=DEVICE)

    # K shots for every class have to be generated
    sample_per_class = {label: GENERATION_K_SHOTS for label in base_classes}

    generation_pipeline(sample_per_class, train_base, base_classes, clip_model, clip_preprocess, DEVICE, DISEF_GEN_PATH, DISEF_DEL_PATH, force_regeneration=False)

  syn_dataset = load_syn_dataset(DISEF_GEN_PATH, do_preprocess=True)

  # Get CLIP-LoRA
  clip_model = get_clip_lora()

  # Hyperparam
  TRAIN_BATCH = 32 # From CLIP-LoRA paper
  VAL_BATCH = 64
  TEST_BATCH = 128
  LR = 2e-4
  NUM_EPOCHS = 50
  PATIENCE = 3
  LAMBDA_WEIGHT = 0.8
  WEIGHT_DECAY = 1e-2
  
	# Create training datasets
  real_dataset = DisefDataset(train_base, is_synthetic=False)
  synthetic_dataset = DisefDataset(syn_dataset, is_synthetic=True)
	# Merge datasets
  merged_dataset = torch.utils.data.ConcatDataset([real_dataset, synthetic_dataset])

  # Get loaders
  train_loader = torch.utils.data.DataLoader(merged_dataset,batch_size=TRAIN_BATCH,shuffle=True,num_workers=2,collate_fn=disef_collate_fn)
  val_loader = torch.utils.data.DataLoader(val_base, batch_size=VAL_BATCH, shuffle=False, num_workers=2)
  test_base_loader = torch.utils.data.DataLoader(test_base, batch_size=TEST_BATCH, shuffle=False, num_workers=2)
  test_novel_loader = torch.utils.data.DataLoader(test_novel, batch_size=TEST_BATCH, shuffle=False, num_workers=2)

  # Get cost function
  loss_fun = torch.nn.CrossEntropyLoss()

  # Train
  if do_train:
    optimizer = optim.AdamW([p for p in clip_model.parameters() if p.requires_grad], lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

    # Train
    train_params = {
        "lora_model": clip_model,
        "train_loader": train_loader,
        "val_loader": val_loader,
        "optimizer": optimizer,
        "loss_fun": loss_fun,
        "scheduler": scheduler,
        "base_classes": base_classes,
				"lambda_weight": LAMBDA_WEIGHT,
        "num_epochs": NUM_EPOCHS,
        "patience": PATIENCE,
        "device": DEVICE
    }
    clip_model, log = disef_train(**train_params)
    # Get training plot
    plot_log(log, "disef")
    # Save weights
    save_lora(clip_model, "disef.pt")
  else:
    load_lora(clip_model, "disef.pt", device=DEVICE)

  # Evaluate base and novel
  base_accuracy = clip_test(clip_model, test_base_loader, base_classes, DEVICE, label="🧠 DISEF evaluation on Base")
  novel_accuracy = clip_test(clip_model, test_novel_loader, novel_classes, DEVICE, label="🧠 DISEF evaluation on Novel")

  # Show results
  print() # For separating from progress bars
  print("="*20,"Evaluating DISEF","="*20)
  print(f"🔍 Base classes accuracy: {base_accuracy*100:.2f}%")
  print(f"🔍 Novel classes accuracy: {novel_accuracy*100:.2f}%")
  print(f"🔍 Harmonic Mean: {harmonic_mean(base_accuracy, novel_accuracy)*100:.2f}%")

Now the evaluation of DISEF can be run.

In [None]:
if RUN_DISEF:
  disef_eval(do_gen=GEN_DISEF, do_train=TRAIN_DISEF)

<div align="center">

| Model      | Base | Novel | Harmonic Mean |
|------------|----------|-----------|---------------|
| Zero-Shot  | 71.33%   | 78.24%    | 74.62%        |
| CoCoOp     | 95,19%   | 71.11%    | 81.41%        |
| CLIP-LoRA     | 96.81 %   | 74.13 %    | 83.96 %        |
| DISEF     |  95.27 %   | 75.44 %    | 84.20 %        |

</div>

Some unexpected results can be noticed. Firstly the synthetic data does not provide any improvement upon base accuracy but also it seems to worsen the prerformance while novel accuracy increases.
The generated data is probably too far from the dataset data. This means that probably the diffusion model used has been trained on data too different from our datasets and does not produce similar enough samples.


Let's inspect the training loop plot to get some more insight:
<div style="text-align: center;">
  <img src="https://raw.githubusercontent.com/Frasor2002/DL_Project/main/report_diagrams/disef_training.png" width="90%">
</div>


It seems that the validation accuracy stays still while the model learns on the training set with only marginal improvements.
Further iterations and tests (especially with different LAMBDA_WEIGTHs) have shown us how our implemented generative pipeline does not produce data that improves drastically the base accuracy as expected by the results in [4].
This could be due to slightly different implementation choices and also different number of shots to generate samples.

We can also notice how the model converges in just 7 epochs making us note how the synthetic data does give some infomation on base accuracy but not enough to boost performance above real training data. Notice also how the novel accuracy has increased compared to vanilla LoRA, this could have some implications on our synthetic data properties but further testing is required.

# Evaluation of improved DISEF


In this section after the previos result we try to modify the approach to reach better accuracy or to at least reduce the high computational load of generating the whole dataset (~ 10 hours to reach a filtered data in the previous step).

The first proposal is to compute f1 scores for every class using zero-shot clip to understand what are the most problematic classes for CLIP to predict and then adjust the generation accordingly. The second is to use a trained CLIP_LoRA model for CLIP filtering. The diagram shows in red the parts of the pipeline we want to act on:
<div style="text-align: center;">
  <img src="https://raw.githubusercontent.com/Frasor2002/DL_Project/main/report_diagrams/our_disef_pipeline.png" width="80%">
</div>


## Synthetic sample generation

The work of He et al.[5] motivated us to use performance score on the validation set to set the number of images to generated. We decided to generate images only for classes where the f1 score was lower than 70%.
This allowed us to skip 28 base classes that have f1 score > 70% as seen in this plot:

<div style="text-align: center;">
  <img src="https://raw.githubusercontent.com/Frasor2002/DL_Project/main/report_diagrams/flower_accuracy_chart.png" width="80%">
</div>

Reducing the number of samples this way made the generation faster by 55% reducing the computational cost of generating images.

In [None]:
def get_f1_scores(model, loader, categories, device="cuda:0", label=""):
  """"Return class f1 scores given a model and a DataLoader.
  Args:
    model (nn.Module): CLIP model to compute f1 scores.
    loader (DataLoader): DataLoader to test on.
    categories (list): list of categories.
    device (str): device to use.
    label (str): label for progress bar.
  Returns:
    dict: Dictionary mapping label to F1 score.
  """
  # Model in eval mode
  model.eval()

  # Mappings for prediction
  contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}
  idx2cat = {v: k for k, v in contig_cat2idx.items()}  # Reverse map for output

  # Save predictions and target labels
  all_preds = []
  all_tgts = []

  # Tokenized inputs for CLIP
  text_inputs = clip.tokenize([PROMPT_FORMAT.format(CLASS_NAMES[c]) for c in categories]).to(device)

  with torch.no_grad():
    # Text features
    text_features = model.encode_text(text_inputs)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    for images, labels in tqdm(loader, desc=label):
      # Map labels to contiguous set
      labels = torch.tensor([contig_cat2idx[l.item()] for l in labels]).long()

      images, labels = images.to(device), labels.to(device)
      # Image features
      image_features = model.encode_image(images)
      image_features /= image_features.norm(dim=-1, keepdim=True)

      logits = image_features @ text_features.T
      predicted_class = logits.argmax(dim=-1)

      all_preds.extend(predicted_class.to("cpu").numpy())
      all_tgts.extend(labels.to("cpu").numpy())

  # Use sklearn to compute metric
  precision, recall, f1, _ = precision_recall_fscore_support(all_tgts, all_preds, labels=list(range(len(categories))), zero_division=0)

  # Map class index to F1 score
  f1_per_class = {idx2cat[i]: f1[i] for i in range(len(categories))}

  return f1_per_class

The logic used to compute the amount of samples per class is very simple and aims to reduce computational cost, more complex strategies should be tested in order to find the best one to reduce cost and maximize performance.

In [None]:
def get_samples_per_class(f1_scores, max_samples):
  """Given f1 scores we compute samples to generate for every class.
  Args:
    f1_scores (dict): map labels -> class f1_score
    max_samples (int): max samples to generate for a class.
  Returns:
    dict: map labels -> samples to generate.
  """
  samples_per_class = {}

  threshold = 0.70
  for label, f1 in f1_scores.items():
    samples_per_class[label] = max_samples if f1 < threshold else 0
  return samples_per_class

## Training and evaluation
After the new strategy has been defined,we perform the evaluation step keeping the same hyperparameters as in the previous section.


In [None]:
def our_disef_eval(do_gen=True, do_train=True):
  """"Evaluate DISEF technique.
  Args:
    do_gen (bool): generate synthetic samples or not.
    do_train (bool): train model or load weights.
  """
  # Batch sizes
  TRAIN_BATCH = 32 # From CLIP-LoRA paper
  VAL_BATCH = 64
  TEST_BATCH = 128

  # Get datasets
  train_base, val_base, test_base, test_novel, base_classes, novel_classes = get_dataset(do_preprocess=True)

  # Get loaders
  val_loader = torch.utils.data.DataLoader(val_base, batch_size=VAL_BATCH, shuffle=False, num_workers=2)
  test_base_loader = torch.utils.data.DataLoader(test_base, batch_size=TEST_BATCH, shuffle=False, num_workers=2)
  test_novel_loader = torch.utils.data.DataLoader(test_novel, batch_size=TEST_BATCH, shuffle=False, num_workers=2)


  if do_gen:
    # Generate synthetic dataset

    # Get samples_per_class with new strategy
    clip_model, clip_preprocess = clip.load(CLIP_BACKBONE, device=DEVICE)
    f1_scores = get_f1_scores(clip_model, val_loader, base_classes, device="cuda:0", label="Getting f1 scores")
    sample_per_class = get_samples_per_class(f1_scores, GENERATION_K_SHOTS)

    # LoRA CLIP for filtering and preprocess
    clip_model = get_clip_lora()
    load_lora(clip_model, "LoRA.pt", device=DEVICE)

    generation_pipeline(sample_per_class, train_base, base_classes, clip_model, clip_preprocess, DEVICE, OUR_DISEF_GEN_PATH, OUR_DISEF_DEL_PATH, force_regeneration=False)


  syn_dataset = load_syn_dataset(OUR_DISEF_GEN_PATH, do_preprocess=True)

  # Get CLIP-LoRA
  clip_model = get_clip_lora()

  # Get syn loader
  # Create training datasets
  real_dataset = DisefDataset(train_base, is_synthetic=False)
  synthetic_dataset = DisefDataset(syn_dataset, is_synthetic=True)
  # Merge datasets
  merged_dataset = torch.utils.data.ConcatDataset([real_dataset, synthetic_dataset])

  # Get loaders
  train_loader = torch.utils.data.DataLoader(merged_dataset,batch_size=TRAIN_BATCH,shuffle=True,num_workers=2,collate_fn=disef_collate_fn)

  # Hyperparam
  LR = 2e-5 # From CLIP-LoRA paper
  NUM_EPOCHS = 50
  PATIENCE = 3
  LAMBDA_WEIGHT = 0.95
  WEIGHT_DECAY = 1e-2

  # Get cost function
  loss_fun = torch.nn.CrossEntropyLoss()

  # Train
  if do_train:
    optimizer = optim.AdamW([p for p in clip_model.parameters() if p.requires_grad], lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

    # Train
    train_params = {
        "lora_model": clip_model,
        "train_loader": train_loader,
        "val_loader": val_loader,
        "optimizer": optimizer,
        "loss_fun": loss_fun,
        "scheduler": scheduler,
        "base_classes": base_classes,
        "lambda_weight": LAMBDA_WEIGHT,
        "num_epochs": NUM_EPOCHS,
        "patience": PATIENCE,
        "device": DEVICE
    }
    clip_model, log = disef_train(**train_params)
    # Get training plot
    plot_log(log, "our_disef")
    # Save weights
    save_lora(clip_model, "our_disef.pt")
  else:
    load_lora(clip_model, "our_disef.pt", device=DEVICE)

  # Evaluate base and novel
  base_accuracy = clip_test(clip_model, test_base_loader, base_classes, DEVICE, label="🧠 DISEF evaluation on Base")
  novel_accuracy = clip_test(clip_model, test_novel_loader, novel_classes, DEVICE, label="🧠 DISEF evaluation on Novel")

  # Show results
  print() # For separating from progress bars
  print("="*20,"Evaluating DISEF","="*20)
  print(f"🔍 Base classes accuracy: {base_accuracy*100:.2f}%")
  print(f"🔍 Novel classes accuracy: {novel_accuracy*100:.2f}%")
  print(f"🔍 Harmonic Mean: {harmonic_mean(base_accuracy, novel_accuracy)*100:.2f}%")

The evaluation can be now run and the results are logged.

In [None]:
if RUN_OUR_DISEF:
  our_disef_eval(do_gen=GEN_OUR_DISEF, do_train=TRAIN_OUR_DISEF)

<div align="center">

| Model      | Base (↑) | Novel (↑) | Harmonic Mean |
|------------|----------|-----------|---------------|
| Zero-Shot  | 71.33%   | 78.24%    | 74.62%        |
| CoCoOp     | 95,19%   | 71.11%    | 81.41%        |
| CLIP-LoRA     | 96.81 %   | 74.13 %    | 83.96 %        |
| DISEF     |  95.27 %   | 75.44 %    | 84.20 %        |
| Our DISEF     | 95.07 %   | 75.87 %    | 84.39 %        |            |


</div>

A slight reduction in base performance can be noticed, showing us how, while the data reduces base accuracy, by reducing the faulty generated data we do not see an increase in performance. This means that some samples contribute more than other in the reduction of performance.

It seems that the generative pipeline expecially fails when dealing with rare classes, the same that are harder for zero-shot CLIP to correctly predict.

# Conclusion

This work correctly implemented and evaluated CoCoOp and CLIP-LoRA techniques showing also in detail how this innovative strategies are essential to improve performance on a large pretrained model while keeping generalization capabilities.

Then we tried to implement a generative pipeline to add data to improve our base accuracy, motivated by the DISEF[4] work. The results of this step show us how our solution still needs refinining in order to understand how to create a robust generative solution to estimate few-shot base class samples.
Another relevant limitation of our work is the limited amount of GPU resources needed to properly test the generative aspect of the work that is a requirement when dealing with a generative pipeline of this scope.

The issues found in the last part of the project leave room for improvement in future works where finetuning the generative model could be tried or even toying even more with the stable diffusion hyperparams.

# References
1. Radford, Alec, et al. "Learning transferable visual models from natural language supervision." International conference on machine learning. PmLR, 2021.
2. Zhou, Kaiyang, et al. "Conditional prompt learning for vision-language models." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2022.
3. Zanella, Maxime, and Ismail Ben Ayed. "Low-rank few-shot adaptation of vision-language models." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2024.
4. da Costa, Victor G. Turrisi, et al. "Diversified in-domain synthesis with efficient fine-tuning for few-shot classification." arXiv preprint arXiv:2312.03046 (2023).
5. He, Ruifei, et al. "Is synthetic data from generative models ready for image recognition?." arXiv preprint arXiv:2210.07574 (2022).