In [1]:
import json
from typing import Any, Dict, List
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "katanemo/Arch-Router-1.5B"
model = AutoModelForCausalLM.from_pretrained(
    model_name, device_map="auto", torch_dtype="auto", trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Please use our provided prompt for best performance
TASK_INSTRUCTION = """
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>

{routes}

</routes>

<conversation>

{conversation}

</conversation>
"""

FORMAT_PROMPT = """
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags.  Follow the instruction:
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
2. You must analyze the route descriptions and find the best match route for user latest intent. 
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.

Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
{"route": "route_name"} 
"""

# Define route config
route_config = [
    {
        "name": "code_generation",
        "description": "Generating new code snippets, functions, or boilerplate based on user prompts or requirements",
    },
    {
        "name": "bug_fixing",
        "description": "Identifying and fixing errors or bugs in the provided code across different programming languages",
    },
    {
        "name": "performance_optimization",
        "description": "Suggesting improvements to make code more efficient, readable, or scalable",
    },
    {
        "name": "api_help",
        "description": "Assisting with understanding or integrating external APIs and libraries",
    },
    {
        "name": "programming",
        "description": "Answering general programming questions, theory, or best practices",
    },
]

# Helper function to create the system prompt for our model
def format_prompt(
    route_config: List[Dict[str, Any]], conversation: List[Dict[str, Any]]
):
    return (
        TASK_INSTRUCTION.format(
            routes=json.dumps(route_config), conversation=json.dumps(conversation)
        )
        + FORMAT_PROMPT
    )

# Define conversations

conversation = [
    {
        "role": "user",
        "content": "what is the Stripe API used for?",
    }
]

route_prompt = format_prompt(route_config, conversation)

messages = [
    {"role": "user", "content": route_prompt},
]

input_ids = tokenizer.apply_chat_template(
    messages, add_generation_prompt=True, return_tensors="pt"
).to(model.device)

# 2. Generate
generated_ids = model.generate(
    input_ids=input_ids,  # or just positional: model.generate(input_ids, …)
    max_new_tokens=32768,
)

# 3. Strip the prompt from each sequence
prompt_lengths = input_ids.shape[1]  # same length for every row here
generated_only = [
    output_ids[prompt_lengths:]  # slice off the prompt tokens
    for output_ids in generated_ids
]

# 4. Decode if you want text
response = tokenizer.batch_decode(generated_only, skip_special_tokens=True)[0]
print(response)


`torch_dtype` is deprecated! Use `dtype` instead!
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


{'route': 'api_help'}
