# Enabling Tool Calling with SageMaker Real-Time Endpoints utilizing vLLM via the LMI Container

In this notebook we'll explore how you can enable tool-calling directly to SageMaker Hosted LLMs via vLLM. While you can use orchestration frameworks such as LangChain with their higher level Agent constructs, vLLM also supports native tool-calling via it's serving engine. 

For more complicated Agentic workflows with multiple agents and built-in memory/session management (note you can also enable sticky session routing with vLLM) I'd recommend utilizing LangChain and multi-agent frameworks such as LangGraph. However, in the case you have more simple tool specs and workflows and want to stick to native vLLM specs this is a great option to consider.

### Additional Resources/Credits

- Great blog by my colleague [Davide Gallitelli](https://www.linkedin.com/in/dgallitelli/) that I used to help understand this functionality: https://dgallitelli95.medium.com/tool-calling-with-amazon-sagemaker-ai-and-djl-serving-inference-6a97dc854881. Check out some of his other work he puts some great stuff out there.
- vLLM Official Tool Calling Docs: https://docs.vllm.ai/en/stable/features/tool_calling.html
- LMI Container Intro: https://www.youtube.com/watch?v=N0r5AWZe2HU

## Setup
Executing in a SM Classic NB Instance on a c5.2xlarge instance.

In [None]:
%pip install sagemaker --upgrade --quiet --no-warn-conflicts

In [None]:
import json
import sagemaker
import boto3

role = sagemaker.get_execution_role()  # execution role for the endpoint
sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs
bucket = sess.default_bucket()  # bucket to house artifacts
region = sess._region_name  # region name of the current SageMaker Studio environment
account_id = sess.account_id()

sm_client = boto3.client("sagemaker")  # client to intreract with SageMaker
smr_client = boto3.client("sagemaker-runtime")  # client to intreract with SageMaker Endpoints

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")
print(f"sagemaker version: {sagemaker.__version__}")

## Sample Deployment

In [None]:
#specify hardware
instance_type = "ml.g5.12xlarge"
num_gpu = 4

# specify container LMIv16
CONTAINER_VERSION = "0.34.0-lmi16.0.0-cu128"
inference_image = f"763104351884.dkr.ecr.{region}.amazonaws.com/djl-inference:{CONTAINER_VERSION}"
print(f"Using image URI: {inference_image}")

#utilize the vLLM async handler: 
vllm_env = {
    "HF_MODEL_ID": "Qwen/Qwen3-8B",
    "HF_TOKEN": "Enter HF Token",
    "SERVING_FAIL_FAST": "true",
    "OPTION_ASYNC_MODE": "true",
    "OPTION_ROLLING_BATCH": "disable",
    "OPTION_TENSOR_PARALLEL_DEGREE": json.dumps(num_gpu),
    "OPTION_ENTRYPOINT": "djl_python.lmi_vllm.vllm_async_service",
    "OPTION_TRUST_REMOTE_CODE": "true",
    "OPTION_ENABLE_AUTO_TOOL_CHOICE": "true",
    "OPTION_TOOL_CALL_PARSER": "hermes",
}

In [None]:
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements

# SageMaker Constructs
model_name = sagemaker.utils.name_from_base("lmi-qwen")
endpoint_name = model_name
inference_component_name = f"ic-{model_name}"

# SageMaker Model Object -> vLLM env
lmi_model = sagemaker.Model(
    image_uri=inference_image,
    env=vllm_env,
    role=role,
    name=model_name,
)

lmi_model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    container_startup_health_check_timeout=600,
    endpoint_name=endpoint_name,
    endpoint_type=sagemaker.enums.EndpointType.INFERENCE_COMPONENT_BASED,
    inference_component_name=inference_component_name,
    #check the memory available for your instance
    resources=ResourceRequirements(requests={"num_accelerators": 4, "memory": 1024*50, "copies": 1,}),
)

In [None]:
import json
content_type = "application/json"

# Adjust payload and parameters as needed
payload = "Who is Roger Federer?"
response = smr_client.invoke_endpoint(
    EndpointName=endpoint_name,
    InferenceComponentName=inference_component_name, #specify IC name
    ContentType=content_type,
    Accept=content_type,
    Body=json.dumps(
        {
            "inputs": payload,
            "parameters": {
                "max_new_tokens": 100  # Adjust this value as needed
                },
        }
    ),
)
result = json.loads(response["Body"].read().decode())['generated_text']
result

## Building our Mock Tool

In [None]:
import json
from pydantic import BaseModel, Field

# --- Model and data ---
class BioData(BaseModel):
    age: int
    gender: str
    occupation: str
    interests: list[str]

people_db = {
    "Ram":   BioData(age=25, gender="Male",   occupation="Engineer", interests=["Reading", "Traveling"]),
    "Shyam": BioData(age=22, gender="Male",   occupation="Doctor",   interests=["Reading", "Traveling"]),
    "Sita":  BioData(age=21, gender="Female", occupation="Teacher",  interests=["Reading", "Traveling"]),
    "Gita":  BioData(age=23, gender="Female", occupation="Lawyer",   interests=["Reading", "Traveling"]),
    "Hari":  BioData(age=24, gender="Male",   occupation="Engineer", interests=["Reading", "Traveling"]),
}

def return_biodata(name: str) -> dict:
    """Return biodata as a plain JSON-serializable dictionary."""
    if name not in people_db:
        return {"error": f"Person '{name}' not found"}
    return people_db[name].model_dump()

# --- Tool spec and registry ---
biodata_tool_spec = {
    "type": "function",
    "function": {
        "name": "return_biodata",
        "description": "Return biodata for a known person name.",
        "parameters": {
            "type": "object",
            "properties": {"name": {"type": "string"}},
            "required": ["name"],
        },
    },
}
tools = [biodata_tool_spec]
TOOLS_REGISTRY = {"return_biodata": return_biodata}

## Tool Calling Utility Functions

In [None]:
import json
from typing import Any, Dict, List, Optional, Tuple, Callable

def execute_tool(tool_name: str, raw_arguments: str):
    """
    Execute tool with input args detected
    """
    args = json.loads(raw_arguments or "{}")
    func = TOOLS_REGISTRY.get(tool_name)
    if not func:
        return {"error": f"Unknown tool: {tool_name}"}
    return func(**args)

def handle_tool_calls(
    result: Dict[str, Any],
    messages: List[Dict[str, Any]],
    executor: Callable[[str, str], Any],
) -> Tuple[Optional[str], List[Dict[str, Any]]]:
    """
    Handle tool calls emitted by the model.
    If no tools are called, return (assistant_content, messages).
    If tools are called, execute each tool, append the results, and return (None, updated_messages).
    """
    msg = result["choices"][0]["message"]
    tool_calls = msg.get("tool_calls", [])

    if not tool_calls:
        return msg.get("content"), messages

    # Append assistant message with its tool calls
    messages.append({k: v for k, v in msg.items() if k in ("role", "content", "tool_calls")})

    # Execute each tool sequentially and append outputs
    for tc in tool_calls:
        tool_id = tc["id"]
        tool_name = tc["function"]["name"]
        raw_args = tc["function"]["arguments"]
        tool_output = executor(tool_name, raw_args)

        messages.append({
            "role": "tool",
            "tool_call_id": tool_id,
            "name": tool_name,
            "content": json.dumps(tool_output, ensure_ascii=False),
        })

    return None, messages


def _invoke_ep(
    smr_client,
    endpoint_name: str,
    inference_component_name: str,
    payload: Dict[str, Any],
    content_type: str = "application/json",
) -> Dict[str, Any]:
    """
    Helper: Call the SageMaker endpoint and return the decoded JSON result.
    """
    resp = smr_client.invoke_endpoint(
        EndpointName=endpoint_name,
        InferenceComponentName=inference_component_name,
        ContentType=content_type,
        Accept=content_type,
        Body=json.dumps(payload),
    )
    return json.loads(resp["Body"].read().decode())


def chat_with_tools(
    smr_client,
    endpoint_name: str,
    inference_component_name: str,
    messages: List[Dict[str, Any]],
    tools: List[Dict[str, Any]],
    executor: Callable[[str, str], Any],
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
    """
    Run a full inference loop:
    1. Invoke the model once with messages + tools.
    2. If tool calls exist, execute them and append their outputs.
    3. Re-invoke the model with the updated messages.
    Returns (final_result, updated_messages).
    """
    payload = {"messages": messages, "tools": tools, "tool_choice": "auto"}

    first = _invoke_ep(smr_client, endpoint_name, inference_component_name, payload)
    assistant_content, updated_messages = handle_tool_calls(first, messages, executor)

    # If no tool calls were made, return first result
    if assistant_content is not None:
        return first, updated_messages

    # Reinvoke after tool execution
    final_payload = {"messages": updated_messages, "tools": tools, "tool_choice": "auto"}
    final = _invoke_ep(smr_client, endpoint_name, inference_component_name, final_payload)
    return final, updated_messages

## Sample Inference

In [None]:
# --- No tool call example ---
messages_no_tool = [
    {
        "role": "system",
        "content": (
            "You are a helpful assistant with access to tools. "
            "Use the 'return_biodata' tool only when the user asks for information about specific people. "
            "For all general knowledge questions, rely on your own knowledge."
        ),
    },
    {"role": "user", "content": "What is the capital of France?"},
]


# --- Single tool call example ---
messages_single_tool = [
    {
        "role": "system",
        "content": (
            "You are a helpful assistant with access to tools. "
            "Use the 'return_biodata' tool only when the user asks for information about specific people. "
            "For all general knowledge questions, rely on your own knowledge."
        ),
    },
    {"role": "user", "content": "Who is Shyam?"},
]


# --- Multi-tool (complex) example ---
messages_multi_tool = [
    {
        "role": "system",
        "content": (
            "You are a helpful assistant with access to tools. "
            "Use the 'return_biodata' tool for any information requested about people. "
            "For all other queries, use your general knowledge."
        ),
    },
    {"role": "user", "content": "What is the combined age of Ram and Shyam?"},
]

In [None]:
# No Tool
final_result, updated_messages = chat_with_tools(
    smr_client,
    endpoint_name,
    inference_component_name,
    messages_no_tool,
    tools,
    execute_tool,
)

print(json.dumps(final_result["choices"][0]["message"]["content"], indent=2))


# Single Tool
final_result, updated_messages = chat_with_tools(
    smr_client,
    endpoint_name,
    inference_component_name,
    messages_single_tool,
    tools,
    execute_tool,
)

print(json.dumps(final_result["choices"][0]["message"]["content"], indent=2))


# Multi-Tool
final_result, updated_messages = chat_with_tools(
    smr_client,
    endpoint_name,
    inference_component_name,
    messages_multi_tool,   # or messages_no_tool / messages_single_tool
    tools,
    execute_tool,
)

print(json.dumps(final_result["choices"][0]["message"]["content"], indent=2))