Install all dependencies

In [1]:
!pip install --upgrade fsspec==2024.10.0
!pip install --upgrade datasets
!pip install torch torchvision transformers datasets

!pip install ftfy regex tqdm
!git clone https://github.com/openai/CLIP.git

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m14.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl 

Import all necessary libraries

In [2]:
# Import libraries
import torch
import torch.nn as nn
import numpy as np
from torchvision import transforms
from datasets import load_dataset
from torch.utils.data import DataLoader,Dataset
from transformers import CLIPTokenizer

from transformers import CLIPTokenizer, CLIPModel, CLIPProcessor

In [3]:
# Step 1: Download Flickr30k dataset
dataset = load_dataset("nlphuji/flickr30k")

# Check the available splits in the dataset
print(dataset.keys())  # This will print the available split names

# Step 2: Image and Text Preprocessing
# Define transformations for images
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to match CLIP's input size
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))  # ImageNet stats
])

# Tokenize captions using CLIP's tokenizer
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")


# Custom PyTorch Dataset Wrapper
class FlickrDataset(Dataset):
    def __init__(self, hf_dataset):
        self.dataset = hf_dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        # Process image
        image = item['image']
        if not isinstance(image, torch.Tensor):
            image = image_transform(image)
        # Tokenize caption
        caption = item['caption']
        text = tokenizer(caption, return_tensors="pt", truncation=True, padding="max_length")["input_ids"].squeeze(0)
        return {"image": image, "text": text}

# Preprocess and split dataset
train_size = int(0.8 * len(dataset['test']))
train_data = FlickrDataset(dataset['test'].select(range(train_size)))
test_data = FlickrDataset(dataset['test'].select(range(train_size, len(dataset['test']))))

# Create DataLoaders
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False, drop_last=True)



The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/641 [00:00<?, ?B/s]

flickr30k.py:   0%|          | 0.00/2.51k [00:00<?, ?B/s]

0000.parquet:   0%|          | 0.00/506M [00:00<?, ?B/s]

0001.parquet:   0%|          | 0.00/502M [00:00<?, ?B/s]

0002.parquet:   0%|          | 0.00/506M [00:00<?, ?B/s]

0003.parquet:   0%|          | 0.00/512M [00:00<?, ?B/s]

0004.parquet:   0%|          | 0.00/504M [00:00<?, ?B/s]

0005.parquet:   0%|          | 0.00/495M [00:00<?, ?B/s]

0006.parquet:   0%|          | 0.00/495M [00:00<?, ?B/s]

0007.parquet:   0%|          | 0.00/497M [00:00<?, ?B/s]

0008.parquet:   0%|          | 0.00/289M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/31014 [00:00<?, ? examples/s]

dict_keys(['test'])


tokenizer_config.json:   0%|          | 0.00/592 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/862k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.22M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/4.19k [00:00<?, ?B/s]

In [4]:
# Iterate through the loader and inspect the contents of each batch.
for batch_idx, batch in enumerate(test_loader):
    print(f"Batch {batch_idx}:")
    for key, value in batch.items():
        print(f"  {key}: type={type(value)}, shape={value[0].shape if hasattr(value[0], 'shape') else 'N/A'}")
    if batch_idx == 0:  # Print only the first batch for brevity
        break

Batch 0:
  image: type=<class 'torch.Tensor'>, shape=torch.Size([3, 224, 224])
  text: type=<class 'torch.Tensor'>, shape=torch.Size([5, 77])


In [5]:
# Step 4: Check a Sample
sample = next(iter(train_loader))

# Convert the image data to a NumPy array
image_np = np.array(sample['image'][0])
text_np = np.array(sample['text'][0])  # Assuming 'text' is a list of tokenized text

print(f"Sample image shape: {image_np.shape}")  # Check image dimensions
print(f"Sample text shape: {text_np.shape}")  # Check text dimensions

Sample image shape: (3, 224, 224)
Sample text shape: (5, 77)


Baseline implementation

In [6]:
# Load the pre-trained CLIP model and processor
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Baseline Text-to-Image Retrieval
def calculate_similarity(text, images):
    """
    Calculates the cosine similarity between text and image embeddings.

    Args:
        text (torch.Tensor): Text embeddings.
        images (torch.Tensor): Image embeddings.

    Returns:
        similarities (np.ndarray): Similarity matrix.
    """
    with torch.no_grad():
        text = text.view(-1, text.shape[-1]) # Reshape to (batch_size * num_captions, sequence_length)
        text_features = clip_model.get_text_features(text).cpu().numpy()
        image_features = clip_model.get_image_features(images).cpu().numpy()
        # Compute cosine similarity
        similarities = (text_features @ image_features.T) / (np.linalg.norm(text_features, axis=1)[:, None] * np.linalg.norm(image_features, axis=1))
    return similarities


for batch in test_loader:

    # Preprocess and convert images to tensors
    image_list = []
    for img in batch['image']:
        # Convert each image in the batch to a tensor if it's still in list format
        if isinstance(img, list):
            img = torch.tensor(img)  # Convert nested list to tensor
        image_list.append(img)

    # Stack all tensors into a single batch
    images = torch.stack(image_list).to(device)  # Move to device

    # Convert text data into tensor and move to device
    text = batch['text'].to(device)

    # Calculate similarities
    similarities = calculate_similarity(text, images)

    # Debugging output
    #print("Similarity Matrix Shape:", similarities.shape) #Similarity Matrix Shape: (160, 32)
    # Add ranking and metrics computation here


pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

In [8]:
# Learnable Prompt Embedding
class SoftPromptTuning(torch.nn.Module):
    """
    Implements soft prompt tuning for CLIP.

    Args:
        clip_model (CLIPModel): The pre-trained CLIP model.
        prompt_length (int): Length of the soft prompt.
        embedding_dim (int): Dimensionality of the soft prompt embeddings.

    Forward Pass:
        text (torch.Tensor): Input text.
        images (torch.Tensor): Input images.

    Returns:
        text_features (torch.Tensor): Text features after soft prompt tuning.
        image_features (torch.Tensor): Image features.
    """
    def __init__(self, clip_model, prompt_length=5, embedding_dim=512):
        super(SoftPromptTuning, self).__init__()
        self.clip_model = clip_model
        self.prompt_embeddings = torch.nn.Embedding(prompt_length, clip_model.config.text_config.hidden_size)
        self.prompt_length = prompt_length
        self.embedding_dim = embedding_dim

        # Initialize the soft prompt as a learnable parameter
        self.soft_prompt = nn.Parameter(torch.randn(prompt_length, embedding_dim))  # Shape: [prompt_length, embedding_dim]

    def forward(self, text, images):
    # Embed the input text
      with torch.no_grad():
          token_embeddings = self.clip_model.text_model.embeddings(input_ids=text)  # Shape: [batch_size, seq_len, embed_dim]
      token_embeddings = token_embeddings[:, 0, :, :]

      # Reshape soft prompt embeddings if necessary
      batch_size = token_embeddings.size(0)
      soft_prompt = self.soft_prompt.unsqueeze(0).expand(batch_size, -1, -1)  # Shape: [batch_size, prompt_len, embed_dim]


      # Print the shapes of token_embeddings and soft_prompt
      print(f"soft_prompt shape: {soft_prompt.shape}")
      print(f"token_embeddings shape: {token_embeddings.shape}")

      #  # Ensure token_embeddings is a 3D tensor by removing extra dimensions
      # token_embeddings = token_embeddings.squeeze(2)  # Remove extra dimension if token_embeddings has shape [batch_size, seq_len, 1, embed_dim]

      # print(f"token_embeddings shape after squeeze: {token_embeddings.shape}")

      # Concatenate soft prompt embeddings with token embeddings
      augmented_embeddings = torch.cat([soft_prompt, token_embeddings], dim=1)  # Shape: [batch_size, prompt_len + seq_len, embed_dim]

      # Print the shape of augmented_embeddings
      print(f"augmented_embeddings shape: {augmented_embeddings.shape} \n")


      attention_mask = torch.ones(augmented_embeddings.size(0), augmented_embeddings.size(1), device=augmented_embeddings.device)

      # Reshape attention_mask for multi-head attention
      attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)  # Add two new dimensions for num_heads and seq_len


      # Manually forward through the text encoder
      hidden_states = self.clip_model.text_model.encoder(
          inputs_embeds=augmented_embeddings,
          attention_mask=attention_mask,
          #position_ids=position_ids,
      ).last_hidden_state

      # Apply the text pooling (e.g., CLS token or mean pooling)
      text_features = self.clip_model.text_projection(hidden_states[:, 0, :])

      # Pass images through CLIP image model
      image_features = self.clip_model.get_image_features(images)

      return text_features, image_features

# Define contrastive loss
def contrastive_loss(text_features, image_features, temperature=0.07):
    """
    Calculates the contrastive loss between text and image features.

    Args:
        text_features (torch.Tensor): Text features.
        image_features (torch.Tensor): Image features.
        temperature (float): Temperature parameter for scaling logits.

    Returns:
        loss (torch.Tensor): Contrastive loss value.
    """
    logits = text_features @ image_features.T / temperature
    labels = torch.arange(len(logits)).to(logits.device)
    loss = torch.nn.CrossEntropyLoss()(logits, labels)
    return loss

# Training Loop
soft_prompt_model = SoftPromptTuning(clip_model, prompt_length=5).to(device)
optimizer = torch.optim.Adam(soft_prompt_model.parameters(), lr=1e-4)

for epoch in range(10):
    for batch in train_loader:
        images = batch['image'].to(device)
        text = batch['text'].to(device)

        # Forward pass
        text_features, image_features = soft_prompt_model(text, images)

        # Compute loss
        loss = contrastive_loss(text_features, image_features)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
soft_prompt shape: torch.Size([32, 5, 512])
token_embeddings shape: torch.Size([32, 77, 512])
augmented_embeddings shape: torch.Size([32, 82, 512]) 

soft_prompt shape: torch.Size([32, 5, 512])
token_embeddings shape: torch.Size([32, 77, 512])
augmented_embeddings shape: torch.Size([32, 82, 512]) 

soft_prompt shape: torch.Size([32, 5, 512])
token_embeddings shape: torch.Size([32, 77, 512])
augmented_embeddings shape: torch.Size([32, 82, 512]) 

soft_prompt shape: torch.Size([32, 5, 512])
token_embeddings shape: torch.Size([32, 77, 512])
augmented_embeddings shape: torch.Size([32, 82, 512]) 

soft_prompt shape: torch.Size([32, 5, 512])
token_embeddings shape: torch.Size([32, 77, 512])
augmented_embeddings shape: torch.Size([32, 82, 512]) 

soft_prompt shape: torch.Size([32, 5, 512])
token_embeddings shape: torch.Size([32, 77, 512])
augmented_embeddings shape: torch.Size([32, 82, 512]) 

soft_prompt shape: torch.Size([32, 

In [9]:

def evaluate_model(model, test_loader):
    """
    Evaluates the model on the test dataset.

    Args:
        model (torch.nn.Module): The model to evaluate.
        test_loader (DataLoader): DataLoader for the test dataset.

    Returns:
        avg_loss (float): Average loss on the test dataset.
        accuracy (float): Accuracy on the test dataset.
    """
    model.eval()  # Set the model to evaluation mode
    total_loss = 0
    correct = 0
    total_samples = 0

    with torch.no_grad():  # No gradient computation for evaluation
        for batch in test_loader:
            images = batch['image'].to(device)
            text = batch['text'].to(device)

            # Forward pass
            text_features, image_features = model(text, images)

            # Contrastive loss and accuracy
            logits = text_features @ image_features.T
            labels = torch.arange(len(logits)).to(logits.device)
            loss = torch.nn.CrossEntropyLoss()(logits, labels)

            total_loss += loss.item()
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total_samples += labels.size(0)

            print(f"Loss for this batch: {loss.item()}")

    accuracy = correct / total_samples
    avg_loss = total_loss / len(test_loader)

    print(f"Test Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.4%}")
    return avg_loss, accuracy

for epoch in range(10):
  # Evaluate at the end of each epoch
  avg_loss ,accuracy = evaluate_model(soft_prompt_model, test_loader)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
token_embeddings shape: torch.Size([32, 77, 512])
augmented_embeddings shape: torch.Size([32, 82, 512]) 

Loss for this batch: 3.024122953414917
soft_prompt shape: torch.Size([32, 5, 512])
token_embeddings shape: torch.Size([32, 77, 512])
augmented_embeddings shape: torch.Size([32, 82, 512]) 

Loss for this batch: 3.0777974128723145
soft_prompt shape: torch.Size([32, 5, 512])
token_embeddings shape: torch.Size([32, 77, 512])
augmented_embeddings shape: torch.Size([32, 82, 512]) 

Loss for this batch: 3.10418963432312
soft_prompt shape: torch.Size([32, 5, 512])
token_embeddings shape: torch.Size([32, 77, 512])
augmented_embeddings shape: torch.Size([32, 82, 512]) 

Loss for this batch: 2.936718702316284
soft_prompt shape: torch.Size([32, 5, 512])
token_embeddings shape: torch.Size([32, 77, 512])
augmented_embeddings shape: torch.Size([32, 82, 512]) 

Loss for this batch: 2.9720847606658936
soft_prompt shape: torch.Size([32

In [24]:
import torch
import numpy as np

def calculate_recall_at_k(similarities, labels, k=5):
    """
    Calculate Recall@k.

    Args:
        similarities (torch.Tensor): Similarity matrix of shape (num_queries, num_candidates).
        labels (torch.Tensor): True labels indicating the index of the correct match.
        k (int): The value of k for Recall@k.

    Returns:
        recall (float): Recall@k value.
    """

    # Get top-k indices for each query
    top_k_indices = torch.topk(similarities, k, dim=1).indices  # Shape: (num_queries, k)

    # Check if the true label is in the top-k indices
    hits = torch.tensor([
        labels[i].item() in top_k_indices[i] for i in range(labels.shape[0])
    ], dtype=torch.float32)

    # Calculate recall
    recall = hits.mean().item()
    return recall

# Example Usage
def evaluate_retrieval_metrics(model, test_loader, k=5):
    """
    Evaluates the retrieval performance of the model using Recall@k.

    Args:
        model (torch.nn.Module): The model to evaluate.
        test_loader (DataLoader): DataLoader for the test dataset.
        k (int): The value of k for Recall@k.

    Returns:
        recall_at_k (float): Recall@k value.
    """
    model.eval()
    similarities_list = []
    labels_list = []

    with torch.no_grad():
        for batch in test_loader:
            images = batch['image'].to(device)
            text = batch['text'].to(device)
            labels = torch.arange(len(images)).to(device)  # Assuming labels are the indices

            # Compute similarities
            text_features, image_features = model(text, images)
            similarities = text_features @ image_features.T  # Cosine similarity

            # Debugging: Print shapes
            print(f"Similarities shape: {similarities.shape}")
            print(f"Labels shape: {labels.shape}")


            similarities_list.append(similarities)
            labels_list.append(labels)

    # Determine maximum size for padding
    max_size = max(sim.shape[0] for sim in similarities_list)

    # Pad smaller matrices
    padded_similarities = [
        torch.nn.functional.pad(sim, (0, max_size - sim.shape[1], 0, max_size - sim.shape[0]))
        for sim in similarities_list
    ]

    # Combine the padded tensors
    similarities = torch.cat(padded_similarities, dim=0)  # Shape: (num_queries, num_candidates)
    labels = torch.cat(labels_list, dim=0)  # Shape: (num_queries,)

    # Calculate metrics
    recall_at_k = calculate_recall_at_k(similarities, labels, k=k)

    print(f"Recall@{k}: {recall_at_k:.4f}")
    return recall_at_k


In [25]:
recall_at_k = evaluate_retrieval_metrics(soft_prompt_model, test_loader, k=5)  # You can change k if needed

soft_prompt shape: torch.Size([32, 5, 512])
token_embeddings shape: torch.Size([32, 77, 512])
augmented_embeddings shape: torch.Size([32, 82, 512]) 

Similarities shape: torch.Size([32, 32])
Labels shape: torch.Size([32])
soft_prompt shape: torch.Size([32, 5, 512])
token_embeddings shape: torch.Size([32, 77, 512])
augmented_embeddings shape: torch.Size([32, 82, 512]) 

Similarities shape: torch.Size([32, 32])
Labels shape: torch.Size([32])
soft_prompt shape: torch.Size([32, 5, 512])
token_embeddings shape: torch.Size([32, 77, 512])
augmented_embeddings shape: torch.Size([32, 82, 512]) 

Similarities shape: torch.Size([32, 32])
Labels shape: torch.Size([32])
soft_prompt shape: torch.Size([32, 5, 512])
token_embeddings shape: torch.Size([32, 77, 512])
augmented_embeddings shape: torch.Size([32, 82, 512]) 

Similarities shape: torch.Size([32, 32])
Labels shape: torch.Size([32])
soft_prompt shape: torch.Size([32, 5, 512])
token_embeddings shape: torch.Size([32, 77, 512])
augmented_embeddin