# text video retrieval engine

> [How to Build a Text-Video Retrieval Engine](https://github.com/towhee-io/examples/blob/07b8d446d26f3af371d8f3e8cb4fd91c9c0cd991/video/text_video_retrieval/1_text_video_retrieval_engine.ipynb)

## prepare the data

`pdm run data_msrvtt`

In [1]:
import os
from pathlib import Path

import pandas as pd

DATA_DIR = Path("./data")

raw_video_path = DATA_DIR / "test_1k_compress"  # 1k test video path.
test_csv_path = DATA_DIR / "MSRVTT_JSFUSION_test.csv"  # 1k video caption csv.

test_sample_csv_path = DATA_DIR / "MSRVTT_JSFUSION_test_sample.csv"

sample_num = 1000  # you can change this sample_num to be smaller, so that this notebook will be faster.
test_df = pd.read_csv(test_csv_path)
print("length of all test set is {}".format(len(test_df)))
sample_df = test_df.sample(sample_num, random_state=42)

sample_df["video_path"] = sample_df.apply(
    lambda x: os.path.join(raw_video_path, x["video_id"]) + ".mp4", axis=1
)

sample_df.to_csv(test_sample_csv_path)
print("random sample {} examples".format(sample_num))

df = pd.read_csv(test_sample_csv_path)

df[["video_id", "video_path", "sentence"]].head()

length of all test set is 1000
random sample 1000 examples


Unnamed: 0,video_id,video_path,sentence
0,video7579,data/test_1k_compress/video7579.mp4,a girl wearing red top and black trouser is pu...
1,video7725,data/test_1k_compress/video7725.mp4,young people sit around the edges of a room cl...
2,video9258,data/test_1k_compress/video9258.mp4,a person is using a phone
3,video7365,data/test_1k_compress/video7365.mp4,cartoon people are eating at a restaurant
4,video8068,data/test_1k_compress/video8068.mp4,a woman on a couch talks to a a man


## Create a Milvus Collection

In [2]:
from pymilvus import (
    Collection,
    CollectionSchema,
    DataType,
    FieldSchema,
    connections,
    utility,
)

connections.connect(host="127.0.0.1", port="19530")


def create_milvus_collection(collection_name, dim):
    if utility.has_collection(collection_name):
        utility.drop_collection(collection_name)

    fields = [
        FieldSchema(
            name="id",
            dtype=DataType.INT64,
            descrition="ids",
            is_primary=True,
            auto_id=False,
        ),
        FieldSchema(
            name="embedding",
            dtype=DataType.FLOAT_VECTOR,
            descrition="embedding vectors",
            dim=dim,
        ),
    ]
    schema = CollectionSchema(fields=fields, description="video retrieval")
    collection = Collection(name=collection_name, schema=schema)

    # create IVF_FLAT index for collection.
    index_params = {
        "metric_type": "L2",  # IP
        "index_type": "IVF_FLAT",
        "params": {"nlist": 2048},
    }
    collection.create_index(field_name="embedding", index_params=index_params)
    return collection

In [3]:
collection = create_milvus_collection("text_video_retrieval", 512)

## Text-Video retrieval

In [11]:
%%time
import os

from towhee import ops, pipe
from towhee.datacollection import DataCollection
from towhee.operator import PyOperator


def read_csv(csv_file):
    import csv

    with open(csv_file, "r", encoding="utf-8-sig") as f:
        data = csv.DictReader(f)
        for line in data:
            yield int(line["video_id"][len("video") :]), line["video_path"]


dc = (
    pipe.input("csv_file")
    .flat_map("csv_file", ("video_id", "video_path"), read_csv)
    .map(
        "video_path",
        "frames",
        ops.video_decode.ffmpeg(
            sample_type="uniform_temporal_subsample", args={"num_samples": 12}
        ),
    )
    .map(
        "frames",
        "vec",
        ops.video_text_embedding.clip4clip(model_name="clip_vit_b32", modality="video"),
    )
    .map(
        ("video_id", "vec"),
        (),
        ops.ann_insert.milvus_client(
            host="127.0.0.1", port="19530", collection_name="text_video_retrieval"
        ),
    )
    .output("video_id")
)

CPU times: user 2.89 s, sys: 134 ms, total: 3.02 s
Wall time: 2.97 s


In [5]:
dc(test_sample_csv_path)
collection.load()

In [6]:
print("Total number of inserted data is {}.".format(collection.num_entities))

Total number of inserted data is 0.


In [7]:
%%time


def read_csv(csv_file):
    import csv

    with open(csv_file, "r", encoding="utf-8-sig") as f:
        data = csv.DictReader(f)
        for line in data:
            yield line["video_id"], line["sentence"]


dc_search = (
    pipe.input("csv_file")
    .flat_map("csv_file", ("video_id", "sentence"), read_csv)
    .map(
        "sentence",
        "vec",
        ops.video_text_embedding.clip4clip(model_name="clip_vit_b32", modality="text"),
    )
    .map(
        "vec",
        "top10_raw_res",
        ops.ann_search.milvus_client(
            host="127.0.0.1",
            port="19530",
            collection_name="text_video_retrieval",
            limit=10,
        ),
    )
    .map("top10_raw_res", ("top1", "top5", "top10"), lambda x: (x[:1], x[:5], x[:10]))
    .map("video_id", "ground_truth", lambda x: x[len("video") :])
    .output("video_id", "sentence", "ground_truth", "top1", "top5", "top10")
)

CPU times: user 2.9 s, sys: 135 ms, total: 3.04 s
Wall time: 3.02 s


In [8]:
from towhee.datacollection import DataCollection

ret = DataCollection(dc_search(test_sample_csv_path))
ret.show()

video_id,sentence,ground_truth,top1,top5,top10
video7579,a girl wearing red top and black trouser is putting a sweater on a dog,7579,"[[7579, 1.4151520729064941]] len=1","[[7579, 1.4151520729064941],[9969, 1.4799103736877441],[8837, 1.4897732734680176],[9347, 1.4948582649230957],...] len=5","[[7579, 1.4151520729064941],[9969, 1.4799103736877441],[8837, 1.4897732734680176],[9347, 1.4948582649230957],...] len=10"
video7725,young people sit around the edges of a room clapping and raising their arms while others dance in the center during a party,7725,"[[7725, 1.3622068166732788]] len=1","[[7725, 1.3622068166732788],[8014, 1.4865269660949707],[8339, 1.4922082424163818],[8442, 1.5024113655090332],...] len=5","[[7725, 1.3622068166732788],[8014, 1.4865269660949707],[8339, 1.4922082424163818],[8442, 1.5024113655090332],...] len=10"
video9258,a person is using a phone,9258,"[[9258, 1.401197075843811]] len=1","[[9258, 1.401197075843811],[9257, 1.4228630065917969],[9697, 1.4413856267929077],[7910, 1.4945622682571411],...] len=5","[[9258, 1.401197075843811],[9257, 1.4228630065917969],[9697, 1.4413856267929077],[7910, 1.4945622682571411],...] len=10"
video7365,cartoon people are eating at a restaurant,7365,"[[7365, 1.4027700424194336]] len=1","[[7365, 1.4027700424194336],[8781, 1.4623045921325684],[9537, 1.4739770889282227],[7831, 1.505112886428833],...] len=5","[[7365, 1.4027700424194336],[8781, 1.4623045921325684],[9537, 1.4739770889282227],[7831, 1.505112886428833],...] len=10"
video8068,a woman on a couch talks to a a man,8068,"[[7162, 1.471674919128418]] len=1","[[7162, 1.471674919128418],[8304, 1.4787474870681763],[8068, 1.4926886558532715],[7724, 1.4982554912567139],...] len=5","[[7162, 1.471674919128418],[8304, 1.4787474870681763],[8068, 1.4926886558532715],[7724, 1.4982554912567139],...] len=10"


## Evaluation

`Recall@topk` is the proportion of relevant items found in the top-k recommendations.

In [9]:
def mean_hit_ratio(actual, *predicteds):
    rets = []
    for predicted in predicteds:
        ratios = []
        for act, pre in zip(actual, predicted):
            hit_num = len(set(act) & set(pre))
            ratios.append(hit_num / len(act))
        rets.append(sum(ratios) / len(ratios))
    return rets


def get_label_from_raw_data(data):
    ret = []
    for item in data:
        ret.append(item[0])
    return ret


ev = (
    pipe.input("dc_data")
    .flat_map("dc_data", "data", lambda x: x)
    .map(
        "data",
        ("ground_truth", "top1", "top5", "top10"),
        lambda x: (
            [int(x.ground_truth)],
            get_label_from_raw_data(x.top1),
            get_label_from_raw_data(x.top5),
            get_label_from_raw_data(x.top10),
        ),
    )
    .window_all(
        ("ground_truth", "top1", "top5", "top10"),
        ("top1_mean_hit_ratio", "top5_mean_hit_ratio", "top10_mean_hit_ratio"),
        mean_hit_ratio,
    )
    .output("top1_mean_hit_ratio", "top5_mean_hit_ratio", "top10_mean_hit_ratio")
)

DataCollection(ev(ret)).show()

top1_mean_hit_ratio,top5_mean_hit_ratio,top10_mean_hit_ratio
0.426,0.716,0.814


## Release a Showcase

In [10]:
import gradio

show_num = 3

milvus_search_pipe = (
    pipe.input("sentence")
    .map(
        "sentence",
        "vec",
        ops.video_text_embedding.clip4clip(
            model_name="clip_vit_b32", modality="text", device="cpu"
        ),
    )
    .map(
        "vec",
        "rows",
        ops.ann_search.milvus_client(
            host="127.0.0.1",
            port="19530",
            collection_name="text_video_retrieval",
            limit=show_num,
        ),
    )
    .map(
        "rows",
        "videos_path",
        lambda rows: (
            os.path.join(raw_video_path, "video" + str(r[0]) + ".mp4") for r in rows
        ),
    )
    .output("videos_path")
)


def milvus_search_function(text):
    return milvus_search_pipe(text).to_list()[0][0]


# print(milvus_search_function('a girl wearing red top and black trouser is putting a sweater on a dog'))


interface = gradio.Interface(
    milvus_search_function,
    inputs=[gradio.Textbox()],
    outputs=[gradio.Video(format="mp4") for _ in range(show_num)],
)

interface.launch(inline=True, share=False)

Running on local URL:  http://127.0.0.1:7861

To create a public link, set `share=True` in `launch()`.


