In [None]:
!pip install milvus
!pip install pymilvus
!pip3 install faiss-gpu
!pip install torchvision
!pip install timm

In [None]:
import torch
import csv
from glob import glob
from pathlib import Path
from statistics import mean

from towhee import pipe, ops, DataCollection
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility

In [None]:
MODEL = 'resnet50'

DEVICE = torch.device('cpu')


HOST = 'localhost'
PORT = '19530'
TOPK = 10
DIM = 2048
COLLECTION_NAME = 'reverse_image_search'
INDEX_TYPE = 'IVF_FLAT'
METRIC_TYPE = 'L2'

classes = ['bareland', 'commercial', 'playground', 'mountain', 'desert', 'river', 'pond', 'sparseresidential', 'parking', 'railwaystation', 'resort', 'baseballfield', 'denseresidential', 'center', 'viaduct', 'mediumresidential', 'meadow', 'forest', 'beach', 'bridge', 'church', 'park', 'stadium', 'storagetanks', 'port', 'airport', 'industrial', 'square', 'school', 'farmland']
IMAGE_POOLS = []

for c in classes:
    class_lists = glob(f'/Users/siasejung/Desktop/milvus_image/{c}*.jpg')
    IMAGE_POOLS.extend(class_lists)
p_embed = (
    pipe.input('src')
        .flat_map('src', 'img_path', lambda _: IMAGE_POOLS)
        .map('img_path', 'img', ops.image_decode())
        .map('img', 'vec', ops.image_embedding.timm(model_name=MODEL, device=DEVICE))
)
p_display = p_embed.output('img_path', 'img', 'vec')
DataCollection(p_display('/Users/siasejung/Desktop/milvus_image/*.jpg')).show()

In [None]:
def create_milvus_collection(collection_name, dim):
    if utility.has_collection(collection_name):
        utility.drop_collection(collection_name)
    
    fields = [
        FieldSchema(name='path', dtype=DataType.VARCHAR, description='path to image', max_length=500, 
                    is_primary=True, auto_id=False),
        FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, description='image embedding vectors', dim=dim)
    ]
    schema = CollectionSchema(fields=fields, description='reverse image search')
    collection = Collection(name=collection_name, schema=schema)

    index_params = {
        'metric_type': METRIC_TYPE,
        'index_type': INDEX_TYPE,
        'params': {"nlist": 2048}
    }
    collection.create_index(field_name='embedding', index_params=index_params)
    return collection

In [None]:
connections.connect(host=HOST, port=PORT)
collection = create_milvus_collection(COLLECTION_NAME, DIM)
p_insert = (
        p_embed.map(('img_path', 'vec'), 'mr', ops.ann_insert.milvus_client(
                    host=HOST,
                    port=PORT,
                    collection_name=COLLECTION_NAME
                    ))
          .output('mr')
)
for i in IMAGE_POOLS:
    p_insert(i)
print('Number of data inserted:', collection.num_entities)