<a href="https://colab.research.google.com/github/Mechanics-Mechatronics-and-Robotics/CV-2025/blob/main/Week_13/Lab10_CLIP_fine_tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Lab 10. CLIP fine-tuning.

In [None]:
!pip install git+https://github.com/openai/CLIP.git

**Let's recall what is fine-tuning itself?**

In machine learning, fine-tuning is a process of taking a pre-trained model and “tuning” its parameters slightly to adapt to a new, similar task.

**Why do we need this?**

* Save resources
* Transfer learning effect
* Limited data

To complete this lab we will fine-tune a model entity from Hugging Face repo.

Let's start with importing libraries...

In [None]:
from PIL import Image
import torch
from torch import nn, optim
import glob
import os
import pandas as pd
import json
import numpy as np
import clip
from torch.utils.data import Dataset, DataLoader, BatchSampler
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
import random
from matplotlib.pyplot import imshow

import nltk, re, string, collections
from nltk.util import ngrams
import collections

%matplotlib inline
device = "cuda" if torch.cuda.is_available() else "cpu"

BATCH_SIZE = 32
EPOCH = 5

In [None]:
!gdown 1WizPT7uYnpsEEelztKt3aLbRoFCzuijj

In [None]:
!unzip memes.zip

In [None]:
IMG_ROOT = "mems_images"
JSON_ROOT = "meme_project_clean_json"
img_paths = glob.glob(os.path.join(IMG_ROOT, "*.jpg"))
d = {}
for i, img_path in enumerate(img_paths):
    name = img_path.split("/")[-1].split(".")[0]
    with open(os.path.join(JSON_ROOT, name+".json"), "r") as f:
        captions = json.load(f)
        temp = []
        for cap in captions:
            if "http" not in (cap[0]+ ' '+cap[1]) and len(cap[0]+ ' '+cap[1]) >= 8 and len(cap[0]+ ' '+cap[1]) <= 72:
                temp.append(cap[0]+ ' '+cap[1])
        d[img_path] = temp
len(d)

In [None]:
Image.open(list(d.items())[0][0])

In [None]:
list(d.items())[0]

Splitting our dataset...

In [None]:
train_img_paths, test_img_paths = train_test_split(img_paths, test_size=0.2, random_state=42)
d_train = {k: d[k] for k in train_img_paths}
d_test = {k: d[k] for k in test_img_paths}
len(d_train), len(d_test)

Downloading our model

In [None]:
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

In [None]:
class MemeDataset(Dataset):
    def __init__(self, data, preprocess):
        self.preprocess = preprocess
        self.img_paths = []
        self.captions = []
        for img_path, captions in data.items():
            for cap in captions:
                self.img_paths.append(img_path)
                self.captions.append(cap)
        self.processed_cache = {}
        for img_path in data:
            self.processed_cache[img_path] = self.preprocess(Image.open(img_path))
        self.img_paths_set = list(data.keys())
        self.path2label = {path: self.img_paths_set.index(path) for path in self.img_paths_set}

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        image = self.processed_cache[img_path]
        caption = self.captions[idx]
        label = self.path2label[img_path]
        return image, caption, label

train_dataset = MemeDataset(d_train, preprocess)
test_dataset = MemeDataset(d_test, preprocess)


In [None]:
class BalancedBatchSampler(BatchSampler):
    """
    BatchSampler - from a MNIST-like dataset, samples n_classes and within these classes samples n_samples.
    Returns batches of size n_classes * n_samples
    """

    def __init__(self, labels, n_classes, n_samples):
        self.labels = labels
        self.labels_set = list(set(self.labels.numpy()))
        self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0]
                                 for label in self.labels_set}
        for l in self.labels_set:
            np.random.shuffle(self.label_to_indices[l])
        self.used_label_indices_count = {label: 0 for label in self.labels_set}
        self.count = 0
        self.n_classes = n_classes
        self.n_samples = n_samples
        self.n_dataset = len(self.labels)
        self.batch_size = self.n_samples * self.n_classes

    def __iter__(self):
        self.count = 0
        while self.count + self.batch_size < self.n_dataset:
            classes = np.random.choice(self.labels_set, self.n_classes, replace=False)
            indices = []
            for class_ in classes:
                indices.extend(self.label_to_indices[class_][
                               self.used_label_indices_count[class_]:self.used_label_indices_count[
                                                                         class_] + self.n_samples])
                self.used_label_indices_count[class_] += self.n_samples
                if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]):
                    np.random.shuffle(self.label_to_indices[class_])
                    self.used_label_indices_count[class_] = 0
            yield indices
            self.count += self.n_classes * self.n_samples

    def __len__(self):
        return self.n_dataset // self.batch_size

train_labels = torch.tensor([item[2] for item in train_dataset])
train_sampler = BalancedBatchSampler(train_labels, BATCH_SIZE, 1)
train_dataloader = DataLoader(train_dataset, batch_sampler=train_sampler)

test_labels = torch.tensor([item[2] for item in test_dataset])
test_sampler = BalancedBatchSampler(test_labels, BATCH_SIZE, 1)
test_dataloader = DataLoader(test_dataset, batch_sampler=test_sampler)

In [None]:
def convert_models_to_fp32(model):
    for p in model.parameters():
        p.data = p.data.float()
        p.grad.data = p.grad.data.float()

if device == "cpu":
    model.float()

loss_img = nn.CrossEntropyLoss()
loss_txt = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=1e-5)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_dataloader)*EPOCH)

In [None]:
best_te_loss = 1e5
best_ep = -1
for epoch in range(EPOCH):
    print(f"running epoch {epoch}, best test loss {best_te_loss} after epoch {best_ep}")
    step = 0
    tr_loss = 0
    model.train()
    pbar = tqdm(train_dataloader, leave=False)
    for batch in pbar:
        step += 1
        optimizer.zero_grad()

        images, texts, _ = batch
        images = images.to(device)
        texts = clip.tokenize(texts).to(device)

        logits_per_image, logits_per_text = model(images, texts)
        ground_truth = torch.arange(BATCH_SIZE).to(device)

        total_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2

        total_loss.backward()
        tr_loss += total_loss.item()
        if device == "cpu":
            optimizer.step()
            scheduler.step()
        else:
            convert_models_to_fp32(model)
            optimizer.step()
            scheduler.step()
            clip.model.convert_weights(model)
        pbar.set_description(f"train batchCE: {total_loss.item()}", refresh=True)
    tr_loss /= step

    step = 0
    te_loss = 0
    with torch.no_grad():
        model.eval()
        test_pbar = tqdm(test_dataloader, leave=False)
        for batch in test_pbar:
            step += 1
            images, texts, _ = batch
            images = images.to(device)
            texts = clip.tokenize(texts).to(device)
            logits_per_image, logits_per_text = model(images, texts)
            ground_truth = torch.arange(BATCH_SIZE).to(device)

            total_loss = (loss_img(logits_per_image,ground_truth) + loss_txt(logits_per_text,ground_truth))/2
            te_loss += total_loss.item()
            test_pbar.set_description(f"test batchCE: {total_loss.item()}", refresh=True)
        te_loss /= step

    if te_loss < best_te_loss:
        best_te_loss = te_loss
        best_ep = epoch
        torch.save(model.state_dict(), "best_model.pt")
    print(f"epoch {epoch}, tr_loss {tr_loss}, te_loss {te_loss}")
torch.save(model.state_dict(), "last_model.pt")

In [None]:
!gdown 1rC0lq-4_-F4lAHY0aSTvLtYnHohPcqbk

Let's download already fine-tuned model with 5 epoches...

In [None]:
model.load_state_dict(torch.load("last_model.pt", map_location=device))
NUM_NEG = 32
NUM_TEST = 1000

In [None]:
n_correct = 0
for i in tqdm(range(NUM_TEST)):
    empty = True
    while empty:
        img_path = random.choice(list(d_test.keys()))
        image = preprocess(Image.open(img_path)).unsqueeze(0).to(device)
        name = img_path.split('/')[-1].split('.')[0]
        caps = d_test[img_path]
        if len(caps) > 0:
            pos_txt = random.choice(caps)
        #         pos_txt = ' '.join(pos_txt)
            empty = False
#     print(pos_txt)
    neg_i = 0
    neg_txts = []
    while neg_i < NUM_NEG:
        img_path = random.choice(list(d_test.keys()))
        neg_name = img_path.split('/')[-1].split('.')[0]
        if neg_name == name:
            continue
        caps = d_test[img_path]
        if len(caps) == 0:
            continue
        neg_txt = random.choice(caps)
        if neg_txt in neg_txts:
            continue
        neg_txts.append(neg_txt)
        neg_i += 1
#     print(name)
#     print(f"Positive caption: {pos_txt}")
#     print(f"Negative caption: {neg_txts}")
    text = clip.tokenize([pos_txt]+neg_txts).to(device)

    with torch.no_grad():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)

        logits_per_image, logits_per_text = model(image, text)
        probs = logits_per_image.softmax(dim=-1).cpu().numpy()

#     print("Label probs:", probs)
#     print(np.argmax(probs))
    if np.argmax(probs) == 0:
        n_correct +=1
print(f"Test precision {n_correct/NUM_TEST}")

In [None]:
corpus = []
for txtlist in d_train.values():
    corpus += txtlist
len(corpus), corpus[0]

In [None]:
def sample1Caption(img_path, corpus, model, num_cand):
    image = preprocess(Image.open(img_path)).unsqueeze(0).to(device)
    i = 0
    txts = []
    while i < num_cand:
        txt = random.choice(corpus)
        if txt in txts:
            continue
        if len(txt.split())<5 or len(txt)>72:
            continue
        txts.append(txt)
        i += 1
    #     print(name)
    #     print(f"Positive caption: {pos_txt}")
    #     print(f"Negative caption: {neg_txts}")
    text = clip.tokenize(txts).to(device)

    with torch.no_grad():
        logits_per_image, logits_per_text = model(image, text)
        probs = logits_per_image.softmax(dim=-1).cpu().numpy()

    #     print("Label probs:", probs)
    #     print(np.argmax(probs))
    #     imshow(np.asarray(Image.open(img_path)))
    return txts[np.argmax(probs)]

In [None]:
seen_path = random.choice(list(d_train.keys()))
pred_cap_seen = sample1Caption(seen_path, corpus, model, 1000)
gt_cap_seen = d_train[seen_path][:5]
imshow(Image.open(seen_path))
print(f"Some ground truth captions for this seen image: {gt_cap_seen}")
print(f"Caption sampled by fintuned CLIP for this seen image: {pred_cap_seen}")

In [None]:
unseen_path = random.choice(list(d_test.keys()))
pred_cap_unseen = sample1Caption(unseen_path, corpus, model, 1000)
imshow(Image.open(unseen_path))
gt_cap_unseen = d_test[unseen_path][:5]
print(f"Some ground truth captions for this unseen image: {gt_cap_unseen}")
print(f"Caption sampled by fintuned CLIP for this unseen image: {pred_cap_unseen}")

# Homework (2 points)

You have to choose architecture with text-to-image principle and fine-tune it to your task.

Requirements:

* Find an architecture, if CLIP - it should be a different dataset
* Find a dataset adapted for text-to-image task
* Prepare a dataset to fine-tune your model accordingly
* Show metrics of fine-tune procedure (metrics for test set of something else)
* Generate an image based on the text or text based on the image (think about a function for that)
* You can arrange you own fine-tuning process

Example of datasets:

https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset

https://www.kaggle.com/datasets/validmodel/indo-fashion-dataset/data

https://huggingface.co/datasets?task_categories=task_categories:text-to-image&sort=trending

