In [45]:
import cv2
import numpy as np
from tqdm import tqdm

import torch
import clip
from PIL import Image

In [99]:
def extract_key_frames(video_path, interval=30):
    cap = cv2.VideoCapture(video_path)
    prev_hist = None
    key_frames = []
    frame_count = 0

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        
        # Uniform sampling
        if frame_count % interval == 0:
            key_frames.append(frame)

        # Shot bounday detection 
        hist = cv2.calcHist([frame], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256])
        hist = cv2.normalize(hist, hist).flatten()

        if prev_hist is not None:
            hist_diff = cv2.compareHist(prev_hist, hist, cv2.HISTCMP_CORREL)
            if hist_diff < 0.5:  # Threshold
                key_frames.append(frame)

        prev_hist = hist
        frame_count += 1

    cap.release()

    return key_frames

def get_image_embeddings(key_frames, model, preprocess, device):
    image_embeddings = []

    for frame in tqdm(key_frames):
        # Convert OpenCV image (BGR) to PIL image (RGB)
        pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        
        # Preprocess the image
        preprocessed_image = preprocess(pil_image).unsqueeze(0).to(device)

        # Get the image embedding
        with torch.no_grad():
            image_embedding = model.encode_image(preprocessed_image)
        
        image_embeddings.append(image_embedding.cpu().numpy())

    image_embeddings_tensor = torch.tensor(np.vstack(image_embeddings))
    return image_embeddings_tensor

def get_text_embedding(text, model, device):
    # Preprocess the input text
    text_inputs = clip.tokenize([text]).to(device)
    
    # Get the text embedding
    with torch.no_grad():
        text_embedding = model.encode_text(text_inputs)
    
    return text_embedding.cpu()

def cosine_sim(x, y):
    x = x / x.norm(dim=1, keepdim=True)
    y = y / y.norm(dim=1, keepdim=True)
    return x @ y.T # dot prod

def top_k_similarity(image_embeddings, text_embedding, k=5):
    similarities = cosine_sim(image_embeddings, text_embedding)
    if k == 1:
        top_k_indices = torch.topk(similarities, k, dim=0).indices.item()
        return [top_k_indices]
    else:
        top_k_indices = torch.topk(similarities, k, dim=0).indices.squeeze().cpu().numpy()
        return top_k_indices.tolist()

In [100]:
def main(video_path, text_prompt, interval=30, k=1):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load("ViT-B/32", device=device)
    key_frames = extract_key_frames(video_path, interval=interval)
    image_embeddings = get_image_embeddings(key_frames, model, preprocess, device)
    text_embedding = get_text_embedding(text_prompt, model, device)
    top_k_indices = top_k_similarity(image_embeddings, text_embedding, k=k)
    print(top_k_indices)
    top_k_frames = [key_frames[i] for i in top_k_indices]
    for i, frame in enumerate(top_k_frames):
        cv2.imwrite(f'top_frame_{i}.jpg', frame)
    return top_k_frames
    

In [102]:
top_k_frames = main("test2.mp4", "lava")
for i, frame in enumerate(top_k_frames):
    cv2.imwrite(f'top_frame_{i}.jpg', frame)

100%|██████████| 148/148 [00:11<00:00, 12.90it/s]


[136]
