## 使用 Milvus 向量库建立索引

官网：https://milvus.io/docs/milvus_lite.md

样例：https://github.com/milvus-io/milvus-lite/blob/main/examples/example.py

In [1]:
import random
from milvus import default_server
from pymilvus import (
    connections,
    FieldSchema, CollectionSchema, DataType,
    Collection,
    utility
)

In [2]:
# Optional, if you want store all related data to specific location
# default it wil using %APPDATA%/milvus-io/milvus-server
default_server.set_base_dir('db')

# Optional, if you want cleanup previous data
default_server.cleanup()

# star you milvus server
default_server.start()

_HOST = '127.0.0.1'
# The port may be changed, by default it's 19530
_PORT = default_server.listen_port

# Const names
_COLLECTION_NAME = 'test'

# Vector parameters
_DIM = 1536

# Index parameters
_METRIC_TYPE = 'L2'
_INDEX_TYPE = 'IVF_FLAT'
_NLIST = 1024
_NPROBE = 16
_TOPK = 8

In [3]:
# create a connection
print(f"\nCreate connection...")
connections.connect(host=_HOST, port=_PORT)
print(f"\nList connections:")
print(connections.list_connections())

# drop collection if the collection exists
if utility.has_collection(_COLLECTION_NAME):
    collection = Collection(_COLLECTION_NAME)
    collection.drop()
    print("\nDrop collection: {}".format(_COLLECTION_NAME))

# create collection
id_field = FieldSchema(name='id', dtype=DataType.INT64, description="id_int64", is_primary=True)
embedding_field = FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, description="embedding_floatvector", dim=_DIM, is_primary=False)
month_field = FieldSchema(name='month', dtype=DataType.INT64, description="month_int64", is_primary=False)
company_field = FieldSchema(name='company', dtype=DataType.INT64, description="company_int64", is_primary=False)
source_field = FieldSchema(name='source', dtype=DataType.INT64, description="source_int64", is_primary=False)
schema = CollectionSchema(fields=[id_field, embedding_field, month_field, company_field, source_field], description="storing all data")
collection = Collection(name=_COLLECTION_NAME, data=None, schema=schema, properties={"collection.ttl.seconds": 15})

collection.set_properties(properties={"collection.ttl.seconds": 2**31-1})

# show collections
print(utility.list_collections())


Create connection...

List connections:
[('default', <pymilvus.client.grpc_handler.GrpcHandler object at 0x000002DC14D23C40>)]
['test']


In [41]:
sources = []
company_ids = []
months = []
embeddings = []
for source in [1, 2, 3, 4]:
    for company_id in [1, 2]:
        for month in [202310, 202311, 202312]:
            for i in range(100):
                sources.append(source)
                company_ids.append(company_id)
                months.append(month)
                embeddings.append([(source*1e9+company_id*1e8+month*10+random.random()) for _ in range(_DIM)])

data = [
    [i for i in range(len(embeddings))],
    embeddings,
    months,
    company_ids,
    sources
]

collection.insert(data)
collection.flush()

# get the number of entities
print(collection.num_entities)

4800


In [42]:
# create index
index_param = {
    "index_type": _INDEX_TYPE,
    "params": {"nlist": _NLIST},
    "metric_type": _METRIC_TYPE}
collection.create_index('embedding', index_param)
print("\nCreated index:\n{}".format(collection.index().params))


Created index:
{'index_type': 'IVF_FLAT', 'params': {'nlist': 1024}, 'metric_type': 'L2'}


In [44]:
# load data to memory
collection.load()

In [53]:
source = 3
company_id = 2
month = 202311

# collection.release()
# print(utility.list_collections())
# collection = Collection(_COLLECTION_NAME)
# collection.load()
print(collection.num_entities)

search_param = {
    "data": [[(source*1e9+company_id*1e8+month*10+random.random()) for _ in range(_DIM)]],
    "anns_field": 'embedding',
    "param": {"metric_type": _METRIC_TYPE, "params": {"nprobe": 1024}},
    "limit": _TOPK,
    "expr": '(company == 2 or company==2) and source == 3',
    "output_fields": ["id", "month", "company", "source"]
}
results = collection.search(**search_param)

for i, result in enumerate(results):
    print("\nSearch result for {}th vector: ".format(i))
    for j, res in enumerate(result):
        print("Top {}: {}".format(j, res))

4800

Search result for 0th vector: 
Top 0: id: 1504, distance: 0.0, entity: {'company': 2, 'source': 3, 'id': 1504, 'month': 202310}
Top 1: id: 1506, distance: 0.0, entity: {'company': 2, 'source': 3, 'id': 1506, 'month': 202310}
Top 2: id: 1507, distance: 0.0, entity: {'company': 2, 'source': 3, 'id': 1507, 'month': 202310}
Top 3: id: 1503, distance: 0.0, entity: {'company': 2, 'source': 3, 'id': 1503, 'month': 202310}
Top 4: id: 1501, distance: 0.0, entity: {'company': 2, 'source': 3, 'id': 1501, 'month': 202310}
Top 5: id: 1505, distance: 0.0, entity: {'company': 2, 'source': 3, 'id': 1505, 'month': 202310}
Top 6: id: 1502, distance: 0.0, entity: {'company': 2, 'source': 3, 'id': 1502, 'month': 202310}
Top 7: id: 1500, distance: 0.0, entity: {'company': 2, 'source': 3, 'id': 1500, 'month': 202310}


In [None]:
# release memory
collection.release()

# # drop collection index
# collection.drop_index()

# # drop collection
# collection.drop()