### Module

In [1]:
import pandas as pd
import pickle

from pymilvus.model.sparse.bm25.tokenizers import build_default_analyzer
from pymilvus.model.sparse import BM25EmbeddingFunction

from sklearn.metrics.pairwise import cosine_similarity

import pandas as pd
from tqdm import tqdm
import pickle
from pymilvus import model
from pymilvus import MilvusClient, Collection, connections, DataType, CollectionSchema, FieldSchema
import numpy as np
import json
from FlagEmbedding import FlagReranker
from pymilvus.model.reranker import BGERerankFunction
import random

### Functions

In [None]:
# BM25 모델 Load
def get_bm25_model(model_path:str = "./files/bm25_msmarco_v1.json",
                   analyzer_language:str = "en"):
    analyzer = build_default_analyzer(language=analyzer_language)
    bm25_ef = BM25EmbeddingFunction(analyzer)
    bm25_ef.load(model_path)
    return bm25_ef

# BGE Reranker 모델 Load
def get_reranker_model(model_path:str = "BAAI/bge-reranker-v2-m3",
                       device:str = "cuda:0"):
    bge_rf = BGERerankFunction(
        model_name=model_path,
        device=device,
        batch_size=32,
    )
    return bge_rf

# 검색 result를 tsv 형태로 저장
def save_to_tsv(result, output_path):
    result_df = pd.DataFrame(result)
    result_df.to_csv(output_path, sep='\t', index=False)
    print("Done!")
    

### Data Load
* dataset -> MSMARCO Passage Ranking <Dev> Dataset
* testset -> MSMARCO Passage Ranking <Dev>'s 20%

In [3]:
# MSMARCO Dev dataset
msmarco_dev = pd.read_csv("./data/top1000_dev.tsv", sep='\t', names=['qid', 'pid', 'query', 'passage'])

# MSMARCO Query
unique_query = pd.read_csv('/home/livin/rimo/llm/msmarco/notebook/unique_query.csv')

# Passage 추출 (중복된 Passage는 제거)
msmarco_dev_passages = msmarco_dev["passage"].unique().tolist()

# Pid 추출 (중복된 Pid는 제거)
msmarco_dev_pids = msmarco_dev["pid"].unique().tolist()

# Test Set에 대한 Ground Truth
test_qrels = pd.read_csv("./data/test_qrels.tsv", sep='\t', names=['qid', 'r', 'pid', 'l'])

# Test Set에 대한 qid 추출
test_qid = test_qrels["qid"].tolist()

print("passages : ", len(msmarco_dev_passages))
print("test_qid : ", len(set(test_qid)))


passages :  3895239
test_qid :  1324


### Model Load

In [4]:
# BM25
bm25_ef = get_bm25_model()

# BGE Reranker
bge_rf = get_reranker_model(model_path="./models/kw_3_easy_train", device="cuda:0")

### BM25_docs_embedding Load

In [5]:
# BM25 모델을 활용한 passage들의 imbedding 불러오기
with open("./files/bm25_docs_embeddings.pickle", "rb") as handle:
    docs_embeddings = pickle.load(handle)

### test set load

In [43]:
result = []
pid_array = np.array(msmarco_dev_pids)
passages_array = np.array(msmarco_dev_passages)
test_query_list = [unique_query[unique_query["qid"] == qid_i]["query"].tolist()[0] for qid_i in test_qid]
query_embeddings = bm25_ef.encode_queries(test_query_list)
cosine_similarities = cosine_similarity(docs_embeddings, query_embeddings)#.flatten()
top_n = 100

### Retrieve
* 1차 retriever -> Milvus/BM25
* 2차 retriever -> BAAI/bge-reranker-v2-m3

In [57]:
# 24분 소요
for i in tqdm(range(len(test_query_list))):
    qid = test_qid[i]
    query = test_query_list[i]
    
    candidate_idxs = np.argsort(cosine_similarities[:,i])[-100:][::-1]
    candidate_pids = pid_array[candidate_idxs]
    candidate_passages = passages_array[candidate_idxs]

    top_k = bge_rf(
                query=query,
                documents=candidate_passages,
                top_k=100,
            )
    for n,i in enumerate(top_k):
            result.append([qid, candidate_pids[i.index], n+1])
    break

  0%|          | 0/1324 [00:00<?, ?it/s]


In [47]:
save_to_tsv(result, "./result/bm25_reranker_.tsv")

Done!


In [33]:
result = []
pid_array = np.array(msmarco_dev_pids)
passages_array = np.array(msmarco_dev_passages)
top_n = 100

for i in tqdm(range(len(test_qid))):
    try:
        # qid, query 추출
        qid = test_qid[i]
        query = unique_query[unique_query["qid"] == qid]["query"].tolist()[0]

        # BM25를 활용한 1차 retrieve
        query_embeddings = bm25_ef.encode_queries([query]) # query 임베딩
        cosine_similarities = cosine_similarity(docs_embeddings, query_embeddings).flatten() # query와 passage간의 유사도 계산
        candidate_idxs = np.argsort(cosine_similarities)[-100:][::-1] # top 100 indexes 추출

        # BM25 검색 결과에 대한 pids와 passages 추출
        candidate_pids = pid_array[candidate_idxs]
        candidate_passages = passages_array[candidate_idxs]

        # reranker를 활용한 2차 retrieve
        top_k = bge_rf(
                query=query,
                documents=candidate_passages,
                top_k=100,
            )
        
        # [qid, pid, rank] format으로 저장
        for n,i in enumerate(top_k):
            result.append([qid, candidate_pids[i.index], n+1])
        break
    except:
        print(qid)


save_to_tsv(result, "./result/bm25_reranker.tsv")

  0%|          | 0/1324 [00:06<?, ?it/s]

Done!





In [35]:
qid, query

(118448, 'define body muscular endurance')