In [1]:
import re
import os
from langchain.vectorstores.neo4j_vector import Neo4jVector
from langchain.document_loaders import WikipediaLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import (
    CharacterTextSplitter,
    RecursiveCharacterTextSplitter,
)
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from secret import *
from langchain.graphs.graph_document import (
    GraphDocument,
    Node as BaseNode,
    Relationship as BaseRelationship,
)
from typing import List, Dict, Any, Optional
from langchain.pydantic_v1 import Field, BaseModel
import chromadb

os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY

In [2]:
class Property(BaseModel):
    """A single property consisting of key and value"""

    key: str = Field(..., description="key")
    value: str = Field(..., description="value")


class Node(BaseNode):
    properties: Optional[List[Property]] = Field(
        None, description="List of node properties"
    )


class Relationship(BaseRelationship):
    properties: Optional[List[Property]] = Field(
        None, description="List of relationship properties"
    )


class KnowledgeGraph(BaseModel):
    """Generate a knowledge graph with entities and relationships."""

    nodes: List[Node] = Field(..., description="List of nodes in the knowledge graph")
    rels: List[Relationship] = Field(
        ..., description="List of relationships in the knowledge graph"
    )

In [None]:
llm = ChatOpenAI(model="gpt-3.5-turbo-16k", temperature=0)

In [3]:
from langchain.prompts.chat import ChatPromptTemplate
from langchain.chains.openai_functions import create_structured_output_chain
from langchain.chat_models import ChatOpenAI

llm = ChatOpenAI(model="gpt-3.5-turbo-16k", temperature=0)


def get_extraction_chain(
    allowed_nodes: Optional[List[str]] = None, allowed_rels: Optional[List[str]] = None
):
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                f"""# Knowledge Graph Instructions for GPT
## 1. Overview
You are a top-tier algorithm designed for extracting information in structured formats to build a knowledge graph.
- **Nodes** represent entities and concepts. They're akin to Wikipedia nodes.
- The aim is to achieve simplicity and clarity in the knowledge graph, making it accessible for a vast audience.
## 2. Labeling Nodes
- **Consistency**: Ensure you use basic or elementary types for node labels.
  - For example, when you identify an entity representing a person, always label it as **"person"**. Avoid using more specific terms like "mathematician" or "scientist".
- **Node IDs**: Never utilize integers as node IDs. Node IDs should be names or human-readable identifiers found in the text.
{'- **Allowed Node Labels:**' + ", ".join(allowed_nodes) if allowed_nodes else ""}
{'- **Allowed Relationship Types**:' + ", ".join(allowed_rels) if allowed_rels else ""}
## 3. Handling Numerical Data and Dates
- Numerical data, like age or other related information, should be incorporated as attributes or properties of the respective nodes.
- **No Separate Nodes for Dates/Numbers**: Do not create separate nodes for dates or numerical values. Always attach them as attributes or properties of nodes.
- **Property Format**: Properties must be in a key-value format.
- **Quotation Marks**: Never use escaped single or double quotes within property values.
- **Naming Convention**: Use camelCase for property keys, e.g., `birthDate`.
## 4. Coreference Resolution
- **Maintain Entity Consistency**: When extracting entities, it's vital to ensure consistency.
If an entity, such as "John Doe", is mentioned multiple times in the text but is referred to by different names or pronouns (e.g., "Joe", "he"), 
always use the most complete identifier for that entity throughout the knowledge graph. In this example, use "John Doe" as the entity ID.  
Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial. 
## 5. Strict Compliance
Adhere to the rules strictly. Non-compliance will result in termination.""",
            ),
            (
                "human",
                "Use the given format to extract information from the following input: {input}",
            ),
            ("human", "Tip: Make sure to answer in the correct format"),
        ]
    )
    return create_structured_output_chain(KnowledgeGraph, llm, prompt, verbose=False)

In [4]:
def format_property_key(s: str) -> str:
    words = s.split()
    if not words:
        return s
    first_word = words[0].lower()
    capitalized_words = [word.capitalize() for word in words[1:]]
    return "".join([first_word] + capitalized_words)

def props_to_dict(props) -> dict:
    """Convert properties to a dictionary."""
    properties = {}
    if not props:
        return properties
    for p in props:
        properties[format_property_key(p.key)] = p.value
    return properties

def map_to_base_node(node: Node) -> BaseNode:
    """Map the KnowledgeGraph Node to the base Node."""
    properties = props_to_dict(node.properties) if node.properties else {}
    # Add name property for better Cypher statement generation
    properties["name"] = node.id.title()
    return BaseNode(
        id=node.id.title(), type=node.type.capitalize(), properties=properties
    )


def map_to_base_relationship(rel: Relationship) -> BaseRelationship:
    """Map the KnowledgeGraph Relationship to the base Relationship."""
    source = map_to_base_node(rel.source)
    target = map_to_base_node(rel.target)
    properties = props_to_dict(rel.properties) if rel.properties else {}
    return BaseRelationship(
        source=source, target=target, type=rel.type, properties=properties
    )

In [5]:
# from langchain.schema.document import Document
# from langchain.graphs import Neo4jGraph

# def extract_and_store_graph(
#     graph: Neo4jGraph,
#     document: Document,
#     nodes: Optional[List[str]] = None,
#     rels: Optional[List[str]] = None,
# ) -> None:
#     # Extract graph data using OpenAI functions
#     extract_chain = get_extraction_chain(nodes, rels)
#     # return extract_chain
#     data = extract_chain.run(document.page_content)
#     # Construct a graph document
#     graph_document = GraphDocument(
#         nodes=[map_to_base_node(node) for node in data.nodes],
#         relationships=[map_to_base_relationship(rel) for rel in data.rels],
#         source=document,
#     )
#     # Store information into a graph
#     graph.add_graph_documents([graph_document])

In [6]:
from langchain.document_loaders import WikipediaLoader
from langchain.text_splitter import TokenTextSplitter
raw_documents = WikipediaLoader(query="Atelectasis").load()
text_splitter = TokenTextSplitter(chunk_size=2048, chunk_overlap=24)
documents = text_splitter.split_documents(raw_documents)

In [7]:
from langchain.graphs import Neo4jGraph
graph = Neo4jGraph(url=NOE4J_URL, username="neo4j", password=NOE4J_PASSWORD)
print(graph.schema)


        Node properties are the following:
        []
        Relationship properties are the following:
        []
        The relationships are the following:
        []
        


In [8]:
from pydantic.v1.error_wrappers import ValidationError

In [9]:
from tenacity import (
    retry,
    stop_after_attempt,
    wait_random_exponential,
    retry_if_exception_type
)  # for exponential backoff
from http.client import SERVICE_UNAVAILABLE
from pydantic.v1.error_wrappers import ValidationError

@retry(
    retry=retry_if_exception_type((SERVICE_UNAVAILABLE)), 
    wait=wait_random_exponential(multiplier=1, max=60), 
    stop=stop_after_attempt(10)
)
def add_graph_document(graph, graph_document):
    graph.add_graph_documents([graph_document])
    return graph


@retry(
    retry=retry_if_exception_type((SERVICE_UNAVAILABLE)), 
    wait=wait_random_exponential(multiplier=1, max=60), 
    stop=stop_after_attempt(10)
)
def add_graph_documents(graph, graph_documents):
    graph.add_graph_documents(graph_documents)
    return graph

@retry(
    retry=retry_if_exception_type((ValidationError)), 
    wait=wait_random_exponential(multiplier=1, max=60), 
    stop=stop_after_attempt(10)
)
def chain_run(chain, content):
    return chain.run(content)


In [10]:
from tqdm import tqdm

# allowed_nodes = None
allowed_nodes = ["Symptom", "Disease"]
allowed_rels = ["CAN_CAUSE", "DESCRIBE", "HAS"]

extract_chain = get_extraction_chain(allowed_nodes, allowed_rels)
gds = []

for d in tqdm(documents, total=len(documents)):
    data = chain_run(extract_chain, d.page_content)
    # data = extract_chain.run(d.page_content)
    graph_document = GraphDocument(
        nodes=[map_to_base_node(node) for node in data.nodes],
        relationships=[map_to_base_relationship(rel) for rel in data.rels],
        source=d,
    )
    # add_graph_document(graph, graph_document)
    gds.append(graph_document)

add_graph_documents(graph, gds)

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [06:15<00:00, 37.59s/it]


<langchain.graphs.neo4j_graph.Neo4jGraph at 0x200d51c1e80>

In [16]:
graph.refresh_schema()
print(graph.schema)


        Node properties are the following:
        [{'labels': 'Condition', 'properties': [{'property': 'name', 'type': 'STRING'}, {'property': 'id', 'type': 'STRING'}, {'property': 'description', 'type': 'STRING'}]}, {'labels': 'Symptom', 'properties': [{'property': 'name', 'type': 'STRING'}, {'property': 'definition', 'type': 'STRING'}, {'property': 'id', 'type': 'STRING'}, {'property': 'description', 'type': 'STRING'}]}, {'labels': 'Cause', 'properties': [{'property': 'id', 'type': 'STRING'}, {'property': 'name', 'type': 'STRING'}]}, {'labels': 'Diagnosis', 'properties': [{'property': 'id', 'type': 'STRING'}, {'property': 'name', 'type': 'STRING'}]}, {'labels': 'Anatomy', 'properties': [{'property': 'name', 'type': 'STRING'}, {'property': 'definition', 'type': 'STRING'}, {'property': 'id', 'type': 'STRING'}, {'property': 'description', 'type': 'STRING'}]}, {'labels': 'Measurement', 'properties': [{'property': 'name', 'type': 'STRING'}, {'property': 'id', 'type': 'STRING'}, {'proper

In [17]:
# Query the knowledge graph in a RAG application
from langchain.chains import GraphCypherQAChain

cypher_chain = GraphCypherQAChain.from_llm(
    graph=graph,
    cypher_llm=ChatOpenAI(temperature=0, model="gpt-3.5-turbo-16k"),
    qa_llm=ChatOpenAI(temperature=0, model="gpt-3.5-turbo-16k"),
    validate_cypher=True,  # Validate relationship directions
    verbose=True,
)

In [18]:
cypher_chain.run("What is Atelectasis?")



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (d:Disease {name: 'Atelectasis'})
RETURN d[0m
Full Context:
[32;1m[1;3m[{'d': {'pulmonaryConsolidation': 'It is distinct from pulmonary consolidation, in which the alveoli are filled with liquid.', 'riskFactors': 'Type of surgery, use of muscle relaxation, obesity, high oxygen, lower lung segments, age, chronic obstructive pulmonary disease or asthma, type of anesthetic', 'alveoli': 'The alveoli are deflated down to little or no volume.', 'chestX-rays': 'It is a very common finding in chest X-rays and other radiological studies.', 'causes': 'Post-surgical atelectasis, pulmonary tuberculosis, blockage of bronchiole or bronchus, poor surfactant spreading, suction', 'name': 'Atelectasis', 'diagnosis': 'Clinically significant atelectasis is generally visible on chest X-ray; findings can include lung opacification and/or loss of lung volume. Post-surgical atelectasis will be bibasal in pattern. Ches

'Atelectasis is the collapse or closure of a lung, which leads to reduced or absent gas exchange. It is usually unilateral, affecting part or all of one lung. Clinically significant atelectasis can be seen on chest X-rays, where findings may include lung opacification and/or loss of lung volume. Some common causes of atelectasis include post-surgical atelectasis, pulmonary tuberculosis, blockage of bronchiole or bronchus, poor surfactant spreading, and suction. It is distinct from pulmonary consolidation, where the alveoli are filled with liquid. Risk factors for atelectasis include type of surgery, use of muscle relaxation, obesity, high oxygen, lower lung segments, age, chronic obstructive pulmonary disease or asthma, and type of anesthetic. If the cause of atelectasis is not clinically apparent, further diagnostic tests such as chest CT or bronchoscopy may be necessary.'

In [14]:
cypher_chain.run("What are the symptoms of Atelectasis?")



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (:Disease {name: 'Atelectasis'})-[:DESCRIBE]->(s:Symptom)
RETURN s.name[0m
Full Context:
[32;1m[1;3m[][0m

[1m> Finished chain.[0m


"I'm sorry, but I don't have the information to answer your question. It would be best to consult a medical professional for accurate information on the symptoms of Atelectasis."

In [15]:
cypher_chain.run("What can cause Atelectasis?")



[1m> Entering new GraphCypherQAChain chain...[0m
Generated Cypher:
[32;1m[1;3mMATCH (:Condition {name: 'Atelectasis'})-[:CAN_CAUSE]->(cause:Condition)
RETURN cause.name[0m
Full Context:
[32;1m[1;3m[][0m

[1m> Finished chain.[0m


'There are several factors that can cause atelectasis. Some common causes include blockage of the airways due to mucus, a foreign object, or a tumor, as well as lung diseases such as chronic obstructive pulmonary disease (COPD) or pneumonia. Additionally, certain medical procedures, such as anesthesia or prolonged bed rest, can also contribute to the development of atelectasis. It is important to consult with a healthcare professional for a proper diagnosis and treatment plan.'