In [None]:
!pip install decord
!pip install wandb

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
from transformers import DistilBertTokenizer, DistilBertModel
import os
import numpy as np
from PIL import Image
import random
from decord import VideoReader
from decord import cpu
from tqdm import tqdm
import pandas as pd
from scipy.stats import spearmanr
import torch.nn.functional as F
import wandb


In [None]:
!pip install kaggle --upgrade
os.environ['KAGGLE_USERNAME'] = "XXXXX"
os.environ['KAGGLE_KEY'] = "XXXXXXXXXXXXXX"

!kaggle datasets download sarthakjain004/msvd-clips
!unzip msvd-clips.zip

In [None]:
!kaggle datasets download julianschelb/wordsim353-crowd
!unzip wordsim353-crowd.zip

In [None]:
class CFG:
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer=DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

    val_ratio=0.2
    batch_size=8
    lr=1e-5
    weight_decay=1e-4
    train_epochs=12

    video_dropout=0.3
    text_dropout=0.3
    video_trainable=True
    text_trainable=True

    image_mode=False
    seed=42
    num_captions=3
    num_frames=16
    temperature=0.07

    cache_dir="cache"
    video_dir="YouTubeClips"
    annotation_file="annotations.txt"
    simverb_path="simverb-3500.csv"
    wordsim_path="wordsim353crowd.csv"

In [None]:
!wandb login
config_dict = {
    key: str(value) if isinstance(value, (torch.device,)) else value
    for key, value in CFG.__dict__.items()
    if not key.startswith("__")
}
config_dict
wandb.init(project="CLIP", config=config_dict)

In [None]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def worker_init_fn(worker_id):
    seed = torch.initial_seed() % 2**32
    np.random.seed(seed)
    random.seed(seed)


set_seed(CFG.seed)

generator = torch.Generator()
generator.manual_seed(CFG.seed)

In [None]:
def evaluate_benchmarks(model):
  text_encoder = model.text_encoder
  text_encoder.eval()

  def get_embedding(word):
    tokens = tokenizer(word, return_tensors="pt", padding="max_length", truncation=True, max_length=30)
    tokens = {k: v.to(CFG.device) for k, v in tokens.items()}
    with torch.no_grad():
        embedding = text_encoder(tokens)
    return embedding.squeeze(0)

  def cosine_similarity(emb1, emb2):
    emb1 = F.normalize(emb1, dim=0)
    emb2 = F.normalize(emb2, dim=0)
    return torch.dot(emb1, emb2).item()

  results = {}
  df_verb = pd.read_csv(CFG.simverb_path)
  df_word = pd.read_csv(CFG.wordsim_path)

  model_sims = []
  human_sims = []
  for _, row in df_verb.iterrows():
      emb1 = get_embedding(row["word1"])
      emb2 = get_embedding(row["word2"])
      sim = cosine_similarity(emb1, emb2)
      model_sims.append(sim)
      human_sims.append(row["similarity"])
  corr, p_value = spearmanr(human_sims, model_sims)
  results["simverb_corr"] = corr

  model_sims = []
  human_sims = []
  for _, row in df_word.iterrows():
      emb1 = get_embedding(row["Word 1"])
      emb2 = get_embedding(row["Word 2"])
      sim = cosine_similarity(emb1, emb2)
      model_sims.append(sim)
      human_sims.append(row["Human (Mean)"])
  corr, p_value = spearmanr(human_sims, model_sims)
  results["wordsim_corr"] = corr

  return results



In [None]:
class MSVDVideoCaptionDataset(Dataset):
    def __init__(self, video_dir, annotation_file, tokenizer, cache_dir="cache", num_frames=8, transform=None, video_ids=None, single_frame=False, num_captions=5, seed=42):
        self.video_dir = video_dir
        self.transform = transform
        self.num_frames = num_frames
        self.tokenizer = tokenizer
        self.single_frame = single_frame
        self.cache_dir = cache_dir
        self.num_captions = num_captions

        os.makedirs(cache_dir, exist_ok=True)
        rng = random.Random(seed)

        self.captions_by_video = {}
        with open(annotation_file, 'r') as f:
            for line in f:
                parts = line.strip().split()
                video_id = parts[0] + ".avi"
                caption = " ".join(parts[1:])
                if (video_ids is None) or (video_id in video_ids):
                    if video_id not in self.captions_by_video:
                        self.captions_by_video[video_id] = []
                    self.captions_by_video[video_id].append(caption)

        self.samples = [(vid, caption) for vid, captions in self.captions_by_video.items() for caption in rng.sample(captions, min(len(captions), self.num_captions))]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        video_id, caption = self.samples[idx]
        base_cache_path = os.path.join(self.cache_dir, video_id[:-4])
        os.makedirs(base_cache_path, exist_ok=True)

        if self.single_frame:
            cache_path = os.path.join(base_cache_path, "frame.pt")
            if os.path.exists(cache_path):
                frames = torch.load(cache_path)
            else:
                video_path = os.path.join(self.video_dir, video_id)
                vr = VideoReader(video_path, ctx=cpu(0))
                total_frames = len(vr)
                frame_idx = total_frames // 2
                frame = Image.fromarray(vr[frame_idx].asnumpy())
                if self.transform:
                    frame = self.transform(frame)
                frames = frame.unsqueeze(0)
                torch.save(frames, cache_path)
        else:
            cache_path = os.path.join(base_cache_path, f"frames_{self.num_frames}.pt")
            if os.path.exists(cache_path):
                frames = torch.load(cache_path)
            else:
                video_path = os.path.join(self.video_dir, video_id)
                vr = VideoReader(video_path, ctx=cpu(0))
                total_frames = len(vr)
                if total_frames >= self.num_frames:
                    start = int(0.2 * total_frames)
                    end = int(0.8 * total_frames)
                    indices = np.linspace(start, end - 1, self.num_frames, dtype=int)
                else:
                    indices = np.linspace(0, total_frames - 1, self.num_frames, dtype=int)
                    indices = np.clip(indices, 0, total_frames - 1)
                frames_np = vr.get_batch(indices).asnumpy()
                frames = [Image.fromarray(frame) for frame in frames_np]
                if self.transform:
                    frames = [self.transform(img) for img in frames]
                frames = torch.stack(frames)
                torch.save(frames, cache_path)

        tokens = self.tokenizer(caption, return_tensors="pt", padding="max_length", truncation=True, max_length=30)
        return frames, tokens

In [None]:
class VideoEncoder(nn.Module):
    def __init__(self, embed_dim=512, trainable=True, dropout=0.3):
        super(VideoEncoder, self).__init__()
        cnn = models.resnet18(pretrained=True)
        if not trainable:
            for param in cnn.parameters():
                param.requires_grad = False
        self.backbone = nn.Sequential(*list(cnn.children())[:-1])
        self.fc = nn.Linear(512, embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.frame_score_layer = nn.Linear(embed_dim, 1)

    def forward(self, x):
        B, T, C, H, W = x.size()
        x = x.view(B * T, C, H, W)
        features = self.backbone(x).squeeze(-1).squeeze(-1)
        features = self.fc(features)
        features = features.view(B, T, -1)

        scores = self.frame_score_layer(features).squeeze(-1)
        weights = torch.softmax(scores, dim=1)

        weighted_features = (features * weights.unsqueeze(-1)).sum(dim=1)
        return self.dropout(weighted_features)

In [None]:
class TextEncoder(nn.Module):
    def __init__(self, embed_dim=512, pretrained=True, trainable=True, dropout=0.3):
        super(TextEncoder, self).__init__()
        if pretrained:
            self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
        else:
            self.bert = DistilBertModel(config=DistilBertConfig())

        if not trainable:
            for param in self.bert.parameters():
                param.requires_grad = False

        self.fc = nn.Linear(self.bert.config.hidden_size, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, tokens):
        output = self.bert(**tokens).last_hidden_state[:, 0]
        return self.dropout(self.fc(output))

In [None]:
class CLIP(nn.Module):
    def __init__(self):
        super(CLIP, self).__init__()
        self.video_encoder = VideoEncoder(trainable=CFG.video_trainable, dropout=CFG.video_dropout)
        self.text_encoder = TextEncoder(trainable=CFG.text_trainable, dropout=CFG.text_dropout)

    def forward(self, video_frames, text_tokens):
        video_embed = self.video_encoder(video_frames)
        text_embed = self.text_encoder(text_tokens)
        return video_embed, text_embed

In [None]:
def contrastive_loss_hard(video_embed, text_embed, temperature=0.07):
    video_embed = nn.functional.normalize(video_embed, dim=-1)
    text_embed = nn.functional.normalize(text_embed, dim=-1)

    logits = video_embed @ text_embed.T / temperature
    labels = torch.arange(len(video_embed)).to(video_embed.device)

    loss_i2t = nn.CrossEntropyLoss()(logits, labels)
    loss_t2i = nn.CrossEntropyLoss()(logits.T, labels)

    return (loss_i2t + loss_t2i) / 2


In [None]:
def train_loop(model, train_loader, val_loader, optimizer, device, num_epochs=6):
    os.makedirs("output", exist_ok=True)
    best_loss = float('inf')

    for epoch in range(num_epochs):
        model.train()
        model.text_encoder.train()
        total_loss = 0
        pbar = tqdm(train_loader, desc=f"Training Epoch {epoch+1}")
        for video_frames, text_tokens in pbar:
            video_frames = video_frames.to(device)
            text_tokens = {k: v.squeeze(1).to(device) for k, v in text_tokens.items()}

            video_embed, text_embed = model(video_frames, text_tokens)
            loss = contrastive_loss_hard(video_embed, text_embed, temperature=CFG.temperature)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            pbar.set_postfix({"batch_loss": f"{loss.item():.4f}"})
        avg_train_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f}")

        if val_loader:
            model.eval()
            val_loss = 0
            pbar = tqdm(val_loader, desc=f"Validation Epoch {epoch+1}")
            with torch.no_grad():
                for video_frames, text_tokens in pbar:
                    video_frames = video_frames.to(device)
                    text_tokens = {k: v.squeeze(1).to(device) for k, v in text_tokens.items()}
                    video_embed, text_embed = model(video_frames, text_tokens)
                    loss = contrastive_loss_hard(video_embed, text_embed, temperature=CFG.temperature)
                    val_loss += loss.item()
                    pbar.set_postfix({"batch_loss": f"{loss.item():.4f}"})
            avg_val_loss = val_loss / len(val_loader)
            print(f"Epoch {epoch+1}: Val Loss = {avg_val_loss:.4f}")

        if avg_val_loss <= best_loss:
          best_loss = avg_val_loss
          torch.save(model.state_dict(), f"output/clip_best.pt")
          print("Saved best model")

        results = evaluate_benchmarks(model)
        print("wordsim: ", results["wordsim_corr"])
        print("simverb: ", results["simverb_corr"])

        wandb.log({
          "epoch": epoch+1,
          "train_loss": avg_train_loss,
          "val_loss": avg_val_loss,
          "wordsim_corr": results["wordsim_corr"],
          "simverb_corr": results["simverb_corr"]
        })
        wandb.save("output/clip_best.pt")

    wandb.finish()


In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

In [None]:
tokenizer = CFG.tokenizer
all_videos = set()
with open(CFG.annotation_file, 'r') as f:
    for line in f:
        video_id = line.strip().split()[0] + ".avi"
        all_videos.add(video_id)

all_videos = list(all_videos)
random.shuffle(all_videos)
val_ratio = CFG.val_ratio
val_size = int(len(all_videos) * val_ratio)
val_videos = set(all_videos[:val_size])
train_videos = set(all_videos[val_size:])

In [None]:
set_seed(42)
train_dataset = MSVDVideoCaptionDataset(
  video_dir=CFG.video_dir,
  annotation_file=CFG.annotation_file,
  tokenizer=tokenizer,
  cache_dir=CFG.cache_dir,
  num_frames=CFG.num_frames,
  transform=transform,
  video_ids=train_videos,
  single_frame=CFG.image_mode,
  num_captions=CFG.num_captions,
  seed=CFG.seed
)

val_dataset = MSVDVideoCaptionDataset(
  video_dir=CFG.video_dir,
  annotation_file=CFG.annotation_file,
  tokenizer=tokenizer,
  cache_dir=CFG.cache_dir,
  num_frames=CFG.num_frames,
  transform=transform,
  video_ids=val_videos,
  single_frame=CFG.image_mode,
  num_captions=CFG.num_captions,
  seed=CFG.seed
)

In [None]:
print(len(train_dataset), len(val_dataset))

In [None]:
total_tokens = 0

for _, tokens in train_dataset:
    input_ids = tokens["input_ids"].squeeze(0)
    token_count = (input_ids != 0).sum().item()
    total_tokens += token_count

print(f"Number if tokens in train data: {total_tokens}")

In [None]:
train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=4, worker_init_fn=worker_init_fn, generator=generator)
val_loader = DataLoader(val_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn, generator=generator)

In [None]:
len(train_loader.dataset)

In [None]:
device = CFG.device
model = CLIP().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)

In [None]:
train_loop(model, train_loader, val_loader, optimizer, device, num_epochs=CFG.train_epochs)

# Evaluation
