# CLIP from Scratch


**CLIP** or **Contrastive Language-Image Pre-training** is a model that learns the relationship between a whole sentence and the image it describes; in a sense that when the model is trained, given an input sentence it will be able to retrieve the most related images corresponding to that sentence. The important thing here is that it is trained on full sentences instead of single classes like car, dog, etc. The intuition is that when trained on whole sentences, the model can learn a lot more things and finds some pattern between images and texts.
They also show that when this model is trained on a huge dataset of images and their corresponding texts, it can also act as a classifier too. I encourage you to study the paper to learn more about this exciting model and their astonishing results on benchmarking datasets . To mention just one, CLIP model trained with this strategy classifies ImageNet better than those SOTA models trained on the ImageNet itself optimized for the only task of classification!

As a **teaser**, let's see what the final model that we will build in this article from scratch is capable of: given a query (raw text) like "a boy jumping with skateboard" or "a girl jumping from swing", the model will retrieve the most relevant images:

![](https://i.ibb.co/9gdYqNP/teaser-cropped.png)

In this notebook, we will see how to implement CLIP from Scratch

## Imports

In [None]:
# !pip install -q timm
# # !pip install ipywidgets==7.6.5


In [None]:
# !pip install --upgrade ipywidgets
# !jupyter nbextension enable --py widgetsnbextension
# !jupyter labextension install @jupyter-widgets/jupyterlab-manager


In [None]:
import os
import cv2
import gc
import numpy as np
import pandas as pd
import itertools
from tqdm.autonotebook import tqdm
import albumentations as A
import matplotlib.pyplot as plt

import torch
from torch import nn
import torch.nn.functional as F
import timm
from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer

In [None]:
# base_dir= r"F:\cv class project\flickr30k_images"
# dataset = r"F:\cv class project\flickr30k_images\results.csv"
# IMG_PATH = r"F:\cv class project\flickr30k_images\flickr30k_images"

## Some pre-preocessing

In [None]:
import json
import pandas as pd
import os
import re

# Path to your JSON file
json_path = r"F:\cv class project\ICFG-PDES\ICFG-PEDES\ICFG-PEDES.json"

# Base directory where images are actually stored
base_img_dir = r"F:\cv class project\ICFG-PDES\ICFG-PEDES\imgs\test"

# Load the JSON dataset
with open(json_path, "r") as f:
    data = json.load(f)

def clean_path(path):
    # Remove all occurrences of 'train/'
    path = path.replace("train/", "")
    # Replace multiple 'test/' occurrences with a single one
    path = re.sub(r'(test/)+', 'test/', path)
    # Remove the initial 'test/' since base_img_dir already includes it
    if path.startswith("test/"):
        path = path[len("test/"):]
    # Join with base directory and normalize
    return os.path.normpath(os.path.join(base_img_dir, path))

# Build the DataFrame
df = pd.DataFrame({
    "image": [clean_path(item["file_path"]) for item in data],
    "caption": [item["captions"][0] for item in data]
})

# Add unique IDs
df["id"] = range(len(df))

# Save to CSV
output_path = r"F:\cv class project\ICFG-PDES\ICFG-PEDES\captions.csv"
df.to_csv(output_path, index=False)

print(f"✅ Preprocessing complete! File saved at: {output_path}")


In [None]:
import pandas as pd
import os

# Load the CSV file
csv_path = r"F:\cv class project\ICFG-PDES\ICFG-PEDES\captions.csv"
df = pd.read_csv(csv_path)

# Check if the file exists for each image path
df["valid_path"] = df["image"].apply(lambda x: os.path.exists(x))

# Count valid and invalid paths
valid_count = df["valid_path"].sum()
invalid_count = len(df) - valid_count

print(f"✅ Valid image paths: {valid_count}")
print(f"❌ Invalid image paths: {invalid_count}")

# Optionally: Save invalid paths to a file for inspection
invalid_df = df[~df["valid_path"]]
invalid_output_path = r"F:\cv class project\ICFG-PDES\ICFG-PEDES\invalid_paths.csv"
invalid_df.to_csv(invalid_output_path, index=False)

print(f"📁 Invalid paths saved to: {invalid_output_path}")


In [None]:
import pandas as pd
import os

# Load the CSV file
csv_path = r"F:\cv class project\ICFG-PDES\ICFG-PEDES\captions.csv"
df = pd.read_csv(csv_path)

# Check if the file exists for each image path
df["valid_path"] = df["image"].apply(lambda x: os.path.exists(x))

# Filter only valid rows
valid_df = df[df["valid_path"]].drop(columns=["valid_path"])  # drop helper column

# Save only the valid paths
cleaned_output_path = r"F:\cv class project\ICFG-PDES\ICFG-PEDES\captions_cleaned.csv"
valid_df.to_csv(cleaned_output_path, index=False)

print(f"✅ Cleaned data saved with {len(valid_df)} valid rows at: {cleaned_output_path}")


In [None]:
df.head(10)

## Config

In [None]:
# class CFG:
#     debug = False
#     image_path = r"F:\cv class project\ICFG-PDES\ICFG-PEDES\imgs\test"
#     captions_path = r"F:\cv class project\ICFG-PDES\ICFG-PEDES\captions_cleaned.csv"
#     batch_size = 64
#     num_workers = 0
#     head_lr = 1e-3
#     image_encoder_lr = 1e-4
#     text_encoder_lr = 1e-5
#     weight_decay = 1e-3
#     patience = 1
#     factor = 0.8
#     epochs = 1
#     device = "cuda" if torch.cuda.is_available() else "cpu"

#     model_name = 'resnet50'
#     image_embedding = 2048
#     text_encoder_model = "distilbert-base-uncased"
#     text_embedding = 768
#     text_tokenizer = "distilbert-base-uncased"
#     max_length = 200

#     pretrained = True # for both image encoder and text encoder
#     trainable = True # for both image encoder and text encoder
#     temperature = 1.0

#     # image size
#     size = 224

#     # for projection head; used for both image and text encoders
#     num_projection_layers = 1
#     projection_dim = 256 
#     dropout = 0.1
# cfg = CFG()

class CFG:
    debug = False
    image_path = r"F:\cv class project\ICFG-PDES\ICFG-PEDES\imgs\test"
    captions_path = r"F:\cv class project\ICFG-PDES\ICFG-PEDES\captions_cleaned.csv"
    batch_size = 64
    num_workers = 0
    epochs = 1
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Using pretrained CLIP model
    pretrained = True  # For both image encoder and text encoder
    trainable = True   # For both image encoder and text encoder

    temperature = 1.0  # You can adjust this for contrastive loss
    size = 224         # Default image size for CLIP


## Utils

In [None]:
class AvgMeter:
    def __init__(self, name="Metric"):
        self.name = name
        self.reset()

    def reset(self):
        self.avg, self.sum, self.count = [0] * 3

    def update(self, val, count=1):
        self.count += count
        self.sum += val * count
        self.avg = self.sum / self.count

    def __repr__(self):
        text = f"{self.name}: {self.avg:.4f}"
        return text

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]

## Dataset

We need to encode both images and their describing texts. We use **Flickr 30k** dataset that contains 31.8k images and caption pairs. 

We will use **DistilBERT** model (which is smaller than BERT but performs nearly as well as BERT) from **HuggingFace** library as our text encoder; so, we need to **tokenize** the sentences (captions) with DistilBERT tokenizer and then feed the token ids (input_ids) and the attention masks to DistilBERT. Therefore, the dataset needs to take care of the tokenization as well. Below you can see the dataset's code. Below that I'll explain the most important things that is happening in the code.

In the **\_\_init\_\_** we receive a tokenizer object which is actually a HuggingFace tokinzer; this tokenizer will be loaded when running the model. We are padding and truncating the captions to a specified max_length. In the **\_\_getitem\_\_** we will first load an encoded caption which is a dictionary with keys input_ids and attention_mask, make tensors out of its values and after that we will load the corresponding image, transform and augment it (if there is any!) and then we make it a tensor and put it in the dictionary with "image" as the key. Finally we put the raw text of the caption with the key "caption" in the dictionary only for visualization purposes. 

I did not use additional data augmentations but you can add them if you want to improve the model's performance.

In [None]:
# def get_transforms(mode="train"):
#     if mode == "train":
#         return A.Compose(
#             [
#                 A.Resize(cfg.size, cfg.size, always_apply=True),
#                 A.Normalize(max_pixel_value=255.0, always_apply=True),
#             ]
#         )
#     else:
#         return A.Compose(
#             [
#                 A.Resize(cfg.size, cfg.size, always_apply=True),
#                 A.Normalize(max_pixel_value=255.0, always_apply=True),
#             ]
#         )


# class CLIPDataset(torch.utils.data.Dataset):
#     def __init__(self, image_filenames, captions, tokenizer, transforms):
#         """
#         image_filenames and cpations must have the same length; so, if there are
#         multiple captions for each image, the image_filenames must have repetitive
#         file names
#         """

#         self.image_filenames = image_filenames
#         self.captions = list(captions)
#         self.encoded_captions = tokenizer(
#             list(captions), padding=True, truncation=True, max_length=cfg.max_length
#         )
#         self.transforms = transforms
#         # Print image paths to a text file before loading
#         with open("image_paths.txt", "w") as f:
#             for filename in self.image_filenames:
#                 image_path = os.path.join(cfg.image_path, filename)
#                 f.write(image_path + "\n")
#         print("Image paths saved to image_paths.txt")
            
#     def __getitem__(self, idx):
#         try:
#             item = {
#                 key: torch.tensor(values[idx])
#                 for key, values in self.encoded_captions.items()
#             }

#             # Construct the full image path
#             image_path = os.path.join(cfg.image_path, self.image_filenames[idx])
#             print(f"Original image path: {image_path}")

#             # Normalize the path to handle different separators (e.g., / or \)
#             image_path = os.path.normpath(image_path)

#             # Split the path into parts for manipulation
#             parts = image_path.split(os.sep)

#             # Remove duplicate consecutive 'test' if present
#             cleaned_parts = []
#             for i, part in enumerate(parts):
#                 # Avoid adding consecutive 'test' directories
#                 if part == "test" and i > 0 and parts[i - 1] == "test":
#                     continue
#                 cleaned_parts.append(part)

#             # Reconstruct the cleaned path
#             image_path = os.sep.join(cleaned_parts)

#             # Double-check the corrected path
#             print(f"Trying to load image from: {image_path}", flush=True)

#             # Read the image
#             image = cv2.imread(image_path)

#             # Handle missing/corrupt images
#             if image is None:
#                 print(f"⚠️ Error loading image: {image_path}", flush=True)
#                 # Return placeholder data with an empty caption
#                 image = np.zeros((cfg.size, cfg.size, 3), dtype=np.uint8)  # Blank image
#                 caption = ""  # Empty caption
#             else:
#                 # Convert BGR to RGB
#                 image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
#                 caption = self.captions[idx]  # Use the original caption

#             # Apply transforms if available
#             image = self.transforms(image=image)["image"]

#             # Prepare the item
#             item["image"] = torch.tensor(image).permute(2, 0, 1).float()
#             item["caption"] = caption

#             return item

#         except Exception as e:
#             print(f"Error processing sample {idx}: {e}")
#             # Return placeholder data with an empty caption
#             image = np.zeros((cfg.size, cfg.size, 3), dtype=np.uint8)  # Blank image
#             caption = ""  # Empty caption
#             return {
#                 "image": torch.tensor(image).permute(2, 0, 1).float(),
#                 "caption": caption,
#             }        
#     def __len__(self):
#         return len(self.captions)



import torchvision.transforms as T

def get_transforms(mode="train"):
    if mode == "train":
        return A.Compose(
            [
                A.Resize(cfg.size, cfg.size, always_apply=True),
                A.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], always_apply=True),
            ]
        )
    else:
        return A.Compose(
            [
                A.Resize(cfg.size, cfg.size, always_apply=True),
                A.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], always_apply=True),
            ]
        )


class CLIPDataset(torch.utils.data.Dataset):
    def __init__(self, image_filenames, captions, tokenizer, transforms):
        """
        image_filenames and captions must have the same length; so, if there are
        multiple captions for each image, the image_filenames must have repetitive
        file names
        """
        self.image_filenames = image_filenames
        self.captions = list(captions)
        self.encoded_captions = tokenizer(
            list(captions), padding=True, truncation=True, max_length=cfg.max_length
        )
        self.transforms = transforms
        
        # Save image paths to a text file before loading
        with open("image_paths.txt", "w") as f:
            for filename in self.image_filenames:
                image_path = os.path.join(cfg.image_path, filename)
                f.write(image_path + "\n")
        print("Image paths saved to image_paths.txt")

    def __getitem__(self, idx):
        try:
            item = {
                key: torch.tensor(values[idx])
                for key, values in self.encoded_captions.items()
            }

            # Construct the full image path
            image_path = os.path.join(cfg.image_path, self.image_filenames[idx])
            image_path = os.path.normpath(image_path)

            # Clean up path to handle duplicate 'test' directories
            parts = image_path.split(os.sep)
            cleaned_parts = []
            for i, part in enumerate(parts):
                if part == "test" and i > 0 and parts[i - 1] == "test":
                    continue
                cleaned_parts.append(part)
            image_path = os.sep.join(cleaned_parts)

            # Read the image
            image = cv2.imread(image_path)

            if image is None:
                print(f"⚠️ Error loading image: {image_path}")
                image = np.zeros((cfg.size, cfg.size, 3), dtype=np.uint8)  # Placeholder image
                caption = ""  # Empty caption
            else:
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                caption = self.captions[idx]  # Use original caption

            # Apply transformations (resize + normalization)
            image = self.transforms(image=image)["image"]

            item["image"] = torch.tensor(image).permute(2, 0, 1).float()
            item["caption"] = caption

            return item

        except Exception as e:
            print(f"Error processing sample {idx}: {e}")
            image = np.zeros((cfg.size, cfg.size, 3), dtype=np.uint8)  # Placeholder image
            caption = ""  # Empty caption
            return {
                "image": torch.tensor(image).permute(2, 0, 1).float(),
                "caption": caption,
            }

    def __len__(self):
        return len(self.captions)


## Image Encoder

The image encoder code is straight forward. I'm using PyTorch Image Models library (timm) here which makes a lot of different image models available from ResNets to EfficientNets and many more. Here we will use a ResNet50 as our image encoder. You can easily use torchvision library to use ResNets if you don't want to install a new library.

The code encodes each image to a fixed size vector with the size of the model's output channels (in case of ResNet50 the vector size will be **2048**). This is the output after the nn.AdaptiveAvgPool2d() layer.

In [None]:
# class ImageEncoder(nn.Module):
#     """
#     Encode images to a fixed size vector
#     """

#     def __init__(
#         self, model_name=cfg.model_name, pretrained=cfg.pretrained, trainable=cfg.trainable
#     ):
#         super().__init__()
#         self.model = timm.create_model(
#             model_name, pretrained, num_classes=0, global_pool="avg"
#         )
#         for p in self.model.parameters():
#             p.requires_grad = trainable

#     def forward(self, x):
#         return self.model(x)

import clip
import torch
import torch.nn as nn

class ImageEncoder(nn.Module):
    """
    Encode images using CLIP's image encoder
    """

    def __init__(self, pretrained=cfg.pretrained, trainable=cfg.trainable):
        super().__init__()
        # Load the pretrained CLIP model
        self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=cfg.device)  # You can choose other CLIP architectures here
        
        # Set the image encoder part of CLIP to be trainable or frozen
        for param in self.clip_model.parameters():
            param.requires_grad = trainable

    def forward(self, x):
        # The image preprocessing is applied here for CLIP
        x = self.clip_preprocess(x).to(cfg.device)
        return self.clip_model.encode_image(x)


## Text Encoder

I'll use DistilBERT as the text encoder. Like its bigger brother BERT, two special tokens will be added to the actual input tokens: **CLS** and **SEP** which mark the start and end of a sentence. To grab the whole representation of a sentence (as the related BERT and DistilBERT papers point out) we use the final representations of the CLS token and we hope that this representation captures the overall meaning of the sentence (caption). Thinking it in this way, it is similar to what we did to images and converted them into a fixed size vector.

In the case of DistilBERT (and also BERT) the output hidden representation for each token is a vector with size **768**. So, the whole caption will be encoded in the CLS token representation whose size is 768.

In [None]:
# class TextEncoder(nn.Module):
#     def __init__(self, model_name=cfg.text_encoder_model, pretrained=cfg.pretrained, trainable=cfg.trainable):
#         super().__init__()
#         if pretrained:
#             self.model = DistilBertModel.from_pretrained(model_name)
#         else:
#             self.model = DistilBertModel(config=DistilBertConfig())
            
#         for p in self.model.parameters():
#             p.requires_grad = trainable

#         # we are using the CLS token hidden representation as the sentence's embedding
#         self.target_token_idx = 0

#     def forward(self, input_ids, attention_mask):
#         output = self.model(input_ids=input_ids, attention_mask=attention_mask)
#         last_hidden_state = output.last_hidden_state
#         return last_hidden_state[:, self.target_token_idx, :]

## Projection Head

Now that we have encoded both our images and texts into fixed size vectors (2048 for image and 768 for text) we need to bring (project) them into a _new world_ with **similar dimensions** for both images and texts in order to be able to compare them and push apart the non-relevant image and texts and pull together those that match. So, the following code will bring the 2048 and 768 dimensional vectors into a 256 (projection_dim) dimensional world, where we can **compare** them.

"embedding_dim" is the size of the input vector (2048 for images and 768 for texts) and "projection_dim" is the the size of the output vector which will be 256 for our case. For understanding the details of this part you can refer to the CLIP paper.

In [None]:
class ProjectionHead(nn.Module):
    def __init__(
        self,
        embedding_dim,
        projection_dim=cfg.projection_dim,
        dropout=cfg.dropout
    ):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)
    
    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x

## CLIP

![clip.png](attachment:fb0403a4-73b2-4b97-bfec-c824d11677ee.png)

Here we will use the previous modules that we built to implement the main model. The \_\_init\_\_ function is self-explanatory. In the forward function, we first encode the images and texts separately into fixed size vectors (with different dimensionalities). After that, using separate projection modules we project them to that shared world (space) that I talked about previously. Here the encodings will become of similar shape (256 in our case). After that we will compute the loss. Again I recommend reading CLIP paper to get it better but I'll try my best to explain this part.

In Linear Algebra, one common way to measure if two vectors are of similar characteristics (they are like each other) is to calculate their **dot product** (multiplying the matching entries and take the sum of them); if the final number is big, they are alike and if it is small they are not (relatively speaking)!

Let's now understand the loss function. We talked about two vectors, but, what do we have here? We have image_embeddings, a matrix with shape (batch_size, 256) and text_embeddings with shape (batch_size, 256). It means we have two groups of vectors instead of two single vectors. How do we measure how similar two groups of vectors (two matrices) are to each other? Again, with dot product (@ operator in PyTorch does the dot product or matrix multiplication in this case). To be able to multiply these two matrices together, we transpose the second one. Okay, we get a matrix with shape (batch_size, batch_size) which we will call logits. (temperature is equal to 1.0 in our case, so, it does not make a difference. You can play with it and see what difference it makes. Also look at the paper to see why it is here!).

In [None]:
# class CLIPModel(nn.Module):
#     def __init__(
#         self,
#         temperature=cfg.temperature,
#         image_embedding=cfg.image_embedding,
#         text_embedding=cfg.text_embedding,
#     ):
#         super().__init__()
#         self.image_encoder = ImageEncoder()
#         self.text_encoder = TextEncoder()
#         self.image_projection = ProjectionHead(embedding_dim=image_embedding)
#         self.text_projection = ProjectionHead(embedding_dim=text_embedding)
#         self.temperature = temperature

#     def forward(self, batch):
#         # Getting Image and Text Features
#         image_features = self.image_encoder(batch["image"])
#         text_features = self.text_encoder(
#             input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
#         )
#         # Getting Image and Text Embeddings (with same dimension)
#         image_embeddings = self.image_projection(image_features)
#         text_embeddings = self.text_projection(text_features)

#         # Calculating the Loss
#         logits = (text_embeddings @ image_embeddings.T) / self.temperature
#         images_similarity = image_embeddings @ image_embeddings.T
#         texts_similarity = text_embeddings @ text_embeddings.T
#         targets = F.softmax(
#             (images_similarity + texts_similarity) / 2 * self.temperature, dim=-1
#         )
#         texts_loss = cross_entropy(logits, targets, reduction='none')
#         images_loss = cross_entropy(logits.T, targets.T, reduction='none')
#         loss =  (images_loss + texts_loss) / 2.0 # shape: (batch_size)
#         return loss.mean()


# def cross_entropy(preds, targets, reduction='none'):
#     log_softmax = nn.LogSoftmax(dim=-1)
#     loss = (-targets * log_softmax(preds)).sum(1)
#     if reduction == "none":
#         return loss
#     elif reduction == "mean":
#         return loss.mean()

from transformers import CLIPModel

class CLIPWrapper(nn.Module):
    def __init__(self):
        super().__init__()
        self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

    def forward(self, inputs):
        return self.clip(**inputs)



So, in the best case scenario, text_embeddings and image_embedding matricies should be the same because they are describing similar things. Let's think now: if this happens, what would the logits matrix be like? Let's see with a simple example!

In [None]:
# A simple Example

batch_size = 4
dim = 256
embeddings = torch.randn(batch_size, dim)
out = embeddings @ embeddings.T
print(F.softmax(out, dim=-1))

So logits, in the best case, will be a matrix that if we take its softmax, will have 1.0s in the diagonal (An identity matrix to call it with fancy words!). As the loss function's job is to make model's predictions similar to targets (at least in most cases!), we want such a matrix as our target. That's the reason why we are calculating images_similarity and texts_similarity matrices in the code block above.

Now that we've got our targets matrix, we will use simple cross entropy to calculate the actual loss. I've written the full matrix form of cross entropy as a function which you can see in the bottom of the code block.

There's a simpler way to calculate this loss in PyTorch; by doing this: nn.CrossEntropyLoss()(logits, torch.arange(batch_size)). The reason of not using that here is that the dataset we are using has multiple captions for a single image; so, there is the possibility that two identical images with their similar captions exist in a batch (it is rare but it can happen). Taking the loss with this easier method will ignore this possibility and the model learns to pull apart two representations (assume them different)  that are actually the same. Obviously, we don't want this to happen so I calculated the whole target matrix in a way that takes care of these edge cases.

## Train

Here are some funtions to help us load train and valid dataloaders, our model and then train and evaluate our model on those. There's not much going on here; just simple training loop and utility functions

In [None]:
def make_train_valid_dfs():
    dataframe = pd.read_csv(r"F:\cv class project\ICFG-PDES\ICFG-PEDES\captions_cleaned.csv")
    max_id = dataframe["id"].max() + 1 if not cfg.debug else 100
    image_ids = np.arange(0, max_id)
    np.random.seed(42)
    valid_ids = np.random.choice(
        image_ids, size=int(0.2 * len(image_ids)), replace=False
    )
    train_ids = [id_ for id_ in image_ids if id_ not in valid_ids]
    train_dataframe = dataframe[dataframe["id"].isin(train_ids)].reset_index(drop=True)
    valid_dataframe = dataframe[dataframe["id"].isin(valid_ids)].reset_index(drop=True)
    return train_dataframe, valid_dataframe


def build_loaders(dataframe, tokenizer, mode):
    transforms = get_transforms(mode=mode)
    dataset = CLIPDataset(
        dataframe["image"].values,
        dataframe["caption"].values,
        tokenizer=tokenizer,
        transforms=transforms,
    )
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=cfg.batch_size,
        num_workers=cfg.num_workers,
        shuffle=True if mode == "train" else False,
    )
    return dataloader

Here's a handy function to train our model. There's not much happening here; just loading the batches, feeding them to the model and stepping the optimizer and lr_scheduler.

In [None]:
def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
    loss_meter = AvgMeter()
    tqdm_object = tqdm(train_loader, total=len(train_loader))
    for batch in tqdm_object:
        batch = {k: v.to(cfg.device) for k, v in batch.items() if k != "caption"}
        loss = model(batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if step == "batch":
            lr_scheduler.step()

        count = batch["image"].size(0)
        loss_meter.update(loss.item(), count)

        tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
    return loss_meter


def valid_epoch(model, valid_loader):
    loss_meter = AvgMeter()

    tqdm_object = tqdm(valid_loader, total=len(valid_loader))
    for batch in tqdm_object:
        batch = {k: v.to(cfg.device) for k, v in batch.items() if k != "caption"}
        loss = model(batch)

        count = batch["image"].size(0)
        loss_meter.update(loss.item(), count)

        tqdm_object.set_postfix(valid_loss=loss_meter.avg)
    return loss_meter

Running the next cell start training the model. Put the kernel on GPU mode. Every epoch should take about 24 minutes on GPU (even one epoch is enough!).

In [None]:

train_df, valid_df = make_train_valid_dfs()
tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
train_loader = build_loaders(train_df, tokenizer, mode="train")
valid_loader = build_loaders(valid_df, tokenizer, mode="valid")

model = CLIPModel().to(CFG.device)

params = [
    {"params": model.image_encoder.parameters(), "lr": cfg.image_encoder_lr},
    {"params": model.text_encoder.parameters(), "lr": cfg.text_encoder_lr},
    {"params": itertools.chain(
        model.image_projection.parameters(), model.text_projection.parameters()
    ), "lr": cfg.head_lr, "weight_decay": cfg.weight_decay}
]
optimizer = torch.optim.AdamW(params, weight_decay=0.)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", patience=cfg.patience, factor=cfg.factor
)
step = "epoch"

best_loss = float('inf')
for epoch in range(cfg.epochs):
    print(f"Epoch: {epoch + 1}")
    model.train()
    train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step)
    model.eval()
    with torch.no_grad():
        valid_loss = valid_epoch(model, valid_loader)

    if valid_loss.avg < best_loss:
        best_loss = valid_loss.avg
        torch.save(model.state_dict(), "best1.pt")
        print("Saved Best Model!")

    lr_scheduler.step(valid_loss.avg)



    

    # "F:\cv class project\ICFG-PDES\ICFG-PEDES\imgs\test\0000\0000_004_01_0303morning_0017_1.jpg"

## Inference

Okay! We are done with training the model. Now, we need to do inference which in our case will be giving the model a piece of text and want it to retrieve the most relevant images from an unseen validation (or test) set.

### Getting Image Embeddings

In this function, we are loading the model that we saved after training, feeding it images in validation set and returning the image_embeddings with shape (valid_set_size, 256) and the model itself.

In [None]:
def get_image_embeddings(valid_df, model_path):
    tokenizer = DistilBertTokenizer.from_pretrained(cfg.text_tokenizer)
    valid_loader = build_loaders(valid_df, tokenizer, mode="valid")
    
    model = CLIPModel().to(cfg.device)
    model.load_state_dict(torch.load(model_path, map_location=cfg.device))
    model.eval()
    
    valid_image_embeddings = []
    with torch.no_grad():
        for batch in tqdm(valid_loader):
            image_features = model.image_encoder(batch["image"].to(cfg.device))
            image_embeddings = model.image_projection(image_features)
            valid_image_embeddings.append(image_embeddings)
    return model, torch.cat(valid_image_embeddings)

In [None]:
_, valid_df = make_train_valid_dfs()
model, image_embeddings = get_image_embeddings(valid_df, "best.pt")

### Finding Matches

This function does the final task that we wished our model would be capable of: it gets the model, image_embeddings, and a text query. It will display the most relevant images from the validation set! Isn't it amazing? Let's see how it performs.

In [None]:
def find_matches(model, image_embeddings, query, image_filenames, n=9):
    tokenizer = DistilBertTokenizer.from_pretrained(cfg.text_tokenizer)
    encoded_query = tokenizer([query])
    batch = {
        key: torch.tensor(values).to(cfg.device)
        for key, values in encoded_query.items()
    }
    with torch.no_grad():
        text_features = model.text_encoder(
            input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
        )
        text_embeddings = model.text_projection(text_features)
    
    image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
    text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
    dot_similarity = text_embeddings_n @ image_embeddings_n.T
    
    values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
    matches = [image_filenames[idx] for idx in indices[::5]]
    
    _, axes = plt.subplots(3, 3, figsize=(10, 10))
    for match, ax in zip(matches, axes.flatten()):
        image = cv2.imread(f"{cfg.image_path}/{match}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        ax.imshow(image)
        ax.axis("off")
    
    plt.show()

In [None]:
import os
import matplotlib.pyplot as plt
import cv2
import torch
import torch.nn.functional as F
from transformers import DistilBertTokenizer

def find_matches(model, image_embeddings, query, image_filenames, n=9):
    tokenizer = DistilBertTokenizer.from_pretrained(cfg.text_tokenizer)
    encoded_query = tokenizer([query])
    batch = {
        key: torch.tensor(values).to(cfg.device)
        for key, values in encoded_query.items()
    }

    with torch.no_grad():
        text_features = model.text_encoder(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"]
        )
        text_embeddings = model.text_projection(text_features)

    image_embeddings_n = F.normalize(image_embeddings, p=2, dim=-1)
    text_embeddings_n = F.normalize(text_embeddings, p=2, dim=-1)
    dot_similarity = text_embeddings_n @ image_embeddings_n.T

    values, indices = torch.topk(dot_similarity.squeeze(0), n * 5)
    matches = [image_filenames[idx] for idx in indices[::5]]

    _, axes = plt.subplots(3, 3, figsize=(10, 10))
    shown = 0

    for match, ax in zip(matches, axes.flatten()):
        image_path = os.path.join(cfg.image_path, match)
        if not os.path.exists(image_path):
            print(f"[Warning] File does not exist: {image_path}")
            ax.axis("off")
            continue

        image = cv2.imread(image_path)
        if image is None:
            print(f"[Error] Could not read image at: {image_path}")
            ax.axis("off")
            continue

        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        ax.imshow(image)
        ax.axis("off")
        shown += 1

    if shown == 0:
        print("[Info] No images were successfully loaded.")

    plt.tight_layout()
    plt.show()


This is how we use this function. The results:

In [None]:
find_matches(model, 
             image_embeddings,
             query="A middle-aged woman with black hair tied on the back is wearing a black hooded insulated jacket with a grey round patch on front over a grey shirt. She is also wearing fitted blue denim pants.",
             
             image_filenames=valid_df['image'].values,
             n=9)

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm

# Load trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Step 1: Initialize the CLIP Model
model = CLIPModel().to(device)  # Make sure this matches your training architecture

# Step 2: Load the trained weights into the model
model.load_state_dict(torch.load("best.pt", map_location=device))  # Load state_dict

# Step 3: Set the model to evaluation mode
model.eval()  # Load the best saved model
tokenizer = DistilBertTokenizer.from_pretrained(CFG.text_tokenizer)
# Function to extract features
def extract_features(dataloader, model, feature_type="image"):
    """Extract image or text features using the trained model."""
    features = []
    with torch.no_grad():
        for batch in tqdm(dataloader, desc=f"Extracting {feature_type} features"):
            batch = {k: v.to(device) for k, v in batch.items() if k != "caption"}

            if feature_type == "image":
                feature = model.image_projection(model.image_encoder(batch["image"]))
            else:  # Text Features
                feature = model.text_projection(model.text_encoder(
                    input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]
                ))

            feature = F.normalize(feature, dim=1)  # Normalize for cosine similarity
            features.append(feature.cpu().numpy())

    return np.vstack(features)

# Load validation dataset
_, valid_dataframe = make_train_valid_dfs()  # Get validation split

# Initialize tokenizer (assuming you already have a tokenizer object)
valid_loader = build_loaders(valid_dataframe, tokenizer, mode="valid")

# Extract features using validation data
image_features = extract_features(valid_loader, model, "image")
text_features = extract_features(valid_loader, model, "text")


# Compute cosine similarity between image and text features
similarity_matrix = np.matmul(text_features, image_features.T)

# Function to compute retrieval metrics
def compute_retrieval_metrics(similarity_matrix, top_k=[1, 5, 10]):
    """Compute Recall@K and Mean Average Precision (mAP)."""
    num_queries = similarity_matrix.shape[0]
    recall_at_k = {k: 0 for k in top_k}
    average_precision = []
    image_to_indices = valid_df.groupby('image').indices  # Map image to caption indices

    for i in range(num_queries):
        # sorted_indices = np.argsort(similarity_matrix[i])[::-1]  # Descending order
        # gt_index = i  # Assuming ground-truth image-text pair is at index i
        sorted_indices = np.argsort(similarity_matrix[i])[::-1]
        gt_image = valid_df.iloc[i]['image']
        gt_indices = image_to_indices[gt_image]  # All caption indices for this image

        # Compute Recall@K
        # for k in top_k:
        #     if gt_index in sorted_indices[:k]:
        #         recall_at_k[k] += 1

        # # Compute AP
        # rank = np.where(sorted_indices == gt_index)[0][0] + 1
        # average_precision.append(1.0 / rank)

        for k in top_k:
            if any(idx in sorted_indices[:k] for idx in gt_indices):
                recall_at_k[k] += 1

        ranks = [np.where(sorted_indices == idx)[0][0] + 1 for idx in gt_indices]
        average_precision.append(np.mean([1.0 / rank for rank in ranks]))



    # Normalize metrics
    recall_at_k = {k: recall_at_k[k] / num_queries for k in recall_at_k}
    mean_ap = np.mean(average_precision)

    return recall_at_k, mean_ap

# Compute metrics
recall_k, mean_ap = compute_retrieval_metrics(similarity_matrix)

# Print results
print("\nRetrieval Performance:")
for k in recall_k:
    print(f"Recall@{k}: {recall_k[k]:.4f}")
print(f"Mean Average Precision (mAP): {mean_ap:.4f}")

In [None]:
"F:\cv class project\ICFG-PDES\ICFG-PEDES\imgs\test\0000\0000_000_01_0303morning_0015_0.jpg"
 F:\cv class project\ICFG-PDES\ICFG-PEDES\imgs\test\train\0715\0715_009_07_0302afternoon_1291_0.jpg

In [11]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import torch.nn.functional as F
from tqdm import tqdm

# =========================
# Load CSV
# =========================
csv_file = r'F:\cv class project\ICFG-PDES\ICFG-PEDES\invalid_paths.csv'
df = pd.read_csv(csv_file, header=0)  # Ensure your CSV has header: image_path, description, id

# =========================
# Transform for CLIP
# =========================
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5] * 3, [0.5] * 3),
])

# =========================
# Load CLIP
# =========================
model_name = "openai/clip-vit-base-patch32"
processor = CLIPProcessor.from_pretrained(model_name)
model = CLIPModel.from_pretrained(model_name).cuda().eval()

# =========================
# Dataset Definition
# =========================
class ImageTextDataset(Dataset):
    def __init__(self, dataframe, processor, transform=None):
        self.df = dataframe
        self.processor = processor
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_path = row['image_path']
        caption = row['description']
        label_id = int(row['id'])

        if not os.path.exists(image_path):
            raise FileNotFoundError(f"Missing image: {image_path}")

        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)

        # Prompt-enhanced caption
        prompt = f"A photo of a person. {caption}"
        text_inputs = self.processor(
            text=prompt,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=77
        )

        return image, text_inputs, label_id

# =========================
# Collate Function for Dataloader
# =========================
def collate_fn(batch):
    images, texts, labels = zip(*batch)
    images = torch.stack(images)

    # Convert text dict list into batched dict
    input_ids = torch.cat([t['input_ids'] for t in texts], dim=0)
    attention_mask = torch.cat([t['attention_mask'] for t in texts], dim=0)

    labels = torch.tensor(labels)
    return images, {'input_ids': input_ids, 'attention_mask': attention_mask}, labels

# =========================
# Dataloader
# =========================
dataset = ImageTextDataset(df, processor, transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)

# =========================
# Inference
# =========================
image_embeds = []
text_embeds = []
label_ids = []

with torch.no_grad():
    for images, text_inputs, labels in tqdm(dataloader, desc="Extracting embeddings"):
        images = images.cuda()
        input_ids = text_inputs['input_ids'].cuda()
        attention_mask = text_inputs['attention_mask'].cuda()

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=images)

        # Normalize
        img_embed = F.normalize(outputs.image_embeds, p=2, dim=1)
        txt_embed = F.normalize(outputs.text_embeds, p=2, dim=1)

        image_embeds.append(img_embed.cpu())
        text_embeds.append(txt_embed.cpu())
        label_ids.append(labels)

# =========================
# Save Embeddings
# =========================
image_embeds = torch.cat(image_embeds, dim=0)
text_embeds = torch.cat(text_embeds, dim=0)
label_ids = torch.cat(label_ids, dim=0)

torch.save({
    "image_embeds": image_embeds,
    "text_embeds": text_embeds,
    "labels": label_ids
}, "clip_embeddings.pt")

print("✅ Embeddings saved to clip_embeddings.pt")


Extracting embeddings: 100%|██████████| 32/32 [00:03<00:00,  8.76it/s]

✅ Embeddings saved to clip_embeddings.pt





Extracting embeddings:   0%|          | 0/1725 [00:00<?, ?it/s]


KeyError: 'image'

In [12]:
def recall_at_k(similarity_matrix, labels, k):
    correct = 0
    sorted_indices = np.argsort(similarity_matrix, axis=1)[:, ::-1]
    for i in range(len(labels)):
        if labels[i] in labels[sorted_indices[i, :k]]:
            correct += 1
    return correct / len(labels)

def compute_mAP(similarity_matrix, labels):
    sorted_indices = np.argsort(similarity_matrix, axis=1)[:, ::-1]
    num_classes = len(np.unique(labels))
    aps = []

    for cls in np.unique(labels):
        y_true = (labels == cls).astype(int)
        y_score = np.array([
            1 if cls in labels[sorted_indices[i, :10]] else 0
            for i in range(len(labels))
        ])
        aps.append(average_precision_score(y_true, y_score))

    return np.mean(aps)


In [13]:
# Load saved embeddings
image_embeds = np.load("image_embeddings.npy")
text_embeds = np.load("text_embeddings.npy")
labels = np.load("labels.npy")

# Cosine similarity
similarity_matrix = np.dot(image_embeds, text_embeds.T)

# Recall and mAP
print(f"Recall@1: {recall_at_k(similarity_matrix, labels, 1):.4f}")
print(f"Recall@5: {recall_at_k(similarity_matrix, labels, 5):.4f}")
print(f"Recall@10: {recall_at_k(similarity_matrix, labels, 10):.4f}")
print(f"mAP: {compute_mAP(similarity_matrix, labels):.4f}")


Recall@1: 0.0022
Recall@5: 0.0079
Recall@10: 0.0137


KeyboardInterrupt: 

In [14]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from sklearn.metrics import average_precision_score

# =========================
# Load CSV (Update your path!)
# =========================
csv_path = r"F:\cv class project\ICFG-PDES\ICFG-PEDES\invalid_paths.csv"  # Make sure it has columns: image_path, description, id
df = pd.read_csv(csv_path)

# =========================
# CLIP Model Setup
# =========================
model_name = "openai/clip-vit-base-patch32"
processor = CLIPProcessor.from_pretrained(model_name)
model = CLIPModel.from_pretrained(model_name).cuda().eval()

# =========================
# Image Preprocessing
# =========================
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# =========================
# Dataset Class
# =========================
class CLIPDataset(Dataset):
    def __init__(self, dataframe, processor, transform=None):
        self.df = dataframe
        self.processor = processor
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_path = row['image_path']
        caption = row['description']
        label_id = int(row['id'])

        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)

        prompt = f"A photo of a person. {caption}"
        text_input = self.processor(
            text=prompt,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=77
        )

        return image, text_input, label_id

# =========================
# Collate Function
# =========================
def collate_fn(batch):
    images, texts, labels = zip(*batch)
    images = torch.stack(images)
    input_ids = torch.cat([t['input_ids'] for t in texts], dim=0)
    attention_mask = torch.cat([t['attention_mask'] for t in texts], dim=0)
    labels = torch.tensor(labels)
    return images, {'input_ids': input_ids, 'attention_mask': attention_mask}, labels

# =========================
# Dataloader
# =========================
dataset = CLIPDataset(df, processor, transform)
loader = DataLoader(dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

# =========================
# Embedding Extraction
# =========================
image_embeds = []
text_embeds = []
label_ids = []

with torch.no_grad():
    for images, text_inputs, labels in tqdm(loader, desc="Extracting Embeddings"):
        images = images.cuda()
        input_ids = text_inputs['input_ids'].cuda()
        attention_mask = text_inputs['attention_mask'].cuda()

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=images)
        img_embed = F.normalize(outputs.image_embeds, p=2, dim=1)
        txt_embed = F.normalize(outputs.text_embeds, p=2, dim=1)

        image_embeds.append(img_embed.cpu())
        text_embeds.append(txt_embed.cpu())
        label_ids.append(labels)

# =========================
# Save Embeddings
# =========================
image_embeds = torch.cat(image_embeds).numpy()
text_embeds = torch.cat(text_embeds).numpy()
labels = torch.cat(label_ids).numpy()

np.save("image_embeddings.npy", image_embeds)
np.save("text_embeddings.npy", text_embeds)
np.save("labels.npy", labels)

print("✅ Embeddings saved")

# =========================
# Evaluation Functions
# =========================
def recall_at_k(similarity_matrix, labels, k):
    correct = 0
    sorted_indices = np.argsort(similarity_matrix, axis=1)[:, ::-1]
    for i in range(len(labels)):
        if labels[i] in labels[sorted_indices[i, :k]]:
            correct += 1
    return correct / len(labels)

def compute_mAP(similarity_matrix, labels):
    aps = []
    sorted_indices = np.argsort(similarity_matrix, axis=1)[:, ::-1]

    for i in range(len(labels)):
        y_true = (labels == labels[i]).astype(int)
        y_score = similarity_matrix[i]
        aps.append(average_precision_score(y_true, y_score))

    return np.mean(aps)

# =========================
# Evaluation
# =========================
image_embeds = np.load("image_embeddings.npy")
text_embeds = np.load("text_embeddings.npy")
labels = np.load("labels.npy")

# Ensure unique image-text pair handling
unique_images = {}
for idx, label in enumerate(labels):
    if label not in unique_images:
        unique_images[label] = image_embeds[idx]
image_embeds = np.array([unique_images[label] for label in labels])  # One image per label

similarity = np.dot(text_embeds, image_embeds.T)

print(f"Recall@1 : {recall_at_k(similarity, labels, 1):.4f}")
print(f"Recall@5 : {recall_at_k(similarity, labels, 5):.4f}")
print(f"Recall@10: {recall_at_k(similarity, labels, 10):.4f}")
print(f"mAP      : {compute_mAP(similarity, labels):.4f}")


Extracting Embeddings: 100%|██████████| 16/16 [00:03<00:00,  4.94it/s]


✅ Embeddings saved
Recall@1 : 0.0381
Recall@5 : 0.0942
Recall@10: 0.1583
mAP      : 0.0783


In [None]:
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import numpy as np
from sklearn.metrics import average_precision_score
from tqdm import tqdm
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Configuration
class CFG:
    model_name = "openai/clip-vit-base-patch32"
    batch_size = 4
    num_workers = 4
    max_length = 77
    image_size = 224
    epochs = 1
    lr = 5e-6
    weight_decay = 0.01
    device = "cuda" if torch.cuda.is_available() else "cpu"
    data_path = r"F:\cv class project\ICFG-PDES\ICFG-PEDES\invalid_paths.csv"
    checkpoint_path = "clip_finetuned_best.pt"

cfg = CFG()

# Custom Dataset
class ImageTextDataset(Dataset):
    def __init__(self, dataframe, processor, transform=None):
        self.dataframe = dataframe
        self.processor = processor
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        image_path = self.dataframe.iloc[idx]['image']  # Use full path from CSV
        description = self.dataframe.iloc[idx]['caption']
        label_id = self.dataframe.iloc[idx]['id']

        try:
            if not os.path.exists(image_path):
                logging.warning(f"Image not found: {image_path}")
                return None, None, None

            image = Image.open(image_path).convert("RGB")
            if self.transform:
                image = self.transform(image)

            text_inputs = self.processor(
                text=description,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=cfg.max_length
            )

            return image, text_inputs, label_id

        except Exception as e:
            logging.error(f"Error loading sample {idx}: {e}")
            return None, None, None

    def collate_fn(self, batch):
        images, text_inputs, label_ids = [], [], []
        for image, text, label in batch:
            if image is not None:
                images.append(image)
                text_inputs.append(text)
                label_ids.append(label)

        if not images:
            return None, None, None

        images = torch.stack(images)
        input_ids = torch.cat([t['input_ids'] for t in text_inputs], dim=0)
        attention_mask = torch.cat([t['attention_mask'] for t in text_inputs], dim=0)
        label_ids = torch.tensor(label_ids)

        return images, {'input_ids': input_ids, 'attention_mask': attention_mask}, label_ids

# Data Augmentation
train_transform = transforms.Compose([
    transforms.Resize((cfg.image_size, cfg.image_size)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

valid_transform = transforms.Compose([
    transforms.Resize((cfg.image_size, cfg.image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

# Load and Split Data
df = pd.read_csv(cfg.data_path)
train_size = int(0.8 * len(df))
train_df = df[:train_size]
valid_df = df[train_size:]

# Initialize Processor and Model
processor = CLIPProcessor.from_pretrained(cfg.model_name)
model = CLIPModel.from_pretrained(cfg.model_name).to(cfg.device)

# Datasets and Dataloaders
train_dataset = ImageTextDataset(train_df, processor, train_transform)
valid_dataset = ImageTextDataset(valid_df, processor, valid_transform)

train_loader = DataLoader(
    train_dataset,
    batch_size=cfg.batch_size,
    shuffle=True,
    num_workers=cfg.num_workers,
    collate_fn=train_dataset.collate_fn
)
valid_loader = DataLoader(
    valid_dataset,
    batch_size=cfg.batch_size,
    shuffle=False,
    num_workers=cfg.num_workers,
    collate_fn=valid_dataset.collate_fn
)

# Contrastive Loss
def contrastive_loss(image_embeds, text_embeds, temperature=0.07):
    logits = torch.matmul(image_embeds, text_embeds.T) / temperature
    labels = torch.arange(logits.size(0)).to(cfg.device)
    loss_i2t = F.cross_entropy(logits, labels)
    loss_t2i = F.cross_entropy(logits.T, labels)
    return (loss_i2t + loss_t2i) / 2

# Training Loop
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.epochs)
best_map = 0

for epoch in range(cfg.epochs):
    model.train()
    train_loss = 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg.epochs} - Training"):
        if batch[0] is None:
            continue

        images, text_inputs, _ = batch
        images = images.to(cfg.device)
        input_ids = text_inputs['input_ids'].to(cfg.device)
        attention_mask = text_inputs['attention_mask'].to(cfg.device)

        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=images)
        loss = contrastive_loss(outputs.image_embeds, outputs.text_embeds)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(train_loader)
    logging.info(f"Epoch {epoch+1}/{cfg.epochs} - Train Loss: {train_loss:.4f}")

    # Validation
    model.eval()
    image_embeds, text_embeds, label_ids = [], [], []
    with torch.no_grad():
        for batch in tqdm(valid_loader, desc=f"Epoch {epoch+1}/{cfg.epochs} - Validation"):
            if batch[0] is None:
                continue

            images, text_inputs, ids = batch
            images = images.to(cfg.device)
            input_ids = text_inputs['input_ids'].to(cfg.device)
            attention_mask = text_inputs['attention_mask'].to(cfg.device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=images)
            image_embeds.append(outputs.image_embeds.cpu())
            text_embeds.append(outputs.text_embeds.cpu())
            label_ids.extend(ids.numpy())

    image_embeds = torch.cat(image_embeds).numpy()
    text_embeds = torch.cat(text_embeds).numpy()
    label_ids = np.array(label_ids)

    # Compute Metrics
    similarity_matrix = np.dot(image_embeds, text_embeds.T)

    def recall_at_k(similarity_matrix, labels, k):
        correct = 0
        sorted_indices = np.argsort(similarity_matrix, axis=1)[:, ::-1]
        for i, label in enumerate(labels):
            top_k_labels = labels[sorted_indices[i, :k]]
            if label in top_k_labels:
                correct += 1
        return correct / len(labels)

    def compute_mAP(similarity_matrix, labels):
        aps = []
        for i, label in enumerate(labels):
            y_true = (labels == label).astype(int)
            y_score = similarity_matrix[i]
            aps.append(average_precision_score(y_true, y_score))
        return np.mean(aps)

    recall_1 = recall_at_k(similarity_matrix, label_ids, 1)
    recall_5 = recall_at_k(similarity_matrix, label_ids, 5)
    recall_10 = recall_at_k(similarity_matrix, label_ids, 10)
    mAP = compute_mAP(similarity_matrix, label_ids)

    logging.info(f"Epoch {epoch+1}/{cfg.epochs} - Recall@1: {recall_1:.4f}, Recall@5: {recall_5:.4f}, Recall@10: {recall_10:.4f}, mAP: {mAP:.4f}")

    # Save Best Model
    if mAP > best_map:
        best_map = mAP
        torch.save(model.state_dict(), cfg.checkpoint_path)
        logging.info(f"Saved best model with mAP: {best_map:.4f}")

    scheduler.step()

# Final Evaluation
model.load_state_dict(torch.load(cfg.checkpoint_path))
model.eval()
image_embeds, text_embeds, label_ids = [], [], []
with torch.no_grad():
    for batch in tqdm(valid_loader, desc="Final Evaluation"):
        if batch[0] is None:
            continue

        images, text_inputs, ids = batch
        images = images.to(cfg.device)
        input_ids = text_inputs['input_ids'].to(cfg.device)
        attention_mask = text_inputs['attention_mask'].to(cfg.device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=images)
        image_embeds.append(outputs.image_embeds.cpu())
        text_embeds.append(outputs.text_embeds.cpu())
        label_ids.extend(ids.numpy())

image_embeds = torch.cat(image_embeds).numpy()
text_embeds = torch.cat(text_embeds).numpy()
label_ids = np.array(label_ids)

# Save Embeddings
np.save("image_embeddings_finetuned.npy", image_embeds)
np.save("text_embeddings_finetuned.npy", text_embeds)
np.save("labels_finetuned.npy", label_ids)

# Compute Final Metrics
similarity_matrix = np.dot(image_embeds, text_embeds.T)
recall_1 = recall_at_k(similarity_matrix, label_ids, 1)
recall_5 = recall_at_k(similarity_matrix, label_ids, 5)
recall_10 = recall_at_k(similarity_matrix, label_ids, 10)
mAP = compute_mAP(similarity_matrix, label_ids)

print(f"Final Metrics:")
print(f"Recall@1: {recall_1:.4f}")
print(f"Recall@5: {recall_5:.4f}")
print(f"Recall@10: {recall_10:.4f}")
print(f"mAP: {mAP:.4f}")

  from .autonotebook import tqdm as notebook_tqdm
Epoch 1/1 - Training:   0%|          | 0/100 [00:00<?, ?it/s]

In [18]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from sklearn.metrics import average_precision_score
from torch import nn
from torch.optim import AdamW

# =========================
# Load CSV (Update your path!)
# =========================
csv_path = r"F:\cv class project\ICFG-PDES\ICFG-PEDES\captions_cleaned.csv"
df = pd.read_csv(csv_path)

# =========================
# CLIP Model Setup
# =========================
model_name = "openai/clip-vit-base-patch32"
processor = CLIPProcessor.from_pretrained(model_name)

# =========================
# Image Preprocessing
# =========================
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# =========================
# Dataset Class
# =========================
class CLIPDataset(Dataset):
    def __init__(self, dataframe, processor, transform=None):
        self.df = dataframe
        self.processor = processor
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_path = row['image_path']
        caption = row['description']
        label_id = int(row['id'])

        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)

        prompt = f"A photo of a person. {caption}"
        text_input = self.processor(
            text=prompt,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=77
        )

        return image, text_input, label_id

# =========================
# Collate Function
# =========================
def collate_fn(batch):
    images, texts, labels = zip(*batch)
    images = torch.stack(images)
    input_ids = torch.cat([t['input_ids'] for t in texts], dim=0)
    attention_mask = torch.cat([t['attention_mask'] for t in texts], dim=0)
    labels = torch.tensor(labels)
    return images, {'input_ids': input_ids, 'attention_mask': attention_mask}, labels

# =========================
# DataLoader
# =========================
dataset = CLIPDataset(df, processor, transform)
loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

# =========================
# Model with Projection Head
# =========================
class CLIPRetrievalModel(nn.Module):
    def __init__(self, clip_model_name="openai/clip-vit-base-patch32", embed_dim=512):
        super().__init__()
        self.clip = CLIPModel.from_pretrained(clip_model_name)
        self.proj_image = nn.Linear(self.clip.config.projection_dim, embed_dim)
        self.proj_text = nn.Linear(self.clip.config.projection_dim, embed_dim)

    def forward(self, input_ids, attention_mask, pixel_values):
        outputs = self.clip(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values)
        img_embed = F.normalize(self.proj_image(outputs.image_embeds), p=2, dim=1)
        txt_embed = F.normalize(self.proj_text(outputs.text_embeds), p=2, dim=1)
        return img_embed, txt_embed

# =========================
# Training Loop
# =========================
model = CLIPRetrievalModel().cuda()
optimizer = AdamW(model.parameters(), lr=1e-5, weight_decay=1e-4)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(1):
    model.train()
    total_loss = 0
    for images, text_inputs, labels in tqdm(loader, desc=f"Training Epoch {epoch+1}"):
        input_ids = text_inputs['input_ids'].cuda()
        attention_mask = text_inputs['attention_mask'].cuda()
        images = images.cuda()

        img_embed, txt_embed = model(input_ids, attention_mask, images)

        logits_per_image = img_embed @ txt_embed.T
        logits_per_text = txt_embed @ img_embed.T
        ground_truth = torch.arange(len(images)).cuda()

        loss_img = loss_fn(logits_per_image, ground_truth)
        loss_txt = loss_fn(logits_per_text, ground_truth)
        loss = (loss_img + loss_txt) / 2

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}: Loss = {total_loss / len(loader):.4f}")

# =========================
# Save Fine-Tuned Model
# =========================
torch.save(model.state_dict(), "clip_finetuned.pt")

# =========================
# Evaluation Functions
# =========================
def recall_at_k(similarity_matrix, labels, k):
    correct = 0
    sorted_indices = np.argsort(similarity_matrix, axis=1)[:, ::-1]
    for i in range(len(labels)):
        if labels[i] in labels[sorted_indices[i, :k]]:
            correct += 1
    return correct / len(labels)

def compute_mAP(similarity_matrix, labels):
    aps = []
    sorted_indices = np.argsort(similarity_matrix, axis=1)[:, ::-1]
    for i in range(len(labels)):
        y_true = (labels == labels[i]).astype(int)
        y_score = similarity_matrix[i]
        aps.append(average_precision_score(y_true, y_score))
    return np.mean(aps)

# =========================
# Inference and Evaluation
# =========================
model.eval()
image_embeds = []
text_embeds = []
label_ids = []

with torch.no_grad():
    for images, text_inputs, labels in tqdm(loader, desc="Evaluating"):
        input_ids = text_inputs['input_ids'].cuda()
        attention_mask = text_inputs['attention_mask'].cuda()
        images = images.cuda()

        img_embed, txt_embed = model(input_ids, attention_mask, images)

        image_embeds.append(img_embed.cpu())
        text_embeds.append(txt_embed.cpu())
        label_ids.append(labels)

image_embeds = torch.cat(image_embeds).numpy()
text_embeds = torch.cat(text_embeds).numpy()
labels = torch.cat(label_ids).numpy()

# Handle multiple captions per image (optional simplification)
unique_images = {}
for idx, label in enumerate(labels):
    if label not in unique_images:
        unique_images[label] = image_embeds[idx]
image_embeds = np.array([unique_images[label] for label in labels])

similarity = np.dot(text_embeds, image_embeds.T)

print(f"Recall@1 : {recall_at_k(similarity, labels, 1):.4f}")
print(f"Recall@5 : {recall_at_k(similarity, labels, 5):.4f}")
print(f"Recall@10: {recall_at_k(similarity, labels, 10):.4f}")
print(f"mAP      : {compute_mAP(similarity, labels):.4f}")


Training Epoch 1: 100%|██████████| 863/863 [05:47<00:00,  2.48it/s]


Epoch 1: Loss = 2.9102


Evaluating: 100%|██████████| 863/863 [02:35<00:00,  5.54it/s]


Recall@1 : 0.0028
Recall@5 : 0.0115
Recall@10: 0.0209
mAP      : 0.0116


In [21]:
# =========================
# Training Loop
# =========================
best_recall1 = 0  # Variable to track the best Recall@1 score
model = CLIPRetrievalModel().cuda()
optimizer = AdamW(model.parameters(), lr=1e-5, weight_decay=1e-4)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(1):
    model.train()
    total_loss = 0
    for images, text_inputs, labels in tqdm(loader, desc=f"Training Epoch {epoch+1}"):
        input_ids = text_inputs['input_ids'].cuda()
        attention_mask = text_inputs['attention_mask'].cuda()
        images = images.cuda()

        img_embed, txt_embed = model(input_ids, attention_mask, images)

        logits_per_image = img_embed @ txt_embed.T
        logits_per_text = txt_embed @ img_embed.T
        ground_truth = torch.arange(len(images)).cuda()

        loss_img = loss_fn(logits_per_image, ground_truth)
        loss_txt = loss_fn(logits_per_text, ground_truth)
        loss = (loss_img + loss_txt) / 2

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}: Loss = {total_loss / len(loader):.4f}")

    # =========================
    # Inference and Evaluation
    # =========================
    model.eval()
    image_embeds = []
    text_embeds = []
    label_ids = []

    with torch.no_grad():
        for images, text_inputs, labels in tqdm(loader, desc="Evaluating"):
            input_ids = text_inputs['input_ids'].cuda()
            attention_mask = text_inputs['attention_mask'].cuda()
            images = images.cuda()

            img_embed, txt_embed = model(input_ids, attention_mask, images)

            image_embeds.append(img_embed.cpu())
            text_embeds.append(txt_embed.cpu())
            label_ids.append(labels)

    image_embeds = torch.cat(image_embeds).numpy()
    text_embeds = torch.cat(text_embeds).numpy()
    labels = torch.cat(label_ids).numpy()

    # Handle multiple captions per image (optional simplification)
    unique_images = {}
    for idx, label in enumerate(labels):
        if label not in unique_images:
            unique_images[label] = image_embeds[idx]
    image_embeds = np.array([unique_images[label] for label in labels])

    similarity = np.dot(text_embeds, image_embeds.T)

    recall1 = recall_at_k(similarity, labels, 1)
    print(f"Recall@1 : {recall1:.4f}")

    # Save the best model based on Recall@1
    if recall1 > best_recall1:
        best_recall1 = recall1
        torch.save(model.state_dict(), "model.pt")  # Save the best model

print(f"Best Recall@1: {best_recall1:.4f}")


Training Epoch 1: 100%|██████████| 863/863 [06:12<00:00,  2.31it/s]


Epoch 1: Loss = 2.9317


Evaluating: 100%|██████████| 863/863 [02:36<00:00,  5.52it/s]


Recall@1 : 0.0022
Best Recall@1: 0.0022
