In [1]:
import pickle 
from pathlib import Path

import numpy as np

from milvus import Milvus, IndexType, MetricType, Status

In [2]:
client = Milvus(host='localhost', port='19530')
collection_name = "LSC_full"
status, ok = client.has_collection(collection_name)

if ok:
    print(f"Collection {collection_name} already exists! Drop collection to re-create.")
#     return
else:
    client.create_collection({
        "collection_name": collection_name,
        "dimension": 768,
        "index_file_size": 2048,
        "metric_type": MetricType.L2,
    })

_, collections = client.list_collections()

_, collection = client.get_collection_info(collection_name)
print(collection)

Collection LSC_full already exists! Drop collection to re-create.
CollectionSchema(collection_name='LSC_full', dimension=768, index_file_size=2048, metric_type=<MetricType: L2>)


In [16]:
# status = client.drop_collection(collection_name)
# print(status)

Status(code=0, message='Delete collection successfully!')


In [10]:
with open("/home/vbs2/lsc/L14_336_features_128.pkl", "rb") as f:
    data = pickle.load(f)

In [7]:
name_to_paths = {}
for yearmonth in Path("/home/vbs2/lsc/extracted").iterdir():
    for day in yearmonth.iterdir():
        if not day.is_dir(): continue
        name_to_paths.update({
            path.name: path for path in day.iterdir() if path.suffix == ".jpg"
        })
print(len(name_to_paths))

724443


In [29]:
name_2000 = {
    name.split('.')[0]: path for name, path in name_to_paths.items() if name.startswith("2000")
}
with open("path_correction.txt", 'w') as f:
    for k, v in name_2000.items():
        print(k, v, file=f)

In [21]:
image_list = [name.split('.')[0] for name in sorted(name_to_paths.keys())]
with open("image_list_full.txt", 'w') as f:
    print(*image_list, sep='\n', file=f)

In [12]:
features_list = [
    (x[0].split('.')[0], x[1]) for x in data.items() if x[0] in name_to_paths
]

features_list.sort(key=lambda x: x[0])

In [13]:
feature_name_list, vectors = zip(*features_list)
vectors = np.array(list(vectors), dtype=np.float32)
print(vectors.shape)

(724443, 768)


In [22]:
assert list(feature_name_list) == image_list
np.linalg.norm(vectors[230] / np.linalg.norm(vectors[230]))

0.99999994

In [16]:
from tqdm.notebook import tqdm
# Insert by batch since milvus only support 256MB inserts at a time
n = vectors.shape[0]
bs = n // 256
all_ids = []

for i in tqdm(range(0, n, bs)):
    batch = vectors[i:i+bs]
    status, batch_ids = client.insert(collection_name=collection_name, records=batch.astype(np.float32))
    while not status.OK():
        status, batch_ids = client.insert(collection_name=collection_name, records=batch.astype(np.float32))
    all_ids.extend(batch_ids)

  0%|          | 0/257 [00:00<?, ?it/s]

In [17]:
with open("milvus_ids_full.txt", 'w') as f:
    print(*all_ids, sep='\n', file=f)

In [19]:
status, stats = client.get_collection_stats(collection_name)
print(stats)

{'partitions': [{'row_count': 724443, 'segments': [{'data_size': 1389283280, 'index_name': 'IDMAP', 'name': '1656254268313068000', 'row_count': 451066}, {'data_size': 14543760, 'index_name': 'IDMAP', 'name': '1656254320775117000', 'row_count': 4722}, {'data_size': 123643520, 'index_name': 'IDMAP', 'name': '1656254339850177000', 'row_count': 40144}, {'data_size': 703813880, 'index_name': 'IDMAP', 'name': '1656254342851931000', 'row_count': 228511}], 'tag': '_default'}], 'row_count': 724443}


In [89]:
%%time
client.create_index(collection_name, IndexType.IVF_FLAT, params={"nlist": 2048})

CPU times: user 1.17 s, sys: 92.8 ms, total: 1.26 s
Wall time: 25min 26s


Status(code=0, message='Build index successfully!')

In [23]:
status, index = client.get_index_info(collection_name)
print(index)

(collection_name='LSC_full', index_type=<IndexType: FLAT>, params={})


In [14]:
%%time
search_param = {
    "nprobe": 1024
}


param = {
    'collection_name': collection_name,
    'query_records': np.random.rand(1, 768).astype(np.float32),
    'top_k': 4000,
    'params': search_param,
}

status, results = client.search(**param)
if status.OK():
    print(results.shape)

(1, 4000)
CPU times: user 3.9 ms, sys: 4.78 ms, total: 8.68 ms
Wall time: 69.9 ms
