In [1]:
import os
import json
import cv2
import re
import torch
import clip
from PIL import Image
import torchvision.transforms as transforms

# Load CLIP model and preprocessing function
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

# Define directories
data_dir = "data"
videos_dir = os.path.join(data_dir, "videos")
captions_dir = os.path.join(data_dir, "captions")

# Function to extract YouTube video ID from filename
def extract_video_id(filename):
    match = re.search(r"\[([A-Za-z0-9_-]+)\]", filename)  # Extracts text inside brackets [videoID]
    return match.group(1) if match else None

# Get first available video file
video_files = [f for f in os.listdir(videos_dir) if f.endswith(".mp4")]
if not video_files:
    print("No video files found!")
    exit()

video_file = video_files[0]  # Grab the first video
video_path = os.path.join(videos_dir, video_file)

# Find corresponding caption file
video_id = extract_video_id(video_file)
if not video_id:
    print(f"Could not extract video ID from {video_file}")
    exit()

caption_file = next((f for f in os.listdir(captions_dir) if video_id in f and f.endswith(".json")), None)
if not caption_file:
    print(f"No matching caption file found for {video_file}")
    exit()

caption_path = os.path.join(captions_dir, caption_file)

# Load the caption JSON
with open(caption_path, "r", encoding="utf-8") as f:
    captions = json.load(f)

# Grab the first timestamp and its frames
if not captions:
    print("Caption file is empty!")
    exit()

first_entry = captions[0]
start_time = first_entry["start_time"]
end_time = first_entry["end_time"]
caption_text = first_entry["caption"]
frame_indices = first_entry.get("frames", [])

if not frame_indices:
    print("No frame indices found in the first caption entry!")
    exit()

print(f"\nSanity Check: {video_file}")
print(f"Caption: \"{caption_text}\"")
print(f"Timestamp: {start_time} → {end_time}")
print(f"Frame Indices: {frame_indices}")

# Extract frames and convert them into tensors
cap = cv2.VideoCapture(video_path)
frame_tensors = []

for frame_idx in frame_indices:
    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
    ret, frame = cap.read()
    
    if ret:
        # Convert BGR (OpenCV) to RGB (PIL)
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        pil_image = Image.fromarray(frame_rgb)
        
        # Apply CLIP preprocessing
        frame_tensor = preprocess(pil_image).unsqueeze(0).to(device)  # Add batch dimension
        frame_tensors.append(frame_tensor)
    else:
        print(f"❌ Failed to retrieve frame {frame_idx}")

cap.release()

# Stack frames into a single tensor (batch)
if frame_tensors:
    frames_tensor = torch.cat(frame_tensors, dim=0)
    print(f"Frames tensor shape: {frames_tensor.shape}")  # (batch_size, 3, 224, 224)
else:
    print("❌ No valid frames were processed.")



Sanity Check: Cal Poly Survivor： S3 E8： Like a Mob Boss [jJePD7jcNBQ].mp4
Caption: "previously on C paully"
Timestamp: 00:00:00,359 → 00:00:02,350
Frame Indices: [10, 25, 40, 55, 70]
Frames tensor shape: torch.Size([5, 3, 224, 224])


In [2]:
from torch.utils.data import Dataset
import random

class LazyFrameCaptionDataset(Dataset):
    def __init__(self, videos_dir, captions_dir, tokenizer, preprocess, device="cpu"):
        self.device = device
        self.tokenizer = tokenizer
        self.preprocess = preprocess
        self.metadata = []

        video_files = [f for f in os.listdir(videos_dir) if f.endswith(".mp4")]
        for video_file in video_files:
            video_path = os.path.join(videos_dir, video_file)
            video_id = extract_video_id(video_file)

            if not video_id:
                continue

            # Match caption file
            caption_file = next((f for f in os.listdir(captions_dir) if video_id in f and f.endswith(".json")), None)
            if not caption_file:
                continue

            caption_path = os.path.join(captions_dir, caption_file)
            with open(caption_path, "r", encoding="utf-8") as f:
                captions = json.load(f)

            for entry in captions:
                if not entry.get("frames"):
                    continue

                self.metadata.append({
                    "video_path": video_path,
                    "frame_indices": entry["frames"],
                    "caption": entry["caption"]
                })

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        entry = self.metadata[idx]
        frame_indices = entry["frame_indices"]
        caption_text = entry["caption"]
        video_path = entry["video_path"]

        # Lazy-load one random frame
        cap = cv2.VideoCapture(video_path)
        frame_idx = random.choice(frame_indices)
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
        ret, frame = cap.read()
        cap.release()

        if not ret:
            frame_tensor = torch.zeros((3, 224, 224), dtype=torch.float)
        else:
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            pil_image = Image.fromarray(frame_rgb)
            frame_tensor = self.preprocess(pil_image).to(self.device)

        text_tokens = self.tokenizer([caption_text], truncate=True).squeeze(0).to(self.device)

        return frame_tensor, text_tokens


In [3]:
from torch.utils.data import DataLoader

dataset = LazyFrameCaptionDataset(videos_dir, captions_dir, clip.tokenize, preprocess, device=device)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)


In [4]:
import torch.nn.functional as F

def clip_contrastive_loss(image_features, text_features, temperature=0.07):
    image_features = F.normalize(image_features, dim=-1)
    text_features = F.normalize(text_features, dim=-1)

    logits_per_image = image_features @ text_features.T
    logits_per_text = text_features @ image_features.T

    logits_per_image /= temperature
    logits_per_text /= temperature

    targets = torch.arange(image_features.shape[0], device=image_features.device)
    loss_i2t = F.cross_entropy(logits_per_image, targets)
    loss_t2i = F.cross_entropy(logits_per_text, targets)

    return (loss_i2t + loss_t2i) / 2


In [5]:
import torch.optim as optim
import os

# Create directory to save model checkpoints
save_dir = "checkpoints"
os.makedirs(save_dir, exist_ok=True)

optimizer = optim.Adam(model.parameters(), lr=1e-6)
num_epochs = 10

for epoch in range(num_epochs):
    total_loss = 0
    for images, text_tokens in dataloader:
        images = images.to(device)
        text_tokens = text_tokens.to(device)

        image_features = model.encode_image(images)
        text_features = model.encode_text(text_tokens)

        loss = clip_contrastive_loss(image_features, text_features)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{num_epochs} - Avg Loss: {avg_loss:.4f}")

    # Save checkpoint
    checkpoint_path = os.path.join(save_dir, f"clip_epoch_{epoch+1}.pt")
    torch.save({
        "epoch": epoch + 1,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": avg_loss,
    }, checkpoint_path)

    print(f"✅ Model saved to {checkpoint_path}")


Epoch 1/10 - Avg Loss: 3.2319
Epoch 2/10 - Avg Loss: 2.7050
Epoch 3/10 - Avg Loss: 2.2439
Epoch 4/10 - Avg Loss: 1.8727
Epoch 5/10 - Avg Loss: 1.5761
Epoch 6/10 - Avg Loss: 1.3329


KeyboardInterrupt: 