### Module

In [1]:
import pandas as pd
from tqdm import tqdm
import pickle
import numpy as np
import json
import random
random.seed(42)

from pymilvus import model
from pymilvus import MilvusClient, Collection, connections, DataType, CollectionSchema, FieldSchema
from pymilvus.model.reranker import BGERerankFunction

from FlagEmbedding import FlagReranker

### Functions

In [2]:
# Embedding 모델 Load
def get_embedding_model(model_name = "BAAI/bge-m3", batch_size= 64, device = "cuda:0"):
    bge_m3_ef = model.hybrid.BGEM3EmbeddingFunction(
        model_name= model_name,
        batch_size = batch_size,
        device = device,
        return_dense = True,
        return_sparse = False,
        return_colbert_vecs = False,
    )
    return bge_m3_ef

# BGE Reranker 모델 Load
def get_reranker_model(model_path = "./models/kw_3_easy_train",
                       device = "cuda:0"):
    bge_rf = BGERerankFunction(
        model_name=model_path,
        device=device,
        batch_size=32,
    )
    return bge_rf

# 검색 result를 tsv 형태로 저장
def save_to_tsv(result, output_path):
    result_df = pd.DataFrame(result)
    result_df.to_csv(output_path, sep='\t', index=False)
    print("Successfully Saved!")

def get_vector_search_result(client, query, top_n = 100):

    query_vectors = bge_m3_ef.encode_queries([query])["dense"]
    candidate = client.search(
        collection_name="msmarco_bgem3",  # target collection
        data=query_vectors,  # query vectors
        limit=top_n,  # number of returned entities
        output_fields=["pid","text"],
        anns_field="dense_vector"
    )
    candidate_passages = [i["entity"]["text"] for i in candidate[0]]
    candidate_pids = np.array([i["entity"]["pid"] for i in candidate[0]])

    return candidate_pids, candidate_passages

def get_reranker_result(query, candidate_passages, top_n = 100):
    top_k = bge_rf(
            query=query,
            documents=candidate_passages,
            top_k=top_n,
        )
    return top_k

### Data Load

In [3]:
unique_query = pd.read_csv('./data/unique_query.csv')
test_qrels = pd.read_csv("./data/test_qrels.tsv", sep='\t', names=['qid', 'r', 'pid', 'l'])
test_qid = test_qrels["qid"].tolist()

### Model Load

In [4]:
# BGE-M3
bge_m3_ef = get_embedding_model()

# BGE Reranker
bge_rf = get_reranker_model()

Fetching 30 files:   0%|          | 0/30 [00:00<?, ?it/s]

### Retrieve

In [5]:
# VectorDB에 클라이언트 연결
client = MilvusClient()
client.load_collection("msmarco_bgem3")

In [6]:
# 19분 소요
result = []
error_list = []

for i in tqdm(range(len(test_qid))):
    try:
        qid = test_qid[i]
        query = unique_query[unique_query["qid"] == qid]["query"].tolist()[0]
        candidate_pids, candidate_passages = get_vector_search_result(client, query, 100)
        top_k = get_reranker_result(query, candidate_passages, 100)

        for n,k in enumerate(top_k):
            result.append([qid, candidate_pids[k.index], n+1])

    except:
        error_list.append(qid)
        print(qid)

output_path = "./result/bgem3_reranker.tsv"
save_to_tsv(result, output_path)

 15%|█▍        | 192/1324 [02:27<14:19,  1.32it/s]

In [None]:
!python ms_marco_eval.py \
./data/test_qrels.tsv \
./result/bgem3_reranker.tsv

Traceback (most recent call last):
  File "/home/livin/rimo/llm/MSMARCO_TEST/ms_marco_eval.py", line 176, in <module>
    main()
  File "/home/livin/rimo/llm/MSMARCO_TEST/ms_marco_eval.py", line 170, in main
    metrics = compute_metrics_from_files(path_to_reference, path_to_candidate)
  File "/home/livin/rimo/llm/MSMARCO_TEST/ms_marco_eval.py", line 157, in compute_metrics_from_files
    qids_to_ranked_candidate_passages = load_candidate(path_to_candidate)
  File "/home/livin/rimo/llm/MSMARCO_TEST/ms_marco_eval.py", line 74, in load_candidate
    with open(path_to_candidate,'r') as f:
FileNotFoundError: [Errno 2] No such file or directory: './result/bgem3_reranker.tsv'
