# Direct Preference Optimization (DPO) Training with SageMaker

This notebook demonstrates how to use the **DPOTrainer** to fine-tune large language models using Direct Preference Optimization (DPO). DPO is a technique that trains models to align with human preferences by learning from preference data without requiring a separate reward model.

## What is DPO?

Direct Preference Optimization (DPO) is a method for training language models to follow human preferences. Unlike traditional RLHF (Reinforcement Learning from Human Feedback), DPO directly optimizes the model using preference pairs without needing a reward model.

**Key Benefits:**
- Simpler than RLHF - no reward model required
- More stable training process
- Direct optimization on preference data
- Works with LoRA for efficient fine-tuning

## Workflow Overview

1. **Prepare Preference Dataset**: Upload preference data in JSONL format
2. **Register Dataset**: Create a SageMaker AI Registry dataset
3. **Configure DPO Trainer**: Set up model, training parameters, and resources
4. **Execute Training**: Run the DPO fine-tuning job
5. **Track Results**: Monitor training with MLflow integration

***

### Prerequistes

#### Setup and dependencies

In [None]:
import boto3
from sagemaker.core.helper.session_helper import Session, get_execution_role

sess = Session()
sagemaker_session_bucket = None

if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = get_execution_role()
except ValueError:
    iam = boto3.client("iam")
    role = iam.get_role(RoleName="sagemaker_execution_role")["Role"]["Arn"]

s3_client = boto3.client("s3")
sess = Session(default_bucket=sagemaker_session_bucket)
sm_client = boto3.client("sagemaker", region_name=sess.boto_region_name)
bucket_name = sess.default_bucket()
default_prefix = sess.default_bucket_prefix

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

In [None]:
import os
from sagemaker.ai_registry.dataset import DataSet

# Required config
base_model_id = "meta-textgeneration-llama-3-2-1b-instruct"
training_dataset = DataSet.get(name="humanlike-dpo-train")
validation_dataset = DataSet.get(name="humanlike-dpo-val")

# Optional Configs
mlflow_resource_arn = "arn:aws:sagemaker:<region>:<account>:mlflow-app/app-xxxxxxx" # If you do not want to use the default app
job_name = f"dpo-{base_model_id.split('/')[-1].replace('.', '-')}"
mlflow_experiment_name = "humanlike-llama3-2-1b-dpo"

if default_prefix:
    output_path = f"s3://{bucket_name}/{default_prefix}/{base_model_id}-dpo"
else:
    output_path = f"s3://{bucket_name}/{base_model_id}-dpo"

os.environ["SAGEMAKER_MLFLOW_CUSTOM_ENDPOINT"] = (
    f"https://mlflow.sagemaker.{sess.boto_region_name}.app.aws"
)

***

### Create Model Package Group

In [None]:
from sagemaker.core.resources import ModelPackageGroup

model_package_group_name = f"{base_model_id}-dpo"

model_package_group = ModelPackageGroup.create(
    model_package_group_name=model_package_group_name,
    model_package_group_description='store models from SageMaker serverless customization' #Required Description
)

# Part 1: Configure and Execute DPO Training

### Step 1: Creating the Trainer

#### Create DPO Trainer (Direct Preference Optimization)

Direct Preference Optimization (DPO) is a method for training language models to follow human preferences. Unlike traditional RLHF (Reinforcement Learning from Human Feedback), DPO directly optimizes the model using preference pairs without needing a reward model.

**Key Benefits:**
- Simpler than RLHF - no reward model required
- More stable training process
- Direct optimization on preference data
- Works with LoRA for efficient fine-tuning

##### Key Parameters:
- `model` Base model to fine-tune (from SageMaker Hub)
- `training_type` Fine-tuning method (LoRA recommended for efficiency)
- `training_dataset` ARN of the registered preference dataset. Training Dataset - either Dataset ARN or S3 Path of the dataset (Please note these are required for a training job to run, can be either provided via Trainer or .train())
- `model_package_group` Where to store the fine-tuned model
- `mlflow_resource_arn`: MLFlow app ARN to track the training job (optional)
- `mlflow_experiment_name`: MLFlow app experiment name(str) (optional)
- `mlflow_run_name`: MLFlow app run name(str) (optional) not sure what this is?
- `validation_dataset`: Validation Dataset - either Dataset ARN or S3 Path of the dataset (optional)
- `s3_output_path`: S3 path for the trained model artifacts (optional)

### Training Features:
- **Serverless Training**: Automatically managed compute resources
- **LoRA Integration**: Parameter-efficient fine-tuning
- **MLflow Tracking**: Automatic experiment and metrics logging
- **Model Versioning**: Automatic model package creation

### Run Severless Job

In [None]:
from sagemaker.train.dpo_trainer import DPOTrainer
from sagemaker.train.common import TrainingType

In [None]:

trainer = DPOTrainer(
    model=base_model_id,
    training_type=TrainingType.LORA,
    model_package_group=model_package_group,
    training_dataset=training_dataset,
    validation_dataset=validation_dataset,
    s3_output_path=output_path,
    mlflow_resource_arn=mlflow_resource_arn,
    mlflow_experiment_name=mlflow_experiment_name,
    base_job_name=job_name,
    sagemaker_session=sess,
    accept_eula=True,
    role=role
)

Print Hyperparameters

In [None]:
from rich import print as rprint
from rich.pretty import pprint

print("Default Finetuning options:")
pprint(trainer.hyperparameters.to_dict())

Override Hyperparameters

In [None]:
# trainer.hyperparameters.learning_rate = 0.0001
trainer.hyperparameters.global_batch_size = 64
trainer.hyperparameters.max_epochs = 3

In [None]:
print("\nModified/user defined options:")
pprint(trainer.hyperparameters.to_dict())

In [None]:
from rich import print as rprint
from rich.pretty import pprint

training_job = trainer.train(wait=True)

TRAINING_JOB_NAME = training_job.training_job_name

pprint(training_job)