In [None]:
import asyncio
import json

from langchain.callbacks.base import AsyncCallbackHandler
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, ToolMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.base import RunnableSerializable
from langchain_core.tools import tool
from langchain_ollama.chat_models import ChatOllama

# 1. Setup

In [None]:
llm = ChatOllama(model="llama3.1:8b")

In [None]:
@tool
def add(x: float, y: float) -> float:
    """Add 'x' and 'y'."""
    return x + y


@tool
def multiply(x: float, y: float) -> float:
    """Multiply 'x' and 'y'."""
    return x * y


@tool
def exponentiate(x: float, y: float) -> float:
    """Raise 'x' to the power of 'y'."""
    return x**y


@tool
def subtract(x: float, y: float) -> float:
    """Subtract 'x' from 'y'."""
    return y - x


tools = [add, subtract, multiply, exponentiate]
name2tool = {t.name: t.func for t in tools}

In [None]:
prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            (
                "You're a helpful assistant. When answering a user's question, "
                "you may use one of the provided tools if needed. After using a "
                "tool, you can provide a final answer directly. "
                "DO NOT use the same tool more than once."
            ),
        ),
        MessagesPlaceholder(variable_name="chat_history"),
        ("human", "{input}"),
        MessagesPlaceholder(variable_name="agent_scratchpad"),
    ]
)

# 2. Streaming with astream

In [None]:
tokens = []
async for token in llm.astream("What is NLP?"):
    tokens.append(token)
    print(token.content, end="|", flush=True)

# 3. Queue Callback Handler

In [None]:
agent: RunnableSerializable = (
    {
        "input": lambda x: x["input"],
        "chat_history": lambda x: x["chat_history"],
        "agent_scratchpad": lambda x: x.get("agent_scratchpad", []),
    }
    | prompt
    | llm.bind_tools(tools)
)

In [None]:
class QueueCallbackHandler(AsyncCallbackHandler):
    """Callback handler that puts tokens into a queue."""

    def __init__(self, queue: asyncio.Queue):
        self.queue = queue

    async def __aiter__(self):
        while True:
            if self.queue.empty():
                await asyncio.sleep(0.1)
                continue

            token_or_done = await self.queue.get()
            if token_or_done == "<<DONE>>":
                return
            if token_or_done:
                yield token_or_done

    async def on_llm_new_token(self, *args, **kwargs) -> None:
        chunk = kwargs.get("chunk")
        if chunk:
            self.queue.put_nowait(chunk)

    async def on_llm_end(self, *args, **kwargs) -> None:
        self.queue.put_nowait("<<DONE>>")

In [None]:
queue = asyncio.Queue()
streamer = QueueCallbackHandler(queue)

tokens = []


async def stream(query: str):
    response = agent.with_config(callbacks=[streamer])
    async for token in response.astream(
        {"input": query, "chat_history": [], "agent_scratchpad": []}
    ):
        tokens.append(token)
        print(token, flush=True)


await stream("What is 10 + 10")

# 4. Custom Agent Executor - astream

In [None]:
class CustomAgentExecutor:
    chat_history: list[BaseMessage]

    def __init__(self, max_iterations: int = 3):
        self.chat_history = []
        self.max_iterations = max_iterations
        self.agent: RunnableSerializable = (
            {
                "input": lambda x: x["input"],
                "chat_history": lambda x: x["chat_history"],
                "agent_scratchpad": lambda x: x.get("agent_scratchpad", []),
            }
            | prompt
            | llm.bind_tools(tools)
        )

    async def invoke(
        self, input: str, streamer: QueueCallbackHandler, verbose: bool = False
    ) -> dict:
        # invoke the agent but we do this iteratively in a loop until
        # reaching a final answer
        count = 0
        agent_scratchpad = []
        final_output = None

        while count < self.max_iterations:
            # invoke a step for the agent to generate a tool call
            async def stream(query: str):
                response = self.agent.with_config(callbacks=[streamer])
                # we initialize the output dictionary that we will be populating with
                # our streamed output
                output = None
                # now we begin streaming
                async for token in response.astream(
                    {
                        "input": query,
                        "chat_history": self.chat_history,
                        "agent_scratchpad": agent_scratchpad,
                    }
                ):
                    if output is None:
                        output = token
                    else:
                        # we can just add the tokens together as they are streamed and
                        # we'll have the full response object at the end
                        output += token

                    if verbose and token.content:
                        # we can capture various parts of the response object
                        print(f"content: {token.content}", flush=True)

                    tool_calls = token.additional_kwargs.get("tool_calls")
                    if tool_calls and verbose:
                        print(f"tool_calls: {tool_calls}", flush=True)

                return output

            output = await stream(query=input)

            if not output or not getattr(output, "tool_calls", None):
                print("Detected final answer.")
                final_output = (
                    output.content if hasattr(output, "content") else str(output)
                )
                break

            # Parse tool call
            tool_call = output.tool_calls[0]
            tool_name = tool_call["name"]
            tool_args = tool_call["args"]
            tool_args = (
                json.loads(tool_args) if isinstance(tool_args, str) else tool_args
            )
            tool_call_id = tool_call["id"]

            if verbose:
                print(f"{count}: {tool_name}({tool_args})")

            tool_out = name2tool[tool_name](**tool_args)

            # Append tool call + result
            agent_scratchpad.append(
                AIMessage(
                    content=output.content,
                    tool_calls=output.tool_calls,
                )
            )
            agent_scratchpad.append(
                ToolMessage(content=f"{tool_out}", tool_call_id=tool_call_id)
            )

            count += 1

        # Add exchange to chat history
        self.chat_history.extend(
            [HumanMessage(content=input), AIMessage(content=final_output)]
        )

        return {"answer": final_output}

In [None]:
agent_executor = CustomAgentExecutor()

queue = asyncio.Queue()
streamer = QueueCallbackHandler(queue)

out = await agent_executor.invoke("What is 10 + 10", streamer, verbose=True)

In [None]:
out