In [1]:
import json
from typing import Any, Dict, List
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "katanemo/Arch-Function-3B"
model = AutoModelForCausalLM.from_pretrained(
    model_name, device_map="auto", torch_dtype="auto", trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Please use our provided prompt for best performance
TASK_PROMPT = """
You are a helpful assistant.
""".strip()

TOOL_PROMPT = """
# Tools

You may call one or more functions to assist with the user query.

You are provided with function signatures within <tools></tools> XML tags:
<tools>
{tool_text}
</tools>
""".strip()

FORMAT_PROMPT = """
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>
""".strip()

# Define available tools
get_weather_api = {
    "type": "function",
    "function": {
        "name": "get_weather",
        "description": "Get the current weather for a location",
        "parameters": {
            "type": "object",
            "properties": {
                "location": {
                    "type": "str",
                    "description": "The city and state, e.g. San Francisco, New York",
                },
                "unit": {
                    "type": "str",
                    "enum": ["celsius", "fahrenheit"],
                    "description": "The unit of temperature to return",
                },
            },
            "required": ["location"],
        },
    },
}

openai_format_tools = [get_weather_api]


def convert_tools(tools: List[Dict[str, Any]]):
    return "\n".join([json.dumps(tool) for tool in tools])

# Helper function to create the system prompt for our model
def format_prompt(tools: List[Dict[str, Any]]):
    tool_text = convert_tools(tools)

    return (
        TASK_PROMPT
        + "\n\n"
        + TOOL_PROMPT.format(tool_text=tool_text)
        + "\n\n"
        + FORMAT_PROMPT
        + "\n"
    )


system_prompt = format_prompt(openai_format_tools)

The explicitly set RoPE scaling factor (config.rope_scaling['factor'] = 4.0) does not match the ratio implicitly set by other parameters (implicit factor = post-yarn context length / pre-yarn context length = config.max_position_embeddings / config.rope_scaling['original_max_position_embeddings'] = 1.0). Using the explicit factor (4.0) in YaRN. This may cause unexpected behaviour in model usage, please correct the 'max_position_embeddings' fields in the model config.
`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [2]:
messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": "What is the weather in Redmond?"},
]

print("system_prompt: ", system_prompt)

inputs = tokenizer.apply_chat_template(
    messages, add_generation_prompt=True, return_tensors="pt"
).to(model.device)

outputs = model.generate(
    inputs,
    max_new_tokens=512,
    do_sample=False,
    num_return_sequences=1,
    eos_token_id=tokenizer.eos_token_id,
)

response = tokenizer.decode(outputs[0][len(inputs[0]) :], skip_special_tokens=True)
print(response)


The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


system_prompt:  You are a helpful assistant.

# Tools

You may call one or more functions to assist with the user query.

You are provided with function signatures within <tools></tools> XML tags:
<tools>
{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather for a location", "parameters": {"type": "object", "properties": {"location": {"type": "str", "description": "The city and state, e.g. San Francisco, New York"}, "unit": {"type": "str", "enum": ["celsius", "fahrenheit"], "description": "The unit of temperature to return"}}, "required": ["location"]}}}
</tools>

For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>

<tool_call>
{"name": "get_weather", "arguments": {"location": "Redmond, Washington"}}
</tool_call>
