In [None]:
from __future__ import annotations

from functools import update_wrapper
from typing import Callable, Any, Optional, Dict, TypeVar, overload, Union, cast, Literal
from pydantic import create_model
from openai.types.responses.tool_param import ToolParam
from openai.types.chat import ChatCompletionToolUnionParam
import inspect
from enum import Enum

class OpenAIAPIProtocol(str, Enum):
    CHAT_COMPLETIONS = "chat_completions"
    RESPONSES = "responses"

@overload
def build_tool_schema(
    func: Callable[..., Any], 
    name: str,
    api: Literal[OpenAIAPIProtocol.RESPONSES],
    description: str = "",
) -> ToolParam: ... # type: ignore

@overload
def build_tool_schema(
    func: Callable[..., Any], 
    name: str,
    api: Literal[OpenAIAPIProtocol.CHAT_COMPLETIONS],
    description: str = "",
) -> ChatCompletionToolUnionParam: ... # type: ignore

def build_tool_schema(
    func: Callable[..., Any], 
    name: str,
    api: OpenAIAPIProtocol,
    description: str = "",
) -> Union[ToolParam, ChatCompletionToolUnionParam]:
    signature = inspect.signature(func)
    final_description = inspect.getdoc(func) or description

    model_fields = {
        name: (param.annotation, ...)
        for name, param in signature.parameters.items()
    }

    ToolArguments = create_model("ToolArguments", **model_fields) # type: ignore
    raw_schema = ToolArguments.model_json_schema() # type: ignore
    raw_schema = cast(Dict[str, Any], raw_schema)
    raw_schema.pop("title", None) # type: ignore
    raw_schema["additionalProperties"] = False

    if api == OpenAIAPIProtocol.RESPONSES:
        schema: dict[str, Any] = {
            "type": "function",
            "name": name,
            "description": final_description,
            "parameters": raw_schema,
            "strict": True,
        }
        return cast(ToolParam, schema)
    
    else:
        schema: dict[str, Any] = {
            "type": "function",
            "function": {
                "name": name,
                "description": final_description,
                "strict": True,
                "parameters": raw_schema,
            },
        }

        return cast(ChatCompletionToolUnionParam, schema)

class Tool:
    """
    Wrapper that makes a function into a tool with a schema.
    """
    __tool_wrapped__ = True

    def __init__(
        self,
        func: Optional[Callable[..., Any]] = None,
        *,
        schema: Optional[Union[ToolParam, ChatCompletionToolUnionParam]] = None,
        api: OpenAIAPIProtocol = OpenAIAPIProtocol.CHAT_COMPLETIONS,
    ) -> None:
        if func is None:
            func = self.__call__

        self._func = func

        if schema is None:
            schema = build_tool_schema(self._func, name=str(self.__class__.__name__), api=api)
        
        self.schema = schema
        update_wrapper(self, self._func)

    @classmethod
    def from_decorator(
        cls,
        func: Callable[..., Any],
        *,
        schema: Union[ToolParam, ChatCompletionToolUnionParam],
    ) -> "Tool":
        """ 
        Create a Tool instance from a function and a schema.
        This is used internally by the `tool` decorator.
        """
        tool_instance = cls(func, schema=schema)
        return tool_instance

    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        if inspect.iscoroutinefunction(self._func):
            return self._func(*args, **kwargs)
        return self._func(*args, **kwargs)

    def __repr__(self) -> str:
        return f"<Tool {self._func.__name__}>"
    
    def __name__(self) -> str:
        return self._func.__name__

T = TypeVar("T", bound=Callable[..., Any])

@overload
def tool(
    func: Callable[..., Any],
    *,
    description: str = "",
    api: OpenAIAPIProtocol = OpenAIAPIProtocol.CHAT_COMPLETIONS,
) -> Tool: ... # type: ignore

@overload
def tool(
    *,
    description: str = "",
    api: OpenAIAPIProtocol = OpenAIAPIProtocol.CHAT_COMPLETIONS,
) -> Callable[[Callable[..., Any]], Tool]: ... # type: ignore

def tool(
    func: Optional[T] = None,
    *,
    description: str = "",
    api: OpenAIAPIProtocol = OpenAIAPIProtocol.CHAT_COMPLETIONS,
) -> Union[Tool, Callable[[T], Tool]]:
    """
    Decorator to wrap a function into a Tool with OpenAI function-calling schema.
    """
    def decorator(inner_func: T) -> Tool:
        schema = build_tool_schema(inner_func, name=inner_func.__name__, api=api, description=description)

        return Tool.from_decorator(inner_func, schema=schema)

    # If used without args: @tool
    return decorator if func is None else decorator(func)

In [None]:
from abc import ABC, abstractmethod
from typing import Generator, Callable, Any, Generic, TypedDict, Required
from pydantic.dataclasses import dataclass
from openai.types.responses import ResponseInputItemParam, ResponseFunctionToolCallParam
from openai.types.responses.response_output_message_param import ResponseOutputMessageParam
from openai.types.responses.response_reasoning_item_param import ResponseReasoningItemParam, Summary
from openai.types.responses.response_input_item_param import FunctionCallOutput
from enum import Enum
import json

MessageLike = ResponseInputItemParam

class StreamingEventType(Enum):
    REASONING_TOKEN = "reasoning_token"
    REASONING = "reasoning"
    TOKEN = "token"
    TOOL_CALL = "tool_call"
    TOOL_RESULT = "tool_result"
    COMPLETED = "completed"

class AgentStreamingEventType(Enum):
    REASONING = "reasoning"
    TOKEN = "token"
    ACTION = "action"
    ACTION_RESULT = "action_result"
    COMPLETED = "completed"


@dataclass
class ToolCall:
    name: str
    tool_call_id: str
    arguments: dict[str, Any]

@dataclass
class Context:
    conversation_id: str
    user_query: str
    messages: list[MessageLike]
    tools: set[Callable[..., str]]

@dataclass
class LLMStreamingEvent:
    event: StreamingEventType
    data: str
    tool_call_id: str | None = None
    tool_calls: list[ToolCall] | None = None

@dataclass
class AgentStreamingEvent:
    event: AgentStreamingEventType
    data: str
    action_id: str | None = None

class LLM(ABC):
    @abstractmethod
    def generate_text(self, messages: list[MessageLike], tools: Optional[set[Callable[..., str]]] = None) -> Generator[LLMStreamingEvent, None, None]:
        """Generate text based on the given prompt."""
        pass

class EmbeddingModel(ABC):
    @abstractmethod
    def embed_texts(self, texts: list[str]) -> list[list[float]]:
        """Generate embedding for the given texts."""
        pass

TIn = TypeVar("TIn")
TOut = TypeVar("TOut")
TNext = TypeVar("TNext")

class Handler(ABC, Generic[TIn, TOut]):
    @abstractmethod
    def process(self, input: TIn) -> TOut:
        pass

    def handle(self, input: TIn) -> TOut:
        return self.process(input)

    def then(self, next_handler: Handler[TOut, TNext]) -> Chain[TIn, TNext]:
        return Chain(self, next_handler)

class Chain(Handler[TIn, TOut]):
    """Represents two linked handlers as a single unit."""
    def __init__(self, first: Handler[TIn, TNext], second: Handler[TNext, TOut]):
        self.first = first
        self.second = second

    def process(self, input: TIn) -> TOut:
        intermediate = self.first.handle(input)
        return self.second.handle(intermediate)

T = TypeVar("T")

class Command(ABC, Generic[T]):
    @abstractmethod
    def exec(self, input: T) -> None:
        pass

class ToolHandler(Handler[ToolCall, str]):
    def __init__(
        self, 
        tools: set[Callable[..., str]],
        on_error: Optional[Callable[[ToolCall, Exception], str]] = None,
    ) -> None:
        self.tool_mappings: dict[str, Callable[..., str]] = {}
        self.add_tools(tools)
        self.on_error = on_error

    @property
    def tools(self) -> set[Callable[..., str]]:
        return set(self.tool_mappings.values())
    
    @property
    def tool_names(self) -> set[str]:
        return set(self.tool_mappings.keys())
    
    def clear_tools(self) -> None:
        """Remove all tools from the handler."""
        self.tool_mappings = {}
    
    def get_tool(self, name: str) -> Optional[Callable[..., str]]:
        return self.tool_mappings.get(name)
        
    def add_tools(self, tools: set[Callable[..., str]]) -> None:
        """Add tools to the handler."""
        for tool in tools:
            if tool.__name__ in self.tool_mappings:
                continue
            self.tool_mappings[tool.__name__] = tool

    def set_tools(self, tools: set[Callable[..., str]]) -> None:
        """Set the tools for the handler, replacing any existing tools."""
        self.clear_tools()
        self.add_tools(tools)

    def update_tool(self, tool: Callable[..., str]) -> None:
        """Update or add a single tool in the handler."""
        self.tool_mappings[tool.__name__] = tool

    def remove_tool(self, name: str) -> None:
        """Remove a tool by name from the handler."""
        if name in self.tool_mappings:
            del self.tool_mappings[name]

    def process(self, input: ToolCall) -> str:
        return self.execute_tool(input)

    def execute_tool(self, tool_call: ToolCall) -> str:
        tool = self.tool_mappings.get(tool_call.name)
        if not tool:
            return f"Error: Tool '{tool_call.name}' not found."

        try:
            return tool(**tool_call.arguments)
        except Exception as e:
            if self.on_error:
                return self.on_error(tool_call, e)
            else:
                return f"Error executing tool '{tool_call.name}': {str(e)}"
    
class ToolRouter(ToolHandler):
    """Base class for tool routing based on queries."""
    @abstractmethod
    def retrieve(self, query: str) -> set[Callable[..., str]]:
        """Retrieve appropriate tools based on the query."""
        pass

ToolUseBehavior = Literal[
    "stop_on_tool_call", 
    "stop_on_tool_result", 
    "auto"
]

ToolName = str

class Agent:
    def __init__(
        self, 
        llm: LLM, 
        tool_handler: ToolHandler,
        tool_use_behavior: Optional[ToolUseBehavior] = "auto",
        on_tool_call: Optional[dict[ToolName, Command[ToolCall]]] = None,
        on_tool_result: Optional[dict[ToolName, Command[tuple[ToolCall, str]]]] = None,
    ) -> None:
        super().__init__()
        self.llm = llm
        self.tool_handler = tool_handler
        self.tool_use_behavior = tool_use_behavior

        self.on_tool_call = on_tool_call
        self.on_tool_result = on_tool_result

    def add_tools(self, tools: set[Callable[..., str]]) -> None:
        """Add tools to the agent."""
        self.tool_handler.add_tools(tools)

    def run(self, messages: list[MessageLike]) -> Generator[LLMStreamingEvent, None, None]:
        """Run the agent with the given prompt."""
        while True:
            for delta in self.llm.generate_text(messages, tools=self.tool_handler.tools):
                if delta.event == StreamingEventType.REASONING_TOKEN:
                    yield delta
                if delta.event == StreamingEventType.REASONING:
                    yield delta
                    messages.append( 
                        ResponseReasoningItemParam( # type: ignore
                            type="reasoning",
                            summary=[Summary(
                                type="summary_text",
                                text=delta.data
                            )],
                        )
                    )
                elif delta.event == StreamingEventType.TOOL_CALL and delta.tool_calls:
                    yield delta
                    
                    for tool_call in delta.tool_calls:
                        messages.append(
                            ResponseFunctionToolCallParam(
                                type="function_call",
                                name=tool_call.name,
                                arguments=json.dumps(tool_call.arguments),
                                call_id=tool_call.tool_call_id,
                            )
                        )

                        if self.on_tool_call:
                            command = self.on_tool_call.get(tool_call.name)
                            if command:
                                command.exec(tool_call)
                    
                    if self.tool_use_behavior == "stop_before":
                        return
                    
                    for tool_call in delta.tool_calls:
                        tool_result = self.tool_handler.execute_tool(tool_call)
                        yield LLMStreamingEvent(
                            event=StreamingEventType.TOOL_RESULT,
                            tool_call_id=tool_call.tool_call_id,
                            data=tool_result,
                        )
                        messages.append(
                            FunctionCallOutput(
                                type="function_call_output",
                                output=tool_result,
                                call_id=tool_call.tool_call_id,
                            )
                        )

                        if self.on_tool_result:
                            command = self.on_tool_result.get(tool_call.name)
                            if command:
                                command.exec((tool_call, tool_result))
                        
                    if self.tool_use_behavior == "stop_after":
                        return
                elif delta.event == StreamingEventType.COMPLETED:
                    yield LLMStreamingEvent(
                        event=StreamingEventType.COMPLETED,
                        data=delta.data,
                    )
                    messages.append(
                        ResponseOutputMessageParam( # type: ignore
                            role="assistant",
                            content=delta.data
                        )
                    )
                    return
                elif delta.event == StreamingEventType.TOKEN:
                    yield LLMStreamingEvent(
                        event=StreamingEventType.TOKEN,
                        data=delta.data,
                    )
                
class ConversationRepository(ABC):
    @abstractmethod
    def get_conversation_history(self, conversation_id: str) -> list[MessageLike]:
        """Retrieve the conversation history for the given conversation ID."""
        pass

    @abstractmethod
    def add_messages(self, conversation_id: str, messages: list[MessageLike]) -> None:
        """Add messages to the conversation history."""
        pass

class ContextPrepareParam(TypedDict):
    conversation_id: Required[str]
    query: Required[str]

class ContextMiddleware(Handler[ContextPrepareParam, Context]):
    def __init__(
        self, 
        conversation_repository: ConversationRepository,
        tool_router: ToolRouter
    ) -> None:
        super().__init__()
        self.conversation_repository = conversation_repository
        self.tool_router = tool_router
        
    def process(self, input: ContextPrepareParam) -> Context:
        """Prepare context for the LLM based on conversation ID and query."""
        history = self.conversation_repository.get_conversation_history(input["conversation_id"])
        tools = self.tool_router.retrieve(input["query"])

        history.append({"role": "user", "content": input["query"]})
        
        return Context(
            conversation_id=input["conversation_id"],
            user_query=input["query"],
            messages=history,
            tools=tools
        )

class AgentService:
    def __init__(
        self, 
        agent: Agent,
        context_middleware: ContextMiddleware
    ) -> None:
        super().__init__()
        self.agent = agent
        self.context_middleware = context_middleware

    def handle_request(self, conversation_id: str, query: str) -> Generator[AgentStreamingEvent, None, None]:
        """Handle the request by preparing context and generating text."""
        context = self.context_middleware.process({"conversation_id": conversation_id, "query": query})
        self.agent.add_tools(context.tools)
        for delta in self.agent.run(context.messages):
            if delta.event == StreamingEventType.TOOL_CALL and delta.tool_calls:
                for tool_call in delta.tool_calls:
                    yield AgentStreamingEvent(
                        event=AgentStreamingEventType.ACTION,
                        data=f"Calling tool {tool_call.name} with arguments {tool_call.arguments}",
                        action_id=tool_call.tool_call_id or ""
                    )
            
            elif delta.event == StreamingEventType.TOOL_RESULT:
                yield AgentStreamingEvent(
                    event=AgentStreamingEventType.ACTION_RESULT,
                    data=delta.data,
                    action_id=delta.tool_call_id or ""
                )

            elif delta.event == StreamingEventType.COMPLETED:
                self.context_middleware.conversation_repository.add_messages(
                    conversation_id,
                    [{"role": "user", "content": query}, {"role": "assistant", "content": delta.data}]
                )
                yield AgentStreamingEvent(
                    event=AgentStreamingEventType.COMPLETED,
                    data=delta.data
                )

            else:
                yield AgentStreamingEvent(
                    event=AgentStreamingEventType.TOKEN,
                    data=delta.data
                )

In [None]:
class StringToIntHandler(Handler[str, int]):
    def process(self, input: str) -> int:
        return len(input)

class IntToFloatHandler(Handler[int, float]):
    def process(self, input: int) -> float:
        return input / 2

class FloatToStringHandler(Handler[float, str]):
    def process(self, input: float) -> str:
        return f"Result is {input}"

# Chain them
chain = StringToIntHandler()\
        .then(IntToFloatHandler())\
        .then(FloatToStringHandler())\

result = chain.handle("hello")
print(result)  # Result is 2.5

In [None]:
from openai import OpenAI
from openai.types.responses import ResponseStreamEvent

client = OpenAI()

# generator = client.responses.create(
#     model="gpt-5.1",
#     input="What's the difference between MCP Servers and traditional function calling?",
#     reasoning={
#         "effort": "medium",
#         "summary": "auto"
#     },
#     stream=True
# )

# events: list[ResponseStreamEvent] = []
# for event in generator:
#     events.append(event)
#     print(event)

In [None]:
from openai.types.responses import (
    ResponseTextDeltaEvent, 
    ResponseReasoningSummaryTextDeltaEvent,
    ResponseCompletedEvent,
    ResponseFunctionToolCall,
    ResponseReasoningSummaryTextDoneEvent
)
from openai.types.chat_model import ChatModel
from openai.types.shared_params import Reasoning
from typing import Sequence
import re


class OpenAILLMResponsesAPI(LLM):
    def __init__(
        self, 
        client: OpenAI, 
        model: ChatModel = "gpt-4o-mini",
        allow_parallel_tool_calls: bool = True,
        temperature: float = 0.7,
        top_p: float = 1.0,
        reasoning: Optional[Reasoning] = None,
    ) -> None:
        super().__init__()
        self.client = client
        self.model = model
        self.allow_parallel_tool_calls = allow_parallel_tool_calls
        self.temperature = temperature
        self.top_p = top_p
        self.reasoning = reasoning

        # If reasoning is enabled, disable temperature and top_p as they are not supported
        if self._is_reasoning_model():
            self.temperature = None
            self.top_p = None
        else:
            self.reasoning = None
    
    def _is_reasoning_model(self) -> bool:
        pattern = r"(gpt-5\S*|o\S*)"
        return re.match(pattern, self.model) is not None

    def generate_text(self, messages: list[MessageLike], tools: Optional[set[Callable[..., str]]] = None) -> Generator[LLMStreamingEvent, None, None]:
        """Generate text based on the given prompt."""
        openai_tools: Sequence[ToolParam] = []
        for t in tools or []:
            wrapped_tool = tool(t, api=OpenAIAPIProtocol.RESPONSES)
            openai_tools.append(cast(ToolParam, wrapped_tool.schema))
        
        generator = self.client.responses.create(
            model=self.model,
            input=messages,
            tools=openai_tools,
            parallel_tool_calls=self.allow_parallel_tool_calls,
            reasoning=self.reasoning,
            temperature=self.temperature,
            top_p=self.top_p,
            stream=True
        )

        final_response = ""
        final_tool_calls: list[ToolCall] = []
        for event in generator:
            if isinstance(event, ResponseReasoningSummaryTextDeltaEvent):
                yield LLMStreamingEvent(
                    event=StreamingEventType.REASONING_TOKEN,
                    data=event.delta
                )
            elif isinstance(event, ResponseTextDeltaEvent):
                final_response += event.delta
                yield LLMStreamingEvent(
                    event=StreamingEventType.TOKEN,
                    data=event.delta
                )
            elif isinstance(event, ResponseReasoningSummaryTextDoneEvent):
                yield LLMStreamingEvent(
                    event=StreamingEventType.REASONING,
                    data=event.text
                )
            elif isinstance(event, ResponseCompletedEvent):
                for output_item in event.response.output:
                    if isinstance(output_item, ResponseFunctionToolCall):
                        final_tool_calls.append(
                            ToolCall(
                                name=output_item.name,
                                tool_call_id=output_item.call_id,
                                arguments=json.loads(output_item.arguments)
                            )
                        )
                if final_tool_calls:
                    yield LLMStreamingEvent(
                        event=StreamingEventType.TOOL_CALL,
                        data="",
                        tool_calls=final_tool_calls
                    )
                else:
                    yield LLMStreamingEvent(
                        event=StreamingEventType.COMPLETED,
                        data=final_response,
                    )

In [None]:
llm = OpenAILLMResponsesAPI(
    client=client,
    model="gpt-4o-mini",
    allow_parallel_tool_calls=True, 
    reasoning=Reasoning(effort="medium", summary="auto")
)

In [None]:
def get_weather(location: str) -> str:
    return f"The weather in {location} is sunny with a high of 75Â°F."

def get_user_info(user_name: str) -> str:
    return f"User {user_name} is a software developer from San Francisco."

In [None]:
# for event in llm.generate_text([{"role": "user", "content": "What's the difference between MCP Servers and traditional function calling API?"}], tools=[get_weather, get_user_info]):
#     print(event)

In [None]:
import timeit

agent = Agent(
    llm=llm, 
    tool_handler=ToolHandler(tools={get_weather, get_user_info})
)

start_time = timeit.default_timer()

first_token_flag = True
messages: list[MessageLike] = []

while True:
    user_input = input("Enter your query (or 'exit' to quit): ")
    if user_input.lower() == 'exit':
        break
    messages.append({"role": "user", "content": user_input})
    for event in agent.run(messages):
        if first_token_flag:
            elapsed = timeit.default_timer() - start_time
            print(f"Time to first token: {elapsed:.2f} seconds")
            first_token_flag = False
        print(event)