In [None]:
%pip install sagemaker boto3 sagemaker litellm -qU

<div class="alert alert-block alert-info">
<center>⚠️ <b>Important:</b> Please restart the kernel after installing the dependencies. ⚠️</center>
</div>

----

# Tool calling with Amazon SageMaker AI

<div class="alert alert-block alert-info">
<center>Make sure you've deployed the model according to the previous lab before proceeding.</center>
</div>

Amazon SageMaker AI APIs do not natively support tool calling. To achieve this, we have to embed the tool definition in the prompt we send to the model. We recommend using models that have been fine-tuned for function calling in order to make sure tool calling works as expected.

In [1]:
try: 
    predictor
except:
    import boto3
    from sagemaker.session import Session
    from sagemaker.predictor import Predictor
    from sagemaker.serializers import JSONSerializer
    from sagemaker.deserializers import JSONDeserializer
    
    endpoint_name = input("> Enter your endpoint name: ")
    component_name = input("> Enter your inference component name (leave empty if not using a component): ") or None

    boto_session = boto3.session.Session(region_name="us-west-2")
    session = Session(boto_session=boto_session)
    
    predictor = Predictor(
        sagemaker_session=session,
        endpoint_name=endpoint_name, component_name=component_name,
        serializer=JSONSerializer(), deserializer=JSONDeserializer()
    )



sagemaker.config INFO - Not applying SDK defaults from location: /Library/Application Support/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /Users/dggallit/Library/Application Support/sagemaker/config.yaml


In [2]:
def get_top_song(sign):
    """Returns the most popular song for the requested station.
    Args:
        call_sign (str): The call sign for the station for which you want
        the most popular song.

    Returns:
        response (json): The most popular song and artist.
    """

    song = ""
    artist = ""
    if sign == 'WZPZ':
        song = "Elemental Hotel"
        artist = "8 Storey Hike"

    else:
        raise Exception(f"Station {sign} not found.")

    return {
        "song": song,
        "artist": artist
    }

In order for the LLM to know that it can use this tool, we have to pass the tool definition to the LLM.

In [3]:
tools = [
    {
        "type": "function",
        "function": {
            "name": "get_top_song",
            "description": "Get the most popular song played on a radio station.",
            "parameters": {
                "type": "object",
                "properties": {
                    "sign": {
                        "type": "string",
                        "description": "The call sign for the radio station for which you want the most popular song. Example calls signs are WZPZ and WKRP."
                    }
                },
                "required": ["sign"],
            },
        },
    }
]

Now we can start conversing with the model.

In [4]:
input_text = "What is the most popular song on WZPZ?"

In [5]:
from datetime import datetime

system_prompt = """\
You are an AI assistant, created by AWS and powered by Amazon SageMaker AI.
Your goal is to help the user by answering their questions honestly, helpfully and truthfully.
The current date is {currentDateTime} .

Follow these principles when responding to queries:
1. Avoid tool calls if not needed
2. If uncertain, answer normally and offer to use tools
3. Always use the best tools for the query
"""
messages = [
    {'role':'system', 'content':system_prompt.format(tools=tools, currentDateTime=datetime.now())},
    {'role':'user', 'content':input_text}
]
payload = {'messages': messages, 'max_tokens': 4*1024, 'tools':tools, 'tool_choice':'auto'}

In [6]:
import boto3, json

sagemaker_runtime = boto3.client("sagemaker-runtime", region_name="us-east-1")
response = sagemaker_runtime.invoke_endpoint(
    EndpointName=endpoint_name,
    # InferenceComponentName=component_name or None,
    ContentType="application/json",
    Body=json.dumps(payload)
)
output = json.loads(response['Body'].read().decode())
output

{'id': 'chatcmpl-6fb61aeab4294502bb03dcb04b9e3c04',
 'object': 'chat.completion',
 'created': 1747317996,
 'model': 'lmi',
 'choices': [{'index': 0,
   'message': {'role': 'assistant',
    'reasoning_content': None,
    'content': None,
    'tool_calls': [{'id': 'SsUEOjFNU',
      'type': 'function',
      'function': {'name': 'get_top_song', 'arguments': '{"sign": "WZPZ"}'}}]},
   'logprobs': None,
   'finish_reason': 'tool_calls',
   'stop_reason': None}],
 'usage': {'prompt_tokens': 224,
  'total_tokens': 249,
  'completion_tokens': 25,
  'prompt_tokens_details': None},
 'prompt_logprobs': None}

In [7]:
clean_message = {k: v or "Thinking ..." for k, v in output['choices'][0]['message'].items() if k in ['role', 'content']}
messages.append(clean_message)
messages

[{'role': 'system',
  'content': 'You are an AI assistant, created by AWS and powered by Amazon SageMaker AI.\nYour goal is to help the user by answering their questions honestly, helpfully and truthfully.\nThe current date is 2025-05-15 16:06:34.237947 .\n\nFollow these principles when responding to queries:\n1. Avoid tool calls if not needed\n2. If uncertain, answer normally and offer to use tools\n3. Always use the best tools for the query\n'},
 {'role': 'user', 'content': 'What is the most popular song on WZPZ?'},
 {'role': 'assistant', 'content': 'Thinking ...'}]

In [8]:
# If stop_reason == "tool_calls", then you need to perform tool calling!
stop_reason = output['choices'][0]['finish_reason']
tool_calls = output['choices'][0]['message']['tool_calls']
stop_reason, tool_calls

('tool_calls',
 [{'id': 'SsUEOjFNU',
   'type': 'function',
   'function': {'name': 'get_top_song', 'arguments': '{"sign": "WZPZ"}'}}])

In [9]:
import sys
if stop_reason == "tool_calls":
    tool_calls = output['choices'][0]['message']['tool_calls']
    for tool_call in tool_calls:
        if tool_call['type'] == 'function':
            name = tool_call['function']['name']
            args = json.loads(tool_call['function']['arguments'])
        # Execute the function with name from tool_call['function']['name']
        tool_foo = getattr(sys.modules[__name__], name)
        output = tool_foo(**args)
    output

In [10]:
tool_result_message = {
    "role": "user", "content": json.dumps(output)
}
messages.append(tool_result_message)

In [11]:
messages

[{'role': 'system',
  'content': 'You are an AI assistant, created by AWS and powered by Amazon SageMaker AI.\nYour goal is to help the user by answering their questions honestly, helpfully and truthfully.\nThe current date is 2025-05-15 16:06:34.237947 .\n\nFollow these principles when responding to queries:\n1. Avoid tool calls if not needed\n2. If uncertain, answer normally and offer to use tools\n3. Always use the best tools for the query\n'},
 {'role': 'user', 'content': 'What is the most popular song on WZPZ?'},
 {'role': 'assistant', 'content': 'Thinking ...'},
 {'role': 'user',
  'content': '{"song": "Elemental Hotel", "artist": "8 Storey Hike"}'}]

In [12]:
payload = {'messages': messages, 'max_tokens': 4*1024}
response = sagemaker_runtime.invoke_endpoint(
    EndpointName=endpoint_name,
    # InferenceComponentName=component_name or None,
    ContentType="application/json",
    Body=json.dumps(payload)
)
output = json.loads(response['Body'].read().decode())
output

{'id': 'chatcmpl-be1f5cb4e016437d8b91b2ffb8eca279',
 'object': 'chat.completion',
 'created': 1747318020,
 'model': 'lmi',
 'choices': [{'index': 0,
   'message': {'role': 'assistant',
    'reasoning_content': None,
    'content': 'The most popular song on WZPZ right now is "Elemental Hotel" by 8 Storey Hike.',
    'tool_calls': []},
   'logprobs': None,
   'finish_reason': 'stop',
   'stop_reason': None}],
 'usage': {'prompt_tokens': 150,
  'total_tokens': 176,
  'completion_tokens': 26,
  'prompt_tokens_details': None},
 'prompt_logprobs': None}