# Customer Support Bot
This bot will use Langgraph to build an agent that can complete the following tasks:
- Get total spend by category
- Get transactions by category
- Get products and their current balance

## Lets define the APIs
As writing and calling services that can retrieve the information for the listed tasks we will just mock those APIs.

In [1]:
import datetime
import random
from typing import Literal

def random_date_past_month():
    """Generates a random date in the past month for transaction mocking."""
    # Get today's date
    today = datetime.date.today()
    
    # Calculate the date from one month ago
    if today.month == 1:
        one_month_ago = datetime.date(today.year - 1, 12, today.day)
    else:
        # Handle different month lengths and edge cases like leap years
        try:
            one_month_ago = datetime.date(today.year, today.month - 1, today.day)
        except ValueError:
            # Handle cases where the day doesn't exist in the previous month
            # (e.g., March 31 -> February 28/29)
            if today.month == 3:
                # Special case for February
                year = today.year
                # Check if it's a leap year
                if (year % 4 == 0 and year % 100 != 0) or (year % 400 == 0):
                    last_day = 29
                else:
                    last_day = 28
                one_month_ago = datetime.date(year, 2, last_day)
            else:
                # For other months, get the last day of the previous month
                last_day = [31, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31][today.month - 2]
                one_month_ago = datetime.date(today.year, today.month - 1, last_day)
    
    # Generate a random number of days between one month ago and today
    delta_days = (today - one_month_ago).days
    random_days = random.randint(0, delta_days)
    
    # Calculate the random date
    random_date = today - datetime.timedelta(days=random_days)
    
    return random_date.strftime("%Y-%m-%d")

def get_products(username:str):
    """
    This function will mock calling an API that gets the products a given user has.

    Args:
        username (str) - Username of the user to retrieve accounts for.
    """
    accounts = [
        {
            "type":"savings",
            "name":f"{username}'s Savings Account",
            "balance":6_000
        },
        {
            "type":"current",
            "name":f"{username}'s Current Account",
            "balance":350
        },
        {
            "type":"credit card",
            "name": f"{username}'s Everyday Credit Card",
            "balance":-100
        }
    ]
    # The user has a savings account, current account and credit card.
    return {"accounts": accounts, "username":username}

category_spends = {
    "groceries" : random.randint(1,500),
    "holidays" : random.randint(1,4_000),
    "entertainment": random.randint(1,250),
    "transport" : random.randint(1,750),
}

def get_total_category(username:str, category:Literal["groceries", "holidays", "entertainment","transport"]):
    """
    Gets a total user spend for a category.
    
    Args:
        username (str) - Username of the user to retrieve accounts for.
        category (Literal["groceries", "holidays", "entertainment","transport"]) - Category to retrieve spend for. Note there are 4 categories.
    """
    return {"username":username, "spend": category_spends[category]}

def get_transactions_per_category(username:str, category:Literal["groceries", "holidays", "entertainment","transport"]):
    """
    Gets a list of all transactions within a category.

    This list will be a set of randomly generated numbers, and thus not consistent.
    """
    out = []
    remaining = category_spends[category]
    while remaining > 10:
        this_val = random.randint(1,max(1,remaining//3))
        this_date = random_date_past_month()
        out.append({"transaction_date" : this_date, "amount": this_val})
        remaining -= this_val
    else:
        out.append({"transaction_date" : random_date_past_month(), "amount": remaining})
    return {"username": username, "category": category, "transactions": out}

# Graph Overview



Our graph will look as follows:

![graph](./graphs/graph.png)

The workflow this reflects is as follows:
1. Identify which APIs need to be called to answer the user question
2. Call those APIs to gather the information to answer.
3. Create a textual answer.

This will involve 2 LLM calls, one for identifying which APIs we need to call, and one for generating our final answer.

In [2]:
# Define the prompts
API_SELECTION_PROMPT = """
You are a developer within a retail bank. You have a user question, and you need to decide which APIs should be called to answer the user question.

You have the following APIs you can call:
- An API which gets the products and the balance of the financial products a user owns. This API is called: products
- An API which gets the total spending in a given category over the past month. This API is called: category total
- An API which gets the transactions in a given category and the date they occurred. This API is called: category transactions

User Question: {user_question}

{format_instructions}
"""

SUMMARISATION_PROMPT = """
You are an individual working within a retail bank. Your job is to answer the user question to the best of your ability given the following API responses.

User Question: {user_question}

These are the results of the API calls: {api_responses}

The users name is {username}

{format_instructions}
"""

In [3]:
from pydantic import BaseModel, Field
# Define the response formats.

class APICall(BaseModel):
    api_name: Literal["products","category total", "category transactions"] = Field(description="Name of the API to call.")
    category: Literal["groceries", "holidays", "entertainment","transport",""] = Field(default="", description="Category to use with the given API")

class APIIdentificationResponse(BaseModel):
    reasoning:str = Field(description="Reasoning for your selection")
    calls:list[APICall] = Field(description="List of API calls you would like to make.")

class SummarisationResponse(BaseModel):
    summarised_answer:str

## Defining the Graph
To define our graph we need to complete the following tasks:

1. Define our State Object
2. Define our Node functions.
3. Define our Decision functions
4. Compile the graph

In [4]:
from pydantic import BaseModel

# 1. Define State Object
class GraphState(BaseModel):
    summarised_response: str = ""
    api_calls:list[APICall] = []
    api_responses:list[dict] = []
    user_question:str = ""
    username:str = ""

In [None]:
from langchain_openai.chat_models import ChatOpenAI
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import PydanticOutputParser
from dotenv import load_dotenv
import os
load_dotenv(override=True)

#2a. Define LLM based nodes

# Switch this to the LLM you would like to use.
#LLM = ChatOpenAI(model="gpt-4o", api_key="YOUR OPENAI KEY GOES HERE!").. use this if using open ai directly
#use this when using azure apon AI. ensure that AZURE_OPENAI_API_KEY and AZURE_OPENAI_ENDPOINT are set
# Set up your Azure OpenAI API key and endpoint


LLM = ChatOpenAI(model="gpt-4.1-nano", temperature=0)

def api_selection_node(state:GraphState)->GraphState:
    output_parser = PydanticOutputParser(pydantic_object=APIIdentificationResponse)
    prompt_template = PromptTemplate.from_template(
        API_SELECTION_PROMPT,
        partial_variables={"format_instructions":output_parser.get_format_instructions()}
    )
    chain = prompt_template | LLM | output_parser
    response:APIIdentificationResponse = chain.invoke({"user_question":state.user_question})
    print(response.reasoning)
    state.api_calls = response.calls
    return state

def summarisation_node(state:GraphState)->GraphState:
    output_parser = PydanticOutputParser(pydantic_object=SummarisationResponse)
    prompt_template = PromptTemplate.from_template(
        SUMMARISATION_PROMPT,
        partial_variables={"format_instructions":output_parser.get_format_instructions()}
    )
    chain = prompt_template | LLM | output_parser
    api_calls_str = ""
    for api_call in state.api_responses:
        api_calls_str += f"{api_call.get('query')} Result: {api_call.get('response')}\n"
    print("-----These were the API call results-----")
    print(api_calls_str)
    print("-----------------------------------------")
    response:SummarisationResponse = chain.invoke({"user_question":state.user_question, "api_responses" : api_calls_str, "username":state.username})
    state.summarised_response=response.summarised_answer
    return state

In [6]:
# 2b. Define Non-LLM based nodes
def call_api(state:GraphState)->GraphState:
    print(state)
    this_call = state.api_calls.pop(0)
    if this_call.api_name ==  "products":
        response = get_products(state.username)
        state.api_responses.append({"query": "Call to products API.","response" : response.get("accounts")})
    if this_call.api_name ==  "category total":
        response = get_total_category(state.username, this_call.category)
        state.api_responses.append({"query": f"Call to get total spend by category API with the category '{this_call.category}'.","response" : response.get("spend")})
    if this_call.api_name ==  "category transactions":
        response = get_transactions_per_category(state.username, this_call.category)
        state.api_responses.append({"query": f"Call to get transactions by category API with the category '{this_call.category}'.","response" : response.get("transactions")})
    return state

In [7]:
# 3. Define decision functions
def decide_make_call(state:GraphState):
    # If we need to make more api calls go to API call node, else summarise.
    if state.api_calls:
        return "call_api_node"
    return "summarisation_node"

In [8]:
from langgraph.graph import StateGraph, END
# 4. Compile the graph 
def generate_graph():
    # Instantiate graph with state class
    graph = StateGraph(GraphState)
    # Map node functions to node names
    graph.add_node("select_apis_node", api_selection_node)
    graph.add_node("summarisation_node", summarisation_node)
    graph.add_node("call_api_node", call_api)
    # Define unconditional edges
    graph.add_edge("summarisation_node", END)
    # Define conditional edge
    graph.add_conditional_edges("select_apis_node", decide_make_call)
    graph.add_conditional_edges("call_api_node", decide_make_call)
    # Add entrypoint
    graph.set_entry_point("select_apis_node") 
    return graph.compile()

graph = generate_graph()

## Calling the graph

Finally we can now call our agent/graph!

In [9]:
def call_graph(question:str, username:str):
    inital_state = GraphState(username=username, user_question=question)
    response = graph.invoke(inital_state)
    return response.get("summarised_response")

In [10]:
print(call_graph("What was the day I spent the most on groceries this month?", "Username"))

To answer the user's question about the day they spent the most on groceries this month, I will need to call the 'category transactions' API. This API will provide the individual transaction details for the groceries category, including the dates and amounts spent, which will allow me to determine the day with the highest expenditure.
summarised_response='' api_calls=[APICall(api_name='category transactions', category='groceries')] api_responses=[] user_question='What was the day I spent the most on groceries this month?' username='Username'
-----These were the API call results-----
Call to get transactions by category API with the category 'groceries'. Result: [{'transaction_date': '2025-05-06', 'amount': 21}, {'transaction_date': '2025-05-15', 'amount': 4}, {'transaction_date': '2025-04-28', 'amount': 62}, {'transaction_date': '2025-05-11', 'amount': 56}, {'transaction_date': '2025-04-27', 'amount': 15}, {'transaction_date': '2025-05-04', 'amount': 33}, {'transaction_date': '2025-04-