# Fine-tune LLM with PyTorch FSDP and QLora on Amazon SageMaker AI using ModelTrainer

In this notebook, we fine-tune LLM on Amazon SageMaker AI, using Python scripts and SageMaker ModelTrainer for executing a training job.

## Prerequisites

In [None]:
%pip install -r ./scripts/requirements.txt --upgrade

***

## Setup Configuration file path

If you have created a Managed MLflow server, copy the `ARN` code here and assign a name to the experiment

In [None]:
import os

os.environ["mlflow_uri"] = ""
os.environ["mlflow_experiment_name"] = ""

***

## Visualize and upload the dataset

We are going to load [FreedomIntelligence/medical-o1-reasoning-SFT](https://huggingface.co/datasets/FreedomIntelligence/medical-o1-reasoning-SFT) dataset

In [None]:
import sagemaker

In [None]:
sagemaker_session = sagemaker.Session()
bucket_name = sagemaker_session.default_bucket()
default_prefix = sagemaker_session.default_bucket_prefix

In [None]:
from datasets import load_dataset
import pandas as pd

dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", "en")

df = pd.DataFrame(dataset['train'])
df = df[:1000]

df.head()

In [None]:
from sklearn.model_selection import train_test_split

train, test = train_test_split(df, test_size=0.1, random_state=42)

print("Number of train elements: ", len(train))
print("Number of test elements: ", len(test))

Create a prompt template and load the dataset with a random sample to try summarization.

In [None]:
# custom instruct prompt start
prompt_template = f"""
<｜begin▁of▁sentence｜>
You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning. 
Below is an instruction that describes a task, paired with an input that provides further context. 
Write a response that appropriately completes the request.
Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.
<｜User｜>
{{question}}
<｜Assistant｜>
<think>
{{complex_cot}}
</think>

{{answer}}
<｜end▁of▁sentence｜>
"""

# template dataset to add prompt to each sample
def template_dataset(sample):
    sample["text"] = prompt_template.format(question=sample["Question"],
                                            complex_cot=sample["Complex_CoT"],
                                            answer=sample["Response"])
    return sample

Use the Hugging Face Trainer class to fine-tune the model. Define the hyperparameters we want to use. We also create a DataCollator that will take care of padding our inputs and labels.

In [None]:
from datasets import Dataset, DatasetDict
from random import randint

train_dataset = Dataset.from_pandas(train)
test_dataset = Dataset.from_pandas(test)

dataset = DatasetDict({"train": train_dataset, "test": test_dataset})

train_dataset = dataset["train"].map(template_dataset, remove_columns=list(dataset["train"].features))

print(train_dataset[randint(0, len(dataset))]["text"])

test_dataset = dataset["test"].map(template_dataset, remove_columns=list(dataset["test"].features))

### Upload to Amazon S3

In [None]:
import boto3
import shutil
import sagemaker

In [None]:
sagemaker_session = sagemaker.Session()
s3_client = boto3.client('s3')

bucket_name = sagemaker_session.default_bucket()
default_prefix = sagemaker_session.default_bucket_prefix

In [None]:
# save train_dataset to s3 using our SageMaker session
if default_prefix:
    input_path = f'{default_prefix}/datasets/llm-fine-tuning-modeltrainer-sft'
else:
    input_path = f'datasets/llm-fine-tuning-modeltrainer-sft'

# Save datasets to s3
# We will fine tune only with 20 records due to limited compute resource for the workshop
train_dataset.to_json("./data/train/dataset.json", orient="records")
test_dataset.to_json("./data/test/dataset.json", orient="records")

s3_client.upload_file("./data/train/dataset.json", bucket_name, f"{input_path}/train/dataset.json")
train_dataset_s3_path = f"s3://{bucket_name}/{input_path}/train/dataset.json"
s3_client.upload_file("./data/test/dataset.json", bucket_name, f"{input_path}/test/dataset.json")
test_dataset_s3_path = f"s3://{bucket_name}/{input_path}/test/dataset.json"

shutil.rmtree("./data")

print(f"Training data uploaded to:")
print(train_dataset_s3_path)
print(test_dataset_s3_path)

***

## Model fine-tuning

We are now ready to fine-tune our model. We will use the [Trainer](https://huggingface.co/docs/transformers/main_classes/trainer) from transfomers to fine-tune our model. We prepared a script [train.py](./scripts/train.py) which will loads the dataset from disk, prepare the model, tokenizer and start the training.

For configuration we use `TrlParser`, that allows us to provide hyperparameters in a `yaml` file. This yaml will be uploaded and provided to Amazon SageMaker similar to our datasets. Below is the config file for fine-tuning the model on `ml.g5.12xlarge`. We are saving the config file as `args.yaml` and upload it to S3.

In [None]:
%%bash

cat > ./args.yaml <<EOF
model_id: "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"       # Hugging Face model id
mlflow_uri: "${mlflow_uri}"
mlflow_experiment_name: "${mlflow_experiment_name}"
# sagemaker specific parameters
output_dir: "/opt/ml/model"                       # path to where SageMaker will upload the model 
train_dataset_path: "/opt/ml/input/data/train/"   # path to where FSx saves train dataset
test_dataset_path: "/opt/ml/input/data/test/"     # path to where FSx saves test dataset
# training parameters
lora_r: 8
lora_alpha: 16
lora_dropout: 0.1                 
learning_rate: 2e-4                    # learning rate scheduler
num_train_epochs: 1                    # number of training epochs
per_device_train_batch_size: 2         # batch size per device during training
per_device_eval_batch_size: 1          # batch size for evaluation
gradient_accumulation_steps: 2         # number of steps before performing a backward/update pass
gradient_checkpointing: true           # use gradient checkpointing
bf16: true                             # use bfloat16 precision
tf32: false                            # use tf32 precision
fsdp: "full_shard auto_wrap offload"
fsdp_config: 
    backward_prefetch: "backward_pre"
    cpu_ram_efficient_loading: true
    offload_params: true
    forward_prefetch: false
    use_orig_params: true
merge_weights: true                    # merge weights in the base model
EOF

Lets upload the config file to S3.

In [None]:
from sagemaker.s3 import S3Uploader

if default_prefix:
    input_path = f"s3://{bucket_name}/{default_prefix}/datasets/llm-fine-tuning-modeltrainer-sft"
else:
    input_path = f"s3://{bucket_name}/datasets/llm-fine-tuning-modeltrainer-sft"

# upload the model yaml file to s3
model_yaml = "args.yaml"
train_config_s3_path = S3Uploader.upload(local_path=model_yaml, desired_s3_uri=f"{input_path}/config")

print(f"Training config uploaded to:")
print(train_config_s3_path)

## Fine-tune model

Below estimtor will train the model with QLoRA, merge the adapter in the base model and save in S3

#### Get PyTorch image_uri

We are going to use the native PyTorch container image, pre-built for Amazon SageMaker

In [None]:
import sagemaker
from sagemaker import image_uris
from sagemaker.config import load_sagemaker_config

In [None]:
sagemaker_session = sagemaker.Session()

bucket_name = sagemaker_session.default_bucket()
default_prefix = sagemaker_session.default_bucket_prefix
configs = load_sagemaker_config()

In [None]:
instance_type = "ml.g5.12xlarge" # Override the instance type if you want to get a different container version

instance_type

In [None]:
image_uri = sagemaker.image_uris.retrieve(
    framework="pytorch",
    region=sagemaker_session.boto_session.region_name,
    version="2.5.1",
    instance_type=instance_type,
    image_scope="training"
)

image_uri

In [None]:
model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"

In [None]:
from sagemaker.modules.configs import Compute, InputData, OutputDataConfig, SourceCode, StoppingCondition
from sagemaker.modules.distributed import Torchrun
from sagemaker.modules.train import ModelTrainer

# Define the script to be run
source_code = SourceCode(
    source_dir="./scripts",
    requirements="requirements.txt",
    entry_script="train.py",
)

# Define the compute
compute_configs = Compute(
    instance_type=instance_type,
    instance_count=1,
    keep_alive_period_in_seconds=0
)

# define Training Job Name 
job_name = f"train-{model_id.split('/')[-1].replace('.', '-')}-sft-script"

# define OutputDataConfig path
if default_prefix:
    output_path = f"s3://{bucket_name}/{default_prefix}/{job_name}"
else:
    output_path = f"s3://{bucket_name}/{job_name}"

# Define the ModelTrainer
model_trainer = ModelTrainer(
    training_image=image_uri,
    source_code=source_code,
    base_job_name=job_name,
    compute=compute_configs,
    distributed=Torchrun(),
    stopping_condition=StoppingCondition(
        max_runtime_in_seconds=7200
    ),
    hyperparameters={
        "config": "/opt/ml/input/data/config/args.yaml" # path to TRL config which was uploaded to s3
    },
    output_data_config=OutputDataConfig(
        s3_output_path=output_path
    ),
)

In [None]:
from sagemaker.modules.configs import InputData

# Pass the input data
train_input = InputData(
    channel_name="train",
    data_source=train_dataset_s3_path, # S3 path where training data is stored
)

test_input = InputData(
    channel_name="test",
    data_source=test_dataset_s3_path, # S3 path where training data is stored
)

config_input = InputData(
    channel_name="config",
    data_source=train_config_s3_path, # S3 path where training data is stored
)

# Check input channels configured
data = [train_input, test_input, config_input]
data

In [None]:
# starting the train job with our uploaded datasets as input
model_trainer.train(input_data_config=data, wait=False)

***

# Model evaluation

In the following sections, we are going to evaluate the fine-tuned model, by using the ROUGE metrics (ROUGE-1, ROUGE-2, ROUGE-L, and ROUGE-L-Sum), which measure the similarity between machine-generated text and human-written reference text.

## Load Fine-Tuned model

In [None]:
import boto3
import sagemaker

In [None]:
sagemaker_session = sagemaker.Session()

In [None]:
model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"

bucket_name = sagemaker_session.default_bucket()
default_prefix = sagemaker_session.default_bucket_prefix
job_prefix = f"train-{model_id.split('/')[-1].replace('.', '-')}-sft-script"

In [None]:
def get_last_job_name(job_name_prefix):
    sagemaker_client = boto3.client('sagemaker')

    matching_jobs = []
    next_token = None

    while True:
        # Prepare the search parameters
        search_params = {
            'Resource': 'TrainingJob',
            'SearchExpression': {
                'Filters': [
                    {
                        'Name': 'TrainingJobName',
                        'Operator': 'Contains',
                        'Value': job_name_prefix
                    },
                    {
                        'Name': 'TrainingJobStatus',
                        'Operator': 'Equals',
                        'Value': "Completed"
                    }
                ]
            },
            'SortBy': 'CreationTime',
            'SortOrder': 'Descending',
            'MaxResults': 100
        }

        # Add NextToken if we have one
        if next_token:
            search_params['NextToken'] = next_token

        # Make the search request
        search_response = sagemaker_client.search(**search_params)

        # Filter and add matching jobs
        matching_jobs.extend([
            job['TrainingJob']['TrainingJobName'] 
            for job in search_response['Results']
            if job['TrainingJob']['TrainingJobName'].startswith(job_name_prefix)
        ])

        # Check if we have more results to fetch
        next_token = search_response.get('NextToken')
        if not next_token or matching_jobs:  # Stop if we found at least one match or no more results
            break

    if not matching_jobs:
        raise ValueError(f"No completed training jobs found starting with prefix '{job_name_prefix}'")

    return matching_jobs[0]

In [None]:
job_name = get_last_job_name(job_prefix)

job_name

Define S3 path for the model data

In [None]:
if default_prefix:
    model_data = f"s3://{bucket_name}/{default_prefix}/{job_prefix}/{job_name}/output/model.tar.gz"
else:
    model_data = f"s3://{bucket_name}/{job_prefix}/{job_name}/output/model.tar.gz"

model_data

### Run evaluation job using SageMaker ModelTrainer

In [None]:
instance_type = "ml.g5.12xlarge" # Override the instance type if you want to get a different container version

instance_type

In [None]:
image_uri = sagemaker.image_uris.retrieve(
    framework="pytorch",
    region=sagemaker_session.boto_session.region_name,
    version="2.6.0",
    instance_type=instance_type,
    image_scope="training"
)

image_uri

In [None]:
from sagemaker.modules.configs import Compute, SourceCode, StoppingCondition
from sagemaker.modules.train import ModelTrainer

# Define the script to be run
source_code = SourceCode(
    source_dir="./scripts",
    requirements="requirements.txt",
    entry_script="rouge_evaluation.py",
)

# Define the compute
compute_configs = Compute(
    instance_type=instance_type,
    instance_count=1,
    keep_alive_period_in_seconds=0
)

# define Training Job Name
job_name = f"eval-{job_prefix}"

# Define the ModelTrainer
model_trainer = ModelTrainer(
    training_image=image_uri,
    source_code=source_code,
    base_job_name=job_name,
    compute=compute_configs,
    stopping_condition=StoppingCondition(
        max_runtime_in_seconds=7200
    ),
    hyperparameters={
        "model_id": "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",  # Hugging Face model id
        "dataset_name": "FreedomIntelligence/medical-o1-reasoning-SFT"
    }
)

In [None]:
from sagemaker.modules.configs import InputData

# Pass the input data
train_input = InputData(
    channel_name="adapter",
    data_source=model_data,
)

# Check input channels configured
data = [train_input]
data

In [None]:
# starting the train job with our uploaded datasets as input
model_trainer.train(input_data_config=data, wait=False)

***

# Model Deployment

In the following sections, we are going to deploy the fine-tuned model on an Amazon SageMaker Real-time endpoint.

## Load Fine-Tuned model

In [None]:
import boto3
import sagemaker

In [None]:
sagemaker_session = sagemaker.Session()

In [None]:
model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"

bucket_name = sagemaker_session.default_bucket()
default_prefix = sagemaker_session.default_bucket_prefix
job_prefix = f"train-{model_id.split('/')[-1].replace('.', '-')}-sft-script"

In [None]:
def get_last_job_name(job_name_prefix):
    sagemaker_client = boto3.client('sagemaker')

    matching_jobs = []
    next_token = None

    while True:
        # Prepare the search parameters
        search_params = {
            'Resource': 'TrainingJob',
            'SearchExpression': {
                'Filters': [
                    {
                        'Name': 'TrainingJobName',
                        'Operator': 'Contains',
                        'Value': job_name_prefix
                    },
                    {
                        'Name': 'TrainingJobStatus',
                        'Operator': 'Equals',
                        'Value': "Completed"
                    }
                ]
            },
            'SortBy': 'CreationTime',
            'SortOrder': 'Descending',
            'MaxResults': 100
        }

        # Add NextToken if we have one
        if next_token:
            search_params['NextToken'] = next_token

        # Make the search request
        search_response = sagemaker_client.search(**search_params)

        # Filter and add matching jobs
        matching_jobs.extend([
            job['TrainingJob']['TrainingJobName'] 
            for job in search_response['Results']
            if job['TrainingJob']['TrainingJobName'].startswith(job_name_prefix)
        ])

        # Check if we have more results to fetch
        next_token = search_response.get('NextToken')
        if not next_token or matching_jobs:  # Stop if we found at least one match or no more results
            break

    if not matching_jobs:
        raise ValueError(f"No completed training jobs found starting with prefix '{job_name_prefix}'")

    return matching_jobs[0]

In [None]:
job_name = get_last_job_name(job_prefix)

job_name

#### Inference configurations

In [None]:
import sagemaker
from sagemaker import get_execution_role
from sagemaker import Model

In [None]:
instance_count = 1
instance_type = "ml.g5.12xlarge"
number_of_gpu = 1
health_check_timeout = 700

In [None]:
image_uri = sagemaker.image_uris.retrieve(
    framework="djl-lmi",
    region=sagemaker_session.boto_session.region_name,
    version="latest"
)

image_uri

In [None]:
if default_prefix:
    model_data=f"s3://{bucket_name}/{default_prefix}/{job_prefix}/{job_name}/output/model.tar.gz"
else:
    model_data=f"s3://{bucket_name}/{job_prefix}/{job_name}/output/model.tar.gz"

model = Model(
    image_uri=image_uri,
    model_data=model_data,
    role=get_execution_role(),
    env={
        'HF_MODEL_ID': "/opt/ml/model", # path to where sagemaker stores the model
        'OPTION_TRUST_REMOTE_CODE': 'true',
        'OPTION_ROLLING_BATCH': "vllm",
        'OPTION_DTYPE': 'bf16',
        'OPTION_QUANTIZE': 'fp8',
        'OPTION_TENSOR_PARALLEL_DEGREE': 'max',
        'OPTION_MAX_ROLLING_BATCH_SIZE': '32',
        'OPTION_MODEL_LOADING_TIMEOUT': '3600',
        'OPTION_MAX_MODEL_LEN': '4096'
    }
)

In [None]:
model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
endpoint_name = f"{model_id.split('/')[-1].replace('.', '-')}-sft-djl"

In [None]:
predictor = model.deploy(
    endpoint_name=endpoint_name,
    initial_instance_count=instance_count,
    instance_type=instance_type,
    container_startup_health_check_timeout=health_check_timeout,
    model_data_download_timeout=3600
)

#### Predict

In [None]:
import sagemaker

In [None]:
sagemaker_session = sagemaker.Session()

In [None]:
model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"

endpoint_name = f"{model_id.split('/')[-1].replace('.', '-')}-sft-djl"

In [None]:
predictor = sagemaker.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sagemaker_session,
    serializer=sagemaker.serializers.JSONSerializer(),
    deserializer=sagemaker.deserializers.JSONDeserializer(),
)

In [None]:
base_prompt = f"""
<｜begin▁of▁sentence｜>
You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning. 
Below is an instruction that describes a task, paired with an input that provides further context. 
Write a response that appropriately completes the request.
Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.
<｜User｜>
{{question}}
<｜Assistant｜>
"""

In [None]:
prompt = base_prompt.format(
    question="A 3-week-old child has been diagnosed with late onset perinatal meningitis, and the CSF culture shows gram-positive bacilli. What characteristic of this bacterium can specifically differentiate it from other bacterial agents?"
)

print(prompt)

In [None]:
response = predictor.predict({
	"inputs": prompt,
    "parameters": {
        "temperature": 0.2,
        "top_p": 0.9,
        "return_full_text": False,
        "max_new_tokens": 1024,
        "stop": ['<｜end▁of▁sentence｜>']
    }
})

response = response["generated_text"].split("<｜end▁of▁sentence｜>")[0]

response

#### Delete Endpoint

In [None]:
import sagemaker

In [None]:
sagemaker_session = sagemaker.Session()

In [None]:
model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"

endpoint_name = f"{model_id.split('/')[-1].replace('.', '-')}-sft-djl"

In [None]:
predictor = sagemaker.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sagemaker_session,
    serializer=sagemaker.serializers.JSONSerializer(),
    deserializer=sagemaker.deserializers.JSONDeserializer(),
)

In [None]:
predictor.delete_model()
predictor.delete_endpoint(delete_endpoint_config=True)