In [1]:
import os
from datetime import datetime, timezone
from dotenv import load_dotenv

import asyncio
from typing import Any, Union, Optional
from pydantic import SecretStr, BaseModel, Field

from langchain.schema import HumanMessage, AIMessage, SystemMessage
from langchain.callbacks.base import AsyncCallbackHandler

from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, SystemMessagePromptTemplate, HumanMessagePromptTemplate
from langchain_core.runnables import ConfigurableField#, ConfigurableFieldSpec
# from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.tools import tool

from langchain_openai import ChatOpenAI

In [2]:
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")

### LLM & Prompt

In [3]:
# LLM and Prompt Setup
llm_agent = ChatOpenAI(
    model="gpt-4.1-nano",
    temperature=0.0,
    streaming=True,
).configurable_fields(
    callbacks=ConfigurableField(
        id="callbacks",
        name="callbacks",
        description="A list of callbacks to use for streaming",
    )
)

system_prompt = """
You are a conversational AI assistant that relies solely on the conversation history and tool outputs as your sources of truth.
Always use tools to answer the user's current question not previous questions before responding.
When you have enough information, use the final_answer tool to provide the final response.

Guidelines:
1. Use only chat history or tool results — never invent or assume facts.
2. If information is missing, ask one short clarifying question.
3. If the user changes topic or interrupts, handle it naturally and retain relevant context.
4. Be concise, factual, and context-aware.
5. When resuming after an interruption, reuse past context only if relevant.
6. When exucting a calculation take note of rules of arithmetic order of operations:
    a. Parentheses — " ( " , " ) "
    b. Exponentiation — " ^ " , " ** "
    c. Multiplication and division — " * " , " / "
    d. Addition and subtraction — " + " , " - "

Your goal is to respond clearly, naturally, and accurately across turns.
"""

prompt = ChatPromptTemplate.from_messages([
    ("system", system_prompt),
    MessagesPlaceholder(variable_name="chat_history"),
    ("human", "{input}"),
    MessagesPlaceholder(variable_name="agent_scratchpad"),
])

### Chat Memory

In [13]:
class ConversationSummaryBufferMemory_custom(BaseChatMessageHistory, BaseModel):
    """
    Based on number of messages. Where if number of messages is more than k, 
    pop oldest messages and create a new summary by adding information from poped messages.
    """
    messages: list[BaseMessage] = Field(default_factory=list)
    llm: Any = None
    k: int = Field(default_factory=int)

    def __init__(self, llm: Any, k: int):
        super().__init__(llm=llm, k=k)
        # print(f"Initializing ConversationSummaryBufferMemory_custom with k={k}")

    async def add_messages(self, messages: list[BaseMessage]) -> None:
        """Add messages to the history, 
        keep only the last 'k' messages and 
        generate new summary by combining information from dropped messages.
        """
        existing_summary: SystemMessage | None = None
        old_messages: list[BaseMessage] | None = None

        # check if there is already a summary message
        if len(self.messages) > 0 and isinstance(self.messages[0], SystemMessage):
            # print(">> Found existing summary")
            existing_summary = self.messages.pop(0) # remove old summary from messages

        # add the new messages to the history
        self.messages.extend(messages)

        # check if there is too many messages
        if len(self.messages) > self.k:
            # print(
            #     f">> Found {len(self.messages)} messages, dropping "
            #     f"oldest {len(self.messages) - self.k} messages.")
            
            # pull out the oldest messages
            num_to_drop = len(self.messages) - self.k
            old_messages = self.messages[:num_to_drop] # self.messages[:self.k]

            # keep only the most recent messages
            self.messages = self.messages[-self.k:]

        # if no old_messages, no new info to update the summary
        if old_messages is None:
            # print(">> No old messages to update summary with")
            return
        
        # construct the summary chat messages
        summary_prompt = ChatPromptTemplate.from_messages([
            SystemMessagePromptTemplate.from_template(
                "Given the existing conversation summary and the new messages, "
                "generate a new summary of the conversation. Ensuring to maintain "
                "as much relevant information as possible."
            ),
            HumanMessagePromptTemplate.from_template(
                "Existing conversation summary:\n{existing_summary}\n\n"
                "New messages:\n{old_messages}"
            )
        ])

        # format the messages and invoke the LLM
        new_summary = self.llm.invoke(
            summary_prompt.format_messages(
                existing_summary=existing_summary,
                old_messages=old_messages
            )
        )

        # call synchronous llm.invoke in a thread so we don't block the event loop:
        # loop = asyncio.get_running_loop()
        
        # new_summary = await loop.run_in_executor(
        #     None,  # default ThreadPoolExecutor
        #     lambda: self.llm.invoke(
        #         summary_prompt.format_messages(
        #             existing_summary=existing_summary,
        #             old_messages=old_messages
        #         )
        #     )
        # )

        # print(f">> New summary: {new_summary.content}")
        # prepend the new summary to the history
        self.messages = [SystemMessage(content=new_summary.content)] + self.messages


    def clear(self) -> None:
        """Clear the history."""
        self.messages = []

# function to get memory for specific session id
def get_chat_history(session_id: str, llm: ChatOpenAI, k: int = 4) -> ConversationSummaryBufferMemory_custom:
    # print(f"get_chat_history called with session_id={session_id} and k={k}")
    if session_id not in chat_map:
        # if session ID doesn't exist, create a new chat history
        chat_map[session_id] = ConversationSummaryBufferMemory_custom(llm=llm, k=k)
    # remove anything beyond the last
    return chat_map[session_id]

chat_map = {}
llm_memory = ChatOpenAI(temperature=0.0, model="gpt-4.1-nano")

### Tools

In [5]:
# Tools definition
# note: all tools as async to simplify later code
@tool
async def add(x: float, y: float) -> float:
    """Add 'x' and 'y'."""
    return x + y

@tool
async def multiply(x: float, y: float) -> float:
    """Multiply 'x' and 'y'."""
    return x * y

@tool
async def exponentiate(x: float, y: float) -> float:
    """Raise 'x' to the power of 'y'."""
    return x ** y

@tool
async def subtract(x: float, y: float) -> float:
    """Subtract 'x' from 'y'."""
    return y - x

@tool
async def divide(x: float, y: float) -> float:
    """Divide 'x' by 'y'."""
    if y > 0:
        return x / y
    else:
        return "Division by error y must be more than 0"

@tool
async def final_answer(answer: str, tools_used: list[str]) -> dict[str, str | list[str]]:
    """Use this tool to provide a final answer to the user."""
    return {"answer": answer, "tools_used": tools_used}

tools = [add, subtract, multiply, exponentiate, divide, final_answer]
# note when we have sync tools we use tool.func, when async we use tool.coroutine
name2tool = {tool.name: tool.coroutine for tool in tools}

### Streaming Callback

In [6]:
# Streaming Handler
class QueueCallbackHandler(AsyncCallbackHandler):
    """Callback handler that puts tokens into a queue."""

    def __init__(self, queue: asyncio.Queue):
        self.queue = queue
        self.final_answer_seen = False

    async def __aiter__(self): # outputs tokens
        while True:
            if self.queue.empty():
                await asyncio.sleep(0.1)
                continue
            token_or_done = await self.queue.get()
            if token_or_done == "<<DONE>>":
                # this means we're done
                return
            if token_or_done:
                yield token_or_done
    
    async def on_llm_new_token(self, *args, **kwargs) -> None:
        """Put new token in the queue."""

        #print(f"on_llm_new_token: {args}, {kwargs}")
        chunk = kwargs.get("chunk")
        # check for final_answer tool call
        if chunk and chunk.message.additional_kwargs.get("tool_calls"):
            if chunk.message.additional_kwargs["tool_calls"][0]["function"]["name"] == "final_answer":
                # this will allow the stream to end on the next `on_llm_end` call
                self.final_answer_seen = True
                
        self.queue.put_nowait(kwargs.get("chunk")) #store tokens into queue
    
    async def on_llm_end(self, *args, **kwargs) -> None:
        """Put DONE in the queue to signal completion."""

        #print(f"on_llm_end: {args}, {kwargs}")
        # this should only be used at the end of our agent execution, however LangChain
        # will call this at the end of every tool call, not just the final tool call
        # so we must only send the "done" signal if we have already seen the final_answer

        if self.final_answer_seen:
            self.queue.put_nowait("<<DONE>>")
        else:
            self.queue.put_nowait("<<STEP_END>>")

async def execute_tool(tool_call: AIMessage) -> ToolMessage:
    tool_name = tool_call.tool_calls[0]["name"]
    tool_args = tool_call.tool_calls[0]["args"]
    tool_out = await name2tool[tool_name](**tool_args)
    return ToolMessage(
        content=f"{tool_out}",
        tool_call_id=tool_call.tool_calls[0]["id"]
    )

### Agent Executor

In [14]:
# Agent Executor
class CustomAgentExecutor:
    def __init__(self, max_iterations: int = 5):
        self.chat_history: list[BaseMessage] = []
        self.max_iterations = max_iterations
        self.agent = (
            {
                "input": lambda x: x["input"],
                "chat_history": lambda x: x["chat_history"],
                "agent_scratchpad": lambda x: x.get("agent_scratchpad", [])
            }
            | prompt
            | llm_agent.bind_tools(tools, tool_choice="required")
        )

        # In-memory trace list for inspection
        self.decision_trace: list[dict] = []

    async def invoke(self, input: str, 
                     streamer: QueueCallbackHandler, 
                     verbose: bool = False, 
                     chat_memory: Optional[Union[list[BaseMessage], object]] = None, 
                     session_id: str = "session_id_00") -> dict:
        # --- pick the chat history container to use for this invocation ---
        # support memory object (with .messages) or plain list
        if chat_memory is None:
            # fallback (existing behavior)
            chat_container = self.chat_history
            use_memory_api = False
        else:
            # detect if memory object (duck-typing)
            if hasattr(chat_memory, "messages") and hasattr(chat_memory, "add_messages"):
                chat_container = chat_memory  # memory object
                use_memory_api = True

            elif isinstance(chat_memory, list):
                chat_container = chat_memory
                use_memory_api = False
            else:
                # Unexpected type — fall back to list view if possible
                raise TypeError("chat_memory must be a list or a memory object with .messages/.add_messages")
                                
        # invoke the agent but we do this iteratively in a loop until reach a final answer
        count = 0
        final_answer: str | None = None
        agent_scratchpad: list[AIMessage | ToolMessage] = []
        # streaming function
        async def stream(query: str) -> list[AIMessage]:
            # get the current messages list to pass to the agent prompt
            if use_memory_api:
                history_for_prompt = chat_container.messages
            else:
                history_for_prompt = chat_container

            configured_agent = self.agent.with_config(
                callbacks=[streamer]
            )
            # Initialize the output dictionary that will be populating with streamed output
            outputs = []
            # now begin streaming
            async for token in configured_agent.astream({
                "input": query,
                "chat_history": history_for_prompt,
                "agent_scratchpad": agent_scratchpad
            }):
                tool_calls = token.additional_kwargs.get("tool_calls")
                if tool_calls: # -> outputs = [tool1, tool2, tool3]
                    # first check if have a tool call id - this indicates a new tool
                    if tool_calls[0]["id"]:
                        outputs.append(token)
                    else:
                        outputs[-1] += token
                else:
                    pass
            return [
                AIMessage(
                    content=x.content,
                    tool_calls=x.tool_calls,
                    tool_call_id=x.tool_calls[0]["id"]
                ) for x in outputs
            ]

        while count < self.max_iterations:
            # invoke a step for the agent to generate a tool call
            tool_calls = await stream(query=input)
            # gather tool execution coroutines
            tool_obs = await asyncio.gather(
                *[execute_tool(tool_call) for tool_call in tool_calls]
            )
            # append tool calls and tool observations to the scratchpad in order
            id2tool_obs = {tool_call.tool_call_id: tool_obs for tool_call, tool_obs in zip(tool_calls, tool_obs)}
            for tool_call in tool_calls:
                agent_scratchpad.extend([
                    tool_call,
                    id2tool_obs[tool_call.tool_call_id]
                ])

            # ------------------ TOOL-USAGE LOGGING ------------------
            # Only log when a tool was requested this iteration
            if tool_calls:
                # Use first tool_call as representative (agent may call multiple tools)
                tc = tool_calls[0]
                tool_name = tc.tool_calls[0].get("name")
                tool_args = tc.tool_calls[0].get("args", {})
                obs = id2tool_obs.get(tc.tool_call_id)
                tool_result_summary = getattr(obs, "content", str(obs)) if obs is not None else None

                trace_entry = {
                    "session_id": session_id,
                    "turn": count + 1,
                    "query": input,
                    "tool": tool_name,
                    "tool_args": tool_args,
                    "tool_result_summary": tool_result_summary,
                    "timestamp": datetime.now(timezone.utc) #datetime.now(datetime.timezone.utc) + "Z"
                }

                # keep the trace in memory and print it (for screenshots / manual inspection)
                self.decision_trace.append(trace_entry)
                # pretty print the trace to console (one-line)
                # print("PLANNER TRACE:", trace_entry)
            
            count += 1
            # if the tool call is the final answer tool, then stop
            found_final_answer = False
            for tool_call in tool_calls:
                if tool_call.tool_calls[0]["name"] == "final_answer":
                    final_answer_call = tool_call.tool_calls[0]
                    final_answer = final_answer_call["args"]["answer"]
                    found_final_answer = True
                    break
            
            # Only break the loop if found a final answer
            if found_final_answer:
                break
            
        # --- write final messages back into the chat memory (or list) ---
        human_msg = HumanMessage(content=input)
        ai_msg = AIMessage(content=final_answer if final_answer else "No answer found")

        # If the agent exited due to reaching max iteration limit, ensure the streamer stops cleanly
        if not found_final_answer and count >= self.max_iterations:
            print(f"[WARN] Max iteration limit ({self.max_iterations}) reached — stopping agent loop.")
            try:
                streamer.queue.put_nowait("<<DONE>>")
            except Exception:
                pass

        if use_memory_api:
            # use memory object's API (ConversationSummaryBufferMemory_custom.add_messages)
            # Note: add_messages expects a list[BaseMessage]
            try:
                await chat_container.add_messages([human_msg, ai_msg])
                
            except Exception as e:
                # fallback: append to the messages list
                chat_container.messages.extend([human_msg, ai_msg])
        else:
            # plain list
            chat_container.extend([human_msg, ai_msg])

        # return the final answer in dict form
        return final_answer_call if final_answer else {"answer": "No answer found or iteration limit reached"}

# Initialize agent executor
agent_executor = CustomAgentExecutor() 

### Token Generator

In [None]:
# streaming function
async def token_generator(
        content: str, 
        streamer: QueueCallbackHandler, 
        chat_memory: Optional[Union[list[BaseMessage], object]] = None, 
        session_id: str = "session_id_00"):
    
    task = asyncio.create_task(agent_executor.invoke(
        input=content,
        streamer=streamer,
        verbose=True,  # set to True to see verbose output in console
        chat_memory = chat_memory,
        session_id = session_id
    ))
    
    # initialize various components to stream
    async for token in streamer:
        try:
            if token == "<<STEP_END>>":
                # send end of step token
                # yield "</step>"
                print("</step>", flush=True)

            elif tool_calls := token.message.additional_kwargs.get("tool_calls"):
                if tool_name := tool_calls[0]["function"]["name"]:
                    # send start of step token followed by step name tokens
                    # yield f"<step><step_name>{tool_name}</step_name>"
                    print(f"<step><step_name>{tool_name}</step_name>", flush=True)

                if tool_args := tool_calls[0]["function"]["arguments"]:
                    # tool args are streamed directly, ensure it's properly encoded
                    # yield tool_args
                    print(f"{tool_args}", end="", flush=True)

                # print(f"\n{tool_calls[0]}", end="", flush=True)
                    
        except Exception as e:
            print(f"Error streaming token: {e}")
            continue

    final_answer_call = await task
    print("\n")
    print(final_answer_call)
    print("\n")
    
    if final_answer_call["args"]:
        print(f"Bot Output: {final_answer_call["args"]["answer"]}")
        print(f"Tools Used: {final_answer_call["args"]["tools_used"]}")

### Print Chat History Function

In [9]:
# Prints Conversation History
def print_history(session_id: str):
    print("\n===== Conversation history =====")
    history = chat_map.get(session_id)

    if history is None:
        print("(no history)")
        return
    
    for msg in history.messages:
        if isinstance(msg, HumanMessage):
            role_label = "Human"
        elif isinstance(msg, AIMessage):
            role_label = "AI"
        elif isinstance(msg, SystemMessage):
            role_label = "System (Summary)"
        else:
            role_label = msg.__class__.__name__  # fallback to type name

        content = getattr(msg, "content", str(msg))
        print(f"\n{role_label}: {content}")
    
    print("============================\n")

### Calculator Tests

### Successful Calculation

In [15]:
queue: asyncio.Queue = asyncio.Queue()
streamer = QueueCallbackHandler(queue)
session_id = "test_1"
k = 6
chat_memory = get_chat_history(session_id=session_id, llm=llm_memory, k=k)

content = "2*(5*(3+6)/3)^3"
await token_generator(content, streamer, chat_memory=chat_memory, session_id=session_id)

# conversation history
print_history(session_id)

# tool usage logs
print("Tool Usage Log:")
for log in agent_executor.decision_trace:
    print(log)

<step><step_name>add</step_name>
{"x": 3, "y": 6}<step><step_name>divide</step_name>
{"x": 5, "y": 3}</step>
<step><step_name>multiply</step_name>
{"x": 3, "y": 6}<step><step_name>divide</step_name>
{"x": 5, "y": 3}</step>
<step><step_name>add</step_name>
{"x": 3, "y": 6}<step><step_name>divide</step_name>
{"x": 5, "y": 3}</step>
<step><step_name>multiply</step_name>
{"x": 2, "y": 1.6666666666666667}<step><step_name>exponentiate</step_name>
{"x": 9, "y": 3}</step>
<step><step_name>multiply</step_name>
{"x[WARN] Max iteration limit (5) reached — stopping agent loop.
":3.3333333333333335,"y":729}</step>


{'answer': 'No answer found or iteration limit reached'}




KeyError: 'args'