In [None]:
# notebook.ipynb
import sagemaker
import os

from sagemaker.pytorch.estimator import PyTorch
from time import gmtime, strftime
job_id = "llama3-chinese-full-"+strftime("%Y-%m-%d-%H-%M-%S", gmtime())

# 训练结果存储的S3桶
bucket = "YOUR_BUCKET"

# 请确保训练使用的角色可以写入上述S3桶
sm_role = sagemaker.get_execution_role()
print(sm_role)

# 如果是PEFT训练则可以考虑使用ml.g5系列，并且设置use_spot_instances=True
instance_type = 'ml.p4d.24xlarge'
use_spot_instances = False
max_wait = 1200 if use_spot_instances else None

# 适当调整训练任务最大运行时间
max_run = 432000

# 如果想在训练完成后不立即回收加快下一次训练的启动时间，则可以设置keep_alive_period_in_seconds
keep_alive_period_in_seconds = None

# 拼接输出路径
output_uri = os.path.join("s3://", bucket, job_id, "output")
checkpoint_uri = os.path.join("s3://", bucket, job_id, "checkpoints")

metric_definitions = [
        {'Name': 'loss', 'Regex': "'loss': ([0-9]+(.|e-)[0-9]+),?"},
        {'Name': 'grad_norm', 'Regex': "'grad_norm': ([0-9]+(.|e-)[0-9]+),?"},
        {'Name': 'learning_rate', 'Regex': "'learning_rate': ([0-9]+(.|e-)[0-9]+),?"},
        {'Name': 'rewards/chosen', 'Regex': "'rewards/chosen': ([0-9]+(.|e-)[0-9]+),?"},
        {'Name': 'rewards/rejected', 'Regex': "'rewards/rejected': ([0-9]+(.|e-)[0-9]+),?"},
        {'Name': 'rewards/accuracies', 'Regex': "'rewards/accuracies': ([0-9]+(.|e-)[0-9]+),?"},
        {'Name': 'rewards/margins', 'Regex': "'rewards/margins': ([0-9]+(.|e-)[0-9]+),?"},
        {'Name': 'logps/rejected', 'Regex': "'logps/rejected': ([0-9]+(.|e-)[0-9]+),?"},
        {'Name': 'logps/chosen', 'Regex': "'logps/chosen': ([0-9]+(.|e-)[0-9]+),?"},
        {'Name': 'logits/rejected', 'Regex': "'logits/rejected': ([0-9]+(.|e-)[0-9]+),?"},
        {'Name': 'logits/chosen', 'Regex': "'logits/chosen': ([0-9]+(.|e-)[0-9]+),?"},
        {'Name': 'sft_loss', 'Regex': "'sft_loss': ([0-9]+(.|e-)[0-9]+),?"},
        {'Name': 'odds_ratio_loss', 'Regex': "'odds_ratio_loss': ([0-9]+(.|e-)[0-9]+),?"},
        {'Name': 'epoch', 'Regex': "'epoch': ([0-9]+(.|e-)[0-9]+),?"},
]

environment = {'NCCL_DEBUG': 'INFO', 'FI_PROVIDER': 'efa', 'NCCL_PROTO': 'simple', 'FI_EFA_USE_DEVICE_RDMA': '1'}

huggingface_estimator = PyTorch(entry_point='train.sh',
                                    source_dir='./src',
                                    instance_type=instance_type,
                                    instance_count=1,
                                    image_uri='YOUR_IMAGE_URI',
                                    role=sm_role,
                                    metric_definitions = metric_definitions,
                                    environment=environment,
                                    use_spot_instances=use_spot_instances,
                                    max_run=max_run,
                                    max_wait=max_wait,
                                    output_path=output_uri,
                                    checkpoint_s3_uri=checkpoint_uri,
                                    keep_alive_period_in_seconds=keep_alive_period_in_seconds,
                                   )

huggingface_estimator.fit(job_name=job_id)
