In [1]:
# Full video reterval using CLIP4CLIP and TwoHee

In [20]:
import os
from towhee import ops, pipe, register
from towhee.operator import PyOperator
from towhee import DataCollection
from tqdm import tqdm
import pandas as pd
import json

In [23]:
# CONSTANTS
VIDEOS_FOLDER = "data/MSRVTT/videos/all"
CSV_LOADER_FILE = "MSRVTT_video_paths.csv"
CSV_SEARCH_FILE_1 = "MSRVTT_search_1.csv"
CSV_SEARCH_FILE_2 = "MSRVTT_search_2.csv"
MILVUS_COLLECTION_NAME = "text_video_retrieval_6"
MSR_VTT_ANNOTATION_JSON= "data/MSRVTT/annotation/MSR_VTT.json"

In [41]:
# Load videos into Vector DB

In [3]:
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility

connections.connect(host='localhost', port='19530')

In [4]:
def create_milvus_collection(collection_name, dim):
    if utility.has_collection(collection_name):
        utility.drop_collection(collection_name)
    
    fields = [
    FieldSchema(name='vid_id', dtype=DataType.INT64, description='ids', is_primary=True, auto_id=False),
    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, description='embedding vectors', dim=dim)
    ]
    schema = CollectionSchema(fields=fields, description='video retrieval')
    collection = Collection(name=collection_name, schema=schema)

    # create IVF_FLAT index for collection.
    index_params = {
        'metric_type':'L2', #IP
        'index_type':"IVF_FLAT",
        'params':{"nlist":2048}
    }
    collection.create_index(field_name="embedding", index_params=index_params)
    return collection

collection = create_milvus_collection(MILVUS_COLLECTION_NAME, 512)

In [5]:
# get video paths
def get_video_paths(folder):
    video_paths = []
    for root, _, files in os.walk(folder):
        for file in files:
            if file.endswith(('.mp4', '.avi', '.mov')):
                video_paths.append(os.path.join(root, file))
    return video_paths
def get_video_ids(folder):
    video_ids = []
    for root, _, files in os.walk(folder):
        for file in files:
            if file.endswith(('.mp4', '.avi', '.mov')):
                video_ids.append(file.split('.')[0])
    return video_ids

all_vid_ids = get_video_ids(VIDEOS_FOLDER)
len(all_vid_ids)

10000

In [6]:
all_vid_paths = get_video_paths(VIDEOS_FOLDER)
print(*all_vid_paths[:5], sep='\n')
print(len(all_vid_paths))

data/MSRVTT/videos/all/video12.mp4
data/MSRVTT/videos/all/video6110.mp4
data/MSRVTT/videos/all/video9223.mp4
data/MSRVTT/videos/all/video2376.mp4
data/MSRVTT/videos/all/video938.mp4
10000


In [8]:
# Create CSV to feed into TwoHee pipeline
# Create dataframe with video_ids and video_paths columns
pd.DataFrame({
    'video_id': all_vid_ids,
    'video_path': all_vid_paths
}).to_csv(CSV_LOADER_FILE, index=False)

In [9]:
def read_csv(csv_file):
    import csv
    with open(csv_file, 'r', encoding='utf-8-sig') as f:
        data = csv.DictReader(f)
        for line in data:
            yield int(line['video_id'][len('video'):]), line['video_path']

twohee_load_videos_pipeline = (
    pipe.input('csv_file')
    .flat_map('csv_file', ('video_id', 'video_path'), read_csv)
    .map('video_path', 'frames', ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 12}))
    .map('frames', 'vec', ops.video_text_embedding.clip4clip(model_name='clip_vit_b32', modality='video', device='mps'))
    .map(('video_id', 'vec'), (), ops.ann_insert.milvus_client(host='127.0.0.1', port='19530', collection_name=MILVUS_COLLECTION_NAME))
    .output('video_id')
)

In [10]:
ret = DataCollection(twohee_load_videos_pipeline(CSV_LOADER_FILE))

2025-04-12 17:18:57,508 - 15720230912 - node.py-node:167 - INFO: Begin to run Node-_input
2025-04-12 17:18:57,508 - 15737057280 - node.py-node:167 - INFO: Begin to run Node-read_csv-0
2025-04-12 17:18:57,509 - 15753883648 - node.py-node:167 - INFO: Begin to run Node-video-decode/ffmpeg-1
2025-04-12 17:18:57,509 - 15720230912 - node.py-node:167 - INFO: Begin to run Node-video-text-embedding/clip4clip-2
2025-04-12 17:18:57,510 - 15787536384 - node.py-node:167 - INFO: Begin to run Node-ann-insert/milvus-client-3
2025-04-12 17:18:57,510 - 15770710016 - node.py-node:167 - INFO: Begin to run Node-_output


In [12]:
# Query the collection

In [36]:
def read_csv_searcher(csv_file):
    import csv
    with open(csv_file, 'r', encoding='utf-8-sig') as f:
        data = csv.DictReader(f)
        for line in data:
            yield int(line['video_id'][len('video'):]), line['sentence']
            
twohee_search_videos_pipeline = (
    pipe.input('csv_file')
    .flat_map('csv_file', ('video_id', 'sentence'), read_csv_searcher)
    .map('sentence', 'vec', ops.video_text_embedding.clip4clip(model_name='clip_vit_b32', modality='text', device='mps'))
    .map('vec', 'top10_raw_res', 
         ops.ann_search.milvus_client(
             host='127.0.0.1', port='19530', collection_name=MILVUS_COLLECTION_NAME, limit=10)
        )
    .map('top10_raw_res', ('top1', 'top5', 'top10'), lambda x: (x[:1], x[:5], x[:10]))
    .map('video_id', 'ground_truth', lambda x: x)
    .output('video_id', 'sentence', 'ground_truth', 'top1', 'top5', 'top10', 'top10_raw_res')
)

In [39]:
# Load the MSR_VTT annotation JSON file
with open(MSR_VTT_ANNOTATION_JSON, 'r') as f:
    annotations = json.load(f)

# Extract video_id and sentence columns
annotations_df = pd.DataFrame([
    {'video_id': ann['image_id'], 'sentence': ann['caption']}
    for ann in annotations['annotations']
])

print(annotations_df.head())
annotations_df.head(2000).to_csv(CSV_SEARCH_FILE_2, index=False)

    video_id                                           sentence
0  video2960  a cartoon animals runs through an ice cave in ...
1  video2960  a cartoon character runs around inside of a vi...
2  video2960                 a character is running in the snow
3  video2960  a person plays a video game centered around ic...
4  video2960       a person plays online and records themselves


In [40]:
res = DataCollection(twohee_search_videos_pipeline(CSV_SEARCH_FILE_2))

2025-04-12 18:31:51,620 - 12975009792 - node.py-node:167 - INFO: Begin to run Node-read_csv_searcher-0


In [43]:
res.show()

video_id,sentence,ground_truth,top1,top5,top10,top10_raw_res
2960,a cartoon animals runs through an ice cave in a video game,2960,"[[8203, 1.3518543243408203]] len=1","[[8203, 1.3518543243408203],[2960, 1.3743010759353638],[3121, 1.458393931388855],[1742, 1.4631893634796143],...] len=5","[[8203, 1.3518543243408203],[2960, 1.3743010759353638],[3121, 1.458393931388855],[1742, 1.4631893634796143],...] len=10","[[8203, 1.3518543243408203],[2960, 1.3743010759353638],[3121, 1.458393931388855],[1742, 1.4631893634796143],...] len=10"
2960,a cartoon character runs around inside of a video game,2960,"[[4462, 1.420356273651123]] len=1","[[4462, 1.420356273651123],[6841, 1.4255818128585815],[1368, 1.4360970258712769],[3135, 1.4391367435455322],...] len=5","[[4462, 1.420356273651123],[6841, 1.4255818128585815],[1368, 1.4360970258712769],[3135, 1.4391367435455322],...] len=10","[[4462, 1.420356273651123],[6841, 1.4255818128585815],[1368, 1.4360970258712769],[3135, 1.4391367435455322],...] len=10"
2960,a character is running in the snow,2960,"[[8203, 1.3948254585266113]] len=1","[[8203, 1.3948254585266113],[2960, 1.4484233856201172],[7328, 1.448801875114441],[5853, 1.4639129638671875],...] len=5","[[8203, 1.3948254585266113],[2960, 1.4484233856201172],[7328, 1.448801875114441],[5853, 1.4639129638671875],...] len=10","[[8203, 1.3948254585266113],[2960, 1.4484233856201172],[7328, 1.448801875114441],[5853, 1.4639129638671875],...] len=10"
2960,a person plays a video game centered around ice age the movie,2960,"[[8203, 1.3764996528625488]] len=1","[[8203, 1.3764996528625488],[6790, 1.3854682445526123],[5744, 1.4214403629302979],[2960, 1.4253463745117188],...] len=5","[[8203, 1.3764996528625488],[6790, 1.3854682445526123],[5744, 1.4214403629302979],[2960, 1.4253463745117188],...] len=10","[[8203, 1.3764996528625488],[6790, 1.3854682445526123],[5744, 1.4214403629302979],[2960, 1.4253463745117188],...] len=10"
2960,a person plays online and records themselves,2960,"[[6819, 1.4372600317001343]] len=1","[[6819, 1.4372600317001343],[4635, 1.445406436920166],[4391, 1.4497307538986206],[8860, 1.4539172649383545],...] len=5","[[6819, 1.4372600317001343],[4635, 1.445406436920166],[4391, 1.4497307538986206],[8860, 1.4539172649383545],...] len=10","[[6819, 1.4372600317001343],[4635, 1.445406436920166],[4391, 1.4497307538986206],[8860, 1.4539172649383545],...] len=10"


In [62]:
# output to a csv
res_list = res.to_list()
res_obj_list = []

for r in res_list:
    res_obj = vars(r)
    res_obj_list.append(res_obj)
    
res_df = pd.DataFrame(res_obj_list)
res_df



Unnamed: 0,video_id,sentence,ground_truth,top1,top5,top10,top10_raw_res
0,2960,a cartoon animals runs through an ice cave in ...,2960,"[[8203, 1.3518543243408203]]","[[8203, 1.3518543243408203], [2960, 1.37430107...","[[8203, 1.3518543243408203], [2960, 1.37430107...","[[8203, 1.3518543243408203], [2960, 1.37430107..."
1,2960,a cartoon character runs around inside of a vi...,2960,"[[4462, 1.420356273651123]]","[[4462, 1.420356273651123], [6841, 1.425581812...","[[4462, 1.420356273651123], [6841, 1.425581812...","[[4462, 1.420356273651123], [6841, 1.425581812..."
2,2960,a character is running in the snow,2960,"[[8203, 1.3948254585266113]]","[[8203, 1.3948254585266113], [2960, 1.44842338...","[[8203, 1.3948254585266113], [2960, 1.44842338...","[[8203, 1.3948254585266113], [2960, 1.44842338..."
3,2960,a person plays a video game centered around ic...,2960,"[[8203, 1.3764996528625488]]","[[8203, 1.3764996528625488], [6790, 1.38546824...","[[8203, 1.3764996528625488], [6790, 1.38546824...","[[8203, 1.3764996528625488], [6790, 1.38546824..."
4,2960,a person plays online and records themselves,2960,"[[6819, 1.4372600317001343]]","[[6819, 1.4372600317001343], [4635, 1.44540643...","[[6819, 1.4372600317001343], [4635, 1.44540643...","[[6819, 1.4372600317001343], [4635, 1.44540643..."
...,...,...,...,...,...,...,...
1995,5762,movie preview of an animated feature,5762,"[[7222, 1.4541020393371582]]","[[7222, 1.4541020393371582], [5791, 1.45608532...","[[7222, 1.4541020393371582], [5791, 1.45608532...","[[7222, 1.4541020393371582], [5791, 1.45608532..."
1996,5762,a short 3d animated clip for a children movie,5762,"[[711, 1.424837589263916]]","[[711, 1.424837589263916], [7222, 1.4282803535...","[[711, 1.424837589263916], [7222, 1.4282803535...","[[711, 1.424837589263916], [7222, 1.4282803535..."
1997,5762,an animated boy watching dinozaurs in forest,5762,"[[4979, 1.4297271966934204]]","[[4979, 1.4297271966934204], [4375, 1.43844258...","[[4979, 1.4297271966934204], [4375, 1.43844258...","[[4979, 1.4297271966934204], [4375, 1.43844258..."
1998,5762,a 3d dinosaur stares at 3d kid,5762,"[[9713, 1.4939796924591064]]","[[9713, 1.4939796924591064], [1105, 1.50564360...","[[9713, 1.4939796924591064], [1105, 1.50564360...","[[9713, 1.4939796924591064], [1105, 1.50564360..."


In [None]:
def average_precision(ground_truth, predictions):
    """
    Calculate the Average Precision (AP) for a single query.

    Args:
        ground_truth (int): The ground truth video ID.
        predictions (list): List of predicted video IDs.

    Returns:
        float: The Average Precision (AP) score for the query.
    """
    hits = 0
    sum_precision = 0
    for i, pred in enumerate(predictions):
        if pred == ground_truth:
            hits += 1
            sum_precision += hits / (i + 1)
    return sum_precision / hits if hits > 0 else 0
def calculate_mean_average_precision(df):
    """
    Calculate the Mean Average Precision (MAP) for the given dataframe.

    Args:
        df (pd.DataFrame): DataFrame containing columns 'query', 'ground_truth', 'top1', 'top5', 'top10'.

    Returns:
        float: The Mean Average Precision (MAP) score.
    """
    # Calculate AP for each query
    ap_scores = []
    for _, row in df.iterrows():
        ground_truth = row['ground_truth']
        predictions_with_scores = row['top10']
        predictions = [pred[0] for pred in predictions_with_scores]
        ap_scores.append(average_precision(ground_truth, predictions))

    print(ap_scores)
    # Calculate MAP
    mean_ap = sum(ap_scores) / len(ap_scores) if ap_scores else 0
    return mean_ap

calculate_mean_average_precision(res_df)

[8203, 2960, 3121, 1742, 2874, 4462, 7821, 5744, 8567, 3713]
[4462, 6841, 1368, 3135, 2874, 6790, 8203, 789, 2621, 1742]
[8203, 2960, 7328, 5853, 6251, 5744, 187, 4462, 4945, 730]
[8203, 6790, 5744, 2960, 1368, 2874, 3416, 9134, 3191, 8406]
[6819, 4635, 4391, 8860, 7331, 5332, 4873, 607, 5285, 3341]
[6790, 8203, 5744, 2960, 1368, 3191, 8898, 3416, 2874, 4979]
[789, 2960, 6478, 4462, 1368, 8203, 8256, 1052, 686, 9074]
[8203, 2960, 3121, 789, 5260, 5906, 1742, 3713, 2874, 4462]
[4277, 6236, 6790, 5480, 8203, 4231, 4387, 3713, 7798, 3121]
[8203, 6790, 5744, 2960, 1368, 2874, 8898, 3191, 8406, 4873]
[9460, 8493, 6007, 2309, 5773, 4014, 6348, 2948, 6978, 9620]
[8203, 2960, 789, 686, 5906, 3121, 5260, 5744, 1745, 1742]
[8203, 6790, 2960, 1368, 5744, 3191, 2874, 8898, 8406, 3416]
[6478, 5332, 4997, 2162, 2977, 3863, 262, 4272, 3713, 5535]
[6790, 8203, 2960, 5744, 1368, 8898, 2874, 3945, 9134, 8406]
[5800, 6090, 3214, 545, 5332, 2621, 7919, 1779, 3349, 4448]
[9460, 8493, 6007, 2309, 5773, 4014

0.26792658730158725

In [54]:
vars(res.to_list()[0])

{'video_id': 2960,
 'sentence': 'a cartoon animals runs through an ice cave in a video game',
 'ground_truth': 2960,
 'top1': [[8203, 1.3518543243408203]],
 'top5': [[8203, 1.3518543243408203],
  [2960, 1.3743010759353638],
  [3121, 1.458393931388855],
  [1742, 1.4631893634796143],
  [2874, 1.4642057418823242]],
 'top10': [[8203, 1.3518543243408203],
  [2960, 1.3743010759353638],
  [3121, 1.458393931388855],
  [1742, 1.4631893634796143],
  [2874, 1.4642057418823242],
  [4462, 1.468991756439209],
  [7821, 1.4724347591400146],
  [5744, 1.476823329925537],
  [8567, 1.4791474342346191],
  [3713, 1.4855239391326904]],
 'top10_raw_res': [[8203, 1.3518543243408203],
  [2960, 1.3743010759353638],
  [3121, 1.458393931388855],
  [1742, 1.4631893634796143],
  [2874, 1.4642057418823242],
  [4462, 1.468991756439209],
  [7821, 1.4724347591400146],
  [5744, 1.476823329925537],
  [8567, 1.4791474342346191],
  [3713, 1.4855239391326904]]}