<a href="https://colab.research.google.com/github/Decoding-Data-Science/airesidency/blob/main/Langraph_with_tools_cohort7.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
%pip install -U "langchain[openai]" langchain-community langgraph langchain-text-splitters


In [2]:
import os
from google.colab import userdata

# Retrieve API keys from Colab's secure storage
groq_api_key = userdata.get("GROQ_API_KEY")
openai_api_key = userdata.get("openai")

# Set them as environment variables
if groq_api_key:
    os.environ["GROQ_API_KEY"] = groq_api_key
if openai_api_key:
    os.environ["OPENAI_API_KEY"] = openai_api_key

In [3]:
# --- 1. Imports & model -------------------------------------------------------
from typing import Dict
import operator

from typing_extensions import TypedDict, Annotated

from langchain.tools import tool
from langchain.chat_models import init_chat_model
from langchain.messages import (
    AnyMessage,
    HumanMessage,
    SystemMessage,
    ToolMessage,
)
from langgraph.graph import StateGraph, START, END

# Initialize a chat model (OpenAI GPT-4o-mini via LangChain)
# Uses OPENAI_API_KEY from the previous cell
model = init_chat_model("gpt-4o-mini", model_provider="openai")

# --- 2. Define simple tools ---------------------------------------------------

@tool
def add(a: int, b: int) -> int:
    """Add two integers."""
    return a + b

@tool
def multiply(a: int, b: int) -> int:
    """Multiply two integers."""
    return a * b

tools = [add, multiply]
tools_by_name = {t.name: t for t in tools}
model_with_tools = model.bind_tools(tools)

# --- 3. Define graph state ----------------------------------------------------

class AgentState(TypedDict):
    # Keep a running list of messages in the conversation
    messages: Annotated[list[AnyMessage], operator.add]

# --- 4. Define the LLM node ---------------------------------------------------

def llm_node(state: Dict) -> Dict:
    """LLM decides whether to answer directly or call a tool."""
    response = model_with_tools.invoke(
        [SystemMessage(content="You are a calculator assistant.")]
        + state["messages"]
    )
    return {"messages": [response]}

# --- 5. Define the tool node --------------------------------------------------

def tool_node(state: Dict) -> Dict:
    """Executes the tool calls requested by the LLM."""
    last = state["messages"][-1]
    results = []

    for call in last.tool_calls:
        tool = tools_by_name[call["name"]]
        observation = tool.invoke(call["args"])
        results.append(ToolMessage(content=str(observation), tool_call_id=call["id"]))

    return {"messages": results}

# --- 6. Define routing logic --------------------------------------------------

from typing import Literal

def should_continue(state: AgentState) -> Literal["tool_node", END]:
    """If the last message has tool calls, go to tools; otherwise, stop."""
    last = state["messages"][-1]
    if getattr(last, "tool_calls", None):
        return "tool_node"
    return END

# --- 7. Build and compile the graph ------------------------------------------

builder = StateGraph(AgentState)

builder.add_node("llm", llm_node)
builder.add_node("tool_node", tool_node)

builder.add_edge(START, "llm")
builder.add_conditional_edges("llm", should_continue, ["tool_node", END])
builder.add_edge("tool_node", "llm")

agent = builder.compile()

# --- 8. Test it ---------------------------------------------------------------

from langchain.messages import HumanMessage

query = "What is 12 * 7, then add 5 to the result?"
messages = [HumanMessage(content=query)]

result_state = agent.invoke({"messages": messages})

for m in result_state["messages"]:
    m.pretty_print()



What is 12 * 7, then add 5 to the result?
Tool Calls:
  multiply (call_bN3f0SO9qArsmKXkhM0NkyRd)
 Call ID: call_bN3f0SO9qArsmKXkhM0NkyRd
  Args:
    a: 12
    b: 7
  add (call_Rw6tkMLJcqTHNugcaqpPXuLR)
 Call ID: call_Rw6tkMLJcqTHNugcaqpPXuLR
  Args:
    a: 12
    b: 5

84

17

The result of \( 12 \times 7 \) is 84. Then, when you add 5 to the result, you would get \( 84 + 5 = 89 \).
