# Assignment 0 Part 3 - CLIP and Zero-Shot Image Classification

In our previous work, we focused solely on visual information, leveraging convolutional neural networks (CNNs) and Vision Transformers to analyze and interpret images. However, the field of machine learning has expanded to include multimodal models that integrate different types of data. This brings us to CLIP (Contrastive Language-Image Pretraining), a groundbreaking model that combines both text and images to learn their interrelationships.

CLIP, developed by OpenAI, is designed to understand and connect textual descriptions with corresponding images. Unlike traditional models that operate within a single modality, CLIP learns to map both text and images into a shared feature space. This shared space enables the model to grasp the relationship between visual content and textual descriptions, making it possible to perform tasks that involve both types of information. For example, CLIP can interpret a sentence and find the image that best matches this description or generate a textual description for a given image.

The ability to learn representations across different modalities opens up new possibilities for tasks that CLIP wasn't explicitly trained for. One of the most intriguing applications is zero-shot classification, where CLIP can classify images into categories it has never seen during training, simply based on the similarity of the image to textual descriptions of those categories. Another powerful feature is the ability to find the closest image match to a query caption, which can be particularly useful for image retrieval and recommendation systems.

By implementing CLIP, we move beyond single-modality tasks and explore the potential of models that understand and integrate multiple forms of information. This approach not only enhances the capabilities of our models but also expands the range of applications they can address, making them more versatile and powerful in handling complex real-world tasks.

Here is a list of resources to help you with this part:
- [CLIP — Intuitively and Exhaustively Explained](https://medium.com/towards-data-science/clip-intuitively-and-exhaustively-explained-1d02c07dbf40) - A good article to get an overview of CLIP.
- [The Annotated CLIP](https://amaarora.github.io/posts/2023-03-06_Understanding_CLIP.html) - A two-part series on a more implementation-focused overview of CLIP.
- [Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) - The official CLIP paper

In [None]:
from typing import List, Dict, Callable, Optional
from dataclasses import dataclass
import itertools

import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms
from torchvision import datasets

import timm
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer
from tqdm import tqdm

import matplotlib.pyplot as plt

## Task 0. Getting our Data

To implement CLIP effectively, we need a dataset that pairs complete sentences with their corresponding images, enabling the model to learn the relationship between textual descriptions and visual content. One such dataset is Flick8k, which contains thousands of images, each paired with multiple descriptive captions. This dataset is a good fit for training models that require understanding and associating both text and images.

You can download Flick8k using [this link from Kaggle](https://www.kaggle.com/datasets/adityajn105/flickr8k). 

However, it's important to note that replicating the exact training process described in the original CLIP paper would require substantial computational resources. Instead, we can use a subset of the dataset for our experiments, which will allow us to explore CLIP's capabilities without the need for extensive training.

In the CLIP architecture, we need to process both text and images through separate encoders. The Text Encoder converts input captions into embeddings, while the Image Encoder transforms images into a compatible feature space. One critical step in working with text is tokenization. Tokenization involves breaking down a sentence into smaller units, such as words or subwords, which are then converted into numerical representations that the model can understand.

Tokenization prepares the text for embedding by splitting it into manageable pieces and mapping these pieces to unique identifiers. For example, the sentence "A cat sitting on a mat" might be tokenized into individual words or subword units, each of which is then converted into an embedding vector by the Text Encoder. This process allows the model to handle various lengths and structures of text inputs effectively.

By incorporating both text and image encoders into our implementation, and understanding the importance of tokenization, we can create a system that learns to relate textual descriptions to visual content accurately. This approach enables the model to perform tasks such as matching captions to images and understanding the semantic connections between different modalities.

In the cell(s) below, a configuration class has been defined for the hyperparameters of the model, and a `CLIPDataset` class to make working with the data less of a hassle. Take some time to understand how the tokenizer works and what one batch of data would look like.

In [None]:
@dataclass
class Config:
    images_path: str = "flickr8k/Images"
    captions_path: str = "flickr8k/captions.txt"
    batch_size: int = 32
    epochs: int = 3

    image_encoder: str = 'resnet50'
    image_emb_size: int = 2048
    image_size: int = 224
    image_encoder_lr: float = 1e-4
    
    text_encoder: str = 'distilbert-base-uncased'
    text_embedding: str = 768
    text_tokenizer: str = 'distilbert-base-uncased'
    max_length: int = 200
    text_encoder_lr: float = 1e-5

    projection_dim: int = 256
    head_lr: float = 1e-3
    weight_decay: float = 1e-3

cfg = Config()

In [None]:
class CLIPDataset(Dataset):
    def __init__(
        self, 
        image_filenames: List[str], 
        captions: List[str], 
        tokenizer: Callable[[List[str]], Dict[str, torch.Tensor]], 
        transform: Optional[Callable[[Image.Image], Image.Image]] = None
    ):
        """
        Initializes the dataset with image filenames, captions, a tokenizer function, and optional image transforms.
        
        :param image_filenames: List of image file names.
        :param captions: List of captions corresponding to the images.
        :param tokenizer: Function to tokenize captions. It should return a dictionary with tensors.
        :param transform: Optional transform to be applied on images.
        """
        self.image_filenames = image_filenames
        self.captions = captions
        self.tokenizer = tokenizer
        self.transform = transform

        # Tokenize all captions
        self.encoded_captions = tokenizer(captions, padding=True, truncation=True, max_length=cfg.max_length)
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """
        Retrieves an item from the dataset given an index.

        :param idx: Index of the item to retrieve.
        :return: Dictionary with 'image' and 'caption' keys.
        """
        # Get encoded caption
        encoded_caption = {key: torch.tensor(values[idx]) for key, values in self.encoded_captions.items()}
        
        # Load and process image
        image_path = f"{cfg.images_path}/{self.image_filenames[idx]}"
        image = Image.open(image_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)
        
        # Return dictionary with tensors
        return {
            'image': image,
            'caption': encoded_caption
        }
    
    def __len__(self) -> int:
        """
        Returns the number of items in the dataset.

        :return: Number of items.
        """
        return len(self.captions)



df = pd.read_csv(cfg.captions_path)
print(f"Size of original dataset: {df.shape}")

# Remove rows of duplicate images (otherwise we have one image with multiple captions)
df = df.drop_duplicates(subset='image', keep='first')
print(f"Size after deduplication: {df.shape}")

tokenizer = DistilBertTokenizer.from_pretrained(cfg.text_tokenizer)

clip_ds = CLIPDataset(
    image_filenames=df.image.values.tolist(),
    captions=df.caption.values.tolist(),
    tokenizer=tokenizer,
    transform=transforms.Compose([
        transforms.Resize((cfg.image_size, cfg.image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
)
clip_dl = DataLoader(
    clip_ds,
    batch_size=cfg.batch_size,
    shuffle=True
)

## Task 1. Image Encoder

One of the core components of CLIP is the Image Encoder, which transforms images into embeddings that are compared against text embeddings. This process allows CLIP to map visual information into the same feature space as textual descriptions, enabling effective cross-modal comparisons.

For our implementation, we will use the `timm` library to load a pre-trained ResNet50 model. ResNet50 is a popular convolutional neural network known for its robust feature extraction capabilities. However, instead of using the full model with its classifier head, we'll focus on the feature extractor backbone. This allows us to leverage ResNet50's powerful feature extraction while bypassing the classification layer, which is unnecessary for our purpose.

Take some time to go through the documentation for `timm` to load in the feature extractor for a ResNet50.

To enhance training efficiency, consider freezing the parameters of the Image Encoder. Since the focus is on training the projection heads for the CLIP model, freezing the Image Encoder parameters can prevent unnecessary updates and speed up the training process. This approach allows the model to concentrate on learning how to project the image embeddings into the shared feature space.

When performing a forward pass through the Image Encoder, it's important to note the dimensionality of the output features. We are aiming for a tensor of shape $(B, d)$, so if the tensor has a higher dimensionality, consider performing some form of global pooling.

By implementing and fine-tuning the Image Encoder in this way, we ensure that our model effectively converts images into a form that can be used in conjunction with text embeddings, facilitating accurate cross-modal comparisons and achieving the desired functionalities of CLIP.

In [None]:
class ImageEncoder(nn.Module):
    def __init__(self, 
                 model_name: str,
                 trainable: bool = False):
        raise NotImplementedError
    
    def forward(self, x: torch.tensor):
        raise NotImplementedError

In [None]:
img_batch = next(iter(clip_dl))['image']
print(img_batch.shape)

img_encoder = ImageEncoder(cfg.image_encoder)
img_enc_out = img_encoder(img_batch)

print(img_enc_out.shape)

## Task 2. Text Encoder

The Text Encoder in CLIP is responsible for transforming textual descriptions into embeddings that can be compared with image embeddings. For this purpose, we will use DistilBERT, a lighter and faster variant of the BERT model. DistilBERT retains most of BERT’s language understanding capabilities while reducing its size and computational requirements, making it well-suited for our task.

**Tokenization** is a crucial step in processing text for neural network models. It involves breaking down text into smaller units, such as words or subwords, and converting these units into numerical representations that the model can understand. For example, the sentence "A cat sitting on a mat" might be tokenized into individual words or subword units.

In the context of DistilBERT, tokenization is performed using a tokenizer that splits text into tokens and maps them to corresponding IDs from a vocabulary. The tokenizer also handles special tokens required by the model, such as the [CLS] token used for classification.

When preparing text for input to DistilBERT, two key components are generated:
- **`input_ids`**: These are the token IDs representing the input text. Each token is mapped to a unique identifier from the model's vocabulary.
- **`attention_mask`**: This is a binary mask indicating which tokens should be attended to and which should be ignored (typically, padding tokens are ignored). It helps the model focus on the actual content of the text while disregarding any padding.

After tokenization, the text is passed through DistilBERT to obtain embeddings. The output of DistilBERT includes embeddings for each token, but we are primarily interested in the embedding of the [CLS] token. This token is specially designated for representing the entire sequence, and its embedding serves as the sentence embedding.

For our purposes, the sentence embedding corresponds to the embedding of the [CLS] token, which captures the overall meaning of the input sentence. This embedding can then be used in conjunction with image embeddings to perform various cross-modal tasks, such as matching or classification.

By implementing the Text Encoder with DistilBERT and understanding the roles of tokenization, `input_ids`, `attention_mask`, and the CLS token, we ensure that textual descriptions are effectively transformed into embeddings that align with the image embeddings produced by the Image Encoder.

Take some time to familiarize yourself with [Tokenizers](https://huggingface.co/docs/tokenizers/en/index) and [DistilBERT](https://huggingface.co/docs/transformers/en/model_doc/distilbert).

In [None]:
class TextEncoder(nn.Module):
    def __init__(self, 
                 model_name: str = cfg.text_encoder,
                 trainable: bool = False):
        raise NotImplementedError

    def forward(self,
                input_ids: torch.tensor,
                attention_mask: torch.tensor):
        
        raise NotImplementedError

In [None]:
text_batch = next(iter(clip_dl))['caption']
print(text_batch.keys())

text_encoder = TextEncoder(cfg.text_encoder)
text_enc_out = text_encoder(
    input_ids=text_batch["input_ids"],
    attention_mask=text_batch["attention_mask"]
)

print(text_enc_out.shape)

## Task 3. Creating CLIP

In the CLIP architecture, the model integrates both text and image encoders to produce embeddings in a shared feature space. Since the dimensionalities of text and image embeddings differ, projection heads are employed to map these embeddings into a common space, allowing for effective similarity computation.

The `CLIPModel` class encompasses four primary components:
1. **Text Encoder**: Converts text into embeddings.
2. **Image Encoder**: Converts images into embeddings.
3. **Text Projection Head**: Projects text embeddings into the shared feature space.
4. **Image Projection Head**: Projects image embeddings into the same shared feature space.

The projection heads ensure that both text and image embeddings are in a compatible format, making it possible to compare them directly.

To streamline the training process, you can freeze the parameters of the image and text encoders. This approach focuses training on the projection heads, which are responsible for mapping the embeddings from their respective domains into the shared feature space. By freezing the encoders, you reduce computational overhead and concentrate on fine-tuning the projections, which are crucial for aligning text and image embeddings.

To train CLIP effectively, the loss function requires a comparison of the embeddings in the shared space. Given that the goal is to align text and image embeddings such that their outer product approximates an identity matrix, the targets for the loss function are crucial. 

Since duplicates have been removed, you can use `torch.arange` to set the targets. This approach helps create a target matrix where the diagonal elements represent correct matches between text and image embeddings. This setup facilitates training the model to distinguish between correct and incorrect pairs by comparing the similarities of embeddings.

In [None]:
class ProjectionHead(nn.Module):
    def __init__(self,
                 embedding_dim: int,
                 proj_dim: int = cfg.projection_dim):
        super().__init__()
        self.proj = nn.Linear(embedding_dim, proj_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(proj_dim, proj_dim)
        self.layernorm = nn.LayerNorm(proj_dim)
    
    def forward(self, x: torch.tensor):
        x = self.proj(x)
        x = x + self.fc(self.gelu(x))
        x = self.layernorm(x)
        return x

In [None]:
class CLIPModel(nn.Module):
    def __init__(self, 
                 cfg: Config = cfg):
        raise NotImplementedError

    def forward(self,
                x: dict):
        raise NotImplementedError

## Task 4. Training CLIP

### Training the CLIP Model

Now that we have our `CLIPModel` set up with both text and image encoders, along with their respective projection heads, it’s time to train the model. The training process involves optimizing the alignment between text and image embeddings, so they can be effectively compared.

Since the CLIP model consists of four distinct components (text encoder, image encoder, and their corresponding projection heads), it can be beneficial to use different learning rates for each component. This allows for finer control over the training process, ensuring that each part of the model is updated appropriately according to its role. This, alongside the training setup, has been done for you.

Your task is to complete the training function and plot the training loss curve. This is not at all different from what you've done before.

You are allowed to tweak parts of the training pipeline (including freezing/unfreezing the encoders, training for more epochs etc.) at your own risk - we just want to see a nice hockey-shaped loss curve that looks like it's converging.

In [None]:
def move_tensors_to_device(d: dict, device: torch.device) -> dict:
    """
    Helper function for moving tensors inside (nested) dictionaries to a target device - not necessary to use but can be useful depending on implementation
    
    :param d: Dictionary with potential nested dictionaries and tensors.
    :param device: The device to move tensors to (e.g., torch.device('cuda:0') or torch.device('cpu')).
    :return: A new dictionary with tensors moved to the specified device.
    """
    new_dict = {}
    for k, v in d.items():
        if isinstance(v, dict):
            # Recursively process nested dictionaries
            new_dict[k] = move_tensors_to_device(v, device)
        elif isinstance(v, torch.Tensor):
            # Move tensors to the device
            new_dict[k] = v.to(device)
        else:
            # For non-tensor, non-dict items, just copy them as is
            new_dict[k] = v
    return new_dict

In [None]:
model = CLIPModel()
params = [
    {"params": model.image_encoder.parameters(), "lr": cfg.image_encoder_lr},
    {"params": model.text_encoder.parameters(), "lr": cfg.text_encoder_lr},
    {"params": itertools.chain(
        model.image_projection.parameters(), model.text_projection.parameters()
    ), "lr": cfg.head_lr, "weight_decay": cfg.weight_decay}
]
optimizer = torch.optim.AdamW(params, weight_decay=0.)
epochs = 15

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

In [None]:
# Train the model
raise NotImplementedError

In [None]:
# Plot the loss curve
raise NotImplementedError

## Task 5. Zero-shot Image Classification

Now for a very new type of downstream task.

Zero-shot image classification is a technique where a model can correctly classify images into categories it has never seen during training. Unlike traditional classification models that require training on specific labeled examples for each category, a zero-shot model like CLIP leverages its ability to understand and relate images and text descriptions to classify new categories on the fly. This is possible because CLIP learns a joint embedding space for both images and text, allowing it to generalize to new tasks without additional training.

CLIP achieves zero-shot classification by mapping both images and text (such as category names or descriptions) into the same embedding space. During inference, you provide the model with a set of textual descriptions of potential categories (e.g., "a dog," "a cat," "a car") and an image to classify. The model computes the similarity between the image embedding and each text embedding. The category with the highest similarity score is then chosen as the predicted label for the image.

Here are the steps to perform it:

1. **Prepare Your Text Prompts**: Write down the categories you want to classify the image into. For example, if you're classifying an animal, your categories might be "a dog," "a cat," "a bird," etc. Note you can include full length sentences too, it's not limited to one-word categories anymore.

2. **Encode the Text Prompts**: Use the text encoder part of the CLIP model to convert these category descriptions into embeddings.

3. **Encode the Image**: Use the image encoder part of the CLIP model to convert the image into an embedding.

4. **Compute Similarities**: Calculate the similarity between the image embedding and each of the text embeddings. This step determines how closely the image matches each category description.

5. **Select the Best Match**: Identify the text prompt (category) that has the highest similarity score with the image. This category is your model's prediction.

Your task is to create a function that performs the zero-shot classification as described. Once you've implemented the function, test it with an image of your choice.

You may find the model does not perform spectacularly on some sets of images and prompts, but that's alright: we trained it on a very small dataset with a very small number of steps. You can look into scaling up in your own time.

In [None]:
def zero_shot_classify(clip_model: nn.Module,
                       tokenizer: Callable[[List[str]], Dict[str, torch.Tensor]],
                       images: torch.Tensor,
                       class_descriptions: List[str],
                       transform: Optional[Callable[[Image.Image], Image.Image]] = None,
                       device: str = "cuda") -> torch.Tensor:
    """
    Perform zero-shot classification on images using textual class descriptions.

    :param clip_model: PyTorch nn.Module class for the CLIP model
    :param images: Tensor of images to classify.
    :param class_descriptions: List of class descriptions (prompts) for zero-shot classification.
    :return: Tensor of predicted class indices.
    """
    raise NotImplementedError

In [None]:
!wget -O dummy_img.jpg https://cdn11.bigcommerce.com/s-t04x4i8lh4/product_images/uploaded_images/how-to-exercise-your-dog-in-the-winter.jpg

In [None]:
from PIL import Image
img = Image.open("dummy_img.jpg")
img

In [None]:
zero_shot_classify(
    model,
    tokenizer,
    img,
    ["car driving through a city", "dog running through snow", "helicopter smashing into a tree", "students crying over an exam"],
    transform=transforms.Compose([
        transforms.Resize((cfg.image_size, cfg.image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
)

## Fin.