In [4]:
! python -m pip install -q towhee towhee.models pillow ipython gradio

In [14]:
import pandas as pd

df = pd.read_csv('./reverse_video_search.csv')

In [6]:
import os

id_video = df.set_index('id')['path'].to_dict()
label_ids = {}
for label in set(df['label']):
    label_ids[label] = list(df[df['label']==label].id)

def ground_truth(path):
    print("Path received:", path)
    label = os.path.basename(os.path.dirname(path))
    print("Extracted label:", label)
    if label not in label_ids:
        print("Label not found in label_ids dictionary:", label)
    return label_ids.get(label, None)

In [7]:
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility

connections.connect(host='localhost', port='19530')

def create_milvus_collection(collection_name, dim):    
    if utility.has_collection(collection_name):
        utility.drop_collection(collection_name)
    
    fields = [
    FieldSchema(name='id', dtype=DataType.INT64, descrition='ids', is_primary=True, auto_id=False),
    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='embedding vectors', dim=dim)
    ]
    schema = CollectionSchema(fields=fields, description='reverse video search')
    collection = Collection(name=collection_name, schema=schema)

    # create IVF_FLAT index for collection.
    index_params = {
        'metric_type':'L2',
        'index_type':"IVF_FLAT",
        'params':{"nlist": 400}
    }
    collection.create_index(field_name="embedding", index_params=index_params)
    return collection

collection = create_milvus_collection('x3d_m', 2048)

In [9]:
from towhee import pipe, ops
from towhee.datacollection import DataCollection

def read_csv(csv_file):
    import csv
    with open(csv_file, 'r', encoding='utf-8-sig') as f:
        data = csv.DictReader(f)
        for line in data:
            yield line['id'], line['path'], line['label']


insert_pipe = (
    pipe.input('csv_path')
        .flat_map('csv_path', ('id', 'path', 'label'), read_csv)
        .map('id', 'id', lambda x: int(x))
        .map('path', 'frames', ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 16}))
        .map('frames', ('labels', 'scores', 'features'), ops.action_classification.pytorchvideo(model_name='x3d_m', skip_preprocess=True))
        .map(('id', 'features'), 'insert_res', ops.ann_insert.milvus_client(host='127.0.0.1', port='19530', collection_name='x3d_m'))
        .output()
)

insert_pipe('reverse_video_search.csv')
print('Total number of inserted data is {}.'.format(collection.num_entities))

Using cache found in C:\Users\HP/.cache\torch\hub\facebookresearch_pytorchvideo_main
2024-04-25 09:23:49,627 - 7196 - node.py-node:167 - INFO: Begin to run Node-_input
2024-04-25 09:23:49,638 - 7196 - node.py-node:167 - INFO: Begin to run Node-read_csv-0
2024-04-25 09:23:49,639 - 10084 - node.py-node:167 - INFO: Begin to run Node-lambda-1
2024-04-25 09:23:49,649 - 7196 - node.py-node:167 - INFO: Begin to run Node-video-decode/ffmpeg-2
2024-04-25 09:23:49,661 - 10084 - node.py-node:167 - INFO: Begin to run Node-action-classification/pytorchvideo-3
2024-04-25 09:23:49,661 - 1492 - node.py-node:167 - INFO: Begin to run Node-ann-insert/milvus-client-4
2024-04-25 09:23:49,662 - 6832 - node.py-node:167 - INFO: Begin to run Node-_output


Total number of inserted data is 220.


In [10]:
collection.load()

query_path = './test/eating_carrots/ty4UQlowp0c.mp4'

query_pipe = (
    pipe.input('path')
        .map('path', 'frames', ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 16}))
        .map('frames', ('labels', 'scores', 'features'), ops.action_classification.pytorchvideo(model_name='x3d_m', skip_preprocess=True))
        .map('features', 'result', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', collection_name='x3d_m', limit=10))  
        .map('result', 'candidates', lambda x: [id_video[i[0]] for i in x])
        .output('path', 'candidates')
)

res = DataCollection(query_pipe(query_path))
res.show()

Using cache found in C:\Users\HP/.cache\torch\hub\facebookresearch_pytorchvideo_main
2024-04-25 09:39:38,615 - 11180 - node.py-node:167 - INFO: Begin to run Node-_input
2024-04-25 09:39:38,627 - 13124 - node.py-node:167 - INFO: Begin to run Node-video-decode/ffmpeg-0
2024-04-25 09:39:38,628 - 17208 - node.py-node:167 - INFO: Begin to run Node-action-classification/pytorchvideo-1
2024-04-25 09:39:38,630 - 11180 - node.py-node:167 - INFO: Begin to run Node-ann-search/milvus-client-2
2024-04-25 09:39:38,635 - 16492 - node.py-node:167 - INFO: Begin to run Node-lambda-3
2024-04-25 09:39:38,639 - 5960 - node.py-node:167 - INFO: Begin to run Node-_output


path,candidates
./test/eating_carrots/ty4UQlowp0c.mp4,./train/eating_carrots/V7DUq0JJneY.mp4 ./train/eating_carrots/bTCznQiu0hc.mp4 ./train/eating_carrots/Ou1w86qEr58.mp4 ./train/eating_carrots/Ka6z9NtiVMQ.mp4 ./train/eating_carrots/9OZhQqMhX50.mp4


In [11]:
import os
from IPython import display
from PIL import Image

tmp_dir = './tmp'
os.makedirs(tmp_dir, exist_ok=True)

def video_to_gif(video_path):
    gif_path = os.path.join(tmp_dir, video_path.split('/')[-1][:-4] + '.gif')
    p = (
        pipe.input('path')
            .map('path', 'frames', ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 16}))
            .output('frames')
    )
    frames = p(video_path).get()[0]
    imgs = [Image.fromarray(frame) for frame in frames]
    imgs[0].save(fp=gif_path, format='GIF', append_images=imgs[1:], save_all=True, loop=0)
    return gif_path

html = 'Query video "{}": <br/>'.format(query_path.split('/')[-2])
query_gif = video_to_gif(query_path)
html_line = '<img src="{}"> <br/>'.format(query_gif)
html +=  html_line
html += 'Top 3 search results: <br/>'

for path in res[0]['candidates'][:3]:
    gif_path = video_to_gif(path)
    html_line = '<img src="{}" style="display:inline;margin:1px"/>'.format(gif_path)
    html +=  html_line
display.HTML(html)

2024-04-25 09:39:49,532 - 5700 - node.py-node:167 - INFO: Begin to run Node-_input
2024-04-25 09:39:49,541 - 14900 - node.py-node:167 - INFO: Begin to run Node-video-decode/ffmpeg-0
2024-04-25 09:39:49,550 - 7724 - node.py-node:167 - INFO: Begin to run Node-_output
2024-04-25 09:39:51,948 - 13820 - node.py-node:167 - INFO: Begin to run Node-_input
2024-04-25 09:39:51,950 - 16464 - node.py-node:167 - INFO: Begin to run Node-video-decode/ffmpeg-0
2024-04-25 09:39:51,951 - 16456 - node.py-node:167 - INFO: Begin to run Node-_output
2024-04-25 09:39:54,650 - 5704 - node.py-node:167 - INFO: Begin to run Node-_input
2024-04-25 09:39:54,655 - 4772 - node.py-node:167 - INFO: Begin to run Node-video-decode/ffmpeg-0
2024-04-25 09:39:54,656 - 5704 - node.py-node:167 - INFO: Begin to run Node-_output
2024-04-25 09:39:55,810 - 13968 - node.py-node:167 - INFO: Begin to run Node-_input
2024-04-25 09:39:55,812 - 3344 - node.py-node:167 - INFO: Begin to run Node-video-decode/ffmpeg-0
2024-04-25 09:39:55

In [12]:
import glob

def mean_hit_ratio(actual, predicted):
    ratios = []
    for act, pre in zip(actual, predicted):
        hit_num = len(set(act) & set(pre))
        ratios.append(hit_num / len(act))
    return sum(ratios) / len(ratios)

def mean_average_precision(actual, predicted):
    aps = []
    for act, pre in zip(actual, predicted):
        precisions = []
        hit = 0
        for idx, i in enumerate(pre):
            if i in act:
                hit += 1
            precisions.append(hit / (idx + 1))
        aps.append(sum(precisions) / len(precisions))
    
    return sum(aps) / len(aps)

eval_pipe = (
    pipe.input('path')
        .flat_map('path', 'path', lambda x: glob.glob(x))
        .map('path', 'frames', ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 16}))
        .map('frames', ('labels', 'scores', 'features'), ops.action_classification.pytorchvideo(model_name='x3d_m', skip_preprocess=True))
        .map('features', 'result', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', collection_name='x3d_m', limit=10))  
        .map('result', 'predict', lambda x: [i[0] for i in x])
        .map('path', 'ground_truth', ground_truth)
        .window_all(('ground_truth', 'predict'), 'mHR', mean_hit_ratio)
        .window_all(('ground_truth', 'predict'), 'mAP', mean_average_precision)
        .output('mHR', 'mAP')
)

res = DataCollection(eval_pipe('./test/*/*.mp4'))
res.show()

Using cache found in C:\Users\HP/.cache\torch\hub\facebookresearch_pytorchvideo_main
2024-04-25 09:40:09,053 - 9372 - node.py-node:167 - INFO: Begin to run Node-_input
2024-04-25 09:40:09,057 - 13600 - node.py-node:167 - INFO: Begin to run Node-lambda-0
2024-04-25 09:40:09,058 - 9372 - node.py-node:167 - INFO: Begin to run Node-video-decode/ffmpeg-1
2024-04-25 09:40:09,061 - 15832 - node.py-node:167 - INFO: Begin to run Node-action-classification/pytorchvideo-2
2024-04-25 09:40:09,065 - 15868 - node.py-node:167 - INFO: Begin to run Node-ann-search/milvus-client-3
2024-04-25 09:40:09,066 - 2172 - node.py-node:167 - INFO: Begin to run Node-lambda-4
2024-04-25 09:40:09,077 - 16996 - node.py-node:167 - INFO: Begin to run Node-ground_truth-5
2024-04-25 09:40:09,078 - 1988 - node.py-node:167 - INFO: Begin to run Node-mean_hit_ratio-6
2024-04-25 09:40:09,086 - 14900 - node.py-node:167 - INFO: Begin to run Node-mean_average_precision-7
2024-04-25 09:40:09,107 - 13600 - node.py-node:167 - INFO:

Path received: ./test\chopping_wood\kDuAS29BCwk.mp4
Extracted label: chopping_wood
Path received: ./test\clay_pottery_making\QRQfX7aUPqs.mp4
Extracted label: clay_pottery_making
Path received: ./test\country_line_dancing\1TzGn3qcsTk.mp4
Extracted label: country_line_dancing
Path received: ./test\dancing_gangnam_style\LhZSA_QY8Fg.mp4
Extracted label: dancing_gangnam_style
Path received: ./test\doing_aerobics\r63FpwJ9dik.mp4
Extracted label: doing_aerobics
Path received: ./test\drop_kicking\ONMjPkk2x0Y.mp4
Extracted label: drop_kicking
Path received: ./test\dunking_basketball\y_-ivQSPV0Q.mp4
Extracted label: dunking_basketball
Path received: ./test\eating_carrots\ty4UQlowp0c.mp4
Extracted label: eating_carrots
Path received: ./test\eating_hotdog\rJu8mSNHX_8.mp4
Extracted label: eating_hotdog
Path received: ./test\javelin_throw\ZmBDBldVa74.mp4
Extracted label: javelin_throw
Path received: ./test\juggling_fire\umd-9rS3hQg.mp4
Extracted label: juggling_fire
Path received: ./test\juggling_so

mHR,mAP
0.345,0.7013333333333333


In [17]:
collection = create_milvus_collection('x3d_m_norm', 2048)

insert_pipe = (
    pipe.input('csv_path')
        .flat_map('csv_path', ('id', 'path', 'label'), read_csv)
        .map('id', 'id', lambda x: int(x))
        .map('path', 'frames', ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 16}))
        .map('frames', ('labels', 'scores', 'features'), ops.action_classification.pytorchvideo(model_name='x3d_m', skip_preprocess=True))
        .map('features', 'features', ops.towhee.np_normalize())
        .map(('id', 'features'), 'insert_res', ops.ann_insert.milvus_client(host='127.0.0.1', port='19530', collection_name='x3d_m_norm'))
        .output()
)

insert_pipe('reverse_video_search.csv')

collection.load()
eval_pipe = (
    pipe.input('path')
        .flat_map('path', 'path', lambda x: glob.glob(x))
        .map('path', 'frames', ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 16}))
        .map('frames', ('labels', 'scores', 'features'), ops.action_classification.pytorchvideo(model_name='x3d_m', skip_preprocess=True))
        .map('features', 'features', ops.towhee.np_normalize())
        .map('features', 'result', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', collection_name='x3d_m_norm', limit=10))  
        .map('result', 'predict', lambda x: [i[0] for i in x])
        .map('path', 'ground_truth', ground_truth)
        .window_all(('ground_truth', 'predict'), 'mHR', mean_hit_ratio)
        .window_all(('ground_truth', 'predict'), 'mAP', mean_average_precision)
        .output('mHR', 'mAP')
)

res = DataCollection(eval_pipe('./test/*/*.mp4'))
res.show()

Using cache found in C:\Users\HP/.cache\torch\hub\facebookresearch_pytorchvideo_main
2024-04-24 12:26:08,512 - 10224 - node.py-node:167 - INFO: Begin to run Node-_input
2024-04-24 12:26:08,521 - 10224 - node.py-node:167 - INFO: Begin to run Node-read_csv-0
2024-04-24 12:26:08,523 - 3944 - node.py-node:167 - INFO: Begin to run Node-lambda-1
2024-04-24 12:26:08,523 - 16320 - node.py-node:167 - INFO: Begin to run Node-video-decode/ffmpeg-2
2024-04-24 12:26:08,524 - 1696 - node.py-node:167 - INFO: Begin to run Node-action-classification/pytorchvideo-3
2024-04-24 12:26:08,536 - 9600 - node.py-node:167 - INFO: Begin to run Node-towhee/np-normalize-4
2024-04-24 12:26:08,540 - 10224 - node.py-node:167 - INFO: Begin to run Node-ann-insert/milvus-client-5
2024-04-24 12:26:08,541 - 5424 - node.py-node:167 - INFO: Begin to run Node-_output
Using cache found in C:\Users\HP/.cache\torch\hub\facebookresearch_pytorchvideo_main
2024-04-24 12:29:59,700 - 17184 - node.py-node:167 - INFO: Begin to run Nod

Path received: ./test\chopping_wood\kDuAS29BCwk.mp4
Extracted label: chopping_wood
Path received: ./test\clay_pottery_making\QRQfX7aUPqs.mp4
Extracted label: clay_pottery_making
Path received: ./test\country_line_dancing\1TzGn3qcsTk.mp4
Extracted label: country_line_dancing
Path received: ./test\dancing_gangnam_style\LhZSA_QY8Fg.mp4
Extracted label: dancing_gangnam_style
Path received: ./test\doing_aerobics\r63FpwJ9dik.mp4
Extracted label: doing_aerobics
Path received: ./test\drop_kicking\ONMjPkk2x0Y.mp4
Extracted label: drop_kicking
Path received: ./test\dunking_basketball\y_-ivQSPV0Q.mp4
Extracted label: dunking_basketball


2024-04-24 12:30:12,849 - 2412 - node.py-node:167 - INFO: Begin to run Node-_output


Path received: ./test\eating_carrots\ty4UQlowp0c.mp4
Extracted label: eating_carrots
Path received: ./test\eating_hotdog\rJu8mSNHX_8.mp4
Extracted label: eating_hotdog
Path received: ./test\javelin_throw\ZmBDBldVa74.mp4
Extracted label: javelin_throw
Path received: ./test\juggling_fire\umd-9rS3hQg.mp4
Extracted label: juggling_fire
Path received: ./test\juggling_soccer_ball\bH9cE46eZJY.mp4
Extracted label: juggling_soccer_ball
Path received: ./test\playing_trombone\iiJS7YTT5Gs.mp4
Extracted label: playing_trombone
Path received: ./test\pumping_fist\t6-eBFhPsxo.mp4
Extracted label: pumping_fist
Path received: ./test\pushing_cart\esDMtgFIAtQ.mp4
Extracted label: pushing_cart
Path received: ./test\riding_mule\2uzRsYqIwy4.mp4
Extracted label: riding_mule
Path received: ./test\shuffling_cards\8oy1RNINVfQ.mp4
Extracted label: shuffling_cards
Path received: ./test\tap_dancing\EC3Jzrs_mNA.mp4
Extracted label: tap_dancing
Path received: ./test\trimming_trees\dxe9LkBB4Q8.mp4
Extracted label: tri

mHR,mAP
0.66,0.7376626984126984


In [None]:
import gradio as gr
import torch
import torchvision.transforms as transforms
from torchvision.models import video
from PIL import Image
import glob
import os

video_search_pipe = (
    pipe.input('path')
        .flat_map('path', 'path', lambda x: glob.glob(x))
        .map('path', 'frames', ops.video_decode.ffmpeg(sample_type='uniform_temporal_subsample', args={'num_samples': 32}))
        .map('frames', ('labels', 'scores', 'features'), ops.action_classification.pytorchvideo(model_name='x3d_m', skip_preprocess=True))
        .map('features', 'features', ops.towhee.np_normalize())
        .map('features', 'result', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', collection_name='x3d_m_norm', limit=3)) 
        .map('result', 'predict', lambda x: [id_video[i[0]] for i in x])
        .output('predict')
)


def video_search_function(video):
    return video_search_pipe(video).to_list()[0][0]

interface = gradio.Interface(video_search_function, 
                             inputs=gradio.Video(source='upload'),
                             outputs=[gradio.Video(format='mp4') for _ in range(3)]
                            )

interface.launch(inline=True, share=True)

Using cache found in C:\Users\HP/.cache\torch\hub\facebookresearch_pytorchvideo_main
2024-04-25 10:33:59,374 - 12828 - connectionpool.py-connectionpool:1055 - DEBUG: Starting new HTTPS connection (1): api.gradio.app:443
2024-04-25 10:33:59,381 - 12832 - connectionpool.py-connectionpool:1055 - DEBUG: Starting new HTTPS connection (1): api.gradio.app:443
2024-04-25 10:33:59,749 - 12528 - selector_events.py-selector_events:54 - DEBUG: Using selector: SelectSelector
2024-04-25 10:33:59,808 - 2612 - connectionpool.py-connectionpool:244 - DEBUG: Starting new HTTP connection (1): 127.0.0.1:7862
2024-04-25 10:33:59,845 - 2612 - connectionpool.py-connectionpool:549 - DEBUG: http://127.0.0.1:7862 "GET /startup-events HTTP/1.1" 200 5
2024-04-25 10:33:59,914 - 2612 - connectionpool.py-connectionpool:244 - DEBUG: Starting new HTTP connection (1): 127.0.0.1:7862
2024-04-25 10:33:59,966 - 2612 - connectionpool.py-connectionpool:549 - DEBUG: http://127.0.0.1:7862 "HEAD / HTTP/1.1" 200 0
2024-04-25 10:

Running on local URL:  http://127.0.0.1:7862


2024-04-25 10:34:01,983 - 12832 - connectionpool.py-connectionpool:549 - DEBUG: https://api.gradio.app:443 "GET /pkg-version HTTP/1.1" 200 21
2024-04-25 10:34:02,002 - 12828 - connectionpool.py-connectionpool:549 - DEBUG: https://api.gradio.app:443 "POST /gradio-initiated-analytics/ HTTP/1.1" 200 None
2024-04-25 10:34:05,146 - 2612 - connectionpool.py-connectionpool:549 - DEBUG: https://api.gradio.app:443 "GET /v2/tunnel-request HTTP/1.1" 200 None
2024-04-25 10:34:05,153 - 2612 - connectionpool.py-connectionpool:1055 - DEBUG: Starting new HTTPS connection (1): cdn-media.huggingface.co:443
2024-04-25 10:34:07,376 - 2612 - connectionpool.py-connectionpool:549 - DEBUG: https://cdn-media.huggingface.co:443 "GET /frpc-gradio-0.2/frpc_windows_amd64.exe HTTP/1.1" 200 11681280



Could not create share link. Please check your internet connection or our status page: https://status.gradio.app.


2024-04-25 10:34:20,916 - 14864 - connectionpool.py-connectionpool:1055 - DEBUG: Starting new HTTPS connection (1): api.gradio.app:443


2024-04-25 10:34:21,006 - 18972 - connectionpool.py-connectionpool:1055 - DEBUG: Starting new HTTPS connection (1): api.gradio.app:443




2024-04-25 10:34:24,390 - 18972 - connectionpool.py-connectionpool:549 - DEBUG: https://api.gradio.app:443 "POST /gradio-launched-telemetry/ HTTP/1.1" 200 None
2024-04-25 10:34:24,391 - 14864 - connectionpool.py-connectionpool:549 - DEBUG: https://api.gradio.app:443 "POST /gradio-error-analytics/ HTTP/1.1" 200 None
2024-04-25 10:35:23,075 - 16848 - node.py-node:167 - INFO: Begin to run Node-_input
2024-04-25 10:35:23,105 - 10472 - node.py-node:167 - INFO: Begin to run Node-lambda-0
2024-04-25 10:35:23,124 - 12972 - node.py-node:167 - INFO: Begin to run Node-video-decode/ffmpeg-1
2024-04-25 10:35:23,131 - 16848 - node.py-node:167 - INFO: Begin to run Node-action-classification/pytorchvideo-2
2024-04-25 10:35:23,197 - 10472 - node.py-node:167 - INFO: Begin to run Node-towhee/np-normalize-3
2024-04-25 10:35:23,199 - 16140 - node.py-node:167 - INFO: Begin to run Node-ann-search/milvus-client-4
2024-04-25 10:35:23,262 - 16496 - node.py-node:167 - INFO: Begin to run Node-lambda-5
2024-04-25 