# Fine-tune LLM with PyTorch FSDP and QLora on Amazon SageMaker AI using ModelTrainer

In this notebook, we fine-tune LLM on Amazon SageMaker AI, using Python scripts and SageMaker ModelTrainer for executing a training job.

## Prerequisites

In [None]:
%pip install -r ./scripts/requirements.txt --upgrade

***

## Setup Configuration file path

If you have created a Managed MLflow server, copy the `ARN` code here and assign a name to the experiment

In [None]:
import os

model_id = "Qwen/Qwen3-0.6B"

os.environ["model_id"] = model_id
os.environ["mlflow_uri"] = ""
os.environ["mlflow_experiment_name"] = "qwen3-06b-fsdp"

***

## Visualize and upload the dataset

We are going to load [glaiveai/glaive-function-calling-v2](https://huggingface.co/datasets/glaiveai/glaive-function-calling-v2) dataset

In [None]:
import sagemaker

In [None]:
sagemaker_session = sagemaker.Session()
bucket_name = sagemaker_session.default_bucket()
default_prefix = sagemaker_session.default_bucket_prefix

In [None]:
from datasets import load_dataset

dataset = load_dataset("glaiveai/glaive-function-calling-v2", split="train[:10000]")

In [None]:
from utils.preprocessing import glaive_to_openai

processed_dataset = glaive_to_openai(dataset)

In [None]:
import pandas as pd

df = pd.DataFrame(processed_dataset)

df.head()

In [None]:
from sklearn.model_selection import train_test_split

train, val = train_test_split(df, test_size=0.1, random_state=42)

print("Number of train elements: ", len(train))
print("Number of test elements: ", len(val))

Create a prompt template and load the dataset with a random sample to try summarization.

In [None]:
import json
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_id)

def prepare_dataset(sample):
    # Parse tools only once if needed
    tools = json.loads(sample["tools"]) if sample["tools"] is not None else []

    # Define message transformations based on role
    messages = []
    for message in sample["messages"]:
        role = message["role"]

        # Base message with role and content
        msg = {"role": role}

        if role in ["system", "user"]:
            # Simple roles just need content
            msg["content"] = message["content"]
        elif role == "assistant":
            if message["tool_calls"]:
                msg.update({"content": "", "tool_calls": message["tool_calls"]})
            else:
                msg["content"] = message["content"]
        elif role == "tool":
            # Tool messages need additional fields
            msg.update(
                {
                    "content": message["content"],
                    "tool_call_id": message["tool_call_id"],
                    "name": message["name"],
                }
            )

        messages.append(msg)

    # Apply chat template
    sample["text"] = tokenizer.apply_chat_template(
        messages, tools=tools, tokenize=False
    )

    return sample

In [None]:
from datasets import Dataset, DatasetDict
from random import randint

train_dataset = Dataset.from_pandas(train)
val_dataset = Dataset.from_pandas(val)

columns_to_remove = list(processed_dataset.features)

dataset = DatasetDict({"train": train_dataset, "val": val_dataset})

train_dataset = dataset["train"].map(
    prepare_dataset, remove_columns=list(train_dataset.features)
)

print(train_dataset[randint(0, len(dataset))]["text"])

val_dataset = dataset["val"].map(
    prepare_dataset, remove_columns=list(val_dataset.features)
)

### Upload to Amazon S3

In [None]:
import boto3
import shutil
import sagemaker

In [None]:
sagemaker_session = sagemaker.Session()
s3_client = boto3.client('s3')

bucket_name = sagemaker_session.default_bucket()
default_prefix = sagemaker_session.default_bucket_prefix

In [None]:
# save train_dataset to s3 using our SageMaker session
if default_prefix:
    input_path = f"{default_prefix}/datasets/llm-fine-tuning-modeltrainer-fsdp"
else:
    input_path = f"datasets/llm-fine-tuning-modeltrainer-fsdp"

train_dataset_s3_path = f"s3://{bucket_name}/{input_path}/train/dataset.json"
val_dataset_s3_path = f"s3://{bucket_name}/{input_path}/val/dataset.json"

In [None]:
# Save datasets to s3
# We will fine tune only with 20 records due to limited compute resource for the workshop
train_dataset.to_json("./data/train/dataset.json", orient="records")
val_dataset.to_json("./data/val/dataset.json", orient="records")

s3_client.upload_file("./data/train/dataset.json", bucket_name, f"{input_path}/train/dataset.json")
s3_client.upload_file("./data/val/dataset.json", bucket_name, f"{input_path}/val/dataset.json")

shutil.rmtree("./data")

print(f"Training data uploaded to:")
print(train_dataset_s3_path)
print(val_dataset_s3_path)

***

## Model fine-tuning

We are now ready to fine-tune our model. We will use the [Trainer](https://huggingface.co/docs/transformers/main_classes/trainer) from transfomers to fine-tune our model. We prepared a script [train.py](./scripts/train.py) which will loads the dataset from disk, prepare the model, tokenizer and start the training.

For configuration we use `TrlParser`, that allows us to provide hyperparameters in a `yaml` file. This yaml will be uploaded and provided to Amazon SageMaker similar to our datasets. We are saving the config file as `args.yaml` and upload it to S3.

In [None]:
%%bash

cat > ./args.yaml <<EOF
model_id: "${model_id}"                           # Hugging Face model id
mlflow_uri: "${mlflow_uri}"                       # MLflow tracking server URI
mlflow_experiment_name: "${mlflow_experiment_name}" # MLflow experiment name
# sagemaker specific parameters
output_dir: "/opt/ml/model"                       # path to where SageMaker will upload the model 
checkpoint_dir: "/opt/ml/checkpoints/"            # directory for saving training checkpoints
train_dataset_path: "/opt/ml/input/data/train/"   # path to where S3 saves train dataset
val_dataset_path: "/opt/ml/input/data/val/"       # path to where S3 saves test dataset
token: "${HF_TOKEN}"                              # Hugging Face API token
merge_weights: true                               # merge weights in the base model
# training parameters
apply_truncation: true                           # apply truncation to datasets
attn_implementation: "flash_attention_2"         # attention implementation type
learning_rate: 5e-5                              # learning rate scheduler
num_train_epochs: 1                              # number of training epochs
per_device_train_batch_size: 4                   # batch size per device during training
per_device_eval_batch_size: 2                    # batch size for evaluation
gradient_accumulation_steps: 4                   # number of steps before performing a backward/update pass
gradient_checkpointing: true                     # use gradient checkpointing
torch_dtype: "bfloat16"                          # float precision type
bf16: true                                       # use bfloat16 precision
tf32: true                                       # use tf32 precision
ignore_data_skip: true                           # skip data loading errors
logging_strategy: "steps"                        # logging strategy
logging_steps: 1                                 # log every N steps
log_on_each_node: false                          # disable logging on each node
ddp_find_unused_parameters: false                # DDP unused parameter detection
save_total_limit: 1                              # maximum number of checkpoints to keep
save_steps: 100                                  # Save checkpoint every this many steps
warmup_steps: 100                                # number of warmup steps
weight_decay: 0.01                               # weight decay coefficient
fsdp: "full_shard auto_wrap offload"             # FSDP sharding strategy
fsdp_config:                                     # FSDP configuration options
    backward_prefetch: "backward_pre"            # prefetch parameters during backward pass
    cpu_ram_efficient_loading: true              # enable CPU RAM efficient model loading
    offload_params: true                         # offload parameters to CPU
    forward_prefetch: false                      # disable forward prefetch
    use_orig_params: true                        # use original parameter names
# LoRA parameters
load_in_4bit: true                               # enable 4-bit quantization
lora_r: 32                                       # LoRA rank
lora_alpha: 64                                   # LoRA alpha parameter
lora_dropout: 0.03                               # LoRA dropout rate
EOF

Lets upload the config file to S3.

In [None]:
import os
from sagemaker.s3 import S3Uploader

if default_prefix:
    input_path = f"s3://{bucket_name}/{default_prefix}/datasets/llm-fine-tuning-modeltrainer-fsdp"
else:
    input_path = f"s3://{bucket_name}/datasets/llm-fine-tuning-modeltrainer-fsdp"

# upload the model yaml file to s3
model_yaml = "args.yaml"
train_config_s3_path = S3Uploader.upload(local_path=model_yaml, desired_s3_uri=f"{input_path}/config")

os.remove("./args.yaml")

print(f"Training config uploaded to:")
print(train_config_s3_path)

## Fine-tune model

Below estimtor will train the model with QLoRA, merge the adapter in the base model and save in S3

#### Get PyTorch image_uri

We are going to use the native PyTorch container image, pre-built for Amazon SageMaker

In [None]:
import sagemaker
from sagemaker.config import load_sagemaker_config

In [None]:
sagemaker_session = sagemaker.Session()

bucket_name = sagemaker_session.default_bucket()
default_prefix = sagemaker_session.default_bucket_prefix
configs = load_sagemaker_config()

In [None]:
instance_type = "ml.g6.12xlarge" # Override the instance type if you want to get a different container version
instance_count = 1

instance_type

In [None]:
image_uri = sagemaker.image_uris.retrieve(
    framework="pytorch",
    region=sagemaker_session.boto_session.region_name,
    version="2.7.1",
    instance_type=instance_type,
    image_scope="training"
)

image_uri

In [None]:
from sagemaker.modules.configs import (
    CheckpointConfig,
    Compute,
    OutputDataConfig,
    SourceCode,
    StoppingCondition,
)
from sagemaker.modules.distributed import Torchrun
from sagemaker.modules.train import ModelTrainer

# Define the script to be run
source_code = SourceCode(
    source_dir="./scripts",
    requirements="requirements.txt",
    entry_script="train.py",
)

# Define the compute
compute_configs = Compute(
    instance_type=instance_type,
    instance_count=instance_count,
    keep_alive_period_in_seconds=0,
)

# define Training Job Name
job_name = f"train-{model_id.split('/')[-1].replace('.', '-')}-fsdp"

# define OutputDataConfig path
if default_prefix:
    output_path = f"s3://{bucket_name}/{default_prefix}/{job_name}"
else:
    output_path = f"s3://{bucket_name}/{job_name}"

# Define the ModelTrainer
model_trainer = ModelTrainer(
    training_image=image_uri,
    source_code=source_code,
    base_job_name=job_name,
    compute=compute_configs,
    distributed=Torchrun(),
    stopping_condition=StoppingCondition(max_runtime_in_seconds=18000),
    hyperparameters={
        "config": "/opt/ml/input/data/config/args.yaml"  # path to TRL config which was uploaded to s3
    },
    output_data_config=OutputDataConfig(s3_output_path=output_path),
    checkpoint_config=CheckpointConfig(
        s3_uri=output_path + "/checkpoint", local_path="/opt/ml/checkpoints"
    ),
)

In [None]:
from sagemaker.modules.configs import InputData

# Pass the input data
train_input = InputData(
    channel_name="train",
    data_source=train_dataset_s3_path, # S3 path where training data is stored
)

val_input = InputData(
    channel_name="val",
    data_source=val_dataset_s3_path, # S3 path where training data is stored
)

config_input = InputData(
    channel_name="config",
    data_source=train_config_s3_path, # S3 path where training data is stored
)

# Check input channels configured
data = [train_input, val_input, config_input]
data

In [None]:
# starting the train job with our uploaded datasets as input
model_trainer.train(input_data_config=data, wait=False)

***

# Model Deployment

In the following sections, we are going to deploy the fine-tuned model on an Amazon SageMaker Real-time endpoint.

## Load Fine-Tuned model

In [None]:
import boto3
import sagemaker

In [None]:
sagemaker_session = sagemaker.Session()

In [None]:
model_id = "Qwen/Qwen3-0.6B"

bucket_name = sagemaker_session.default_bucket()
default_prefix = sagemaker_session.default_bucket_prefix
job_prefix = f"train-{model_id.split('/')[-1].replace('.', '-')}-fsdp"

In [None]:
def get_last_job_name(job_name_prefix):
    sagemaker_client = boto3.client('sagemaker')

    matching_jobs = []
    next_token = None

    while True:
        # Prepare the search parameters
        search_params = {
            'Resource': 'TrainingJob',
            'SearchExpression': {
                'Filters': [
                    {
                        'Name': 'TrainingJobName',
                        'Operator': 'Contains',
                        'Value': job_name_prefix
                    },
                    {
                        'Name': 'TrainingJobStatus',
                        'Operator': 'Equals',
                        'Value': "Completed"
                    }
                ]
            },
            'SortBy': 'CreationTime',
            'SortOrder': 'Descending',
            'MaxResults': 100
        }

        # Add NextToken if we have one
        if next_token:
            search_params['NextToken'] = next_token

        # Make the search request
        search_response = sagemaker_client.search(**search_params)

        # Filter and add matching jobs
        matching_jobs.extend([
            job['TrainingJob']['TrainingJobName'] 
            for job in search_response['Results']
            if job['TrainingJob']['TrainingJobName'].startswith(job_name_prefix)
        ])

        # Check if we have more results to fetch
        next_token = search_response.get('NextToken')
        if not next_token or matching_jobs:  # Stop if we found at least one match or no more results
            break

    if not matching_jobs:
        raise ValueError(f"No completed training jobs found starting with prefix '{job_name_prefix}'")

    return matching_jobs[0]

In [None]:
job_name = get_last_job_name(job_prefix)

job_name

#### Inference configurations

In [None]:
import sagemaker
from sagemaker import get_execution_role
from sagemaker import Model

In [None]:
instance_count = 1
instance_type = "ml.g5.xlarge"
number_of_gpu = 1
health_check_timeout = 700

In [None]:
image_uri = sagemaker.image_uris.retrieve(
    framework="djl-lmi",
    region=sagemaker_session.boto_session.region_name,
    version="latest"
)

image_uri = image_uri.split("/")[0] + "/djl-inference:0.33.0-lmi15.0.0-cu128"

image_uri

In [None]:
if default_prefix:
    model_data=f"s3://{bucket_name}/{default_prefix}/{job_prefix}/{job_name}/output/model.tar.gz"
else:
    model_data=f"s3://{bucket_name}/{job_prefix}/{job_name}/output/model.tar.gz"

model = Model(
    image_uri=image_uri,
    model_data=model_data,
    role=get_execution_role(),
    env={
        "HF_MODEL_ID": "/opt/ml/model",  # path to where sagemaker stores the model
        "OPTION_TRUST_REMOTE_CODE": "true",
        "OPTION_ROLLING_BATCH": "vllm",
        "OPTION_DTYPE": "bf16",
        "OPTION_QUANTIZE": "fp8",
        "OPTION_TENSOR_PARALLEL_DEGREE": "max",
        "OPTION_MAX_ROLLING_BATCH_SIZE": "32",
        "OPTION_MODEL_LOADING_TIMEOUT": "3600",
        "OPTION_ENABLE_AUTO_TOOL_CHOICE": "true",
        "OPTION_TOOL_CALL_PARSER": "hermes",
        "OPTION_ENABLE_STREAMING": "true",
    },
)

In [None]:
endpoint_name = f"{model_id.split('/')[-1].replace('.', '-')}-djl"

In [None]:
predictor = model.deploy(
    endpoint_name=endpoint_name,
    initial_instance_count=instance_count,
    instance_type=instance_type,
    container_startup_health_check_timeout=health_check_timeout,
    model_data_download_timeout=3600
)

#### Predict

In [None]:
import sagemaker

In [None]:
sagemaker_session = sagemaker.Session()

In [None]:
model_id = "Qwen/Qwen3-0.6B"

endpoint_name = f"{model_id.split('/')[-1].replace('.', '-')}-djl"

In [None]:
predictor = sagemaker.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sagemaker_session,
    serializer=sagemaker.serializers.JSONSerializer(),
    deserializer=sagemaker.deserializers.JSONDeserializer(),
)

In [None]:
import json

system_prompt = f"""
You are a helpful AI assistant that can answer questions and provide information.
You must include your reasoning activities in the tags <thinking></thinking>
You must include your final answer in the tags <answer></answer>
You can use tools to help you with your tasks.
IMPORTANT: Before using any tool, you must have ALL required information from the user.
If any required parameter is missing, ask the user to provide it instead of making assumptions.
"""

tools = [
    {
        "type": "function",
        "function": {
            "name": "calculate_bmi",
            "description": "Calculate BMI given weight in kg and height in meters",
            "parameters": {
                "type": "object",
                "properties": {
                    "weight_kg": {
                        "type": "number",
                        "description": "Property weight_kg",
                    },
                    "height_m": {"type": "number", "description": "Property height_m"},
                },
                "required": ["weight_kg", "height_m"],
            },
        },
    },
    {
        "type": "function",
        "function": {
            "name": "fetch_weather",
            "description": 'Fetch weather information\n\nArgs:\nquery: The weather query (e.g., "weather in New York")\nnum_results: Number of results to return (default: 1)\n\nReturns:\nJSON string containing weather information\n',
            "parameters": {
                "type": "object",
                "properties": {
                    "query": {"type": "string", "description": "Property query"},
                    "num_results": {
                        "type": "integer",
                        "description": "Property num_results",
                    },
                },
                "required": ["query"],
            },
        },
    },
]

messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": "What is the weather in Rome, Italy?"},
]

response = predictor.predict(
    {
        "messages": messages,
        "temperature": 0.2,
        "max_tokens": 4096,
        "tools": tools
    }
)

response

#### Delete Endpoint

In [None]:
import sagemaker

In [None]:
sagemaker_session = sagemaker.Session()

In [None]:
model_id = "Qwen/Qwen3-0.6B"

endpoint_name = f"{model_id.split('/')[-1].replace('.', '-')}-djl"

In [None]:
predictor = sagemaker.Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sagemaker_session,
    serializer=sagemaker.serializers.JSONSerializer(),
    deserializer=sagemaker.deserializers.JSONDeserializer(),
)

In [None]:
predictor.delete_model()
predictor.delete_endpoint(delete_endpoint_config=True)