Video retrieval by embedding single frames using CLIP

In [1]:
import os
import csv
from towhee import ops, pipe, register
from towhee.operator import PyOperator
from towhee import DataCollection
from tqdm import tqdm
import pandas as pd
import json
import numpy as np
from helpers import milvus_utils
from helpers.extract_frames import extract_frame, extract_n_frames

Connected to Milvus server at port 19530


In [97]:
# CONSTANTS

# Files
MSRVTT_SAMPLES = "./MSRVTT_1K.csv"
MSRVTT_SAMPLES_WITH_FRAMES = "./MSRVTT_1K_frames.csv"
# file created using raw FIRE judgements, see clean_fire_judgements.ipynb
FIRE_BENCHMARK_Q_JUDGEMENTS = "./fire_benchmark_q_judgements.csv" 

# Database Collections
FRAME_RET_COLLECTION = "msrvtt_multi_frame_ret_8b"

In [3]:
raw_samples_df = pd.read_csv(MSRVTT_SAMPLES)
raw_samples_df[['video_id', 'video_path', 'sentence']].head()

Unnamed: 0,video_id,video_path,sentence
0,video7579,./test_1k_compress/video7579.mp4,a girl wearing red top and black trouser is pu...
1,video7725,./test_1k_compress/video7725.mp4,young people sit around the edges of a room cl...
2,video9258,./test_1k_compress/video9258.mp4,a person is using a phone
3,video7365,./test_1k_compress/video7365.mp4,cartoon people are eating at a restaurant
4,video8068,./test_1k_compress/video8068.mp4,a woman on a couch talks to a a man


Before we embed video frames, we need to extract and/or construct a single frame from each of the 1000 videos.

In [20]:
import cv2


def extract_n_frames_2(video_path, output_folder, n=10):
    """
    Extract n equally spaced frames from a video file and save them to a directory.
    
    Args:
        video_path (str): Path to the input video file
        output_folder (str): Path to the output directory to save frames
        n (int): Number of equally spaced frames to extract (default: 10)
    """
    # Create the output folder if it doesn't exist
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # Open the video file
    video = cv2.VideoCapture(video_path)
    
    # Get the total number of frames in the video
    total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
    
    # Calculate the frame interval to get n equally spaced frames
    if total_frames <= n:
        # If video has fewer frames than requested, extract all frames
        frame_indices = list(range(total_frames))
    else:
        # Calculate indices of equally spaced frames
        frame_indices = [int(i * total_frames / n) for i in range(n)]
    
    # Get the video filename for naming the frames
    video_name = os.path.splitext(os.path.basename(video_path))[0]
    
    # Extract the frames at the calculated indices
    for i, frame_index in enumerate(frame_indices):
        # Set the video position to the desired frame
        video.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
        
        # Read the frame
        ret, frame = video.read()
        
        # Break if frame reading failed
        if not ret:
            print(f"Failed to read frame at index {frame_index}")
            continue
        
        # Generate the output filename
        output_filename = f"{video_name}_frame_{i+1:03d}_of_{n:03d}.jpg"
        output_path = os.path.join(output_folder, output_filename)
        
        # Save the frame as an image
        cv2.imwrite(output_path, frame)
    
    # Release the video capture object
    video.release()
    
    # Return the list of saved frame paths
    frame_paths = [os.path.join(output_folder, f"{video_name}_frame_{i+1:03d}_of_{n:03d}.jpg") for i in range(len(frame_indices))]
    return frame_paths

In [35]:
NUM_FRAMES = 8
print(f"Creating {NUM_FRAMES} per video...")
for row in tqdm(raw_samples_df.iterrows(), total=len(raw_samples_df)):
    video_path = row[1]['video_path']
    images_dir = "./test_1k_images_8"
    image_name = os.path.basename(video_path).split('.')[0]
    image_path = os.path.join(images_dir, image_name) + ".jpg"
    
    all_frame_paths = extract_n_frames_2(video_path, images_dir, NUM_FRAMES)
    for num_f in range(NUM_FRAMES):
        raw_samples_df.at[row[0], f'frame_path_{num_f+1}'] = all_frame_paths[num_f]
        
    raw_samples_df.at[row[0], 'frame_id'] = image_name
    
# Now this should contain new columns with the frame_path and frame_id
raw_samples_df

Creating 8 per video...


100%|██████████| 1000/1000 [00:42<00:00, 23.54it/s]


Unnamed: 0.1,Unnamed: 0,key,vid_key,video_id,sentence,video_path,frame_path_1,frame_path_2,frame_path_3,frame_path_4,frame_id,frame_path_5,frame_path_6,frame_path_7,frame_path_8
0,521,ret521,msr7579,video7579,a girl wearing red top and black trouser is pu...,./test_1k_compress/video7579.mp4,./test_1k_images_8/video7579_frame_001_of_008.jpg,./test_1k_images_8/video7579_frame_002_of_008.jpg,./test_1k_images_8/video7579_frame_003_of_008.jpg,./test_1k_images_8/video7579_frame_004_of_008.jpg,video7579,./test_1k_images_8/video7579_frame_005_of_008.jpg,./test_1k_images_8/video7579_frame_006_of_008.jpg,./test_1k_images_8/video7579_frame_007_of_008.jpg,./test_1k_images_8/video7579_frame_008_of_008.jpg
1,737,ret737,msr7725,video7725,young people sit around the edges of a room cl...,./test_1k_compress/video7725.mp4,./test_1k_images_8/video7725_frame_001_of_008.jpg,./test_1k_images_8/video7725_frame_002_of_008.jpg,./test_1k_images_8/video7725_frame_003_of_008.jpg,./test_1k_images_8/video7725_frame_004_of_008.jpg,video7725,./test_1k_images_8/video7725_frame_005_of_008.jpg,./test_1k_images_8/video7725_frame_006_of_008.jpg,./test_1k_images_8/video7725_frame_007_of_008.jpg,./test_1k_images_8/video7725_frame_008_of_008.jpg
2,740,ret740,msr9258,video9258,a person is using a phone,./test_1k_compress/video9258.mp4,./test_1k_images_8/video9258_frame_001_of_008.jpg,./test_1k_images_8/video9258_frame_002_of_008.jpg,./test_1k_images_8/video9258_frame_003_of_008.jpg,./test_1k_images_8/video9258_frame_004_of_008.jpg,video9258,./test_1k_images_8/video9258_frame_005_of_008.jpg,./test_1k_images_8/video9258_frame_006_of_008.jpg,./test_1k_images_8/video9258_frame_007_of_008.jpg,./test_1k_images_8/video9258_frame_008_of_008.jpg
3,660,ret660,msr7365,video7365,cartoon people are eating at a restaurant,./test_1k_compress/video7365.mp4,./test_1k_images_8/video7365_frame_001_of_008.jpg,./test_1k_images_8/video7365_frame_002_of_008.jpg,./test_1k_images_8/video7365_frame_003_of_008.jpg,./test_1k_images_8/video7365_frame_004_of_008.jpg,video7365,./test_1k_images_8/video7365_frame_005_of_008.jpg,./test_1k_images_8/video7365_frame_006_of_008.jpg,./test_1k_images_8/video7365_frame_007_of_008.jpg,./test_1k_images_8/video7365_frame_008_of_008.jpg
4,411,ret411,msr8068,video8068,a woman on a couch talks to a a man,./test_1k_compress/video8068.mp4,./test_1k_images_8/video8068_frame_001_of_008.jpg,./test_1k_images_8/video8068_frame_002_of_008.jpg,./test_1k_images_8/video8068_frame_003_of_008.jpg,./test_1k_images_8/video8068_frame_004_of_008.jpg,video8068,./test_1k_images_8/video8068_frame_005_of_008.jpg,./test_1k_images_8/video8068_frame_006_of_008.jpg,./test_1k_images_8/video8068_frame_007_of_008.jpg,./test_1k_images_8/video8068_frame_008_of_008.jpg
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,106,ret106,msr7034,video7034,man in black shirt is holding a baby upside do...,./test_1k_compress/video7034.mp4,./test_1k_images_8/video7034_frame_001_of_008.jpg,./test_1k_images_8/video7034_frame_002_of_008.jpg,./test_1k_images_8/video7034_frame_003_of_008.jpg,./test_1k_images_8/video7034_frame_004_of_008.jpg,video7034,./test_1k_images_8/video7034_frame_005_of_008.jpg,./test_1k_images_8/video7034_frame_006_of_008.jpg,./test_1k_images_8/video7034_frame_007_of_008.jpg,./test_1k_images_8/video7034_frame_008_of_008.jpg
996,270,ret270,msr7568,video7568,the queen of england is seen walking with an e...,./test_1k_compress/video7568.mp4,./test_1k_images_8/video7568_frame_001_of_008.jpg,./test_1k_images_8/video7568_frame_002_of_008.jpg,./test_1k_images_8/video7568_frame_003_of_008.jpg,./test_1k_images_8/video7568_frame_004_of_008.jpg,video7568,./test_1k_images_8/video7568_frame_005_of_008.jpg,./test_1k_images_8/video7568_frame_006_of_008.jpg,./test_1k_images_8/video7568_frame_007_of_008.jpg,./test_1k_images_8/video7568_frame_008_of_008.jpg
997,860,ret860,msr7979,video7979,people talking about a fight,./test_1k_compress/video7979.mp4,./test_1k_images_8/video7979_frame_001_of_008.jpg,./test_1k_images_8/video7979_frame_002_of_008.jpg,./test_1k_images_8/video7979_frame_003_of_008.jpg,./test_1k_images_8/video7979_frame_004_of_008.jpg,video7979,./test_1k_images_8/video7979_frame_005_of_008.jpg,./test_1k_images_8/video7979_frame_006_of_008.jpg,./test_1k_images_8/video7979_frame_007_of_008.jpg,./test_1k_images_8/video7979_frame_008_of_008.jpg
998,435,ret435,msr7356,video7356,a vehicle with details on what comes with it b...,./test_1k_compress/video7356.mp4,./test_1k_images_8/video7356_frame_001_of_008.jpg,./test_1k_images_8/video7356_frame_002_of_008.jpg,./test_1k_images_8/video7356_frame_003_of_008.jpg,./test_1k_images_8/video7356_frame_004_of_008.jpg,video7356,./test_1k_images_8/video7356_frame_005_of_008.jpg,./test_1k_images_8/video7356_frame_006_of_008.jpg,./test_1k_images_8/video7356_frame_007_of_008.jpg,./test_1k_images_8/video7356_frame_008_of_008.jpg


In [96]:
# We write the transformed samples data to a CSV file so it can be loaded into the load pipeline
# raw_samples_df[['video_id', 'frame_path', 'frame_id', 'sentence']].to_csv(MSRVTT_SAMPLES_WITH_FRAMES, index=False)
def frame_path_cols(num_frames):
    cols = []
    for i in range(num_frames):
        cols.append(f'frame_path_{i+1}')
    return cols
columns = frame_path_cols(NUM_FRAMES) + ['video_id', 'frame_id', 'sentence']
print(columns)
raw_samples_df[columns].to_csv(MSRVTT_SAMPLES_WITH_FRAMES, index=False)

['frame_path_1', 'frame_path_2', 'frame_path_3', 'frame_path_4', 'frame_path_5', 'frame_path_6', 'frame_path_7', 'frame_path_8', 'video_id', 'frame_id', 'sentence']


In [98]:
# Create the collection in Milvus to store image embeddings
milvus_utils.create_milvus_collection(FRAME_RET_COLLECTION, 512)

<Collection>:
-------------
<name>: msrvtt_multi_frame_ret_8b
<description>: video retrieval
<schema>: {'auto_id': False, 'description': 'video retrieval', 'fields': [{'name': 'id', 'description': '', 'type': <DataType.INT64: 5>, 'is_primary': True, 'auto_id': False}, {'name': 'embedding', 'description': '', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 512}}]}

In [None]:
def read_frame_loader_csv(csv_path, encoding='utf-8-sig'):
    with open(csv_path, 'r', encoding=encoding) as f:
        data = csv.DictReader(f)
        for line in data:
            raw_id = line['frame_id']
            cleaned_id = raw_id[len('video'):]
            # print("yeilding", cleaned_id, line['frame_path_1'], line['frame_path_2'], line['frame_path_3'], line['frame_path_4'],
            #       line['frame_path_5'], line['frame_path_6'], line['frame_path_7'], line['frame_path_8'])
            values_to_yield = []
            values_to_yield.append(int(cleaned_id))
            for i in range(NUM_FRAMES):
                values_to_yield.append(line[f'frame_path_{i+1}'])
            
            yield values_to_yield
            # yield int(cleaned_id), line['frame_path_1'], line['frame_path_2'], line['frame_path_3'], line['frame_path_4'] 


frame_loader_pipeline = (
    pipe.input('csv_file')
    # .flat_map('csv_file', ('frame_id', 'f_path_1', 'f_path_2', 'f_path_3', 'f_path_4'), read_frame_loader_csv)
    .flat_map('csv_file', ('frame_id', 'f_path_1', 'f_path_2', 'f_path_3', 'f_path_4', 'f_path_5', 'f_path_6', 'f_path_7', 'f_path_8'), read_frame_loader_csv)
    # .map('frame_id', 'frame_id', lambda fid: print(fid))
    .map('f_path_1', 'img1', ops.image_decode.cv2('rgb'))
    .map('f_path_2', 'img2', ops.image_decode.cv2('rgb'))
    .map('f_path_3', 'img3', ops.image_decode.cv2('rgb'))
    .map('f_path_4', 'img4', ops.image_decode.cv2('rgb'))
    .map('f_path_5', 'img5', ops.image_decode.cv2('rgb'))
    .map('f_path_6', 'img6', ops.image_decode.cv2('rgb'))
    .map('f_path_7', 'img7', ops.image_decode.cv2('rgb'))
    .map('f_path_8', 'img8', ops.image_decode.cv2('rgb'))
    
    .map('img1', 'vec1', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image', device='mps'))
    .map('img2', 'vec2', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image', device='mps'))
    .map('img3', 'vec3', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image', device='mps'))
    .map('img4', 'vec4', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image', device='mps'))
    .map('img5', 'vec5', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image', device='mps'))
    .map('img6', 'vec6', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image', device='mps'))
    .map('img7', 'vec7', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image', device='mps'))
    .map('img8', 'vec8', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image', device='mps'))
    .map(('vec1', 'vec2', 'vec3', 'vec4', 'vec5', 'vec6', 'vec7', 'vec8'), 'vec', lambda v1, v2, v3, v4, v5, v6, v7, v8: np.mean([v1, v2, v3, v4, v5, v6, v7, v8], axis=0))
    # .map(('vec1', 'vec2', 'vec3', 'vec4'), 'vec', lambda v1, v2, v3, v4: np.mean([v1, v2, v3, v4], axis=0))
    .map('vec', 'vec', lambda x: x / np.linalg.norm(x))
    .map(('frame_id', 'vec'), (), ops.ann_insert.milvus_client(collection_name=FRAME_RET_COLLECTION))
    .output('frame_id')
)



2025-04-17 11:41:06,721 - 8454604864 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/main/config.json HTTP/1.1" 200 0
2025-04-17 11:41:06,752 - 8454604864 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/main/config.json HTTP/1.1" 200 0
2025-04-17 11:41:06,790 - 8454604864 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/main/model.safetensors HTTP/1.1" 404 0
2025-04-17 11:41:06,798 - 18874396672 - connectionpool.py-connectionpool:1049 - DEBUG: Starting new HTTPS connection (1): huggingface.co:443
2025-04-17 11:41:06,864 - 18874396672 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "GET /api/models/openai/clip-vit-base-patch16 HTTP/1.1" 200 3499
2025-04-17 11:41:06,914 - 18874396672 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co

2025-04-17 11:41:12,915 - 14630809600 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "GET /api/models/openai/clip-vit-base-patch16/discussions?p=0 HTTP/1.1" 200 6380
2025-04-17 11:41:12,971 - 14630809600 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "GET /api/models/openai/clip-vit-base-patch16/commits/refs%2Fpr%2F10 HTTP/1.1" 200 4064
2025-04-17 11:41:13,010 - 14630809600 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/refs%2Fpr%2F10/model.safetensors.index.json HTTP/1.1" 404 0
2025-04-17 11:41:13,041 - 14630809600 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/refs%2Fpr%2F10/model.safetensors HTTP/1.1" 302 0


In [100]:
frame_loader_pipeline(MSRVTT_SAMPLES_WITH_FRAMES)

2025-04-17 11:41:16,308 - 14630809600 - node.py-node:167 - INFO: Begin to run Node-_input
2025-04-17 11:41:16,308 - 14647635968 - node.py-node:167 - INFO: Begin to run Node-read_frame_loader_csv-0
2025-04-17 11:41:16,309 - 14664462336 - node.py-node:167 - INFO: Begin to run Node-image-decode/cv2-1
2025-04-17 11:41:16,309 - 14681288704 - node.py-node:167 - INFO: Begin to run Node-image-decode/cv2-2
2025-04-17 11:41:16,310 - 14698115072 - node.py-node:167 - INFO: Begin to run Node-image-decode/cv2-3
2025-04-17 11:41:16,310 - 14630809600 - node.py-node:167 - INFO: Begin to run Node-image-decode/cv2-4
2025-04-17 11:41:16,310 - 14714941440 - node.py-node:167 - INFO: Begin to run Node-image-decode/cv2-5
2025-04-17 11:41:16,311 - 14731767808 - node.py-node:167 - INFO: Begin to run Node-image-decode/cv2-6
2025-04-17 11:41:16,311 - 14748594176 - node.py-node:167 - INFO: Begin to run Node-image-decode/cv2-7
2025-04-17 11:41:16,312 - 14765420544 - node.py-node:167 - INFO: Begin to run Node-image-

<towhee.runtime.data_queue.DataQueue at 0x32f007520>

In [101]:
def read_frame_search_csv(csv_file):
    with open(csv_file, 'r', encoding='utf-8-sig') as f:
        data = csv.DictReader(f)
        for line in data:
            yield line['frame_id'], line['sentence']
            
frame_search_pipeline = (
    pipe.input('csv_file')
    .flat_map('csv_file', ('rel_frame_id', 'query'), read_frame_search_csv)
    .map('query', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='text', device='mps'))
    .map('vec', 'vec', lambda x: x / np.linalg.norm(x))
    .map('vec', 'top10_raw_res', ops.ann_search.milvus_client(collection_name=FRAME_RET_COLLECTION, limit=10))
    # .map('vec', 'top10_raw_res', 
    #      ops.ann_search.milvus_client(collection_name=VIDEO_RET_COLLECTION, limit=10))
    .map('top10_raw_res', ('top1', 'top5', 'top10'), lambda x: (x[:1], x[:5], x[:10]))
    .output('rel_frame_id', 'query', 'top1', 'top5', 'top10')
    # .output('vec')
)

2025-04-17 11:48:50,225 - 8454604864 - connectionpool.py-connectionpool:289 - DEBUG: Resetting dropped connection: huggingface.co
2025-04-17 11:48:50,404 - 8454604864 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/main/config.json HTTP/1.1" 200 0
2025-04-17 11:48:50,479 - 8454604864 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/main/config.json HTTP/1.1" 200 0
2025-04-17 11:48:50,513 - 8454604864 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/main/model.safetensors HTTP/1.1" 404 0
2025-04-17 11:48:50,518 - 15049191424 - connectionpool.py-connectionpool:1049 - DEBUG: Starting new HTTPS connection (1): huggingface.co:443
2025-04-17 11:48:50,588 - 15049191424 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "GET /api/models/openai/clip-vit-base-patch

In [103]:
ret_dc = DataCollection(frame_search_pipeline(MSRVTT_SAMPLES_WITH_FRAMES))

2025-04-17 11:50:21,987 - 15267295232 - node.py-node:167 - INFO: Begin to run Node-read_frame_search_csv-0


In [104]:
ret_dc.show()

rel_frame_id,query,top1,top5,top10
video7579,a girl wearing red top and black trouser is putting a sweater on a dog,"[[7579, 1.360194444656372]] len=1","[[7579, 1.360194444656372],[9451, 1.4077376127243042],[7730, 1.4435832500457764],[9034, 1.4514113664627075],...] len=5","[[7579, 1.360194444656372],[9451, 1.4077376127243042],[7730, 1.4435832500457764],[9034, 1.4514113664627075],...] len=10"
video7725,young people sit around the edges of a room clapping and raising their arms while others dance in the center during a party,"[[7725, 1.3989646434783936]] len=1","[[7725, 1.3989646434783936],[7444, 1.4208428859710693],[8556, 1.4406566619873047],[8441, 1.4462244510650635],...] len=5","[[7725, 1.3989646434783936],[7444, 1.4208428859710693],[8556, 1.4406566619873047],[8441, 1.4462244510650635],...] len=10"
video9258,a person is using a phone,"[[9257, 1.422074317932129]] len=1","[[9257, 1.422074317932129],[9697, 1.4267330169677734],[9258, 1.4296905994415283],[8945, 1.4394712448120117],...] len=5","[[9257, 1.422074317932129],[9697, 1.4267330169677734],[9258, 1.4296905994415283],[8945, 1.4394712448120117],...] len=10"
video7365,cartoon people are eating at a restaurant,"[[9777, 1.3951191902160645]] len=1","[[9777, 1.3951191902160645],[9537, 1.4230724573135376],[7365, 1.4236586093902588],[7741, 1.4339861869812012],...] len=5","[[9777, 1.3951191902160645],[9537, 1.4230724573135376],[7365, 1.4236586093902588],[7741, 1.4339861869812012],...] len=10"
video8068,a woman on a couch talks to a a man,"[[7724, 1.3635627031326294]] len=1","[[7724, 1.3635627031326294],[7341, 1.4104743003845215],[9347, 1.4147610664367676],[7685, 1.4171432256698608],...] len=5","[[7724, 1.3635627031326294],[7341, 1.4104743003845215],[9347, 1.4147610664367676],[7685, 1.4171432256698608],...] len=10"


In [105]:
# TODO remove this, import from helpers and rerun the whole notebook

def twohee_data_col_to_df(twohee_data_collection):
    res_list = twohee_data_collection.to_list()
    res_obj_list = []
    for r in res_list:
        res_obj = vars(r)
        res_obj_list.append(res_obj)
    res_df = pd.DataFrame(res_obj_list)
    
    # Add ground truth column
    if 'rel_video_id' in res_df.columns:
        res_df['ground_truth'] = res_df['rel_video_id'].apply(
            lambda x: int(x[len('video'):]))
    if 'rel_frame_id' in res_df.columns:
        res_df['ground_truth'] = res_df['rel_frame_id'].apply(
            lambda x: int(x[len('video'):]))
    else:
        raise ValueError("No rel_video_id or rel_frame_id found in the DataCollection")
    return res_df.copy()


def average_precision(ground_truth, predictions):
    """
    Calculate the Average Precision (AP) for a single query.

    Args:
        ground_truth (int): The ground truth video ID.
        predictions (list): List of predicted video IDs.

    Returns:
        float: The Average Precision (AP) score for the query.
    """
    hits = 0
    sum_precision = 0
    for i, pred in enumerate(predictions):
        if pred == ground_truth:
            hits += 1
            sum_precision += hits / (i + 1)
    return sum_precision / hits if hits > 0 else 0


def calculate_mean_average_precision(df):
    """
    Calculate the Mean Average Precision (MAP) for the given dataframe.

    Args:
        df (pd.DataFrame): DataFrame containing columns 'query', 'ground_truth', 'top1', 'top5', 'top10'.

    Returns:
        float: The Mean Average Precision (MAP) score.
    """
    # Calculate AP for each query
    ap_scores = []
    for _, row in df.iterrows():
        ground_truth = row['ground_truth']
        predictions_with_scores = row['top10']
        predictions = [pred[0] for pred in predictions_with_scores]
        ap_scores.append(average_precision(ground_truth, predictions))

    # Calculate MAP
    mean_ap = sum(ap_scores) / len(ap_scores) if ap_scores else 0
    return mean_ap


def calculate_recall(df):
    """
    Calculate recall@1, recall@5, and recall@10 for the given dataframe.

    Args:
        df (pd.DataFrame): DataFrame containing columns 'query', 'ground_truth', 'top1', 'top5', 'top10'.

    Returns:
        dict: A dictionary containing recall@1, recall@5, and recall@10.
    """
    recall_at_1 = 0
    recall_at_5 = 0
    recall_at_10 = 0
    total_queries = len(df)

    for _, row in df.iterrows():
        ground_truth = row['ground_truth']
        if ground_truth in [pred[0] for pred in row['top1']]:
            recall_at_1 += 1
        if ground_truth in [pred[0] for pred in row['top5']]:
            recall_at_5 += 1
        if ground_truth in [pred[0] for pred in row['top10']]:
            recall_at_10 += 1

    return {
        'recall@1': recall_at_1 / total_queries,
        'recall@5': recall_at_5 / total_queries,
        'recall@10': recall_at_10 / total_queries
    }


def ndcg_score(ground_truth, predictions, k=10):
    """
    Calculate the Normalized Discounted Cumulative Gain (NDCG) for a single query.

    Args:
        ground_truth (int): The ground truth video ID.
        predictions (list): List of predicted video IDs with scores [(id, score), ...].
        k (int): The number of top predictions to consider.

    Returns:
        float: The NDCG score for the query.
    """
    def dcg(relevance_scores):
        return sum(rel / np.log2(idx + 2) for idx, rel in enumerate(relevance_scores))

    # Relevance scores: 1 if the prediction matches the ground truth, else 0
    relevance_scores = [1 if pred[0] ==
                        ground_truth else 0 for pred in predictions[:k]]

    # Calculate DCG and IDCG
    actual_dcg = dcg(relevance_scores)
    ideal_dcg = dcg(sorted(relevance_scores, reverse=True))

    # Return NDCG
    return actual_dcg / ideal_dcg if ideal_dcg > 0 else 0

# call this function to get the NDCG score for each query


def calculate_ndcg(df, k=10):
    """
    Calculate NDCG for the given dataframe.

    Args:
        df (pd.DataFrame): DataFrame containing columns 'query', 'ground_truth', 'top1', 'top5', 'top10'.
        k (int): The number of top predictions to consider.

    Returns:
        float: The mean NDCG score.
    """
    ndcg_scores = []
    for _, row in df.iterrows():
        ground_truth = row['ground_truth']
        predictions_with_scores = row['top10']
        ndcg_scores.append(ndcg_score(
            ground_truth, predictions_with_scores, k))

    return sum(ndcg_scores) / len(ndcg_scores) if ndcg_scores else 0


def get_all_eval_scores(df):
    """Return a dataframe with all evaluation scores: Recall@1, Recall@5, Recall@10, MAP, NDCG@1, NDCG@5, NDCG@10"""
    recall_scores = calculate_recall(df)
    map_score = calculate_mean_average_precision(df)
    ndcg_score_1 = calculate_ndcg(df, k=1)
    ndcg_score_5 = calculate_ndcg(df, k=5)
    ndcg_score_10 = calculate_ndcg(df, k=10)

    eval_scores = {
        'recall@1': recall_scores['recall@1'],
        'recall@5': recall_scores['recall@5'],
        'recall@10': recall_scores['recall@10'],
        'map': map_score,
        'ndcg@1': ndcg_score_1,
        'ndcg@5': ndcg_score_5,
        'ndcg@10': ndcg_score_10
    }

    return eval_scores


In [106]:
twohee_data_col_to_df(ret_dc)

Unnamed: 0,rel_frame_id,query,top1,top5,top10,ground_truth
0,video7579,a girl wearing red top and black trouser is pu...,"[[7579, 1.360194444656372]]","[[7579, 1.360194444656372], [9451, 1.407737612...","[[7579, 1.360194444656372], [9451, 1.407737612...",7579
1,video7725,young people sit around the edges of a room cl...,"[[7725, 1.3989646434783936]]","[[7725, 1.3989646434783936], [7444, 1.42084288...","[[7725, 1.3989646434783936], [7444, 1.42084288...",7725
2,video9258,a person is using a phone,"[[9257, 1.422074317932129]]","[[9257, 1.422074317932129], [9697, 1.426733016...","[[9257, 1.422074317932129], [9697, 1.426733016...",9258
3,video7365,cartoon people are eating at a restaurant,"[[9777, 1.3951191902160645]]","[[9777, 1.3951191902160645], [9537, 1.42307245...","[[9777, 1.3951191902160645], [9537, 1.42307245...",7365
4,video8068,a woman on a couch talks to a a man,"[[7724, 1.3635627031326294]]","[[7724, 1.3635627031326294], [7341, 1.41047430...","[[7724, 1.3635627031326294], [7341, 1.41047430...",8068
...,...,...,...,...,...,...
995,video7034,man in black shirt is holding a baby upside do...,"[[9522, 1.4664983749389648]]","[[9522, 1.4664983749389648], [9320, 1.47163987...","[[9522, 1.4664983749389648], [9320, 1.47163987...",7034
996,video7568,the queen of england is seen walking with an e...,"[[7568, 1.2713391780853271]]","[[7568, 1.2713391780853271], [7116, 1.43614828...","[[7568, 1.2713391780853271], [7116, 1.43614828...",7568
997,video7979,people talking about a fight,"[[7211, 1.4263927936553955]]","[[7211, 1.4263927936553955], [7835, 1.44549965...","[[7211, 1.4263927936553955], [7835, 1.44549965...",7979
998,video7356,a vehicle with details on what comes with it b...,"[[7356, 1.3739084005355835]]","[[7356, 1.3739084005355835], [9358, 1.39286637...","[[7356, 1.3739084005355835], [9358, 1.39286637...",7356


In [108]:
get_all_eval_scores(twohee_data_col_to_df(ret_dc))

{'recall@1': 0.332,
 'recall@5': 0.551,
 'recall@10': 0.656,
 'map': 0.42762857142857114,
 'ndcg@1': 0.332,
 'ndcg@5': 0.44768332562758373,
 'ndcg@10': 0.48185493150879904}

# Try evaluation against queries from FIRE benchmark

We are working with a sample of MSR-VTT and our evaluation pipeline supports only one relevant query per video, hence we need to filter the full FIRE benchmark to only include videos we have sampled and ones with a single relevant result.

FIRE_BENCHMARK_Q_JUDGEMENTS is created in the notebook `./clean_fire_judgements.ipynb`

In [110]:
# Run query pipeline using FIRE


# CSV parser function and pipeline recreated since the FIRE csv uses `video_id` instead of `frame_id`
def read_frame_search_fire_csv(csv_file):
    with open(csv_file, 'r', encoding='utf-8-sig') as f:
        data = csv.DictReader(f)
        for line in data:
            yield line['video_id'], line['sentence']
            
frame_search_fire_pipeline = (
    pipe.input('csv_file')
    .flat_map('csv_file', ('rel_frame_id', 'query'), read_frame_search_fire_csv)
    .map('query', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='text', device='mps'))
    .map('vec', 'vec', lambda x: x / np.linalg.norm(x))
    .map('vec', 'top10_raw_res', ops.ann_search.milvus_client(collection_name=FRAME_RET_COLLECTION, limit=10))
    .map('top10_raw_res', ('top1', 'top5', 'top10'), lambda x: (x[:1], x[:5], x[:10]))
    .output('rel_frame_id', 'query', 'top1', 'top5', 'top10')
)
            
fire_query_results = DataCollection(frame_search_fire_pipeline(FIRE_BENCHMARK_Q_JUDGEMENTS))
fire_query_results.show()

2025-04-17 12:06:37,280 - 8454604864 - connectionpool.py-connectionpool:289 - DEBUG: Resetting dropped connection: huggingface.co
2025-04-17 12:06:37,372 - 8454604864 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/main/config.json HTTP/1.1" 200 0
2025-04-17 12:06:37,399 - 8454604864 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/main/config.json HTTP/1.1" 200 0
2025-04-17 12:06:37,464 - 8454604864 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/main/model.safetensors HTTP/1.1" 404 0
2025-04-17 12:06:37,467 - 15284121600 - connectionpool.py-connectionpool:289 - DEBUG: Resetting dropped connection: huggingface.co
2025-04-17 12:06:37,770 - 15284121600 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "GET /api/models/openai/clip-vit-base-patch16 HTTP/1.

rel_frame_id,query,top1,top5,top10
video8469,two parrots in a bird cage one white chick and on green adult,"[[7546, 1.4175093173980713]] len=1","[[7546, 1.4175093173980713],[8469, 1.4307200908660889],[7849, 1.4471447467803955],[8481, 1.4707032442092896],...] len=5","[[7546, 1.4175093173980713],[8469, 1.4307200908660889],[7849, 1.4471447467803955],[8481, 1.4707032442092896],...] len=10"
video9687,a man chopping lobster and taking off the shell,"[[9687, 1.3523515462875366]] len=1","[[9687, 1.3523515462875366],[9309, 1.3865060806274414],[8025, 1.3969764709472656],[8686, 1.4153034687042236],...] len=5","[[9687, 1.3523515462875366],[9309, 1.3865060806274414],[8025, 1.3969764709472656],[8686, 1.4153034687042236],...] len=10"
video7698,two women are walking in a parking lot,"[[8899, 1.390440583229065]] len=1","[[8899, 1.390440583229065],[7138, 1.4211058616638184],[7212, 1.42289137840271],[8016, 1.4232873916625977],...] len=5","[[8899, 1.390440583229065],[7138, 1.4211058616638184],[7212, 1.42289137840271],[8016, 1.4232873916625977],...] len=10"
video9503,a woman is talking about how jeans with patches or rips is trendy,"[[9503, 1.387997031211853]] len=1","[[9503, 1.387997031211853],[8825, 1.424526572227478],[7724, 1.441739559173584],[8488, 1.4568039178848267],...] len=5","[[9503, 1.387997031211853],[8825, 1.424526572227478],[7724, 1.441739559173584],[8488, 1.4568039178848267],...] len=10"
video8903,a naked child runs through a field,"[[8125, 1.4102811813354492]] len=1","[[8125, 1.4102811813354492],[9240, 1.42606782913208],[8931, 1.4346017837524414],[9031, 1.4469490051269531],...] len=5","[[8125, 1.4102811813354492],[9240, 1.42606782913208],[8931, 1.4346017837524414],[9031, 1.4469490051269531],...] len=10"


In [111]:
get_all_eval_scores(twohee_data_col_to_df(fire_query_results))

{'recall@1': 0.49044585987261147,
 'recall@5': 0.6878980891719745,
 'recall@10': 0.7484076433121019,
 'map': 0.5742796481649985,
 'ndcg@1': 0.49044585987261147,
 'ndcg@5': 0.5961737289559949,
 'ndcg@10': 0.616343918349418}