## Huggingface GRPOTrainer with Accelerate Multi-GPU Training on Amazon SageMaker

In [None]:
import os
from dotenv import load_dotenv
import boto3
from huggingface_hub import HfFolder
from sagemaker.modules import Session
from sagemaker.modules.train import ModelTrainer
from sagemaker.modules.distributed import Torchrun
from sagemaker.modules.train.model_trainer import Mode
from sagemaker.modules.configs import (
    Compute,
    SourceCode,
    InputData,
    CheckpointConfig,
)


load_dotenv()  # wandb and/or huggingface token in .env file

region = 'eu-central-1'
boto_session = boto3.Session(region_name=region)

In [None]:
pytorch_image = f'763104351884.dkr.ecr.{region}.amazonaws.com/pytorch-training:2.6.0-gpu-py312-cu126-ubuntu22.04-sagemaker'
# define the script to be run
source_code = SourceCode(
    source_dir="scripts/",
    requirements="requirements.txt",
    # entry_script="dpo.py",
    command="accelerate launch --config_file /opt/ml/input/data/code/default_config.yaml /opt/ml/input/data/code/grpo.py",
)

environment = {
    'HF_TOKEN': HfFolder.get_token(),
    'WANDB_API_KEY': os.environ.get('WANDB_API_KEY'),
}

compute = Compute(
    instance_count=1,
    instance_type="ml.p4d.24xlarge",
    # volume_size_in_gb=96,
    keep_alive_period_in_seconds=3600,
)

In [None]:
sess = Session(boto_session=boto_session)
bucket = sess.default_bucket()
base_job_name = "grpo-trl"
checkpoint_path = f"s3://{bucket}/{base_job_name}/checkpoints/"

# define the ModelTrainer
model_trainer = ModelTrainer(
    sagemaker_session=sess,
    training_image=pytorch_image,
    source_code=source_code,
    base_job_name=base_job_name,
    compute=compute,
    environment=environment,
    checkpoint_config=CheckpointConfig(
        s3_uri=checkpoint_path,
    ),
)

In [None]:
# start the training job
model_trainer.train(wait=False)

In [None]:
# With command line arguments

# source_code = SourceCode(
#     source_dir="scripts/",
#     requirements="requirements.txt",
#     # entry_script="grpo_advanced.py",
#     command="accelerate launch --config_file /opt/ml/input/data/code/default_config.yaml /opt/ml/input/data/code/grpo_advanced.py \
#             --dataset_name trl-lib/tldr \
#             --model_name_or_path Qwen/Qwen2-0.5B-Instruct \
#             --reward_model_name_or_path Qwen/Qwen2-72B-Instruct-AWQ \
#             --output_dir Qwen2-0.5B-GRPO-ADV \
#             --report_to wandb",
# )