In [1]:
# Full video reterval using CLIP4CLIP and TwoHee

In [1]:
import os
import csv
from towhee import ops, pipe, register
from towhee.operator import PyOperator
from towhee import DataCollection
from tqdm import tqdm
import pandas as pd
import json
import numpy as np
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility

In [2]:
# CONSTANTS
VIDEOS_FOLDER = "data/MSRVTT/videos/all"
CSV_LOADER_FILE = "MSRVTT_video_paths.csv"
CSV_SEARCH_FILE_1 = "MSRVTT_search_1.csv"
CSV_SEARCH_FILE_2 = "MSRVTT_search_2.csv"
CSV_SEARCH_DISTINCT = "MSRVTT_search_distinct.csv"
MILVUS_COLLECTION_NAME = "text_video_retrieval_6"
MSR_VTT_ANNOTATION_JSON= "data/MSRVTT/annotation/MSR_VTT.json"

# Constants for CLIP model
IMAGES_FOLDER = "syn_data/msrvtt_imgs"
MILVUS_IMAGES_COLLECTION_NAME = "text_video_retrieval_images_1"
CSV_IMAGES_LOADER_FILE = "MSRVTT_images_paths.csv"
CSV_IMAGES_SEARCH_FILE_1 = "MSRVTT_images_search_1.csv"

In [41]:
# Load videos into Vector DB

In [3]:
connections.connect(host='localhost', port='19530')

In [3]:
def create_milvus_collection(collection_name, dim):
    if utility.has_collection(collection_name):
        utility.drop_collection(collection_name)
    
    fields = [
    FieldSchema(name='vid_id', dtype=DataType.INT64, description='ids', is_primary=True, auto_id=False),
    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, description='embedding vectors', dim=dim)
    ]
    schema = CollectionSchema(fields=fields, description='video retrieval')
    collection = Collection(name=collection_name, schema=schema)

    # create IVF_FLAT index for collection.
    index_params = {
        'metric_type':'L2', #IP
        'index_type':"IVF_FLAT",
        'params':{"nlist":2048}
    }
    collection.create_index(field_name="embedding", index_params=index_params)
    return collection


In [None]:

collection = create_milvus_collection(MILVUS_COLLECTION_NAME, 512)

In [17]:
# get video paths
def get_video_paths(folder):
    video_paths = []
    for root, _, files in os.walk(folder):
        for file in files:
            if file.endswith(('.mp4', '.avi', '.mov')):
                video_paths.append(os.path.join(root, file))
    return video_paths
def get_video_ids(folder):
    video_ids = []
    for root, _, files in os.walk(folder):
        for file in files:
            if file.endswith(('.mp4', '.avi', '.mov')):
                video_ids.append(file.split('.')[0])
    return video_ids

def get_video_frame_paths(folder):
    video_frame_paths = []
    # print(list(os.walk(folder)))
    for root, _, files in os.walk(folder):
        for file in files:
            if file.endswith(('.jpg', '.png')):
                video_frame_paths.append(os.path.join(root, file))
    return video_frame_paths

all_vid_ids = get_video_ids(VIDEOS_FOLDER)
len(all_vid_ids)

10000

In [18]:
# next(os.walk("syn_data/msrvtt_imgs"))
all_frame_paths = get_video_frame_paths(IMAGES_FOLDER)

In [19]:
all_vid_paths = get_video_paths(VIDEOS_FOLDER)
print(*all_vid_paths[:5], sep='\n')
print(len(all_vid_paths))

data/MSRVTT/videos/all/video12.mp4
data/MSRVTT/videos/all/video6110.mp4
data/MSRVTT/videos/all/video9223.mp4
data/MSRVTT/videos/all/video2376.mp4
data/MSRVTT/videos/all/video938.mp4
10000


In [None]:
# Create CSV to feed into TwoHee pipeline
# Create dataframe with video_ids and video_paths columns
pd.DataFrame({
    'video_id': all_vid_ids,
    'video_path': all_vid_paths
}).to_csv(CSV_LOADER_FILE, index=False)

In [20]:
pd.DataFrame({
    'video_id': all_vid_ids,
    'frame_path': all_frame_paths
}).head(100).to_csv(CSV_IMAGES_LOADER_FILE, index=False)

In [None]:
def read_csv(csv_file):
    with open(csv_file, 'r', encoding='utf-8-sig') as f:
        data = csv.DictReader(f)
        for line in data:
            yield int(line['video_id'][len('video'):]), line['video_path']

twohee_load_videos_pipeline = (
    pipe.input('csv_file')
    .flat_map('csv_file', ('video_id', 'video_path'), read_csv)
    .map('video_path', 'frames', ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 12}))
    .map('frames', 'vec', ops.video_text_embedding.clip4clip(model_name='clip_vit_b32', modality='video', device='mps'))
    .map(('video_id', 'vec'), (), ops.ann_insert.milvus_client(host='127.0.0.1', port='19530', collection_name=MILVUS_COLLECTION_NAME))
    .output('video_id')
)

In [10]:
ret = DataCollection(twohee_load_videos_pipeline(CSV_LOADER_FILE))

2025-04-12 17:18:57,508 - 15720230912 - node.py-node:167 - INFO: Begin to run Node-_input
2025-04-12 17:18:57,508 - 15737057280 - node.py-node:167 - INFO: Begin to run Node-read_csv-0
2025-04-12 17:18:57,509 - 15753883648 - node.py-node:167 - INFO: Begin to run Node-video-decode/ffmpeg-1
2025-04-12 17:18:57,509 - 15720230912 - node.py-node:167 - INFO: Begin to run Node-video-text-embedding/clip4clip-2
2025-04-12 17:18:57,510 - 15787536384 - node.py-node:167 - INFO: Begin to run Node-ann-insert/milvus-client-3
2025-04-12 17:18:57,510 - 15770710016 - node.py-node:167 - INFO: Begin to run Node-_output


In [12]:
# Query the collection

In [3]:
def read_csv_searcher(csv_file):
    import csv
    with open(csv_file, 'r', encoding='utf-8-sig') as f:
        data = csv.DictReader(f)
        for line in data:
            yield int(line['video_id'][len('video'):]), line['sentence']
            
twohee_search_videos_pipeline = (
    pipe.input('csv_file')
    .flat_map('csv_file', ('video_id', 'sentence'), read_csv_searcher)
    .map('sentence', 'vec', ops.video_text_embedding.clip4clip(model_name='clip_vit_b32', modality='text', device='mps'))
    .map('vec', 'top10_raw_res', 
         ops.ann_search.milvus_client(
             host='127.0.0.1', port='19530', collection_name=MILVUS_COLLECTION_NAME, limit=10)
        )
    .map('top10_raw_res', ('top1', 'top5', 'top10'), lambda x: (x[:1], x[:5], x[:10]))
    .map('video_id', 'ground_truth', lambda x: x)
    .output('video_id', 'sentence', 'ground_truth', 'top1', 'top5', 'top10', 'top10_raw_res')
)

In [13]:
# Load the MSR_VTT annotation JSON file
with open(MSR_VTT_ANNOTATION_JSON, 'r') as f:
    annotations = json.load(f)

# Extract video_id and sentence columns
annotations_df = pd.DataFrame([
    {'video_id': ann['image_id'], 'sentence': ann['caption']}
    for ann in annotations['annotations']
])

# select distinct video_ids
anno_df_single_vids = annotations_df.drop_duplicates(subset=['video_id'])
anno_df_single_vids.head(10_000).to_csv(CSV_SEARCH_DISTINCT, index=False)

# print(annotations_df.head())
annotations_df.to_csv(CSV_SEARCH_FILE_2, index=False)

In [14]:
print("Num entries in search csv: " + str(len(open(CSV_SEARCH_DISTINCT).readlines()) - 1))
res = DataCollection(twohee_search_videos_pipeline(CSV_SEARCH_DISTINCT))

2025-04-15 13:02:42,362 - 16185847808 - node.py-node:167 - INFO: Begin to run Node-_input


Num entries in search csv: 10000


In [15]:
res.show()

video_id,sentence,ground_truth,top1,top5,top10,top10_raw_res
2960,a cartoon animals runs through an ice cave in a video game,2960,"[[8203, 1.3518543243408203]] len=1","[[8203, 1.3518543243408203],[2960, 1.3743010759353638],[3121, 1.458393931388855],[1742, 1.4631893634796143],...] len=5","[[8203, 1.3518543243408203],[2960, 1.3743010759353638],[3121, 1.458393931388855],[1742, 1.4631893634796143],...] len=10","[[8203, 1.3518543243408203],[2960, 1.3743010759353638],[3121, 1.458393931388855],[1742, 1.4631893634796143],...] len=10"
2636,a man gets hit in the face with a chair during a wwf wrestling match,2636,"[[2636, 1.3602731227874756]] len=1","[[2636, 1.3602731227874756],[4774, 1.3791619539260864],[9847, 1.3864306211471558],[9636, 1.399008870124817],...] len=5","[[2636, 1.3602731227874756],[4774, 1.3791619539260864],[9847, 1.3864306211471558],[9636, 1.399008870124817],...] len=10","[[2636, 1.3602731227874756],[4774, 1.3791619539260864],[9847, 1.3864306211471558],[9636, 1.399008870124817],...] len=10"
4311,a person is explaining something,4311,"[[4242, 1.492103934288025]] len=1","[[4242, 1.492103934288025],[6415, 1.5018973350524902],[8060, 1.5108928680419922],[5002, 1.51104736328125],...] len=5","[[4242, 1.492103934288025],[6415, 1.5018973350524902],[8060, 1.5108928680419922],[5002, 1.51104736328125],...] len=10","[[4242, 1.492103934288025],[6415, 1.5018973350524902],[8060, 1.5108928680419922],[5002, 1.51104736328125],...] len=10"
1844,a man conducting a science experiment,1844,"[[641, 1.3759840726852417]] len=1","[[641, 1.3759840726852417],[9176, 1.4121618270874023],[7282, 1.4177645444869995],[1844, 1.419187307357788],...] len=5","[[641, 1.3759840726852417],[9176, 1.4121618270874023],[7282, 1.4177645444869995],[1844, 1.419187307357788],...] len=10","[[641, 1.3759840726852417],[9176, 1.4121618270874023],[7282, 1.4177645444869995],[1844, 1.419187307357788],...] len=10"
2213,a couple of men wrestling on the ground,2213,"[[4778, 1.414879322052002]] len=1","[[4778, 1.414879322052002],[3685, 1.424682378768921],[5689, 1.4251071214675903],[5475, 1.4268649816513062],...] len=5","[[4778, 1.414879322052002],[3685, 1.424682378768921],[5689, 1.4251071214675903],[5475, 1.4268649816513062],...] len=10","[[4778, 1.414879322052002],[3685, 1.424682378768921],[5689, 1.4251071214675903],[5475, 1.4268649816513062],...] len=10"


In [16]:
# output to a csv
res_list = res.to_list()
res_obj_list = []

for r in res_list:
    res_obj = vars(r)
    res_obj_list.append(res_obj)
    
res_df = pd.DataFrame(res_obj_list)

res_df.to_csv('res_10_000.csv', index=False)
res_df



Unnamed: 0,video_id,sentence,ground_truth,top1,top5,top10,top10_raw_res
0,2960,a cartoon animals runs through an ice cave in ...,2960,"[[8203, 1.3518543243408203]]","[[8203, 1.3518543243408203], [2960, 1.37430107...","[[8203, 1.3518543243408203], [2960, 1.37430107...","[[8203, 1.3518543243408203], [2960, 1.37430107..."
1,2636,a man gets hit in the face with a chair during...,2636,"[[2636, 1.3602731227874756]]","[[2636, 1.3602731227874756], [4774, 1.37916195...","[[2636, 1.3602731227874756], [4774, 1.37916195...","[[2636, 1.3602731227874756], [4774, 1.37916195..."
2,4311,a person is explaining something,4311,"[[4242, 1.492103934288025]]","[[4242, 1.492103934288025], [6415, 1.501897335...","[[4242, 1.492103934288025], [6415, 1.501897335...","[[4242, 1.492103934288025], [6415, 1.501897335..."
3,1844,a man conducting a science experiment,1844,"[[641, 1.3759840726852417]]","[[641, 1.3759840726852417], [9176, 1.412161827...","[[641, 1.3759840726852417], [9176, 1.412161827...","[[641, 1.3759840726852417], [9176, 1.412161827..."
4,2213,a couple of men wrestling on the ground,2213,"[[4778, 1.414879322052002]]","[[4778, 1.414879322052002], [3685, 1.424682378...","[[4778, 1.414879322052002], [3685, 1.424682378...","[[4778, 1.414879322052002], [3685, 1.424682378..."
...,...,...,...,...,...,...,...
9995,7795,a person in black turning a red color wing on ...,7795,"[[4546, 1.4018069505691528]]","[[4546, 1.4018069505691528], [7299, 1.42924499...","[[4546, 1.4018069505691528], [7299, 1.42924499...","[[4546, 1.4018069505691528], [7299, 1.42924499..."
9996,7112,men outside playing basketball one of the men ...,7112,"[[2592, 1.3740363121032715]]","[[2592, 1.3740363121032715], [7112, 1.39124894...","[[2592, 1.3740363121032715], [7112, 1.39124894...","[[2592, 1.3740363121032715], [7112, 1.39124894..."
9997,8658,a man displays a very thin circular silicon wafer,8658,"[[8658, 1.2997075319290161]]","[[8658, 1.2997075319290161], [3236, 1.36788225...","[[8658, 1.2997075319290161], [3236, 1.36788225...","[[8658, 1.2997075319290161], [3236, 1.36788225..."
9998,8978,a man at a bar with a beard blow cigarette smo...,8978,"[[418, 1.4422941207885742]]","[[418, 1.4422941207885742], [9180, 1.448186516...","[[418, 1.4422941207885742], [9180, 1.448186516...","[[418, 1.4422941207885742], [9180, 1.448186516..."


In [17]:
def average_precision(ground_truth, predictions):
    """
    Calculate the Average Precision (AP) for a single query.

    Args:
        ground_truth (int): The ground truth video ID.
        predictions (list): List of predicted video IDs.

    Returns:
        float: The Average Precision (AP) score for the query.
    """
    hits = 0
    sum_precision = 0
    for i, pred in enumerate(predictions):
        if pred == ground_truth:
            hits += 1
            sum_precision += hits / (i + 1)
    return sum_precision / hits if hits > 0 else 0
def calculate_mean_average_precision(df):
    """
    Calculate the Mean Average Precision (MAP) for the given dataframe.

    Args:
        df (pd.DataFrame): DataFrame containing columns 'query', 'ground_truth', 'top1', 'top5', 'top10'.

    Returns:
        float: The Mean Average Precision (MAP) score.
    """
    # Calculate AP for each query
    ap_scores = []
    for _, row in df.iterrows():
        ground_truth = row['ground_truth']
        predictions_with_scores = row['top10']
        predictions = [pred[0] for pred in predictions_with_scores]
        ap_scores.append(average_precision(ground_truth, predictions))

    print(ap_scores)
    # Calculate MAP
    mean_ap = sum(ap_scores) / len(ap_scores) if ap_scores else 0
    return mean_ap

calculate_mean_average_precision(res_df)

[0.5, 1.0, 0, 0.25, 0, 0.5, 1.0, 0, 0, 0, 1.0, 0, 0.14285714285714285, 0.1111111111111111, 1.0, 1.0, 0.1, 0.5, 0, 0.3333333333333333, 0, 0, 1.0, 0, 0, 0, 0, 1.0, 0.16666666666666666, 0, 1.0, 1.0, 0, 0.5, 0, 0.5, 1.0, 1.0, 0, 1.0, 0, 0.1, 0.3333333333333333, 1.0, 0, 1.0, 0, 0, 0, 0.25, 0, 0.3333333333333333, 1.0, 0, 0, 0, 0, 0.3333333333333333, 0, 0, 1.0, 0.5, 0.25, 0, 1.0, 1.0, 0, 0, 0.5, 0, 0, 0, 0, 0, 0, 1.0, 0, 0, 0.5, 0, 0, 0.1, 0, 0, 0.3333333333333333, 0.3333333333333333, 0, 0, 0.3333333333333333, 0, 1.0, 0, 1.0, 0.125, 1.0, 1.0, 0.3333333333333333, 0, 0.1, 0, 0.14285714285714285, 0, 0.5, 0, 1.0, 0, 0, 0, 0.5, 0.5, 0.14285714285714285, 1.0, 1.0, 1.0, 1.0, 0.1, 0, 1.0, 0.3333333333333333, 0, 1.0, 1.0, 0.5, 0, 0.25, 0, 0.1111111111111111, 0.2, 0.5, 1.0, 0.5, 0.14285714285714285, 0, 0, 0, 1.0, 0, 0.5, 0, 0, 1.0, 0, 0.16666666666666666, 0, 0, 0, 0, 0, 0.125, 0.125, 0.3333333333333333, 0, 0.1, 0.3333333333333333, 1.0, 0, 0.2, 0, 0.5, 0.25, 0.25, 1.0, 0, 0.5, 1.0, 0, 0.3333333333333333

0.32125265873016007

In [18]:
def calculate_recall(df):
    """
    Calculate recall@1, recall@5, and recall@10 for the given dataframe.

    Args:
        df (pd.DataFrame): DataFrame containing columns 'query', 'ground_truth', 'top1', 'top5', 'top10'.

    Returns:
        dict: A dictionary containing recall@1, recall@5, and recall@10.
    """
    recall_at_1 = 0
    recall_at_5 = 0
    recall_at_10 = 0
    total_queries = len(df)

    for _, row in df.iterrows():
        ground_truth = row['ground_truth']
        if ground_truth in [pred[0] for pred in row['top1']]:
            recall_at_1 += 1
        if ground_truth in [pred[0] for pred in row['top5']]:
            recall_at_5 += 1
        if ground_truth in [pred[0] for pred in row['top10']]:
            recall_at_10 += 1

    return {
        'recall@1': recall_at_1 / total_queries,
        'recall@5': recall_at_5 / total_queries,
        'recall@10': recall_at_10 / total_queries
    }

# Example usage:
recall_scores = calculate_recall(res_df)
print(recall_scores)

{'recall@1': 0.2334, 'recall@5': 0.4368, 'recall@10': 0.5216}


## Implementation using CLIP

In [24]:
create_milvus_collection(MILVUS_IMAGES_COLLECTION_NAME, 512)

<Collection>:
-------------
<name>: text_video_retrieval_images_1
<description>: video retrieval
<schema>: {'auto_id': False, 'description': 'video retrieval', 'fields': [{'name': 'vid_id', 'description': 'ids', 'type': <DataType.INT64: 5>, 'is_primary': True, 'auto_id': False}, {'name': 'embedding', 'description': 'embedding vectors', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 512}}]}

In [None]:
# Create 

In [21]:
def read_csv(csv_path, encoding='utf-8-sig'):
    with open(csv_path, 'r', encoding=encoding) as f:
        data = csv.DictReader(f)
        for line in data:
            raw_id = line['video_id']
            cleaned_id = raw_id[len('video'):]
            yield int(cleaned_id), line['frame_path']

p3 = (
    pipe.input('csv_file')
    .flat_map('csv_file', ('video_id', 'frame_path'), read_csv)
    .map('frame_path', 'img', ops.image_decode.cv2('rgb'))
    .map('img', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image', device='mps'))
    .map('vec', 'vec', lambda x: x / np.linalg.norm(x))
    .map(('video_id', 'vec'), (), ops.ann_insert.milvus_client(host='127.0.0.1', port='19530', collection_name=MILVUS_IMAGES_COLLECTION_NAME))
    .output()
)

ret = p3(CSV_IMAGES_LOADER_FILE)

2025-04-15 17:44:16,113 - 8454604864 - connectionpool.py-connectionpool:289 - DEBUG: Resetting dropped connection: huggingface.co
2025-04-15 17:44:16,198 - 8454604864 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/main/config.json HTTP/1.1" 200 0
2025-04-15 17:44:16,474 - 8454604864 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/main/config.json HTTP/1.1" 200 0
2025-04-15 17:44:16,596 - 8454604864 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/main/model.safetensors HTTP/1.1" 404 0
2025-04-15 17:44:16,598 - 16450269184 - connectionpool.py-connectionpool:1049 - DEBUG: Starting new HTTPS connection (1): huggingface.co:443
2025-04-15 17:44:16,789 - 16450269184 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "GET /api/models/openai/clip-vit-base-patch

In [23]:
def read_image(image_ids):
    df = pd.read_csv('reverse_image_search.csv')
    id_img = df.set_index('id')['path'].to_dict()
    imgs = []
    decode = ops.image_decode.cv2('rgb')
    for image_id in image_ids:
        path = id_img[image_id]
        imgs.append(decode(path))
    return imgs


p4 = (
    pipe.input('text')
    .map('text', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='text', device='mps'))
    .map('vec', 'vec', lambda x: x / np.linalg.norm(x))
    .map('vec', 'result', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', collection_name=MILVUS_IMAGES_COLLECTION_NAME, limit=4))
    .map('result', 'image_ids', lambda x: [item[0] for item in x])
    # .map('image_ids', 'images', read_image)
    .output('text', 'image_ids')
)

DataCollection(p4("snowy mountain with a clear blue sky")).show()

2025-04-15 17:44:38,719 - 8454604864 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/main/config.json HTTP/1.1" 200 0
2025-04-15 17:44:38,749 - 8454604864 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/main/config.json HTTP/1.1" 200 0
2025-04-15 17:44:38,783 - 8454604864 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/main/model.safetensors HTTP/1.1" 404 0
2025-04-15 17:44:38,787 - 14472818688 - connectionpool.py-connectionpool:1049 - DEBUG: Starting new HTTPS connection (1): huggingface.co:443
2025-04-15 17:44:39,261 - 8454604864 - connectionpool.py-connectionpool:544 - DEBUG: https://huggingface.co:443 "HEAD /openai/clip-vit-base-patch16/resolve/main/tokenizer_config.json HTTP/1.1" 200 0
2025-04-15 17:44:39,282 - 14472818688 - connectionpool.py-connectionpool:544 - DEBUG: h

text,image_ids
snowy mountain with a clear blue sky,"[2342,6009,6465,3466] len=4"
