In [19]:
import os
import json
import cv2
import re
import torch
import clip
from PIL import Image
import torchvision.transforms as transforms

# Load CLIP model and preprocessing function
device = "cpu"#"cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

# Define directories
data_dir = "data"
videos_dir = os.path.join(data_dir, "videos")
captions_dir = os.path.join(data_dir, "captions")

# Function to extract YouTube video ID from filename
def extract_video_id(filename):
    match = re.search(r"\[([A-Za-z0-9_-]+)\]", filename)  # Extracts text inside brackets [videoID]
    return match.group(1) if match else None

# Get first available video file
video_files = [f for f in os.listdir(videos_dir) if f.endswith(".mp4")]
if not video_files:
    print("No video files found!")
    exit()

video_file = video_files[0]  # Grab the first video
video_path = os.path.join(videos_dir, video_file)

# Find corresponding caption file
video_id = extract_video_id(video_file)
if not video_id:
    print(f"Could not extract video ID from {video_file}")
    exit()

caption_file = next((f for f in os.listdir(captions_dir) if video_id in f and f.endswith(".json")), None)
if not caption_file:
    print(f"No matching caption file found for {video_file}")
    exit()

caption_path = os.path.join(captions_dir, caption_file)

# Load the caption JSON
with open(caption_path, "r", encoding="utf-8") as f:
    captions = json.load(f)

# Grab the first timestamp and its frames
if not captions:
    print("Caption file is empty!")
    exit()

first_entry = captions[0]
start_time = first_entry["start_time"]
end_time = first_entry["end_time"]
caption_text = first_entry["caption"]
frame_indices = first_entry.get("frames", [])

if not frame_indices:
    print("No frame indices found in the first caption entry!")
    exit()

print(f"\nSanity Check: {video_file}")
print(f"Caption: \"{caption_text}\"")
print(f"Timestamp: {start_time} → {end_time}")
print(f"Frame Indices: {frame_indices}")

def frames_from_timestamp(timestamp):
    # Extract frames and convert them into tensors
    cap = cv2.VideoCapture(video_path)
    frame_tensors = []

    for frame_idx in frame_indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
        ret, frame = cap.read()
        
        if ret:
            # Convert BGR (OpenCV) to RGB (PIL)
            #show image
            cv2.imshow(f"Frame {frame_idx}", frame)
            cv2.waitKey(500)
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            pil_image = Image.fromarray(frame_rgb)
            
            # Apply CLIP preprocessing
            frame_tensor = preprocess(pil_image).unsqueeze(0).to(device)  # Add batch dimension
            frame_tensors.append(frame_tensor)
        else:
            print(f"❌ Failed to retrieve frame {frame_idx}")

    cap.release()

    # Stack frames into a single tensor (batch)
    if frame_tensors:
        frames_tensor = torch.cat(frame_tensors, dim=0)
        print(f"Frames tensor shape: {frames_tensor.shape}")  # (batch_size, 3, 224, 224)
    else:
        print("❌ No valid frames were processed.")



Sanity Check: Cal Poly Survivor： S3 E10： Loved Ones [Kggc-m8ntVQ].mp4
Caption: "previously on C paully"
Timestamp: 00:00:00,359 → 00:00:02,350
Frame Indices: [10, 25, 40, 55, 70]


In [20]:
import json
import datetime

def time_to_seconds(time_str):
    """Convert timestamp from 'HH:MM:SS,ms' to total seconds."""
    time_obj = datetime.datetime.strptime(time_str, "%H:%M:%S,%f")
    return (time_obj.hour * 3600) + (time_obj.minute * 60) + time_obj.second + (time_obj.microsecond / 1_000_000)

def find_caption_at_time(json_file, query_time):
    """Finds the caption at a specific query time."""
    
    # Load JSON data
    with open(json_file, "r", encoding="utf-8") as f:
        captions = json.load(f)

    # Convert query time to seconds
    query_seconds = time_to_seconds(query_time)

    # Search for the matching caption
    for entry in captions:
        start_seconds = time_to_seconds(entry["start_time"])
        end_seconds = time_to_seconds(entry["end_time"])
        
        if start_seconds <= query_seconds <= end_seconds:
            return entry  # Return full caption entry

    return None  # No matching caption found

In [21]:
#example of finding a caption
json_file_path = "data/captions/Cal Poly Survivor： S3 E1： Wit, Grit & Charm [dfYMJhtFvuU].en.json"
query_time = "00:45:26,500"

find_caption_at_time(json_file_path, query_time)

{'start_time': '00:45:25,440',
 'end_time': '00:45:27,190',
 'caption': "future I can't believe that literally like orange team is here first guys I",
 'frames': [81773, 81786, 81799, 81812, 81826]}

In [22]:
import torch.nn.functional as F

def clip_dot(image_features, text_features, temperature=0.07):
    image_features = F.normalize(image_features, dim=-1)
    text_features = F.normalize(text_features, dim=-1)

    logits_per_image = image_features @ text_features.T
    #logits_per_text = text_features @ image_features.T

    logits_per_image /= temperature
    #logits_per_text /= temperature

    return F.softmax(logits_per_image)



In [23]:
from torch.utils.data import Dataset
import random


def compare_captions(timestamp, cap_list, cap_file, vid_file, show_frame=True):
    video_id = extract_video_id(vid_file)

    if not video_id:
        raise Exception("ouch")


    entry = find_caption_at_time(
        cap_file, timestamp
    )

    frame_indices = entry["frames"]
    caption_text = entry["caption"]

    # Lazy-load one random frame
    cap = cv2.VideoCapture(vid_file)
    frame_idx = random.choice(frame_indices)
    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
    ret, frame = cap.read()
    cap.release()

    if not ret:
        frame_tensor = torch.zeros((3, 224, 224), dtype=torch.float)
    else:
        cv2.imshow("Image", frame)
        cv2.waitKey()
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        pil_image = Image.fromarray(frame_rgb)
        frame_tensor = preprocess(pil_image).to(device)
    
    frame_tensor = frame_tensor.unsqueeze(0)

    texts = [caption_text, *cap_list]
    text_tokens = clip.tokenize(texts, truncate=True).to(device)

    return frame_tensor, text_tokens, texts


In [24]:
import torch.optim as optim
import os
import tqdm

# Create directory to save/load model checkpoints
save_dir = "checkpoints"
os.makedirs(save_dir, exist_ok=True)


# load checkpoint
checkpoint_path = os.path.join(save_dir, f"clip_overnight.pt")
checkpoint = torch.load(checkpoint_path)

model.load_state_dict(checkpoint["model_state_dict"])
model.to(device=device)
model.eval()

CLIP(
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32), bias=False)
    (ln_pre): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): Sequential(
        (0): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): QuickGELU()
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): ResidualAttentionBlock(
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          

In [26]:
frame, text_tok, texts = compare_captions(
    "00:19:05,000",
    ["Gaslight", "Girlboss"],
    "data/captions/Cal Poly Survivor： S3 E1： Wit, Grit & Charm [dfYMJhtFvuU].en.json",
    "data/videos/Cal Poly Survivor： S3 E1： Wit, Grit & Charm [dfYMJhtFvuU].mp4"
)

image_features = model.encode_image(frame)
text_features = model.encode_text(text_tok)

for t, closeness in zip(texts, clip_dot(image_features, text_features).tolist()[0]):
    print(f"P(text): {closeness} | '{t}'")

P(text): 0.9984492063522339 | 'Gaslight and lie to you guys until the game is over I'm so happy I feel so so'
P(text): 0.0010551055893301964 | 'Gaslight'
P(text): 0.0004955871845595539 | 'Girlboss'


  return F.softmax(logits_per_image)
