## RLAIF Example - Finetuning with Sagemaker

This notebook demonstrates basic user flow for RLAIF Finetuning from a model available in Sagemaker Jumpstart.
Information on available models on jumpstart: https://docs.aws.amazon.com/sagemaker/latest/dg/jumpstart-foundation-models-latest.html

### Setup and Configuration

Initialize the environment by importing necessary libraries and configuring AWS credentials

In [None]:
# Configure AWS credentials and region
#! ada credentials update --provider=isengard --account=<> --role=Admin --profile=default --once
#! aws configure set region us-west-2

In [None]:
#!/usr/bin/env python3

from sagemaker.train.rlaif_trainer import RLAIFTrainer
from sagemaker.train.configs import InputData
from rich import print as rprint
from rich.pretty import pprint
from sagemaker.core.resources import ModelPackage
import os
#os.environ['SAGEMAKER_REGION'] = 'us-east-1'
#os.environ['SAGEMAKER_STAGE'] = 'prod'

import boto3
from sagemaker.core.helper.session_helper import Session

# For MLFlow native metrics in Trainer wait, run below line with approriate region
os.environ["SAGEMAKER_MLFLOW_CUSTOM_ENDPOINT"] = "https://mlflow.sagemaker.us-west-2.app.aws"



### Create RLAIFTrainer
**Required Parameters** 

* `model`: base_model id on Sagemaker Hubcontent that is available to finetune (or) ModelPackage artifacts

**Optional Parameters**
* `reward_model_id`: Bedrock model id, supported evaluation models: https://docs.aws.amazon.com/bedrock/latest/userguide/evaluation-judge.html
*  `reward_prompt`: Reward prompt ARN or builtin prompts refer: https://docs.aws.amazon.com/bedrock/latest/userguide/model-evaluation-metrics.html
* `model_package_group_name`: ModelPackage group name or ModelPackageGroup
* `mlflow_resource_arn`: MLFlow app ARN to track the training job
* `mlflow_experiment_name`: MLFlow app experiment name(str)
* `mlflow_run_name`: MLFlow app run name(str)
* `training_dataset`: Training Dataset - either Dataset ARN or S3 Path of the dataset (Please note these are required for a training job to run, can be either provided via Trainer or .train())
* `validation_dataset`: Validation Dataset - either Dataset ARN or S3 Path of the dataset
* `s3_output_path`: S3 path for the trained model artifacts

In [None]:
# For fine-tuning 
rlaif_trainer = RLAIFTrainer(
    model="meta-textgeneration-llama-3-2-1b-instruct", # Union[str, ModelPackage] 
    model_package_group_name="sdk-test-finetuned-models", # Make it Optional
    reward_model_id='anthropic.claude-3-5-sonnet-20240620-v1:0',
    reward_prompt='Builtin.Correctness',
    #mlflow_resource_arn="arn:aws:sagemaker:us-west-2:<>:mlflow-tracking-server/mmlu-eval-experiment",  # Optional[str] - MLflow app ARN (auto-resolved if not provided), can accept name and search in the account
    mlflow_experiment_name="test-rlaif-finetuned-models-exp", # Optional[str]
    mlflow_run_name="test-rlaif-finetuned-models-run", # Optional[str]
    training_dataset="s3://mc-flows-sdk-testing/input_data/rlvr-rlaif-test-data/train_285.jsonl", #Optional[]
    s3_output_path="s3://mc-flows-sdk-testing/output/",
    accept_eula=True
    #sagemaker_session=sagemaker_session,
    #role="arn:aws:iam::<>:role/Admin"
)

### Discover and update Finetuning options

Each of the technique and model has overridable hyperparameters that can be finetuned by the user.

In [None]:
print("Default Finetuning Options:")
pprint(rlaif_trainer.hyperparameters.to_dict()) # rename as hyperparameters

#set options
rlaif_trainer.hyperparameters.get_info()


#### Start RLAIF training


In [None]:

training_job = rlaif_trainer.train(wait=True)


In [None]:
training_job = rlaif_trainer.train(wait=True)


In [None]:
import re
from sagemaker.core.utils.utils import Unassigned
import json


def pretty_print(obj):
    def parse_unassigned(item):
        if isinstance(item, Unassigned):
            return None
        if isinstance(item, dict):
            return {k: parse_unassigned(v) for k, v in item.items() if parse_unassigned(v) is not None}
        if isinstance(item, list):
            return [parse_unassigned(x) for x in item if parse_unassigned(x) is not None]
        if isinstance(item, str) and "Unassigned object" in item:
            pairs = re.findall(r"(\w+)=([^<][^=]*?)(?=\s+\w+=|$)", item)
            result = {k: v.strip("'\"") for k, v in pairs}
            return result if result else None
        return item

    cleaned = parse_unassigned(obj.__dict__ if hasattr(obj, '__dict__') else obj)
    print(json.dumps(cleaned, indent=2, default=str))

pretty_print(training_job)

### View any Training job details

We can get any training job details and its status with TrainingJob.get(...)

In [None]:
from sagemaker.core.resources import TrainingJob

response = TrainingJob.get(training_job_name="meta-textgeneration-llama-3-2-1b-instruct-rlaif-20251124140754")
pretty_print(response)

### Test RLAIF with Custom Reward Prompt

Here we are providing a user-defined reward prompt/evaluator ARN

In [None]:


# For fine-tuning 
rlaif_trainer = RLAIFTrainer(
    model="meta-textgeneration-llama-3-2-1b-instruct", # Union[str, ModelPackage] 
    model_package_group_name="sdk-test-finetuned-models", # Make it Optional
    reward_model_id='anthropic.claude-3-5-sonnet-20240620-v1:0',
    reward_prompt="arn:aws:sagemaker:us-west-2:<>:hub-content/sdktest/JsonDoc/rlaif-test-prompt/0.0.1",
    #mlflow_resource_arn="arn:aws:sagemaker:us-west-2:<>:mlflow-tracking-server/mmlu-eval-experiment",  # Optional[str] - MLflow app ARN (auto-resolved if not provided), can accept name and search in the account
    mlflow_experiment_name="test-rlaif-finetuned-models-exp", # Optional[str]
    mlflow_run_name="test-rlaif-finetuned-models-run", # Optional[str]
    training_dataset="s3://mc-flows-sdk-testing/input_data/rlvr-rlaif-test-data/train_285.jsonl", #Optional[str]
    s3_output_path="s3://mc-flows-sdk-testing/output/",
    accept_eula=True
    #sagemaker_session=sagemaker_session,
    #role="arn:aws:iam::<>:role/service-role/AmazonSageMaker-ExecutionRole-20250731T162975"
    #role="arn:aws:iam::<>:role/Admin"
)


In [None]:
training_job = rlaif_trainer.train(wait=True,
                                   logs=True)

In [None]:
training_job = rlaif_trainer.train(wait=True)

In [None]:
pretty_print(training_job)