### Module

In [1]:
from keybert import KeyBERT
import pandas as pd
from tqdm import tqdm
import pickle
from pymilvus import model
from pymilvus import MilvusClient, Collection, connections, DataType, CollectionSchema, FieldSchema
import numpy as np
import json
from FlagEmbedding import FlagReranker
from pymilvus.model.reranker import BGERerankFunction
import random
random.seed(42)

### Models

In [None]:
bge_m3_ef = model.hybrid.BGEM3EmbeddingFunction(
        model_name= "BAAI/bge-m3",
        batch_size = 64,
        device = "cuda:0",
        # use_fp16 = True,
        return_dense = True,
        return_sparse = False,
        return_colbert_vecs = False,
    )

bge_rf = BGERerankFunction(
        model_name="./models/kw_3_easy_train", 
        device="cuda:0",
        batch_size=32
    )

Fetching 30 files:   0%|          | 0/30 [00:00<?, ?it/s]

### Data setting

In [None]:
client = MilvusClient()
client.load_collection("msmarco_bgem3")
unique_query = pd.read_csv('./data/unique_query.csv')
test_qrels = pd.read_csv("./data/test_qrels.tsv", sep='\t', names=['qid', 'r', 'pid', 'l'])
test_qid = test_qrels["qid"].tolist()

### Retrieve

In [None]:
# 19분 소요
result = []
error_list = []

for i in tqdm(range(len(test_qid))):
    try:
        qid = test_qid[i]
        query = unique_query[unique_query["qid"] == qid]["query"].tolist()[0]
        # qid, query = unique_query.iloc[i]["qid"], unique_query.iloc[i]["query"]
        query_vectors = bge_m3_ef.encode_queries([query])["dense"]

        candidate = client.search(
            collection_name="msmarco_bgem3",  # target collection
            data=query_vectors,  # query vectors
            limit=100,  # number of returned entities
            output_fields=["pid","text"],
            anns_field="dense_vector"
        )
        candidate_text = [i["entity"]["text"] for i in candidate[0]]
        candidate_pid = np.array([i["entity"]["pid"] for i in candidate[0]])

        top_k = bge_rf(
            query=query,
            documents=candidate_text,
            top_k=100,
        )
        for n,k in enumerate(top_k):
            result.append([qid, candidate_pid[k.index], n+1])

    except:
        error_list.append(qid)
        print(qid)


  0%|          | 0/1324 [00:00<?, ?it/s]

  1%|          | 12/1324 [00:09<16:54,  1.29it/s]

In [None]:
result_df = pd.DataFrame(result)

tsv_file_path = 'result/bgem3_reranker.tsv'
result_df.to_csv(tsv_file_path, sep='\t', index=False)

In [None]:
# # 19분 소요 버전
# result = []
# error_list = []

# for i in tqdm(range(len(test_qid))):
#     try:
#         qid = test_qid[i]
#         query = unique_query[unique_query["qid"] == qid]["query"].tolist()[0]
#         # qid, query = unique_query.iloc[i]["qid"], unique_query.iloc[i]["query"]
#         query_vectors = bge_m3_ef.encode_queries([query])["dense"]

#         candidate = client.search(
#             collection_name="msmarco_bgem3",  # target collection
#             data=query_vectors,  # query vectors
#             limit=100,  # number of returned entities
#             # filter=f"qid == {qid}",
#             output_fields=["pid","text"],
#             anns_field="dense_vector"
#         )
#         candidate_text = [i["entity"]["text"] for i in candidate[0]]
#         candidate_pid = np.array([i["entity"]["pid"] for i in candidate[0]])

#         top_k = bge_rf(
#             query=query,
#             documents=candidate_text,
#             top_k=100,
#         )
#         for n,i in enumerate(top_k):
#             result.append([qid, candidate_pid[i.index], n+1])
#         break

#     except:
#         error_list.append(qid)
#         print(qid)
