### Module

In [66]:
import pandas as pd
import pickle
from tqdm import tqdm
import pickle
import numpy as np
import json
import random

from pymilvus.model.sparse.bm25.tokenizers import build_default_analyzer
from pymilvus.model.sparse import BM25EmbeddingFunction
from sklearn.metrics.pairwise import cosine_similarity
from pymilvus import model
from pymilvus import MilvusClient, Collection, connections, DataType, CollectionSchema, FieldSchema
from pymilvus.model.reranker import BGERerankFunction
from pymilvus import MilvusClient

### DataLoad

In [47]:
# msmarco dev셋 load
msmarco_dev = pd.read_csv("./data/top1000_dev.tsv", sep='\t', names=['qid', 'pid', 'query', 'passage'])

# Passage 추출 (중복된 Passage는 제거)
msmarco_dev_passages = msmarco_dev["passage"].unique().tolist()

# Pid 추출 (중복된 Pid는 제거)
msmarco_dev_pids = msmarco_dev["pid"].unique().tolist()

# BM25 모델을 활용한 passage들의 imbedding 불러오기
with open("./files/bm25_docs_embeddings.pickle", "rb") as handle:
    docs_embeddings = pickle.load(handle)

### collection 생성

In [8]:
# 클라이언트 연결
client = MilvusClient()

In [4]:
# 필드 스키마 정의
pid_field = FieldSchema(name="pid", dtype=DataType.INT64, is_primary=True, description="Passage ID")
text_field = FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=2048, description="Passage Text")
sparse_vector_field = FieldSchema(name="sparse_vector", dtype=DataType.SPARSE_FLOAT_VECTOR, description="Sparse Embedding Vector")

In [6]:
# 컬렉션 스키마 생성
schema = CollectionSchema(fields=[pid_field, text_field, sparse_vector_field], description="MSMARCO Dataset BM25 Embedding Collection")

In [9]:
# 컬렉션 생성
collection_name = "msmarco_bm25"
client.create_collection(collection_name=collection_name, schema=schema)

### BM25 embedding

In [13]:
# BM25 모델 Load
def get_bm25_model(model_path:str = "./files/bm25_msmarco_v1.json",
                   analyzer_language:str = "en"):
    analyzer = build_default_analyzer(language=analyzer_language)
    bm25_ef = BM25EmbeddingFunction(analyzer)
    bm25_ef.load(model_path)
    return bm25_ef

bm25_ef = get_bm25_model()

In [62]:
len(msmarco_dev_pids)/100000

38.95239

In [None]:
cnt = 0
entities = []
for pid, passage, sparse_vector in zip(msmarco_dev_pids, msmarco_dev_passages, docs_embeddings):
    entity = {
        "pid": pid,
        "text": passage,
        "sparse_vector": sparse_vector
        }
    entities.append(entity)
    client.insert(
    collection_name="msmarco_bm25",
    entities=entities
)
    cnt+=1
    if cnt%1000000 == 0:
        print(cnt)

### Insert

In [None]:
# start_index = 0
# batch_size = 1000

# # insert 제한량으로 인한 batch성 주입
# while start_index <= len(entities):
#     entities_batch = entities[start_index : start_index+batch_size]
#     client.insert(
#         collection_name="msmarco_bm25",
#         data=entities_batch
#     )
#     start_index += batch_size

# # 마지막 batch에서 입력되지 않은 값들 주입
# client.insert(
#         collection_name="msmarco_bm25",
#         data=entities[start_index-batch_size:]
#     )

### Check

In [87]:
# collection 연결
connections.connect("default", host="localhost", port="19530")
collection_name = "msmarco_bm25"
collection = Collection(collection_name)

# collection내의 entity 개수 확인
collection.num_entities

3895239

### Indexing

In [89]:
######### Sparse Vector ##########

index_params = MilvusClient.prepare_index_params()

index_params.add_index(
    field_name="sparse_vector",
    metric_type="IP", # COSINE 유사도 활용
    index_type="SPARSE_INVERTED_INDEX", 
    index_name="sparse_index"
)

client.create_index(
    collection_name="msmarco_bm25",
    index_params=index_params
)