In [None]:
# Define connection parameters for your cloud Milvus instance
URI = "XXXXXXXXXXXX"  # Replace with your cloud Milvus connection URI
API_KEY = "XXXXXXXXXXXXXXX"  # Replace with your cloud Milvus API key

In [None]:
file_name = "XXXXXX.csv" # replace with your file name

In [None]:
import pandas as pd

df = pd.read_csv(file_name) 
df.head()

In [None]:
import cv2
from towhee.types.image import Image

id_img = df.set_index('id')['path'].to_dict()
def read_images(results):
    imgs = []
    for re in results:
        path = id_img[re.id]
        imgs.append(Image(cv2.imread(path), 'BGR'))
    return imgs

In [None]:
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility

def create_milvus_collection(collection_name, dim):
    # connections.connect(host='127.0.0.1', port='19530')

    connections.connect(uri=URI, token=API_KEY, secure=True)

    if utility.has_collection(collection_name):
        utility.drop_collection(collection_name)

    fields = [
    FieldSchema(name='id', dtype=DataType.INT64, descrition='ids', is_primary=True, auto_id=False),
    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='embedding vectors', dim=dim)
    ]
    schema = CollectionSchema(fields=fields, description='text image search')
    collection = Collection(name=collection_name, schema=schema)

    # create IVF_FLAT index for collection.
    index_params = {
        'metric_type':'L2',
        'index_type':"IVF_FLAT",
        'params':{"nlist":512}
    }
    collection.create_index(field_name="embedding", index_params=index_params)
    return collection

collection = create_milvus_collection('text_image_search', 512)

In [None]:
from towhee import ops, pipe, DataCollection
import numpy as np

In [None]:
p = (
    pipe.input('path')
    .map('path', 'img', ops.image_decode.cv2('rgb'))
    .map('img', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image'))
    .map('vec', 'vec', lambda x: x / np.linalg.norm(x))
    .output('img', 'vec')
)

DataCollection(p('image.png')).show() # replace image.png with any test image

In [None]:
p2 = (
    pipe.input('text')
    .map('text', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='text'))
    .map('vec', 'vec', lambda x: x / np.linalg.norm(x))
    .output('text', 'vec')
)

DataCollection(p2("A teddybear on a skateboard in Times Square.")).show()

In [None]:
%%time
collection = create_milvus_collection('text_image_search', 512)
from pymilvus import connections

# Use the URI and API key to establish the Milvus connection
connections.connect(uri=URI, token = API_KEY)
def read_csv(csv_path, encoding='utf-8-sig'):
    import csv
    with open(csv_path, 'r', encoding=encoding) as f:
        data = csv.DictReader(f)
        for line in data:
            yield int(line['id']), line['path']

p3 = (
    pipe.input('csv_file')
    .flat_map('csv_file', ('id', 'path'), read_csv)
    .map('path', 'img', ops.image_decode.cv2('rgb'))
    .map('img', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image', device=0))
    .map('vec', 'vec', lambda x: x / np.linalg.norm(x))
    .map(('id', 'vec'), (), ops.ann_insert.milvus_client(uri=URI, token=API_KEY,collection_name='text_image_search'))
    .output()
)

ret = p3(file_name)

In [None]:
collection.load()
print('Total number of inserted data is {}.'.format(collection.num_entities))

In [None]:
import pandas as pd
import cv2

def read_image(image_ids):
    df = pd.read_csv(file_name)
    id_img = df.set_index('id')['path'].to_dict()
    imgs = []
    decode = ops.image_decode.cv2('rgb')
    for image_id in image_ids:
        path = id_img[image_id]
        imgs.append(decode(path))
    return imgs


p4 = (
    pipe.input('text')
    .map('text', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='text'))
    .map('vec', 'vec', lambda x: x / np.linalg.norm(x))
    .map('vec', 'result', ops.ann_search.milvus_client(uri=URI, token= API_KEY, collection_name='text_image_search', limit=5))
    .map('result', 'image_ids', lambda x: [item[0] for item in x])
    .map('image_ids', 'images', read_image)
    .output('text', 'images')
)

DataCollection(p4("book")).show()
DataCollection(p4("A black dog")).show()

In [None]:
search_pipeline = (
    pipe.input('text')
    .map('text', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='text'))
    .map('vec', 'vec', lambda x: x / np.linalg.norm(x))
    .map('vec', 'result', ops.ann_search.milvus_client(uri=URI, token= API_KEY, collection_name='text_image_search', limit=5))
    .map('result', 'image_ids', lambda x: [item[0] for item in x])
    .output('image_ids')
)

def search(text):
    df = pd.read_csv(file_name)
    id_img = df.set_index('id')['path'].to_dict()
    imgs = []
    image_ids = search_pipeline(text).to_list()[0][0]
    return [id_img[image_id] for image_id in image_ids]


import gradio as gr

output_images = [gr.Image(type="filepath") for _ in range(5)]
interface = gr.Interface(fn=search, inputs="text", outputs=output_images)

interface.launch(inline=True, share=True)