<a href="https://colab.research.google.com/github/Vapour-Exchange/ai-swarm/blob/main/GRAPH_TIME.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction to Graph based RAG using LangChain and Neo4j
Based on "https://medium.com/neo4j/enhancing-the-accuracy-of-rag-applications-with-knowledge-graphs-ad5e2ffab663" by [Tomaz Bratanic](https://medium.com/@bratanic-tomaz), with permission.

Original notebook here: https://github.com/tomasonjo/blogs/blob/master/llm/enhancing_rag_with_graph.ipynb

Prereqs:
    
*  A Neo4j database: Either [Neo4j Desktop](https://neo4j.com/download/) (local instance) or [Aura](https://neo4j.com/product/auradb/) (cloud instance)



## Load libraries

In [None]:
# prompt: install all the dependencies used in the next codeblock

!pip install neo4j
!pip install langchain
!pip install langchain_community
!pip install python-dotenv
!pip install openai
!pip install langchain_openai
!pip install langchain_experimental
!pip install yfiles_jupyter_graphs
!pip install utils



### Connect to the Neo4j database
Connect to the graph database using the [LangChain Neo4jGraph connector](https://api.python.langchain.com/en/latest/graphs/langchain_community.graphs.neo4j_graph.Neo4jGraph.html).

In [None]:
from langchain_core.runnables import (
    RunnableBranch,
    RunnableLambda,
    ConfigurableField, RunnableParallel, RunnablePassthrough
)
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import Tuple, List, Optional
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.output_parsers import StrOutputParser
import os
from langchain_community.graphs import Neo4jGraph
from langchain.document_loaders import WikipediaLoader
from langchain.text_splitter import TokenTextSplitter
from langchain_openai import ChatOpenAI
from langchain_experimental.graph_transformers import LLMGraphTransformer
from neo4j import GraphDatabase
from yfiles_jupyter_graphs import GraphWidget
from langchain_community.vectorstores import Neo4jVector
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores.neo4j_vector import remove_lucene_chars
from importlib import reload
import utils

import warnings
warnings.filterwarnings('ignore')
#reload(utils)

In [None]:
# Create a Neo4jGraph object to interact with the graph database
from google.colab import userdata
graph = Neo4jGraph(
    url=userdata.get('url'), username=userdata.get('user'), password=userdata.get('password')
)

### Data Ingestion
Load Tweet Data - webhook_logs - last 7 days

In [None]:

from langchain_community.document_loaders.csv_loader import CSVLoader

loader = CSVLoader(file_path='./sample_data/posts-3days.csv')
documents = loader.load()
print(documents)



### Construct a graph based on the retrieved documents
The [LLMGraphTransformer](https://python.langchain.com/v0.2/api_reference/experimental/graph_transformers/langchain_experimental.graph_transformers.llm.LLMGraphTransformer.html) returns graph documents, which can be imported to the Neo4j graph database via the [add_graph_documents](https://api.python.langchain.com/en/latest/graphs/langchain_community.graphs.neo4j_graph.Neo4jGraph.html#langchain_community.graphs.neo4j_graph.Neo4jGraph.add_graph_documents) method. The `baseEntityLabel` parameter assigns an additional __Entity__ label to each node, enhancing indexing and query performance. The `include_source` parameter links nodes to their originating documents, facilitating data traceability and context understanding. Define the LLM for knowledge graph generation chain usage. Currently, only OpenAI and Mistra function calling models are supported.

In [None]:
# Create the LLM using gpt-4
from langchain_openai import AzureChatOpenAI

llm = AzureChatOpenAI(
    api_key = userdata.get('AZURE_OPENAI_API_KEY'),
    azure_endpoint = userdata.get('AZURE_OPENAI_ENDPOINT'),
    azure_deployment="tigestfouroomini",
    api_version="2024-08-01-preview",
    temperature=0,
    model_name="gpt-4o-mini"
)

# Create the LLMGraphTransformer
llm_transformer = LLMGraphTransformer(llm=llm)

# Process documents one at a time
for doc in documents:
    try:

        # Convert single document to graph document
        graph_document = await llm_transformer.aconvert_to_graph_documents([doc])

        # Add the single document to graph
        graph.add_graph_documents(
            graph_document,
            baseEntityLabel=True,
            include_source=True
        )

        print("Successfully added document to graph")

    except Exception as e:
        print(f"Error processing document: {str(e)}")
        continue

print("Finished processing all documents")

Successfully added document to graph
Successfully added document to graph
Successfully added document to graph
Successfully added document to graph
Successfully added document to graph
Successfully added document to graph
Successfully added document to graph
Successfully added document to graph
Successfully added document to graph
Successfully added document to graph
Successfully added document to graph
Successfully added document to graph
Successfully added document to graph
Successfully added document to graph
Successfully added document to graph
Successfully added document to graph
Successfully added document to graph
Successfully added document to graph
Successfully added document to graph
Successfully added document to graph
Successfully added document to graph
Successfully added document to graph
Successfully added document to graph
Successfully added document to graph
Successfully added document to graph
Successfully added document to graph
Successfully added document to graph
S

CancelledError: 

### Inspect the generated graph with yfiles visualization.

In [None]:
# directly show the graph resulting from the given Cypher query
from neo4j import GraphDatabase, RoutingControl
default_cypher = "MATCH (s)-[r]->(t) RETURN s, r, t"

URI = userdata.get('url')
AUTH = (userdata.get('user'), userdata.get('password'))

def showGraph(cypher: str = default_cypher):

  with GraphDatabase.driver(URI, auth=AUTH) as driver:
    # create a neo4j session to run queries
      # create a graph widget to display the graph
      session = driver.session()
      widget = GraphWidget(graph = session.run(cypher).graph())
      widget.node_label_mapping = 'id'
      #display(widget)
      return widget

showGraph()

GraphWidget(layout=Layout(height='800px', width='100%'))

In [None]:
from google.colab import output
output.enable_custom_widget_manager()

Support for third party widgets will remain active for the duration of the session. To disable support:

In [None]:
from google.colab import output
output.disable_custom_widget_manager()

## Hybrid Retrieval
In the following diagram, a user asks a question. A RAG retriever employs keyword and vector searches to search unstructured text data, combining it with collected knowledge graph information. As Neo4j supports both keyword and vector indexes, you can implement all three retrieval options with a single database system. The collected data from these sources is fed into an LLM to generate and deliver the final answer.

![image.png](https://raw.githubusercontent.com/tomasonjo/blogs/master/graphhybrid.png))

## Unstructured data retriever
The `Neo4jVector.from_existing_graph` method adds both keyword and vector retrieval to documents. It configures keyword and vector search indexes for a hybrid retrieval approach.

In [None]:
from langchain_openai import AzureOpenAIEmbeddings
import os
os.environ['NEO4J_URI'] = userdata.get('url')
os.environ['NEO4J_USERNAME'] = userdata.get('user')
os.environ['NEO4J_PASSWORD'] = userdata.get('password')
# Create the vector index from the graph
vector_index = Neo4jVector.from_existing_graph(
      AzureOpenAIEmbeddings(
      api_key = '4H0ol5RRaaofPoZN0k2LqoHVOUJagzRDCouGBKlplIfPbhZQ23teJQQJ99ALACHYHv6XJ3w3AAAAACOGpC74',
      azure_endpoint = 'https://abhis-m4rhtb7g-eastus2.cognitiveservices.azure.com/',
      azure_deployment="text-embedding-3-small-2",
      api_version="2023-05-15",
    ),
    search_type="hybrid",
    node_label="Document",
    text_node_properties=["text"],
    embedding_node_property="embedding"
)

## Graph retriever
The graph retriever starts by identifying relevant entities in the input. For simplicity, we instruct the LLM to identify people, organizations, and locations. To achieve this, we will use LCEL with the `with_structured_output` method to achieve this.

In [None]:

# Define the retriever
graph.query(
    "CREATE FULLTEXT INDEX entity IF NOT EXISTS FOR (e:__Entity__) ON EACH [e.id]")

# Extract entities from text
class Entities(BaseModel):
    """Identifying information about entities."""
    names: List[str] = Field(
        ...,
        description="All the person, organization, or business entities that "
        "appear in the text",
    )

#
prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are extracting organization and person entities from the text.",
        ),
        (
            "human",
            "Use the given format to extract information from the following "
            "input: {question}",
        ),
    ]
)
# Extract entities from text
entity_chain = prompt | llm.with_structured_output(Entities)

ERROR:neo4j.io:[#A544]  _: <CONNECTION> error: Failed to read from defunct connection IPv4Address(('node-t6wvyk64u2qu4.eastus.cloudapp.azure.com', 7687)) (ResolvedIPv4Address(('172.191.138.95', 7687))): TimeoutError('timed out')


### Test Graph retriever to detect entities in the question

In [None]:

entity_chain.invoke({"question": "beast games"})

Entities(names=['beast games'])

### Define query generation function
Now that we can detect entities in the question, let's use a full-text index to map them to the knowledge graph.
First, we define a full-text index and a function that will generate full-text queries that allow a bit of misspelling.

In [None]:
def generate_full_text_query(input: str) -> str:
    """
    Generate a full-text search query for a given input string.

    This function constructs a query string suitable for a full-text search.
    It processes the input string by splitting it into words and appending a
    similarity threshold (~2 changed characters) to each word, then combines
    them using the AND operator. Useful for mapping entities from user questions
    to database values, and allows for some misspelings.
    """
    full_text_query = ""
    words = [el for el in remove_lucene_chars(input).split() if el]
    for word in words[:-1]:
        full_text_query += f" {word}~2 AND"
    full_text_query += f" {words[-1]}~2"
    return full_text_query.strip()

# Fulltext index query
def structured_retriever(question: str) -> str:
    """
    Collects the neighborhood of entities mentioned
    in the question
    """
    result = ""
    entities = entity_chain.invoke({"question": question})
    for entity in entities.names:
        response = graph.query(
            """CALL db.index.fulltext.queryNodes('entity', $query, {limit:2})
            YIELD node,score
            CALL {
              WITH node
              MATCH (node)-[r:!MENTIONS]->(neighbor)
              RETURN node.id + ' - ' + type(r) + ' -> ' + neighbor.id AS output
              UNION ALL
              WITH node
              MATCH (node)<-[r:!MENTIONS]-(neighbor)
              RETURN neighbor.id + ' - ' + type(r) + ' -> ' +  node.id AS output
            }
            RETURN output LIMIT 50
            """,
            {"query": generate_full_text_query(entity)},
        )
        result += "\n".join([el['output'] for el in response])
    return result

The `structured_retriever` function starts by detecting entities in the user question. Next, it iterates over the detected entities and uses a Cypher template to retrieve the neighborhood of relevant nodes.

In [None]:
print(structured_retriever("bitcoin"))

Bob - UNLOCKS -> Bitcoin


### Final retriever
Bringing it all together, combine the unstructured and graph retriever to create the final context that will be passed to an LLM.

In [None]:
def retriever(question: str):
    print(f"Search query: {question}")
    structured_data = structured_retriever(question)
    unstructured_data = [el.page_content for el in vector_index.similarity_search(question)]
    final_data = f"""Structured data:
{structured_data}
Unstructured data:
{"#Document ". join(unstructured_data)}
    """
    return final_data

### Defining the RAG chain
We have successfully implemented the the RAG retrieval component. First, the query rewriting part that allows conversational follow up questions with chat history.

In [None]:
# Condense a chat history and follow-up question into a standalone question
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question,
in its original language.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""  # noqa: E501
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)

def _format_chat_history(chat_history: List[Tuple[str, str]]) -> List:
    buffer = []
    for human, ai in chat_history:
        buffer.append(HumanMessage(content=human))
        buffer.append(AIMessage(content=ai))
    return buffer

_search_query = RunnableBranch(
    # If input includes chat_history, we condense it with the follow-up question
    (
        RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config(
            run_name="HasChatHistoryCheck"
        ),  # Condense follow-up question and chat into a standalone_question
        RunnablePassthrough.assign(
            chat_history=lambda x: _format_chat_history(x["chat_history"])
        )
        | CONDENSE_QUESTION_PROMPT
        | AzureChatOpenAI(
            api_key = userdata.get('AZURE_OPENAI_API_KEY'),
            azure_endpoint = userdata.get('AZURE_OPENAI_ENDPOINT'),
            azure_deployment="tigestfouroomini",
            api_version="2024-08-01-preview",
            temperature=0,
            model_name="gpt-4o-mini"
          )
        | StrOutputParser(),
    ),
    # Else, we have no chat history, so just pass through the question
    RunnableLambda(lambda x : x["question"]),
)

### Defining the prompt
Next, we introduce a prompt that leverages the context provided by the integrated hybrid retriever to produce the response, completing the implementation of the RAG chain.

In [None]:
template = """Answer the question based only on the following context. It should be detailed as it would be used for creating content like a twitter post, linkedin post or response to a linkedin post, a twitter post etc.:
{context}

Question: {question}
Use natural language and be concise. If you don't know the answer, you can say "null"
Answer:"""
prompt = ChatPromptTemplate.from_template(template)

chain = (
    RunnableParallel(
        {
            "context": _search_query | retriever,
            "question": RunnablePassthrough(),
        }
    )
    | prompt
    | llm
    | StrOutputParser()
)

### Test the hybrid RAG implementation

In [None]:
chain.invoke({"question": "psychosort's opinions with tweet content"})

Search query: psychosort's opinions with tweet content


ERROR:neo4j.io:[#8F00]  _: <CONNECTION> error: Failed to read from defunct connection IPv4Address(('node-t6wvyk64u2qu4.eastus.cloudapp.azure.com', 7687)) (ResolvedIPv4Address(('172.191.138.95', 7687))): TimeoutError('timed out')
ERROR:neo4j.io:[#A908]  _: <CONNECTION> error: Failed to read from defunct connection IPv4Address(('node-t6wvyk64u2qu4.eastus.cloudapp.azure.com', 7687)) (ResolvedIPv4Address(('172.191.138.95', 7687))): TimeoutError('timed out')


KeyboardInterrupt: 

###  Test a follow up question based on chat history - did it rewrite?
Given that we use vector and keyword search methods, we must rewrite follow-up questions to optimize our search process.

In [None]:
chain.invoke(
    {
        "question": "When was she born?",
        "chat_history": [("Which house did Elizabeth I belong to?", "House Of Tudor")],
    }
)