In [None]:
from typing import Annotated 

from typing_extentions import TypedDict

from langgraph.graph import StateGraph, START, END 
from langgraph.graph.message import add_messages 

import os 
from dotenv import load_dotenv 

from langchain_groq import ChatGroq
# from langchain.chat_models import init_chat_model

from IPython.display import Image, display

from langchain_tavily import TavilySearch

from langgraph.prebuilt import ToolNode 
from langgraph.prebuilt import tools_condition

from langraph.checkpoint.memory import MemorySaver

: 

In [None]:
load_dotenv()

In [None]:
class State(TypedDict):
    """
    Messages have the type "list". The "add_messages" function
    in the annotation defines how this statekey should be updated
    (in this case, It appends messages to the list, rather than overwriting them)
    """
    messages:Annotated[list,add_messages]


In [None]:
llm = ChatGroq(model = "llama-3.1-8b-instant")
llm = init_chat_model("groq:llama-3.1-8b-instant")

In [None]:
# ## node functionality 
# def chatbot(state:State):
#     return {"messages":[llm.invoke(state["messages"])]}

In [None]:
# graph_builder = StateGraph(State)

# ## add node
# graph_builder.add_node("llmchatbot",chatbot)

# ## add edges
# graph_builder.add_edge(START,"llmchatbot")
# graph_builder.add_edge("llmchatbot",END)

# ## compile the graph
# graph = graph_builder.compile()

In [None]:
# ## Visualize the graph

# try:
#     display(Image(graph.get_graph().draw_mermaid_png()))
# except Exception:
#     pass

In [None]:
# response = graph.invoke({"messages":"Hello"})

In [None]:
# response["messages"][-1].content

In [None]:
# for event in graph.stream({"messages":"Hello How are you?"}):
#     for value in event.values():
#         print(value["messages"][-1].content)


## Add tools to agent

In [None]:
search_internet = TavilySearch(max_results=2)

In [None]:
def multiply(a: int, b: int) -> int:
    """Multiply a and b
    Args:
    a(int) : first int
    b(int) : second int

    Returns:
    int : product of a and b
    """
    return a * b

In [None]:
tools = [search_internet, multiply]

In [None]:
## bind llm with tools
llm_with_tools = llm.bind_tools(tools)

In [None]:
# initialize memory 
memory = MemorySaver()

In [None]:
def tool_calling_llm(state:State):
    return {"messages":[llm_with_tools.invoke(state["messages"])]}

In [None]:
builder = StateGraph(State)
builder.add_node("tool_calling_llm",tool_calling_llm)
builder.add_node("tools",ToolNode(tools))

builder.add_edge(START,"tool_calling_llm")
builder.add_conditional_edges(
    "tool_calling_llm",
    # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
    # If the latest message (result) from assistant is not a tool call -> tools_condition routes to END
    tools_condition
)
# builder.add_edge("tools",END)
builder.add_edge("tools","tool_calling_llm")

## compile the graph
graph = builder.compile(checkpointer=memory)

display(Image(graph.get_graph().draw_mermaid_png()))


In [None]:
graph.invoke({"messages":"What is the recent ai news"})
graph.invoke({"messages":"What is 2 multiplied by 3"})
graph.invoke({"messages":"What is 2 multiplied by 3 and then multiplied by 6"})
graph.invoke({"messages":"Give me the recent ai news and then multiply 3 by 6"})


## use memory
config = {"configurable":{"thread_id":"1"}}

graph.invoke({"messages":"Hi my name is Anjali"},config=config)
graph.invoke({"messages":"What is my name?"},config=config)
