<a href="https://colab.research.google.com/github/Miyazaki-Kohei/LLMS/blob/main/GraphRAG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# 参考 https://blog.langchain.dev/enhancing-rag-based-applications-accuracy-by-constructing-and-leveraging-knowledge-graphs/

In [1]:
%%capture
%pip install --upgrade --quiet  langchain langchain-community langchain-openai langchain-experimental neo4j wikipedia tiktoken yfiles_jupyter_graphs

In [2]:
import os
from neo4j import GraphDatabase
from yfiles_jupyter_graphs import GraphWidget
from langchain_core.runnables import RunnableLambda, RunnableParallel, RunnablePassthrough, ConfigurableField
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from typing import List
from langchain_core.output_parsers import StrOutputParser
from langchain_community.graphs import Neo4jGraph
from langchain.document_loaders import TextLoader
from langchain.text_splitter import TokenTextSplitter
from langchain_openai import ChatOpenAI
from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain_community.vectorstores import Neo4jVector
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores.neo4j_vector import remove_lucene_chars

try:
    import google.colab
    from google.colab import output
    output.enable_custom_widget_manager()
except:
    pass

In [4]:
import getpass

os.environ["OPENAI_API_KEY"] = getpass.getpass(prompt = 'OpenAIのAPIキーを入力してください')
os.environ["NEO4J_URI"] = getpass.getpass(prompt = 'NEO4JのURIを入力してください')
os.environ["NEO4J_USERNAME"] = getpass.getpass(prompt = 'NEO4JのUSERNAMEを入力してください')
os.environ["NEO4J_PASSWORD"] = getpass.getpass(prompt = 'NEO4Jのパスワードを入力してください')

OpenAIのAPIキーを入力してください··········
NEO4JのURIを入力してください··········
NEO4JのUSERNAMEを入力してください··········
NEO4Jのパスワードを入力してください··········


In [5]:
raw_documents = TextLoader('孫悟空少年編.txt').load()
text_splitter = TokenTextSplitter(chunk_size=512, chunk_overlap=125)
documents = text_splitter.split_documents(raw_documents)

In [6]:
llm=ChatOpenAI(temperature=0, model_name="gpt-3.5-turbo-0125")
llm_transformer = LLMGraphTransformer(llm=llm)
graph_documents = llm_transformer.convert_to_graph_documents(documents)

In [7]:
graph = Neo4jGraph()

In [9]:
graph.add_graph_documents(
    graph_documents,
    baseEntityLabel=True,
    include_source=True
)

In [10]:
# directly show the graph resulting from the given Cypher query
default_cypher = "MATCH (s)-[r:!MENTIONS]->(t) RETURN s,r,t LIMIT 50"

def showGraph(cypher: str = default_cypher):
    # create a neo4j session to run queries
    driver = GraphDatabase.driver(
        uri = os.environ["NEO4J_URI"],
        auth = (os.environ["NEO4J_USERNAME"],
                os.environ["NEO4J_PASSWORD"]))
    session = driver.session()
    widget = GraphWidget(graph = session.run(cypher).graph())
    widget.node_label_mapping = 'id'
    #display(widget)
    return widget

In [11]:
showGraph()

GraphWidget(layout=Layout(height='800px', width='100%'))

In [12]:
vector_index = Neo4jVector.from_existing_graph(
    OpenAIEmbeddings(),
    search_type="hybrid",
    node_label="Document",
    text_node_properties=["text"],
    embedding_node_property="embedding"
)

In [13]:
graph.query("CREATE FULLTEXT INDEX entity IF NOT EXISTS FOR (e:__Entity__) ON EACH [e.id]")

[]

In [14]:
# Extract entities from text
class Entities(BaseModel):
    """Identifying information about entities."""

    names: List[str] = Field(
        ...,
        description="All the person, organization, or business entities that "
        "appear in the text",
    )

In [17]:
prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are extracting organization and person entities from the text.",
        ),
        (
            "human",
            "Use the given format to extract information from the following "
            "input: {question}",
        ),
    ]
)

entity_chain = prompt | llm.with_structured_output(Entities)

In [18]:
entity_chain.invoke({"question": "梧空とにゃんたは戦った"}).names

['梧空', 'にゃんた']

In [15]:
def generate_full_text_query(input: str) -> str:
    """
    Generate a full-text search query for a given input string.

    This function constructs a query string suitable for a full-text search.
    It processes the input string by splitting it into words and appending a
    similarity threshold (~2 changed characters) to each word, then combines
    them using the AND operator. Useful for mapping entities from user questions
    to database values, and allows for some misspelings.
    """
    full_text_query = ""
    words = [el for el in remove_lucene_chars(input).split() if el]
    for word in words[:-1]:
        full_text_query += f" {word}~2 AND"
    full_text_query += f" {words[-1]}~2"
    return full_text_query.strip()

# Fulltext index query
def structured_retriever(question: str) -> str:
    """
    Collects the neighborhood of entities mentioned
    in the question
    """
    result = ""
    entities = entity_chain.invoke({"question": question})
    for entity in entities.names:
        response = graph.query(
            """CALL db.index.fulltext.queryNodes('entity', $query, {limit:20})
            YIELD node,score
            CALL {
              WITH node
              MATCH (node)-[r:!MENTIONS]->(neighbor)
              RETURN node.id + ' - ' + type(r) + ' -> ' + neighbor.id AS output
              UNION ALL
              WITH node
              MATCH (node)<-[r:!MENTIONS]-(neighbor)
              RETURN neighbor.id + ' - ' + type(r) + ' -> ' +  node.id AS output
            }
            RETURN output LIMIT 1000
            """,
            {"query": generate_full_text_query(entity)},
        )
        result += "\n".join([el['output'] for el in response])
    return result

In [19]:
print(structured_retriever("孫悟空と関わりがあるエンティティを知りたい"))

孫悟空 - ADOPTED_BY -> 孫悟飯
悟空 - REUNITED_WITH -> 孫悟飯
悟空 - FATHER -> 孫悟飯
悟空 - FRIEND -> クリリン
悟空 - FRIEND -> 悟飯
悟空 - MENTOR -> 亀仙人
悟空 - MENTOR -> 界王
悟空 - MENTOR -> 界王神
悟空 - PARTICIPATED_IN -> 天下一武道会
悟空 - FOUGHT_AGAINST -> レッドリボン軍
悟空 - FOUGHT_AGAINST -> 天津飯
悟空 - FOUGHT_AGAINST -> ピッコロ大魔王
悟空 - TRAINING_UNDER -> 仙猫カリン
悟空 - REUNITED_WITH -> 孫悟飯
悟空 - RECEIVED_HELP -> ヤジロベー
悟空 - RECEIVED_HELP -> カリン
悟空 - RESURRECTED -> 神龍
悟空 - MET -> 神様
悟空 - MARRIED -> チチ
悟空 - FOUGHT -> ピッコロ大魔王
悟空 - FOUGHT -> ラディッツ
悟空 - FOUGHT -> 悟飯
悟空 - FOUGHT -> セル
悟空 - FATHER -> 孫悟飯
悟空 - CONFLICT -> フリーザ
悟空 - USED -> ドラゴンボール
悟空 - ESCAPED -> ナメック星
悟空 - ARRIVED -> ヤードラット星
悟空 - FIGHT -> 魔人ブウ
悟空 - TEAM_UP -> ベジータ
悟空 - FINAL_BATTLE -> ベジータ
悟空 - FAMILY -> パン
悟空 - TRAINING -> ウーブ
ピッコロ大魔王 - CHALLENGED -> 悟空
フリーザ - RIVAL -> 悟空
悟飯 - RESCUE -> 悟空
悟飯 - ENCOUNTERED -> 魔人ブウ
悟飯 - FOUGHT -> セル
悟飯 - USED -> ドラゴンボール
悟飯 - REVIVED -> フリーザ
悟飯 - REVIVED -> サイヤ人
悟飯 - FIGHT -> 魔人ブウ
悟飯 - FUSION -> ゴテンクス
悟飯 - RESCUE -> 悟空
悟空 - FRIEND -> 悟飯
悟空 - FOUGHT 

In [20]:
def retriever(question: str):
    print(f"Search query: {question}")
    structured_data = structured_retriever(question)
    unstructured_data = [el.page_content for el in vector_index.similarity_search(question)]
    final_data = f"""Structured data:
    {structured_data}
    Unstructured data:
    {"#Document ". join(unstructured_data)}
    """
    # print(final_data)
    return final_data

In [21]:
_search_query = RunnableLambda(lambda x: x["question"])

In [22]:
template = """あなたは優秀なAIです。下記のコンテキストを利用してユーザーの質問に丁寧に答えてください。
必ず文脈からわかる情報のみを使用して回答を生成してください。
{context}

ユーザーの質問: {question}"""
prompt = ChatPromptTemplate.from_template(template)

chain = (
    RunnableParallel(
        {
            "context": _search_query | retriever,
            "question": RunnablePassthrough(),
        }
    )
    | prompt
    | llm
    | StrOutputParser()
)

In [24]:
chain.invoke({"question": "梧空と仲がいいのは誰ですか？"})

Search query: 梧空と仲がいいのは誰ですか？


'悟空と仲がいいのは、悟飯、クリリン、悟飯、ブルマ、そして界王神です。'