### Setup

In [None]:
import json
from cactus_bindings import bindings
import pandas as pd
import numpy as np
import warnings
from enum import Enum
import time
from typing import List, Dict, Callable, Optional, get_args, get_origin, Union
import inspect
from docstring_parser import parse
from dataclasses import dataclass, field
from functools import wraps

warnings.simplefilter('ignore')

In [None]:
def calculate_cosine_similarity(text1:str, text2:str, verbose:bool=False)->float:
    t1 = time.time()
    v1 = np.asarray(clm.embed(text1))
    t2 = time.time()
    v2 = np.asarray(clm.embed(text2))
    t3 = time.time()
    if verbose:
        print(f'Embeddings complete in {t2-t1:.2f}sec | {t3-t2:.2f}sec')
    return float(np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)))

def keywords(kw: list = None):
    "A decorator to attach a list of keywords to a function."
    def decorator(func):
        setattr(func, '_keywords', kw or [])
        
        @wraps(func)
        def wrapper(*args, **kwargs):
            return func(*args, **kwargs)
        
        return func
    
    return decorator

class TransformationType(str, Enum):
    NAME_DESCRIPTION = "name-description"
    NAME_DESCRIPTION_ARGS = "name-description-args"
    NAME_DESCRIPTION_ARGS_DESCRIPTIONS = "name-description-args-descriptions"


@dataclass
class ToolArgument:
    name:str
    description:str
    _type:str
    required:bool


@dataclass
class Tool:
    name:str
    description:str
    func:Callable
    args:Optional[List[ToolArgument]]=field(default_factory=list)
    keywords:Optional[List[str]]=field(default_factory=list)

    def to_openai_format(self)->Dict:
        return {
            "type": "function",
            "function": {
                "name": self.name,
                "description": self.description,
                "parameters": {
                    "type": "object",
                    "properties": {
                        arg.name: {
                            "type": arg._type,
                            "description": arg.description
                        }
                        for arg in self.args
                    },
                    "required": [arg.name for arg in self.args if arg.required]
                }
            }
        }

    def to_string(
        self,
        transformation_type:TransformationType=TransformationType.NAME_DESCRIPTION_ARGS_DESCRIPTIONS
    )->str:
        match transformation_type.value:
            case 'name-description':
                return f"Function {self.name} with description: `{self.description}`" 
            case 'name-description-args':
                return f"""Function {self.name} with description: `{self.description}` and arguments {', '.join([f"`{arg.name}`" for arg in self.args])}"""
            case 'name-description-args-descriptions':
                return f"""Function {self.name} with description: `{self.description}` and arguments {', '.join([f"`{arg.name}` ({arg.description})" for arg in self.args])}"""
            case _:
                raise Exception('unknown case!')


class Tools:

    def __init__(
        self, 
        tools:List[Tool],
    )->None:
        self.tools=tools

    def retrieve_relevant(
        self, 
        user_query:str,
        return_format:str='openai',
        top_n:int=1
    )->List[Tool | Dict]:
        assert return_format in ('openai', 'dict'), 'Unsupported return format!'
        
        tool_cos_sims = [calculate_cosine_similarity(tool.to_string(), user_query) for tool in self.tools]
        sorted_tools = sorted(zip(self.tools, tool_cos_sims), key=lambda x: x[1], reverse=True)
        retrieved_tools = [x[0] for x in sorted_tools[:top_n]]
        
        match return_format:
            case 'openai':
                return [t.to_openai_format() for t in retrieved_tools]
            case 'dict':
                return retrieved_tools

class CactusModel:
    WEIGHTS_PATH="/Users/noahcylich/Documents/Desert/cactus-fc/weights/"

    def __init__(self, slug:str, context_size:int=2048)->None:
        self.model_path=self.WEIGHTS_PATH+slug
        self.context_size=context_size
        self.initialize_model()

    def initialize_model(self)->None:
        self.model=bindings.cactus_init(
            model_path=self.model_path,
            context_size=self.context_size
        )

    def complete(self, messages:List[Dict], max_tokens:int=1024, tools:List=[])->Dict:
        try:
            response_json = bindings.cactus_complete(
                model=self.model,
                messages_json=json.dumps(messages),
                response_buffer_size=4096,
                options_json=json.dumps({
                    "max_tokens": max_tokens,
                    "temperature": 0.7
                }),
                tools_json=json.dumps(tools),
                callback=None
            )
            response=json.loads(response_json)
            return response
        except Exception as e:
            print(e)

    def embed(self, text:str)->List:
        try:
            return bindings.cactus_embed(
                model=self.model,
                text=text
            )
        except Exception as e:
            print(e)


class CactusChatModel(CactusModel):

    def __init__(
        self, 
        slug:str, 
        context_size:int=2048, 
        prompt:str='You are a helpful assistant.'
    )->None:
        super().__init__(slug, context_size)
        self.system_prompt = prompt 
        self.reset_history()

    def reset_history(self)->None:
        """Resets the conversation to just the initial system prompt."""
        self.message_history_raw = [{'role': 'system', 'content': self.system_prompt}]

    def send_message(
        self, 
        message:str, 
        tools:Optional[Tools]=Tools([]),
        filter_tools:bool=True,
        top_n_tools:int=1,
        auto_call_tool=True
    )->Dict:
        if filter_tools:
            self.tools=tools.retrieve_relevant(message, return_format='openai', top_n=top_n_tools)
            print(f"Filtered tools down to {[t.get('function').get('name') for t in self.tools]} for {message}")
        else:
            self.tools=[tool.to_openai_format() for tool in tools.tools]
            
        self.message_history_raw.append({'role': 'user', 'content': message})
        response_json = self.complete(messages=self.message_history, tools=self.tools)

        if response_json:
            self.message_history_raw.append(response_json)
    
            if auto_call_tool and response_json.get('function_calls'):
                
                function_calls = response_json.get('function_calls', [])
                for call in function_calls:
                    tool_name = call.get('name') or call.get('function')
                    tool_args = call.get('arguments')
    
                    tool = [t for t in tools.tools if t.name == tool_name]
                    if tool:
                        try:
                            tool_output = tool[0].func(**tool_args)
                        except Exception as e:
                            tool_output = 'error calling tool'
                        self.message_history_raw.append({'role': 'tool_response', 'content': tool_output})
                    else:
                        self.message_history_raw.append({'role': 'tool_response', 'content': f"Error: Tool '{tool_name}' not found."})
                
                response_json = self.complete(self.message_history, max_tokens=1024)
                self.message_history_raw.append(response_json)
            
            return response_json
        else:
            return {}

    @property
    def message_history(self)->List:
        return [
            x if x.get('role') in ['user', 'system', 'tool_response'] else \
                {**{'role': 'assistant', 'content': x['response'] if not x.get('function_calls') else [{"type": "tool_call", "function": x} for x in x.get('function_calls')]}} \
                for x in self.message_history_raw
        ]

    @property
    def messages_df(self)->pd.DataFrame:
        df = pd.DataFrame(self.message_history_raw)
        df['content'] = df.content.combine_first(df.response)
        df['role'].fillna('assistant', inplace=True)
        return df[[col for col in df if col != 'response']]

def _map_type(py_type) -> str:
    """Maps Python types to JSON schema string types."""
    if py_type == str:
        return "string"
    if py_type == int:
        return "integer"
    if py_type == float:
        return "number"
    if py_type == bool:
        return "boolean"
    return "string" # Default

def build_tool_from_func(func: Callable) -> Tool:
    "Generates a Tool object by inspecting a Python function."
    
    sig = inspect.signature(func) # get signature (names, types, defaults)
    docstring = parse(inspect.getdoc(func))
    param_docs = {param.arg_name: param.description for param in docstring.params}
    
    tool_args = []
    for name, param in sig.parameters.items():        
        is_required = (param.default == inspect.Parameter.empty)
        real_type = param.annotation
        if get_origin(real_type) in [Union, Optional]:
            args = get_args(real_type) # first type argument that isn't None
            real_type = next(t for t in args if t is not type(None))
        
        tool_args.append(
            ToolArgument(
                name=name,
                description=param_docs.get(name, ""),
                _type=_map_type(real_type),
                required=is_required
            )
        )
        
    return Tool(
        name=func.__name__,
        description=docstring.short_description or "",
        func=func,
        args=tool_args,
        keywords=getattr(func, '_keywords', [])
    )

#### Tool definitions:

In [None]:
from typing import Optional

# Using Optional[type] which is equivalent to type | None
# for broader Python compatibility.

def create_note(text: str):
    """
    Creates a new note with the given text. Call this tool if asked to be reminded or to take a note.

    Args:
        text: The text of the note, usually a direct quote from the user
    """
    return f"Note created with text: {text}"


def set_alarm(time_hours: int, time_minutes: int):
    """
    Sets an alarm for a specific time.

    Args:
        time_hours: The hour component of the alarm time (24 hour time)
        time_minutes: The minute component of the alarm time (0-59)
    """
    return f"Alarm set successfully!"


def set_timer_absolute(day_offset: Optional[str], time_hours: int, time_minutes: int):
    """
    Sets a timer to go off at an absolute day and time.

    Args:
        day_offset: The offset of the day to remind the user at e.g. 'tomorrow', 'today', 'thursday' (will be the next thursday), '3' (will be in 3 days)
        time_hours: The hour component of the desired end time (24 hour time)
        time_minutes: The minute component of the desired end time (0-59)
    """
    return f"Absolute timer set for {day_offset} at {time_hours}:{time_minutes}"


def set_timer(time_hours: Optional[int], time_minutes: Optional[int], time_seconds: Optional[int]):
    """
    Sets a timer for a relative duration (hours, minutes, seconds).

    Args:
        time_hours: The number of hours on the timer
        time_minutes: The number of minutes on the timer
        time_seconds: The number of seconds on the timer
    """
    return f"Timer set for {time_hours}h {time_minutes}m {time_seconds}s"


def reminder_absolute(day_offset: Optional[str], absolute_time_hour: int, absolute_time_minute: int, date_month_day: Optional[str], date_year: Optional[int], message: str):
    """
    Creates a reminder for a specific absolute date and time.

    Args:
        day_offset: The offset of the day to remind the user at e.g. 'tomorrow', 'today', 'thursday' (will be the next thursday), '3' (will be in 3 days)
        absolute_time_hour: The absolute time to remind the user at as a 24 hour hour part e.g. '17'
        absolute_time_minute: The absolute time to remind the user at as a minute part e.g. '30', or '00' for the top of the hour
        date_month_day: The date to remind the user at if specified by the user as a date part (month-day) e.g. '12-31'
        date_year: The year to remind the user at if specified by the user as a year part e.g. '2022'
        message: The message to remind the user e.g. 'Buy more milk'
    """
    return f"Absolute reminder set for '{message}' on {date_month_day}-{date_year} or {day_offset} at {absolute_time_hour}:{absolute_time_minute}"


def create_reminder_relative(relative_time: int, time_unit: str, message: str):
    """
    When the user requires a reminder at a relative time e.g. 'in 5 minutes' use the create_reminder_relative tool.

    Args:
        relative_time: The relative time to remind the user at as n 'time_unit's in the future
        time_unit: The unit of time for the relative time. Must be one of: ["seconds", "minutes", "hours", "days", "weeks", "months", "years"]
        message: The message to remind the user e.g. 'Buy more milk'
    """
    return f"Relative reminder set for '{message}' in {relative_time} {time_unit}"

tools = Tools([build_tool_from_func(f) for f in [
    create_note,
    set_alarm,
    set_timer_absolute,
    set_timer,
    reminder_absolute,
    create_reminder_relative
]])

### Benchmarking

In [None]:
eval_data = [
    {
        "query": "send Henry a message about our upcoming framework release.", 
        "correct_tool": "write_text_message"
    },
    {
        "query": "what is the weather in London?", 
        "correct_tool": "weather_lookup"
    },
    {
        "query": "Wake me up at 5 am tomorrow please.", 
        "correct_tool": "set_alarm"
    },
    {
        "query": "Write down that i need to go buy groceries for the house tomorrow", 
        "correct_tool": "create_note"
    },
    {
        "query": "Hey how are you!", 
        "correct_tool": None
    },
    {
        "query": "Text mom I'll be home late.",
        "correct_tool": "write_text_message"
    },
    {
        "query": "Can you message Alex about the 3pm call?",
        "correct_tool": "write_text_message"
    },
    {
        "query": "Tell Henry I've finished the draft for the medium article.", # Personalized from your context
        "correct_tool": "write_text_message"
    },
    {
        "query": "Will I need an umbrella tomorrow in New York?",
        "correct_tool": "weather_lookup"
    },
    {
        "query": "How cold is it in Paris right now?",
        "correct_tool": "weather_lookup"
    },
    {
        "query": "Get me the forecast for San Francisco this weekend.",
        "correct_tool": "weather_lookup"
    },

    {
        "query": "Set an alarm for 7:30 PM.",
        "correct_tool": "set_alarm"
    },
    {
        "query": "I need an alarm for 6:15 tomorrow morning.",
        "correct_tool": "set_alarm"
    },
    {
        "query": "Remind me to buy milk and eggs.",
        "correct_tool": "create_note"
    },
    {
        "query": "Make a note: pick up dry cleaning on Tuesday.",
        "correct_tool": "create_note"
    },
    {
        "query": "Save this thought: on-device inference is key for privacy.", # Personalized from your context
        "correct_tool": "create_note"
    },
    {
        "query": "That's great, thanks!",
        "correct_tool": None
    },
    {
        "query": "What is the capital of France?",
        "correct_tool": None
    },
    {
        "query": "Who won the game last night?",
        "correct_tool": None
    },
    {
        "query": "Make a note of the weather in Berlin.", # Ambiguous query to test boundaries
        "correct_tool": "create_note" 
    },
    {
        "query": "I need an alarm for 8:45 in the morning.",
        "correct_tool": "set_alarm"
    },
    {
        "query": "Set an alarm for 11:30 PM tonight.",
        "correct_tool": "set_alarm"
    },
    {
        "query": "Alarm for 6am.",
        "correct_tool": "set_alarm"
    },
    {
        "query": "Can you wake me up at 7:15 am?",
        "correct_tool": "set_alarm"
    },

    # --- weather_lookup ---
    {
        "query": "What's the weather like in Boston?",
        "correct_tool": "weather_lookup"
    },
    {
        "query": "I'm going to Paris tomorrow, what's the forecast?",
        "correct_tool": "weather_lookup"
    },
    {
        "query": "Tell me the temperature in Dubai.",
        "correct_tool": "weather_lookup"
    },
    {
        "query": "Weather forecast for Seattle for the next 3 days.",
        "correct_tool": "weather_lookup"
    },

    # --- write_text_message ---
    {
        "query": "Send a message to Alice asking 'What time is dinner?'",
        "correct_tool": "write_text_message"
    },
    {
        "query": "Text Bob: 'I'm running about 15 minutes late.'",
        "correct_tool": "write_text_message"
    },
    {
        "query": "Please message my manager that I've pushed the new code.",
        "correct_tool": "write_text_message"
    },
    {
        "query": "Text 'On my way!' to Sarah.",
        "correct_tool": "write_text_message"
    },
    {
        "query": "Can you text my brother 'Happy birthday!'?",
        "correct_tool": "write_text_message"
    },
    {
        "query": "Note to self: buy milk.",
        "correct_tool": "create_note"
    },
    {
        "query": "Remember this: the new inference engine for the react-native app is a priority.",
        "correct_tool": "create_note"
    },
    {
        "query": "I need to make a note about the meeting... just write down 'Follow up with marketing'.",
        "correct_tool": "create_note"
    },
    {
        "query": "Jot this down: need to research more apps for the Cactus library.",
        "correct_tool": "create_note"
    },
    {
        "query": "Create a new note titled 'Gift Ideas' with 'book for mom' in it.",
        "correct_tool": "create_note"
    },
    {
        "query": "What time is it?",
        "correct_tool": None
    },
    {
        "query": "Thanks, that's perfect.",
        "correct_tool": None
    },
    {
        "query": "How do I set an alarm?",
        "correct_tool": None
    },
    {
        "query": "Who was the first person on the moon?",
        "correct_tool": None
    },
    {
        "query": "How old is the Eiffel Tower?",
        "correct_tool": None
    },
    {
        "query": "What's the alarm for?",
        "correct_tool": None
    },
    {
        "query": "Tell me a joke.",
        "correct_tool": None
    }
]

In [None]:
model_slugs=[
    'Qwen3-0.6B_tool_calling_lora',
    # 'qwen3-0.6b',
    # 'lfm2-1.2B-Tool',
    # 'lfm2-1.2B',
    # 'lfm2-350m',
]

results = []

for model_slug in model_slugs:
    
    clm = CactusChatModel(
        slug=model_slug,
        prompt="You are a helpful assistant. You have access to a list of tools. You are communicating via one-shot interactions. If using a tool/function, just call it without asking follow-up questions."
    )
    
    for filter_tools in [True, False]:
        for sample in eval_data:
            query = sample['query']
            correct_tool = sample['correct_tool']
        
            clm.reset_history()
        
            clm.send_message(
                message=query,
                tools=tools,
                filter_tools=filter_tools,
                top_n_tools=3
            )
            try:
                function_calls = [m for m in clm.message_history_raw if m.get('function_calls')]
                if function_calls:
                    tools_called = [fc.get('name') for fc in function_calls[-1].get('function_calls', {})]
                else:
                    tools_called = []
            
                correct_tool_called = (correct_tool is None and tools_called == []) or (correct_tool in tools_called)
            
                results.append({
                    "query": query,
                    "model": model_slug,
                    "filter_tools": filter_tools,
                    "correct_tool": correct_tool,
                    "tools_called": tools_called,
                    "correct_tool_called": correct_tool_called,
                    "message_history": clm.message_history_raw
                })
            except Exception as e:
                print(e)

Filtered tools down to ['create_note', 'reminder_absolute', 'set_timer'] for Who was the first person on the moon?


In [None]:
df = pd.DataFrame(results)

In [None]:
# df[df.correct_tool == 'weather_lookup']
df[df.correct_tool == 'write_text_message']

In [None]:
pd.DataFrame(df.groupby(['filter_tools', 'correct_tool']).correct_tool_called.mean()).reset_index().pivot(
    columns='correct_tool',
    index='filter_tools',
    values='correct_tool_called'
)

In [None]:
df.groupby(['filter_tools']).correct_tool_called.mean()

In [None]:
df.groupby(['filter_tools', 'model', ]).correct_tool_called.mean().reset_index().pivot(
    columns='model',
    index='filter_tools',
    values='correct_tool_called'
)

In [None]:
response_df = df[
    (df.model == model_slugs[1]) &
    (df.correct_tool_called == False)
].message_history.apply(lambda x: x[1].get('response', '') if len(x) > 1 else None)

for resp in response_df.dropna().tolist():
    print("----")
    print(resp)