# Fine-tune CLIP on Title-Thumbnail Pairs

Code authored by: Shaw Talebi

[Video link](https://youtu.be/W4s6b2ZM6kI) | [Blog link](https://medium.com/towards-data-science/fine-tuning-multimodal-embedding-models-bf007b1c5da5) <br>
[Dataset](https://huggingface.co/datasets/shawhin/yt-title-thumbnail-pairs) | [Fine-tuned Model](https://huggingface.co/shawhin/clip-title-thumbnail-embeddings)

### imports

In [1]:
from datasets import load_dataset

from PIL import Image
import requests

from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.evaluation import TripletEvaluator, SentenceEvaluator

from typing import List, Dict
import torch

### import model and dataset

In [2]:
model_name = "sentence-transformers/clip-ViT-L-14"
model = SentenceTransformer(model_name)

In [3]:
dataset = load_dataset("shawhin/yt-title-thumbnail-pairs")

### freeze model params

In [4]:
# pick specific layers to train (note: you can add more layers to this list)
trainable_layers_list = ['projection']

# Apply freezing configuration
for name, param in model.named_parameters():
    # freeze all params
    param.requires_grad = False

    # unfreeze layers in trainable_layers_list
    if any(layer in name for layer in trainable_layers_list):
        param.requires_grad = True

In [5]:
# Verify trainable parameters
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"Trainable: {name}")

Trainable: 0.model.visual_projection.weight
Trainable: 0.model.text_projection.weight


In [6]:
# Count total and trainable parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Percentage of trainable parameters: {100 * trainable_params / total_params:.2f}%")

Total parameters: 427,616,513
Trainable parameters: 1,376,256
Percentage of trainable parameters: 0.32%


### preprocess data

In [7]:
# process positive pairs
def preprocess(batch):
    """
        Preprocessing data without augmentations for test set
    """
    # get images from urls
    image_list = [Image.open(requests.get(url, stream=True).raw) for url in batch["thumbnail_url"]]

    # return columns with standard names
    return {
        "anchor": image_list,       
        "positive": batch["title"],  
        "negative": batch["title_neg"]
    }

In [8]:
# remove columns not relevant to training
columns_to_remove = [col for col in dataset['train'].column_names if col not in ['anchor', 'positive', 'negative']]
# applu transformations
dataset = dataset.map(preprocess, batched=True, remove_columns=columns_to_remove)

In [9]:
dataset

DatasetDict({
    train: Dataset({
        features: ['anchor', 'positive', 'negative'],
        num_rows: 53
    })
    valid: Dataset({
        features: ['anchor', 'positive', 'negative'],
        num_rows: 11
    })
    test: Dataset({
        features: ['anchor', 'positive', 'negative'],
        num_rows: 12
    })
})

### eval pre-trained model

In [10]:
def create_triplet_evaluator(set_name):
    """
        Create triplet evaluator for "train", "valid", or "test" split
    """

    return TripletEvaluator(
        anchors=dataset[f"{set_name}"]["anchor"],
        positives=dataset[f"{set_name}"]["positive"],
        negatives=dataset[f"{set_name}"]["negative"],
        name=f"yt-title-thumbnail-{set_name}",
    )

In [11]:
evaluator_train = create_triplet_evaluator("train")
evaluator_valid = create_triplet_evaluator("valid")

print("Train:", evaluator_train(model))
print("Valid:", evaluator_valid(model))

Train: {'yt-title-thumbnail-train_cosine_accuracy': np.float64(0.9622641509433962)}
Valid: {'yt-title-thumbnail-valid_cosine_accuracy': np.float64(1.0)}


In [12]:
class ImageTextRetrievalEvaluator(SentenceEvaluator):
    def __init__(
        self,
        images: List,
        texts: List[str],
        name: str = '',
        k: int = 1,
        batch_size: int = 32,
        show_progress_bar: bool = False
    ):
        self.images = images
        self.texts = texts
        self.name = name
        self.k = k
        self.batch_size = batch_size
        self.show_progress_bar = show_progress_bar

    def __call__(self,
        model: SentenceTransformer,
        output_path: str = None,
        epoch: int = -1,
        steps: int = -1) -> Dict[str, float]:
        
        # Get embeddings for all images
        img_embeddings = model.encode(
            self.images,
            batch_size=self.batch_size,
            show_progress_bar=self.show_progress_bar,
            convert_to_tensor=True
        )
        
        # Get embeddings for all texts
        text_embeddings = model.encode(
            self.texts,
            batch_size=self.batch_size,
            show_progress_bar=self.show_progress_bar,
            convert_to_tensor=True
        )
        
        # Compute similarity matrix
        cos_scores = torch.nn.functional.cosine_similarity(
            img_embeddings.unsqueeze(1),
            text_embeddings.unsqueeze(0),
            dim=2
        )
        
        # Get indices of top k predictions for each image
        _, top_indices = torch.topk(cos_scores, k=self.k, dim=1)
        
        # Calculate Recall@k (correct if ground truth index is in top k predictions)
        correct = sum(i in top_indices[i].tolist() for i in range(len(self.images)))
        recall_at_k = correct / len(self.images)

        return {f'{self.name}_Recall@{self.k}': recall_at_k}

In [13]:
def create_recall_evaluator(set_name, k=1):
    """
        Create triplet evaluator for "train", "valid", or "test" split
    """

    return ImageTextRetrievalEvaluator(
        images=dataset[f"{set_name}"]["anchor"],
        texts=dataset[f"{set_name}"]["positive"],
        name=f"yt-title-thumbnail-{set_name}",
        k=k
    )

In [14]:
# Create new evaluator with Recall@k
evaluator_recall_train = create_recall_evaluator("train", k=1)
evaluator_recall_valid = create_recall_evaluator("valid", k=1)

print("Train:", evaluator_recall_train(model))
print("Valid:", evaluator_recall_valid(model))

Train: {'yt-title-thumbnail-train_Recall@1': 0.660377358490566}
Valid: {'yt-title-thumbnail-valid_Recall@1': 0.6363636363636364}


### define training args

In [15]:
# define loss (note: loss expects columns to be ordered as anchor-positive-negative)
loss = MultipleNegativesRankingLoss(model)

# hyperparameters
num_epochs = 2
batch_size = 16
lr = 1e-4
finetuned_model_name = "clip-title-thumbnail-embeddings"

train_args = SentenceTransformerTrainingArguments(
    output_dir=f"models/{finetuned_model_name}",
    num_train_epochs=num_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=lr,
    # Evaluation settings
    eval_strategy="epoch",
    eval_steps=1,
    logging_steps=1,
)

### fine-tune model

In [16]:
%%time
trainer = SentenceTransformerTrainer(
    model=model,
    args=train_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["valid"],
    loss=loss,
    evaluator=[evaluator_recall_train, evaluator_recall_valid],
)
trainer.train()

Epoch,Training Loss,Validation Loss,Yt-title-thumbnail-train Recall@1,Yt-title-thumbnail-valid Recall@1,Sequential Score
1,0.7505,1.491625,0.830189,0.909091,0.909091
2,0.3315,1.499041,0.849057,0.909091,0.909091


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

CPU times: user 1.75 s, sys: 822 ms, total: 2.57 s
Wall time: 6.66 s


TrainOutput(global_step=8, training_loss=1.3635553307831287, metrics={'train_runtime': 6.3182, 'train_samples_per_second': 16.777, 'train_steps_per_second': 1.266, 'total_flos': 0.0, 'train_loss': 1.3635553307831287, 'epoch': 2.0})

### evaluate fine-tuned model

In [17]:
evaluator_test = create_triplet_evaluator("test")

print("Train:", evaluator_train(model))
print("Valid:", evaluator_valid(model))
print("Test:", evaluator_valid(model))

Train: {'yt-title-thumbnail-train_cosine_accuracy': np.float64(1.0)}
Valid: {'yt-title-thumbnail-valid_cosine_accuracy': np.float64(1.0)}
Test: {'yt-title-thumbnail-valid_cosine_accuracy': np.float64(1.0)}


In [18]:
evaluator_recall_test = create_recall_evaluator("test")

print("Train:", evaluator_recall_train(model))
print("Valid:", evaluator_recall_valid(model))
print("Test:", evaluator_recall_test(model))

Train: {'yt-title-thumbnail-train_Recall@1': 0.8490566037735849}
Valid: {'yt-title-thumbnail-valid_Recall@1': 0.9090909090909091}
Test: {'yt-title-thumbnail-test_Recall@1': 0.75}


### push model to hub

In [20]:
model.push_to_hub(f"shawhin/{finetuned_model_name}")

model.safetensors:   0%|          | 0.00/1.71G [00:00<?, ?B/s]

'https://huggingface.co/shawhin/clip-title-thumbnail-embeddings/commit/05dcc90819309f6823025915ff0a58d4e2bdd95d'