# Data-Preprocessing for Fine-tuning Korean ReRanker
 - **한국어 ReRanker 모델 파인튜닝 예시는 [FlagEmbedding](https://github.com/FlagOpen/FlagEmbedding/tree/master?tab=readme-ov-file)을 기반으로 합니다.**

## AutoReload

In [1]:
%load_ext autoreload
%autoreload 2

## 0. [Data format](https://github.com/FlagOpen/FlagEmbedding/tree/master/examples/finetune#data-format)
- `{"query": str, "pos": List[str], "neg":List[str]}`
    - `query` 및 `pos`는 **1개 이상의 문장이 필요**하며, `neg`는 복수개의 문장도 가능합니다.

## 1. Download dataset
- [msmarco-triplets](https://github.com/microsoft/MSMARCO-Passage-Ranking)
    - (Question, Answer, Negative)-Triplets from MS MARCO Passages dataset, 499,184 samples
    - 해당 데이터 셋은 영문으로 구성되어 있습니다.

In [None]:
!wget -O ./dataset/msmarco/msmarco-triplets.jsonl.gz https://huggingface.co/datasets/sentence-transformers/embedding-training-data/resolve/main/msmarco-triplets.jsonl.gz 
!gunzip ./dataset/msmarco/msmarco-triplets.jsonl

## 2. [Hard negatives](https://github.com/FlagOpen/FlagEmbedding/tree/master/examples/finetune#hard-negatives) (optional)
- Hard Negatives는 문장 임베딩의 품질을 개선하기 위해 널리 사용되는 방법입니다.
- 다음 명령에 따라 Hard Negatives 생성 할 수 있습니다
```
!python ./src/preprocess/hn_mine.py \
    --model_name_or_path BAAI/bge-base-en-v1.5 \
    --input_file ./dataset/toy_finetune_data.jsonl \
    --output_file ./dataset/toy_finetune_data_minedHN.jsonl \
    --range_for_sampling 2-200 \
    --use_gpu_for_searching
```

- `input_file`: json data for finetuning. This script will retrieve top-k documents for each query, and random sample negatives from the top-k documents (not including the positive documents).
- `output_file`: path to save JSON data with mined hard negatives for finetuning
- `range_for_sampling`: where to sample negative. For example, 2-100 means sampling negative from top2-top200 documents. You can set larger value to reduce the difficulty of negatives (e.g., set it 60-300 to sample negatives from top50-300 passages)
- `candidate_pool`: The pool to retrieval. The default value is None, and this script will retrieve from the combination of all neg in input_file. The format of this file is the same as pretrain data. If input a candidate_pool, this script will retrieve negatives from this file.
 - `use_gpu_for_searching`: whether use faiss-gpu to retrieve negatives.

In [None]:
from FlagEmbedding import FlagModel

In [3]:
!python ./src/preprocess/hn_mine.py \
    --model_name_or_path BAAI/bge-base-en-v1.5 \
    --input_file ./dataset/toy_finetune_data.jsonl \
    --output_file ./dataset/toy_finetune_data_minedHN.jsonl \
    --range_for_sampling 2-200 \
    #--use_gpu_for_searching

----------using 4*GPUs----------
inferencing embedding for corpus (number=80)--------------
corpus: 80 || ['Two males are performing.', 'The family was falling apart.', 'The people are watching a funeral procession.', 'A boy is sitting outside playing in the sand.', 'It is boring and mundane.', 'Conrad was being plotted against, to be hit on the head.', 'A group of people plays volleyball.', 'Some people are playing a tune.', 'They sold their home because they were retiring and not because of the loan.', 'The man sits at the table and eats food.', 'Mother Teresa is an easy choice.', 'Person in black clothing, with white bandanna and sunglasses waits at a bus stop.', "She's not going to court to clear her record.", 'a cat is running', 'A man is in a city.', 'A group of women watch soap operas.', 'A girl sits beside a boy.', 'A woman is riding her bike.', 'Some women with flip-flops on, are walking along the beach', 'Two people jumped off the dock.', 'Two men watching a magic show.', 'Th

## 3. Translatation (en -> ko)
 - Amazon Translate를 활용하여 영문을 국문으로 번역합니다.
 - **[주의] Amazon Translate를 사용할 경우 비용이 발생합니다!!**

In [6]:
import json
import time
import boto3
import datetime
import threading
from functools import partial
from collections import OrderedDict
from multiprocessing.pool import ThreadPool

In [7]:
input_file = "./dataset/msmarco/msmarco-triplets.jsonl"
out_file = "./dataset/translated/msmarco/msmarco-triplets-trans"

In [8]:
translate = boto3.client("translate")
def trans(text, target="ko"):
    try:
        response=translate.translate_text(
            Text=text,
            SourceLanguageCode="Auto",
            TargetLanguageCode=target
        )

        text_translate = response["TranslatedText"]

    except Exception as e:
        text_translate = "err"

    return text_translate

[TIP] 처리속도 향상을 위해 1/ `Multi-thread`, 2/ `Multi-processing with multi-thread` 두 가지 옵션을 제공합니다.
- 둘 중 하나 선택하면 됩니다. (2번 방식이 더 빠릅니다)

### 3.1 Multi-thread

In [None]:
pool = ThreadPool(processes=10)
write_freq = 10
trans_data_chuck = []
start = time.time()

# error callback function
def custom_error_callback(error):
    print(f'Got error: {error}')

for idx, line in enumerate(open(input_file)):

    line = json.loads(line.strip())
    query, pos, neg = line["query"], line["pos"], line["neg"]

    query_ko = trans(query)

    task_pos = []
    for pos_ in pos:
        trans_pos= partial(trans, text=pos_)
        task_pos.append(pool.apply_async(trans_pos, error_callback=custom_error_callback))

    task_neg = []
    for neg_ in neg:
        trans_neg= partial(trans, text=neg_)
        task_neg.append(pool.apply_async(trans_neg, error_callback=custom_error_callback))

    pos_ko = [task.get() for task in task_pos if task.get() != "err"]
    neg_ko = [task.get() for task in task_neg if task.get() != "err"]

    trans_data = {}
    trans_data["query"], trans_data["pos"], trans_data["neg"] = query_ko, pos_ko, neg_ko

    trans_data_chuck.append(trans_data)

    if len(trans_data_chuck) == write_freq:
        with open(f'{out_file}.jsonl', "a+", encoding="utf-8") as f:
            for trans_data in trans_data_chuck:
                json.dump(trans_data, f, ensure_ascii=False) # ensure_ascii로 한글이 깨지지 않게 저장
                f.write("\n") # json을 쓰는 것과 같지만, 여러 줄을 써주는 것이므로 "\n"을 붙여준다.
        trans_data_chuck = []

    if idx % write_freq == 0:
        elapsed = time.time() - start
        elapsed = datetime.timedelta(seconds=elapsed)
        print (f'{idx}/499184, Elapsed: {elapsed}')


### 3.2 Multi-processing with multi-thread

In [None]:
from multiprocessing import Pool

In [None]:
# error callback function
def custom_error_callback(error):
    print(f'Got error: {error}')
    
def tranlation(input_file, out_file, start_idx, end_idx, write_freq):

    pool = ThreadPool(processes=7)
    trans_data_chuck = []
    start = time.time()

    for idx, line in enumerate(open(input_file)):

        if idx >= start_idx and idx < end_idx:

            line = json.loads(line.strip())
            query, pos, neg = line["query"], line["pos"], line["neg"]

            query_ko = trans(query)

            task_pos = []
            for pos_ in pos:
                trans_pos= partial(trans, text=pos_)
                task_pos.append(pool.apply_async(trans_pos,))

            task_neg = []
            for neg_ in neg:
                trans_neg= partial(trans, text=neg_)
                task_neg.append(pool.apply_async(trans_neg,))

            pos_ko = [task.get() for task in task_pos if task.get() != "err"]
            neg_ko = [task.get() for task in task_neg if task.get() != "err"]

            trans_data = {}
            trans_data["query"], trans_data["pos"], trans_data["neg"] = query_ko, pos_ko, neg_ko

            trans_data_chuck.append(trans_data)

            if len(trans_data_chuck) == write_freq:
                with open(f'{out_file}-{start_idx}.jsonl', "a+", encoding="utf-8") as f:
                    for trans_data in trans_data_chuck:
                        json.dump(trans_data, f, ensure_ascii=False) # ensure_ascii로 한글이 깨지지 않게 저장
                        f.write("\n") # json을 쓰는 것과 같지만, 여러 줄을 써주는 것이므로 "\n"을 붙여준다.
                trans_data_chuck = []

            if (idx-start_idx) % write_freq == 0:
                elapsed = time.time() - start
                elapsed = datetime.timedelta(seconds=elapsed)
                print (f'{idx-start_idx}/{end_idx-start_idx}, Elapsed: {elapsed}')

        if idx >= end_idx:
            with open(f'{out_file}-{start_idx}.jsonl', "a+", encoding="utf-8") as f:
                for trans_data in trans_data_chuck:
                    json.dump(trans_data, f, ensure_ascii=False) # ensure_ascii로 한글이 깨지지 않게 저장
                    f.write("\n") # json을 쓰는 것과 같지만, 여러 줄을 써주는 것이므로 "\n"을 붙여준다.
            break


In [None]:
total_row = 499184
worker_size = 6
interval = int(total_row/worker_size)
mp_pool = Pool(worker_size)

for i in range(0, total_row, interval):
    start_idx = i
    end_idx = start_idx + interval
    if end_idx > total_row:
        end_idx = total_row

    print (start_idx, end_idx)

    trans_jobs= partial(
        tranlation,
        input_file=input_file,
        out_file=out_file,
        start_idx=start_idx,
        end_idx=end_idx,
        write_freq=50
    )

    mp_pool.apply_async(trans_jobs)
    
mp_pool.close()
mp_pool.join()

## 4. Check files
Data format에 맞지 않는 샘플들은 제거합니다.

In [None]:
from glob import glob

In [None]:
dir_path = "./dataset/translated/msmarco/"

In [None]:
total_cnt = 0
for input_file in glob(os.path.join(dir_path, "msmarco-triplets-trans-*.jsonl")):
    cnt = 0
    for idx, line in enumerate(open(input_file)): cnt += 1
    total_cnt += cnt
    print (f'{input_file}: currently {cnt} lines')
print (f'total: {total_cnt} lines')

In [None]:
for input_file in glob(os.path.join(dir_path, "msmarco-triplets-trans-*.jsonl")):
    out_file = input_file.replace("msmarco-triplets-trans", "msmarco-triplets-trans-processed")
    
    print ("==========")
    print (f'input_file: {input_file}')
    print (f'out_file: {out_file}')
    
    processed_data = []
    for idx, line in enumerate(open(input_file)):
        line = json.loads(line.strip())
        query, pos, neg = line["query"], line["pos"], line["neg"]

        if len(query) > 0 and len(pos) > 0 and len(neg) > 0: processed_data.append(line)
        else: print (f'Skip line {idx}: query: {len(query)}, pos: {len(pos)}, neg: {len(neg)}')

    with open(out_file, "w", encoding="utf-8") as f:
        for data in processed_data:
            json.dump(data, f, ensure_ascii=False) # ensure_ascii로 한글이 깨지지 않게 저장
            f.write("\n") # json을 쓰는 것과 같지만, 여러 줄을 써주는 것이므로 "\n"을 붙여준다.


## 5. Merge

### 5.1 Check each files

In [None]:
dir_path = "./dataset/translated/msmarco/"

total_cnt = 0
for input_file in glob(os.path.join(dir_path, "msmarco-triplets-trans-processed-*.jsonl")):
    cnt = 0
    for idx, line in enumerate(open(input_file)): cnt += 1
    total_cnt += cnt
    print (f'{input_file}: currently {cnt} lines')
print (f'total: {total_cnt} lines')

###  5.2 Merge them

In [None]:
src = os.path.join(dir_path, "msmarco-triplets-trans-processed-*.jsonl")
dst = "./dataset/translated/merged/msmarco-triplets-trans-processed-merged.jsonl"
!cat $src > $dst

In [None]:
cnt = 0
for idx, line in enumerate(open(dst)): cnt += 1
print (f'{dst}: {cnt} lines')

## 6. Store data to S3

In [None]:
import sagemaker

In [None]:
bucket_name = sagemaker.Session().default_bucket()
print (f'bucket_name: {bucket_name}')

In [None]:
s3_data_path = f"s3://{bucket_name}/fine-tune-reranker-kr/dataset"
local_data_Path = os.path.join(os.getcwd(), "dataset", "translated", "merged")
file_name = "msmarco-triplets-trans-processed-merged.jsonl"

print (f's3_data_path: {s3_data_path}')
print (f'local_data_Path: {local_data_Path}')
print (f'file_name: {file_name}')

In [None]:
%%bash
aws configure set default.s3.max_concurrent_requests 100
aws configure set default.s3.max_queue_size 10000
aws configure set default.s3.multipart_threshold 1GB
aws configure set default.s3.multipart_chunksize 64MB

In [None]:
!aws s3 sync $local_data_Path $s3_data_path

### Data back-up (optional)

In [None]:
s3_data_path = f"s3://{bucket_name}/reranker-dataset-ko/"
local_data_Path = os.path.join(os.getcwd(), "dataset")

In [None]:
!aws s3 sync $local_data_Path $s3_data_path