### Module

In [1]:
from keybert import KeyBERT
import pandas as pd
from tqdm import tqdm
import pickle
from pymilvus import model
from pymilvus import MilvusClient, Collection, connections, DataType, CollectionSchema, FieldSchema
import numpy as np
import json
from FlagEmbedding import FlagReranker
from pymilvus.model.reranker import BGERerankFunction
import random
random.seed(42)

### Models

In [2]:
bge_m3_ef = model.hybrid.BGEM3EmbeddingFunction(
        model_name= "BAAI/bge-m3",
        batch_size = 64,
        device = "cuda:0",
        # use_fp16 = True,
        return_dense = True,
        return_sparse = False,
        return_colbert_vecs = False,
    )

bge_rf = BGERerankFunction(
        model_name="./models/kw_3_easy_train", 
        device="cuda:0",
        batch_size=32
    )

Fetching 30 files:   0%|          | 0/30 [00:00<?, ?it/s]

### Data setting

In [3]:
client = MilvusClient()
unique_query = pd.read_csv('./data/unique_query.csv')
test_qrels = pd.read_csv("./data/test_qrels.tsv", sep='\t', names=['qid', 'r', 'pid', 'l'])
test_qid = test_qrels["qid"].tolist()

### Retrieve

In [12]:
# 19분 소요
result = []
error_list = []

for i in tqdm(range(len(test_qid))):
    try:
        qid = test_qid[i]
        query = unique_query[unique_query["qid"] == qid]["query"].tolist()[0]
        # qid, query = unique_query.iloc[i]["qid"], unique_query.iloc[i]["query"]
        query_vectors = bge_m3_ef.encode_queries([query])["dense"]

        candidate = client.search(
            collection_name="msmarco_bgem3",  # target collection
            data=query_vectors,  # query vectors
            limit=100,  # number of returned entities
            # filter=f"qid == {qid}",
            output_fields=["pid","text"],
            anns_field="dense_vector"
        )
        candidate_text = [i["entity"]["text"] for i in candidate[0]]
        candidate_pid = np.array([i["entity"]["pid"] for i in candidate[0]])

        top_k = bge_rf(
            query=query,
            documents=candidate_text,
            top_k=100,
        )
        for n,k in enumerate(top_k):
            result.append([qid, candidate_pid[k.index], n+1])

    except:
        error_list.append(qid)
        print(qid)


 10%|▉         | 132/1324 [01:41<14:54,  1.33it/s]

1082948


 11%|█▏        | 149/1324 [01:54<14:50,  1.32it/s]

: 

: 

In [10]:
result_df = pd.DataFrame(result)

tsv_file_path = 'result/bgem3_reranker.tsv'
result_df.to_csv(tsv_file_path, sep='\t', index=False)

In [8]:
query = unique_query["query"].tolist()[0:100]
query

[' Androgen receptor define',
 '3 levels of government in canada and their responsibilities',
 '3/5 of 60',
 '60x40 slab cost',
 'Bethel University was founded in what year',
 'Does Suddenlink Carry ESPN3',
 'Explain what a bone scan is and what it is used for.',
 'Is the Louisiana sales tax 4.75',
 'Ludacris Net Worth',
 'Sony PS-LX300USB how to connect to pc',
 'The hormone that does the opposite of calcitonin is',
 'What Does Noel Mean in the Bible',
 'When did the earthquake hit San Francisco during the World Series',
 '_____ is the ability of cardiac pacemaker cells to spontaneously initiate an electrical impulse without being stimulated from another source, such as a nerve.',
 '_____ is the name used to refer to the era of legalized segregation in the united states',
 '_______ is a fuel produced by fermenting crops.',
 '________ disparity refers to the slightly different view of the world that each eye receives.cyclopeanbinocularmonoculartrichromatic',
 '____________________ is c

In [9]:
query_vectors = bge_m3_ef.encode_queries(query)["dense"]
query_vectors

[array([ 0.01229456, -0.0285839 ,  0.00026209, ...,  0.01170966,
        -0.05474449,  0.03620829], dtype=float32),
 array([ 0.02562712, -0.00611602, -0.02199598, ...,  0.01604042,
         0.0076088 , -0.03495618], dtype=float32),
 array([ 0.00285246,  0.00149539, -0.06758212, ..., -0.0204071 ,
         0.02017461, -0.01204825], dtype=float32),
 array([-0.04419862, -0.00739527, -0.03273088, ..., -0.04458462,
         0.00706671,  0.07837474], dtype=float32),
 array([-0.01665251,  0.00897697,  0.01971322, ..., -0.00369375,
         0.01407065,  0.02688692], dtype=float32),
 array([-0.06936234, -0.01865093, -0.01795062, ...,  0.01110846,
        -0.01361303,  0.03744141], dtype=float32),
 array([ 0.01577861,  0.00223051, -0.08699854, ..., -0.01199264,
         0.05077625, -0.02666679], dtype=float32),
 array([-0.02100987, -0.0058685 , -0.05266322, ..., -0.02853285,
        -0.05244388,  0.03993858], dtype=float32),
 array([-0.04774194,  0.0484274 , -0.01176151, ...,  0.00623611,
       

In [10]:
for i in tqdm(query_vectors):
    candidate = client.search(
            collection_name="msmarco_bgem3",  # target collection
            data=[i],  # query vectors
            limit=100,  # number of returned entities
            output_fields=["pid","text"],
            anns_field="dense_vector"
        )

100%|██████████| 100/100 [00:49<00:00,  2.00it/s]


In [11]:
candidate = client.search(
            collection_name="msmarco_bgem3",  # target collection
            data=query_vectors,  # query vectors
            limit=100,  # number of returned entities
            output_fields=["pid","text"],
            anns_field="dense_vector"
        )

In [28]:
len(candidate[0])

100

In [None]:
# # 19분 소요 버전
# result = []
# error_list = []

# for i in tqdm(range(len(test_qid))):
#     try:
#         qid = test_qid[i]
#         query = unique_query[unique_query["qid"] == qid]["query"].tolist()[0]
#         # qid, query = unique_query.iloc[i]["qid"], unique_query.iloc[i]["query"]
#         query_vectors = bge_m3_ef.encode_queries([query])["dense"]

#         candidate = client.search(
#             collection_name="msmarco_bgem3",  # target collection
#             data=query_vectors,  # query vectors
#             limit=100,  # number of returned entities
#             # filter=f"qid == {qid}",
#             output_fields=["pid","text"],
#             anns_field="dense_vector"
#         )
#         candidate_text = [i["entity"]["text"] for i in candidate[0]]
#         candidate_pid = np.array([i["entity"]["pid"] for i in candidate[0]])

#         top_k = bge_rf(
#             query=query,
#             documents=candidate_text,
#             top_k=100,
#         )
#         for n,i in enumerate(top_k):
#             result.append([qid, candidate_pid[i.index], n+1])
#         break

#     except:
#         error_list.append(qid)
#         print(qid)
