In [None]:
%%capture --no-stderr
%pip install langgraph langchain-google-genai langchain-community langchain-core tavily-python

In [None]:
import os
from google.colab import userdata
from langchain_google_genai import ChatGoogleGenerativeAI

GEMINI_API_KEY = userdata.get('GEMINI_API_KEY')
os.environ["TAVILY_API_KEY"] = userdata.get('TAVILY_API_KEY')
os.environ['LANGCHAIN_API_KEY'] = userdata.get('LANGCHAIN_API_KEY')
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "learn_agentic_ai"

In [None]:
from typing import Annotated, List
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import BaseMessage
from pydantic import BaseModel
from langchain.output_parsers import StructuredOutputParser
from langchain.output_parsers.structured import ResponseSchema
from typing_extensions import TypedDict
from langchain.prompts import ChatPromptTemplate
from IPython.display import Image, display
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import StateGraph, START, END
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, AnyMessage
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.types import Command

memory = MemorySaver()

llm: ChatGoogleGenerativeAI = ChatGoogleGenerativeAI(api_key=GEMINI_API_KEY, model="gemini-1.5-flash", temperature=0)

class State(TypedDict):
    messages: Annotated[List[AnyMessage], add_messages]
    category: str
    sentiment: str
    response: str

response_schemas = [
    ResponseSchema(
        name="category",
        description="The category of the query, one of: Order Issues, Delivery Status, Product Information, Feedback / Complaints, Handle General"
    )
]

output_parser = StructuredOutputParser.from_response_schemas(response_schemas)

format_instructions = output_parser.get_format_instructions()


def categorize(state: State) -> State:
    """
    Categorize the customer query into one of the following categories:
    - Order Issues
    - Delivery Status
    - Product Information
    - Feedback / Complaints
    - Handle General

    """
    print("categorize node")
    prompt = ChatPromptTemplate.from_template(
        "Categorize the following customer query into one of these categories: "
        "Order Issues, Delivery Status, Product Information, Feedback / Complaints   Query: {query} {{format_instructions}}"
    )
    chain = prompt | llm
    category = chain.invoke({"query": state["messages"]}).content
    return {"category": category}

def analyze_sentiment(state: State) -> State:
    """Analyze the sentiment of the customer query - Sentiment options: Positive, Neutral, or Negative"""
    print("analyze_sentiment node")
    prompt = ChatPromptTemplate.from_template(
        "Analyze the sentiment of the following customer query. "
        "Respond with either 'Positive', 'Neutral', or 'Negative'. Query: {query}"
    )
    chain = prompt | llm
    sentiment = chain.invoke({"query": state["messages"]}).content
    return {"sentiment": sentiment}

def order_issues(state: State) -> State:
    """Handle customer queries categorized as 'Order Issues'."""
    print("order_issues node")
    prompt = ChatPromptTemplate.from_template(
        "Provide a response to the following order issue query: {query}"
    )
    chain = prompt | llm
    response = chain.invoke({"query": state["messages"]}).content
    return {"response": response}

def delivery_status(state: State) -> State:
    """Handle customer queries categorized as 'Delivery Status'."""
    print("delivery_status node")
    prompt = ChatPromptTemplate.from_template(
        "Provide a response to the following delivery status query: {query}"
    )
    chain = prompt | llm
    response = chain.invoke({"query": state["messages"]}).content
    return {"response": response}

def feedback_complaints(state: State) -> State:
    """Handle customer queries categorized as 'Feedback / Complaints'."""
    print("feedback_complaints node")
    prompt = ChatPromptTemplate.from_template(
        "Provide a response to the following feedback or complaints query: {query}"
    )
    chain = prompt | llm
    response = chain.invoke({"query": state["messages"]}).content
    return {"response": response}

def handle_general(state: State) -> State:
    """Handle customer queries categorized as 'General'."""
    print("handle_general node")
    prompt = ChatPromptTemplate.from_template(
        "Provide a general support response to the following query: {query}"
    )
    chain = prompt | llm
    response = chain.invoke({"query": state["messages"]}).content
    return {"response": response}

def escalate(state: State) -> State:
    """Escalate the query to a human agent due to negative sentiment."""
    print("escalate node")
    return {"response": "This query has been escalated to a human agent due to its negative sentiment."}

def route_query(state: State) -> str:
    """
    Route the query based on its sentiment and category.

    Args:
        state (State): The current state containing user messages, category, and sentiment.

    Returns:
        str: The next node name based on the routing logic.
    """
    if state["sentiment"] == "Negative":
        return "escalate"
    elif state["category"] == "Order Issues":
        return "order_issues"
    elif state["category"] == "Delivery Status":
        return "delivery_status"
    elif state["category"] == "Feedback / Complaints":
        return "feedback_complaints"
    else:
        return "handle_general"

# Define the LangGraph workflow
graph_builder = StateGraph(State)

# Add nodes to the graph
graph_builder.add_node("categorize", categorize)
graph_builder.add_node("analyze_sentiment", analyze_sentiment)
graph_builder.add_node("order_issues", order_issues)
graph_builder.add_node("delivery_status", delivery_status)
graph_builder.add_node("feedback_complaints", feedback_complaints)
graph_builder.add_node("handle_general", handle_general)
graph_builder.add_node("escalate", escalate)

# Add edges to the graph
graph_builder.add_edge(START, "categorize")
graph_builder.add_edge("categorize", "analyze_sentiment")
graph_builder.add_conditional_edges(
    "analyze_sentiment",
    route_query,
    {
        "escalate": "escalate",
        "order_issues": "order_issues",
        "delivery_status": "delivery_status",
        "feedback_complaints": "feedback_complaints",
        "handle_general": "handle_general",
    }
)
graph_builder.add_edge("order_issues", END)
graph_builder.add_edge("delivery_status", END)
graph_builder.add_edge("feedback_complaints", END)
graph_builder.add_edge("handle_general", END)
graph_builder.add_edge("escalate", END)

# Compile and display the graph
graph = graph_builder.compile(checkpointer=memory)
display(Image(graph.get_graph().draw_mermaid_png()))

def test_langgraph_agent(query: str):
    """
    Test the LangGraph agent with a sample query.

    Args:
        query (str): The customer query to test the workflow.

    Returns:
        dict: Final state including category, sentiment, and response.
    """
    # Initialize the starting state
    initial_state = {
        "messages": [HumanMessage(content=query)],
        "category": "",
        "sentiment": "",
        "response": ""
    }

    config = { "configurable": { "thread_id": 1 }}
        # Run the query through the compiled LangGraph

    output = graph.invoke(initial_state, config)

    # Display the results
    print("Query:", query)
    print("Category:", output["category"])
    print("Sentiment:", output["sentiment"])
    print("Response:", output["response"])

# Example queries to test
test_queries = [
    "I need help with a refund for my order.",
    "When will my package arrive?",
    "I have a complaint about the quality of the product.",
    "What are your business hours?"
]

for query in test_queries:
    print("\n--- Testing Query ---")
    test_langgraph_agent(query)
