## Setup and Imports

In [0]:
%pip install -U langchain-databricks langchain-core langchain mlflow langgraph



In [0]:
dbutils.library.restartPython()

In [0]:
import operator
from typing import TypedDict, Annotated, Sequence
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_databricks import ChatDatabricks
from langchain_core.tools import tool
from langgraph.graph import StateGraph, END

@tool
def get_user_balance(user_id: int) -> str:
  """
  Finds the account balance for a specific user_id from the user_data table.
  """
  print(f"--- Tool Called: get_user_balance(user_id={user_id}) ---")
  try:
    df = spark.sql(f"SELECT account_balance FROM main.agents_demo.user_data WHERE user_id = {user_id}")
    result = df.first()
    if result:
      return f"The balance for user_id {user_id} is ${result['account_balance']}"
    else:
      return f"No user found with user_id {user_id}"
  except Exception as e:
    return f"Error querying database: {str(e)}"

tools = [get_user_balance]

# --- Define the Agent's State ---
# This is the "memory" of our graph.
class AgentState(TypedDict):
    # 'messages' will hold the list of all messages in the conversation
    messages: Annotated[Sequence[BaseMessage], operator.add]

## Graph Nodes

## Define the Graph Edges

In [0]:
# --- Imports from Step 2 (make sure you have these) ---
from langgraph.graph import StateGraph, END
from langchain_core.messages import BaseMessage, HumanMessage, ToolMessage, AIMessage
import operator
from typing import TypedDict, Annotated, Sequence
from langchain_databricks import ChatDatabricks
# --- Imports for our NEW Step 3 ---
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

# --- Your LLM and Tools (from Step 1 & 2) ---
llm = ChatDatabricks(endpoint="databricks-gpt-oss-120b", max_tokens=200)

llm_with_tools = llm.bind_tools(tools)

# Create a much simpler prompt. 
# It just takes the whole list of messages.
prompt = ChatPromptTemplate.from_messages([
    ("system", "You are a helpful assistant. You must use your tools to answer questions."),
    MessagesPlaceholder(variable_name="messages"), # This placeholder holds all chat history
])

# Our agent runnable is now a simple chain
agent_runnable = prompt | llm_with_tools

# --- 2. Define the 'call_model' node ---
# This function is now incredibly simple.
def call_model(state: AgentState):
    """Our "Planner" node. It calls the LLM."""
    print("--- Calling Model ---")
    
    # The 'agent_runnable' takes the entire state (which is just {"messages": [...]})
    # It returns a single AIMessage, either with content or with tool_calls.
    response_message = agent_runnable.invoke(state)
    
    # We just add this new message to the list of messages
    return {"messages": [response_message]}

# --- 3. Define the 'call_tool' node ---
# but include it here for completeness)
def call_tool(state: AgentState):
    """Our "Tool Executor" node."""
    print("--- Calling Tool ---")
    
    last_message = state['messages'][-1] # Get the AIMessage with tool calls
    
    tool_outputs = []
    for tool_call in last_message.tool_calls:
        tool_name = tool_call['name']
        tool_args = tool_call['args']
        
        # Find the actual tool function from our list
        selected_tool = next(t for t in tools if t.name == tool_name)
        
        # Call the tool
        print(f"--- Calling Tool: {tool_name} with args {tool_args} ---")
        output = selected_tool.invoke(tool_args)
        
        # Store the output in a ToolMessage
        tool_outputs.append(
            ToolMessage(content=str(output), tool_call_id=tool_call['id'])
        )

    # Return the tool outputs to be added to the state
    return {"messages": tool_outputs}
# --- 4. Define the conditional edge ---
def should_continue(state: AgentState):
    """
    This function decides our next step.
    - If the last message has tool calls, we go to the 'call_tool' node.
    - Otherwise, we 'END' the graph.
    """
    last_message = state['messages'][-1]
    
    # Check if the last message is from the AI and has tool calls
    if isinstance(last_message, AIMessage) and last_message.tool_calls:
        return "continue_to_tools"
    else:
        return "end_conversation"

# --- 5. Build the Graph ---
# NOW you can run the code that uses 'should_continue'
workflow = StateGraph(AgentState)

# Add nodes...
workflow.add_node("agent", call_model)
workflow.add_node("action", call_tool)

# Set entry point...
workflow.set_entry_point("agent")

# Add the conditional edge...
workflow.add_conditional_edges(
    "agent",          # The node to branch from
    should_continue,  # <-- This function MUST be defined first
    {
        "continue_to_tools": "action", 
        "end_conversation": END
    }
)

# Add remaining edges...
workflow.add_edge("action", "agent")

# Compile...
app = workflow.compile()

## Run The agent

In [0]:
# Loggin with MLflow
import mlflow
mlflow.set_experiment("/Users/ebrahemelsherif666i@gmail.com/langgraph_agent_demo")
mlflow.langchain.autolog()

with mlflow.start_run(run_name="LangGraph Agent Run"):
    
    # Note: The input to a graph is a list of messages
    inputs = {"messages": [HumanMessage(content="What is the balance for user_id 3?")]}
    
    # 'stream' the output to see every step
    for output in app.stream(inputs):
        # 'output' will be a dictionary with the name of the node that just ran
        print(f"--- Step: {list(output.keys())[0]} ---")
        print(output[list(output.keys())[0]])
        print("\n")

# To get just the final response
# final_response = app.invoke(inputs)
# print(final_response['messages'][-1].content)