In [9]:
from langchain.memory import ConversationBufferWindowMemory
from langchain.agents import load_tools, AgentType, initialize_agent
from langchain.chat_models import ChatOpenAI
from dotenv import load_dotenv
import os
from langchain.callbacks import FinalStreamingStdOutCallbackHandler

from helper import *



llm = ChatOpenAI(
    openai_api_key=api_key,
    temperature=0.0,
    model_name="gpt-4o",
    streaming=True,  # ! important
    callbacks=[
        FinalStreamingStdOutCallbackHandler(
            answer_prefix_tokens = ["Final", "Answer"]
        )
    ]
)

# initialize conversational memory
memory = ConversationBufferWindowMemory(
    memory_key="chat_history",
    k=5,
    return_messages=True,
    output_key="output"
)

# create a single tool to see how it impacts streaming
tools = load_tools(["llm-math"], llm=llm)

# initialize the agent
agent = initialize_agent(
    agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION,
    tools=tools,
    llm=llm,
    memory=memory,
    verbose=True,
    max_iterations=3,
    early_stopping_method="generate",
    return_intermediate_steps=False
)

In [12]:
import sys

class CallbackHandler(StreamingStdOutCallbackHandler):
    def __init__(self):
        self.content: str = ""
        self.final_answer: bool = False

    def on_llm_new_token(self, token: str, **kwargs: any) -> None:
        self.content += token
        if "Final Answer" in self.content:
            # now we're in the final answer section, but don't print yet
            self.final_answer = True
            self.content = ""
        if self.final_answer:
            if '"action_input": "' in self.content:
                if token not in ["}"]:
                    sys.stdout.write(token)  # equal to `print(token, end="")`
                    sys.stdout.flush()

agent.agent.llm_chain.llm.callbacks = [CallbackHandler()]

In [13]:
prompt = "what is the square root of 72?"
agent(prompt)



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m```json
{
    "action": "Calculator",
    "action_input": "sqrt(72)"
}
```[0m
Observation: [36;1m[1;3mAnswer: 8.48528137423857[0m
Thought: "The square root of 72 is approximately 8.485."
}
```[32;1m[1;3m```json
{
    "action": "Final Answer",
    "action_input": "The square root of 72 is approximately 8.485."
}
```[0m

[1m> Finished chain.[0m


{'input': 'what is the square root of 72?',
 'chat_history': [HumanMessage(content='what is the square root of 71?', additional_kwargs={}, response_metadata={}),
  AIMessage(content='The square root of 71 is approximately 8.426.', additional_kwargs={}, response_metadata={}),
  HumanMessage(content='what is the square root of 72?', additional_kwargs={}, response_metadata={}),
  AIMessage(content='The square root of 72 is approximately 8.485.', additional_kwargs={}, response_metadata={})],
 'output': 'The square root of 72 is approximately 8.485.'}

In [27]:
import requests


def get_stream(query: str):
    s = requests.Session()
    with s.post(
        "http://localhost:8000/chat",
        stream=True,
        json={"text": query}
    ) as r:
        for line in r.iter_content():
            print(line.decode("utf-8"), end="")

In [29]:
get_stream("what did I ask you before")

 "I'm sorry, I don't have the ability to recall past interactions or questions. Each session is stateless and doesn't retain memory of previous conversations."
}
```