# Module

In [1]:
from keybert import KeyBERT
import pandas as pd
from tqdm import tqdm
import pickle
from pymilvus import model
from pymilvus import MilvusClient, Collection, connections, DataType, CollectionSchema, FieldSchema
import numpy as np
import json
from FlagEmbedding import FlagReranker
from pymilvus.model.reranker import BGERerankFunction
import random
random.seed(42)

# Data Load

In [2]:
# dataset load
data = pd.read_csv("/home/livin/rimo/llm/msmarco/data/top1000_dev.tsv", sep='\t', names=['qid', 'pid', 'query', 'passage'])
unique_query = pd.read_csv("/home/livin/rimo/llm/msmarco/notebook/unique_query.csv")
qrels = pd.read_csv("/home/livin/rimo/llm/msmarco/data/qrels.dev.small.tsv", sep='\t', names=['qid', 'r', 'pid', 'l'])

In [None]:
# with open('/home/livin/rimo/llm/msmarco_test/data/train_pid_list.pkl', 'wb') as file:
#     pickle.dump(train_pid_list, file)

# with open('/home/livin/rimo/llm/msmarco_test/data/test_pid_list.pkl', 'wb') as file:
#     pickle.dump(test_pid_list, file)

# print("변수가 성공적으로 저장되었습니다.")


In [3]:
with open('/home/livin/rimo/llm/msmarco_test/data/all_pid_list.pkl', 'rb') as file:
    all_pid_list = pickle.load(file)
with open('/home/livin/rimo/llm/msmarco_test/data/train_pid_list.pkl', 'rb') as file:
    train_pid_list = pickle.load(file)
with open('/home/livin/rimo/llm/msmarco_test/data/test_pid_list.pkl', 'rb') as file:
    test_pid_list = pickle.load(file)

In [80]:
# # split train data
# file_name_list = ["kw_3_easy","kw_3_hard","kw_5_easy","kw_5_hard","kw_7_easy","kw_7_hard","kw_9_easy","kw_9_hard"]

# for file_name in file_name_list:
# 	file_path = f"/home/livin/rimo/llm/msmarco_test/data/{file_name}.jsonl"
# 	re = []
# 	with open(file_path, 'r', encoding='utf-8') as file:
# 		for n, line in enumerate(file):
# 			re.append(json.loads(line))
			
# 	re = re[:5288]
			
# 	with open(f"/home/livin/rimo/llm/msmarco_test/data/{file_name}_train.jsonl" , encoding= "utf-8",mode="w") as file: 
# 		for i in re: file.write(json.dumps(i) + "\n")

# Get Keywords Generation

In [4]:
kw_model = KeyBERT("BAAI/bge-m3")

bge_m3_ef = model.hybrid.BGEM3EmbeddingFunction(
        model_name= "BAAI/bge-m3",
        batch_size = 16,
        device = "cuda:1",
        # use_fp16 = True,
        return_dense = True,
        return_sparse = False,
        return_colbert_vecs = False,
    )

client = MilvusClient()

Fetching 30 files:   0%|          | 0/30 [00:00<?, ?it/s]

In [5]:
pid_list = data["pid"].tolist()

def get_random_pid(pid_list, exclude_pid, total_numbers):
    pid_list_rerange = [pid for pid in pid_list if pid != exclude_pid]  # 제외할 숫자 제거
    random_numbers = random.sample(pid_list_rerange, total_numbers) 
    return random_numbers

def get_keyword_query(top_n, negative_type):
    train_json = []
    for pid in tqdm(train_pid_list):
        passage = data[data["pid"] == pid]["passage"].tolist()[0]

        passage_keywords = kw_model.extract_keywords(passage, keyphrase_ngram_range=(1,1), top_n=top_n)
        passage_keywords = sorted(passage_keywords, key=lambda x: passage.find(x[0]))
        query = " ".join([i[0] for i in passage_keywords])

        query_vectors = bge_m3_ef.encode_queries([query])["dense"]

        if negative_type == "hard":
            res = client.search(
                collection_name="msmarco_bgem3",
                data=query_vectors,
                limit=10,
                output_fields=["text"],
                anns_field="dense_vector",
                filter=f"pid != {pid}",
            )
            neg_list = [i["entity"]["text"] for i in res[0]]
            
        elif negative_type == "easy":
             neg_pid = get_random_pid(pid_list, pid, 10)
             neg_list = [data[data["pid"] == i]["passage"].tolist()[0] for i in neg_pid]

        train_json.append({
            "query": query, 
            "pos": passage, 
            "neg": neg_list
                })
        # break

    return train_json

def get_gt_query(negative_type):
    train_json = []
    for pid in tqdm(train_pid_list):
        passage = data[data["pid"] == pid]["passage"].tolist()[0]
        qid = qrels[qrels["pid"] == pid]["qid"].tolist()[0]
        query = unique_query[unique_query["qid"] == qid]["query"].tolist()[0]

        # passage_keywords = kw_model.extract_keywords(passage, keyphrase_ngram_range=(1,1), top_n=top_n)
        # passage_keywords = sorted(passage_keywords, key=lambda x: passage.find(x[0]))
        # query = " ".join([i[0] for i in passage_keywords])


        if negative_type == "hard":
            query_vectors = bge_m3_ef.encode_queries([query])["dense"]
            res = client.search(
                collection_name="msmarco_bgem3",
                data=query_vectors,
                limit=10,
                output_fields=["text"],
                anns_field="dense_vector",
                filter=f"pid != {pid}",
            )
            neg_list = [i["entity"]["text"] for i in res[0]]
            
        elif negative_type == "easy":
             neg_pid = get_random_pid(pid_list, pid, 10)
             neg_list = [data[data["pid"] == i]["passage"].tolist()[0] for i in neg_pid]

        train_json.append({
            "query": query, 
            "pos": passage, 
            "neg": neg_list
                })
        # break
    return train_json

In [6]:
# re = get_keyword_query(3, "hard")
# with open("/home/livin/rimo/llm/msmarco_test/data/kw_3_hard.jsonl" , encoding= "utf-8",mode="w") as file: 
# 	for i in re: file.write(json.dumps(i) + "\n")

# re = get_keyword_query(3, "easy")
# with open("/home/livin/rimo/llm/msmarco_test/data/kw_3_easy.jsonl" , encoding= "utf-8",mode="w") as file: 
# 	for i in re: file.write(json.dumps(i) + "\n")

# re = get_keyword_query(5, "hard")
# with open("/home/livin/rimo/llm/msmarco_test/data/kw_5_hard.jsonl" , encoding= "utf-8",mode="w") as file: 
# 	for i in re: file.write(json.dumps(i) + "\n")

# re = get_keyword_query(5, "easy")
# with open("/home/livin/rimo/llm/msmarco_test/data/kw_5_easy.jsonl" , encoding= "utf-8",mode="w") as file: 
# 	for i in re: file.write(json.dumps(i) + "\n")

# re = get_keyword_query(7, "hard")
# with open("/home/livin/rimo/llm/msmarco_test/data/kw_7_hard.jsonl" , encoding= "utf-8",mode="w") as file: 
# 	for i in re: file.write(json.dumps(i) + "\n")

# re = get_keyword_query(7, "easy")
# with open("/home/livin/rimo/llm/msmarco_test/data/kw_7_easy.jsonl" , encoding= "utf-8",mode="w") as file: 
# 	for i in re: file.write(json.dumps(i) + "\n")

# re = get_keyword_query(9, "hard")
# with open("/home/livin/rimo/llm/msmarco_test/data/kw_9_hard.jsonl" , encoding= "utf-8",mode="w") as file: 
# 	for i in re: file.write(json.dumps(i) + "\n")

# re = get_keyword_query(9, "easy")
# with open("/home/livin/rimo/llm/msmarco_test/data/kw_9_easy.jsonl" , encoding= "utf-8",mode="w") as file: 
# 	for i in re: file.write(json.dumps(i) + "\n")

# re = get_gt_query("easy")
# with open("/home/livin/rimo/llm/msmarco_test/data/gt_easy_train.jsonl" , encoding= "utf-8",mode="w") as file: 
# 	for i in re: file.write(json.dumps(i) + "\n")

re = get_gt_query("hard")
with open("/home/livin/rimo/llm/msmarco_test/data/gt_hard_train.jsonl" , encoding= "utf-8",mode="w") as file: 
	for i in re: file.write(json.dumps(i) + "\n")

100%|██████████| 5288/5288 [48:49<00:00,  1.81it/s]


# FineTuning

In [7]:
!torchrun --nproc_per_node 2 -m FlagEmbedding.reranker.run --output_dir /home/livin/rimo/llm/msmarco_test/model --model_name_or_path BAAI/bge-reranker-large --train_data /home/livin/rimo/llm/msmarco_test/data/dev_train.jsonl --learning_rate 6e-5 --fp16 --num_train_epochs 100 --per_device_train_batch_size 2 --gradient_accumulation_steps 4 --dataloader_drop_last True --train_group_size 4 --max_len 512 --weight_decay 0.01 --logging_steps 10

W0715 17:59:23.985000 135470223644480 torch/distributed/run.py:757] 
W0715 17:59:23.985000 135470223644480 torch/distributed/run.py:757] *****************************************
W0715 17:59:23.985000 135470223644480 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0715 17:59:23.985000 135470223644480 torch/distributed/run.py:757] *****************************************
07/15/2024 17:59:28 - INFO - __main__ -   Training/evaluation parameters TrainingArguments(
_n_gpu=1,
accelerator_config={'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'gradient_accumulation_kwargs': None},
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=

# Test

In [6]:
bge_m3_ef = model.hybrid.BGEM3EmbeddingFunction(
        model_name= "BAAI/bge-m3",
        batch_size = 16,
        device = "cuda:0",
        # use_fp16 = True,
        return_dense = True,
        return_sparse = False,
        return_colbert_vecs = False,
    )
bge_rf = BGERerankFunction(
    # model_name="BAAI/bge-reranker-large",  # Specify the model name. Defaults to `BAAI/bge-reranker-v2-m3`.
    model_name="/home/livin/rimo/llm/msmarco_finetuning/model/checkpoint-7000",
    device="cuda:0" # Specify the device to use, e.g., 'cpu' or 'cuda:0'
)

Fetching 30 files:   0%|          | 0/30 [00:00<?, ?it/s]

In [7]:
client = MilvusClient()

In [8]:
unique_query = pd.read_csv('/home/livin/rimo/llm/msmarco/notebook/unique_query.csv')

result = []
error_list = []

for i in tqdm(range(len(unique_query))):
    try:
        qid, query = unique_query.iloc[i]["qid"], unique_query.iloc[i]["query"]
        query_vectors = bge_m3_ef.encode_queries([query])["dense"]

        candidate = client.search(
            collection_name="msmarco_bgem3",  # target collection
            data=query_vectors,  # query vectors
            limit=100,  # number of returned entities
            # filter=f"qid == {qid}",
            output_fields=["pid","text"],
            anns_field="dense_vector"
        )
        candidate_text = [i["entity"]["text"] for i in candidate[0]]
        candidate_pid = np.array([i["entity"]["pid"] for i in candidate[0]])

        top_k = bge_rf(
            query=query,
            documents=candidate_text,
            top_k=100,
        )
        for n,i in enumerate(top_k):
            result.append([qid, candidate_pid[i.index], n+1])
    except:
        error_list.append(qid)


  1%|          | 54/6980 [00:43<1:32:08,  1.25it/s]

In [None]:
result_df = pd.DataFrame(result)

tsv_file_path = '/home/livin/rimo/llm/msmarco_finetuning/result/ms_triplet_7000.tsv'
result_df.to_csv(tsv_file_path, sep='\t', index=False)