# Multi-Turn GRPO Training on SageMaker P4d

This notebook launches the multi-turn GRPO training job on SageMaker using P4d instances.

In [None]:
import sagemaker
from sagemaker.pytorch import PyTorch
from sagemaker import get_execution_role
import boto3

sess = sagemaker.Session()
role = get_execution_role()
region = boto3.Session().region_name
bucket = sess.default_bucket()

print(f"Region: {region}")
print(f"Role: {role}")
print(f"Bucket: {bucket}")

In [None]:
from datetime import datetime

# Configuration
instance_type = "ml.p4d.24xlarge"  # 8x A100 GPUs
instance_count = 1
config_file = "hf_recipes/Qwen/Qwen3-1.7B--mt-grpo.yaml"
source_dir = "./sagemaker_code"
output_path = f"s3://{bucket}/mt-grpo-output"
job_name = f"mt-grpo-qwen3-17b-15epochs-{datetime.now().strftime('%Y%m%d-%H%M%S')}"

In [None]:
estimator = PyTorch(
    entry_point="sm_mt_grpo_train.sh",
    source_dir=source_dir,
    role=role,
    instance_type=instance_type,
    instance_count=instance_count,
    framework_version="2.5.1",
    py_version="py311",
    output_path=output_path,
    hyperparameters={
        "config": config_file,
    },
    environment={
        "HF_HUB_ENABLE_HF_TRANSFER": "1",
        "NCCL_DEBUG": "INFO",
        "WANDB_API_KEY": "",
        "WANDB_ENTITY": "",
        "WANDB_PROJECT": "",
    },
    keep_alive_period_in_seconds=1800,
    disable_output_compression=True,
    max_run=604800,  # 28 Days
)

In [None]:
estimator.fit(job_name=job_name, wait=True)

In [None]:
# Get training job details
print(f"Training job name: {estimator.latest_training_job.name}")
print(f"Model artifacts: {estimator.model_data}")
print(f"Output path: {output_path}")