In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
from transformers import AutoTokenizer, AutoModel
import numpy as np
from PIL import Image
import os
import pandas as pd
import json
from tqdm import tqdm

In [2]:
class CLIPConfig:
    def __init__(self):
        self.embed_dim = 512
        self.num_head = 8
        self.dropout = 0.1
        self.temperature = 0.7
        self.batch_size = 64
        self.epochs = 10
        self.lr = 1e-4
        self.weight_decay = 1e-5
        self.warmup_steps = 500
        self.max_grad_norm = 1.0
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.log_interval = 50
        self.eval_interval = 2
        self.use_amp = True

config = CLIPConfig()



In [3]:
class ProjectionHead(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim = 2048, dropout = 0.1):
        super(ProjectionHead, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, out_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(out_dim, out_dim),
        )
    def forward(self, x):
        return self.mlp(x)

In [4]:
class CLIP(nn.Module):
    def __init__(self, config):
        super(CLIP, self).__init__()
        self.config = config

        weight = ResNet50_Weights.IMAGENET1K_V2
        self.image_encoder = resnet50(weights = weight)
        self.image_encoder.fc = nn.Identity()
        self.image_projection = ProjectionHead(in_dim= 2048, out_dim= config.embed_dim, dropout= config.dropout)


        self.tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
        self.text_encoder = AutoTokenizer.from_pretrained('bert-base-uncased')
        self.text_projection = ProjectionHead(in_dim= 768, out_dim=config.embed_dim, dropout= config.dropout)

        #Tempertature
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1/ config.temperature))

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)

    def encode_image(self, images):
        features = self.image_encoder(images)
        projected = self.image_projection(features)
        return F.normalize(projected, dim = -1)

    def encode_text(self, text):
        inputs = self.tokenizer(text, return_tensors = "pt", padding = True, truncation = True, max_length = 77)
        input_ids = inputs["input_ids"].to(self.config.device)
        attention_mask = inputs['attention_mask'].to(self.config.device)

        outputs = self.text_encoder(input_ids = input_ids, attention_mask = attention_mask)
        pooled_output = outputs.pooler_output
        projected = self.text_projection(pooled_output)
        return F.normalize(projected, dim = -1)

    def forward(self, images, texts):
        images_features = self.encode_image(images)
        text_features = self.encode_text(texts)
        return images_features, text_features




In [5]:
def compute_logits(image_features, text_features, logit_scale):
    return logit_scale.exp() * image_features @ text_features

In [6]:
def contrastive_loss(logits):
    labels = torch.arange(len(logits), device=logits.device)
    loss_i2t = F.cross_entropy(logits, labels)
    loss_t2i = F.cross_entropy(logits.T, labels)
    return (loss_i2t + loss_t2i) / 2


In [7]:
def recall_at_k(sim_matrix, query_indices, k=5):
    top_k = np.min((k, sim_matrix.shape[1]))
    top_k_indices = np.argpartition(sim_matrix, -top_k, axis=1)[:, -top_k:]
    recalls = []
    for i, true_idx in enumerate(query_indices):
        if true_idx in top_k_indices[i]:
            recalls.append(1)
        else:
            recalls.append(0)
    return np.mean(recalls)

In [8]:
class Flickr8kDataset(Dataset):
    def __init__(self, csv_file, image_dir, split="train", transform=None):
        self.image_dir = image_dir

        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.48145466, 0.4578275, 0.40821073],
                std=[0.26862954, 0.26130258, 0.27577711]
            )
        ])

        # Load 2-column CSV: image, caption
        df = pd.read_csv(csv_file)
        df.columns = ["ImageName", "Comment"]  # normalize names

        # Image-level split (no leakage)
        unique_imgs = df["ImageName"].unique()
        np.random.seed(42)
        train_imgs = np.random.choice(unique_imgs, int(0.95 * len(unique_imgs)), replace=False)
        val_imgs = np.setdiff1d(unique_imgs, train_imgs)

        if split == "train":
            self.df = df[df["ImageName"].isin(train_imgs)].reset_index(drop=True)
        else:
            self.df = df[df["ImageName"].isin(val_imgs)].reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.image_dir, row["ImageName"])

        try:
            image = Image.open(img_path).convert("RGB")
        except:
            image = Image.new("RGB", (224, 224), color="gray")

        image = self.transform(image)
        caption = row["Comment"]

        return image, caption


In [9]:
class WarmupLinearSchedule(optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup_steps, total_steps, last_epoch=-1):
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        if self.last_epoch < self.warmup_steps:
            lr_mult = float(self.last_epoch) / float(max(1, self.warmup_steps))
        else:
            lr_mult = 1.0 - float(self.last_epoch - self.warmup_steps) / float(max(1, self.total_steps - self.warmup_steps))
        return [base_lr * lr_mult for base_lr in self.base_lrs]

In [10]:
def evaluate_retrieval(model, val_dataloader, k_values=[1, 5, 10]):
    model.eval()
    all_image_features = []
    all_text_features = []
    all_true_indices = []

    with torch.no_grad():
        for images, captions in tqdm(val_dataloader, desc='Encoding for eval'):
            images = images.to(config.device)
            i_feats = model.encode_image(images)
            all_image_features.append(i_feats.cpu())

            t_feats = model.encode_text(captions)
            all_text_features.append(t_feats.cpu())
            all_true_indices.extend([i for i in range(len(captions))])  # Assume paired

    all_image_features = torch.cat(all_image_features)
    all_text_features = torch.cat(all_text_features)

    # Compute similarities (image-to-text)
    logits = (all_image_features @ all_text_features.T).numpy()

    recalls = {}
    for k in k_values:
        recalls[f'R@{k} (i2t)'] = recall_at_k(logits, all_true_indices, k=k)
        # Symmetric t2i
        recalls[f'R@{k} (t2i)'] = recall_at_k(logits.T, all_true_indices, k=k)

    return recalls


def train_clip(model, train_loader, val_loader, config):
    device = config.device
    model.to(device)

    optimizer = optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
    total_steps = len(train_loader) * config.epochs
    scheduler = WarmupLinearSchedule(optimizer, config.warmup_steps, total_steps)

    global_step = 0
    for epoch in range(config.epochs):
        model.train()
        total_loss = 0
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{config.epochs}')

        for batch_idx, (images, captions) in enumerate(pbar):
            images = images.to(device)
            captions = list(captions)  # List of strings

            optimizer.zero_grad()

            image_features, text_features = model(images, captions)
            logits = compute_logits(image_features, text_features, model.logit_scale)
            loss = contrastive_loss(logits)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
            optimizer.step()
            scheduler.step()

            total_loss += loss.item()
            global_step += 1

            pbar.set_postfix({'Loss': f'{loss.item():.4f}', 'LR': f'{scheduler.get_last_lr()[0]:.2e}'})

            if global_step % config.log_interval == 0:
                print(f'Step {global_step}: Loss = {loss.item():.4f}')

        avg_loss = total_loss / len(train_loader)
        print(f'Epoch {epoch+1} completed. Avg Loss: {avg_loss:.4f}')

        if (epoch + 1) % config.eval_interval == 0:
            recalls = evaluate_retrieval(model, val_loader)
            print(f'Eval Metrics: {recalls}')

    return model

In [11]:
import kagglehub

path = kagglehub.dataset_download("adityajn105/flickr8k")

print("Path to dataset files:", path)

Path to dataset files: C:\Users\modam\.cache\kagglehub\datasets\adityajn105\flickr8k\versions\1


In [None]:
root = r"C:\Users\modam\.cache\kagglehub\datasets\adityajn105\flickr8k\versions\1"

train_dataset = Flickr8kDataset(
    csv_file=f"{root}\\captions.txt",
    image_dir=f"{root}\\Images",
    split="train"
)

val_dataset = Flickr8kDataset(
    csv_file=f"{root}\\captions.txt",
    image_dir=f"{root}\\Images",
    split="val"
)


train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True)

print(f"Train size: {len(train_dataset)}, Val size: {len(val_dataset)}")

model = CLIP(config)

trained_model = train_clip(model, train_loader, val_loader, config)

torch.save({
        'model_state_dict': trained_model.state_dict(),
        'config': config.__dict__,
        'logit_scale': trained_model.logit_scale.item()
    }, 'flickr30k_clip_pretrained.pth')
print("Flickr30k pretraining complete. Model saved as 'flickr30k_clip_pretrained.pth'.")

Train size: 38430, Val size: 2025


Epoch 1/10:   0%|          | 0/19215 [00:00<?, ?it/s]

In [None]:
from PIL import Image

def clip_inference(model, image_path, candidate_texts, device="cuda"):
    model.to(device)
    model.eval()

    image = Image.open(image_path).convert("RGB")
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.48145466, 0.4578275, 0.40821073],
            std=[0.26862954, 0.26130258, 0.27577711]
        )
    ])
    image = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        img_feat = model.encode_image(image)
        txt_feat = model.encode_text(candidate_texts)

    sims = (img_feat @ txt_feat.T).squeeze(0)

    scores, indices = torch.sort(sims, descending=True)
    for rank, idx in enumerate(indices):
        print(f"{rank+1}. {candidate_texts[idx]}    (score={scores[rank]:.4f})")

    return candidate_texts[indices[0]]
