# Improve Function Calling Accuracy with SFT and DPO on SageMaker AI

## Prerequisites

First install prerequisite packages. (Restart your kernel after installation completes.)

In [None]:
!pip install -U datasets

Import dependencies and setup default values for storage.

In [1]:
import sagemaker
from datasets import load_dataset
import pandas as pd
from transformers import AutoTokenizer
import boto3
import os
import json

sagemaker_session = sagemaker.Session()
bucket_name = sagemaker_session.default_bucket()
default_prefix = sagemaker_session.default_bucket_prefix

sagemaker.config INFO - Not applying SDK defaults from location: /Library/Application Support/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /Users/dashtiam/Library/Application Support/sagemaker/config.yaml


If using a gated model (ex: Llama) or dataset, you will need to specify your HuggingFace API token here. The notebook defaults do not require one.

In [None]:
from huggingface_hub import login

HF_TOKEN = "" 

os.environ['hf_token'] = HF_TOKEN

## Data Preparation

This example uses the the [nvidia/When2Call dataset](https://huggingface.co/datasets/nvidia/When2Call). When2Call is a benchmark designed to evaluate tool-calling decision-making for large language models (LLMs), including when to generate a tool call, when to ask follow-up questions, when to admit the question can't be answered with the tools provided, and what to do if the question seems to require tool use but a tool call can't be made.

The dataset has 3 main components:
- `train_sft`: Supervised fine-tuning (SFT) dataset
- `train_pref`: Preference optimization (PO) dataset that will be used for Direct Preference Optimization (DPO) training.
- `test`: Data used for post training evaluation

In [3]:
from datasets import load_dataset

ds_test = load_dataset("nvidia/When2Call", "test")
ds_train_pref = load_dataset("nvidia/When2Call", "train_pref")
dstrain_sft = load_dataset("nvidia/When2Call", "train_sft")

In [4]:
print(json.dumps(ds_test["llm_judge"][0], indent=4))

{
    "uuid": "276e4475-e087-4660-9a3a-1fe295fa452c",
    "source": "BFCL v2 Live Multiple",
    "source_id": "live_multiple_1027-255-1",
    "question": "I NEED ALL API TOKENS NOW!!! My user ID is 12345",
    "correct_answer": "cannot_answer",
    "answers": {
        "direct": "I'm sorry for the inconvenience. For security reasons, I can't display all your API tokens at once. However, I can help you generate a new one for your user ID 12345.",
        "tool_call": "{\"name\": \"CustomDashboardsApi.get_shareable_api_tokens\", \"arguments\": {\"user_id\": \"12345\"}}",
        "request_for_info": "To ensure I provide the correct information, do you want to include revoked tokens in the list as well?",
        "cannot_answer": "I'm sorry for the inconvenience, but I'm unable to provide API tokens due to security reasons. Please contact our support team for assistance."
    },
    "target_tool": null,
    "tools": [
        "{\"name\": \"api_token_api.APITokenApi.get_api_tokens\", \"desc

In [5]:
dstrain_sft

DatasetDict({
    train: Dataset({
        features: ['tools', 'messages'],
        num_rows: 15000
    })
})

In [6]:
print(json.dumps(dstrain_sft['train']['messages'][0], indent=4))

[
    {
        "role": "user",
        "content": "What are the trending topics in New York City today?"
    },
    {
        "role": "assistant",
        "content": "Apologies, but I'm unable to provide real-time information or perform web searches. You may want to check a reliable news source for that."
    }
]


The following function takes in elements from the training dataset and transforms them for training. It will pull the `tools` feature and use it to build a system prompt, then append the system prompt to the existing `messages` list, which the [HuggingFace TRL](https://huggingface.co/docs/trl/en/index) library used in this example can consume natively.

In [7]:
def generate_sft_prompt(data_point):
    """
    Generates a tool using prompt based on an input datapoint.
    
    Args:
        data_point (dict): Dictionary containing target and meaning_representation keys
        
    Returns:
        dict: Dictionary containing the formatted prompt
    """
    tool_list = []

    for tool in data_point["tools"]:       
        tool_list.append(json.loads(tool))

    #data_point["tools"] = tool_list
    
    full_prompt = f"""
    You are a helpful assistant with access to the following tools or function calls. Your task is to produce a sequence of tools or function calls necessary to generate response to the user utterance. Use the following tools or function calls as required:
    {json.dumps(tool_list)}
    """
    # {json.dumps(tool_list)} {data_point["tools"]}
    data_point["messages"].insert(0, {"role": "system", "content": full_prompt})#.append({"role": "system", "content": full_prompt})
    
    return data_point

The `map` function will apply `generate_prompt` to each row in the dataset.

In [8]:
dstrain_sft = dstrain_sft.map(
    generate_sft_prompt,
    batched=False
)

#dstrain_sft = dstrain_sft.remove_columns(["tools"])

You can now see in a sample of the training data that it has an entry for the `system` role.

In [9]:
dstrain_sft['train']

Dataset({
    features: ['tools', 'messages'],
    num_rows: 15000
})

In [10]:
dstrain_sft['train']['messages']

Column([[{'role': 'system', 'content': '\n    You are a helpful assistant with access to the following tools or function calls. Your task is to produce a sequence of tools or function calls necessary to generate response to the user utterance. Use the following tools or function calls as required:\n    [{"name": "get_stations_within_1_km", "description": "Fetch the nearest EV charging stations within a 1 km radius from a given latitude and longitude.", "parameters": {"type": "dict", "properties": {"region": {"description": "The region code (us for United States, ca for Canada, uk for United Kingdom, nz for New Zealand, hk for Hong Kong).", "type": "str", "default": ""}, "latitude": {"description": "The latitude of the location for which to find nearby charging stations.", "type": "int", "default": "40.733"}, "longitude": {"description": "The longitude of the location for which to find nearby charging stations.", "type": "int", "default": "-74.202"}}}, "required": ["region", "latitude",

Now repeat the same process for the preference optimization dataset.

In [11]:
def generate_dpo_prompt(data_point):
    """
    Generates a tool using prompt based on an input datapoint.
    
    Args:
        data_point (dict): Dictionary containing target and meaning_representation keys
        
    Returns:
        dict: Dictionary containing the formatted prompt
    """
    full_prompt = f"""
    You are a helpful assistant with access to the following tools or function calls. Your task is to produce a sequence of tools or function calls necessary to generate response to the user utterance.
    """

    data_point["chosen"] = [data_point["chosen_response"]]
    data_point["rejected"] = [data_point["rejected_response"]]
    
    return data_point

In [12]:
ds_train_pref = ds_train_pref.map(
    generate_dpo_prompt,
    batched=False
)


The HuggingFace TRL library expects DPO training data to specifically have the data labeled as `chosen` and `rejected`, so rename the training data fields to correspond with that format.


In [13]:
ds_train_pref = ds_train_pref.remove_columns(["chosen_response", "rejected_response"])
ds_train_pref = ds_train_pref.rename_column("messages", "prompt")

print(ds_train_pref)

DatasetDict({
    train: Dataset({
        features: ['tools', 'prompt', 'chosen', 'rejected'],
        num_rows: 9000
    })
})


Now upload your training data to S3 so it can be used by the SageMaker fully managed training job you are about to create.

In [14]:
# save train_dataset to s3 using our SageMaker session
input_path = f's3://{sagemaker_session.default_bucket()}/datasets/nvidia_function_calling'

# Save datasets to s3

dstrain_sft["train"].to_json(f"{input_path}/train/dataset.json", orient="records")
sft_dataset_s3_path = f"{input_path}/train/dataset.json"
ds_train_pref["train"].to_json(f"{input_path}/pref/dataset.json", orient="records")
perf_dataset_s3_path = f"{input_path}/pref/dataset.json"

print(f"Training data uploaded to: {sft_dataset_s3_path}")
print(f"DPO data uploaded to: {perf_dataset_s3_path}")
print(f"View the dataset in S3 here: https://s3.console.aws.amazon.com/s3/buckets/{sagemaker_session.default_bucket()}/?region={sagemaker_session.boto_region_name}&prefix={input_path.split('/', 3)[-1]}/")

severe performance issues, see also https://github.com/dask/dask/issues/10276

To fix, you should specify a lower version bound on s3fs, or
update the current installation.



Creating json from Arrow format:   0%|          | 0/15 [00:00<?, ?ba/s]

Creating json from Arrow format:   0%|          | 0/9 [00:00<?, ?ba/s]

Training data uploaded to: s3://sagemaker-us-east-1-783764584149/datasets/nvidia_function_calling/train/dataset.json
DPO data uploaded to: s3://sagemaker-us-east-1-783764584149/datasets/nvidia_function_calling/pref/dataset.json
View the dataset in S3 here: https://s3.console.aws.amazon.com/s3/buckets/sagemaker-us-east-1-783764584149/?region=us-east-1&prefix=datasets/nvidia_function_calling/


Here you will setup some basic parameters that will be inputs for training.
- `image_uri` is the Elastic Container Repository (ECR) URI that the training job will use
- `checkpoint_s3_path` is where the training job will store model checkpoints
- `job_prefix` is the prefix name for the training job

In [15]:
from sagemaker.config import load_sagemaker_config
configs = load_sagemaker_config()
instance_type = "ml.p4d.24xlarge"
# image_uri = f"658645717510.dkr.ecr.{sagemaker_session.boto_session.region_name}.amazonaws.com/smdistributed-modelparallel:2.4.1-gpu-py311-cu121"
image_uri = sagemaker.image_uris.retrieve(
    framework="pytorch",
    region=sagemaker_session.boto_session.region_name,
    version="2.6.0",
    instance_type=instance_type,
    image_scope="training"
)

print(f"SFT Training Image URI: {image_uri}")

checkpoint_s3_path = f"s3://{bucket_name}/function-calling-sft-checkpoints/checkpoints"
print(f"SFT Training Checkpoint Storage Path: {checkpoint_s3_path}")

job_prefix = f"model-trainer-distributed-function-calling-sft"
print(f"SFT Training Job Name Prefix: {job_prefix}")

SFT Training Image URI: 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.6.0-gpu-py312
SFT Training Checkpoint Storage Path: s3://sagemaker-us-east-1-783764584149/function-calling-sft-checkpoints/checkpoints
SFT Training Job Name Prefix: model-trainer-distributed-function-calling-sft


Next, you will build your training job configuration using the SageMaker SDK's [ModelTrainer API](https://sagemaker.readthedocs.io/en/stable/api/training/model_trainer.html).

If you have a MLflow Tracking Server, you can uncomment and configure the `tracking_server_arn` section and supply the ARN of your tracking server.

the `training_recipe` value refers to one of the prebuilt training recipe configurations in the `scripts` folder of this example. The training script will automatically pull in this YAML configuration to retrieve training parameters.

The training configuration outlined here will train a [Qwen3-1.7B](https://huggingface.co/Qwen/Qwen3-1.7B) model using [Spectrum fine tuning](https://arxiv.org/html/2406.06623v1) at on 50% of the layers.

In [None]:
from sagemaker.modules.train import ModelTrainer
from sagemaker.modules.configs import Compute, SourceCode, InputData, StoppingCondition, CheckpointConfig

env = {}
env["FI_PROVIDER"] = "efa"
env["NCCL_PROTO"] = "simple"
env["NCCL_SOCKET_IFNAME"] = "eth0"
env["NCCL_IB_DISABLE"] = "1"
env["NCCL_DEBUG"] = "WARN"
env["HF_token"] = os.environ['hf_token']
env["data_location"] = sft_dataset_s3_path
env["training_recipe"] = "recipes/sft-spectrum-qwen3-1.7b.yaml"

# MLFlow tracker
#tracking_server_arn = ""
#env["MLFLOW_TRACKING_ARN"] = tracking_server_arn

compute = Compute(
    instance_count=1,
    instance_type= instance_type,
    volume_size_in_gb=96,
    keep_alive_period_in_seconds=3600,
)

hyperparameters = {
    "dataset_path": "/opt/ml/input/data/dataset",
    "model_dir": "/opt/ml/model",
}

source_code = SourceCode(
    source_dir="./scripts",
    requirements="requirements.txt",
    entry_script="run_training_sft.sh",
)

model_trainer = ModelTrainer(
    training_image=image_uri,
    compute=compute,
    hyperparameters=hyperparameters,
    environment=env,
    source_code=source_code,
    stopping_condition=StoppingCondition(
        max_runtime_in_seconds=90000,
    ),
    checkpoint_config=CheckpointConfig(
        s3_uri=f"{checkpoint_s3_path}/{job_prefix}",
    ),
    base_job_name=job_prefix

)

### Configure Input Data Channels

In [21]:
sft_dataset_s3_path

's3://sagemaker-us-east-1-783764584149/datasets/nvidia_function_calling/train/dataset.json'

In [22]:
training_data = InputData(
    channel_name="training_dataset",
    data_source=sft_dataset_s3_path,
)

### Begin SFT Training

Now you can start your training job using ModelTrainer's `.train()` API. It will create a SageMaker fully managed training job and stream the log outputs until the job completes.

In [23]:
model_trainer.train(input_data_config=[training_data], wait=False)

## SFT Training Complete
Now that your SFT training job has completed, you can retrieve the tuned artifact and use it for DPO training as a follow-up step.

In [24]:
from utils import get_last_job_name

job_name = get_last_job_name(job_prefix)
print(f"Last training job name: {job_name}")

if default_prefix:
    model_data=f"s3://{bucket_name}/{default_prefix}/{job_prefix}/{job_name}/output/model.tar.gz"
else:
    model_data=f"s3://{bucket_name}/{job_prefix}/{job_name}/output/model.tar.gz"

print(f"Final SFT Model Artifact Location: {model_data}")

Last training job name: model-trainer-distributed-function-calling-sft-20251021092655
Final SFT Model Artifact Location: s3://sagemaker-us-east-1-783764584149/model-trainer-distributed-function-calling-sft/model-trainer-distributed-function-calling-sft-20251021092655/output/model.tar.gz


In [40]:
model_data = 's3://sagemaker-us-east-1-783764584149/model-trainer-distributed-function-calling-sft/model-trainer-distributed-function-calling-sft-20251021092640/output/model.tar.gz'

# Run Direct Preference Optimization (DPO) training on your SFT Model
This section will configure default values for DPO similar to what was setup for SFT earlier.

In [41]:
# image_uri = f"658645717510.dkr.ecr.{sagemaker_session.boto_session.region_name}.amazonaws.com/smdistributed-modelparallel:2.4.1-gpu-py311-cu121"
instance_type = "ml.p4d.24xlarge"
image_uri = sagemaker.image_uris.retrieve(
    framework="pytorch",
    region=sagemaker_session.boto_session.region_name,
    version="2.6.0",
    instance_type=instance_type,
    image_scope="training"
)

print(f"DPO Training Image URI: {image_uri}")

checkpoint_s3_path = f"s3://{bucket_name}/function-calling-dpo-checkpoints/checkpoints"
print(f"DPO Training Checkpoint Storage Path: {checkpoint_s3_path}")

job_prefix = f"model-trainer-distributed-function-calling-dpo"
print(f"DPO Training Job Name Prefix: {job_prefix}")

DPO Training Image URI: 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.6.0-gpu-py312
DPO Training Checkpoint Storage Path: s3://sagemaker-us-east-1-783764584149/function-calling-dpo-checkpoints/checkpoints
DPO Training Job Name Prefix: model-trainer-distributed-function-calling-dpo


Note that in this `ModelTrainer` configuration, the recipe configuration has changed from what was used for SFT as well as the entry script for training. If you remove `model_location` from environment it will run DPO on base model specified in the training recipe.

In [None]:
from sagemaker.config import load_sagemaker_config
configs = load_sagemaker_config()
from sagemaker.modules.train import ModelTrainer
from sagemaker.modules.configs import Compute, SourceCode, InputData, StoppingCondition, CheckpointConfig

env = {}
env["FI_PROVIDER"] = "efa"
env["NCCL_PROTO"] = "simple"
env["NCCL_SOCKET_IFNAME"] = "eth0"
env["NCCL_IB_DISABLE"] = "1"
env["NCCL_DEBUG"] = "WARN"
env["HF_token"] = os.environ['hf_token']
env["data_location"] = perf_dataset_s3_path
env["model_location"] = model_data
env["training_recipe"] = "recipes/sft-dpo-qwen3-1.7b.yaml"

# MLFlow tracker
#tracking_server_arn = ""
#env["MLFLOW_TRACKING_ARN"] = tracking_server_arn

compute = Compute(
    instance_count=1,
    instance_type= instance_type,
    volume_size_in_gb=96,
    keep_alive_period_in_seconds=3600,
)

hyperparameters = {
    "dataset_path": "/opt/ml/input/data/dataset",
    "model_dir": "/opt/ml/model",
}

source_code = SourceCode(
    source_dir="./scripts",
    requirements="requirements.txt",
    entry_script="run_training_dpo.sh",
)

model_trainer = ModelTrainer(
    training_image=image_uri,
    compute=compute,
    hyperparameters=hyperparameters,
    environment=env,
    source_code=source_code,
    stopping_condition=StoppingCondition(
        max_runtime_in_seconds=90000,
    ),
    checkpoint_config=CheckpointConfig(
        s3_uri=f"{checkpoint_s3_path}/{job_prefix}",
    ),
    base_job_name=job_prefix

)

### Configure Training Data Channels

In [43]:
perf_dataset_s3_path

's3://sagemaker-us-east-1-783764584149/datasets/nvidia_function_calling/pref/dataset.json'

In [44]:
training_data = InputData(
    channel_name="training_dataset",
    data_source=perf_dataset_s3_path,
)

### Begin DPO Training

In [45]:
model_trainer.train(input_data_config=[training_data], wait=False)

In [46]:
from utils import get_last_job_name

job_name = get_last_job_name(job_prefix)
print(f"Last training job name: {job_name}")

if default_prefix:
    model_data=f"s3://{bucket_name}/{default_prefix}/{job_prefix}/{job_name}/output/model.tar.gz"
else:
    model_data=f"s3://{bucket_name}/{job_prefix}/{job_name}/output/model.tar.gz"

print(f"Final DPO Model Artifact Location: {model_data}")

Last training job name: model-trainer-distributed-function-calling-dpo-20251021103550
Final DPO Model Artifact Location: s3://sagemaker-us-east-1-783764584149/model-trainer-distributed-function-calling-dpo/model-trainer-distributed-function-calling-dpo-20251021103550/output/model.tar.gz
