### Module

In [2]:
import pandas as pd
import pickle

from pymilvus.model.sparse.bm25.tokenizers import build_default_analyzer
from pymilvus.model.sparse import BM25EmbeddingFunction

from sklearn.metrics.pairwise import cosine_similarity

import pandas as pd
from tqdm import tqdm
import pickle
from pymilvus import model
from pymilvus import MilvusClient, Collection, connections, DataType, CollectionSchema, FieldSchema
import numpy as np
import json
from FlagEmbedding import FlagReranker
from pymilvus.model.reranker import BGERerankFunction
import random

### Functions

In [3]:
# BM25 모델 Load
def get_bm25_model(model_path:str = "./files/bm25_msmarco_v1.json",
                   analyzer_language:str = "en"):
    analyzer = build_default_analyzer(language=analyzer_language)
    bm25_ef = BM25EmbeddingFunction(analyzer)
    bm25_ef.load(model_path)
    return bm25_ef

# BGE Reranker 모델 Load
def get_reranker_model(model_path:str = "BAAI/bge-reranker-v2-m3",
                       device:str = "cuda:0"):
    bge_rf = BGERerankFunction(
        model_name=model_path,
        device=device,
        batch_size=32,
    )
    return bge_rf

# 검색 result를 tsv 형태로 저장
def save_to_tsv(result, output_path):
    result_df = pd.DataFrame(result)
    result_df.to_csv(output_path, sep='\t', index=False)
    print("Done!")
    

### Data Load
* dataset -> MSMARCO Passage Ranking <Dev> Dataset
* testset -> MSMARCO Passage Ranking <Dev>'s 20%

In [4]:
# MSMARCO Dev dataset
msmarco_dev = pd.read_csv("./data/top1000_dev.tsv", sep='\t', names=['qid', 'pid', 'query', 'passage'])

# MSMARCO Query
unique_query = pd.read_csv('/home/livin/rimo/llm/msmarco/notebook/unique_query.csv')

# Passage 추출 (중복된 Passage는 제거)
msmarco_dev_passages = msmarco_dev["passage"].unique().tolist()

# Pid 추출 (중복된 Pid는 제거)
msmarco_dev_pids = msmarco_dev["pid"].unique().tolist()

# Test Set에 대한 Ground Truth
test_qrels = pd.read_csv("./data/test_qrels.tsv", sep='\t', names=['qid', 'r', 'pid', 'l'])

# Test Set에 대한 qid 추출
test_qid = test_qrels["qid"].tolist()

print("passages : ", len(msmarco_dev_passages))
print("test_qid : ", len(test_qid))


passages :  3895239
test_qid :  1324


### Model Load

In [5]:
# BM25
bm25_ef = get_bm25_model()

# BGE Reranker
bge_rf = get_reranker_model(model_path="./models/kw_3_easy_train", device="cuda:0")

### test set load

In [6]:
test_query_list = [unique_query[unique_query["qid"] == qid_i]["query"].tolist()[0] for qid_i in test_qid]
top_n = 100

NameError: name 'docs_embeddings' is not defined

### Retrieve
* 1차 retriever -> Milvus/BM25
* 2차 retriever -> BAAI/bge-reranker-v2-m3

In [10]:
client = MilvusClient()
client.load_collection("msmarco_bm25")

In [78]:
# 24분 소요
result = []

for i in tqdm(range(len(test_query_list))):
    try:
        qid = test_qid[i]
        query = test_query_list[i]

        query_embeddings = bm25_ef.encode_queries([query])

        res = client.search(
            collection_name="msmarco_bm25",  # target collection
            data=query_embeddings, 
            limit=100,  # number of returned entities
            output_fields=["pid","text"],  # specifies fields to be returned
            anns_field="sparse_vector",
            )

        candidate_pids = [entity["entity"]["pid"] for entity in res[0]]
        candidate_passages = [entity["entity"]["text"] for entity in res[0]]

        top_k = bge_rf(
                    query=query,
                    documents=candidate_passages,
                    top_k=100,
                )
        for n,i in enumerate(top_k):
                result.append([qid, candidate_pids[i.index], n+1])
    except:
        print(qid)

  0%|          | 0/1324 [00:00<?, ?it/s]

 94%|█████████▎| 1239/1324 [05:25<00:23,  3.62it/s]

983451


100%|██████████| 1324/1324 [05:48<00:00,  3.80it/s]


In [79]:
save_to_tsv(result, "./result/bm25_reranker_collection.tsv")

Done!
