### Module

In [1]:
import pandas as pd
import pickle

from pymilvus.model.sparse.bm25.tokenizers import build_default_analyzer
from pymilvus.model.sparse import BM25EmbeddingFunction

from sklearn.metrics.pairwise import cosine_similarity

import pandas as pd
from tqdm import tqdm
import pickle
from pymilvus import model
from pymilvus import MilvusClient, Collection, connections, DataType, CollectionSchema, FieldSchema
import numpy as np
import json
from FlagEmbedding import FlagReranker
from pymilvus.model.reranker import BGERerankFunction
import random

### Functions
* **get_bm25_model**      -> BM25 모델 Load 함수
* **get_reranker_model**  -> BGEM3 reranker 모델 Load 함수
* **save_to_tsv**         -> 추론 파일 저장 함수
* **get_bm25_result**     -> BM25 모델 추론 결과 출력 함수
* **get_reranker_result** -> Reranker 모델 추론 결과 출력 함수

In [2]:
# BM25 모델 Load
def get_bm25_model(model_path:str = "./files/bm25_msmarco_v1.json",
                   analyzer_language:str = "en"):
    analyzer = build_default_analyzer(language=analyzer_language)
    bm25_ef = BM25EmbeddingFunction(analyzer)
    bm25_ef.load(model_path)
    return bm25_ef

# BGE Reranker 모델 Load
def get_reranker_model(model_path:str = "./models/kw_3_easy_train",
                       device:str = "cuda:0"):
    bge_rf = BGERerankFunction(
        model_name=model_path,
        device=device,
        batch_size=32,
    )
    return bge_rf

# 검색 result를 tsv 형태로 저장
def save_to_tsv(result, output_path):
    result_df = pd.DataFrame(result)
    result_df.to_csv(output_path, sep='\t', index=False)
    print("Successfully Saved!")
    

def get_bm25_result(client, query, error_handle, top_n = 100):

    # query가 에러났을 때
    if error_handle:
        query_embeddings = bm25_ef.encode_queries([query + "this query is dummy"])

        res = client.search(
            collection_name="msmarco_bm25",  # target collection
            data=query_embeddings, 
            limit=100,  # number of returned entities
            output_fields=["pid","text"],  # specifies fields to be returned
            anns_field="sparse_vector",
            )

        candidate_pids = [entity["entity"]["pid"] for entity in res[0]]
        candidate_passages = [entity["entity"]["text"] for entity in res[0]]
        
    # query가 정상일 때
    else:
        query_embeddings = bm25_ef.encode_queries([query])

        res = client.search(
            collection_name="msmarco_bm25",  # target collection
            data=query_embeddings, 
            limit=100,  # number of returned entities
            output_fields=["pid","text"],  # specifies fields to be returned
            anns_field="sparse_vector",
            )

        candidate_pids = [entity["entity"]["pid"] for entity in res[0]]
        candidate_passages = [entity["entity"]["text"] for entity in res[0]]

    return candidate_pids, candidate_passages

def get_reranker_result(query, candidate_passages, top_n = 100):
    top_k = bge_rf(
            query=query,
            documents=candidate_passages,
            top_k=top_n,
        )
    return top_k
    

### Data Load
* dataset -> MSMARCO Passage Ranking <Dev> Dataset
* testset -> MSMARCO Passage Ranking <Dev>'s 20%

In [3]:
# MSMARCO Query
unique_query = pd.read_csv('./data/unique_query.csv')

# Test Set에 대한 Ground Truth
test_qrels = pd.read_csv("./data/test_qrels.tsv", sep='\t', names=['qid', 'r', 'pid', 'l'])

# Test Set에 대한 qid 추출
test_qid = test_qrels["qid"].tolist()

# print("passages : ", len(msmarco_dev_passages))
print("test_qid : ", len(set(test_qid)))

test_qid :  1266


### Model Load

In [4]:
# BM25
bm25_ef = get_bm25_model()

# BGE Reranker
bge_rf = get_reranker_model(model_path="./models/kw_3_easy_train", device="cuda:0") # finetuning 한 모델

### prepare retriever

In [5]:
test_query_list = [unique_query[unique_query["qid"] == qid_i]["query"].tolist()[0] for qid_i in test_qid]
top_n = 100

### Retrieve
* 1차 retriever -> Milvus/BM25
* 2차 retriever -> BAAI/bge-reranker-v2-m3

In [6]:
# VectorDB에 클라이언트 연결
client = MilvusClient()
client.load_collection("msmarco_bm25")

In [7]:
# 6분 소요
result = []

for i in tqdm(range(len(test_query_list))):
    qid = test_qid[i]
    query = test_query_list[i]

    try:
        candidate_pids, candidate_passages = get_bm25_result(client, query, False, top_n)
        top_k = get_reranker_result(query, candidate_passages)
        for n,i in enumerate(top_k):
                result.append([qid, candidate_pids[i.index], n+1])
    except:
        candidate_pids, candidate_passages = get_bm25_result(client, query, True, top_n)
        top_k = get_reranker_result(query, candidate_passages)   
        for n,i in enumerate(top_k):
                result.append([qid, candidate_pids[i.index], n+1])

output_path = "./result/bm25_reranker.tsv"

save_to_tsv(result, output_path)

100%|██████████| 1324/1324 [05:47<00:00,  3.81it/s]

Done!





### MRR@100

In [8]:
!python ms_marco_eval.py \
./data/test_qrels.tsv \
./result/bm25_reranker_collection.tsv

################################
# MRR @100: 0.3914833242741064 #
################################
