# Fine-tune LLM with PyTorch FSDP and QLora on Amazon SageMaker AI using ModelTrainer

In this notebook, we fine-tune LLM on Amazon SageMaker AI, using Python scripts and SageMaker ModelTrainer for executing a training job.

## Prerequisites

In [None]:
%pip install -r ./scripts/requirements.txt --upgrade

In [None]:
# Copy Ray launcher script to the scripts directory. 
%cp ../../../scripts/launcher.py ./scripts/

***

## Setup Configuration file path

In [None]:
import os

# os.environ["AWS_PROFILE"] = "<aws_profile>"

In [None]:
import os

model_id = "Qwen/Qwen3-0.6B"

os.environ["model_id"] = model_id

***

## Prepare the dataset

We are going to load [FreedomIntelligence/medical-o1-reasoning-SFT](https://huggingface.co/datasets/FreedomIntelligence/medical-o1-reasoning-SFT) dataset

In [None]:
import sagemaker

In [None]:
sagemaker_session = sagemaker.Session()
bucket_name = sagemaker_session.default_bucket()
default_prefix = sagemaker_session.default_bucket_prefix
role = sagemaker.get_execution_role()

In [None]:
from datasets import load_dataset

dataset = load_dataset(
    "FreedomIntelligence/medical-o1-reasoning-SFT", "en", split="train[:10000]"
)

dataset

In [None]:
import pandas as pd

df = pd.DataFrame(dataset)

df.head()

In [None]:
from sklearn.model_selection import train_test_split

train, val = train_test_split(df, test_size=0.1, random_state=42)

print("Number of train elements: ", len(train))
print("Number of val elements: ", len(val))

Create a prompt template and load the dataset with a random sample to try summarization.

In [None]:
from transformers import AutoTokenizer


tokenizer = AutoTokenizer.from_pretrained(model_id)

def prepare_dataset(sample):

    system_text = (
        "You are a deep-thinking AI assistant.\n\n"
        "For every user question, first write your thoughts and reasoning inside <think>...</think> tags, then provide your answer."
    )

    messages = []

    messages.append({"role": "system", "content": system_text})
    messages.append({"role": "user", "content": sample["Question"]})
    messages.append(
        {
            "role": "assistant",
            "content": f"<think>\n{sample['Complex_CoT'].lower()}\n</think>\n\n{sample['Response']}",
        }
    )

    # Apply chat template
    sample["text"] = tokenizer.apply_chat_template(messages, tokenize=False)

    return sample

In [None]:
from datasets import Dataset, DatasetDict
from random import randint

train_dataset = Dataset.from_pandas(train)
val_dataset = Dataset.from_pandas(val)

dataset = DatasetDict({"train": train_dataset, "val": val_dataset})

train_dataset = dataset["train"].map(
    prepare_dataset, remove_columns=list(train_dataset.features)
)

print(train_dataset[randint(0, len(dataset))]["text"])

val_dataset = dataset["val"].map(
    prepare_dataset, remove_columns=list(val_dataset.features)
)

### Upload to Amazon S3

In [None]:
import boto3
import shutil
import sagemaker

In [None]:
sagemaker_session = sagemaker.Session()
s3_client = boto3.client('s3')

bucket_name = sagemaker_session.default_bucket()
default_prefix = sagemaker_session.default_bucket_prefix

In [None]:
# save train_dataset to s3 using our SageMaker session
if default_prefix:
    input_path = f"{default_prefix}/datasets/llm-fine-tuning-modeltrainer-sft-ray"
else:
    input_path = f"datasets/llm-fine-tuning-modeltrainer-sft-ray"

train_dataset_s3_path = f"s3://{bucket_name}/{input_path}/train/dataset.json"
val_dataset_s3_path = f"s3://{bucket_name}/{input_path}/val/dataset.json"

In [None]:
# Save datasets to s3
# We will fine tune only with 20 records due to limited compute resource for the workshop
train_dataset.to_json("./data/train/dataset.json", orient="records")
val_dataset.to_json("./data/val/dataset.json", orient="records")

s3_client.upload_file(
    "./data/train/dataset.json", bucket_name, f"{input_path}/train/dataset.json"
)
s3_client.upload_file(
    "./data/val/dataset.json", bucket_name, f"{input_path}/val/dataset.json"
)

shutil.rmtree("./data")

print(f"Training data uploaded to:")
print(train_dataset_s3_path)
print(val_dataset_s3_path)

***

## (Optional) Copy Prometheus binary

In case you want to avoid Ray to download prometheus, you can copy the binary on S3 and pass as parameter to the Training job

In [None]:
! wget https://github.com/prometheus/prometheus/releases/download/v3.4.2/prometheus-3.4.2.linux-amd64.tar.gz

In [None]:
import boto3
import sagemaker

In [None]:
sagemaker_session = sagemaker.Session()
s3_client = boto3.client('s3')

bucket_name = sagemaker_session.default_bucket()
default_prefix = sagemaker_session.default_bucket_prefix

In [None]:
if default_prefix:
    input_path = f"{default_prefix}/datasets/llm-fine-tuning-modeltrainer-sft-ray"
else:
    input_path = f"datasets/llm-fine-tuning-modeltrainer-sft-ray"

prometheus_s3_path = (
    f"s3://{bucket_name}/{input_path}/prometheus/prometheus-3.4.2.linux-amd64.tar.gz"
)

In [None]:
s3_client.upload_file(
    "./prometheus-3.4.2.linux-amd64.tar.gz",
    bucket_name,
    f"{input_path}/prometheus/prometheus-3.4.2.linux-amd64.tar.gz",
)

print(f"Prometheus binary uploaded to:")
print(prometheus_s3_path)

***

## Model fine-tuning

We are now ready to fine-tune our model. We will use the [Trainer](https://huggingface.co/docs/transformers/main_classes/trainer) from transfomers to fine-tune our model. We prepared a script [train.py](./scripts/train.py) which will loads the dataset from disk, prepare the model, tokenizer and start the training.

For configuration we use `TrlParser`, that allows us to provide hyperparameters in a `yaml` file. This yaml will be uploaded and provided to Amazon SageMaker similar to our datasets. We are saving the config file as `args.yaml` and upload it to S3.

In [None]:
%%bash

cat > ./args.yaml <<EOF
model_id: "${model_id}"                           # Hugging Face model id
# sagemaker specific parameters
output_dir: "/opt/ml/model"                       # path to where SageMaker will upload the model 
checkpoint_dir: "/opt/ml/checkpoints/"
train_dataset_path: "/opt/ml/input/data/train/"   # path to where S3 saves train dataset
val_dataset_path: "/opt/ml/input/data/val/"       # path to where S3 saves test dataset
save_steps: 100                                   # Save checkpoint every this many steps
# training parameters
lora_r: 32
lora_alpha: 64
lora_dropout: 0.05                 
learning_rate: 1e-4                    # learning rate scheduler
num_train_epochs: 1                    # number of training epochs
per_device_train_batch_size: 4         # batch size per device during training
per_device_eval_batch_size: 2          # batch size for evaluation
gradient_accumulation_steps: 4         # number of steps before performing a backward/update pass
gradient_checkpointing: false          # use gradient checkpointing
bf16: true                             # use bfloat16 precision
tf32: false                            # use tf32 precision
fsdp: "full_shard auto_wrap offload"
fsdp_config: 
    backward_prefetch: "backward_pre"
    cpu_ram_efficient_loading: true
    offload_params: true
    forward_prefetch: false
    use_orig_params: true
    activation_checkpointing: true
warmup_steps: 100
weight_decay: 0.01
merge_weights: true                    # merge weights in the base model
EOF

Lets upload the config file to S3.

In [None]:
import os
from sagemaker.s3 import S3Uploader

if default_prefix:
    input_path = (
        f"s3://{bucket_name}/{default_prefix}/datasets/llm-fine-tuning-modeltrainer-sft-ray"
    )
else:
    input_path = f"s3://{bucket_name}/datasets/llm-fine-tuning-modeltrainer-sft-ray"

# upload the model yaml file to s3
model_yaml = "args.yaml"
train_config_s3_path = S3Uploader.upload(
    local_path=model_yaml, desired_s3_uri=f"{input_path}/config"
)

os.remove("./args.yaml")

print(f"Training config uploaded to:")
print(train_config_s3_path)

## Fine-tune model

Below estimtor will train the model with QLoRA, merge the adapter in the base model and save in S3

#### Get PyTorch image_uri

We are going to use the native PyTorch container image, pre-built for Amazon SageMaker

In [None]:
import sagemaker
from sagemaker.config import load_sagemaker_config

In [None]:
sagemaker_session = sagemaker.Session()

bucket_name = sagemaker_session.default_bucket()
default_prefix = sagemaker_session.default_bucket_prefix
configs = load_sagemaker_config()

In [None]:
instance_type = "ml.g5.12xlarge" # Override the instance type if you want to get a different container version
instance_count = 1

instance_type

In [None]:
image_uri = sagemaker.image_uris.retrieve(
    framework="pytorch",
    region=sagemaker_session.boto_session.region_name,
    version="2.7.1",
    instance_type=instance_type,
    image_scope="training",
)

image_uri

In [None]:
from sagemaker.modules.configs import (
    CheckpointConfig,
    Compute,
    OutputDataConfig,
    RemoteDebugConfig,
    SourceCode,
    StoppingCondition,
)
from sagemaker.modules.train import ModelTrainer

args = [
    "--entrypoint",
    "train_ray.py",
    "--config",
    "/opt/ml/input/data/config/args.yaml",  # path to TRL config which was uploaded to s3
]

# Define the script to be run
source_code = SourceCode(
    source_dir="./scripts",
    requirements="requirements.txt",
    command=f"python launcher.py {' '.join(args)}",
)

# Define the compute
compute_configs = Compute(
    instance_type=instance_type,
    instance_count=instance_count,
    keep_alive_period_in_seconds=0,
)

# define Training Job Name
job_name = f"train-{model_id.split('/')[-1].replace('.', '-')}-sft-ray"

# define OutputDataConfig path
if default_prefix:
    output_path = f"s3://{bucket_name}/{default_prefix}/{job_name}"
else:
    output_path = f"s3://{bucket_name}/{job_name}"

# Define the ModelTrainer
model_trainer = ModelTrainer(
    training_image=image_uri,
    source_code=source_code,
    base_job_name=job_name,
    compute=compute_configs,
    stopping_condition=StoppingCondition(max_runtime_in_seconds=18000),
    output_data_config=OutputDataConfig(s3_output_path=output_path),
    checkpoint_config=CheckpointConfig(
        s3_uri=output_path + "/checkpoint", local_path="/opt/ml/checkpoints"
    ),
    environment={
        # "launch_prometheus": "true", # enable for local prometheus
        "RAY_PROMETHEUS_HOST": "<PROMETHEUS_HOST>",
        "RAY_GRAFANA_HOST": "<GRAFANA_HOST>",
        "RAY_PROMETHEUS_NAME": "prometheus",
    },
    role=role,
).with_remote_debug_config(RemoteDebugConfig(enable_remote_debug=True))

In [None]:
from sagemaker.modules.configs import InputData, S3DataSource

# Pass the input data
train_input = InputData(
    channel_name="train",
    data_source=S3DataSource(
        s3_data_type="S3Prefix",
        s3_uri=train_dataset_s3_path,
        s3_data_distribution_type="FullyReplicated",
    ),  # S3 path where training data is stored
)

val_input = InputData(
    channel_name="val",
    data_source=S3DataSource(
        s3_data_type="S3Prefix",
        s3_uri=val_dataset_s3_path,
        s3_data_distribution_type="FullyReplicated",
    ),  # S3 path where val data is stored
)

config_input = InputData(
    channel_name="config",
    data_source=S3DataSource(
        s3_data_type="S3Prefix",
        s3_uri=train_config_s3_path,
        s3_data_distribution_type="FullyReplicated",
    ),  # S3 path where configs are stored
)

## Uncomment this lines if you want to provide the prometheus binary

# prometheus_input = InputData(
#     channel_name="prometheus",
#     data_source=S3DataSource(
#         s3_data_type="S3Prefix",
#         s3_uri=prometheus_s3_path,
#         s3_data_distribution_type="FullyReplicated",
#     ),  # S3 path where prometheus_s3_path binary is stored
# )

# Check input channels configured
data = [
    train_input,
    val_input,
    config_input,
    # prometheus_input,
]
data

In [None]:
# starting the train job with our uploaded datasets as input
model_trainer.train(input_data_config=data, wait=False)