<p align="center">
  <strong><font size="6">Deep Learning 2025 Project</font></strong><br><br>
  Podavini Luca 257844<br>
  Richichi Andrea 257850<br>
  Sorrentino Francesco 256151
</p>

# Introduction
The aim of this project is to implement a **PEFT** (Parameter Efficient Fine Tuning) technique for CLIP and to find a strategy to improve it.
The chosen task used to evaluate the model is a base-to-novel classification task on the Flowers102 dataset.
The work is divided into:
1. Evaluation of Zero-Shot CLIP
2. Evaluation of CoCoOp
3. Evaluation of CLIP-LoRA
4. Evaluation of DISEF (an improvement on CLIP-LoRA)
5. Evaluation of our improved DISEF

## Import modules
This project requires different modules to work properly some of which are not pre-installed in colab. We will use the following modules:
- torch: torch library.
- torchvision: library containing data.
- clip: containing the pretrained model.
- tqdm: progress bar.
- matplotlib: plotting.

In [33]:
# Install not pre-installed modules
%pip install openai_clip



In [34]:
# Importing modules

# Modules for the whole project
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import clip
from tqdm import tqdm
import matplotlib.pyplot as plt
import os

# Modules for CoCoOp
from collections import OrderedDict
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

# Modules for LoRA
import math
import torch.nn.functional as F

# Modules for DISEF
from PIL import Image

# Modules for our DISEF
from sklearn.metrics import precision_recall_fscore_support

## Define constants
Here we define constants and parameters to run the project

In [35]:
# Constants
# Device where code is run
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

# CLIP Backbone
CLIP_BACKBONE = "ViT-B/16"

# Classnames in the dataset, hardcoded for use later
CLASS_NAMES = ["pink primrose", "hard-leaved pocket orchid", "canterbury bells", "sweet pea", "english marigold", "tiger lily", "moon orchid", "bird of paradise", "monkshood", "globe thistle", "snapdragon", "colt's foot", "king protea", "spear thistle", "yellow iris", "globe-flower", "purple coneflower", "peruvian lily", "balloon flower", "giant white arum lily", "fire lily", "pincushion flower", "fritillary", "red ginger", "grape hyacinth", "corn poppy", "prince of wales feathers", "stemless gentian", "artichoke", "sweet william", "carnation", "garden phlox", "love in the mist", "mexican aster", "alpine sea holly", "ruby-lipped cattleya", "cape flower", "great masterwort", "siam tulip", "lenten rose", "barbeton daisy", "daffodil", "sword lily", "poinsettia", "bolero deep blue", "wallflower", "marigold", "buttercup", "oxeye daisy", "common dandelion", "petunia", "wild pansy", "primula", "sunflower", "pelargonium", "bishop of llandaff", "gaura", "geranium", "orange dahlia", "pink-yellow dahlia?", "cautleya spicata", "japanese anemone", "black-eyed susan", "silverbush", "californian poppy", "osteospermum", "spring crocus", "bearded iris", "windflower", "tree poppy", "gazania", "azalea", "water lily", "rose", "thorn apple", "morning glory", "passion flower", "lotus", "toad lily", "anthurium", "frangipani", "clematis", "hibiscus", "columbine", "desert-rose", "tree mallow", "magnolia", "cyclamen", "watercress", "canna lily", "hippeastrum", "bee balm", "ball moss", "foxglove", "bougainvillea", "camellia", "mallow", "mexican petunia", "bromelia", "blanket flower", "trumpet creeper", "blackberry lily"]
# Prompt default format
PROMPT_FORMAT = "a photo of a {}, a type of flower."

# Parameters to decide what evaluations must be run
RUN_ZERO_SHOT = False
RUN_COCOOP = False
RUN_LORA = False
RUN_DISEF = False
RUN_OUR_DISEF = False


## Load Dataset
Collecting and preprocessing data from torchvision.
We will use the Flowers102 dataset.

The classes are split into base and novel ones by putting half the classes into base and the other half into novel. This only simulates a real application.

In [36]:
def get_data(data_dir="./data", transform=None):
    """Load Flowers102 train, validation and test sets.
    Args:
        data_dir (str): Directory where the dataset will be stored.
        transform (torch.Compose)
    Returns:
        tuple: A tuple containing the train, validation, and test sets.
    """
    train = torchvision.datasets.Flowers102(root=data_dir, split="train", download=True, transform=transform)
    val = torchvision.datasets.Flowers102(root=data_dir, split="val", download=True, transform=transform)
    test = torchvision.datasets.Flowers102(root=data_dir, split="test", download=True, transform=transform)
    return train, val, test

def base_novel_categories(dataset):
  """Split dataset classes into base and novel classes.
  Args:
    dataset (list): Dataset to split into base and novel classes.
  Returns:
    tuple: A tuple containing base and novel classes
  """
  # Set returns the unique set of all dataset classes
  all_classes = set(dataset._labels)
  num_classes = len(all_classes)

  # Generate base and novel category lists
  base_classes = list(range(num_classes))[:num_classes//2]
  novel_classes = list(range(num_classes))[num_classes//2:]
  return base_classes, novel_classes


def split_data(dataset, base_classes):
  """Split sample given base classes.
  Args:
    dataset (list): list of samples.
    base_classes (list): list of base classes.
  Returns:
    tuple: Tuple containing base and novel datasets.
  """
  # List to store sample idx
  base_categories_samples = []
  novel_categories_samples = []

  # Set of base classes to compute the test below in O(1)
  base_set = set(base_classes)

  # Iterate and get sample idx
  for sample_id, label in enumerate(dataset._labels):
    if label in base_set:
      base_categories_samples.append(sample_id)
    else:
      novel_categories_samples.append(sample_id)

  # Create the dataset subsets using Subset
  base_dataset = torch.utils.data.Subset(dataset, base_categories_samples)
  novel_dataset = torch.utils.data.Subset(dataset, novel_categories_samples)
  return base_dataset, novel_dataset


Now we can write a function to load data enabling or not preprocessing with CLIP.

In [37]:
def get_dataset(do_preprocess=True):
  """Load Flowers102 datasets
  Args:
    preprocess (bool): enable or not preprocessing using CLIP.
  Returns:
    tuple: A tuple containing the train, validation, test sets, base classes and novel classes.
  """
  if do_preprocess:
    # Load the CLIP preprocess
    _, clip_preprocess = clip.load(CLIP_BACKBONE, device=DEVICE)

  # Get the three datasets using the CLIP preprocess
  train_set, val_set, test_set = get_data(transform=clip_preprocess)

  # Split classes into base and novel
  base_classes, novel_classes = base_novel_categories(train_set)

  # Split the three datasets
  train_base, _ = split_data(train_set, base_classes)
  val_base, _ = split_data(val_set, base_classes)
  test_base, test_novel = split_data(test_set, base_classes)

  return train_base, val_base, test_base, test_novel, base_classes, novel_classes


Now data can be loaded and preprocessed to get our datasets.

Before evaluating the model, let's study the number of samples and class distribution.

<img src="sample_counts_per_dataset.png" width="600"/>

We have 510 samples for both training set and validation set.
Test base contains 2473 samples while test novel contains 3676 samples.

<img src="class_distribution_Train_Base" width="600"/>
<img src="class_distribution_Val_Base" width="600"/>

*train_base* and *val_base* contain 10 shots for every base class (51 classes * 10 shots).

<img src="class_distribution_Test_Base" width="600"/>

*test_base* contains a non-uniform distribution of samples through classes with the most amount of samples (> 200) on petunia and wallflower classes.

<img src="class_distribution_Test_Novel" width="600"/>

*test_novel* contains a non-uniform distribution of samples through classes with the most amount of samples (> 200) on passion flower.

# Evaluation of Zero-Shot CLIP
Zero-Shot CLIP must be evaluated before applying any technique in order to have a baseline performance to improve.

## Test functions

First a test function is define to be re-used later in the project.

In [38]:
def harmonic_mean(base_accuracy, novel_accuracy):
  """Compute harmonic mean
  Args:
    base_accuracy (float): accuracy score on base classes.
    novel_accuracy (float): accuracy score on novel classes.
  Returns:
    float: harmonic mean.
  """
  numerator = 2
  denominator = 1 / base_accuracy + 1 / novel_accuracy
  hm = numerator / denominator
  return hm

def clip_test(model, loader, categories, device, label=""):
  """Test function for CLIP model.
  Args:
    model (torch.nn): clip pretrained model to use.
    loader (DataLoader): dataloader for evaluation.
    categories (list): either base or novel idxs.
    device (str): device where to put data.
    label (str): label for evaluation loop.
  Returns:
    float: accuracy score.
  """
  # Set model in eval mode
  model.eval()

  # Dictionary for remapping labels label -> into contiguous set
  contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}

  # Apply the standard CLIP template used for oxford flowers to all categories and immediately tokenize each sentence
  text_inputs = clip.tokenize([PROMPT_FORMAT.format(CLASS_NAMES[c]) for c in categories]).to(device)

  with torch.no_grad():
    # Encode text features for all classes
    text_features = model.encode_text(text_inputs)
    # Normalize them (standard pratice with CLIP)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    # Variable to store number of correct predictions
    correct_predictions = 0
    # Iterate through batches
    for image, target in tqdm(loader, desc=label):
      # Map categories to contiguous to get correct predictions
      target = torch.Tensor([contig_cat2idx[t.item()] for t in target]).long()

      image, target = image.to(device), target.to(device)

      # Encode image features
      image_features = model.encode_image(image)
      # Normalize image features
      image_features /= image_features.norm(dim=-1, keepdim=True)

      # Cosine similarity between image and text features and keep the argmax for every row (every image)
      predicted_class = (image_features @ text_features.T).argmax(dim=-1)
      # Check which are correct, and sum them (False == 0, True == 1)
      correct_predictions += (predicted_class == target).sum().item()
  # Compute the accuracy
  accuracy = correct_predictions / len(loader.dataset)
  return accuracy


After having defined a general test function for CLIP we can run an evaluation function for zero-shot model.

In [39]:
def zero_shot_eval():
  """Function to run zero_shot evaluation."""
  # Load the model
  clip_model, _ = clip.load(CLIP_BACKBONE, device=DEVICE)


  # Get the datasets
  _, _, test_base, test_novel, base_classes, novel_classes = get_dataset(do_preprocess=True)

  # Batch size
  TEST_BATCH = 128

  # Get loaders
  test_base_loader = torch.utils.data.DataLoader(test_base, batch_size=TEST_BATCH, shuffle=False, num_workers=2)
  test_novel_loader = torch.utils.data.DataLoader(test_novel, batch_size=TEST_BATCH, shuffle=False, num_workers=2)

  # Evaluate base and novel
  base_accuracy = clip_test(clip_model, test_base_loader, base_classes, DEVICE, label="🧠 Zero-shot evaluation on Base")
  novel_accuracy = clip_test(clip_model, test_novel_loader, novel_classes, DEVICE, label="🧠 Zero-shot evaluation on Novel")

  # Show results
  print() # For separating from progress bars
  print("="*20,"Evaluating Zero-Shot","="*20)
  print(f"🔍 Base classes accuracy: {base_accuracy*100:.2f}%")
  print(f"🔍 Novel classes accuracy: {novel_accuracy*100:.2f}%")
  print(f"🔍 Harmonic Mean: {harmonic_mean(base_accuracy, novel_accuracy)*100:.2f}%")

In [40]:
if RUN_ZERO_SHOT:
  zero_shot_eval()

The baseline obtained with zero-shot evaluation is:

<div align="center">

| Model      | Base (↑) | Novel (↑) | Harmonic Mean |
|------------|----------|-----------|---------------|
| Zero-Shot  | 71.33%   | 78.24%    | 74.62%        |

<div>

# Evaluation of CoCoOp
The first technique applied to CLIP to have a comparison later on is CoCoOp.
This technique introduces the work into the field of prompt tuning by learning textual prompts that improve the zero-shot performance.


Create the `TextEncoder` class to encode dynamic prompts with tokenized prompts given.

In [41]:
class TextEncoder(nn.Module):
  """Encode dynamic prompts with tokenized given prompts."""
  def __init__(self, clip_model):
    """Init the module.
    Args:
      clip_model (torch.nn): clip pretrained model to use.
    """
    super().__init__()
    # Save into the module clip modules
    self.text_encoder = clip_model.transformer
    self.positional_embedding = clip_model.positional_embedding
    self.layer_norm = clip_model.ln_final
    self.proj = clip_model.text_projection

  def forward(self, embedded_tokens, tokens_ids):
    """Forward pass.
    Args:
      embedded_tokens (Tensor): Input token embeddings.
      tokens_ids (Tensor): Tokenizer prompt IDs.
    """
    # Apply positional embeddings
    output = embedded_tokens + self.positional_embedding

    # Rearrange dimension for transformer input
    output = output.permute(1, 0, 2)  # [batch_size, n_ctx, transformer.width] -> [n_ctx, batch_size, transformer.width]
    output = self.text_encoder(output)
    # Go back to original dimensions
    output = output.permute(1, 0, 2)  # [n_ctx, batch_size, transformer.width] -> [batch_size, n_ctx, transformer.width]
    # Apply layer norm
    output = self.layer_norm(output)

    # Select features corresponding to eot tokens
    eot_ids = tokens_ids.argmax(dim=-1)
    output = output[torch.arange(output.shape[0]), eot_ids] @ self.proj

    return output

The `PromptLearner` class then handles the prompt creation given an inital context or a random context lenght.

In [42]:

class PromptLearner(nn.Module):
  """Module that learns ctx vectors to adapt to visual features for each class."""
  def __init__(self, n_ctx, ctx_init, classnames, clip_model):
    """Initialize PromptLearner module.
    Args:
      n_ctx (int): Number of ctx tokens to learn.
      ctx_init (string): Initialization string for context.
      classnames (list): class names list.
      clip_model (nn.Module): pretrained CLIP for token embedding.
    """
    super().__init__()
    # Save number of classes and number of ctx tokens
    self.n_cls = len(classnames)
    self.n_ctx = n_ctx

    # Get dimensions from CLIP
    # Context embedding dimension
    ctx_dim = clip_model.ln_final.weight.shape[0]
    # Visual encoder output dimension
    vis_dim = clip_model.visual.output_dim

    # Get CLIP device
    device = clip_model.token_embedding.weight.device

    # Use given words to initialize context vectors
    ctx_init = ctx_init.replace("_", " ")
    self.n_ctx = len(ctx_init.split(" "))
    prompt = clip.tokenize(ctx_init).to(device)

    with torch.no_grad():
        embedding = clip_model.token_embedding(prompt) # Convert token to embedding

    # Get only context tokens
    ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
    prompt_prefix = ctx_init

    # Register learnable context parameters
    self.ctx = nn.Parameter(ctx_vectors)

    # Meta-net to apply context based on image features
    self.meta_net = nn.Sequential(OrderedDict([
        ("linear1", nn.Linear(vis_dim, vis_dim // 16)),
        ("relu", nn.ReLU(inplace=True)),
        ("linear2", nn.Linear(vis_dim // 16, ctx_dim))
    ]))

    # Preprocess classnames and add them to the prompts
    classnames = [name.replace("_", " ") for name in classnames]
    # Instantiate tokenizer to get lenght of classnames
    _tokenizer = _Tokenizer()
    self.name_lens = [len(_tokenizer.encode(name)) for name in classnames]
    prompts = [prompt_prefix + " " + name + "." for name in classnames]

    # Tokenize prompts and get embeddings
    tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(device)  # (n_cls, n_tkn)
    with torch.no_grad():
        embedding = clip_model.token_embedding(tokenized_prompts)

    # Register unchanged parts of the embeddings
    self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS
    self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :])  # CLS, EOS

    # Save tokenized_prompts
    self.tokenized_prompts = tokenized_prompts  # torch.Tensor

  def construct_prompts(self, ctx, prefix, suffix, label=None):
    """Construct final prompts by concatenating prefix, ctx and suffix tokens.
    Args:
      ctx (Tensor): context tokens.
      prefix (Tensor): prefix tokens.
      suffix (Tensor): suffix tokens.
      label (Tensor): indeces to select specific class prompts.
    """
    if label is not None:
      # Class specific suffix and prefix if labels are given
      prefix = prefix[label]
      suffix = suffix[label]

    prompts = torch.cat(
      [prefix,
      ctx,
      suffix],
      dim=1
    )

    return prompts

  def forward(self, im_features):
    """Forward pass.
    Args:
      im_features (Tensor): image features.
    Returns:
      Tensor: final prompts.
    """
    # Get fixed prefix and suffix token embeddings
    prefix = self.token_prefix
    suffix = self.token_suffix

    # Using the input image features we generate bias to apply to ctx
    bias = self.meta_net(im_features)
    bias = bias.unsqueeze(1)
    ctx = self.ctx.unsqueeze(0)
    ctx_shifted = ctx + bias # Broadcast bias across context

    # Build prompts for each sample in batch using adapted context
    prompts = []
    for ctx_shifted_i in ctx_shifted:
      # Expands for all classes
      ctx_i = ctx_shifted_i.unsqueeze(0).expand(self.n_cls, -1, -1)
      pts_i = self.construct_prompts(ctx_i, prefix, suffix)
      prompts.append(pts_i)

    # Stack prompts for all batch samples
    prompts = torch.stack(prompts)

    return prompts

The `CoCoOpCLIP` class is a wrapper around clip to handle a clean forward method that takes prompts from the `PromptLearner` class.

In [43]:
class CoCoOpCLIP(nn.Module):
  """A CLIP wrapper that adds prompt tuning and computes logits using dynamic text features."""
  def __init__(self, n_ctx, ctx_init, classnames, clip_model):
    """Initiate module with prompt learner and encoders.
    Args:
      n_ctx (int): number of ctx tokens for prompt learner.
      ctx_init (string): context init string for prompt learner.
      classnames (list): list of classnames for prompt learner.
      clip_model (nn.Module): pretrained CLIP to wrap around.
    """
    super().__init__()
    # Prompt learner generates dynamic prompts based on image features
    self.prompt_learner = PromptLearner(n_ctx, ctx_init, classnames, clip_model)
    # Save tokenized prompts from prompt learner
    self.tokenized_prompts = self.prompt_learner.tokenized_prompts
    # Get visual encoder from CLIP
    self.image_encoder = clip_model.visual
    # Get wrapper around clip text encoder
    self.text_encoder = TextEncoder(clip_model)
    # Learnable scaling factor for logits from CLIP
    self.logit_scale = clip_model.logit_scale

  def forward(self, imgs):
    """Forward pass to compute logits between image and text prompts.
    Args:
      imgs: batch of input images.
    Returns:
      Tensor: similarity logits between image and text features.
    """
    # Get tokenized prompts and logit scale
    tokenized_prompts = self.tokenized_prompts
    logit_scale = self.logit_scale.exp()

    # Encode and normalize image features
    image_features = self.image_encoder(imgs)
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)

    # Generate instance-conditioned prompts using image features
    prompts = self.prompt_learner(image_features)

    logits = []
    for pts_i, imf_i in zip(prompts, image_features):
      # Compute similarity for each image in the batch
      text_features = self.text_encoder(pts_i, tokenized_prompts)
      text_features = text_features / text_features.norm(dim=-1, keepdim=True)
      l_i = logit_scale * imf_i @ text_features.t()
      logits.append(l_i)

    # Stack logits for all batch samples
    logits = torch.stack(logits)

    return logits

New training loop and evaluation loop functions have to be created in order to handle the wrapper module. Since we are doing few-shot, in testing phase we need to pass also the base or novel class idxs and remap labels to contiguos values.

In [44]:
def cocoop_test(cocoop_model, loader, categories, device="cuda:0", label=""):
  """Test CoCoOp model.
  Args:
    cocoop_model (nn.Module): CoCoOp model to test.
    loader (DataLoader): dataloader for evaluation.
    categories (list): either base or novel idxs.
    device (str): device where to put data.
  """
  # Variables to compute performance score
  samples = 0
  cumulative_accuracy = 0

  # Set the network to evaluation mode
  cocoop_model.eval()

  # Remap labels into a contiguous set starting from zero
  contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}

  with torch.no_grad():
    for images, targets in tqdm(loader, desc=label):
      # Map categories to the [0, 50], otherwise we will have wrong predictions
      targets = torch.Tensor([contig_cat2idx[t.item()] for t in targets]).long()
      images, targets = images.to(device), targets.to(device)

      outputs = cocoop_model(images)

      # Compute performance scores
      batch_size = images.shape[0]
      samples += batch_size

      _, predicted = outputs.max(1)

      # Compute accuracy
      cumulative_accuracy += predicted.eq(targets).sum().item()

    accuracy = cumulative_accuracy / samples

    return accuracy

For a more clean code, we can also create a training function that trains the model for a given number of epochs. Here we also implement early stopping with patience.

In [45]:
def cocoop_train(cocoop_model, train_loader, val_loader, optimizer, loss_fun, scheduler, base_classes, num_epochs=5, patience=3, device="cuda:0"):
  """Train CoCoOp.
  Args:
    cocoop_model (nn.Module): CoCoOp model to train.
    train_loader (DataLoader): dataloader for training.
    val_loader (DataLoader): dataloader for validation.
    optimizer (torch.optim): optimizer to use.
    loss_fun (torch.nn): cost function to use.
    scheduler (torch.optim.lr_scheduler): scheduler to use.
    base_classes (list): list of base classes for evauation.
    num_epochs (int): number of epochs to train.
    patience (int): patience for early stopping.
    device (str): device to use.
  """
  def train_step(cocoop_model, loader, optimizer, loss_fun, device="cuda:0"):
    """Training step for CoCoOp.
    Args:
      cocoop_model (nn.Module): CoCoOp model to train.
      loader (data): loader for training set.
      optimizer (torch.optim): optimizer to use.
      loss_fun (torch.nn): cost function to use.
      device (str): device to use.
    """
    # Variables to store values to compute loss and accuracy
    samples = 0
    cumulative_loss = 0
    cumulative_accuracy = 0

    # Set the model to training mode
    cocoop_model.train()

    # Iterate over the training set
    for images, targets in loader:
      images = images.to(device)
      targets = targets.to(device)

      outputs = cocoop_model(images)
      loss = loss_fun(outputs, targets)
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()

      batch_size = images.shape[0]
      samples += batch_size
      cumulative_loss += loss.item() * batch_size

      # Get predictions
      _, predicted = outputs.max(dim=1)
      cumulative_accuracy += predicted.eq(targets).sum().item()

    loss = cumulative_loss / samples
    accuracy = cumulative_accuracy / samples

    return loss, accuracy

  # Initialize vars before the train loop
  best_val_acc = float('-inf')
  best_model_state = None
  curr_patience = patience
  pb = tqdm(range(num_epochs))

  for epoch in pb:
    # Training step
    train_loss, train_acc = train_step(cocoop_model, train_loader, optimizer, loss_fun, device)

    # Evaluation step
    val_acc = cocoop_test(cocoop_model, val_loader, base_classes, device)

    # Show val accuracy at every iteration
    pb.set_description(f"Val Acc: {val_acc*100:.2f}%")

    # If we get better metric save the model
    if val_acc > best_val_acc:
      best_val_acc = val_acc
      best_model_state = cocoop_model.prompt_learner.state_dict()
      curr_patience = patience
    else:
      if curr_patience < 1:
        break
      else:
        curr_patience = curr_patience - 1
    # Step the scheduler if provided
    if scheduler is not None:
      scheduler.step()

  # Load the best model before returning
  if best_model_state is not None:
    cocoop_model.prompt_learner.load_state_dict(best_model_state)
  return cocoop_model


The previous training function generates a log and since we want to plot it we can create a utility to show the training progress.

In [46]:
def save_model(model, model_name):
  """Save the model weights.
  Args:
    model (nn.Module): model to save.
    model_name (str): name of the model to save.
  """
  # Save directory location (if does not exist create it)
  save_dir = os.path.join("bin")
  os.makedirs(save_dir, exist_ok=True)

  file_path = os.path.join(save_dir, f"{model_name}.pt")
  torch.save(model.state_dict(), file_path)
  print(f"{model_name} weights saved to {file_path}.")

def load_model(model, model_name, device):
  """Load the model weights.
  Args:
    model (nn.Module): model where to load the weights.
    model_name (str): name of the model to load.
    device (str): device to use.
  """
  file_path = os.path.join("bin", f"{model_name}.pt")

  if not os.path.isfile(file_path):
    raise FileNotFoundError(f"No weights file found at {file_path}")

  state_dict = torch.load(file_path, map_location=torch.device(device))
  model.load_state_dict(state_dict)
  print(f"Model weights loaded from {file_path}")

After all of this we can create an evaluate CoCoOp function that loads the different models, trains on base training set and gives us accuracy after training.

In [47]:
def cocoop_eval(do_train=True):
  """Function to run CoCoOp evaluation.
  Args:
    do_train (bool): train the model or load weights.
  """
  # Load the model
  clip_model, _ = clip.load(CLIP_BACKBONE, device=DEVICE)
  clip_model = clip_model.float() # To avoid weight type errors

  # Get the datasets
  train_base, val_base, test_base, test_novel, base_classes, novel_classes = get_dataset(do_preprocess=True)

  # Hyperparams
  N_CTX = 4
  CTX_INIT = "a photo of a"
  TRAIN_BATCH = 1 # CoCoOp requires a batch of 1 in training to run
  VAL_BATCH = 64
  TEST_BATCH = 128
  LR = 2e-3
  NUM_EPOCHS = 10 # From CoCoOp paper
  PATIENCE = 3

  # Get base and novel classnames
  base_classnames = [CLASS_NAMES[c] for c in base_classes]
  novel_classnames = [CLASS_NAMES[c] for c in novel_classes]

  # One wrapper for base classification and the other for novel by giving different classnames
  base_model = CoCoOpCLIP(N_CTX, CTX_INIT, base_classnames, clip_model).to(DEVICE)
  novel_model = CoCoOpCLIP(N_CTX, CTX_INIT, novel_classnames, clip_model).to(DEVICE)

  # Freeze all CLIP params other than prompt learner
  for name, param in base_model.named_parameters():
      if "prompt_learner" not in name:
          param.requires_grad_(False)
  for name, param in novel_model.named_parameters():
      if "prompt_learner" not in name:
          param.requires_grad_(False)

  # Sanity check
  enabled = set()
  for name, param in base_model.named_parameters():
      if param.requires_grad:
          enabled.add(name)
  print(f"Parameters to be updated: {enabled}")

  # Get loaders
  train_loader = torch.utils.data.DataLoader(train_base, batch_size=TRAIN_BATCH, shuffle=True, num_workers=2)
  val_loader = torch.utils.data.DataLoader(val_base, batch_size=VAL_BATCH, shuffle=False, num_workers=2)
  test_base_loader = torch.utils.data.DataLoader(test_base, batch_size=TEST_BATCH, shuffle=False, num_workers=2)
  test_novel_loader = torch.utils.data.DataLoader(test_novel, batch_size=TEST_BATCH, shuffle=False, num_workers=2)

  # Get cost function
  loss_fun = torch.nn.CrossEntropyLoss()

  # Train
  if(do_train):
    optimizer = optim.AdamW([p for p in base_model.parameters() if p.requires_grad], lr=LR)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

    # Train
    train_params = {
        "cocoop_model": base_model,
        "train_loader": train_loader,
        "val_loader": val_loader,
        "optimizer": optimizer,
        "loss_fun": loss_fun,
        "scheduler": scheduler,
        "base_classes": base_classes,
        "num_epochs": NUM_EPOCHS,
        "patience": PATIENCE,
        "device": DEVICE
    }
    base_model = cocoop_train(**train_params)
    # Save weights
    save_model(base_model.prompt_learner, model_name="CoCoOp")
  else:
    load_model(base_model.prompt_learner, model_name="CoCoOp", device=DEVICE)

  # Load weights on novel model
  novel_model.prompt_learner.ctx.data = base_model.prompt_learner.ctx.data.clone()
  novel_model.prompt_learner.meta_net.load_state_dict(base_model.prompt_learner.meta_net.state_dict())

  # Evaluate base and novel
  base_accuracy = cocoop_test(base_model, test_base_loader, base_classes, DEVICE, label="🧠 CoCoOp evaluation on Base")
  novel_accuracy = cocoop_test(novel_model, test_novel_loader, novel_classes, DEVICE, label="🧠 CoCoOp evaluation on Novel")

  # Show results
  print() # For separating from progress bars
  print("="*20,"Evaluating CoCoOp","="*20)
  print(f"🔍 Base classes accuracy: {base_accuracy*100:.2f}%")
  print(f"🔍 Novel classes accuracy: {novel_accuracy*100:.2f}%")
  print(f"🔍 Harmonic Mean: {harmonic_mean(base_accuracy, novel_accuracy)*100:.2f}%")

Let's now see the results by executing the following evaluation.

In [48]:
if RUN_COCOOP:
  cocoop_eval(do_train=True)

The new CoCoOp results can now be compared with zero-shot:
<div align="center">

| Model      | Base (↑) | Novel (↑) | Harmonic Mean |
|------------|----------|-----------|---------------|
| Zero-Shot  | 71.33%   | 78.24%    | 74.62%        |
| CoCoOp     | 95,19%   | 71.11%    | 81.41%        |

<div>


# Evaluation of CLIP-LoRA
Before implementing the syntetic generation pipeline we will improve, CLIP needs to be adapted to avoid training all the parameters. The technique chosen is Low-Rank Adaption.

This technique reduces the number of params to train and allows to mantain novel performance while increasing base one.

In [49]:
class LoRAGroupedLinear(nn.Linear):
  """Augmented linear layer using a group of low-rank updates."""
  def __init__(self, in_features, out_features, bias, lora_r, lora_alpha, lora_dropout, enable_lora=[False], merge_weights=True):
    """Initialize the module.
    Args:
      in_features (int): number of input features.
      out_features (int): number of output features.
      lora_r (int): rank of low rank decomposition.
      lora_alpha (int): scaling factor for LoRA adjustment.
      lora_dropout (float): dropout prob for dropout in LoRA branch.
      enable_lora (list): list of which output groups should use LoRA.
      merge_weights (bool): if true merge LoRA and model weights for inference.
    """
    # Initialize parent linear layer
    super().__init__(in_features, out_features, bias=bias)

    # Save LoRA input params
    self.lora_r = lora_r
    self.lora_alpha = lora_alpha
    self.lora_dropout = nn.Dropout(lora_dropout) if lora_dropout > 0 else nn.Identity()
    self.merge_weights = merge_weights
    self.enable_lora = enable_lora

    # When initialization occurs weights are unmerged
    self.merged=False

    # Declare as many LoRA params as are the outputs
    if lora_r > 0 and any(enable_lora):
      self.n_lora_out = sum(enable_lora)
      self.n_out = len(enable_lora)
      self.lora_A = nn.Parameter(self.weight.new_zeros((lora_r * self.n_lora_out, in_features)))
      self.lora_B = nn.Parameter(self.weight.new_zeros((out_features // self.n_out * self.n_lora_out, lora_r)))

      self.scaling = self.lora_alpha / self.lora_r

      # Freeze pretrained
      self.weight.requires_grad = False

      # Mask to track where LoRA is applied
      self.lora_ind = self.weight.new_zeros((out_features), dtype=torch.bool).view(self.n_out, -1)
      self.lora_ind[enable_lora, :] = True
      self.lora_ind = self.lora_ind.view(-1)

    # Initialize LoRA params
    self.reset_parameters()

  def reset_parameters(self):
    """Method to reinit linear and LoRA params."""
    super().reset_parameters()
    if hasattr(self, 'lora_A'):
      # lora_A is initialized in a default way as nn.Linear
      nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
      # lora_B is initialized to zero
      nn.init.zeros_(self.lora_B)

  def zero_pad(self, input_seq):
    """Zero pads the input sequence to match output feature dim using LoRA mask.
    Args:
      input_seq (Tensor): input sequence to pad.
    Returns:
      Tensor: padded input sequence.
    """
    # First get zero tensor
    padded = input_seq.new_zeros((len(self.lora_ind), *input_seq.shape[1:]))
    # Substitute input sequence using LoRA mask
    padded[self.lora_ind] = input_seq
    return padded

  def merge_AB(self):
    """Compute effective weight delta from lora_A and lora_B and apply zero padding."""
    weight_delta = F.conv1d(
        self.lora_A.unsqueeze(0),
        self.lora_B.unsqueeze(-1),
        groups=self.n_lora_out
    ).squeeze(0)
    return self.zero_pad(weight_delta)

  def train(self, mode=True):
    """Enable training mode. If merging is enabled, weights are unmerged when in training mode
    and merged back when switching to eval mode."""
    super().train(mode)
    if mode: # If training mode
      if self.merge_weights and self.merged:
        # If merging is active and weights are merged, unmerge
        if self.lora_r > 0 and any(self.enable_lora):
          self.weight.data -= self.merge_AB() * self.scaling
        self.merged = False
    else: # Going in eval mode
      if self.merge_weights and not self.merged:
        # If merging is active and weights are not merged, merge
        if self.lora_r > 0 and any(self.enable_lora):
          self.weight.data += self.merge_AB() * self.scaling
        self.merged = True

  def forward(self, input_seq):
    """Forward pass.
    Args:
      input_seq (Tensor): input sequence.
    Returns:
      Tensor: output sequence.
    """
    output = F.linear(input_seq, self.weight, bias=self.bias)
    if self.merged: # If weights are merged -> eval mode
      return output
    else: # weights not merged -> train mode
      if self.lora_r > 0:
        lora_out = self.lora_dropout(input_seq) @ self.merge_AB().T
        output += lora_out * self.scaling
      return output

Multihead attention layer must be adapted for LoRA.

In [50]:
class LoRAMultiHeadAttention(nn.Module):
  """Multi-Head Attention module augmented with LoRA."""
  def __init__(self, lora_r, lora_alpha, lora_dropout, embed_dim, num_heads, dropout, bias=True, q_lora=True, k_lora=False, v_lora=True):
    """"Initialize the module.
    Args:
      lora_r (int): rank of LoRA update matrices.
      lora_alpha (int): scaling factor for LoRA.
      lora_dropout (float): dropout probability for LoRA.
      embed_dim (int): dimension of input embeddings.
      num_heads (int): number of attention heads.
      dropout (float): dropout probability after attention.
      bias (bool): if True add a learnable bias to projection layers.
      q_lora (bool): enable LoRA for query projection.
      k_lora (bool): enable LoRA for key projection.
      v_lora (bool): enable LoRA for value projection.
    """
    super().__init__()

    # Save embed dim and set dim of key and value vectors equal to it
    self.embed_dim = embed_dim
    self.kdim = embed_dim
    self.vdim = embed_dim

    # Save multihead attention params
    self.num_heads = num_heads
    self.dropout = dropout
    # Dimension of an head in embed_dim / num_heads
    self.head_dim = embed_dim // num_heads

    # Merged linear layer with LoRA
    qkv_params = {
        "in_features": embed_dim,
        "out_features": 3 * embed_dim, # since we get three outputs for query, key and values
        "bias": bias, # enable bias for linear layer
        "lora_r": lora_r,
        "lora_alpha": lora_alpha,
        "lora_dropout": lora_dropout,
        "enable_lora": [q_lora, k_lora, v_lora]
    }
    self.qkv = LoRAGroupedLinear(**qkv_params)

    self.scaled_dot_product_attention = F.scaled_dot_product_attention
    self.proj = nn.Linear(embed_dim, embed_dim, bias=bias)

  def set_parameters(self, mod):
    """Initialize from existing Pytorch modules to load pretrained weights.
    Args:
      mod (nn.Module): Pytorch module with pretrained weights.
    """
    # Copy weights
    self.qkv.weight.data = mod.in_proj_weight.data
    self.qkv.bias.data = mod.in_proj_bias.data
    self.proj.weight.data = mod.out_proj.weight.data
    self.proj.bias.data = mod.out_proj.bias.data



  def forward(self, query, key, value, key_padding_mask=None, attn_mask=None, is_causal=False, need_weights=True):
    """Forward pass to apply Multi-Head Attention.
    Args:
      query (Tensor): query tensor.
      key (Tensor): key tensor.
      value (Tensor): value tensor.
      key_padding_mask (Tensor): mask to ignore padding.
      attn_mask (Tensor): optinal attention mask.
      is_causal (bool): If True, applies a causal mask.
      need_weights (bool): parameter kept for compatibility.
    """
    # Create key padding mask of proper shape and size
    key_padding_mask = F._canonical_mask(
      mask=key_padding_mask,
      mask_name="key_padding_mask",
      other_type=F._none_or_dtype(attn_mask),
      other_name="attn_mask",
      target_type=query.dtype,
    )

    # Get dims
    tgt_len, bsz, embed_dim = query.shape
    src_len, _, _ = key.shape
    E = query.size(-1)

    # Project Q, K, V using grouped LoRA linear layer
    qkv = self.qkv(query)
    # Split qkv from merged Tensor
    qkv = qkv.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
    q, k, v = qkv[0], qkv[1], qkv[2]

    # Canonicalize attention mask
    attn_mask = F._canonical_mask(
      mask=attn_mask,
      mask_name="attn_mask",
      other_type=F._none_or_dtype(key_padding_mask),
      other_name="key_padding_mask",
      target_type=q.dtype,
      check_other=False,
    )

    # If attn_mask param is provided, validate and reshape it
    if attn_mask is not None:
      # Mask dim must be 3
      if attn_mask.dim() == 2:
        attn_mask = attn_mask.unsqueeze(0)

    if attn_mask is not None:
      if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
        attn_mask = attn_mask.unsqueeze(0)
      else:
        attn_mask = attn_mask.view(bsz, self.num_heads, -1, src_len)

    # Use dropout only if in training mode
    dropout_p = self.dropout if self.training else 0.0

    # Prepare q, k, v for attention by splitting heads
    q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
    k = k.view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
    v = v.view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
    src_len = k.size(1)

    # Reshape for attention
    q = q.view(bsz, self.num_heads, tgt_len, self.head_dim)
    k = k.view(bsz, self.num_heads, src_len, self.head_dim)
    v = v.view(bsz, self.num_heads, src_len, self.head_dim)

    # Apply scaled dot product attention
    attn_output = self.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)

    # Recombine attention heads
    attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)

    # Final output projection
    attn_output = self.proj(attn_output)
    attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))

    return attn_output, None

In order for the modules to be properly loaded a function to replace multihead modules in clip must be defined.

In [51]:
def lora_replace_multihead_attention(transformer, lora_r, lora_alpha, lora_dropout):
  """Replace multihead attention in clip model with lora enhanced multihead attention.
  All blocks are replaced in this implementation.
  Args:
    transformer (nn.Module): transformer model where to replace layers.
    lora_r (int): rank of LoRA update matrices.
    lora_alpha (int): scaling factor for LoRA.
    lora_dropout (float): dropout probability for LoRA.
  """
  for resblock in transformer.resblocks:
    # Params taken from resblock
    # References applies LoRA only to q and v
    lora_multihead_params = {
        "lora_r": lora_r,
        "lora_alpha": lora_alpha,
        "lora_dropout": lora_dropout,
        "embed_dim": resblock.attn.embed_dim,
        "num_heads": resblock.attn.num_heads,
        "dropout": resblock.attn.dropout,
        "bias": True,
        "q_lora": True,
        "k_lora": False,
        "v_lora": True
    }
    # Initialize, load weights and substitute module
    lora_attn = LoRAMultiHeadAttention(**lora_multihead_params)
    lora_attn.set_parameters(resblock.attn)
    resblock.attn = lora_attn

  return transformer


In [52]:
def get_clip_lora():
  """Load CLIP substituting LoRA modules into text and visual encoders."""
  # LoRA hyperparams taken from reference work
  LORA_VISUAL_R = 64
  LORA_VISUAL_ALPHA = 32
  LORA_VISUAL_DROPOUT = 0.1
  LORA_TEXT_R = 16
  LORA_TEXT_ALPHA = 32
  LORA_TEXT_DROPOUT = 0.1

  clip_model, _ = clip.load(CLIP_BACKBONE, device="cpu")

  # Apply LoRA to both text and visual encoders
  clip_model.visual.transformer = lora_replace_multihead_attention(clip_model.visual.transformer,
    lora_r=LORA_VISUAL_R, lora_alpha=LORA_VISUAL_ALPHA, lora_dropout=LORA_VISUAL_DROPOUT
  )
  clip_model.transformer = lora_replace_multihead_attention(clip_model.transformer,
    lora_r=LORA_TEXT_R, lora_alpha=LORA_TEXT_ALPHA, lora_dropout=LORA_TEXT_DROPOUT
  )
  clip_model = clip_model.to(DEVICE)

  # Freeze CLIP params
  for name, param in clip_model.named_parameters():
      param.requires_grad = 'lora' in name

  # Sanity check, only LoRA params to train
  trainable = {name for name, param in clip_model.named_parameters() if param.requires_grad}
  print(f"Trainable params: {trainable}")

  return clip_model


Since we want to save only LoRA weights and not all CLIP weights we require some utily function to save and load them.

In [53]:
def save_lora(model):
  """Given a CLIP-LoRA model, save only LoRA weights.
  Args:
    model (nn.Module): model to save.
  """
  save_dir = os.path.join("bin")
  os.makedirs(save_dir, exist_ok=True)

  lora_state_dict = {
    k: v for k, v in model.state_dict().items()
    if model.get_parameter(k).requires_grad
  }
  file_path = os.path.join(save_dir, "LoRA.pt")
  torch.save(lora_state_dict, file_path)
  print(f"LoRA weights saved to {file_path}.")

def load_lora(model, device="cuda:0"):
  """Load into model LoRA only weights.
  Args:
    model (nn.Module): model to load weights.
    device (str): device to use.
  """
  file_path = os.path.join("bin", "LoRA.pt")

  if not os.path.isfile(file_path):
    raise FileNotFoundError(f"No LoRA weights file found at {file_path}")

  lora_state_dict = torch.load(file_path, map_location=torch.device(device))

  # Only load the matching keys
  model.load_state_dict(lora_state_dict, strict=False)
  print(f"LoRA weights loaded from {file_path}")

In [54]:
def lora_train(lora_model, train_loader, val_loader, optimizer, loss_fun, scheduler, base_classes, num_epochs=5, patience=3, device="cuda:0"):
  """"Train CLIP-LoRA model.
  Args:
    lora_model (nn.Module): CLIP-LoRA model to train.
    train_loader (DataLoader): training data loader.
    val_loader (DataLoader): validation data loader.
    optimizer (torch.optim): optimizer to use.
    loss_fun (nn.Module): loss function to use.
    scheduler (torch.optim.lr_scheduler): learning rate scheduler.
    base_classes (list): list of base classes.
    num_epochs (int): number of epochs to train.
    patience (int): patience for early stopping.
    device (str): device to use.
  """
  def train_step(lora_model, loader, optimizer, loss_fun, base_classes, device="cuda"):
    """Train CLIP-LoRA model for one epoch.
    Args:
      lora_model (nn.Module): CLIP-LoRA model to train.
      loader (DataLoader): train data loader.
      optimizer (torch.optim): optimizer to use.
      loss_fun (nn.Module): loss function to use.
      base_classes (list): list of base classes.
      device (str): device to use.
    """
    # Variables to compute performance scores
    samples = 0
    cumulative_loss = 0
    cumulative_accuracy = 0

    # Put model in training mode
    lora_model.train()

    # Iterate over training set
    for images, labels in loader:
      # Put images and labels on correct device
      images, labels = images.to(device), labels.to(device)

      # Encode images and normalize image features
      image_features = lora_model.encode_image(images)
      image_features = image_features / image_features.norm(dim=-1, keepdim=True)
      # Encode text and normalize text features
      text_prompts = [PROMPT_FORMAT.format(CLASS_NAMES[cls]) for cls in base_classes]
      text_inputs = clip.tokenize(text_prompts).to(device)
      text_features = lora_model.encode_text(text_inputs)
      text_features = text_features / text_features.norm(dim=-1, keepdim=True)
      # Compute similarity logits
      logit_scale = 100 # This value is used with LoRA
      logits = logit_scale * image_features @ text_features.T
      # Targets -> image i should match with text prompt j (contrastive learning idea)
      #targets = torch.arange(len(images)).to(device)

      loss = loss_fun(logits, labels)
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()

      # Compute performance scores
      batch_size = images.shape[0]
      samples += batch_size
      cumulative_loss += loss.item() * batch_size
      _, predicted = logits.max(dim=1)
      cumulative_accuracy += predicted.eq(labels).sum().item()

    loss = cumulative_loss / samples
    accuracy = cumulative_accuracy / samples
    return loss, accuracy

  # Initalize training vars before train loop
  best_val_acc = float('-inf')
  best_model_state = None
  curr_patience = patience
  pb = tqdm(range(num_epochs))

  for epoch in pb:
    # Training step
    train_loss, train_acc = train_step(lora_model, train_loader, optimizer, loss_fun, base_classes, device)

    # Evaluation step
    val_acc = clip_test(lora_model, val_loader, base_classes, device)

    # Show val accuracy
    pb.set_description(f"Val Acc: {val_acc*100:.2f}%")

    # If we get better metric save the model
    if val_acc > best_val_acc:
      best_val_acc = val_acc
      best_model_state = lora_model.state_dict()
      curr_patience = patience
    else:
      if curr_patience < 1:
        break
      else:
        curr_patience = curr_patience - 1

    # If scheduler is given, scheduler step
    if scheduler is not None:
      scheduler.step()

  # Load best model before returning
  if best_model_state is not None:
    lora_model.load_state_dict(best_model_state)
  return lora_model


In [55]:
def lora_eval(do_train=True):
  """Evaluate CLIP-LoRA.
  Args:
    do_train (bool): train model or load weights.
  """
  # Get CLIP-LoRA
  clip_model = get_clip_lora()

  # Get datasets
  train_base, val_base, test_base, test_novel, base_classes, novel_classes = get_dataset(do_preprocess=True)

  # Hyperparam
  TRAIN_BATCH = 32 # From CLIP-LoRA paper
  VAL_BATCH = 64
  TEST_BATCH = 128
  LR = 2e-4 # From CLIP-LoRA paper
  NUM_EPOCHS = 15
  PATIENCE = 3

  # Get loaders
  train_loader = torch.utils.data.DataLoader(train_base, batch_size=TRAIN_BATCH, shuffle=True, num_workers=2)
  val_loader = torch.utils.data.DataLoader(val_base, batch_size=VAL_BATCH, shuffle=False, num_workers=2)
  test_base_loader = torch.utils.data.DataLoader(test_base, batch_size=TEST_BATCH, shuffle=False, num_workers=2)
  test_novel_loader = torch.utils.data.DataLoader(test_novel, batch_size=TEST_BATCH, shuffle=False, num_workers=2)

  # Get cost function
  loss_fun = torch.nn.CrossEntropyLoss()

  # Train
  if do_train:
    optimizer = optim.AdamW([p for p in clip_model.parameters() if p.requires_grad], lr=LR)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

    # Train
    train_params = {
        "lora_model": clip_model,
        "train_loader": train_loader,
        "val_loader": val_loader,
        "optimizer": optimizer,
        "loss_fun": loss_fun,
        "scheduler": scheduler,
        "base_classes": base_classes,
        "num_epochs": NUM_EPOCHS,
        "patience": PATIENCE,
        "device": DEVICE
    }
    clip_model = lora_train(**train_params)
    # Save weights
    save_lora(clip_model)
  else:
    load_lora(clip_model, device=DEVICE)

  # Evaluate base and novel
  base_accuracy = clip_test(clip_model, test_base_loader, base_classes, DEVICE, label="🧠 LoRA evaluation on Base")
  novel_accuracy = clip_test(clip_model, test_novel_loader, novel_classes, DEVICE, label="🧠 LoRA evaluation on Novel")

  # Show results
  print() # For separating from progress bars
  print("="*20,"Evaluating LoRA","="*20)
  print(f"🔍 Base classes accuracy: {base_accuracy*100:.2f}%")
  print(f"🔍 Novel classes accuracy: {novel_accuracy*100:.2f}%")
  print(f"🔍 Harmonic Mean: {harmonic_mean(base_accuracy, novel_accuracy)*100:.2f}%")

Now LoRA can be properly evaluated.

In [56]:
if RUN_LORA:
  lora_eval(do_train=True)

The new CoCoOp results can now be compared with zero-shot:
<div align="center">

| Model      | Base (↑) | Novel (↑) | Harmonic Mean |
|------------|----------|-----------|---------------|
| Zero-Shot  | 71.33%   | 78.24%    | 74.62%        |
| CoCoOp     | 95,19%   | 71.11%    | 81.41%        |
| CLIP-LoRA     | 96.81 %   | 74.13 %    | 83.96 %        |

<div>


# Evaluation of DISEF

## Synthetic sample generation

In [57]:
# ADD SAMPLE GENERATION CODE

## Training and evaluation

The synthetic dataset must be loaded from current `imgs` directory.

In [58]:
def load_syn_dataset(do_preprocess=False):
  """Load generated samples for training.
  Args:
    do_preprocess (bool): if True preprocess with CLIP preprocess.
  """
  # Collect samples in pairs (img, label)
  syn_dataset = []

  # Get clip preprocess
  if do_preprocess:
    _, preprocess = clip.load(CLIP_BACKBONE, device="cpu")

  for label, class_name in enumerate(CLASS_NAMES):
    current_dir = os.path.join("imgs", class_name)
    # If no folder for a class_name
    if not os.path.isdir(current_dir):
      continue
    # Iterate on img in class folder
    for img_name in os.path.listdir(current_dir):
      img_path = os.path.join(class_dir, img_name)

      img = Image.open(img_path).convert("RGB")
      if do_preprocess:
        img = preprocess(img)
      syn_dataset.append((img, label))
  return syn_dataset




A custom training function is used by DISEF. Also a custom loss is computed.

In [59]:
def disef_train(lora_model, train_loader, syn_loader, val_loader, optimizer, loss_fun, scheduler, base_classes, num_epochs=5, patience=3, device="cuda:0"):
  """"Training a CLIP-LoRA model using synthetic samples.
  Args:
    lora_model (nn.Module): CLIP-LoRA model to train.
    train_loader (DataLoader): training dataloader.
    syn_loader (DataLoader): synthetic training dataloader.
    val_loader (DataLoader): validation dataloader.
    optimizer (torch.optim): optimizer to use.
    loss_fun (nn.Module): loss function to use.
    scheduler (torch.optim.lr_scheduler): learning rate scheduler.
    base_classes (list): list of base classes.
    num_epochs (int): number of epochs to train.
    patience (int): patience for early stopping.
    device (str): device to use.
  """
  def train_step(lora_model, real_loader, syn_loader, optimizer, loss_fun, base_classes, lambda_weight=0.8, device="cuda"):
    """Train CLIP-LoRA model for one epoch.
    Args:
      lora_model (nn.Module): CLIP-LoRA model to train.
      real_loader (DataLoader): real samples training dataloader.
      syn_loader (DataLoader): synthetic samples training dataloader.
      optimizer (torch.optim): optimizer to use.
      loss_fun (nn.Module): loss function to use.
      base_classes (list): list of base classes.
      lambda_weight (float): weight to compute the loss.
      device (str): device to use.
    """
    # Variables to compute performance scores
    samples = 0
    cumulative_loss = 0
    cumulative_accuracy = 0

    # Set training mode
    lora_model.train()

    assert len(real_loader) == len(syn_loader), "real and synthetic loaders must have same batch number"

    # Iterate over training sets
    for (real_images, real_labels), (syn_images, syn_labels) in zip(real_loader, syn_loader):
      # Move to device
      images_real, labels_real = images_real.to(device), labels_real.to(device)
      images_syn, labels_syn = images_syn.to(device), labels_syn.to(device)

      # Real visual features
      fv_real = lora_model.encode_image(images_real)
      fv_real = fv_real / fv_real.norm(dim=-1, keepdim=True)

      # Synthetic visual features
      fv_syn = lora_model.encode_image(images_syn)
      fv_syn = fv_syn / fv_syn.norm(dim=-1, keepdim=True)

      # Text features for all classes
      text_prompts = [PROMPT_FORMAT.format(CLASS_NAMES[cls]) for cls in base_classes]
      text_tokens = clip.tokenize(text_prompts).to(device)
      ft = lora_model.encode_text(text_tokens)
      ft = ft / ft.norm(dim=-1, keepdim=True)

      logit_scale = 100 # Used with CLIP-LoRA
      logits_real = logit_scale * fv_real @ ft.T
      logits_syn = logit_scale * fv_syn @ ft.T

      loss_real = loss_fun(logits_real, labels_real)
      loss_syn = loss_fun(logits_syn, labels_syn)
      loss = lambda_weight * loss_real + (1 - lambda_weight) * loss_syn

      loss.backward()
      optimizer.step()
      optimizer.zero_grad()

      batch_size = images_real.size(0) + images_syn.size(0)
      cumulative_loss += loss.item() * batch_size
      samples += batch_size

      pred_real = logits_real.argmax(dim=1)
      pred_syn = logits_syn.argmax(dim=1)
      cumulative_accuracy += (pred_real == labels_real).sum().item() + (pred_syn == labels_syn).sum().item()

    loss = cumulative_loss / samples
    accuracy = cumulative_accuracy / samples
    return loss, accuracy

  # Initialization before training loop
  best_val_acc = float('-inf')
  best_model_state = None
  curr_patience = patience
  pb = tqdm(range(num_epochs))

  # Logging for plots
  log = {
      "epoch": [],
      "train_loss": [],
      "train_acc": [],
      "val_acc": []
  }

  for epoch in pb:
    # Training step
    lambda_weight = 0.8 # From DISEF paper
    train_loss, train_acc = train_step(lora_model, train_loader, syn_loader, optimizer, loss_fun, base_classes, lambda_weight, device)

    # Evaluation step
    val_acc = clip_test(lora_model, val_loader, base_classes, device)

    # Logging
    log["epoch"].append(epoch)
    log["train_loss"].append(train_loss)
    log["train_acc"].append(train_acc)
    log["val_acc"].append(val_acc)

    # Show val accuracy
    pb.set_description(f"Val Acc: {val_acc*100:.2f}%")

    # If we get better metric save the model
    if val_acc > best_val_acc:
      best_val_acc = val_acc
      best_model_state = lora_model.state_dict()
      curr_patience = patience
    else:
      if curr_patience < 1:
        break
      else:
        curr_patience = curr_patience - 1

    # If scheduler is given, scheduler step
    if scheduler is not None:
      scheduler.step()

  # Load best model before returning
  if best_model_state is not None:
    lora_model.load_state_dict(best_model_state)
  return lora_model, log

A plotting function is provided to understand the performance changes through training.

In [60]:
def plot_log(log):
  """"Plot training log.
  Args:
    log (dict): training log.
  """
  plt.figure(figsize=(15,5))

  # Subplot for Cross Entropy
  plt.subplot(1, 2, 1)
  plt.plot(log["epoch"], log["train_loss"], label="Training Loss", color="#6c757d")
  plt.xlabel('Epochs')
  plt.ylabel('Cross Entropy Loss')
  plt.title('Training Loss')
  plt.legend()
  plt.grid(True)

  # Subplot for accuracy
  plt.subplot(1, 2, 2)
  plt.plot(log["epoch"], log["train_acc"], label="Train Acc", color="#6c757d")
  plt.plot(log["epoch"], log["val_acc"], label="Valid Acc", color="#e9c46a")
  plt.xlabel('Epochs')
  plt.ylabel('Accuracy')
  plt.title('Training and Validation Accuracy')
  plt.legend()
  plt.grid(True)
  plt.tight_layout()

  # Save in current directory
  plt.savefig()
  print("Training plot saved in current path.")

In [61]:
def disef_eval(do_gen=True, do_train=True):
  """"Evaluate DISEF technique.
  Args:
    do_gen (bool): generate synthetic samples or not.
    do_train (bool): train model or load weights.
  """

  if do_gen:
    # Generate synthetic dataset
    # generate_samples()
    pass

  syn_dataset = load_syn_dataset(do_preprocess=True)

  # Get CLIP-LoRA
  clip_model = get_clip_lora()

  # Get datasets
  train_base, val_base, test_base, test_novel, base_classes, novel_classes = get_dataset(do_preprocess=True)

  # Hyperparam
  TRAIN_BATCH = 32 # From CLIP-LoRA paper
  VAL_BATCH = 64
  TEST_BATCH = 128
  LR = 2e-4 # From CLIP-LoRA paper
  NUM_EPOCHS = 15
  PATIENCE = 3

  # Get loaders
  train_loader = torch.utils.data.DataLoader(train_base, batch_size=TRAIN_BATCH, shuffle=True, num_workers=2)
  syn_loader = torch.utils.data.DataLoader(syn_dataset, batch_size=TRAIN_BATCH, shuffle=True, num_workers=2)
  val_loader = torch.utils.data.DataLoader(val_base, batch_size=VAL_BATCH, shuffle=False, num_workers=2)
  test_base_loader = torch.utils.data.DataLoader(test_base, batch_size=TEST_BATCH, shuffle=False, num_workers=2)
  test_novel_loader = torch.utils.data.DataLoader(test_novel, batch_size=TEST_BATCH, shuffle=False, num_workers=2)

  # Get cost function
  loss_fun = torch.nn.CrossEntropyLoss()

  # Train
  if do_train:
    optimizer = optim.AdamW([p for p in clip_model.parameters() if p.requires_grad], lr=LR)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

    # Train
    train_params = {
        "lora_model": clip_model,
        "train_loader": train_loader,
        "syn_loader": syn_loader,
        "val_loader": val_loader,
        "optimizer": optimizer,
        "loss_fun": loss_fun,
        "scheduler": scheduler,
        "base_classes": base_classes,
        "num_epochs": NUM_EPOCHS,
        "patience": PATIENCE,
        "device": DEVICE
    }
    clip_model, log = lora_train(**train_params)
    # Get training plot
    plot_log(log)
    # Save weights
    save_lora(clip_model)
  else:
    load_lora(clip_model, device=DEVICE)

  # Evaluate base and novel
  base_accuracy = clip_test(clip_model, test_base_loader, base_classes, DEVICE, label="🧠 DISEF evaluation on Base")
  novel_accuracy = clip_test(clip_model, test_novel_loader, novel_classes, DEVICE, label="🧠 DISEF evaluation on Novel")

  # Show results
  print() # For separating from progress bars
  print("="*20,"Evaluating DISEF","="*20)
  print(f"🔍 Base classes accuracy: {base_accuracy*100:.2f}%")
  print(f"🔍 Novel classes accuracy: {novel_accuracy*100:.2f}%")
  print(f"🔍 Harmonic Mean: {harmonic_mean(base_accuracy, novel_accuracy)*100:.2f}%")

In [62]:
if RUN_DISEF:
  disef_eval(do_gen=True, do_train=True)

The new CoCoOp results can now be compared with zero-shot:
<div align="center">

| Model      | Base (↑) | Novel (↑) | Harmonic Mean |
|------------|----------|-----------|---------------|
| Zero-Shot  | 71.33%   | 78.24%    | 74.62%        |
| CoCoOp     | 95,19%   | 71.11%    | 81.41%        |
| CLIP-LoRA     | 96.81 %   | 74.13 %    | 83.96 %        |
| DISEF     |  %   |  %    |  %        |

<div>


# Evaluation of improved DISEF

The proposal is to compute f1 scores for every class using zero-shot clip to understand what are the most problematic classes for CLIP to predict and then adjust the generation accordingly.

## Synthetic sample generation

In [63]:
def get_f1_scores(model, loader, categories, device="cuda:0", label=""):
  """"Return class f1 scores given a model and a DataLoader.
  Args:
    model (nn.Module): CLIP model to compute f1 scores.
    loader (DataLoader): DataLoader to test on.
    categories (list): list of categories.
    device (str): device to use.
    label (str): label for progress bar.
  Returns:
    dict: Dictionary mapping label to F1 score.
  """
  # Model in eval mode
  model.eval()

  # Mappings for prediction
  contig_cat2idx = {cat: idx for idx, cat in enumerate(categories)}
  idx2cat = {v: k for k, v in contig_cat2idx.items()}  # Reverse map for output

  # Save predictions and target labels
  all_preds = []
  all_tgts = []

  # Tokenized inputs for CLIP
  text_inputs = clip.tokenize([PROMPT_FORMAT.format(CLASS_NAMES[c]) for c in categories]).to(device)

  with torch.no_grad():
    # Text features
    text_features = model.encode_text(text_inputs)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    for images, labels in tqdm(loader, desc=label):
      # Map labels to contiguous set
      labels = torch.tensor([contig_cat2idx[l.item()] for l in labels]).long()

      images, labels = images.to(device), labels.to(device)
      # Image features
      image_features = model.encode_image(images)
      image_features /= image_features.norm(dim=-1, keepdim=True)

      logits = image_features @ text_features.T
      predicted_class = logits.argmax(dim=-1)

      all_preds.extend(predicted_class.to("cpu").numpy())
      all_tgts.extend(labels.to("cpu").numpy())

  # Use sklearn to compute metric
  precision, recall, f1, _ = precision_recall_fscore_support(all_tgts, all_preds, labels=list(range(len(categories))), zero_division=0)

  # Map class index to F1 score
  f1_per_class = {CLASS_NAMES[idx2cat[i]]: f1[i] for i in range(len(categories))}

  return f1_per_class

In [64]:
# Get datasets
train_base, val_base, test_base, test_novel, base_classes, novel_classes = get_dataset(do_preprocess=True)

# Get loaders
train_loader = torch.utils.data.DataLoader(train_base, batch_size=64, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_base, batch_size=64, shuffle=False, num_workers=2)
test_base_loader = torch.utils.data.DataLoader(test_base, batch_size=128, shuffle=False, num_workers=2)
test_novel_loader = torch.utils.data.DataLoader(test_novel, batch_size=128, shuffle=False, num_workers=2)

# Get f1 scores
clip_model, _ = clip.load(CLIP_BACKBONE, device=DEVICE)
f1_scores_val = get_f1_scores(clip_model, val_loader, base_classes, device="cuda:0", label="Getting f1 scores")
f1_scores_test = get_f1_scores(clip_model, test_base_loader, base_classes, device="cuda:0", label="Getting f1 scores")


Getting f1 scores: 100%|██████████| 8/8 [00:04<00:00,  1.68it/s]
Getting f1 scores: 100%|██████████| 20/20 [00:20<00:00,  1.00s/it]


In [65]:
common_thresh = 0.20 # Thresh to have same classes in val and test incorrectly predicted

print("VAL")
for k in f1_scores_val:
  if f1_scores_val[k] < common_thresh:
    print(f"{k} accuracy: {f1_scores_val[k]:.2f}")
print("TEST")
for k in f1_scores_test:
  if f1_scores_test[k] < common_thresh:
    print(f"{k} accuracy: {f1_scores_test[k]:.2f}")

VAL
hard-leaved pocket orchid accuracy: 0.00
colt's foot accuracy: 0.00
globe-flower accuracy: 0.00
prince of wales feathers accuracy: 0.18
love in the mist accuracy: 0.00
mexican aster accuracy: 0.17
cape flower accuracy: 0.00
great masterwort accuracy: 0.00
sword lily accuracy: 0.11
bolero deep blue accuracy: 0.00
TEST
hard-leaved pocket orchid accuracy: 0.00
colt's foot accuracy: 0.00
globe-flower accuracy: 0.09
prince of wales feathers accuracy: 0.00
love in the mist accuracy: 0.00
mexican aster accuracy: 0.17
cape flower accuracy: 0.00
great masterwort accuracy: 0.00
sword lily accuracy: 0.05
bolero deep blue accuracy: 0.00


## Training and evaluation

In [66]:
def our_disef_eval(do_gen=True, do_train=True):
  """"Evaluate DISEF technique.
  Args:
    do_gen (bool): generate synthetic samples or not.
    do_train (bool): train model or load weights.
  """
  # Get datasets
  train_base, val_base, test_base, test_novel, base_classes, novel_classes = get_dataset(do_preprocess=True)

  # Get loaders
  train_loader = torch.utils.data.DataLoader(train_base, batch_size=TRAIN_BATCH, shuffle=True, num_workers=2)
  val_loader = torch.utils.data.DataLoader(val_base, batch_size=VAL_BATCH, shuffle=False, num_workers=2)
  test_base_loader = torch.utils.data.DataLoader(test_base, batch_size=TEST_BATCH, shuffle=False, num_workers=2)
  test_novel_loader = torch.utils.data.DataLoader(test_novel, batch_size=TEST_BATCH, shuffle=False, num_workers=2)

  # Get f1 scores
  clip_model, _ = clip.load(CLIP_BACKBONE, device=DEVICE)
  f1_scores = get_f1_scores(clip_model, val_loader, base_classes, device="cuda:0", label="Getting f1 scores")


  if do_gen:
    # Generate synthetic dataset
    # generate_samples(f1_scores)
    pass

  syn_dataset = load_syn_dataset(do_preprocess=True)

  # Get CLIP-LoRA
  clip_model = get_clip_lora()

  # Get syn loader
  syn_loader = torch.utils.data.DataLoader(syn_dataset, batch_size=TRAIN_BATCH, shuffle=True, num_workers=2)

  # Hyperparam
  TRAIN_BATCH = 32 # From CLIP-LoRA paper
  VAL_BATCH = 64
  TEST_BATCH = 128
  LR = 2e-4 # From CLIP-LoRA paper
  NUM_EPOCHS = 15
  PATIENCE = 3



  # Get cost function
  loss_fun = torch.nn.CrossEntropyLoss()

  # Train
  if do_train:
    optimizer = optim.AdamW([p for p in clip_model.parameters() if p.requires_grad], lr=LR)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

    # Train
    train_params = {
        "lora_model": clip_model,
        "train_loader": train_loader,
        "syn_loader": syn_loader,
        "val_loader": val_loader,
        "optimizer": optimizer,
        "loss_fun": loss_fun,
        "scheduler": scheduler,
        "base_classes": base_classes,
        "num_epochs": NUM_EPOCHS,
        "patience": PATIENCE,
        "device": DEVICE
    }
    clip_model, log = lora_train(**train_params)
    # Get training plot
    plot_log(log)
    # Save weights
    save_lora(clip_model)
  else:
    load_lora(clip_model, device=DEVICE)

  # Evaluate base and novel
  base_accuracy = clip_test(clip_model, test_base_loader, base_classes, DEVICE, label="🧠 DISEF evaluation on Base")
  novel_accuracy = clip_test(clip_model, test_novel_loader, novel_classes, DEVICE, label="🧠 DISEF evaluation on Novel")

  # Show results
  print() # For separating from progress bars
  print("="*20,"Evaluating DISEF","="*20)
  print(f"🔍 Base classes accuracy: {base_accuracy*100:.2f}%")
  print(f"🔍 Novel classes accuracy: {novel_accuracy*100:.2f}%")
  print(f"🔍 Harmonic Mean: {harmonic_mean(base_accuracy, novel_accuracy)*100:.2f}%")

In [67]:
if RUN_OUR_DISEF:
  our_disef_eval(do_gen=True, do_train=True)

The new CoCoOp results can now be compared with zero-shot:
<div align="center">

| Model      | Base (↑) | Novel (↑) | Harmonic Mean |
|------------|----------|-----------|---------------|
| Zero-Shot  | 71.33%   | 78.24%    | 74.62%        |
| CoCoOp     | 95,19%   | 71.11%    | 81.41%        |
| CLIP-LoRA     | 96.81 %   | 74.13 %    | 83.96 %        |
| DISEF     |  %   |  %    |  %        |            |
| Our DISEF     |  %   |  %    |  %        |            |


<div>


# References
1. Radford, Alec, et al. "Learning transferable visual models from natural language supervision." International conference on machine learning. PmLR, 2021.
2. Zhou, Kaiyang, et al. "Conditional prompt learning for vision-language models." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2022.
3. Zanella, Maxime, and Ismail Ben Ayed. "Low-rank few-shot adaptation of vision-language models." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2024.
4. da Costa, Victor G. Turrisi, et al. "Diversified in-domain synthesis with efficient fine-tuning for few-shot classification." arXiv preprint arXiv:2312.03046 (2023).