# LangChain Integration

## 1. Dependencies

In [177]:
import pandas as pd
import matplotlib.pyplot as plt
import os
from dotenv import load_dotenv
from langchain_neo4j import Neo4jGraph, GraphCypherQAChain, Neo4jVector
from langchain_openai import ChatOpenAI, OpenAIEmbeddings

#Advanced LangGraph
from operator import add
from typing import Annotated, List, Literal, Optional
from typing_extensions import TypedDict
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_core.output_parsers import StrOutputParser
from langchain_neo4j.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
from pydantic import BaseModel, Field
from neo4j.exceptions import CypherSyntaxError
from IPython.display import Image, display
from langgraph.graph import END, START, StateGraph

## 2. Connections with graph database

In [178]:
uri = "bolt://localhost:7687"
username = "neo4j"          
password = "12345678"
graph = Neo4jGraph(uri, username, password)


ValueError: Could not connect to Neo4j database. Please ensure that the username and password are correct

In [None]:
nodes = graph.query("MATCH (n) RETURN n LIMIT 2")
for node in nodes:
    print(node)

{'n': {'symbol': 'AAPL', 'betweennessCentrality': 0.0, 'name': 'Apple Inc.', 'eigenvectorCentrality': 2.0812589256979136e-08, 'degreeCentrality': 3.0, 'sector': 'Information Technology'}}
{'n': {'symbol': 'ABBV', 'betweennessCentrality': 0.0, 'name': 'AbbVie', 'eigenvectorCentrality': 1.9054538849600147e-170, 'degreeCentrality': 0.0, 'sector': 'Health Care'}}


In [None]:
relationships = graph.query("MATCH (a)-[r]->(b) RETURN a, r, b LIMIT 2")
for relationship in relationships:
    print(relationship)

{'a': {'symbol': 'AAPL', 'betweennessCentrality': 0.0, 'name': 'Apple Inc.', 'eigenvectorCentrality': 2.0812589256979136e-08, 'degreeCentrality': 3.0, 'sector': 'Information Technology'}, 'r': ({'symbol': 'AAPL', 'betweennessCentrality': 0.0, 'name': 'Apple Inc.', 'eigenvectorCentrality': 2.0812589256979136e-08, 'degreeCentrality': 3.0, 'sector': 'Information Technology'}, 'BELONGS_TO', {'name': 'Information Technology'}), 'b': {'name': 'Information Technology'}}
{'a': {'symbol': 'AAPL', 'betweennessCentrality': 0.0, 'name': 'Apple Inc.', 'eigenvectorCentrality': 2.0812589256979136e-08, 'degreeCentrality': 3.0, 'sector': 'Information Technology'}, 'r': ({'symbol': 'AAPL', 'betweennessCentrality': 0.0, 'name': 'Apple Inc.', 'eigenvectorCentrality': 2.0812589256979136e-08, 'degreeCentrality': 3.0, 'sector': 'Information Technology'}, 'CORRELATED', {'symbol': 'AMZN', 'betweennessCentrality': 2.0, 'name': 'Amazon', 'eigenvectorCentrality': 2.985062353989397e-08, 'degreeCentrality': 5.0, 's

## 3. Integration with GraphCypherQAChain

In [166]:
graph.refresh_schema()
print(graph.schema)

Node properties:
Stock {symbol: STRING, name: STRING, sector: STRING, eigenvectorCentrality: FLOAT, betweennessCentrality: FLOAT, degreeCentrality: FLOAT, louvainCommunityId: INTEGER}
Sector {name: STRING}
Relationship properties:
CORRELATED {correlation: FLOAT}
The relationships:
(:Stock)-[:BELONGS_TO]->(:Sector)
(:Stock)-[:CORRELATED]->(:Stock)


Using the enhaced_schema enriches property information. This additional context helps guide the LLM toward generating more accurate and effective queries.

In [None]:
enhanced_graph=Neo4jGraph(uri, username, password,enhanced_schema=True)
print(enhanced_graph.schema)

Node properties:
- **Stock**
  - `symbol`: STRING Example: "AAPL"
  - `name`: STRING Example: "Apple Inc."
  - `sector`: STRING Example: "Information Technology"
  - `eigenvectorCentrality`: FLOAT Min: 1.9054538849600147E-170, Max: 0.28373309268337865
  - `betweennessCentrality`: FLOAT Min: 0.0, Max: 111.3809523809524
  - `degreeCentrality`: FLOAT Min: 0.0, Max: 17.0
  - `louvainCommunityId`: INTEGER Min: 9, Max: 12
- **Sector**
  - `name`: STRING Example: "Information Technology"
Relationship properties:
- **CORRELATED**
  - `correlation`: FLOAT Min: 0.6009239619998398, Max: 0.8585241254499043
The relationships:
(:Stock)-[:BELONGS_TO]->(:Sector)
(:Stock)-[:CORRELATED]->(:Stock)


In [None]:
load_dotenv()
openai_api_key=os.getenv("OPENAI_API_KEY")
llm=ChatOpenAI(model="gpt-4o", temperature=0)
chain=GraphCypherQAChain.from_llm(
    graph=enhanced_graph, llm=llm, verbose=True, allow_dangerous_requests=True
)

response=chain.invoke({"query":"What are the top 10 stocks by eigenvector centrality?"})
response



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mcypher
MATCH (s:Stock)
RETURN s.symbol, s.name, s.eigenvectorCentrality
ORDER BY s.eigenvectorCentrality DESC
LIMIT 10
[0m
Full Context:
[32;1m[1;3m[{'s.symbol': 'JPM', 's.name': 'JPMorgan Chase', 's.eigenvectorCentrality': 0.28373309268337865}, {'s.symbol': 'MET', 's.name': 'MetLife', 's.eigenvectorCentrality': 0.28275347463734357}, {'s.symbol': 'BK', 's.name': 'BNY Mellon', 's.eigenvectorCentrality': 0.27668837004417635}, {'s.symbol': 'MS', 's.name': 'Morgan Stanley', 's.eigenvectorCentrality': 0.27668837004417635}, {'s.symbol': 'BAC', 's.name': 'Bank of America', 's.eigenvectorCentrality': 0.27668837004417635}, {'s.symbol': 'BRK-B', 's.name': 'Berkshire Hathaway (Class B)', 's.eigenvectorCentrality': 0.2764113801812669}, {'s.symbol': 'USB', 's.name': 'U.S. Bancorp', 's.eigenvectorCentrality': 0.2671514470823249}, {'s.symbol': 'WFC', 's.name': 'Wells Fargo', 's.eigenvectorCentrality': 0.26715144708

{'query': 'What are the top 10 stocks by eigenvector centrality?',
 'result': 'The top 10 stocks by eigenvector centrality are JPMorgan Chase, MetLife, BNY Mellon, Morgan Stanley, Bank of America, Berkshire Hathaway (Class B), U.S. Bancorp, Wells Fargo, Goldman Sachs, and American Express.'}

## 3. Advanced Implementation with LangGraph

Trying with LangGraph to enhance the workflow.

In [169]:
class InputState(TypedDict):
    question:str

class OverallState(TypedDict):
    question:str
    next_action:str
    cypher_statement:str
    cypher_errors:List[str]
    database_records:List[dict]
    steps:Annotated[List[str],add]

class OutputState(TypedDict):
    answer: str
    steps:List[str]
    cypher_statement: str

Implementation of the 'guardrails' to validate if the question is related to our topic "Stock Market".

In [170]:
guardrails_system="""
As an intelligent assistant, your primary objective is to decide whether a given question is related to stocks or finance.
If the question is related to stocks, stock market data, or financial anaysis, output "stock". Otherwise, output "end".
To make this decision, assess the content of the question and determine if it refers to stocks, stock market, financial terms, or stock-related data.
Provide only the specified output:"stock" or "end".
"""

guardrails_prompt=ChatPromptTemplate.from_messages(
    [
        (
            "system",
            guardrails_system,
        ),
        (
            "human",
            ("{question}"),
        ),
    ]
)

class GuardrailsOutput(BaseModel):
    decision:Literal["stock","end"]=Field(
        description="Decision on whether the question is related to stocks or finance"
    )

guardrails_chain=guardrails_prompt | llm.with_structured_output(GuardrailsOutput)

def guardrails(state:InputState) -> OverallState:
    """
    Decides if the question is related to stocks/finance or not.
    """
    guardrails_output=guardrails_chain.invoke({"question":state.get("question")})
    database_records=None
    if guardrails_output.decision =="end":
        database_records="This question is not related to stocks or finance. Therefore, I cannot answer it."
    return {
        "next_action":guardrails_output.decision,
        "database_records":database_records,
        "steps":["guardrail"]
    }

**few-shot prompting**

The idea is guide the LLM in query generation.

In [175]:
examples = [
    {
        "question": "How many stocks belong to the Information Technology sector?",
        "query": "MATCH (s:Stock)-[:BELONGS_TO]->(:Sector {name: 'Information Technology'}) RETURN count(DISTINCT s)",
    },
    {
        "question": "Which stocks have a correlation with AAPL?",
        "query": "MATCH (a:Stock {symbol: 'AAPL'})-[:CORRELATED]->(b:Stock) RETURN b.symbol",
    },
    {
        "question": "How many stocks belong to the Health Care sector?",
        "query": "MATCH (s:Stock)-[:BELONGS_TO]->(:Sector {name: 'Health Care'}) RETURN count(s)",
    },
    {
        "question": "List the sectors that AAPL belongs to.",
        "query": "MATCH (s:Stock {symbol: 'AAPL'})-[:BELONGS_TO]->(sector:Sector) RETURN sector.name",
    },
    {
        "question": "Which stocks are highly correlated with Amazon?",
        "query": "MATCH (a:Stock {symbol: 'AMZN'})-[:CORRELATED]->(b:Stock) RETURN b.symbol ORDER BY r.correlation DESC LIMIT 10",
    },
    {
        "question": "Which stocks have an eigenvector centrality higher than 0.1?",
        "query": "MATCH (s:Stock) WHERE s.eigenvectorCentrality > 0.1 RETURN s.symbol, s.eigenvectorCentrality",
    },
    {
        "question": "Find the sector with the most stocks.",
        "query": "MATCH (s:Stock)-[:BELONGS_TO]->(sector:Sector) RETURN sector.name, COUNT(s) AS num_stocks ORDER BY num_stocks DESC LIMIT 1",
    },
    {
        "question": "Which stocks have a high degree centrality?",
        "query": "MATCH (s:Stock) WHERE s.degreeCentrality > 10 RETURN s.symbol, s.degreeCentrality ORDER BY s.degreeCentrality DESC LIMIT 10",
    },
]

In [None]:
example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples, OpenAIEmbeddings(), Neo4jVector, k=2, input_keys=["question"]
)

ValueError: Could not connect to Neo4j database. Please ensure that the username and password are correct

In [None]:
text2cypher_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            (
                "Given an input question, convert it to a Cypher query. No pre-amble."
                "Do not wrap the response in any backticks or anything else. Respond with a Cypher statement only!"
            ),
        ),
        (
            "human",
            (
                """
                You are a Neo4j expert. Given an input question, create a syntactically correct Cypher query to run.
                Do not wrap the response in any backticks or anything else. Respond with a Cypher statement only!
                Here is the schema information
                {schema}
                Below are a number of examples of questions and their corresponding Cypher queries.
                
                {fewshot_examples}
                
                User input:{question}
                Cypher query:
                """
            ),
        ),
    ]
)

text2cypher_chain=text2cypher_prompt | llm | StrOutputParser()

def generate_cypher(state:OverallState) -> OverallState:
    NL= "\n"
    fewshot_examples=(NL*2).join(
        [
            f"Question: {el['question']}{NL}Cypher:{el['query']}"
            for el in example_selector.select_examples(
                {"question":state.get("question")}
            )
        ]
    )
    
    generated_cypher=text2cypher_chain.invoke(
        {
            "question":state.get("question"),
            "fewshot_examples":fewshot_examples,
            "schema":enhanced_graph.schema,
        }
    )
    return {"cypher_statement": generated_cypher, "steps":["generate_cypher"]}

In [None]:
validate_cypher_system= """
You are a Cypher expert reviewing a statement written by a developer working with a stock market graph. 
The graph contains nodes labeled as `Stock` and `Sector`, and relationships like `BELONGS_TO` and `CORRELATED`. 
You must ensure that the Cypher query conforms to the graph schema, with correct labels and properties.

The schema includes:
- Node labels: `Stock`, `Sector`
- Relationship types: `BELONGS_TO`, `CORRELATED`
- Node properties:
    - Stock: `name`, `eigenvectorCentrality`, `betweennessCentrality`, `degreeCentrality`, `symbol`, `sector`, `louvainCommunityId`
    - Sector: `name`
- Relationship properties:
    - CORRELATED: `correlation`

Check the following:
* Are there any syntax errors in the Cypher statement?
* Are there any missing node labels or relationship types from the schema?
* Are any properties missing or incorrectly referenced in the schema?
* Does the Cypher statement include enough information to answer the question?

Examples of good errors:
* Label (:Stock) does not exist, did you mean (:Sector)?
* Property eigenvectorCentrality does not exist for label Stock, did you mean marketCap?
* Relationship CORRELATED does not exist, did you mean BELONGS_TO?

Schema:
{schema}

The question is:
{question}

The Cypher statement is:
{cypher}

Make sure you don't make any mistakes!
"""

validate_cypher_prompt=ChatPromptTemplate.from_messages(
    [
        ("system", validate_cypher_system),
        ("human",validate_cypher_user),
    ]
)

class Property(BaseModel):
    node_label:str=Field(
        description="The label of the node to which this property belongs."
    )
    property_key: str=Field(description="The key of the property being filtered.")
    property_value: str=Field(
        description="The value that the property is being matched against."
    )

class ValidateCypherOutput(BaseModel):
    errors: Optional[List[str]]=Field(
        description="A list of syntax or semantical errors in the Cypher statement. Always explain the discrepancy between schema and Cypher statement"
    )
    filters: Optional[List[Property]]=Field(
        description="A list of property-based filters applied in the Cypher statement."
    )

validate_cypher_chain=validate_cypher_prompt | llm.with_structured_output(
    ValidateCypherOutput
)

In [None]:
corrector_schema=[
    Schema(el["start"], el["type"], el["end"])
    for el in enhanced_graph.structured_schema.get("relationships")
]

cypher_query_corrector=CypherQueryCorrector(corrector_schema)

In [None]:
def validate_cypher(state:OverallState) -> OverallState:
    """
    Validates the Cypher statement and maps any property values to the database.
    """
    errors=[]
    mapping_errors=[]
    try:
        enhanced_graph.query(f"EXPLAIN {state.get('cypher_statement')}")
    except CypherSyntaxError as e:
        errors.append(e.message)
    
    corrected_cypher=cypher_query_corrector(state.get("cypher_statement"))
    if not corrected_cypher:
        errors.append("The generated Cypher statement doesn't fit te stock graph schema")
    if corrected_cypher != state.get("cypher_statement"):
        print("Relationship direction was corrected")
        
    llm_output = validate_cypher_chain.invoke(
        {
            "question":state.get("question"),
            "schema":enhanced_graph.schema,
            "cypher":state.get("cypher_statement"),
        }
    )
    
    if llm_output.errors:
        errors.extend(llm_output.errors)
        
    if llm_output.filters:
        for filter in llm_output.filters:
            if not any(
                prop["property"] == filter.property_key and prop["type"] =="STRING"
                for prop in enhanced_graph.structured_schema["node_props"][filter.node_label]
             ):
                continue
            
            mapping= enhanced_graph.query(
                f"MATCH (n:{filter.node_label}) WHERE toLower(n. `{filter.property_key}`)= toLower($value) RETURN 'yes' LIMIT 1",
                {"value":filter.property_value},
            )
            if not mapping:
                print(
                    f"Missing value mapping for {filter.node_label} on property {filter.property_key} with value {filter.property_value}"
                )
                mapping_errors.append(
                    f"Missing value mapping for {filter.node_label} on property {filter.property_key} with value {filter.property_value}"
                )
    
    if mapping_errors:
        next_action="end"
    elif errors:
        next_action="correct_cypher"
    else:
        next_action="execute_cypher"
    return {
        "next_action":next_action,
        "cypher_statement":corrected_cypher,
        "cypher_error":errors,
        "steps":["validate_cypher"],
    }

In [147]:
correct_cypher_prompt=ChatPromptTemplate.from_messages(
    [
        (
            "system",
            (
                "You are a Cypher expert reviewing a statement written by a junior developer."
                "You need to correct the Cypher statement based on the provided errors. No pre-amble."
                "Do not wrap the response in any backticks or anything else. Respond with a Cypher statement only!"
            ),
        ),
        (
            "human",
            (
                """
                Check for invalid syntax or semantics and return a corrected Cypher statement.
                Scheme:
                {schema}
                Note: Do not include any explanations or apologies in your responses.
                Do not wrap the response in any backticks or anything else.
                Respond with a Cypher statement only!
                
                Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.
                The question is:
                {question}
                The cypher statement is:
                {cypher}
                The errors are:
                {errors}
                
                Corrected Cypher statement:
                """
            ),
        ),
    ]
)

correct_cypher_chain=correct_cypher_prompt | llm | StrOutputParser()

def correct_cypher(state:OcerallState) -> OverallState:
    """
    Correct the Cypher statement based on the provided errors.
    """
    corrected_cypher=correct_cypher_chain.invoke(
        {
            "question":state.get("question"),
            "errors":state.get("cypher_errors"),
            "cypher":state.get("cypher_statement"),
            "schema":enhanced_graph.schema,
        }
    )
    return {
        "next_action":"validate_cypher",
        "cypher_statement": corrected_cypher,
        "steps":["correct_cypher"],
    }

In [None]:
no_results="I couldn't find any relevant information in the database"

def execute_cypher(state:OverallState) -> OverallState:
    """"
    Executes the given Cypher statement.
    """
    records = enhanced_graph.query(state.get("cypher_statement"))
    
    return {
        "database_records":records if records else no_results,
        "next_action":"end",
        "steps":["execute_cypher"],
    }

In [None]:
generate_final_prompt=ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are a helpful assistant for a stock market database.",
        ),
        (
            "human",
            (
                """
                Use the following results retrieved from a stock market database to provide a succinct,
                definitive answer to the user's question.
                
                Respond as if you are answering the question directly.
                Results: {results}
                Question:{question}
                """
            ),
        ),
    ]
)

generate_final_chain=generate_final_prompt | llm | StrOutputParser()

def generate_final_answer(state:OverallState) -> OutputState:
    """"
    Generate the final answer based on database results and the user's question.
    """
    final_answer = generate_final_chain.invoke(
        {"question":state.get("question"), "results":state.get("database_records")}
    )
    return {"answer":final_answer, "steps":["generate_final_answer"]}

In [None]:
def guardrails_condition(
    state: OverallState,
) -> Literal["generate_cypher","generate_final_answer"]:
    if state.get("next_action") == "end":
        return "generate_final_answer"
    elif state.get("next_action") =="stock_query":
        return "generate_cypher"

def validate_cypher_condition(
    state:OverallState,
) -> Literal ["generate_final_answer","correct_cypher","execute_cypher"]:
    if state.get("next_action") =="end":
        return "generate_final_answer"
    elif state.get("next_action") == "correct_cypher":
        return "correct_cypher"
    elif state.get("next_action") == "execute_cypher":
        return "execute_cypher"
    
    return "end"

In [None]:
langgraph=StateGraph(OverallState,input=InputState,output=OutputState)

langgraph.add_node(guardrails)
langgraph.add_node(generate_cypher)
langgraph.add_node(validate_cypher)
langgraph.add_node(correct_cypher)
langgraph.add_node(execute_cypher)
langgraph.add_node(generate_final_answer)

langgraph.add_edge(START,"guardrails")
langgraph.add_conditional_edges(
    "guardrails",
    guardrails_condition,
)

langgraph.add_edge("generate_cypher","validate_cypher")
langgraph.add_conditional_edges(
    "validate_cypher",
    validate_cypher_condition,
)

langgraph.add_edge("execute_cypher","generate_final_answer")
langgraph.add_edge("correct_cypher","validate_cypher")
langgraph.add_edge("generate_final_answer",END)

langgraph=langgraph.compile()

display(Image(langgraph.get_graph().draw_mermaid_png))

In [None]:
langgraph.invoke({"question":"How many stocks belong to the InformationTechnology sector?"})