Video retrieval by embedding single frames using CLIP

In [10]:
import os
import csv
from towhee import ops, pipe, register
from towhee.operator import PyOperator
from towhee import DataCollection
from tqdm import tqdm
import pandas as pd
import json
import numpy as np
from helpers import milvus_utils
from helpers.extract_frames import extract_frame

In [15]:
# CONSTANTS

# Files
MSRVTT_SAMPLES = "./MSRVTT_1K.csv"
MSRVTT_SAMPLES_WITH_FRAMES = "./MSRVTT_1K_frames.csv"
# file created using raw FIRE judgements, see clean_fire_judgements.ipynb
FIRE_BENCHMARK_Q_JUDGEMENTS = "./fire_benchmark_q_judgements.csv" 

# Database Collections
FRAME_RET_COLLECTION = "msrvtt_frame_ret_1"

In [5]:
raw_samples_df = pd.read_csv(MSRVTT_SAMPLES)
raw_samples_df[['video_id', 'video_path', 'sentence']].head()

Unnamed: 0,video_id,video_path,sentence
0,video7579,./test_1k_compress/video7579.mp4,a girl wearing red top and black trouser is pu...
1,video7725,./test_1k_compress/video7725.mp4,young people sit around the edges of a room cl...
2,video9258,./test_1k_compress/video9258.mp4,a person is using a phone
3,video7365,./test_1k_compress/video7365.mp4,cartoon people are eating at a restaurant
4,video8068,./test_1k_compress/video8068.mp4,a woman on a couch talks to a a man


Before we embed video frames, we need to extract and/or construct a single frame from each of the 1000 videos.

In [13]:
for row in tqdm(raw_samples_df.iterrows(), total=len(raw_samples_df)):
    video_path = row[1]['video_path']
    images_dir = "./test_1k_images"
    image_name = os.path.basename(video_path).split('.')[0]
    image_path = os.path.join(images_dir, image_name) + ".jpg"
    # extract frame 
    extract_frame(video_path, image_path)
    # add column with val to current row
    raw_samples_df.at[row[0], 'frame_path'] = image_path
    raw_samples_df.at[row[0], 'frame_id'] = image_name
    
# Now this should contain new columns with the frame_path and frame_id
raw_samples_df

100%|██████████| 1000/1000 [00:09<00:00, 110.02it/s]


Unnamed: 0.1,Unnamed: 0,key,vid_key,video_id,sentence,video_path,frame_path,frame_id
0,521,ret521,msr7579,video7579,a girl wearing red top and black trouser is pu...,./test_1k_compress/video7579.mp4,./test_1k_images/video7579.jpg,video7579
1,737,ret737,msr7725,video7725,young people sit around the edges of a room cl...,./test_1k_compress/video7725.mp4,./test_1k_images/video7725.jpg,video7725
2,740,ret740,msr9258,video9258,a person is using a phone,./test_1k_compress/video9258.mp4,./test_1k_images/video9258.jpg,video9258
3,660,ret660,msr7365,video7365,cartoon people are eating at a restaurant,./test_1k_compress/video7365.mp4,./test_1k_images/video7365.jpg,video7365
4,411,ret411,msr8068,video8068,a woman on a couch talks to a a man,./test_1k_compress/video8068.mp4,./test_1k_images/video8068.jpg,video8068
...,...,...,...,...,...,...,...,...
995,106,ret106,msr7034,video7034,man in black shirt is holding a baby upside do...,./test_1k_compress/video7034.mp4,./test_1k_images/video7034.jpg,video7034
996,270,ret270,msr7568,video7568,the queen of england is seen walking with an e...,./test_1k_compress/video7568.mp4,./test_1k_images/video7568.jpg,video7568
997,860,ret860,msr7979,video7979,people talking about a fight,./test_1k_compress/video7979.mp4,./test_1k_images/video7979.jpg,video7979
998,435,ret435,msr7356,video7356,a vehicle with details on what comes with it b...,./test_1k_compress/video7356.mp4,./test_1k_images/video7356.jpg,video7356


In [16]:
# We write the transformed samples data to a CSV file so it can be loaded into the load pipeline
raw_samples_df[['video_id', 'frame_path', 'frame_id', 'sentence']].to_csv(MSRVTT_SAMPLES_WITH_FRAMES, index=False)

In [17]:
# Create the collection in Milvus to store image embeddings
milvus_utils.create_milvus_collection(FRAME_RET_COLLECTION, 512)

<Collection>:
-------------
<name>: msrvtt_frame_ret_1
<description>: video retrieval
<schema>: {'auto_id': False, 'description': 'video retrieval', 'fields': [{'name': 'id', 'description': '', 'type': <DataType.INT64: 5>, 'is_primary': True, 'auto_id': False}, {'name': 'embedding', 'description': '', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 512}}]}

In [20]:
def read_frame_loader_csv(csv_path, encoding='utf-8-sig'):
    with open(csv_path, 'r', encoding=encoding) as f:
        data = csv.DictReader(f)
        for line in data:
            raw_id = line['frame_id']
            cleaned_id = raw_id[len('video'):]
            yield int(cleaned_id), line['frame_path']

frame_loader_pipeline = (
    pipe.input('csv_file')
    .flat_map('csv_file', ('frame_id', 'frame_path'), read_frame_loader_csv)
    .map('frame_path', 'img', ops.image_decode.cv2('rgb'))
    .map('img', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image', device='mps'))
    .map('vec', 'vec', lambda x: x / np.linalg.norm(x))
    .map(('frame_id', 'vec'), (), ops.ann_insert.milvus_client(collection_name=FRAME_RET_COLLECTION))
    .output()
)

2025-04-16 17:13:57,032 - 8454604864 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/main/config.json HTTP/1.1" 200 0
2025-04-16 17:13:57,081 - 8454604864 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/main/config.json HTTP/1.1" 200 0
2025-04-16 17:13:57,137 - 8454604864 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/main/model.safetensors HTTP/1.1" 404 0
2025-04-16 17:13:57,140 - 14545874944 - connectionpool.py-connectionpool:1049 - DEBUG: Starting new HTTPS connection (1): huggingface.co:443
2025-04-16 17:13:57,266 - 8454604864 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
2025-04-16 17:13:57,323 - 14545874944 - connectionpool.py-connectionpool:544 - DEBUG: h

In [None]:
frame_loader_pipeline(MSRVTT_SAMPLES_WITH_FRAMES)

2025-04-16 17:14:32,098 - 14545874944 - node.py-node:167 - INFO: Begin to run Node-_input
2025-04-16 17:14:32,098 - 14576283648 - node.py-node:167 - INFO: Begin to run Node-read_frame_loader_csv-0
2025-04-16 17:14:32,098 - 14593110016 - node.py-node:167 - INFO: Begin to run Node-image-decode/cv2-1
2025-04-16 17:14:32,099 - 14609936384 - node.py-node:167 - INFO: Begin to run Node-image-text-embedding/clip-2
2025-04-16 17:14:32,099 - 14626762752 - node.py-node:167 - INFO: Begin to run Node-lambda-3
2025-04-16 17:14:32,100 - 14545874944 - node.py-node:167 - INFO: Begin to run Node-ann-insert/milvus-client-4
2025-04-16 17:14:32,100 - 14643589120 - node.py-node:167 - INFO: Begin to run Node-_output


<towhee.runtime.data_queue.DataQueue at 0x360f3ea10>

2025-04-16 17:25:48,058 - 14576283648 - node.py-node:167 - INFO: Begin to run Node-_input
2025-04-16 17:25:48,058 - 14593110016 - node.py-node:167 - INFO: Begin to run Node-image-decode/cv2-1
2025-04-16 17:25:48,058 - 14609936384 - node.py-node:167 - INFO: Begin to run Node-image-text-embedding/clip-2
2025-04-16 17:25:48,059 - 14626762752 - node.py-node:167 - INFO: Begin to run Node-lambda-3
2025-04-16 17:25:48,059 - 14545874944 - node.py-node:167 - INFO: Begin to run Node-ann-insert/milvus-client-4
2025-04-16 17:25:48,059 - 14643589120 - node.py-node:167 - INFO: Begin to run Node-_output
2025-04-16 17:27:26,008 - 14576283648 - node.py-node:167 - INFO: Begin to run Node-_input
2025-04-16 17:27:26,008 - 14593110016 - node.py-node:167 - INFO: Begin to run Node-image-decode/cv2-1
2025-04-16 17:27:26,009 - 14609936384 - node.py-node:167 - INFO: Begin to run Node-image-text-embedding/clip-2
2025-04-16 17:27:26,009 - 14626762752 - node.py-node:167 - INFO: Begin to run Node-lambda-3
2025-04-1

In [None]:
def read_frame_search_csv(csv_file):
    with open(csv_file, 'r', encoding='utf-8-sig') as f:
        data = csv.DictReader(f)
        for line in data:
            yield line['frame_id'], line['sentence']
            
frame_search_pipeline = (
    pipe.input('csv_file')
    .flat_map('csv_file', ('rel_frame_id', 'query'), read_frame_search_csv)
    .map('query', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='text', device='mps'))
    .map('vec', 'vec', lambda x: x / np.linalg.norm(x))
    .map('vec', 'top10_raw_res', ops.ann_search.milvus_client(collection_name=FRAME_RET_COLLECTION, limit=10))
    # .map('vec', 'top10_raw_res', 
    #      ops.ann_search.milvus_client(collection_name=VIDEO_RET_COLLECTION, limit=10))
    .map('top10_raw_res', ('top1', 'top5', 'top10'), lambda x: (x[:1], x[:5], x[:10]))
    .output('rel_frame_id', 'query', 'top1', 'top5', 'top10')
    # .output('vec')
)

2025-04-16 17:29:06,010 - 8454604864 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/main/config.json HTTP/1.1" 200 0
2025-04-16 17:29:06,391 - 8454604864 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/main/config.json HTTP/1.1" 200 0
2025-04-16 17:29:06,424 - 8454604864 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/main/model.safetensors HTTP/1.1" 404 0
2025-04-16 17:29:06,426 - 15110057984 - connectionpool.py-connectionpool:1049 - DEBUG: Starting new HTTPS connection (1): huggingface.co:443
2025-04-16 17:29:06,512 - 15110057984 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "GET /api/models/openai/clip-vit-base-patch16 HTTP/1.1" 200 3499
2025-04-16 17:29:06,522 - 8454604864 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:

2025-04-16 17:29:07,006 - 15110057984 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/refs%2Fpr%2F10/model.safetensors.index.json HTTP/1.1" 404 0
2025-04-16 17:29:07,055 - 15110057984 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/refs%2Fpr%2F10/model.safetensors HTTP/1.1" 302 0


In [36]:
ret_dc = DataCollection(frame_search_pipeline(MSRVTT_SAMPLES_WITH_FRAMES))

2025-04-16 17:30:47,544 - 15126884352 - node.py-node:167 - INFO: Begin to run Node-read_frame_search_csv-0
2025-04-16 17:30:47,545 - 15143710720 - node.py-node:167 - INFO: Begin to run Node-image-text-embedding/clip-1
2025-04-16 17:30:47,545 - 15160537088 - node.py-node:167 - INFO: Begin to run Node-lambda-2
2025-04-16 17:30:47,545 - 15177363456 - node.py-node:167 - INFO: Begin to run Node-ann-search/milvus-client-3
2025-04-16 17:30:47,546 - 15194189824 - node.py-node:167 - INFO: Begin to run Node-lambda-4


2025-04-16 17:42:04,875 - 15126884352 - node.py-node:167 - INFO: Begin to run Node-read_frame_search_csv-0
2025-04-16 17:42:04,875 - 15143710720 - node.py-node:167 - INFO: Begin to run Node-image-text-embedding/clip-1
2025-04-16 17:42:04,876 - 15160537088 - node.py-node:167 - INFO: Begin to run Node-lambda-2
2025-04-16 17:42:04,876 - 15177363456 - node.py-node:167 - INFO: Begin to run Node-ann-search/milvus-client-3
2025-04-16 17:42:04,876 - 15194189824 - node.py-node:167 - INFO: Begin to run Node-lambda-4
2025-04-16 17:42:04,880 - 15126884352 - node.py-node:142 - INFO: read_frame_search_csv-0 ends with status: NodeStatus.FAILED
2025-04-16 17:43:27,083 - 15126884352 - node.py-node:167 - INFO: Begin to run Node-read_frame_search_csv-0
2025-04-16 17:43:27,084 - 15143710720 - node.py-node:167 - INFO: Begin to run Node-image-text-embedding/clip-1
2025-04-16 17:43:27,084 - 15160537088 - node.py-node:167 - INFO: Begin to run Node-lambda-2
2025-04-16 17:43:27,084 - 15177363456 - node.py-node:

In [37]:
ret_dc.show()

rel_frame_id,query,top1,top5,top10
video7579,a girl wearing red top and black trouser is putting a sweater on a dog,"[[7579, 1.4025176763534546]] len=1","[[7579, 1.4025176763534546],[7113, 1.4259119033813477],[8044, 1.4578959941864014]] len=3","[[7579, 1.4025176763534546],[7113, 1.4259119033813477],[8044, 1.4578959941864014]] len=3"
video7725,young people sit around the edges of a room clapping and raising their arms while others dance in the center during a party,"[[8441, 1.3825130462646484]] len=1","[[8441, 1.3825130462646484],[7725, 1.4156618118286133],[9908, 1.4577584266662598]] len=3","[[8441, 1.3825130462646484],[7725, 1.4156618118286133],[9908, 1.4577584266662598]] len=3"
video9258,a person is using a phone,"[[7728, 1.4489479064941406]] len=1","[[7728, 1.4489479064941406],[7029, 1.4623463153839111],[9258, 1.4660117626190186]] len=3","[[7728, 1.4489479064941406],[7029, 1.4623463153839111],[9258, 1.4660117626190186]] len=3"
video7365,cartoon people are eating at a restaurant,"[[8911, 1.4093154668807983]] len=1","[[8911, 1.4093154668807983],[9777, 1.4120123386383057],[7365, 1.4395549297332764]] len=3","[[8911, 1.4093154668807983],[9777, 1.4120123386383057],[7365, 1.4395549297332764]] len=3"
video8068,a woman on a couch talks to a a man,"[[9793, 1.3899463415145874]] len=1","[[9793, 1.3899463415145874],[7724, 1.4295322895050049],[7549, 1.4335342645645142]] len=3","[[9793, 1.3899463415145874],[7724, 1.4295322895050049],[7549, 1.4335342645645142]] len=3"


In [46]:
# TODO remove this, import from helpers and rerun the whole notebook

def twohee_data_col_to_df(twohee_data_collection):
    res_list = twohee_data_collection.to_list()
    res_obj_list = []
    for r in res_list:
        res_obj = vars(r)
        res_obj_list.append(res_obj)
    res_df = pd.DataFrame(res_obj_list)
    
    # Add ground truth column
    if 'rel_video_id' in res_df.columns:
        res_df['ground_truth'] = res_df['rel_video_id'].apply(
            lambda x: int(x[len('video'):]))
    if 'rel_frame_id' in res_df.columns:
        res_df['ground_truth'] = res_df['rel_frame_id'].apply(
            lambda x: int(x[len('video'):]))
    else:
        raise ValueError("No rel_video_id or rel_frame_id found in the DataCollection")
    return res_df.copy()


def average_precision(ground_truth, predictions):
    """
    Calculate the Average Precision (AP) for a single query.

    Args:
        ground_truth (int): The ground truth video ID.
        predictions (list): List of predicted video IDs.

    Returns:
        float: The Average Precision (AP) score for the query.
    """
    hits = 0
    sum_precision = 0
    for i, pred in enumerate(predictions):
        if pred == ground_truth:
            hits += 1
            sum_precision += hits / (i + 1)
    return sum_precision / hits if hits > 0 else 0


def calculate_mean_average_precision(df):
    """
    Calculate the Mean Average Precision (MAP) for the given dataframe.

    Args:
        df (pd.DataFrame): DataFrame containing columns 'query', 'ground_truth', 'top1', 'top5', 'top10'.

    Returns:
        float: The Mean Average Precision (MAP) score.
    """
    # Calculate AP for each query
    ap_scores = []
    for _, row in df.iterrows():
        ground_truth = row['ground_truth']
        predictions_with_scores = row['top10']
        predictions = [pred[0] for pred in predictions_with_scores]
        ap_scores.append(average_precision(ground_truth, predictions))

    # Calculate MAP
    mean_ap = sum(ap_scores) / len(ap_scores) if ap_scores else 0
    return mean_ap


def calculate_recall(df):
    """
    Calculate recall@1, recall@5, and recall@10 for the given dataframe.

    Args:
        df (pd.DataFrame): DataFrame containing columns 'query', 'ground_truth', 'top1', 'top5', 'top10'.

    Returns:
        dict: A dictionary containing recall@1, recall@5, and recall@10.
    """
    recall_at_1 = 0
    recall_at_5 = 0
    recall_at_10 = 0
    total_queries = len(df)

    for _, row in df.iterrows():
        ground_truth = row['ground_truth']
        if ground_truth in [pred[0] for pred in row['top1']]:
            recall_at_1 += 1
        if ground_truth in [pred[0] for pred in row['top5']]:
            recall_at_5 += 1
        if ground_truth in [pred[0] for pred in row['top10']]:
            recall_at_10 += 1

    return {
        'recall@1': recall_at_1 / total_queries,
        'recall@5': recall_at_5 / total_queries,
        'recall@10': recall_at_10 / total_queries
    }


def ndcg_score(ground_truth, predictions, k=10):
    """
    Calculate the Normalized Discounted Cumulative Gain (NDCG) for a single query.

    Args:
        ground_truth (int): The ground truth video ID.
        predictions (list): List of predicted video IDs with scores [(id, score), ...].
        k (int): The number of top predictions to consider.

    Returns:
        float: The NDCG score for the query.
    """
    def dcg(relevance_scores):
        return sum(rel / np.log2(idx + 2) for idx, rel in enumerate(relevance_scores))

    # Relevance scores: 1 if the prediction matches the ground truth, else 0
    relevance_scores = [1 if pred[0] ==
                        ground_truth else 0 for pred in predictions[:k]]

    # Calculate DCG and IDCG
    actual_dcg = dcg(relevance_scores)
    ideal_dcg = dcg(sorted(relevance_scores, reverse=True))

    # Return NDCG
    return actual_dcg / ideal_dcg if ideal_dcg > 0 else 0

# call this function to get the NDCG score for each query


def calculate_ndcg(df, k=10):
    """
    Calculate NDCG for the given dataframe.

    Args:
        df (pd.DataFrame): DataFrame containing columns 'query', 'ground_truth', 'top1', 'top5', 'top10'.
        k (int): The number of top predictions to consider.

    Returns:
        float: The mean NDCG score.
    """
    ndcg_scores = []
    for _, row in df.iterrows():
        ground_truth = row['ground_truth']
        predictions_with_scores = row['top10']
        ndcg_scores.append(ndcg_score(
            ground_truth, predictions_with_scores, k))

    return sum(ndcg_scores) / len(ndcg_scores) if ndcg_scores else 0


def get_all_eval_scores(df):
    """Return a dataframe with all evaluation scores: Recall@1, Recall@5, Recall@10, MAP, NDCG@1, NDCG@5, NDCG@10"""
    recall_scores = calculate_recall(df)
    map_score = calculate_mean_average_precision(df)
    ndcg_score_1 = calculate_ndcg(df, k=1)
    ndcg_score_5 = calculate_ndcg(df, k=5)
    ndcg_score_10 = calculate_ndcg(df, k=10)

    eval_scores = {
        'recall@1': recall_scores['recall@1'],
        'recall@5': recall_scores['recall@5'],
        'recall@10': recall_scores['recall@10'],
        'map': map_score,
        'ndcg@1': ndcg_score_1,
        'ndcg@5': ndcg_score_5,
        'ndcg@10': ndcg_score_10
    }

    return eval_scores


In [56]:
twohee_data_col_to_df(ret_dc)

Unnamed: 0,rel_frame_id,query,top1,top5,top10,ground_truth
0,video7579,a girl wearing red top and black trouser is pu...,"[[7579, 1.4025176763534546]]","[[7579, 1.4025176763534546], [7113, 1.42591190...","[[7579, 1.4025176763534546], [7113, 1.42591190...",7579
1,video7725,young people sit around the edges of a room cl...,"[[8441, 1.3825130462646484]]","[[8441, 1.3825130462646484], [7725, 1.41566181...","[[8441, 1.3825130462646484], [7725, 1.41566181...",7725
2,video9258,a person is using a phone,"[[7728, 1.4489479064941406]]","[[7728, 1.4489479064941406], [7029, 1.46234631...","[[7728, 1.4489479064941406], [7029, 1.46234631...",9258
3,video7365,cartoon people are eating at a restaurant,"[[8911, 1.4093154668807983]]","[[8911, 1.4093154668807983], [9777, 1.41201233...","[[8911, 1.4093154668807983], [9777, 1.41201233...",7365
4,video8068,a woman on a couch talks to a a man,"[[9793, 1.3899463415145874]]","[[9793, 1.3899463415145874], [7724, 1.42953228...","[[9793, 1.3899463415145874], [7724, 1.42953228...",8068
...,...,...,...,...,...,...
995,video7034,man in black shirt is holding a baby upside do...,"[[9037, 1.4728624820709229]]","[[9037, 1.4728624820709229], [9028, 1.47589492...","[[9037, 1.4728624820709229], [9028, 1.47589492...",7034
996,video7568,the queen of england is seen walking with an e...,"[[7568, 1.2998905181884766]]","[[7568, 1.2998905181884766], [8306, 1.51000714...","[[7568, 1.2998905181884766], [8306, 1.51000714...",7568
997,video7979,people talking about a fight,"[[8490, 1.4737529754638672]]","[[8490, 1.4737529754638672], [7352, 1.47550010...","[[8490, 1.4737529754638672], [7352, 1.47550010...",7979
998,video7356,a vehicle with details on what comes with it b...,"[[7597, 1.4469205141067505]]","[[7597, 1.4469205141067505], [7701, 1.46081280...","[[7597, 1.4469205141067505], [7701, 1.46081280...",7356


In [49]:
get_all_eval_scores(twohee_data_col_to_df(ret_dc))

{'recall@1': 0.255,
 'recall@5': 0.398,
 'recall@10': 0.398,
 'map': 0.3164999999999999,
 'ndcg@1': 0.255,
 'ndcg@5': 0.33736716954643114,
 'ndcg@10': 0.33736716954643114}

# Try evaluation against queries from FIRE benchmark

We are working with a sample of MSR-VTT and our evaluation pipeline supports only one relevant query per video, hence we need to filter the full FIRE benchmark to only include videos we have sampled and ones with a single relevant result.

FIRE_BENCHMARK_Q_JUDGEMENTS is created in the notebook `./clean_fire_judgements.ipynb`

In [None]:
# Run query pipeline using FIRE


# CSV parser function and pipeline recreated since the FIRE csv uses `video_id` instead of `frame_id`
def read_frame_search_fire_csv(csv_file):
    with open(csv_file, 'r', encoding='utf-8-sig') as f:
        data = csv.DictReader(f)
        for line in data:
            yield line['video_id'], line['sentence']
            
frame_search_fire_pipeline = (
    pipe.input('csv_file')
    .flat_map('csv_file', ('rel_frame_id', 'query'), read_frame_search_fire_csv)
    .map('query', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='text', device='mps'))
    .map('vec', 'vec', lambda x: x / np.linalg.norm(x))
    .map('vec', 'top10_raw_res', ops.ann_search.milvus_client(collection_name=FRAME_RET_COLLECTION, limit=10))
    .map('top10_raw_res', ('top1', 'top5', 'top10'), lambda x: (x[:1], x[:5], x[:10]))
    .output('rel_frame_id', 'query', 'top1', 'top5', 'top10')
)
            
fire_query_results = DataCollection(frame_search_fire_pipeline(FIRE_BENCHMARK_Q_JUDGEMENTS))
fire_query_results.show()

2025-04-16 17:45:31,823 - 8454604864 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/main/config.json HTTP/1.1" 200 0
2025-04-16 17:45:31,928 - 8454604864 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/main/config.json HTTP/1.1" 200 0
2025-04-16 17:45:31,979 - 8454604864 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/main/model.safetensors HTTP/1.1" 404 0
2025-04-16 17:45:32,042 - 17460719616 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "GET /api/models/openai/clip-vit-base-patch16 HTTP/1.1" 200 3499
2025-04-16 17:45:32,094 - 8454604864 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
2025-04-16 17:45:32,188 - 8454604864 - connectionp

rel_frame_id,query,top1,top5,top10
video8469,two parrots in a bird cage one white chick and on green adult,"[[8469, 1.4449390172958374]] len=1","[[8469, 1.4449390172958374],[7849, 1.4497870206832886],[7822, 1.4854648113250732]] len=3","[[8469, 1.4449390172958374],[7849, 1.4497870206832886],[7822, 1.4854648113250732]] len=3"
video9687,a man chopping lobster and taking off the shell,"[[7820, 1.40888512134552]] len=1","[[7820, 1.40888512134552],[9742, 1.4197094440460205],[9687, 1.4254179000854492]] len=3","[[7820, 1.40888512134552],[9742, 1.4197094440460205],[9687, 1.4254179000854492]] len=3"
video7698,two women are walking in a parking lot,"[[7558, 1.4385546445846558]] len=1","[[7558, 1.4385546445846558],[9039, 1.4457066059112549],[7698, 1.4519243240356445]] len=3","[[7558, 1.4385546445846558],[9039, 1.4457066059112549],[7698, 1.4519243240356445]] len=3"
video9503,a woman is talking about how jeans with patches or rips is trendy,"[[9503, 1.4195761680603027]] len=1","[[9503, 1.4195761680603027],[8825, 1.4488005638122559],[9039, 1.4948625564575195]] len=3","[[9503, 1.4195761680603027],[8825, 1.4488005638122559],[9039, 1.4948625564575195]] len=3"
video8903,a naked child runs through a field,"[[9031, 1.3999378681182861]] len=1","[[9031, 1.3999378681182861],[9805, 1.4242286682128906],[8125, 1.4620842933654785]] len=3","[[9031, 1.3999378681182861],[9805, 1.4242286682128906],[8125, 1.4620842933654785]] len=3"


In [55]:
get_all_eval_scores(twohee_data_col_to_df(fire_query_results))

{'recall@1': 0.3853503184713376,
 'recall@5': 0.5222929936305732,
 'recall@10': 0.5222929936305732,
 'map': 0.4437367303609342,
 'ndcg@1': 0.3853503184713376,
 'ndcg@5': 0.46382902575068446,
 'ndcg@10': 0.46382902575068446}