# REACT AGENT WITH MEMORY

In [None]:
%pip install --quiet langchain langchain-core langchain-community

In [2]:
import os
from dotenv import load_dotenv

load_dotenv()

# Define MariTalk API key and LLM model
MARITALK_API_KEY = os.getenv('MARITALK_API_KEY')
MARITALK_LLM_MODEL = "sabia-3"

# Define Tavily API key
TAVILY_API_KEY = os.getenv('TAVILY_API_KEY')

In [None]:
%pip install --quiet tavily-python 

In [4]:
from langchain import hub
from langchain.agents import AgentExecutor, create_react_agent
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.chat_models import ChatMaritalk

tools = [TavilySearchResults(max_results=1)]

# Get the prompt from langchain hub
prompt = hub.pull("hwchase17/react")

llm = ChatMaritalk(
    model=MARITALK_LLM_MODEL,
    api_key=MARITALK_API_KEY,
    max_tokens=1000,
)

# Construct the React agent
agent = create_react_agent(llm, tools, prompt)

# Create an agent executor by passing in the agent and tools
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)

In [None]:
agent_executor.invoke({"input": "What is Langchain?"})

# MESSAGE PERSISTENCE

In [None]:
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import START, MessagesState, StateGraph
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_community.chat_models import ChatMaritalk

# Define the Maritalk model
model = ChatMaritalk(
    model=MARITALK_LLM_MODEL,
    api_key=MARITALK_API_KEY,
    max_tokens=1000,
)

# Define the prompt template
prompt_template = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a helpful assistant!"),
        MessagesPlaceholder(variable_name="messages"),
    ]
)
print(prompt_template)

# Define a new graph
workflow = StateGraph(state_schema=MessagesState)


# Define the function that calls the model
def call_model(state: MessagesState):
    prompt = prompt_template.invoke(state)
    response = model.invoke(prompt)
    return {"messages": response}

# Define the (single) node in the graph
workflow.add_edge(START, "model")
workflow.add_node("model", call_model)

# Add memory
memory = MemorySaver()
app = workflow.compile(checkpointer=memory)

In [None]:
# Define the configuration
config = {"configurable": {"thread_id": "thread_id_1"}}
query = "Hi! I'm an undergraduate computer science student."

input_messages = [HumanMessage(query)]
output = app.invoke({"messages": input_messages}, config)
output["messages"][-1].pretty_print()

In [None]:
# Define the query
query = "What do I major in?"

# Invoke the agent
input_messages = [HumanMessage(query)]
output = app.invoke({"messages": input_messages}, config)
output["messages"][-1].pretty_print()

# MANAGING CONVERSATION HISTORY

In [None]:
%pip install --quiet trimmer transformers

In [None]:
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, trim_messages
from langchain_community.chat_models import ChatMaritalk

# Define the model
model = ChatMaritalk(
    model=MARITALK_LLM_MODEL,
    api_key=MARITALK_API_KEY,
    max_tokens=1000,
)

# Define trimmer
trimmer = trim_messages(
    max_tokens=1000,
    strategy="last",
    token_counter=model,
    include_system=True,
    allow_partial=False,
    start_on="human",
)

# Define the messages
messages = [
    SystemMessage(content="You are a helpful assistant!"),
    HumanMessage(content="Hi there! I'm an undergraduate computer science student"),
    AIMessage(content="Hello! How can I help you today?"),
    HumanMessage(content="I am trying to learn about data structures and algorithms"),
    AIMessage(content="Data structures and algorithms are a fundamental part of computer science!"),
]

trimmer.invoke(messages)

In [None]:
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import START, MessagesState, StateGraph
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

# Define prompt template
prompt_template = ChatPromptTemplate.from_messages(
    [
        MessagesPlaceholder(variable_name="messages"),
    ]
)
print(prompt_template)

# Define a new graph
workflow = StateGraph(state_schema=MessagesState)

# Define the function that calls the model
def call_model(state: MessagesState):
    trimmed_messages = trimmer.invoke(state["messages"])
    prompt = prompt_template.invoke(
        {"messages": trimmed_messages}
    )
    response = model.invoke(prompt)
    return {"messages": [response]}


# Define the (single) node in the graph
workflow.add_edge(START, "model")
workflow.add_node("model", call_model)

# Add memory
memory = MemorySaver()
app = workflow.compile(checkpointer=memory)

In [None]:
# Define the configuration
config = {"configurable": {"thread_id": "thread_id_2"}}
query = "What was my previous interaction? What computer science concepts did I mention?"

input_messages = messages + [HumanMessage(query)]
output = app.invoke(
    {"messages": input_messages},
    config,
)
output["messages"][-1].pretty_print()