# GRPO Fine Tuning Llama-3.1 8B model with HuggingFace Estimator on ml.g6e.48xlarge

## Source Directory structure

```
src/config.yaml - Contains all Training Args including Hyperparameters required for training.
src/grpo_training.py - Training script (Contains multiple functions including reward and training functions. Check script args within @dataclass if prompt_length to be udpated.)
src/requirements.txt - Dependencies to be installed before training.
``` 

### Overview of grpo_training.py script
```
class ScriptArguments - Handles the script args like train_datset, test_dataset path, model_id, etc.,
def merge_and_save_model - Saves the PEFT adapter model with the base model post training.
def format_reward - Reward function(1) to generate rewards based on exact(re) match between groudtruth and model generated responses.
def accuracy_reward - Reward function(2) to generate rewards based on the semantic similarity between groudtruth and model generated responses using sentence transformers.
def training_function - Training function with Lora config, Quant config, GRPO config and GRPO trainer.
def main - main function.
```


## Install dependencies

In [None]:
!pip install -U sagemaker transformers datasets "huggingface_hub[cli]" --upgrade --quiet

## Login to huggingface using your token

In [None]:
!huggingface-cli login --token ""

## Import Sagemaker and boto3 modules and define S3 bucket for input and output data with role and region

In [None]:
import sagemaker
import boto3
sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

source_dir = "./src"
try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']
 
sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)
 
print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

## Data Preparation

In [None]:
from datasets import load_dataset
dataset_id = "trl-lib/tldr"
train_dataset, test_dataset = load_dataset(dataset_id, split=["train[:5%]", "test[:1%]"])


## Transform data in chat template with system prompt

In [None]:
SYSTEM_PROMPT = """
    "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
    "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
    "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
    "<think> reasoning process here </think><answer> answer here </answer>"
"""

def make_conversation(data):
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['prompt']}
        ],
        'completion': [{'role': 'assistant', 'content':x['completion']}]
    })
    return data


train_dataset = make_conversation(train_dataset)
test_dataset = make_conversation(test_dataset)

In [None]:
train_dataset.to_json("train_dataset.json")
test_dataset.to_json("test_dataset.json")


### Upload the train/test dataset to S3 bucket

In [None]:
# save train_dataset to s3 using our SageMaker session
input_path = f's3://{sagemaker_session_bucket}/datasets/llama3'
 
from sagemaker.s3 import S3Uploader
train_dataset_s3_path = S3Uploader.upload(local_path="./train_dataset.json", desired_s3_uri=f"{input_path}/train_v3")
test_dataset_s3_path = S3Uploader.upload(local_path="./test_dataset.json", desired_s3_uri=f"{input_path}/test_v3")

print(f"Training data uploaded to:")
print(train_dataset_s3_path)
print(test_dataset_s3_path)

### Upload config.yaml from source_dir to S3

In [None]:
from sagemaker.s3 import S3Uploader
 
# upload the model yaml file to s3
model_yaml = "{}/config.yaml".format(source_dir)
train_config_s3_path = S3Uploader.upload(local_path=model_yaml, desired_s3_uri=f"{input_path}/config")
 
print(f"Training config uploaded to:")
print(train_config_s3_path)

## Training with PyTorch estimator and DLC(Deep Learning Container) Image

In [None]:
from sagemaker.huggingface import HuggingFace
from sagemaker.pytorch import PyTorch
from huggingface_hub import HfFolder

train_dlc_image = "763104351884.dkr.ecr.{}.amazonaws.com/pytorch-training:2.7.1-gpu-py312-cu128-ubuntu22.04-sagemaker".format(sess.boto_region_name)
# define Training Job Name 
job_name = f'llama3-1-8b-grpo'


# create the Estimator 
pytorch_estimator = PyTorch(
    entry_point          = 'grpo_train.py',      # train script
    source_dir           = source_dir,  # directory which includes all the files needed for training
    instance_type        = 'ml.g6e.48xlarge',  # instances type used for the training job
    instance_count       = 1,                 # the number of instances used for training
    base_job_name        = job_name,          # the name of the training job
    role                 = role,              # Iam role used in training job to access AWS ressources, e.g. S3
    volume_size          = 500,               # the size of the EBS volume in GB
    py_version           = 'py312',           # the python version used in the training job
    image_uri            = train_dlc_image,
    hyperparameters      =  {
        "config": "/opt/ml/input/data/config/config.yaml" # path to TRL config which was uploaded to s3
    },
    #distribution={"torch_distributed": {"enabled": True}},   # enables torchrun
    keep_alive_period_in_seconds=1800, #warm pool
    disable_output_compression = True,        # not compress output to save training time and cost
    environment  = {
        "HUGGINGFACE_HUB_CACHE": "/tmp/.cache", # set env variable to cache models in /tmp
        "HF_TOKEN": HfFolder.get_token(),       # huggingface token to access gated models, e.g. llama 3
        "ACCELERATE_USE_FSDP": "1",             # enable FSDP
        "FSDP_CPU_RAM_EFFICIENT_LOADING": "1"   # enable CPU RAM efficient loading
    }, 
)

In [None]:
# define a data input dictonary with our uploaded s3 uris
data = {
  'train': train_dataset_s3_path,
  'test': test_dataset_s3_path,
  'config': train_config_s3_path
  }
 
# starting the train job with our uploaded datasets as input
pytorch_estimator.fit(data, wait=True)

## Deploy the Fine-tuned model in a Sagemaker Endpoint

In [None]:
pytorch_estimator.model_data

In [None]:
from sagemaker.huggingface import get_huggingface_llm_image_uri
 
# retrieve the llm image uri
hf_image = get_huggingface_llm_image_uri(
  "huggingface",
  session=sess,)
# print ecr image uri
print(f"llm image uri: {hf_image}")

In [None]:
from huggingface_hub import HfFolder
from sagemaker.huggingface import HuggingFaceModel
 
# sagemaker config
instance_type = "ml.g5.12xlarge"
health_check_timeout = 1200 # 20 minutes
 
# Define Model and Endpoint configuration parameter
config = {
  'HF_MODEL_ID': "/opt/ml/model",       # Path to the model in the container
  'SM_NUM_GPUS': "4",                   # Number of GPU used per replica
  'MAX_INPUT_LENGTH': "1024",           # Max length of input text
  'MAX_TOTAL_TOKENS': "2048",           # Max length of the generation (including input text)
  'MAX_BATCH_PREFILL_TOKENS': "4096",  # Limits the number of tokens that can be processed in parallel during the generation
  'MESSAGES_API_ENABLED': "true",       # Enable the OpenAI Messages API
}
 
# create HuggingFaceModel with the image uri
grpo_model = HuggingFaceModel(
  role=role,
  # path to s3 bucket with model, we are not using a compressed model
  # {'S3DataSource':{'S3Uri': "s3://...",'S3DataType': 'S3Prefix','CompressionType': 'None'}},
  model_data=pytorch_estimator.model_data,
  image_uri=hf_image,
  env=config
)

In [None]:
# Deploy model to an endpoint
reasoning_model = grpo_model.deploy(
  initial_instance_count=1,
  instance_type=instance_type,
  container_startup_health_check_timeout=health_check_timeout, # 20 minutes to give SageMaker the time to download and merge model
)

### Test Inference

In [None]:
def inference_request(messages):
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    outputs = reasoning_model.predict({
      "inputs": prompt,
      "parameters": {
        "max_new_tokens": 512,
        "do_sample": False,
      }
    })
    return {"role": "assistant", "content": outputs[0]["generated_text"].strip()}

In [None]:
from transformers import AutoTokenizer

model_id = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)

SYSTEM_PROMPT = (
    "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
    "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
    "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
    "<think> reasoning process here </think><answer> answer here </answer>"
)

prompt = "SUBREDDIT: r\/tifu\n\nTITLE: TIFU bY brushing with Baking Soda without learning how to do it correctly.\n\nPOST: Always wanted White Teeth but never visited the dentist since I was 8 due to fear [gotten bad experience as a kid].        \n\nSo I heard that baking soda makes your teeth white if you brush your teeth with it.        \nWhat I didn't get from all the reading, is that though it is supposed to be made into a paste, it shouldn't still be gritty.       \n\nI always kept my baking soda paste gritty by putting very little water.        \n\nAfter brushing straight with it for three months, my gum was extremely sore, but on the up side is, it is true, it is all true, I am amazed myself ! My teeth is very VERY white now compared to the past and even when taking pictures, the teeth becomes the center of attention simply because of how white it is, even my friends jokingly asked if I have painted it white.       \nThese are the images after baking soda brushing for months, understand that I have NEVER visited a dentist ever since I was 8:    \n   \n\nAs my ego grew, I forget about the irritation from the gum and keep on using it.      \nOne fine day, my gum gave up...I was brushing and I saw a nice chunk of my gum get physically brushed OUT of my teeth, I was shocked and at a lost of what I should do...I tried to piece the gum back in hoping that it would stay, suffices to say by the very next day, the gum eventually fall off.       \n\nIt is not that visible if I don't smile too big, but let this be a lesson to all of you out there, baking soda paste works, BUT PLEASE, make sure the paste is not gritty, PLEASE...don't experience this ever.\n\nTL;DR:"
messages = [
    {"role": "system", "content": SYSTEM_PROMPT},
    {"role": "user", "content": prompt},
]


In [None]:
inference_request(messages)

## Delete endpoint and model

In [None]:
reasoning_model.delete_model()
reasoning_model.delete_endpoint(delete_endpoint_config=True)