# Customer Support Bot
This bot will use Langgraph to build an agent that can complete the following tasks:
- Get total spend by category
- Get transactions by category
- Get products and their current balance

## Lets Set Up The LLM

In [None]:
from langchain_openai.chat_models import ChatOpenAI
from dotenv import load_dotenv
load_dotenv(override=True)
LLM = ChatOpenAI(model="gpt-4.1-nano", temperature=0)

## Lets define the APIs
As writing and calling services that can retrieve the information for the listed tasks we will just mock those APIs.

In [42]:
import datetime
import random
from typing import Literal

def random_date_past_month():
    """Generates a random date in the past month for transaction mocking."""
    # Get today's date
    today = datetime.date.today()
    
    # Calculate the date from one month ago
    if today.month == 1:
        one_month_ago = datetime.date(today.year - 1, 12, today.day)
    else:
        # Handle different month lengths and edge cases like leap years
        try:
            one_month_ago = datetime.date(today.year, today.month - 1, today.day)
        except ValueError:
            # Handle cases where the day doesn't exist in the previous month
            # (e.g., March 31 -> February 28/29)
            if today.month == 3:
                # Special case for February
                year = today.year
                # Check if it's a leap year
                if (year % 4 == 0 and year % 100 != 0) or (year % 400 == 0):
                    last_day = 29
                else:
                    last_day = 28
                one_month_ago = datetime.date(year, 2, last_day)
            else:
                # For other months, get the last day of the previous month
                last_day = [31, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31][today.month - 2]
                one_month_ago = datetime.date(today.year, today.month - 1, last_day)
    
    # Generate a random number of days between one month ago and today
    delta_days = (today - one_month_ago).days
    random_days = random.randint(0, delta_days)
    
    # Calculate the random date
    random_date = today - datetime.timedelta(days=random_days)
    
    return random_date.strftime("%Y-%m-%d")

def get_products(username:str):
    """
    This function will mock calling an API that gets the products a given user has.

    Args:
        username (str) - Username of the user to retrieve accounts for.
    """
    accounts = [
        {
            "type":"savings",
            "name":f"{username}'s Savings Account",
            "balance":1_000
        },
        {
            "type":"current",
            "name":f"{username}'s Current Account",
            "balance":350
        },
        {
            "type":"credit card",
            "name": f"{username}'s Everyday Credit Card",
            "balance":-100
        }
    ]
    # The user has a savings account, current account and credit card.
    return {"accounts": accounts, "username":username}

category_spends = {
    "groceries" : random.randint(1,500),
    "holidays" : random.randint(1,4_000),
    "entertainment": random.randint(1,250),
    "transport" : random.randint(1,750),
}

def get_total_category(username:str, category:Literal["groceries", "holidays", "entertainment","transport"]):
    """
    Gets a total user spend for a category.
    
    Args:
        username (str) - Username of the user to retrieve accounts for.
        category (Literal["groceries", "holidays", "entertainment","transport"]) - Category to retrieve spend for. Note there are 4 categories.
    """
    return {"username":username, "spend": category_spends[category]}

def get_transactions_per_category(username:str, category:Literal["groceries", "holidays", "entertainment","transport"]):
    """
    Gets a list of all transactions within a category.

    This list will be a set of randomly generated numbers, and thus not consistent.
    """
    out = []
    remaining = category_spends[category]
    while remaining > 10:
        this_val = random.randint(1,max(1,remaining//3))
        this_date = random_date_past_month()
        out.append({"transaction_date" : this_date, "amount": this_val})
        remaining -= this_val
    else:
        out.append({"transaction_date" : random_date_past_month(), "amount": remaining})
    return {"username": username, "category": category, "transactions": out}

# Graph Overview



Our graph will look as follows:

![graph](./graphs/graph.png)

The workflow this reflects is as follows:
1. Identify which APIs need to be called to answer the user question
2. Call those APIs to gather the information to answer.
3. Create a textual answer.

This will involve 2 LLM calls, one for identifying which APIs we need to call, and one for generating our final answer.

In [43]:
# Define the prompts
API_SELECTION_PROMPT = """
You are a developer within a retail bank. You have a user question, and you need to decide which APIs should be called to answer the user question.

You have the following APIs you can call:
- An API which gets the products and the balance of the financial products a user owns. This API is called: products
- An API which gets the total spending in a given category over the past month. This API is called: category total
- An API which gets the transactions in a given category and the date they occurred. This API is called: category transactions

User Question: {user_question}

{format_instructions}
"""

SUMMARISATION_PROMPT = """
You are an individual working within a retail bank. Your job is to answer the user question to the best of your ability given the following API responses.

User Question: {user_question}

These are the results of the API calls: {api_responses}

The users name is {username}

{format_instructions}
"""

In [44]:
from pydantic import BaseModel, Field
# Define the response formats.

class APICall(BaseModel):
    api_name: Literal["products","category total", "category transactions"] = Field(description="Name of the API to call.")
    category: Literal["groceries", "holidays", "entertainment","transport",""] = Field(default="", description="Category to use with the given API")

class APIIdentificationResponse(BaseModel):
    reasoning:str = Field(description="Reasoning for your selection")
    calls:list[APICall] = Field(description="List of API calls you would like to make.")

class SummarisationResponse(BaseModel):
    summarised_answer:str

## Defining the Graph
To define our graph we need to complete the following tasks:

1. Define our State Object
2. Define our Node functions.
3. Define our Decision functions
4. Compile the graph

In [45]:
from pydantic import BaseModel

# 1. Define State Object
class GraphState(BaseModel):
    summarised_response: str = ""
    api_calls:list[APICall] = []
    api_responses:list[dict] = []
    user_question:str = ""
    username:str = ""

In [46]:
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
from dotenv import load_dotenv

load_dotenv(override=True)

#2a. Define LLM based nodes

def api_selection_node(state:GraphState)->GraphState:
    output_parser = PydanticOutputParser(pydantic_object=APIIdentificationResponse)
    prompt_template = PromptTemplate.from_template(
        API_SELECTION_PROMPT,
        partial_variables={"format_instructions":output_parser.get_format_instructions()}
    )
    chain = prompt_template | LLM | output_parser
    response:APIIdentificationResponse = chain.invoke({"user_question":state.user_question})
    print(response.reasoning)
    state.api_calls = response.calls
    return state

def summarisation_node(state:GraphState)->GraphState:
    output_parser = PydanticOutputParser(pydantic_object=SummarisationResponse)
    prompt_template = PromptTemplate.from_template(
        SUMMARISATION_PROMPT,
        partial_variables={"format_instructions":output_parser.get_format_instructions()}
    )
    chain = prompt_template | LLM | output_parser
    api_calls_str = ""
    for api_call in state.api_responses:
        api_calls_str += f"{api_call.get('query')} Result: {api_call.get('response')}\n"
    print("-----These were the API call results-----")
    print(api_calls_str)
    print("-----------------------------------------")
    response:SummarisationResponse = chain.invoke({"user_question":state.user_question, "api_responses" : api_calls_str, "username":state.username})
    state.summarised_response=response.summarised_answer
    return state

In [47]:
# 2b. Define Non-LLM based nodes
def call_api(state:GraphState)->GraphState:
    print(state)
    this_call = state.api_calls.pop(0)
    if this_call.api_name ==  "products":
        response = get_products(state.username)
        state.api_responses.append({"query": "Call to products API.","response" : response.get("accounts")})
    if this_call.api_name ==  "category total":
        response = get_total_category(state.username, this_call.category)
        state.api_responses.append({"query": f"Call to get total spend by category API with the category '{this_call.category}'.","response" : response.get("spend")})
    if this_call.api_name ==  "category transactions":
        response = get_transactions_per_category(state.username, this_call.category)
        state.api_responses.append({"query": f"Call to get transactions by category API with the category '{this_call.category}'.","response" : response.get("transactions")})
    return state

In [48]:
# 3. Define decision functions
def decide_make_call(state:GraphState):
    # If we need to make more api calls go to API call node, else summarise.
    if state.api_calls:
        return "call_api_node"
    return "summarisation_node"

In [49]:
from langgraph.graph import StateGraph, END
# 4. Compile the graph 
def generate_graph():
    # Instantiate graph with state class
    graph = StateGraph(GraphState)
    # Map node functions to node names
    graph.add_node("select_apis_node", api_selection_node)
    graph.add_node("summarisation_node", summarisation_node)
    graph.add_node("call_api_node", call_api)
    # Define unconditional edges
    graph.add_edge("summarisation_node", END)
    # Define conditional edge
    graph.add_conditional_edges("select_apis_node", decide_make_call)
    graph.add_conditional_edges("call_api_node", decide_make_call)
    # Add entrypoint
    graph.set_entry_point("select_apis_node") 
    return graph.compile()

graph = generate_graph()

## Calling the graph

Finally we can now call our agent/graph!

In [50]:
def call_graph(question:str, username:str):
    inital_state = GraphState(username=username, user_question=question)
    response = graph.invoke(inital_state)
    return response.get("summarised_response")

In [51]:
print(call_graph("What was the day I spent the most on groceries this month?", "Username"))

To determine the day the user spent the most on groceries this month, I need to retrieve the transactions categorized under groceries. Thus, I will call the 'category transactions' API which will provide the details of all grocery transactions along with their respective dates.
summarised_response='' api_calls=[APICall(api_name='category transactions', category='groceries')] api_responses=[] user_question='What was the day I spent the most on groceries this month?' username='Username'
-----These were the API call results-----
Call to get transactions by category API with the category 'groceries'. Result: [{'transaction_date': '2025-04-18', 'amount': 47}, {'transaction_date': '2025-05-01', 'amount': 39}, {'transaction_date': '2025-04-27', 'amount': 66}, {'transaction_date': '2025-04-27', 'amount': 70}, {'transaction_date': '2025-04-22', 'amount': 25}, {'transaction_date': '2025-05-15', 'amount': 20}, {'transaction_date': '2025-05-06', 'amount': 47}, {'transaction_date': '2025-05-14', 'a

# Creating a Airline Booking Agent
We will create a second agent which allows users to check the price of flights and book them, and cancel bookings. This time we will use langchain tools for our agent. The exact implementation will not be covered as indepth.

In [52]:
import hashlib
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode
from langgraph.graph import StateGraph, MessagesState, START

from uuid import uuid4

# Define our APIs
@tool
def get_flight_price(departure_airport:str, destination_airport:str)->int:
    """
    Call to get the price of a flight between two airports.

    Args:
        departure_airport (str) - IATA airport code for the departure airport.
        destination_airport (str) - IATA airport code for the destination airport.

    Returns:
        int - Price of the flight
    """
    return random.randint(30,1000)

bookings = {}

@tool
def book_flight(departure_airport:str, destination_airport:str)->str:
    """
    Call to book a flight between two airports.

    Args:
        departure_airport (str) - IATA airport code for the departure airport.
        destination_airport (str) - IATA airport code for the destination airport.
    
    Returns:
        str - Booking confirmation ID
    """
    booking_id = str(uuid4())
    booking_hash = hashlib.sha1(str.encode(booking_id)).hexdigest()
    bookings[booking_hash] = [departure_airport, destination_airport, random.randint(30,1000)]
    return booking_hash

@tool
def cancel_flight(booking_id:str)->bool:
    """
    Call to cancel a flight.
    Args:
        booking_id - The booking confirmation ID that should be cancelled.
    Returns:
        bool - True if the flight was cancelled successfully, false if it was not.
    """
    if bookings.get(booking_id):
        del bookings[booking_id]
        return True
    return False

tools = [get_flight_price, book_flight, cancel_flight]

flight_llm = LLM.bind_tools(tools)

tool_node = ToolNode(tools)

def should_continue(state: MessagesState):
    messages = state["messages"]
    last_message = messages[-1]
    if last_message.tool_calls:
        return "tools"
    return END


def call_model(state: MessagesState):
    messages = state["messages"]
    response = flight_llm.invoke(messages)
    return {"messages": [response]}

def build_flight_agent():
    graph = StateGraph(MessagesState)

    # Define the two nodes we will cycle between
    graph.add_node("agent", call_model)
    graph.add_node("tools", tool_node)

    graph.add_edge(START, "agent")
    graph.add_conditional_edges("agent", should_continue, ["tools", END])
    graph.add_edge("tools", "agent")

    return graph.compile()

flight_agent = build_flight_agent()

def invoke_flight_agent(question:str)->str:
    """
    Calls the flight agent.
    
    Args:
        question (str) - Question to ask the flight agent.
    
    Returns:
        str - Flight agent response.
    """
    res = flight_agent.invoke({"messages": [("human", question)]})
    return res.get("messages")[-1].content

In [53]:
invoke_flight_agent("How much is a flight from london to new york.")


'The price of a flight from London to New York is $958.'

# Progressing to a Multi Agent Approach
We will now use a Hierarchical approach to use both our agents at the same time.

For this we will create a supervisor agent, which we will invoke directly. This agent is then responsible for delegating tasks and coordinating responses.

![graph](./graphs/supervisor_graph.png)

In [54]:
# First rename our call graph function from earlier, so naming conventions are consistent.
invoke_financial_agent = call_graph

## We will define the supervisor agent following the same steps as before:
1. Define our State Object
2. Define our Node functions.
3. Define our Decision functions
4. Compile the graph

In [55]:
from pydantic import BaseModel
from typing import Literal

class Message(BaseModel):
    role:Literal["user", "assistant"]
    content:str

# 1. Define our State Object
class SupervisorState(BaseModel):
    messages:list[Message] = []
    agent:Literal["financial", "flight", ""] = ""
    message:str = ""
    username:str = ""

In [56]:
from pydantic import BaseModel, Field
from langchain.prompts import PromptTemplate
from langchain.output_parsers import PydanticOutputParser
# 2a. Define our LLM based nodes.

SUPERVISOR_DECISION_PROMPT = """You are a supervisor agent responsible for coordinating multiple sub agents. Your task is to decide whether you need to call a subagent and what question or task you would like to ask that agent to perform.

You have 2 subagent available to you:
1. A flight agent which is responsible for getting costs of flights, booking flights and cancelling bookings.
2. A financial products agent, which has access to a users financial products and identifying trends for the users spending categories. This agent can see the users bank balances.

Here is the message history so far: {message_history}

{format_instructions}
"""

class DecisionResponse(BaseModel):
    instruction:str = Field(description="Question or instruction to give to the subagent if it should be called.", default="")
    agent:Literal["financial", "flight", ""] = Field(description="Agent to call next, should be none if you believe all information has been collected and we can respond to the user.", default="")

SUPERVISOR_SUMMARISATION_PROMPT = """You are a supervisor agent responsible for coordinating multiple sub agents. You have decided you have obtained all the information you can.
Your task now is to look at the history of the different calls and generate a suitable response for the user.

Here is the message history: {message_history}

This is the users name: {username}

This your name: "Sapient Supervisor Bot"
"""

def stringify_chat_history(messages:list[Message])->str:
    return "\n".join(map(lambda x: f"{x.role.title()}: {x.content}", messages))

def supervisor_decision_node(state:SupervisorState):
    parser = PydanticOutputParser(pydantic_object=DecisionResponse)
    prompt = PromptTemplate.from_template(SUPERVISOR_DECISION_PROMPT, partial_variables={"format_instructions":parser.get_format_instructions()})
    chain = prompt | LLM | parser
    output:DecisionResponse = chain.invoke({"message_history":stringify_chat_history(state.messages)})
    state.agent = output.agent
    state.message = output.instruction
    return state

def supervisor_summary_node(state:SupervisorState):
    prompt = PromptTemplate.from_template(SUPERVISOR_SUMMARISATION_PROMPT)
    chain = prompt | LLM
    output = chain.invoke({"message_history":stringify_chat_history(state.messages), "username":state.username})
    state.messages.append(Message(content=output.content, role="assistant"))
    return state


In [57]:
# 2b. Define Non LLM node functions (in this case calling subagents)

def call_financial_agent(state: SupervisorState):
    print(f"Calling financial agent to answer '{state.message}'.")
    response = invoke_financial_agent(state.message, state.username)
    state.messages.append(Message(content=response,role="assistant"))
    print(f"Financial agent response was '{response}'")
    return state


def call_flights_agent(state: SupervisorState):
    print(f"Calling flight agent to answer '{state.message}'.")
    response = invoke_flight_agent(state.message)
    state.messages.append(Message(content=response,role="assistant"))
    print(f"Flight agent response was '{response}'")
    return state

In [58]:
# 3. Define the decision functions
def decide_call_subagent(state:SupervisorState):
    # If we have an agent to call, go to that node, else go to summary
    if state.agent:
        return state.agent
    return "summary"

In [59]:
#4 Compile the graph
from langgraph.graph import StateGraph, END
from langchain_core.messages import HumanMessage
# 4. Compile the graph 
def compile_supervisor_agent():
    # Instantiate graph with state class
    graph = StateGraph(SupervisorState)
    # Map node functions to node names
    graph.add_node("flight", call_flights_agent)
    graph.add_node("financial", call_financial_agent)
    graph.add_node("decision", supervisor_decision_node)
    graph.add_node("summary", supervisor_summary_node)
    # Define unconditional edges
    graph.add_edge("summary", END)
    graph.add_edge("flight", "decision")
    graph.add_edge("financial", "decision")
    # Define conditional edge
    graph.add_conditional_edges("decision", decide_call_subagent)
    # Add entrypoint
    graph.set_entry_point("decision") 
    return graph.compile()

supervisor_agent = compile_supervisor_agent()

def call_supervisor_agent(question:str,username:str):
    inital_state = SupervisorState(messages=[Message(content=question,role="user")], username=username)
    result = supervisor_agent.invoke(inital_state)
    return result.get("messages")[-1].content

In [60]:
print(call_supervisor_agent("Book a flight from new york to singapore if I have enough in savings.", "Marc"))

Calling financial agent to answer 'Check the user's savings balance to determine if there are sufficient funds to book a flight from New York to Singapore.'.
To check if the user has sufficient funds to book a flight from New York to Singapore, I need to retrieve the user's savings balance. The 'products' API will provide the details of the financial products the user owns, including the savings balance.
summarised_response='' api_calls=[APICall(api_name='products', category='')] api_responses=[] user_question="Check the user's savings balance to determine if there are sufficient funds to book a flight from New York to Singapore." username='Marc'
-----These were the API call results-----
Call to products API. Result: [{'type': 'savings', 'name': "Marc's Savings Account", 'balance': 1000}, {'type': 'current', 'name': "Marc's Current Account", 'balance': 350}, {'type': 'credit card', 'name': "Marc's Everyday Credit Card", 'balance': -100}]

-----------------------------------------
Finan