In [1]:
# Full video reterval using CLIP4CLIP and TwoHee

In [2]:
import os
from towhee import ops, pipe, register
from towhee.operator import PyOperator
from towhee import DataCollection
from tqdm import tqdm
import pandas as pd

In [40]:
# CONSTANTS
VIDEOS_FOLDER = "data/MSRVTT/videos/all"
CSV_LOADER_FILE = "MSRVTT_video_paths.csv"
CSV_SEARCH_FILE_1 = "MSRVTT_search_1.csv"
MILVUS_COLLECTION_NAME = "text_video_retrieval_5"

In [41]:
# Load videos into Vector DB

In [42]:
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility

connections.connect(host='localhost', port='19530')

In [43]:
def create_milvus_collection(collection_name, dim):
    if utility.has_collection(collection_name):
        utility.drop_collection(collection_name)
    
    fields = [
    FieldSchema(name='vid_id', dtype=DataType.INT64, description='ids', is_primary=True, auto_id=False),
    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, description='embedding vectors', dim=dim)
    ]
    schema = CollectionSchema(fields=fields, description='video retrieval')
    collection = Collection(name=collection_name, schema=schema)

    # create IVF_FLAT index for collection.
    index_params = {
        'metric_type':'L2', #IP
        'index_type':"IVF_FLAT",
        'params':{"nlist":2048}
    }
    collection.create_index(field_name="embedding", index_params=index_params)
    return collection

collection = create_milvus_collection(MILVUS_COLLECTION_NAME, 512)

In [44]:
# get video paths
def get_video_paths(folder):
    video_paths = []
    for root, _, files in os.walk(folder):
        for file in files:
            if file.endswith(('.mp4', '.avi', '.mov')):
                video_paths.append(os.path.join(root, file))
    return video_paths
def get_video_ids(folder):
    video_ids = []
    for root, _, files in os.walk(folder):
        for file in files:
            if file.endswith(('.mp4', '.avi', '.mov')):
                video_ids.append(file.split('.')[0])
    return video_ids

all_vid_ids = get_video_ids(VIDEOS_FOLDER)
len(all_vid_ids)

10000

In [45]:
all_vid_paths = get_video_paths(VIDEOS_FOLDER)
print(*all_vid_paths[:5], sep='\n')
print(len(all_vid_paths))

data/MSRVTT/videos/all/video12.mp4
data/MSRVTT/videos/all/video6110.mp4
data/MSRVTT/videos/all/video9223.mp4
data/MSRVTT/videos/all/video2376.mp4
data/MSRVTT/videos/all/video938.mp4
10000


In [46]:
# Create CSV to feed into TwoHee pipeline
# Create dataframe with video_ids and video_paths columns
pd.DataFrame({
    'video_id': all_vid_ids,
    'video_path': all_vid_paths
}).head(100).to_csv(CSV_LOADER_FILE, index=False)

In [47]:
def read_csv(csv_file):
    import csv
    with open(csv_file, 'r', encoding='utf-8-sig') as f:
        data = csv.DictReader(f)
        for line in data:
            yield int(line['video_id'][len('video'):]), line['video_path']

twohee_load_videos_pipeline = (
    pipe.input('csv_file')
    .flat_map('csv_file', ('video_id', 'video_path'), read_csv)
    .map('video_path', 'frames', ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 12}))
    .map('frames', 'vec', ops.video_text_embedding.clip4clip(model_name='clip_vit_b32', modality='video', device='mps'))
    .map(('video_id', 'vec'), (), ops.ann_insert.milvus_client(host='127.0.0.1', port='19530', collection_name=MILVUS_COLLECTION_NAME))
    .output('video_id')  
)

In [48]:
DataCollection(twohee_load_videos_pipeline(CSV_LOADER_FILE)).show()

2025-04-12 17:00:38,522 - 20240478208 - node.py-node:167 - INFO: Begin to run Node-_input
2025-04-12 17:00:38,522 - 20257304576 - node.py-node:167 - INFO: Begin to run Node-read_csv-0
2025-04-12 17:00:38,523 - 20274130944 - node.py-node:167 - INFO: Begin to run Node-video-decode/ffmpeg-1
2025-04-12 17:00:38,523 - 20290957312 - node.py-node:167 - INFO: Begin to run Node-video-text-embedding/clip4clip-2
2025-04-12 17:00:38,524 - 20307783680 - node.py-node:167 - INFO: Begin to run Node-ann-insert/milvus-client-3
2025-04-12 17:00:38,524 - 20240478208 - node.py-node:167 - INFO: Begin to run Node-_output


video_id
12
6110
9223
2376
938


In [12]:
# Query the collection

In [49]:
twohee_search_videos_pipeline = (
    pipe.input('csv_file')
    .flat_map('csv_file', ('video_id', 'sentence'), read_csv)
    .map('sentence', 'vec', ops.video_text_embedding.clip4clip(model_name='clip_vit_b32', modality='text', device='mps'))
    .map('vec', 'top10_raw_res', 
         ops.ann_search.milvus_client(
             host='127.0.0.1', port='19530', collection_name=MILVUS_COLLECTION_NAME, limit=10)
        )
    .map('top10_raw_res', ('top1', 'top5', 'top10'), lambda x: (x[:1], x[:5], x[:10]))
    .map('video_id', 'ground_truth', lambda x: x)
    .output('video_id', 'sentence', 'ground_truth', 'top1', 'top5', 'top10')
)

In [50]:
# Create video_id, sentence csv
pd.DataFrame({
    'video_id': all_vid_ids,
    'video_path': all_vid_paths,
    'sentence': "car driving fast"
}).head().to_csv(CSV_SEARCH_FILE_1, index=False)

In [51]:
DataCollection(twohee_search_videos_pipeline(CSV_SEARCH_FILE_1)).show()

2025-04-12 17:01:50,887 - 18770833408 - node.py-node:167 - INFO: Begin to run Node-_input
2025-04-12 17:01:50,888 - 18787659776 - node.py-node:167 - INFO: Begin to run Node-read_csv-0
2025-04-12 17:01:50,889 - 18804486144 - node.py-node:167 - INFO: Begin to run Node-video-text-embedding/clip4clip-1
2025-04-12 17:01:50,889 - 18821312512 - node.py-node:167 - INFO: Begin to run Node-ann-search/milvus-client-2
2025-04-12 17:01:50,890 - 18838138880 - node.py-node:167 - INFO: Begin to run Node-lambda-3
2025-04-12 17:01:50,890 - 18770833408 - node.py-node:167 - INFO: Begin to run Node-lambda-4
2025-04-12 17:01:50,891 - 18854965248 - node.py-node:167 - INFO: Begin to run Node-_output


video_id,sentence,ground_truth,top1,top5,top10
12,data/MSRVTT/videos/all/video12.mp4,12,"[[2410, 1.5083791017532349]] len=1","[[2410, 1.5083791017532349],[8854, 1.533107042312622],[3068, 1.5452654361724854],[904, 1.5486265420913696],...] len=5","[[2410, 1.5083791017532349],[8854, 1.533107042312622],[3068, 1.5452654361724854],[904, 1.5486265420913696],...] len=10"
6110,data/MSRVTT/videos/all/video6110.mp4,6110,"[[2410, 1.5018949508666992]] len=1","[[2410, 1.5018949508666992],[8854, 1.521482229232788],[3068, 1.5339115858078003],[904, 1.5529305934906006],...] len=5","[[2410, 1.5018949508666992],[8854, 1.521482229232788],[3068, 1.5339115858078003],[904, 1.5529305934906006],...] len=10"
9223,data/MSRVTT/videos/all/video9223.mp4,9223,"[[2410, 1.5197679996490479]] len=1","[[2410, 1.5197679996490479],[8854, 1.5299631357192993],[3068, 1.5301129817962646],[2376, 1.5480180978775024],...] len=5","[[2410, 1.5197679996490479],[8854, 1.5299631357192993],[3068, 1.5301129817962646],[2376, 1.5480180978775024],...] len=10"
2376,data/MSRVTT/videos/all/video2376.mp4,2376,"[[2410, 1.5124073028564453]] len=1","[[2410, 1.5124073028564453],[8854, 1.527886152267456],[3068, 1.5359500646591187],[6104, 1.5534569025039673],...] len=5","[[2410, 1.5124073028564453],[8854, 1.527886152267456],[3068, 1.5359500646591187],[6104, 1.5534569025039673],...] len=10"
938,data/MSRVTT/videos/all/video938.mp4,938,"[[2410, 1.5077203512191772]] len=1","[[2410, 1.5077203512191772],[8854, 1.5304750204086304],[3068, 1.5305721759796143],[904, 1.5520923137664795],...] len=5","[[2410, 1.5077203512191772],[8854, 1.5304750204086304],[3068, 1.5305721759796143],[904, 1.5520923137664795],...] len=10"
