In [1]:
import os
from dotenv import load_dotenv
load_dotenv()

from typing import Annotated
from typing_extensions import TypedDict

from langgraph.graph import StateGraph,START,END
from langgraph.graph.message import add_messages

In [2]:
from langchain_google_genai import ChatGoogleGenerativeAI
from google.generativeai.types import HarmCategory, HarmBlockThreshold

import os

os.environ["GOOGLE_API_KEY"] = os.environ["GOOGLE_API_KEY4"]
llm = ChatGoogleGenerativeAI(
    model="gemini-2.5-flash",
    temperature=0.5,
    max_tokens=None,
    timeout=None,
    max_retries=1,
    safety_settings={
        HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
        HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
        HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
        HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
    }
)


In [3]:
class State(TypedDict):
    # Messages have the type "list". The `add_messages` function
    # in the annotation defines how this state key should be updated
    # (in this case, it appends messages to the list, rather than overwriting them)
    messages:Annotated[list,add_messages]
    loop_count:int
    answer:str

In [4]:
# Wait 60 seconds before connecting using these details, or login to https://console.neo4j.io to validate the Aura Instance is available
NEO4J_URI="neo4j+s://2d5e8539.databases.neo4j.io"
NEO4J_USERNAME="neo4j"
NEO4J_PASSWORD="xn8iCGEj2vymA3-43-57qlL63CD70SthzTE_Mt8QfG0"
NEO4J_DATABASE="neo4j"
AURA_INSTANCEID="2d5e8539"
AURA_INSTANCENAME="Instance01"


from langchain_neo4j import Neo4jGraph
enhanced_graph = Neo4jGraph(
    url=NEO4J_URI,
    username="neo4j",
    password=NEO4J_PASSWORD,
    driver_config={
        "max_connection_lifetime": 300,  # 5 minutes
        "keep_alive": True,
        "max_connection_pool_size": 50
    },
    enhanced_schema=False)

In [5]:
from langraph_neo4j3 import AgentState, run_agent_workflow

def alpha_tool(query):
    """
    This tool executes a factual query on a Neo4j graph database.
    """
    Agentstate: AgentState = {
            "question": query,
            "next_action": "",
            "cypher_errors": [],
            "database_records": [],
            "steps": [],
            "answer": "",
            "cypher_statement": ""
        }
    result = run_agent_workflow(Agentstate,enhanced_graph)
    return f"For query '{query}', result: {result['answer']}"

In [6]:
tools=[alpha_tool]
llm_with_tool=llm.bind_tools(tools)
llm_with_tool

RunnableBinding(bound=ChatGoogleGenerativeAI(model='models/gemini-2.5-flash', google_api_key=SecretStr('**********'), temperature=0.5, max_retries=1, safety_settings={<HarmCategory.HARM_CATEGORY_HARASSMENT: 7>: <HarmBlockThreshold.BLOCK_NONE: 4>, <HarmCategory.HARM_CATEGORY_HATE_SPEECH: 8>: <HarmBlockThreshold.BLOCK_NONE: 4>, <HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: 9>: <HarmBlockThreshold.BLOCK_NONE: 4>, <HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: 10>: <HarmBlockThreshold.BLOCK_NONE: 4>}, client=<google.ai.generativelanguage_v1beta.services.generative_service.client.GenerativeServiceClient object at 0x00000218A6975650>, default_metadata=()), kwargs={'tools': [{'type': 'function', 'function': {'name': 'alpha_tool', 'description': 'This tool executes a factual query on a Neo4j graph database.', 'parameters': {'properties': {'query': {}}, 'required': ['query'], 'type': 'object'}}}]}, config={}, config_factories=[])

In [7]:
# from langchain_core.prompts import ChatPromptTemplate
# from typing import Literal
# from pydantic import BaseModel, Field
# guard_prompt = ChatPromptTemplate.from_messages(
#         [
#             (
#                 "system",
#                 """

# You are a conversational agent with reasoning and data interpretation capabilities, working with a graph database.
# Every use input can be classified into 3 modes: Conversation, Query and Hypothesis.
#  1. If the question is unrelated to the graph or can be answered without querying graph, reply normally (no tool use). (Conversation)
#  2. If the question can be answered with a single direct query â†’ call query_tool ONCE with a plain English instruction. (Query)
#  3. If the question is indirect or requires reasoning, i.e. iterative hypothesis generation and querying to verify its' correctness. (Hypothesis)

# ### Schema
# ---
# {schema}
# ---

# Response must be in JSON format with 2 fields - 'Response' and 'Mode'.
# Simple response for Conversation mode, plain english instruction for Query mode and empty string in case of Hypothesis mode.
# """
#             ),
#             (
#                 "human",
#         """User Question:
#     {question}
#     """
#             ),
#         ]
#     )
# class guardOutput(BaseModel):
#     Mode: Literal["Conversation", "Query", "Hypothesis"] = Field(
#         description="Decision on whether the question is related to data"
#     )
#     Response: str = Field(
#         description="LLM response."
#     )
    
# guard_chain = guard_prompt | llm.with_structured_output(guardOutput)
# response = guard_chain.invoke({"schema":enhanced_graph.schema, "question":"Why are my customers declining?"})

In [8]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts import ChatPromptTemplate

prompt = ChatPromptTemplate.from_messages([
    (
        "system",
        """
"You are performing root cause analysis. You may call alpha_tool multiple times to gather data.
Continue reasoning until you have sufficient evidence to explain the root cause.
After each tool response, reassess and decide whether another query is needed."

## Schema Reference:
{schema}
Understand the nodes and relations from the above schema and leverage it to make a hypothesis for the root cause analysis. 
"""
    ),
    (
        "human",
        """
User Question:
{question}
"""
    ),
])

chain = prompt | llm_with_tool

#    - Form a hypothesis for the root question.
#    - Iteratively call query_tool to test hypotheses.
#    - Refine results step by step.
#    - Then summarize findings.

# Always:  
# - Never generate Cypher.  
# - Use plain English instructions with `query_tool`.  
# - Use the schema above to understand what entities and relationships are available before deciding whether to query.

In [9]:
from langchain_core.prompts import ChatPromptTemplate
summary_prompt = ChatPromptTemplate.from_messages(
        [
            ("system","You job is to summarize the conversation and frame an answer for the question asked."),
            ("human","{conversation}")
        ]
    )
summary_chain = summary_prompt | llm

In [10]:
## Stategraph
from langgraph.graph import StateGraph,START,END
from langgraph.prebuilt import ToolNode
from langgraph.prebuilt import tools_condition
from langchain_core.messages import AIMessage, ToolMessage

# def guard_condition(state: State,) -> Literal[]:
    
#     response = guard_chain.invoke({"schema":enhanced_graph.schema, "question":"Why are my customers declining?"})
#     if response.Mode == "Conversation":
#         return {"messages":AIMessage[response.Response]}
#     elif response.Mode == "Query":
#         return {"messages":ToolMessage[query_tool(response.Response)]}
#     elif response.Mode == "Hypothesis":
#         return {"messages":AIMessage["Hypothesis mode..."]}
    
## Node definition
def tool_calling_llm(state: State):
    # Increment loop
    state["loop_count"] += 1

    # Extract latest user question only
    user_question = state["messages"][-1].content if state["messages"] else ""

    # Invoke your LLM chain
    ai_msg = chain.invoke({
        "question": user_question,
        "schema": enhanced_graph.schema,
    })

    # Return as a list (so Annotated[add_messages] will append it)
    return {"messages": [ai_msg]}

def summarize_llm(state:State):
    state["answer"] = summary_chain.invoke({"conversation":state["messages"]})
    return state

## Graph
builder=StateGraph(State)
builder.add_node("tool_calling_llm",tool_calling_llm)
builder.add_node("tools",ToolNode(tools))
builder.add_node("summarize", summarize_llm)

## Add Edges
builder.add_edge(START, "tool_calling_llm")
def reasoning_condition(state: State):
    # If tool call found, go to tools
    last_msg = state["messages"][-1]
    if isinstance(last_msg, AIMessage) and getattr(last_msg, "tool_calls", None):
        return "tools"
    
    # If loop count < 4, keep reasoning
    if state.get("loop_count", 0) < 4:
        return "tool_calling_llm"
    
    # Else summarize and stop
    return "summarize"

builder.add_conditional_edges(
    "tool_calling_llm",
    reasoning_condition  # This function chooses either "tools", END, or "summarize"
)
builder.add_edge("tools","tool_calling_llm")
builder.add_edge("summarize", END)

## compile the graph
graph=builder.compile()

# from IPython.display import Image, display
# display(Image(graph.get_graph().draw_mermaid_png()))

In [11]:
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage

initial_state = {
    "messages": [HumanMessage(content="Why did my sale drop from 2005 to 2004?")],
    "loop_count": 0,
    "answer": ""
}

for event in graph.stream(initial_state):
    for value in event.values():
        last_msg = value["messages"][-1]

        if isinstance(last_msg, AIMessage):
            if getattr(last_msg, "tool_calls", None):  
                tool_call = last_msg.tool_calls[0]
                print(f"ðŸ¤– Tool Call - Query: {tool_call['args']['query']}")
            else:
                # Extract content safely
                if isinstance(last_msg.content, list):
                    content_text = " ".join(
                        part.get("text", "") for part in last_msg.content if isinstance(part, dict)
                    )
                else:
                    content_text = last_msg.content
                print(f"ðŸ¤– AI - {content_text}")

        elif isinstance(last_msg, ToolMessage):
            print(f"ðŸ§° Tool Response - {last_msg.content}")


ðŸ¤– Tool Call - Query: MATCH (o:Order)-[c:CONTAINS]->() WHERE o.YEAR_ID = 2004 RETURN sum(c.SALES) AS Sales2004
ðŸ§° Tool Response - For query 'MATCH (o:Order)-[c:CONTAINS]->() WHERE o.YEAR_ID = 2004 RETURN sum(c.SALES) AS Sales2004', result: The total sales for 2004 were 4,724,162.60.
ðŸ¤– Tool Call - Query: MATCH (o:Order)-[c:CONTAINS]->() WHERE o.YEAR_ID = 2003 RETURN sum(c.SALES) AS Sales2003
ðŸ§° Tool Response - For query 'MATCH (o:Order)-[c:CONTAINS]->() WHERE o.YEAR_ID = 2003 RETURN sum(c.SALES) AS Sales2003', result: The total sales for the year 2003 were 3,516,979.54.
ðŸ¤– Tool Call - Query: MATCH (o:Order)-[c:CONTAINS]->() WHERE o.YEAR_ID = 2002 RETURN sum(c.SALES) AS Sales2002
ðŸ§° Tool Response - For query 'MATCH (o:Order)-[c:CONTAINS]->() WHERE o.YEAR_ID = 2004 RETURN sum(c.SALES) AS Sales2004', result: The total sales for 2004 were 4,724,162.60.
ðŸ¤– Tool Call - Query: MATCH (o:Order)-[c:CONTAINS]->() WHERE o.YEAR_ID = 2003 RETURN sum(c.SALES) AS Sales2003
ðŸ§° Tool Resp

GraphRecursionError: Recursion limit of 25 reached without hitting a stop condition. You can increase the limit by setting the `recursion_limit` config key.
For troubleshooting, visit: https://python.langchain.com/docs/troubleshooting/errors/GRAPH_RECURSION_LIMIT