In [None]:
from pathlib import Path
import pandas as pd
import yaml
import numpy as np
import sys
sys.path.append('..')
from tqdm.notebook import trange

from services.adapter.src import MilvusWrapper

In [None]:
from pymilvus import CollectionSchema, DataType, FieldSchema
def create_schema(description='Museum features') -> CollectionSchema:
    img_id = FieldSchema(
        name="id",
        dtype=DataType.INT64,
        is_primary=True,
        auto_id=True
    )

    image_path = FieldSchema(
        name="image_path",
        dtype=DataType.VARCHAR,
        max_length=100,
        default_value="Unknown"
    )
    
    object_id = FieldSchema(
        name="object_id",
        dtype=DataType.INT64,
    )
    
    description = FieldSchema(
        name="description",
        dtype=DataType.VARCHAR,
        max_length=10000,
        default_value="Unknown"
    )
    
    features = FieldSchema(
        name="features",
        dtype=DataType.FLOAT_VECTOR,
        dim=1408
    )

    schema = CollectionSchema(
        fields=[img_id, image_path, object_id, description, features],
        description='Image retrieval',
        enable_dynamic_field=True
    )
    return schema

In [None]:
config = yaml.safe_load(Path('../configs/config.yaml').read_text())
milvus = MilvusWrapper(config['milvus'])
milvus.connect()
milvus.init_collection(config['collection_name'], schema=create_schema())

In [None]:
data = []
for i, row in pd.read_csv('../dataset/train.csv', sep=';').fillna('').iterrows():
    image_path = Path(f'/home/borntowarn/projects/borntowarn/museum_search/storage/{row.object_id}/{row.img_name}')
    data.append({
        'image_path': f'/storage/{row.object_id}/{row.img_name}',
        'object_id': int(row.object_id),
        'description': row.description,
        'features': np.load(image_path.with_suffix('.npy'))[0].astype(np.float32),
    })

In [None]:
batch = 20
for i in trange(int(np.ceil(len(data) / batch))):
    milvus.insert(data[i * batch : (i + 1) * batch])