## Model Training & Fine Tuning CLIP For Flickr8K Dataset

In [1]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("adityajn105/flickr8k")

print("Path to dataset files:", path)

Path to dataset files: /home/panchani.d/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1


In [3]:
import os

# Path to the directory
directory_path = "/home/gohil.de/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/Images"

# List all paths in the directory
image_paths = []
for root, dirs, files in os.walk(directory_path):
    for name in files:
        image_paths.append(os.path.join(root, name))
print("Image File paths: \n")
print(image_paths[:3])

# Initialize dictionary to store image captions
captions = {}

# Open and read the captions.txt file
with open("/home/panchani.d/.cache/kagglehub/datasets/adityajn105/flickr8k/versions/1/captions.txt", "r") as file:
    for line in file:
        # Split line into image name and caption
        line = line.strip()
        if not line:  # Skip empty lines
            continue
        image_name, caption = line.split(",", 1)  # Split at the first comma

        # Add caption to the dictionary
        if image_name in captions:
            captions[image_name].append(caption)
        else:
            captions[image_name] = [caption]

del captions['image']
print("Image - Captions Map: \n")
captions = list(captions.items())
print(captions[:3])

Image File paths: 

[]
Image - Captions Map: 

[('1000268201_693b08cb0e.jpg', ['A child in a pink dress is climbing up a set of stairs in an entry way .', 'A girl going into a wooden building .', 'A little girl climbing into a wooden playhouse .', 'A little girl climbing the stairs to her playhouse .', 'A little girl in a pink dress going into a wooden cabin .']), ('1001773457_577c3a7d70.jpg', ['A black dog and a spotted dog are fighting', 'A black dog and a tri-colored dog playing with each other on the road .', 'A black dog and a white dog with brown spots are staring at each other in the street .', 'Two dogs of different breeds looking at each other on the road .', 'Two dogs on pavement moving toward each other .']), ('1002674143_1b742ab4b8.jpg', ['A little girl covered in paint sits in front of a painted rainbow with her hands in a bowl .', 'A little girl is sitting in front of a large painted rainbow .', 'A small girl in the grass plays with fingerpaints in front of a white canvas

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
from diffusers import UNet2DConditionModel, AutoencoderKL, LMSDiscreteScheduler
import numpy as np
from tqdm import tqdm

# 1. Prepare Flickr8k Dataset
class Flickr8kDataset(Dataset):
    def __init__(self, image_paths, captions, transform=None):
        self.image_paths = image_paths
        self.captions = captions
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        if self.transform:
            image = self.transform(image)
        caption = self.captions[idx][1][0]
        return image, caption

# Set up data preprocessing
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])

dataset = Flickr8kDataset(image_paths, captions, transform=preprocess)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# 2. Define CLIP Fine-tuning Model
class CLIPFineTuner(nn.Module):
    def __init__(self, text_encoder, vision_encoder, projection_dim=512):
        super().__init__()
        self.text_encoder = text_encoder
        self.vision_encoder = vision_encoder
        
        # Add projection layers to align dimensions
        self.text_projection = nn.Linear(768, projection_dim)  # Project text features
        self.image_projection = nn.Linear(1024, projection_dim)  # Project image features
        
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    def forward(self, images, input_ids, attention_mask):
        # Extract features
        image_features = self.vision_encoder(images).last_hidden_state[:, 0, :]
        text_features = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
        
        # Normalize features
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        
        # Project features into the same dimension
        image_features = self.image_projection(image_features)
        text_features = self.text_projection(text_features)
        
        # Normalize projected features
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        
        # Compute logits
        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logits_per_image.t()

        return logits_per_image, logits_per_text


# 3. Fine-tune CLIP
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
vision_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14")

model = CLIPFineTuner(text_encoder, vision_encoder).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
loss_fn = nn.CrossEntropyLoss()

num_epochs = 30

for epoch in range(num_epochs):
    for images, captions in tqdm(dataloader):
        images = images.to(device)
        inputs = tokenizer(captions, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
        input_ids = inputs.input_ids.to(device)
        attention_mask = inputs.attention_mask.to(device)

        logits_per_image, logits_per_text = model(images, input_ids, attention_mask)

        ground_truth = torch.arange(len(images), dtype=torch.long, device=device)
        loss = (loss_fn(logits_per_image, ground_truth) + loss_fn(logits_per_text, ground_truth)) / 2

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

# Save the fine-tuned model
torch.save(model.state_dict(), "fine_tuned_clip.pth")

100%|██████████| 506/506 [04:56<00:00,  1.71it/s]


Epoch 1/30, Loss: 2.3978962898254395


100%|██████████| 506/506 [04:53<00:00,  1.73it/s]


Epoch 2/30, Loss: 2.3978958129882812


100%|██████████| 506/506 [04:58<00:00,  1.69it/s]


Epoch 3/30, Loss: 2.3978958129882812


100%|██████████| 506/506 [04:56<00:00,  1.71it/s]


Epoch 4/30, Loss: 2.3978965282440186


100%|██████████| 506/506 [05:47<00:00,  1.46it/s]


Epoch 5/30, Loss: 2.397895336151123


100%|██████████| 506/506 [05:14<00:00,  1.61it/s]


Epoch 6/30, Loss: 2.3978958129882812


100%|██████████| 506/506 [05:19<00:00,  1.58it/s]


Epoch 7/30, Loss: 2.397895574569702


100%|██████████| 506/506 [06:47<00:00,  1.24it/s]


Epoch 8/30, Loss: 2.397895574569702


100%|██████████| 506/506 [06:38<00:00,  1.27it/s]


Epoch 9/30, Loss: 2.3978958129882812


100%|██████████| 506/506 [06:04<00:00,  1.39it/s]


Epoch 10/30, Loss: 2.3978958129882812


100%|██████████| 506/506 [05:28<00:00,  1.54it/s]


Epoch 11/30, Loss: 2.397895336151123


100%|██████████| 506/506 [04:54<00:00,  1.72it/s]


Epoch 12/30, Loss: 2.3978958129882812


100%|██████████| 506/506 [04:52<00:00,  1.73it/s]


Epoch 13/30, Loss: 2.397895574569702


100%|██████████| 506/506 [04:52<00:00,  1.73it/s]


Epoch 14/30, Loss: 2.397895336151123


100%|██████████| 506/506 [04:52<00:00,  1.73it/s]


Epoch 15/30, Loss: 2.397895336151123


100%|██████████| 506/506 [04:54<00:00,  1.72it/s]


Epoch 16/30, Loss: 2.397895336151123


100%|██████████| 506/506 [04:55<00:00,  1.71it/s]


Epoch 17/30, Loss: 2.397895336151123


100%|██████████| 506/506 [04:52<00:00,  1.73it/s]


Epoch 18/30, Loss: 2.397895336151123


100%|██████████| 506/506 [04:55<00:00,  1.71it/s]


Epoch 19/30, Loss: 2.397895574569702


100%|██████████| 506/506 [04:53<00:00,  1.72it/s]


Epoch 20/30, Loss: 2.397895336151123


100%|██████████| 506/506 [04:51<00:00,  1.73it/s]


Epoch 21/30, Loss: 2.397895574569702


100%|██████████| 506/506 [04:52<00:00,  1.73it/s]


Epoch 22/30, Loss: 2.397895574569702


100%|██████████| 506/506 [04:52<00:00,  1.73it/s]


Epoch 23/30, Loss: 2.397895336151123


100%|██████████| 506/506 [04:54<00:00,  1.72it/s]


Epoch 24/30, Loss: 2.397895336151123


100%|██████████| 506/506 [04:54<00:00,  1.72it/s]


Epoch 25/30, Loss: 2.397895574569702


 41%|████▏     | 209/506 [02:00<02:51,  1.74it/s]