In [1]:
from typing import Annotated,Sequence,TypedDict
from langchain_core.messages import BaseMessage
from langchain_core.messages import ToolMessage
from langchain_core.messages import SystemMessage
from langchain_core.tools import tool
from langgraph.graph.message import add_messages
from langgraph.graph import StateGraph,END
from langgraph.prebuilt import ToolNode
from langchain_groq import ChatGroq

In [2]:
from google.colab import userdata
groq_api_key=userdata.get('groq_api')

In [3]:
llm = ChatGroq(groq_api_key=groq_api_key,
               model_name="deepseek-r1-distill-llama-70b")

In [4]:
class AgentState(TypedDict):
  messages: Annotated[Sequence[BaseMessage], add_messages]

In [5]:
@tool
def add(a: int, b:int):
  """This is an addition function that adds 2 numbers together"""

  return a + b

@tool
def subtract(a: int, b: int):
  """Subtraction function"""
  return a - b

@tool
def multiply(a: int, b: int):
  """Multiplication function"""
  return a * b

In [6]:
tools = [add, subtract, multiply]

In [7]:
model=llm.bind_tools(tools)

In [8]:
def model_call(state:AgentState) -> AgentState:
  system_prompt = SystemMessage(content=
      "You are my AI assistant, please answer my query to the best of your ability."
  )
  response = model.invoke([system_prompt] + state["messages"])
  return {"messages": [response]}

In [9]:
def should_continue(state: AgentState):
  messages = state["messages"]
  last_message = messages[-1]
  if not last_message.tool_calls:
      return "end"
  else:
      return "continue"

In [10]:
graph = StateGraph(AgentState)
graph.add_node("our_agent", model_call)


tool_node = ToolNode(tools=tools)
graph.add_node("tools", tool_node)

graph.set_entry_point("our_agent")

<langgraph.graph.state.StateGraph at 0x7f8b974a5810>

In [11]:
graph.add_conditional_edges(
    "our_agent",
    should_continue,
    {
        "continue": "tools",
        "end": END,
    },
)

<langgraph.graph.state.StateGraph at 0x7f8b974a5810>

In [12]:
graph.add_edge("tools", "our_agent")

app = graph.compile()


In [13]:
def print_stream(stream):
  for s in stream:
      message = s["messages"][-1]
      if isinstance(message, tuple):
          print(message)
      else:
          message.pretty_print()

In [14]:
inputs = {"messages": [("user", "Add 40 + 12 and then multiply the result by 6. Also tell me a joke please.")]}
print_stream(app.stream(inputs, stream_mode="values"))


Add 40 + 12 and then multiply the result by 6. Also tell me a joke please.

To solve the problem step by step:

1. **Addition**: 40 + 12 = 52
2. **Multiplication**: 52 × 6 = 312

**Answer**: 312

**Joke**: Why don't skeletons fight each other? Because they don't have the guts! 😄
