# Introduction to Graph based RAG using LangChain and Neo4j
Based on "https://medium.com/neo4j/enhancing-the-accuracy-of-rag-applications-with-knowledge-graphs-ad5e2ffab663" by [Tomaz Bratanic](https://medium.com/@bratanic-tomaz), with permission. 

Original notebook here: https://github.com/tomasonjo/blogs/blob/master/llm/enhancing_rag_with_graph.ipynb

Prereqs: 
    
*  A Neo4j database: Either [Neo4j Desktop](https://neo4j.com/download/) (local instance) or [Aura](https://neo4j.com/product/auradb/) (cloud instance)
 


## Load libraries

In [1]:
from langchain_core.runnables import (
    RunnableBranch,
    RunnableLambda,
    ConfigurableField, RunnableParallel, RunnablePassthrough
)
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import Tuple, List, Optional
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
import os
from langchain_community.graphs import Neo4jGraph
from langchain.document_loaders import WikipediaLoader
from langchain.text_splitter import TokenTextSplitter
from langchain_openai import ChatOpenAI
from langchain_experimental.graph_transformers import LLMGraphTransformer
from neo4j import GraphDatabase
from yfiles_jupyter_graphs import GraphWidget
from langchain_community.vectorstores import Neo4jVector
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores.neo4j_vector import remove_lucene_chars
from importlib import reload
import utils

import warnings
warnings.filterwarnings('ignore')
#reload(utils)


For example, replace imports like: `from langchain_core.pydantic_v1 import BaseModel`
with: `from pydantic import BaseModel`
or the v1 compatibility namespace if you are working in a code base that has not been fully upgraded to pydantic 2 yet. 	from pydantic.v1 import BaseModel

  exec(code_obj, self.user_global_ns, self.user_ns)


### Connect to the Neo4j database 
Connect to the graph database using the [LangChain Neo4jGraph connector](https://api.python.langchain.com/en/latest/graphs/langchain_community.graphs.neo4j_graph.Neo4jGraph.html).

In [2]:
# Create a Neo4jGraph object to interact with the graph database
graph = Neo4jGraph()

### Data Ingestion
Load Elizabeth I's Wikipedia page using the [LangChain Wikipedia loader](https://api.python.langchain.com/en/latest/document_loaders/langchain_community.document_loaders.wikipedia.WikipediaLoader.html)

In [3]:

# Read the wikipedia article using  WikipediaLoader 
raw_documents = WikipediaLoader(query="Elizabeth I").load()
# Define chunking strategy 
text_splitter = TokenTextSplitter(chunk_size=512, chunk_overlap=24)
# Split the documents into chunks based on the chunking strategy``
documents = text_splitter.split_documents(raw_documents[:3])

### Construct a graph based on the retrieved documents
The [LLMGraphTransformer](https://python.langchain.com/v0.2/api_reference/experimental/graph_transformers/langchain_experimental.graph_transformers.llm.LLMGraphTransformer.html) returns graph documents, which can be imported to the Neo4j graph database via the [add_graph_documents](https://api.python.langchain.com/en/latest/graphs/langchain_community.graphs.neo4j_graph.Neo4jGraph.html#langchain_community.graphs.neo4j_graph.Neo4jGraph.add_graph_documents) method. The `baseEntityLabel` parameter assigns an additional __Entity__ label to each node, enhancing indexing and query performance. The `include_source` parameter links nodes to their originating documents, facilitating data traceability and context understanding. Define the LLM for knowledge graph generation chain usage. Currently, only OpenAI and Mistra function calling models are supported. 

In [4]:
# Create the LLM using gpt-4o
llm=ChatOpenAI(temperature=0, model_name="gpt-4o") # gpt-4-0125-preview occasionally has issues

# Create the LLMGraphTransformer
llm_transformer = LLMGraphTransformer(llm=llm)

# Convert the documents to graph documents
graph_documents = llm_transformer.convert_to_graph_documents(documents)

# Add the documents to the graph
graph.add_graph_documents(
    graph_documents,
    baseEntityLabel=True,
    include_source=True
)

### Inspect the generated graph with yfiles visualization.

In [5]:
# directly show the graph resulting from the given Cypher query
default_cypher = "MATCH (s)-[r:!MENTIONS]->(t) RETURN s,r,t LIMIT 50"

def showGraph(cypher: str = default_cypher):
    # create a neo4j session to run queries
    driver = utils.neo4j_driver
    # create a graph widget to display the graph
    session = driver.session()
    widget = GraphWidget(graph = session.run(cypher).graph())
    widget.node_label_mapping = 'id'
    #display(widget)
    return widget

showGraph()

GraphWidget(layout=Layout(height='770px', width='100%'))

## Hybrid Retrieval
In the following diagram, a user asks a question. A RAG retriever employs keyword and vector searches to search unstructured text data, combining it with collected knowledge graph information. As Neo4j supports both keyword and vector indexes, you can implement all three retrieval options with a single database system. The collected data from these sources is fed into an LLM to generate and deliver the final answer.

![image.png](https://raw.githubusercontent.com/tomasonjo/blogs/master/graphhybrid.png))

## Unstructured data retriever
The `Neo4jVector.from_existing_graph` method adds both keyword and vector retrieval to documents. It configures keyword and vector search indexes for a hybrid retrieval approach.

In [6]:
# Create the vector index from the graph 
vector_index = Neo4jVector.from_existing_graph(
    OpenAIEmbeddings(),
    search_type="hybrid",
    node_label="Document",
    text_node_properties=["text"],
    embedding_node_property="embedding"
)

## Graph retriever
The graph retriever starts by identifying relevant entities in the input. For simplicity, we instruct the LLM to identify people, organizations, and locations. To achieve this, we will use LCEL with the `with_structured_output` method to achieve this.

In [7]:

# Define the retriever
graph.query(
    "CREATE FULLTEXT INDEX entity IF NOT EXISTS FOR (e:__Entity__) ON EACH [e.id]")

# Extract entities from text
class Entities(BaseModel):
    """Identifying information about entities."""
    names: List[str] = Field(
        ...,
        description="All the person, organization, or business entities that "
        "appear in the text",
    )

#
prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are extracting organization and person entities from the text.",
        ),
        (
            "human",
            "Use the given format to extract information from the following "
            "input: {question}",
        ),
    ]
)
# Extract entities from text
entity_chain = prompt | llm.with_structured_output(Entities)

### Test Graph retriever to detect entities in the question

In [8]:

entity_chain.invoke({"question": "Where was Amelia Earhart born?"}).names

['Amelia Earhart']

### Define query generation function
Now that we can detect entities in the question, let's use a full-text index to map them to the knowledge graph. 
First, we define a full-text index and a function that will generate full-text queries that allow a bit of misspelling.

In [9]:
def generate_full_text_query(input: str) -> str:
    """
    Generate a full-text search query for a given input string.

    This function constructs a query string suitable for a full-text search.
    It processes the input string by splitting it into words and appending a
    similarity threshold (~2 changed characters) to each word, then combines
    them using the AND operator. Useful for mapping entities from user questions
    to database values, and allows for some misspelings.
    """
    full_text_query = ""
    words = [el for el in remove_lucene_chars(input).split() if el]
    for word in words[:-1]:
        full_text_query += f" {word}~2 AND"
    full_text_query += f" {words[-1]}~2"
    return full_text_query.strip()

# Fulltext index query
def structured_retriever(question: str) -> str:
    """
    Collects the neighborhood of entities mentioned
    in the question
    """
    result = ""
    entities = entity_chain.invoke({"question": question})
    for entity in entities.names:
        response = graph.query(
            """CALL db.index.fulltext.queryNodes('entity', $query, {limit:2})
            YIELD node,score
            CALL {
              WITH node
              MATCH (node)-[r:!MENTIONS]->(neighbor)
              RETURN node.id + ' - ' + type(r) + ' -> ' + neighbor.id AS output
              UNION ALL
              WITH node
              MATCH (node)<-[r:!MENTIONS]-(neighbor)
              RETURN neighbor.id + ' - ' + type(r) + ' -> ' +  node.id AS output
            }
            RETURN output LIMIT 50
            """,
            {"query": generate_full_text_query(entity)},
        )
        result += "\n".join([el['output'] for el in response])
    return result

The `structured_retriever` function starts by detecting entities in the user question. Next, it iterates over the detected entities and uses a Cypher template to retrieve the neighborhood of relevant nodes. 

In [10]:
print(structured_retriever("Who is Elizabeth I?"))



Elizabeth I - QUEEN_OF -> England
Elizabeth I - QUEEN_OF -> Ireland
Elizabeth I - LAST_MONARCH_OF -> House Of Tudor
Elizabeth I - CHILD_OF -> Henry Viii
Elizabeth I - CHILD_OF -> Anne Boleyn
Elizabeth I - ADVISED_BY -> William Cecil
Elizabeth I - SUPREME_GOVERNOR_OF -> Church Of England
Elizabeth I - SUCCEEDED_BY -> James Vi Of Scotland
Elizabeth I - SIBLING_OF -> Edward Vi
Elizabeth I - SIBLING_OF -> Mary I
Elizabeth I - PROTECTED_BY -> Sir Francis Walsingham
Elizabeth I - FOREIGN_AFFAIRS_WITH -> France
Elizabeth I - FOREIGN_AFFAIRS_WITH -> Spain
Elizabeth I - MILITARY_CAMPAIGN_IN -> Netherlands
Queen Elizabeth I - INFLUENCED -> English Royal Portraits
Queen Elizabeth I - SAT_FOR -> Cornelis Ketel
Queen Elizabeth I - SAT_FOR -> Federico Zuccaro
Queen Elizabeth I - SAT_FOR -> Isaac Oliver
Queen Elizabeth I - SAT_FOR -> Marcus Gheeraerts The Younger
Nicholas Hilliard - OFFICIAL_LIMNER -> Queen Elizabeth I
George Gower - SERJEANT_PAINTER -> Queen Elizabeth I
George Gower - APPROVED_PORTR

### Final retriever
Bringing it all together, combine the unstructured and graph retriever to create the final context that will be passed to an LLM.

In [11]:
def retriever(question: str):
    print(f"Search query: {question}")
    structured_data = structured_retriever(question)
    unstructured_data = [el.page_content for el in vector_index.similarity_search(question)]
    final_data = f"""Structured data:
{structured_data}
Unstructured data:
{"#Document ". join(unstructured_data)}
    """
    return final_data

### Defining the RAG chain
We have successfully implemented the the RAG retrieval component. First, the query rewriting part that allows conversational follow up questions with chat history.

In [12]:
# Condense a chat history and follow-up question into a standalone question
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question,
in its original language.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""  # noqa: E501
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)

def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
    buffer = []
    for human, ai in chat_history:
        buffer.append(HumanMessage(content=human))
        buffer.append(AIMessage(content=ai))
    return buffer

_search_query = RunnableBranch(
    # If input includes chat_history, we condense it with the follow-up question
    (
        RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
            run_name="HasChatHistoryCheck"
        ),  # Condense follow-up question and chat into a standalone_question
        RunnablePassthrough.assign(
            chat_history=lambda x: _format_chat_history(x["chat_history"])
        )
        | CONDENSE_QUESTION_PROMPT
        | ChatOpenAI(temperature=0)
        | StrOutputParser(),
    ),
    # Else, we have no chat history, so just pass through the question
    RunnableLambda(lambda x : x["question"]),
)

### Defining the prompt
Next, we introduce a prompt that leverages the context provided by the integrated hybrid retriever to produce the response, completing the implementation of the RAG chain.

In [13]:
template = """Answer the question based only on the following context:
{context}

Question: {question}
Use natural language and be concise. If you don't know the answer, you can say "I don't know." 
Answer:"""
prompt = ChatPromptTemplate.from_template(template)

chain = (
    RunnableParallel(
        {
            "context": _search_query | retriever,
            "question": RunnablePassthrough(),
        }
    )
    | prompt
    | llm
    | StrOutputParser()
)

### Test the hybrid RAG implementation

In [14]:
chain.invoke({"question": "Which house did Elizabeth I belong to?"})

Search query: Which house did Elizabeth I belong to?




'Elizabeth I belonged to the House of Tudor.'

###  Test a follow up question based on chat history - did it rewrite?
Given that we use vector and keyword search methods, we must rewrite follow-up questions to optimize our search process.

In [15]:
chain.invoke(
    {
        "question": "When was she born?",
        "chat_history": [("Which house did Elizabeth I belong to?", "House Of Tudor")],
    }
)

Search query: When was Elizabeth I born?




'Elizabeth I was born on 7 September 1533.'