In [1]:
import os
import sys

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.patches as mpatches
from mpl_toolkits.mplot3d import Axes3D
from PIL import Image
from tqdm.auto import tqdm
from einops import rearrange

from transformers import CLIPProcessor

import torch.multiprocessing as mp
try:
    mp.set_start_method("fork", force=True)
except RuntimeError:
    # Start method already set in this session
    pass

sys.path.append('../..')
from fast_nystrom_attention import CLIPModelFNA

In [2]:
class COCODataset(Dataset):
    def __init__(self, input_filename: str, processor: CLIPProcessor,
                 img_key: str = "filepath", caption_key: str = "captions", sep: str = "\t", token_labels_file: str = None, image_size: int = None):
        """
        Args:
            input_filename (str): Path to the CSV file.
            processor (CLIPProcessor): Hugging Face CLIP processor for images and text.
            img_key (str): Column name for image file paths.
            caption_key (str): Column name for text captions.
            sep (str): Separator used in the CSV file.
            token_labels_file: Path to token labels file (optional).
            image_size (int): The size to resize images to.
        """
        df = pd.read_csv(input_filename, sep=sep)
        self.images = df[img_key].tolist()
        self.captions = df[caption_key].tolist()
        self.processor = processor
        
        if image_size:
            self.processor.image_processor.size = {"shortest_edge": image_size}
            self.processor.image_processor.crop_size = {"height": image_size, "width": image_size}
        
        if token_labels_file:
            self.token_labels = torch.load(token_labels_file)

    def __len__(self) -> int:
        return len(self.captions)

    def __getitem__(self, idx):
        image = Image.open(str(self.images[idx]))
        if image.mode != "RGB":
            image = image.convert("RGB")
        
        captions_list = eval(self.captions[idx])
        
        # Process image and text
        inputs = self.processor(text=captions_list, images=image, return_tensors="pt", padding=True, truncation=True)
        
        image = inputs['pixel_values'].squeeze(0)
        texts = inputs['input_ids']
        
        if hasattr(self, 'token_labels'):
            token_labels = self.token_labels[idx]
        else:
            token_labels = None

        return image, texts, token_labels

def custom_collate_fn(batch):
    """
    Custom collate function to handle variable number of captions per image and
    optional token_labels. Expects each sample to be a tuple of:
        (image, texts, token_labels)
    where token_labels can be None.
    """
    # Unpack the batch. token_labels may be None.
    images, texts, token_labels = zip(*batch)
    
    # Stack images into a batch tensor.
    images = torch.stack(images, dim=0)  # Shape: [B, 3, H, W]
    
    # Process texts: pad and flatten while tracking number of captions per image.
    num_captions = [len(t) for t in texts]
    
    # Find the max sequence length in the batch
    max_len = max(t.shape[1] for t in texts)
    
    # Pad all text tensors to the max sequence length
    padded_texts = [F.pad(t, (0, max_len - t.shape[1]), 'constant', 0) for t in texts]
    
    flattened_texts = torch.cat(padded_texts, dim=0)  # Shape: [sum(num_captions), max_seq_len]
    
    # Process token_labels if available (assumes similar structure as texts).
    if token_labels[0] is not None:
        return images, flattened_texts, num_captions, token_labels
    else:
        return images, flattened_texts, num_captions

def get_coco_dataloader(input_filename: str, processor: CLIPProcessor, batch_size: int = 32, 
                          shuffle: bool = False, num_workers: int = 4, token_labels_file: str = None, image_size: int = None):
    dataset = COCODataset(input_filename, processor, token_labels_file=token_labels_file, image_size=image_size)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        collate_fn=custom_collate_fn,
        multiprocessing_context="fork",
        persistent_workers=(num_workers > 0),
        pin_memory=True
    )
    return dataloader

In [3]:
def compute_retrieval_metrics(image_embeddings, text_embeddings, text_gt_indices, k_values=[1, 5, 10]):
    """
    Compute retrieval metrics (Recall@K) for text-to-image and image-to-text retrieval.
    
    Args:
      - image_embeddings (torch.Tensor): Normalized image embeddings of shape (num_images, D)
      - text_embeddings (torch.Tensor): Normalized text embeddings of shape (num_texts, D)
      - text_gt_indices (list[int]): List of length num_texts where each element is the ground truth image index for that caption.
      - k_values (list): List of K values for computing recall.
    
    Returns:
      - dict: Retrieval metrics for text-to-image and image-to-text
    """
    # Ensure embeddings are of compatible dimensions
    if image_embeddings.dim() != 2 or text_embeddings.dim() != 2:
        raise ValueError(f"Embeddings must be 2D tensors. Got image_embeddings:{image_embeddings.shape}, text_embeddings:{text_embeddings.shape}")
    
    if image_embeddings.shape[1] != text_embeddings.shape[1]:
        raise ValueError(f"Embedding dimensions don't match: {image_embeddings.shape[1]} vs {text_embeddings.shape[1]}")
        
    # Compute similarity matrix: (num_images, num_texts)
    similarity_matrix = image_embeddings @ text_embeddings.T

    num_images = image_embeddings.shape[0]
    num_texts = text_embeddings.shape[0]

    # --- Text-to-Image Retrieval ---
    text_to_image_metrics = {}
    for k in k_values:
        recalls = []
        for text_idx in range(num_texts):
            gt_image_idx = text_gt_indices[text_idx]
            scores = similarity_matrix[:, text_idx]
            top_k_indices = torch.topk(scores, k).indices
            recalled = (gt_image_idx in top_k_indices)
            recalls.append(recalled)
        text_to_image_metrics[f'R@{k}'] = np.mean(recalls) * 100

    # --- Image-to-Text Retrieval ---
    image_to_text_metrics = {}
    # Build mapping from image index to set of text indices (ground truth)
    image_to_text_map = {i: set() for i in range(num_images)}
    for text_idx, img_idx in enumerate(text_gt_indices):
        image_to_text_map[img_idx].add(text_idx)
    
    for k in k_values:
        recalls = []
        for img_idx in range(num_images):
            gt_text_indices = image_to_text_map[img_idx]
            scores = similarity_matrix[img_idx, :]
            top_k_texts = set(torch.topk(scores, k).indices.tolist())
            recalled = len(gt_text_indices.intersection(top_k_texts)) > 0
            recalls.append(recalled)
        image_to_text_metrics[f'R@{k}'] = np.mean(recalls) * 100

    return {
        'text_to_image': text_to_image_metrics,
        'image_to_text': image_to_text_metrics
    }

def run_retrieval_evaluation(model, val_dataloader, device):
    """
    Compute retrieval metrics for the validation dataset.
    
    Each image may have a variable number of captions. For each image,
    we encode the image once, and we flatten all captions across the batch,
    recording which image each caption came from.
    
    Returns:
      - dict: Retrieval metrics.
    """
    image_embeddings_list = []
    text_embeddings_list = []
    text_gt_indices = []  # Ground-truth image index for each caption
    global_image_counter = 0  # Keeps track of the image index across batches
    patch_size = model.vision_model.config.patch_size
    
    with torch.no_grad():
        #for batch_idx, (images, texts, num_captions_list, token_labels) in enumerate(tqdm(val_dataloader)):
        for batch_idx, batch in enumerate(tqdm(val_dataloader)):
            # Handle different return formats from dataloader
            if len(batch) == 3:
                images, texts, num_captions_list = batch
            else:  # len(batch) == 4
                images, texts, num_captions_list, token_labels = batch
                
            # images: tensor of shape [B, 3, H, W]
            # texts: tensor of shape [sum(num_captions), seq_length]
            images = images.to(device)
            # Create boolean tensor of shape B x N where the first element in each N is True
            feature_h, feature_w = images.shape[-2] // patch_size, images.shape[-1] // patch_size
            cls_labels = torch.zeros(images.shape[0], feature_h*feature_w+1, dtype=torch.bool, device=device)
            cls_labels[:, 0] = True
            zero_labels = torch.zeros_like(cls_labels, dtype=torch.bool, device=device)
            model.load_cache({"mask_dict": {"guarantee": cls_labels, "exclude": zero_labels}})
            image_emb = model.get_image_features(images, interpolate_pos_encoding=True)  # [B, D]
            image_emb = F.normalize(image_emb, p=2, dim=-1)
            image_embeddings_list.append(image_emb)
            
            for i, num_captions in enumerate(num_captions_list):
                # Record that these captions correspond to image with global index (global_image_counter + i)
                text_gt_indices.extend([global_image_counter + i] * num_captions)

            texts = texts.to(device)
            text_emb = model.get_text_features(input_ids=texts)  # [sum(num_captions), D]
            if isinstance(text_emb, tuple):
                text_emb = text_emb[0]  # Extract the embeddings if it's a tuple
            text_emb = F.normalize(text_emb, p=2, dim=-1)
            text_embeddings_list.append(text_emb)
            
            global_image_counter += images.shape[0]

    # Concatenate all embeddings
    image_embeddings = torch.cat(image_embeddings_list, dim=0)  # [num_images, D]
    text_embeddings = torch.cat(text_embeddings_list, dim=0)    # [num_texts, D]
    
    assert text_embeddings.shape[0] == len(text_gt_indices), "Number of text embeddings doesn't match number of ground truth indices"
    
    metrics = compute_retrieval_metrics(image_embeddings, text_embeddings, text_gt_indices)
    return metrics

In [4]:
MODEL_ID = "openai/clip-vit-large-patch14"
DTYPE = torch.bfloat16
DEVICE = torch.device("cuda:0")
IMAGE_SIZE = 224 * 2
fna_config = {
    'fna_layers': range(13, 24),
    'num_sample': 64,
    'sampling_strategy': 'fps',
    'sampling_features': 'q',
    'resample_fps': False, 
}

processor = CLIPProcessor.from_pretrained(MODEL_ID, use_fast=False)
model = CLIPModelFNA.from_pretrained(
    MODEL_ID, 
    fna_config=fna_config, 
    torch_dtype=DTYPE, 
    device_map=DEVICE
)

model.eval()
model.requires_grad_(False)

val_dataloader = get_coco_dataloader(
    "/home/andrew/codebases/tmp/COCO/annotations/val.csv", 
    processor, 
    batch_size=16, 
    shuffle=False, 
    num_workers=4,
    image_size=IMAGE_SIZE
)

In [5]:
retrieval_metrics = run_retrieval_evaluation(model, val_dataloader, DEVICE)

print("Text-to-Image Retrieval Metrics:")
for k, value in retrieval_metrics['text_to_image'].items():
    print(f"{k}: {value:.2f}%")

print("\nImage-to-Text Retrieval Metrics:")
for k, value in retrieval_metrics['image_to_text'].items():
    print(f"{k}: {value:.2f}%")

  0%|          | 0/313 [00:00<?, ?it/s]

Text-to-Image Retrieval Metrics:
R@1: 34.06%
R@5: 58.96%
R@10: 69.79%

Image-to-Text Retrieval Metrics:
R@1: 52.40%
R@5: 76.54%
R@10: 84.88%
