# Fine-tune LLM with PyTorch FSDP and QLora on Amazon SageMaker AI using ModelTrainer

In this notebook, we will fine-tune LLM on Amazon SageMaker AI, using Python scripts and SageMaker ModelTrainer for executing a training job.

## Prerequisites

In [None]:
%pip install -r ./scripts/requirements.txt --upgrade

## This cell will restart the kernel. Click "OK".

In [None]:
from IPython import get_ipython
get_ipython().kernel.do_shutdown(True)

***

## Setup Configuration file path

If you have created a Managed MLflow server, copy the `ARN` code here and assign a name to the experiment

In [None]:
import boto3
import shutil
from sagemaker.core.helper.session_helper import Session, get_execution_role
from sagemaker.core.config import load_sagemaker_config

sagemaker_session = Session()
s3_client = boto3.client('s3')

region = sagemaker_session.boto_session.region_name
bucket_name = sagemaker_session.default_bucket()
default_prefix = sagemaker_session.default_bucket_prefix
configs = load_sagemaker_config()

If you have your own MLflow tracking server, update the `TrackingServerName` value below to enable experiment tracking.

In [None]:
from botocore.exceptions import ClientError

try:
    response = boto3.client('sagemaker').describe_mlflow_tracking_server(
        TrackingServerName='genai-mlflow-tracker'
    )
    mlflow_tracking_server_uri = response['TrackingServerArn']
except ClientError:
    mlflow_tracking_server_uri = ""

if mlflow_tracking_server_uri == "":
    print("No MLflow Tracking Server Found, experiments will not be tracked.")
else:
    print(f"MLflow Tracking Server ARN: {mlflow_tracking_server_uri}")

In [None]:
import os

os.environ["mlflow_uri"] = mlflow_tracking_server_uri
os.environ["mlflow_experiment_name"] = "Qwen3-4B-Instruct-2507-sft"

***

## Visualize and upload the dataset

We are going to load [FreedomIntelligence/medical-o1-reasoning-SFT](https://huggingface.co/datasets/FreedomIntelligence/medical-o1-reasoning-SFT) dataset

In [None]:
from datasets import load_dataset
import pandas as pd

num_samples = 100

full_dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", "en", split=f"train[:{num_samples}]")

full_dataset[0]

In [None]:
train_test_split_datasets = full_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = train_test_split_datasets["train"]
test_dataset = train_test_split_datasets["test"]

print(f"Number of train elements: {len(train_dataset)}")
print(f"Number of test elements: {len(test_dataset)}")

Create a prompt template and load the dataset with a random sample to try summarization.

In [None]:
SYSTEM_PROMPT = """You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning. 
Below is an instruction that describes a task, paired with an input that provides further context. 
Write a response that appropriately completes the request.
Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response."""


# template dataset to add prompt to each sample
def convert_to_messages(sample, system_prompt=""):
    
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": sample["Question"]},
        {"role": "assistant", "content": f"{sample['Complex_CoT']}\n\n{sample['Response']}"}
    ]

    sample["messages"] = messages
    
    return sample

Use the Hugging Face Trainer class to fine-tune the model. Define the hyperparameters we want to use.

In [None]:
from random import randint

train_dataset = train_dataset.map(convert_to_messages, remove_columns=list(full_dataset.features), fn_kwargs={"system_prompt": SYSTEM_PROMPT})
test_dataset = test_dataset.map(convert_to_messages, remove_columns=list(full_dataset.features), fn_kwargs={"system_prompt": SYSTEM_PROMPT})

#grab a sample from the training and test sets
print(f"Train Sample:\n{train_dataset[randint(0, len(train_dataset)-1)]}\n\n")
print(f"Test Sample:\n{test_dataset[randint(0, len(test_dataset)-1)]}\n\n")

### Upload to Amazon S3

In [None]:
# save train_dataset to s3 using our SageMaker session
if default_prefix:
    input_path = f'{default_prefix}/datasets/llm-fine-tuning-modeltrainer-sft'
else:
    input_path = f'datasets/llm-fine-tuning-modeltrainer-sft'

# Save datasets to s3
train_dataset.to_json("./data/train/dataset.json", orient="records")
test_dataset.to_json("./data/test/dataset.json", orient="records")

s3_client.upload_file("./data/train/dataset.json", bucket_name, f"{input_path}/train/dataset.json")
train_dataset_s3_path = f"s3://{bucket_name}/{input_path}/train/dataset.json"
s3_client.upload_file("./data/test/dataset.json", bucket_name, f"{input_path}/test/dataset.json")
test_dataset_s3_path = f"s3://{bucket_name}/{input_path}/test/dataset.json"

shutil.rmtree("./data")

print(f"Training data uploaded to:")
print(train_dataset_s3_path)
print(test_dataset_s3_path)

In [None]:
from utils import plot_length_distribution

plot_length_distribution(
    train_dataset=train_dataset,
    validation_dataset=test_dataset,
    bins=20,
    figsize=(10, 6)
)

***

## Model fine-tuning

We are now ready to fine-tune our model.

In [None]:
model_id = "Qwen/Qwen3-4B-Instruct-2507"
model_id_filesafe = model_id.replace("/","_")

use_local_model = True #set to false for the training job to download from HF, otherwise True will download locally

In [None]:
from huggingface_hub import snapshot_download
import os
import subprocess

if use_local_model:

    model_local_location = f"../models/{model_id_filesafe}"
    print("Downloading model ", model_id)
    os.makedirs(model_local_location, exist_ok=True)
    snapshot_download(repo_id=model_id, local_dir=model_local_location)
    print(f"Model {model_id} downloaded under {model_local_location}")

    if default_prefix:
        model_s3_destination = f"s3://{bucket_name}/{default_prefix}/models/{model_id_filesafe}"
    else:
        model_s3_destination = f"s3://{bucket_name}/models/{model_id_filesafe}"
    
    print(f"Beginning Model Upload...")

    subprocess.run(['aws', 's3', 'cp', model_local_location, model_s3_destination, '--recursive', '--exclude', '.cache/*', '--exclude', '.gitattributes'])
    
    print(f"Model Uploaded to: \n {model_s3_destination}")

    os.environ["model_location"] = model_s3_destination
else:
    os.environ["model_location"] = model_id

In [None]:
%%bash

cat > ./args.yaml <<EOF

# MLflow Config
mlflow_uri: "${mlflow_uri}"
mlflow_experiment_name: "${mlflow_experiment_name}"


model_id: "${model_location}"       # Hugging Face model id, or S3 location

# sagemaker specific parameters
output_dir: "/opt/ml/model"                       # path to where SageMaker will upload the model 
train_dataset_path: "/opt/ml/input/data/train/"   # path to where FSx saves train dataset
test_dataset_path: "/opt/ml/input/data/test/"     # path to where FSx saves test dataset
# training parameters
max_seq_length: 1500  #512 # 2048
lora_r: 8
lora_alpha: 16
lora_dropout: 0.1                 
learning_rate: 2e-4                    # learning rate scheduler
num_train_epochs: 1                    # number of training epochs
per_device_train_batch_size: 2         # batch size per device during training
per_device_eval_batch_size: 1          # batch size for evaluation
gradient_accumulation_steps: 2         # number of steps before performing a backward/update pass
gradient_checkpointing: true           # use gradient checkpointing
fp16: true
bf16: false                            # use bfloat16 precision
tf32: false                            # use tf32 precision

merge_weights: true                    # merge weights in the base model
EOF

Lets upload the config file to S3.

In [None]:
from sagemaker.core.s3 import S3Uploader

if default_prefix:
    input_path = f"s3://{bucket_name}/{default_prefix}/training_config/{model_id_filesafe}"
else:
    input_path = f"s3://{bucket_name}/training_config/{model_id_filesafe}"

# upload the model yaml file to s3
model_yaml = "args.yaml"
train_config_s3_path = S3Uploader.upload(local_path=model_yaml, desired_s3_uri=f"{input_path}/config")

print(f"Training config uploaded to:")
print(train_config_s3_path)

## Fine-tune model

Below will train the model with QLoRA, merge the adapter in the base model and save in S3

#### Get PyTorch image_uri

In [None]:
instance_type = "ml.g5.2xlarge"

instance_type

In [None]:
from sagemaker.core import image_uris

image_uri = image_uris.retrieve(
    framework="pytorch",
    region=sagemaker_session.boto_session.region_name,
    version="2.6.0",
    instance_type=instance_type,
    image_scope="training"
)

image_uri

In [None]:
from sagemaker.train.model_trainer import ModelTrainer, InputData, Torchrun, StoppingCondition
from sagemaker.core.training.configs import Compute, SourceCode
from sagemaker.core.shapes import OutputDataConfig

# Define the script to be run
source_code = SourceCode(
    source_dir="./scripts",
    requirements="requirements.txt",
    entry_script="train.py",
)

# Define the compute
compute_configs = Compute(
    instance_type=instance_type,
    instance_count=1,
    keep_alive_period_in_seconds=3600,
    volume_size_in_gb=50
)

# define Training Job Name 
job_name = f"train-{model_id.split('/')[-1].replace('.', '-')}-sft-script"

# define OutputDataConfig path
if default_prefix:
    output_path = f"s3://{bucket_name}/{default_prefix}/{job_name}"
else:
    output_path = f"s3://{bucket_name}/{job_name}"

# Define the ModelTrainer
model_trainer = ModelTrainer(
    training_image=image_uri,
    source_code=source_code,
    base_job_name=job_name,
    compute=compute_configs,
    distributed=Torchrun(),
    stopping_condition=StoppingCondition(
        max_runtime_in_seconds=7200
    ),
    hyperparameters={
        "config": "/opt/ml/input/data/config/args.yaml"
    },
    output_data_config=OutputDataConfig(
        s3_output_path=output_path
    ),
    environment={
        "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True"
    }
)

In [None]:
# Pass the input data
train_input = InputData(
    channel_name="train",
    data_source=train_dataset_s3_path,
)

test_input = InputData(
    channel_name="test",
    data_source=test_dataset_s3_path,
)

config_input = InputData(
    channel_name="config",
    data_source=train_config_s3_path,
)

# Check input channels configured
data = [train_input, test_input, config_input]
data

In [None]:
# starting the train job with our uploaded datasets as input
model_trainer.train(input_data_config=data, wait=True)

***

# Model Deployment

In the following sections, we are going to deploy the fine-tuned model on an Amazon SageMaker Real-time endpoint.

## Load Fine-Tuned model

In [None]:
import sys
from utils import get_last_job_name

job_prefix = f"train-{model_id.split('/')[-1].replace('.', '-')}-sft-script"

job_name = get_last_job_name(job_prefix)

job_name

#### Inference configurations

In [None]:
instance_count = 1
instance_type = "ml.g5.2xlarge"
health_check_timeout = 700

In [None]:
inference_image_uri = f"763104351884.dkr.ecr.{region}.amazonaws.com/djl-inference:0.33.0-lmi15.0.0-cu128"
print(f"using image to host: {inference_image_uri}")

In [None]:
import json
from sagemaker.core.resources import Model, Endpoint, EndpointConfig
from sagemaker.core.shapes import ContainerDefinition, ProductionVariant

role = get_execution_role(sagemaker_session, use_default=True)

if default_prefix:
    model_data=f"s3://{bucket_name}/{default_prefix}/{job_prefix}/{job_name}/output/model.tar.gz"
else:
    model_data=f"s3://{bucket_name}/{job_prefix}/{job_name}/output/model.tar.gz"

deploy_model_name = f"Qwen3-4B-sft-{job_name[-8:]}"

core_model = Model.create(
    model_name=deploy_model_name,
    execution_role_arn=role,
    primary_container=ContainerDefinition(
        image=inference_image_uri,
        model_data_url=model_data,
        environment={
            'HF_MODEL_ID': "/opt/ml/model",
            'OPTION_TRUST_REMOTE_CODE': 'true',
            'OPTION_ROLLING_BATCH': "vllm",
            'OPTION_DTYPE': 'bf16',
            'OPTION_QUANTIZE': 'fp8',
            'OPTION_TENSOR_PARALLEL_DEGREE': 'max',
            'OPTION_MAX_ROLLING_BATCH_SIZE': '32',
            'OPTION_MODEL_LOADING_TIMEOUT': '3600',
            'OPTION_MAX_MODEL_LEN': '4096'
        }
    ),
)

In [None]:
from sagemaker.core.common_utils import name_from_base
from sagemaker.core.helper.session_helper import _wait_until, _deploy_done

endpoint_name = f"{model_id.split('/')[-1].replace('.', '-')}-sft"
TUNED_ENDPOINT_NAME = name_from_base(endpoint_name)

EndpointConfig.create(
    endpoint_config_name=TUNED_ENDPOINT_NAME,
    production_variants=[
        ProductionVariant(
            variant_name="AllTraffic",
            model_name=deploy_model_name,
            initial_instance_count=instance_count,
            instance_type=instance_type,
            container_startup_health_check_timeout_in_seconds=health_check_timeout,
            model_data_download_timeout_in_seconds=3600,
        )
    ],
)

core_endpoint = Endpoint.create(
    endpoint_name=TUNED_ENDPOINT_NAME,
    endpoint_config_name=TUNED_ENDPOINT_NAME,
)

_wait_until(lambda: _deploy_done(sagemaker_session.sagemaker_client, TUNED_ENDPOINT_NAME), poll=30)
core_endpoint = Endpoint.get(endpoint_name=TUNED_ENDPOINT_NAME)
print(f"Endpoint status: {core_endpoint.endpoint_status}")

#### Predict

In [None]:
SYSTEM_PROMPT = f"""You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning. 
Below is an instruction that describes a task, paired with an input that provides further context. 
Write a response that appropriately completes the request.
Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response."""

USER_PROMPT = "A 3-week-old child has been diagnosed with late onset perinatal meningitis, and the CSF culture shows gram-positive bacilli. What characteristic of this bacterium can specifically differentiate it from other bacterial agents?"

messages = [
    {"role": "system", "content": SYSTEM_PROMPT},
    {"role": "user", "content": USER_PROMPT},
]

messages

In [None]:
payload = json.dumps({
	"messages": messages,
    "parameters": {
        "temperature": 0.2,
        "top_p": 0.9,
        "return_full_text": False,
        "max_new_tokens": 1024
    }
})

response = core_endpoint.invoke(
    body=payload,
    content_type="application/json",
    accept="application/json",
)

result = json.loads(response.body.read().decode("utf-8"))
result["choices"][0]["message"]["content"]

### Store variables

Save the endpoint name for use later

In [None]:
%store TUNED_ENDPOINT_NAME