In [1]:
import os
import csv
from towhee import ops, pipe, register
from towhee.operator import PyOperator
from towhee import DataCollection
from tqdm import tqdm
import pandas as pd
import json
import numpy as np
from helpers import milvus_utils
from helpers.eval_utils import twohee_data_col_to_df, get_all_eval_scores

Connected to Milvus server at port 19530


In [2]:
# CONSTANTS

# Files
MSRVTT_SAMPLES = "./MSRVTT_1K.csv"
# file created using raw FIRE judgements, see clean_fire_judgements.ipynb
FIRE_BENCHMARK_Q_JUDGEMENTS = "./fire_benchmark_q_judgements.csv" 

# Database Collections
VIDEO_RET_COLLECTION = "msrvtt_vid_ret_1"
FRAME_RET_COLLECTION = "msrvtt_frame_ret_1"

In [3]:
raw_samples_df = pd.read_csv(MSRVTT_SAMPLES)
raw_samples_df[['video_id', 'video_path', 'sentence']].head()

Unnamed: 0,video_id,video_path,sentence
0,video7579,./test_1k_compress/video7579.mp4,a girl wearing red top and black trouser is pu...
1,video7725,./test_1k_compress/video7725.mp4,young people sit around the edges of a room cl...
2,video9258,./test_1k_compress/video9258.mp4,a person is using a phone
3,video7365,./test_1k_compress/video7365.mp4,cartoon people are eating at a restaurant
4,video8068,./test_1k_compress/video8068.mp4,a woman on a couch talks to a a man


In [4]:
milvus_utils.create_milvus_collection(VIDEO_RET_COLLECTION, 512)

<Collection>:
-------------
<name>: msrvtt_vid_ret_1
<description>: video retrieval
<schema>: {'auto_id': False, 'description': 'video retrieval', 'fields': [{'name': 'id', 'description': '', 'type': <DataType.INT64: 5>, 'is_primary': True, 'auto_id': False}, {'name': 'embedding', 'description': '', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 512}}]}

We create a pipeline that loads the video embeddings into the Milvus Vector DB using a distributed Twohee pipeline

In [5]:
def read_loader_csv(csv_file):
    with open(csv_file, 'r', encoding='utf-8-sig') as f:
        data = csv.DictReader(f)
        for line in data:
            yield int(line['video_id'][len('video'):]), line['video_path']

video_loader_pipeline = (
    pipe.input('csv_file')
    .flat_map('csv_file', ('video_id', 'video_path'), read_loader_csv)
    # Create 12 evenly distributed frames per video
    .map('video_path', 'frames', ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 12}))
    # I have a M2 Max, so device is set to mps for better performance
    .map('frames', 'vec', ops.video_text_embedding.clip4clip(model_name='clip_vit_b32', modality='video', device='mps'))
    .map(('video_id', 'vec'), (), ops.ann_insert.milvus_client(collection_name=VIDEO_RET_COLLECTION))
    .output('video_id')
)

In [6]:
# We call the pipeline with a CSV file containing the video paths
video_loader_ret = video_loader_pipeline(MSRVTT_SAMPLES)

2025-04-17 15:31:05,460 - 15982424064 - node.py-node:167 - INFO: Begin to run Node-_input
2025-04-17 15:31:05,461 - 15999250432 - node.py-node:167 - INFO: Begin to run Node-read_loader_csv-0
2025-04-17 15:31:05,461 - 16016076800 - node.py-node:167 - INFO: Begin to run Node-video-decode/ffmpeg-1
2025-04-17 15:31:05,461 - 16032903168 - node.py-node:167 - INFO: Begin to run Node-video-text-embedding/clip4clip-2
2025-04-17 15:31:05,461 - 16049729536 - node.py-node:167 - INFO: Begin to run Node-ann-insert/milvus-client-3
2025-04-17 15:31:05,461 - 15982424064 - node.py-node:167 - INFO: Begin to run Node-_output


The 1000 videos are now loaded into the Milvus `VIDEO_RET_COLLECTION` collection.

Now, we query these videos using the annotated sentences as queries and the video ids as the ground truth results.

In [7]:
def read_video_search_csv(csv_file):
    with open(csv_file, 'r', encoding='utf-8-sig') as f:
        data = csv.DictReader(f)
        for line in data:
            yield line['video_id'], line['sentence']

video_search_pipeline = (
    pipe.input('csv_file')
    .flat_map('csv_file', ('rel_video_id', 'query'), read_video_search_csv)
    .map('query', 'vec', ops.video_text_embedding.clip4clip(model_name='clip_vit_b32', modality='text', device='mps'))
    .map('vec', 'top10_raw_res', 
         ops.ann_search.milvus_client(collection_name=VIDEO_RET_COLLECTION, limit=15))
    .map('top10_raw_res', ('top1', 'top5', 'top10', 'top15'), lambda x: (x[:1], x[:5], x[:10], x[:15]))
    .output('rel_video_id', 'query', 'top1', 'top5', 'top10', 'top15')
)

2025-04-17 15:32:08,412 - 19445510144 - node.py-node:167 - INFO: Begin to run Node-_input
2025-04-17 15:32:08,414 - 19445510144 - node.py-node:167 - INFO: Begin to run Node-_output


In [8]:
all_query_results = DataCollection(video_search_pipeline(MSRVTT_SAMPLES))

2025-04-17 15:32:08,413 - 19462336512 - node.py-node:167 - INFO: Begin to run Node-read_video_search_csv-0
2025-04-17 15:32:08,413 - 19479162880 - node.py-node:167 - INFO: Begin to run Node-video-text-embedding/clip4clip-1
2025-04-17 15:32:08,414 - 19495989248 - node.py-node:167 - INFO: Begin to run Node-ann-search/milvus-client-2
2025-04-17 15:32:08,414 - 19512815616 - node.py-node:167 - INFO: Begin to run Node-lambda-3


2025-04-17 15:38:48,191 - 19462336512 - node.py-node:167 - INFO: Begin to run Node-_input
2025-04-17 15:38:48,192 - 19479162880 - node.py-node:167 - INFO: Begin to run Node-video-text-embedding/clip4clip-1
2025-04-17 15:38:48,193 - 19495989248 - node.py-node:167 - INFO: Begin to run Node-ann-search/milvus-client-2
2025-04-17 15:38:48,193 - 19512815616 - node.py-node:167 - INFO: Begin to run Node-lambda-3


In [10]:
# Convert the Twohee data collection to a pandas dataframe so we can apply evaluation methods
msr_ans_query_results_df = twohee_data_col_to_df(all_query_results)
msr_ans_query_results_df

Unnamed: 0,rel_video_id,query,top1,top5,top10,top15,ground_truth
0,video7579,a girl wearing red top and black trouser is pu...,"[[7579, 1.415151834487915]]","[[7579, 1.415151834487915], [9969, 1.479910612...","[[7579, 1.415151834487915], [9969, 1.479910612...","[[7579, 1.415151834487915], [9969, 1.479910612...",7579
1,video7725,young people sit around the edges of a room cl...,"[[7725, 1.3622068166732788]]","[[7725, 1.3622068166732788], [8014, 1.48652696...","[[7725, 1.3622068166732788], [8014, 1.48652696...","[[7725, 1.3622068166732788], [8014, 1.48652696...",7725
2,video9258,a person is using a phone,"[[9258, 1.4011969566345215]]","[[9258, 1.4011969566345215], [9257, 1.42286348...","[[9258, 1.4011969566345215], [9257, 1.42286348...","[[9258, 1.4011969566345215], [9257, 1.42286348...",9258
3,video7365,cartoon people are eating at a restaurant,"[[7365, 1.4027695655822754]]","[[7365, 1.4027695655822754], [8781, 1.46230483...","[[7365, 1.4027695655822754], [8781, 1.46230483...","[[7365, 1.4027695655822754], [8781, 1.46230483...",7365
4,video8068,a woman on a couch talks to a a man,"[[7162, 1.4716743230819702]]","[[7162, 1.4716743230819702], [8304, 1.47874724...","[[7162, 1.4716743230819702], [8304, 1.47874724...","[[7162, 1.4716743230819702], [8304, 1.47874724...",8068
...,...,...,...,...,...,...,...
995,video7034,man in black shirt is holding a baby upside do...,"[[9320, 1.5113091468811035]]","[[9320, 1.5113091468811035], [9404, 1.51643335...","[[9320, 1.5113091468811035], [9404, 1.51643335...","[[9320, 1.5113091468811035], [9404, 1.51643335...",7034
996,video7568,the queen of england is seen walking with an e...,"[[7568, 1.2981326580047607]]","[[7568, 1.2981326580047607], [7116, 1.41021490...","[[7568, 1.2981326580047607], [7116, 1.41021490...","[[7568, 1.2981326580047607], [7116, 1.41021490...",7568
997,video7979,people talking about a fight,"[[7211, 1.4528591632843018]]","[[7211, 1.4528591632843018], [7979, 1.46294164...","[[7211, 1.4528591632843018], [7979, 1.46294164...","[[7211, 1.4528591632843018], [7979, 1.46294164...",7979
998,video7356,a vehicle with details on what comes with it b...,"[[7356, 1.4014551639556885]]","[[7356, 1.4014551639556885], [7765, 1.47221362...","[[7356, 1.4014551639556885], [7765, 1.47221362...","[[7356, 1.4014551639556885], [7765, 1.47221362...",7356


In [None]:
msr_ans_query_results_df.to_csv('query_results/c4c_queries_msrvtt.csv', index=False)

In [13]:
get_all_eval_scores(msr_ans_query_results_df)

{'recall@1': 0.426,
 'recall@5': 0.716,
 'recall@10': 0.814,
 'map': 0.5456543650793645,
 'ndcg@1': 0.426,
 'ndcg@5': 0.5780321865313451,
 'ndcg@10': 0.6100154270801753}

## Try evaluation against queries from FIRE benchmark

We are working with a sample of MSR-VTT and our evaluation pipeline supports only one relevant query per video, hence we need to filter the full FIRE benchmark to only include videos we have sampled and ones with a single relevant result.

FIRE_BENCHMARK_Q_JUDGEMENTS is created in the notebook `./clean_fire_judgements.ipynb`

In [14]:
# Run query pipeline using FIRE
def read_frame_search_fire_csv(csv_file):
    with open(csv_file, 'r', encoding='utf-8-sig') as f:
        data = csv.DictReader(f)
        for line in data:
            yield line['video_id'], line['sentence']
            
fire_query_results = DataCollection(video_search_pipeline(FIRE_BENCHMARK_Q_JUDGEMENTS))

2025-04-17 15:38:48,192 - 13025488896 - node.py-node:167 - INFO: Begin to run Node-read_video_search_csv-0
2025-04-17 15:38:48,193 - 19445510144 - node.py-node:167 - INFO: Begin to run Node-_output


In [16]:
fire_c4c_query_results_df = twohee_data_col_to_df(fire_query_results)
fire_c4c_query_results_df.to_csv('query_results/c4c_queries_fire.csv', index=False)

In [17]:
get_all_eval_scores(fire_c4c_query_results_df)

{'recall@1': 0.5732484076433121,
 'recall@5': 0.821656050955414,
 'recall@10': 0.9076433121019108,
 'map': 0.6800513092710544,
 'ndcg@1': 0.5732484076433121,
 'ndcg@5': 0.7064313418462581,
 'ndcg@10': 0.7348170561659407}

## Final Scores for Video Retrieval using Clip4Clip

In [18]:
msrvtt_ans_scores = get_all_eval_scores(msr_ans_query_results_df)
fire_ans_scores = get_all_eval_scores(fire_c4c_query_results_df)
# Combine the scores into a single DataFrame for better presentation
scores_df = pd.DataFrame([msrvtt_ans_scores, fire_ans_scores], index=['MSRVTT Ann. Queries', 'FIRE Queries'])

# Style the DataFrame for better visualization
styled_scores = scores_df.style.set_caption("Model Metrics Comparison").format("{:.3f}")

styled_scores

2025-04-17 15:41:41,564 - 8454604864 - __init__.py-__init__:342 - DEBUG: matplotlib data path: /Users/suraj/miniconda3/envs/info-ret-proj/lib/python3.10/site-packages/matplotlib/mpl-data
2025-04-17 15:41:41,568 - 8454604864 - __init__.py-__init__:342 - DEBUG: CONFIGDIR=/Users/suraj/.matplotlib
2025-04-17 15:41:41,615 - 8454604864 - __init__.py-__init__:1557 - DEBUG: interactive is False
2025-04-17 15:41:41,615 - 8454604864 - __init__.py-__init__:1558 - DEBUG: platform is darwin
2025-04-17 15:41:41,657 - 8454604864 - __init__.py-__init__:342 - DEBUG: CACHEDIR=/Users/suraj/.matplotlib
2025-04-17 15:41:41,659 - 8454604864 - font_manager.py-font_manager:1635 - DEBUG: Using fontManager instance from /Users/suraj/.matplotlib/fontlist-v390.json


Unnamed: 0,recall@1,recall@5,recall@10,map,ndcg@1,ndcg@5,ndcg@10
MSRVTT Ann. Queries,0.426,0.716,0.814,0.546,0.426,0.578,0.61
FIRE Queries,0.573,0.822,0.908,0.68,0.573,0.706,0.735
