In [None]:
%pip install sagemaker boto3 litellm -qU

<div class="alert alert-block alert-info">
<center>⚠️ <b>Important:</b> Please restart the kernel after installing the dependencies. ⚠️</center>
</div>

----

# Tool calling with Amazon SageMaker AI

<div class="alert alert-block alert-info">
<center>Make sure you've deployed the model according to the previous lab before proceeding.</center>
</div>

Amazon SageMaker AI APIs do not natively support tool calling. To achieve this, we have to embed the tool definition in the prompt we send to the model. We recommend using models that have been fine-tuned for function calling in order to make sure tool calling works as expected.

In [None]:
try: 
    predictor
except:
    import boto3
    from sagemaker.session import Session
    from sagemaker.predictor import Predictor
    from sagemaker.serializers import JSONSerializer
    from sagemaker.deserializers import JSONDeserializer
    
    endpoint_name = "YOUR-ENDPOINT-NAME"
    component_name = "YOUR-INFERENCE-COMPONENT-NAME"

    boto_session = boto3.session.Session(region_name="us-west-2")
    session = Session(boto_session=boto_session)
    
    predictor = Predictor(
        sagemaker_session=session,
        endpoint_name=endpoint_name, component_name=component_name,
        serializer=JSONSerializer(), deserializer=JSONDeserializer()
    )

In [None]:
def get_top_song(sign):
    """Returns the most popular song for the requested station.
    Args:
        call_sign (str): The call sign for the station for which you want
        the most popular song.

    Returns:
        response (json): The most popular song and artist.
    """

    song = ""
    artist = ""
    if sign == 'WZPZ':
        song = "Elemental Hotel"
        artist = "8 Storey Hike"

    else:
        raise Exception(f"Station {sign} not found.")

    return {
        "song": song,
        "artist": artist
    }

In order for the LLM to know that it can use this tool, we have to pass the tool definition to the LLM.

In [None]:
tools = [
    {
        "type": "function",
        "function": {
            "name": "get_top_song",
            "description": "Get the most popular song played on a radio station.",
            "parameters": {
                "type": "object",
                "properties": {
                    "sign": {
                        "type": "string",
                        "description": "The call sign for the radio station for which you want the most popular song. Example calls signs are WZPZ and WKRP."
                    }
                },
                "required": ["sign"],
            },
        },
    }
]

Now we can start conversing with the model.

In [None]:
input_text = "What is the most popular song on WZPZ?"

In [None]:
initial_prompt = """\
Answer the following query: {user_query}\n
You have access to the following tools:
{tools}
Reply separating your thinking from the function call with <thinking> </thinking> and <tool> </tool> tags.
Make sure that your response only contains one set of tags each, no more.\
"""
messages = [{'role':'user', 'content':initial_prompt.format(user_query=input_text, tools=tools)}]
payload = {'messages': messages, 'max_tokens': 4*1024}

In [None]:
import boto3, json

sagemaker_runtime = boto3.client("sagemaker-runtime", region_name="us-east-1")
response = sagemaker_runtime.invoke_endpoint(
    EndpointName=endpoint_name,
    InferenceComponentName=component_name or None,
    ContentType="application/json",
    Body=json.dumps(payload)
)
output = json.loads(response['Body'].read().decode())
output

In [None]:
messages.append(output['choices'][0]['message'])
messages

In [None]:
output['choices'][0]['message']['content']

In [None]:
# Obtain the text inside <thinking> tags
import re
thinking = re.findall(r'<thinking>(.*?)</thinking>', output['choices'][0]['message']['content'], re.DOTALL)
tool_call = json.loads(re.findall(r'<tool>(.*?)</tool>', output['choices'][0]['message']['content'], re.DOTALL)[0])
print(thinking, tool_call)

In [None]:
song, artist = get_top_song(tool_call['arguments']['sign'])
tool_result = {
    "content": [{"song": song, "artist": artist}]
}
tool_result

In [None]:
tool_result_message = {
    "role": "user", "content": json.dumps([{"toolResult": tool_result}])
}
messages.append(tool_result_message)

In [None]:
payload = {'messages': messages, 'max_tokens': 4*1024}
response = sagemaker_runtime.invoke_endpoint(
    EndpointName=endpoint_name,
    InferenceComponentName=component_name or None,
    ContentType="application/json",
    Body=json.dumps(payload)
)
output = json.loads(response['Body'].read().decode())
output