# Tool calling with Amazon SageMaker AI

<div class="alert alert-block alert-info">
<center>Make sure you've deployed the model according to the previous lab before proceeding.</center>
</div>

Amazon SageMaker AI APIs do not natively support tool calling. To achieve this, we have to embed the tool definition in the prompt we send to the model. We recommend using models that have been fine-tuned for function calling in order to make sure tool calling works as expected.

## Dependencies (Warnings are safe to ignore)

In [None]:
%pip uninstall -q -y autogluon-multimodal autogluon-timeseries autogluon-features autogluon-common autogluon-core
%pip install -Uq sagemaker==2.239.0
%pip install -Uq boto3==1.38.33

## This cell will restart the kernel. Wait for the pop-up box to appear, then click "OK" before proceeding.

In [None]:
from IPython import get_ipython
get_ipython().kernel.do_shutdown(True)

Fetch the `SAGEMAKER_ENDPOINT_NAME` that was deployed during prerequisites.

In [None]:
%store -r SAGEMAKER_ENDPOINT_NAME
print(f"Endpoint name: {SAGEMAKER_ENDPOINT_NAME}")

Setup a SageMaker `Predictor` for invoking your endpoint.

In [None]:
import boto3
from sagemaker.session import Session
from sagemaker.predictor import Predictor
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer

boto_session = boto3.session.Session(region_name=boto3.Session().region_name)
session = Session(boto_session=boto_session)

predictor = Predictor(
    sagemaker_session=session,
    endpoint_name=SAGEMAKER_ENDPOINT_NAME,
    serializer=JSONSerializer(), deserializer=JSONDeserializer()
)

Next, create a function `get_top_song()` to use as a tool with your model. This basic tool will take in a `sign` parameter of a radio station and will mock up a return for the top song and artist.

In [None]:
def get_top_song(sign):
    """Returns the most popular song for the requested station.
    Args:
        call_sign (str): The call sign for the station for which you want
        the most popular song.

    Returns:
        response (json): The most popular song and artist.
    """

    song = ""
    artist = ""
    if sign == 'WZPZ':
        song = "Elemental Hotel"
        artist = "8 Storey Hike"

    else:
        raise Exception(f"Station {sign} not found.")

    return {
        "song": song,
        "artist": artist
    }

Next, build a tool definition. This will later be passed to the LLM and will provide it with the data it needs to understand what the tool is for and what to invoke it with.

The `description` inside of the first `function` object will be used to determine what the tool is for, and the `description` of the `properties` fields will help ensure the LLM submits the correct values when calling the tool.

In [None]:
tools = [
    {
        "type": "function",
        "function": {
            "name": "get_top_song",
            "description": "Get the most popular song played on a radio station.",
            "parameters": {
                "type": "object",
                "properties": {
                    "sign": {
                        "type": "string",
                        "description": "The call sign for the radio station for which you want the most popular song. Example calls signs are WZPZ and WKRP."
                    }
                },
                "required": ["sign"],
            },
        },
    }
]

Now we can start conversing with the model.

In [None]:
input_text = "What is the most popular song on WZPZ?"

In [None]:
from datetime import datetime

system_prompt = """\
You are an AI assistant, created by AWS and powered by Amazon SageMaker AI.
Your goal is to help the user by answering their questions honestly, helpfully and truthfully.
The current date is {currentDateTime} .

Follow these principles when responding to queries:
1. Avoid tool calls if not needed
2. If uncertain, answer normally and offer to use tools
3. Always use the best tools for the query
"""
messages = [
    {'role':'system', 'content':system_prompt.format(currentDateTime=datetime.now())},
    {'role':'user', 'content':input_text}
]

payload = {
    "messages": messages,
    "tools": tools,
    "tool_choice": "auto", # Requires: OPTION_TOOL_CALL_PARSER, OPTION_ENABLE_AUTO_TOOL_CHOICE
    "max_tokens": 4096,
    "temperature": 0.1,
    "top_p": 0.9,
}

In [None]:
import boto3, json

sagemaker_runtime = boto3.client("sagemaker-runtime", region_name=boto3.Session().region_name)
response = sagemaker_runtime.invoke_endpoint(
    EndpointName=SAGEMAKER_ENDPOINT_NAME,
    ContentType="application/json",
    Body=json.dumps(payload)
)
output = json.loads(response['Body'].read().decode())

output

In [None]:
clean_message = {k: v or "Thinking ..." for k, v in output['choices'][0]['message'].items() if k in ['role', 'content']}
messages.append(clean_message)
messages

In [None]:
# If stop_reason == "tool_calls", then you need to perform tool calling!
stop_reason = output['choices'][0]['finish_reason']
tool_calls = output['choices'][0]['message']['tool_calls']
stop_reason, tool_calls

In [None]:
import sys
if stop_reason == "tool_calls":
    tool_calls = output['choices'][0]['message']['tool_calls']
    for tool_call in tool_calls:
        if tool_call['type'] == 'function':
            name = tool_call['function']['name']
            args = json.loads(tool_call['function']['arguments'])
        # Execute the function with name from tool_call['function']['name']
        tool_foo = getattr(sys.modules[__name__], name)
        output = tool_foo(**args)
    output

In [None]:
tool_result_message = {
    "role": "user", "content": json.dumps(output)
}
messages.append(tool_result_message)

In [None]:
messages

In [None]:
payload = {'messages': messages, 'max_tokens': 4*1024}
response = sagemaker_runtime.invoke_endpoint(
    EndpointName=SAGEMAKER_ENDPOINT_NAME,
    ContentType="application/json",
    Body=json.dumps(payload)
)
output = json.loads(response['Body'].read().decode())
output