# Fine-tune for Korean ReRanker based on Amazon SageMaker
 - **한국어 ReRanker 모델 파인튜닝 예시는 [FlagEmbedding](https://github.com/FlagOpen/FlagEmbedding/tree/master?tab=readme-ov-file)을 기반으로 합니다.**
 - Fine-tuning은 SageMaker 기반 Distributed Learning으로 진행됩니다.

## AutoReload

In [1]:
%load_ext autoreload
%autoreload 2

## 1. Dataset

In [2]:
import os
import sagemaker

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml


In [3]:
bucket_name = sagemaker.Session().default_bucket()
print (f'bucket_name: {bucket_name}')

bucket_name: sagemaker-us-east-1-419974056037


#### [확인] `1.data-preprocessing.ipynb`에서 데이터를 저장한 경로를 이용합니다.
#### [Sample data] `./dataset/toy_finetune_data_kr.jsonl` 참고

In [4]:
s3_data_path = f"s3://{bucket_name}/fine-tune-reranker-kr/dataset/" 
local_data_Path = os.path.join(os.getcwd(), "dataset", "translated", "merged")
file_name = "msmarco-triplets-trans-processed-merged.jsonl"

print (f's3_data_path: {s3_data_path}')
print (f'local_data_Path: {local_data_Path}')
print (f'file_name: {file_name}')

s3_data_path: s3://sagemaker-us-east-1-419974056037/fine-tune-reranker-kr/dataset/
local_data_Path: /home/ec2-user/SageMaker/aws-ai-ml-workshop-kr/genai/aws-gen-ai-kr/30_fine_tune/reranker-kr/dataset/translated/merged
file_name: msmarco-triplets-trans-processed-merged.jsonl


## 2.Training-job


### 2.1 params for training job


In [5]:
from sagemaker import get_execution_role
from sagemaker.inputs import TrainingInput

In [6]:
# Set to True to enable SageMaker to run locally
local_mode = False

channel = "train"
fast_file = lambda x: TrainingInput(x, input_mode="FastFile")

if local_mode:
    
    from sagemaker.local import LocalSession
    
    instance_type = "local_gpu"
    sagemaker_session = LocalSession()
    sagemaker_session.config = {'local': {'local_code': True}}
            
    data_channel = {
        #channel: f'file:///home/ec2-user/SageMaker/fine-tune-reranker-kr/dataset/translated/merged/msmarco-triplets-trans-processed-merged-sample.jsonl',
        channel: f'file://{os.path.join(local_data_Path, file_name)}',
    }
    
else:
    instance_type = "ml.p3.8xlarge"# "ml.p3.8xlarge", "ml.g5.12xlarge", "ml.p3dn.24xlarge"
    
    sagemaker_session = sagemaker.Session()

    data_channel = {
        channel: os.path.join(s3_data_path, file_name),
    }

role = get_execution_role().rsplit('/', 1)[-1]

instance_count = 1

spot_training = False
if spot_training:
    max_wait = 1*60*60
    max_run = 1*60*60
    
else:
    max_wait = None
    max_run = 1*60*60
    

use_train_warm_pool = True ## training image 다운받지 않음, 속도 빨라진다
if use_train_warm_pool: keep_alive_seconds = 3600 ## 최대 1시간 동안!!, service quota에서 warmpool을 위한 request 필요
else: keep_alive_seconds = None
if spot_training:
    use_train_warm_pool = False # warmpool은 spot instance 사용시 활용 할 수 없음
    keep_alive_seconds = None

prefix = "fine-tune-reranker-kr"
job_name = "-".join([prefix, "training"])

output_path = os.path.join(
    "s3://{}".format(bucket_name),
    prefix,
    "training",
    "model-output"
)

code_location = os.path.join(
    "s3://{}".format(bucket_name),
    prefix,
    "training",
    "backup-codes"
)

s3_chkpt_path = os.path.join(
    "s3://{}".format(bucket_name),
    prefix,
    "training",
    "checkpoints"
)

In [7]:
print (f"SageMaker Execution Role Name: {role}")
print (f"job_name: {job_name}")
print (f'instance_type: {instance_type}')
print (f'instance_count: {instance_count}')
print (f'sagemaker_session: {sagemaker_session}')
print (f'spot_training: {spot_training}')
print (f'data_channel: {data_channel}')
print (f'output_path: {output_path}')
print (f'code_location: {code_location}')
print (f'use_train_warm_pool: {use_train_warm_pool}/{keep_alive_seconds}')
print (f's3_chkpt_path: {s3_chkpt_path}')

SageMaker Execution Role Name: AmazonSageMaker-ExecutionRole-20221206T163436
job_name: fine-tune-reranker-kr-training
instance_type: ml.p3.8xlarge
instance_count: 1
sagemaker_session: <sagemaker.session.Session object at 0x7ff61d43f8e0>
spot_training: False
data_channel: {'train': 's3://sagemaker-us-east-1-419974056037/fine-tune-reranker-kr/dataset/msmarco-triplets-trans-processed-merged.jsonl'}
output_path: s3://sagemaker-us-east-1-419974056037/fine-tune-reranker-kr/training/model-output
code_location: s3://sagemaker-us-east-1-419974056037/fine-tune-reranker-kr/training/backup-codes
use_train_warm_pool: True/3600
s3_chkpt_path: s3://sagemaker-us-east-1-419974056037/fine-tune-reranker-kr/training/checkpoints


### 2.2 Define training job


In [8]:
from sagemaker.huggingface import HuggingFace

In [9]:
hyperparameters = {
    "output_dir": "/opt/ml/model",
    "model_name_or_path": "BAAI/bge-reranker-large",
    #"train_data": os.path.join(f'/opt/ml/input/data/train/msmarco-triplets-trans-processed-merged-sample.jsonl'),
    "train_data": os.path.join(f'/opt/ml/input/data/{channel}/{file_name}'),
    "learning_rate": 5e-6,
    "fp16": True,
    "num_train_epochs": 3,
    "per_device_train_batch_size": 1,
    "gradient_accumulation_steps": 32,
    "dataloader_drop_last": True,
    "train_group_size": 3,
    "max_len": 512,
    "weight_decay": 0.01,
    "logging_steps": 30,
    #"save_strategy": "no"
    "save_steps": 1000,
    "save_total_limit": 1,
}

# enable torchrun
distribution = {"torch_distributed": {"enabled": True}} 


- [SageMaker built-in images](https://github.com/aws/deep-learning-containers/blob/master/available_images.md)

In [10]:
# estimator
huggingface_estimator = HuggingFace(
    entry_point='run.py',
    source_dir='./src/fine-tune/',
    instance_type=instance_type,
    instance_count=instance_count,
    volume_size=500,
    role=role,
    job_name=job_name,
    transformers_version='4.28.1',
    pytorch_version='2.0.0',
    py_version="py310",
    hyperparameters = hyperparameters,
    distribution=distribution,
    sagemaker_session=sagemaker_session,
    keep_alive_period_in_seconds=keep_alive_seconds,
    output_path=output_path,
    code_location=code_location,
    #input_mode='FastFile',
    checkpoint_s3_uri=s3_chkpt_path if instance_type not in ['local', 'local_gpu'] else None,
    checkpoint_local_path='/opt/checkpoints' if instance_type not in ['local', 'local_gpu'] else None,
)

### 2.3 Start Training job
S3에서 훈련 인스턴스로 복사될 데이터를 지정한 후 SageMaker 훈련 job을 시작합니다. 모델 크기, 데이터 세트 크기에 따라서 몇십 분에서 몇 시간까지 소요될 수 있습니다.

In [11]:
huggingface_estimator.fit(
    inputs=data_channel,
    wait=False
)

INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.
INFO:sagemaker:Creating training-job with name: huggingface-pytorch-training-2024-01-05-01-05-18-679


### 2.4 View job information and logs
훈련 로그는 CloudWatch Logs를 통해서 확인할 수 있습니다. 만약 다른 코드 셀을 실행하고 싶다면 이 코드 셀의 실행을 중단하셔도 됩니다.



In [12]:
import boto3
from IPython.display import display, HTML

In [13]:
def make_console_link(region, train_job_name, train_task='[Training]'):
    train_job_link = f'<b> {train_task} Review <a target="blank" href="https://console.aws.amazon.com/sagemaker/home?region={region}#/jobs/{train_job_name}">Training Job</a></b>'   
    cloudwatch_link = f'<b> {train_task} Review <a target="blank" href="https://console.aws.amazon.com/cloudwatch/home?region={region}#logStream:group=/aws/sagemaker/TrainingJobs;prefix={train_job_name};streamFilter=typeLogStreamPrefix">CloudWatch Logs</a></b>'
    return train_job_link, cloudwatch_link  
        
region = boto3.Session().region_name
train_job_name = huggingface_estimator.latest_training_job.job_name
train_job_link, cloudwatch_link = make_console_link(region, train_job_name, '[Fine-tuning]')

display(HTML(train_job_link))
display(HTML(cloudwatch_link))

In [None]:
huggingface_estimator.logs()

2024-01-05 01:05:20 Starting - Starting the training job......
2024-01-05 01:06:06 Starting - Preparing the instances for training......
2024-01-05 01:07:26 Downloading - Downloading input data..................................................................
2024-01-05 01:18:29 Training - Training image download completed. Training in progress..[34mbash: cannot set terminal process group (-1): Inappropriate ioctl for device[0m
[34mbash: no job control in this shell[0m
[34m2024-01-05 01:18:30,869 sagemaker-training-toolkit INFO     Imported framework sagemaker_pytorch_container.training[0m
[34m2024-01-05 01:18:30,906 sagemaker-training-toolkit INFO     No Neurons detected (normal if no neurons installed)[0m
[34m2024-01-05 01:18:30,918 sagemaker_pytorch_container.training INFO     Block until all host DNS lookups succeed.[0m
[34m2024-01-05 01:18:30,920 sagemaker_pytorch_container.training INFO     Invoking TorchDistributed...[0m
[34m2024-01-05 01:18:30,920 sagemaker_pytorch



Training seconds: 1408
Billable seconds: 1408
