In [27]:
# !pip install openai pydantic python-dotenv

In [None]:
from __future__ import annotations

from functools import update_wrapper
from typing import Callable, Any, Optional, Dict, TypeVar, overload, Union, cast
from pydantic import create_model
from openai.types.chat import ChatCompletionToolUnionParam
import inspect

def build_tool_schema(
    func: Callable[..., Any], 
    name: str,
    description: str = ""
) -> ChatCompletionToolUnionParam:
    signature = inspect.signature(func)
    final_description = inspect.getdoc(func) or description

    model_fields = {
        name: (param.annotation, ...)
        for name, param in signature.parameters.items()
    }

    ToolArguments = create_model("ToolArguments", **model_fields) # type: ignore
    raw_schema = ToolArguments.model_json_schema() # type: ignore
    raw_schema.pop("title", None) # type: ignore
    raw_schema["additionalProperties"] = False

    return {
        "type": "function",
        "function": {
            "name": name,
            "description": final_description,
            "strict": True,
            "parameters": raw_schema,
        },
    }

class Tool:
    """
    Wrapper that makes a function into a tool with a schema.
    """
    __tool_wrapped__ = True

    def __init__(
        self,
        func: Optional[Callable[..., Any]] = None,
        *,
        schema: Optional[ChatCompletionToolUnionParam] = None,
    ) -> None:
        if func is None:
            func = self.__call__

        self._func = func

        if schema is None:
            schema = build_tool_schema(self._func, name=str(self.__class__.__name__))
        
        self.schema = schema
        update_wrapper(self, self._func)

    @classmethod
    def from_decorator(
        cls,
        func: Callable[..., Any],
        *,
        schema: ChatCompletionToolUnionParam,
    ) -> "Tool":
        """ 
        Create a Tool instance from a function and a schema.
        This is used internally by the `tool` decorator.
        """
        tool_instance = cls(func, schema=schema)
        return tool_instance

    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        if inspect.iscoroutinefunction(self._func):
            return self._func(*args, **kwargs)
        return self._func(*args, **kwargs)

    def __repr__(self) -> str:
        return f"<Tool {self._func.__name__}>"
    
    def __name__(self) -> str:
        return self._func.__name__

T = TypeVar("T", bound=Callable[..., Any])

@overload
def tool(func: T) -> Tool: ... # type: ignore

@overload
def tool(
    *,
    description: str = "",
) -> Callable[[T], Tool]: ... # type: ignore

def tool(
    func: Optional[T] = None,
    *,
    description: str = "",
) -> Union[Tool, Callable[[T], Tool]]:
    """
    Decorator to wrap a function into a Tool with OpenAI function-calling schema.
    """
    def decorator(inner_func: T) -> Tool:
        # Inspect signature and build pydantic model for parameters
        signature = inspect.signature(inner_func)
        final_description = inspect.getdoc(inner_func) or description

        model_fields = {
            name: (param.annotation, ...)
            for name, param in signature.parameters.items()
        }
        ToolArguments = create_model("ToolArguments", **model_fields)  # type: ignore
        raw_schema = ToolArguments.model_json_schema() # type: ignore
        assert isinstance(raw_schema, dict), "ToolArguments.model_json_schema() must return a dict"
        tool_arguments: Dict[str, Any] = raw_schema # type: ignore
        tool_arguments.pop("title", None)
        tool_arguments["additionalProperties"] = False

        schema: Any = {
            "type": "function",
            "function": {
                "name": inner_func.__name__,
                "description": final_description,
                "strict": True,
                "parameters": tool_arguments,
            },
        }

        schema = cast(ChatCompletionToolUnionParam, schema)

        return Tool.from_decorator(inner_func, schema=schema)

    # If used without args: @tool
    return decorator if func is None else decorator(func)

In [29]:
from abc import ABC, abstractmethod
from typing import Generator, Callable, Any
from pydantic.dataclasses import dataclass
from openai.types.chat import ChatCompletionMessageParam
from enum import Enum
import json

MessageLike = Union[Dict[str, Any], ChatCompletionMessageParam]

class StreamingEventType(Enum):
    TOKEN = "token"
    TOOL_CALL = "tool_call"
    TOOL_RESULT = "tool_result"
    COMPLETED = "completed"

class AgentStreamingEventType(Enum):
    TOKEN = "token"
    ACTION = "action"
    ACTION_RESULT = "action_result"
    COMPLETED = "completed"


@dataclass
class ToolCall:
    name: str
    tool_call_id: str
    arguments: dict[str, Any]

@dataclass
class Context:
    conversation_id: str
    user_query: str
    messages: list[MessageLike]
    tools: list[Callable[..., str]]

@dataclass
class LLMStreamingEvent:
    event: StreamingEventType
    data: str
    tool_call_id: str | None = None
    tool_calls: list[ToolCall] | None = None

@dataclass
class AgentStreamingEvent:
    event: AgentStreamingEventType
    data: str
    action_id: str | None = None

class LLM(ABC):
    @abstractmethod
    def generate_text(self, messages: list[MessageLike], tools: Optional[list[Callable[..., str]]] = None) -> Generator[LLMStreamingEvent, None, None]:
        """Generate text based on the given prompt."""
        pass

class EmbeddingModel(ABC):
    @abstractmethod
    def embed_texts(self, texts: list[str]) -> list[list[float]]:
        """Generate embedding for the given texts."""
        pass

class ToolHandler:
    def __init__(self, tools: list[Callable[..., str]]) -> None:
        super().__init__()
        self.tool_mappings: dict[str, Callable[..., str]] = {}
        self.add_tools(tools)

    def add_tools(self, tools: list[Callable[..., str]]) -> None:
        """Add a tool to the router."""
        for tool in tools:
            if tool.__name__ in self.tool_mappings:
                continue
            self.tool_mappings[tool.__name__] = tool

    def execute_tool(self, tool_call: ToolCall) -> str:
        """Execute the specified tool call."""
        tool = self.tool_mappings.get(tool_call.name)
        if tool:
            return tool(**tool_call.arguments)
        return "Tool not found."
    
class ToolRouter(ABC):
    def __init__(self, tool_handler: ToolHandler) -> None:
        super().__init__()
        self.tool_handler = tool_handler

    def add_tools(self, tools: list[Callable[..., str]]) -> None:
        """Add tools to the router."""
        self.tool_handler.add_tools(tools)        

    @abstractmethod
    def retrieve(self, query: str) -> list[Callable[..., str]]:
        """Retrieve appropriate tools based on the query."""
        pass

class Agent:
    def __init__(self, llm: LLM, tool_handler: ToolHandler) -> None:
        super().__init__()
        self.llm = llm
        self.tool_handler = tool_handler
        self._tools: list[Callable[..., str]] = []

    def add_tools(self, tools: list[Callable[..., str]]) -> None:
        """Add tools to the agent."""
        self._tools.extend(tools)
        self.tool_handler.add_tools(tools)

    def run(self, messages: list[MessageLike]) -> Generator[LLMStreamingEvent, None, None]:
        """Run the agent with the given prompt."""
        while True:
            print("Agent running with messages:", messages)
            for delta in self.llm.generate_text(messages, tools=self._tools):
                if delta.event == StreamingEventType.TOOL_CALL and delta.tool_calls:
                    yield delta
                    for tool_call in delta.tool_calls:
                        print(f"Agent executing tool: {tool_call.name} with arguments {tool_call.arguments}")
                        messages.append({"role": "assistant", "tool_calls": [{"id": tool_call.tool_call_id, "type": "function", "function": {"name": tool_call.name, "arguments": json.dumps(tool_call.arguments)}}]})
                        tool_result = self.tool_handler.execute_tool(tool_call)
                        yield LLMStreamingEvent(
                            event=StreamingEventType.TOOL_RESULT,
                            tool_call_id=tool_call.tool_call_id,
                            data=tool_result,
                        )
                        messages.append({"role": "tool", "tool_call_id": tool_call.tool_call_id, "content": tool_result})
                elif delta.event == StreamingEventType.COMPLETED:
                    yield LLMStreamingEvent(
                        event=StreamingEventType.COMPLETED,
                        data=delta.data,
                    )
                    return
                yield LLMStreamingEvent(
                    event=StreamingEventType.TOKEN,
                    data=delta.data,
                )

class ConversationRepository(ABC):
    @abstractmethod
    def get_conversation_history(self, conversation_id: str) -> list[MessageLike]:
        """Retrieve the conversation history for the given conversation ID."""
        pass

    @abstractmethod
    def add_messages(self, conversation_id: str, messages: list[MessageLike]) -> None:
        """Add messages to the conversation history."""
        pass

class ContextMiddleware:
    def __init__(
        self, 
        conversation_repository: ConversationRepository,
        tool_router: ToolRouter
    ) -> None:
        super().__init__()
        self.conversation_repository = conversation_repository
        self.tool_router = tool_router
        
    def prepare_context(self, conversation_id: str, query: str) -> Context:
        """Prepare context for the LLM based on conversation ID and query."""
        history = self.conversation_repository.get_conversation_history(conversation_id)
        tools = self.tool_router.retrieve(query)

        history.append({"role": "user", "content": query})
        
        return Context(
            conversation_id=conversation_id,
            user_query=query,
            messages=history,
            tools=tools
        )

class AgentService:
    def __init__(
        self, 
        agent: Agent,
        context_middleware: ContextMiddleware
    ) -> None:
        super().__init__()
        self.agent = agent
        self.context_middleware = context_middleware

    def handle_request(self, conversation_id: str, query: str) -> Generator[AgentStreamingEvent, None, None]:
        """Handle the request by preparing context and generating text."""
        context = self.context_middleware.prepare_context(conversation_id, query)
        self.agent.add_tools(context.tools)
        for delta in self.agent.run(context.messages):
            if delta.event == StreamingEventType.TOOL_CALL and delta.tool_calls:
                for tool_call in delta.tool_calls:
                    yield AgentStreamingEvent(
                        event=AgentStreamingEventType.ACTION,
                        data=f"Calling tool {tool_call.name} with arguments {tool_call.arguments}",
                        action_id=tool_call.tool_call_id or ""
                    )
            
            elif delta.event == StreamingEventType.TOOL_RESULT:
                yield AgentStreamingEvent(
                    event=AgentStreamingEventType.ACTION_RESULT,
                    data=delta.data,
                    action_id=delta.tool_call_id or ""
                )

            elif delta.event == StreamingEventType.COMPLETED:
                self.context_middleware.conversation_repository.add_messages(
                    conversation_id,
                    [{"role": "user", "content": query}, {"role": "assistant", "content": delta.data}]
                )
                yield AgentStreamingEvent(
                    event=AgentStreamingEventType.COMPLETED,
                    data=delta.data
                )

            else:
                yield AgentStreamingEvent(
                    event=AgentStreamingEventType.TOKEN,
                    data=delta.data
                )

In [30]:
import openai
from openai.types.chat import ChatCompletionMessageParam
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
import os
from dotenv import load_dotenv
from typing import Sequence

load_dotenv()

class OpenAILLM(LLM):
    def __init__(self, model_name: openai.types.ChatModel, api_key: str | None = os.getenv("OPENAI_API_KEY")) -> None:
        self.model_name = model_name
        self.api_key = api_key 
        self.client = openai.OpenAI(api_key=self.api_key)

    def generate_text(self, messages: list[MessageLike], tools: Optional[list[Callable[..., str]]] = None) -> Generator[LLMStreamingEvent, None, None]:
        openai_tools: Sequence[ChatCompletionToolUnionParam] = []
        for t in tools or []:
            wrapped_tool = tool(t)
            openai_tools.append(wrapped_tool.schema)

        print(f"Generating text with model {self.model_name} and tools: {[openai_tool['function']['name'] for openai_tool in openai_tools]}") # type: ignore
        
        generator = self.client.chat.completions.create(
            model=self.model_name,
            messages=cast(Sequence[ChatCompletionMessageParam], messages), 
            tools=openai_tools if openai_tools else openai.omit,
            stream=True,
        )
        
        full_content = ""
        final_tool_calls: dict[int, ChoiceDeltaToolCall] = {}
        for chunk in generator:
            delta = chunk.choices[0].delta
            
            if delta.content:
                full_content += delta.content
                yield LLMStreamingEvent(
                    event=StreamingEventType.TOKEN,
                    data=delta.content
                )
            
            if delta.tool_calls:
                for tool_call in delta.tool_calls or []:
                    
                    index = tool_call.index

                    if index not in final_tool_calls:
                        final_tool_calls[index] = tool_call

                    final_tool_calls[index].function.arguments += tool_call.function.arguments # type: ignore
            
            if chunk.choices[0].finish_reason == "tool_calls":
                yield LLMStreamingEvent(
                    event=StreamingEventType.TOOL_CALL,
                    data="",
                    tool_calls=[ToolCall(
                        name=tool_call.function.name or "" if tool_call.function else "",
                        tool_call_id=tool_call.id or "",
                        arguments=json.loads(tool_call.function.arguments or "{}" if tool_call.function else "{}")
                    ) for tool_call in final_tool_calls.values()]
                )
            
            if chunk.choices[0].finish_reason == "stop":
                yield LLMStreamingEvent(
                    event=StreamingEventType.COMPLETED,
                    data=full_content
                )

In [31]:
class OpenAIEmbeddingModel(EmbeddingModel):
    def __init__(self, model_name: openai.types.EmbeddingModel, api_key: str | None = os.getenv("OPENAI_API_KEY")) -> None:
        self.model_name = model_name
        self.api_key = api_key 
        self.client = openai.OpenAI(api_key=self.api_key)

    def embed_texts(self, texts: list[str]) -> list[list[float]]:
        if not texts:
            return []
        print(f"Generating embeddings given model {self.model_name} for texts: {texts}")
        response = self.client.embeddings.create(
            model=self.model_name,
            input=texts
        )
        embeddings = [data.embedding for data in response.data]
        return embeddings

In [32]:
def cosine_similarity(vec1: list[float], vec2: list[float]) -> float:
    dot_product = sum(a * b for a, b in zip(vec1, vec2))
    magnitude1 = sum(a * a for a in vec1) ** 0.5
    magnitude2 = sum(b * b for b in vec2) ** 0.5
    if magnitude1 == 0 or magnitude2 == 0:
        return 0.0
    return dot_product / (magnitude1 * magnitude2)

def elbow_method(similarities: list[float]) -> float:
    if not similarities:
        return 0.0

    n = len(similarities)
    if n < 3:
        # if there are less than 3 points, return the minimum similarity
        return min(similarities)

    # Normalize indices to [0, 1] range for better distance calculation
    points = [(i / (n - 1), sim) for i, sim in enumerate(similarities)]
    start = points[0]
    end = points[-1]

    max_distance = -1.0
    elbow_index = 0

    # Calculate perpendicular distance from each point to the line connecting start and end
    for i in range(1, n - 1):
        point = points[i]
        
        # Vector from start to end
        line_vec_x = end[0] - start[0]
        line_vec_y = end[1] - start[1]
        
        # Vector from start to point
        point_vec_x = point[0] - start[0]
        point_vec_y = point[1] - start[1]
        
        # Calculate perpendicular distance using cross product formula
        line_length_sq = line_vec_x ** 2 + line_vec_y ** 2
        
        if line_length_sq == 0:
            distance = 0.0
        else:
            # Perpendicular distance = |cross product| / |line length|
            cross_product = abs(line_vec_x * point_vec_y - line_vec_y * point_vec_x)
            distance = cross_product / (line_length_sq ** 0.5)

        if distance > max_distance:
            max_distance = distance
            elbow_index = i

    return similarities[elbow_index]

class EmbeddingBasedToolRouter(ToolRouter):
    def __init__(self, tool_handler: ToolHandler, embedding_model: EmbeddingModel, threshold: float = 0.0) -> None:
        super().__init__(tool_handler=tool_handler)
        self.embedding_model = embedding_model
        self.tool_embeddings: dict[str, list[float]] = {}
        self.threshold = threshold
        self.add_tools(list(tool_handler.tool_mappings.values()))
        
    def add_tools(self, tools: list[Callable[..., str]]) -> None:
        """Add a tool to the router and compute its embedding."""
        super().add_tools(tools)
        tool_descriptions = [f"{tool.__name__}: {tool.__doc__ or ''}" for tool in tools]
        tool_embeddings = self.embedding_model.embed_texts(tool_descriptions)
        for tool, tool_embedding in zip(tools, tool_embeddings):
            name = tool.__name__
            self.tool_embeddings[name] = tool_embedding
        print("Added tools with embeddings:", list(self.tool_embeddings.keys()))
        
    def retrieve(self, query: str) -> list[Callable[..., str]]:
        query_embedding = self.embedding_model.embed_texts([query])[0]

        similarities: list[tuple[float, str]] = []
        for tool_name, tool_embedding in self.tool_embeddings.items():
            similarity = cosine_similarity(query_embedding, tool_embedding)
            similarities.append((similarity, tool_name))

        similarities.sort(reverse=True)

        print("Tool similarities:", similarities)

        elbow_threshold = elbow_method([sim for sim, _ in similarities])
        final_threshold = max(self.threshold, elbow_threshold)

        top_tools = [self.tool_handler.tool_mappings[tool_name] for similarity, tool_name in similarities if similarity >= final_threshold]

        print("Retrieved tools:", [tool.__name__ for tool in top_tools])
        
        return top_tools

In [33]:
class InMemoryConversationRepository(ConversationRepository):
    def __init__(self) -> None:
        self.conversations: dict[str, list[MessageLike]] = {}

    def get_conversation_history(self, conversation_id: str) -> list[MessageLike]:
        return self.conversations.get(conversation_id, [])

    def add_messages(self, conversation_id: str, messages: list[MessageLike]) -> None:
        if conversation_id not in self.conversations:
            self.conversations[conversation_id] = []
        self.conversations[conversation_id].extend(messages)

In [34]:

class ToolKit:
    def registry(self) -> list[Callable[..., str]]:
        """Return a list of all available tools."""
        return [
            self.get_weather,
            self.get_user_profile,
            self.create_reminder,
            self.get_news,
            self.code_interpreter,
            self.get_joke,
            self.add,
            self.subtract,
            self.multiply,
            self.divide,
        ]
    
    def get_weather(self, location: str) -> str:
        """Get the current weather for a given location."""
        # Dummy implementation for illustration
        return f"The current weather in {location} is sunny with a temperature of 25Â°C."

    def get_user_profile(self, user_id: str) -> str:
        """Get the user profile information for a given user ID."""
        # Dummy implementation for illustration
        return f"User {user_id} is a 30-year-old software developer from San Francisco."

    def create_reminder(self, user_id: str, reminder_text: str, time: str) -> str:
        """Create a reminder for the user."""
        # Dummy implementation for illustration
        return f"Reminder for user {user_id}: '{reminder_text}' at {time} has been created."

    def get_news(self, topic: str) -> str:
        """Get the latest news on a given topic."""
        # Dummy implementation for illustration
        return f"The latest news on {topic} is that everything is going great!"

    def code_interpreter(self, code: str) -> str:
        """Execute the given code and return the output."""
        try:
            # WARNING: Using eval/exec can be dangerous. This is just for illustration.
            local_vars = {}
            exec(code, {}, local_vars)
            return str(local_vars.get('result', 'No result variable defined.'))
        except Exception as e:
            return f"Error executing code: {e}"
        
    def get_joke(self) -> str:
        """Get a random joke."""
        # Dummy implementation for illustration
        return "Why did the scarecrow win an award? Because he was outstanding in his field!"

    def add(self, a: float, b: float) -> str:
        """Add two numbers."""
        return str(a + b)

    def subtract(self, a: float, b: float) -> str:
        """Subtract two numbers."""
        return str(a - b)

    def multiply(self, a: float, b: float) -> str:
        """Multiply two numbers."""
        return str(a * b)

    def divide(self, a: float, b: float) -> str:
        """Divide two numbers."""
        if b == 0:
            return "Error: Division by zero."
        return str(a / b)

In [35]:
embedding_model = OpenAIEmbeddingModel(model_name="text-embedding-3-small")

tool_kit = ToolKit()

tool_handler = ToolHandler(tools=tool_kit.registry())

tool_router = EmbeddingBasedToolRouter(tool_handler=tool_handler, embedding_model=embedding_model)

llm = OpenAILLM(model_name="gpt-4.1-mini")

agent = Agent(llm=llm, tool_handler=tool_handler)

conversation_repository = InMemoryConversationRepository()

context_middleware = ContextMiddleware(
    conversation_repository=conversation_repository,
    tool_router=tool_router
)

agent_service = AgentService(
    agent=agent,
    context_middleware=context_middleware
)

Generating embeddings given model text-embedding-3-small for texts: ['get_weather: Get the current weather for a given location.', 'get_user_profile: Get the user profile information for a given user ID.', 'create_reminder: Create a reminder for the user.', 'get_news: Get the latest news on a given topic.', 'code_interpreter: Execute the given code and return the output.', 'get_joke: Get a random joke.', 'add: Add two numbers.', 'subtract: Subtract two numbers.', 'multiply: Multiply two numbers.', 'divide: Divide two numbers.']
Added tools with embeddings: ['get_weather', 'get_user_profile', 'create_reminder', 'get_news', 'code_interpreter', 'get_joke', 'add', 'subtract', 'multiply', 'divide']


In [36]:
import timeit

start_time = timeit.default_timer()

gen = agent_service.handle_request(
    conversation_id="conv1",
    query="What's 13 + 24 * 2 - 5 / (1 + 1)?"
)

first_token = True
for event in gen:
    if first_token:
        elapsed = timeit.default_timer() - start_time
        print(f"\n[First token received in {elapsed:.2f} seconds]\n")
        first_token = False
        
    if event.event == AgentStreamingEventType.TOKEN:
        print(event.data, end="", flush=True)
    elif event.event == AgentStreamingEventType.ACTION and event.action_id:
        print(f"\n[Action] {event.data}\n")
    elif event.event == AgentStreamingEventType.ACTION_RESULT:
        print(f"\n[Action Result] {event.data}\n")
    elif event.event == AgentStreamingEventType.COMPLETED:
        print("\n[Completed]")

Generating embeddings given model text-embedding-3-small for texts: ["What's 13 + 24 * 2 - 5 / (1 + 1)?"]
Tool similarities: [(0.41344234278303993, 'multiply'), (0.3905658599540017, 'subtract'), (0.37787729461485337, 'add'), (0.37557557448783335, 'divide'), (0.19657426433152653, 'code_interpreter'), (0.14942267334605042, 'get_joke'), (0.11217535548130472, 'create_reminder'), (0.05213801635570409, 'get_weather'), (0.04245934768334578, 'get_user_profile'), (0.03898759808035919, 'get_news')]
Retrieved tools: ['multiply', 'subtract', 'add', 'divide']
Agent running with messages: [{'content': "What's 13 + 24 * 2 - 5 / (1 + 1)?", 'role': 'user'}]
Generating text with model gpt-4.1-mini and tools: ['multiply', 'subtract', 'add', 'divide']

[First token received in 3.45 seconds]


[Action] Calling tool multiply with arguments {'a': 24, 'b': 2}


[Action] Calling tool add with arguments {'a': 1, 'b': 1}

Agent executing tool: multiply with arguments {'a': 24, 'b': 2}

[Action Result] 48

Agent 