In [1]:
import nest_asyncio
import os
import pickle
import re
import spacy

from alive_progress import alive_bar
from fastembed import TextEmbedding
from langchain_anthropic import ChatAnthropic
from langchain_openai import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from llama_index.core import Settings, Document, PropertyGraphIndex 
from llama_index.core.node_parser import MarkdownElementNodeParser
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.extractors.relik.base import RelikPathExtractor
from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore
from llama_index.llms.anthropic import Anthropic
from llama_parse import LlamaParse
from typing import List, Optional, Tuple

In [2]:
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
CLAUDE_API_KEY = os.getenv('CLAUDE_API_KEY')
LLAMA_API_KEY = os.getenv('LLAMA_API_KEY')

os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY
os.environ["ANTHROPIC_API_KEY"] = CLAUDE_API_KEY
os.environ["LLAMA_CLOUD_API_KEY"] = LLAMA_API_KEY

In [3]:
llm = ChatAnthropic(
    model="claude-3-5-sonnet-20240620",
    max_tokens=4096,
    temperature=0.0,
    stop=["\n\nHuman"],
)

llama_llm = Anthropic(
    model="claude-3-5-sonnet-20240620",
    max_tokens=4096,
    temperature=0.0
)

In [4]:
bge_embed_model = TextEmbedding(model_name="BAAI/bge-large-en-v1.5")
llama_openai_embed_model = OpenAIEmbedding(model_name="text-embedding-3-small")

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

In [5]:
relik = RelikPathExtractor(
    model="relik-ie/relik-relation-extraction-small"
)

                ___              __         
               /\_ \      __    /\ \        
 _ __     __   \//\ \    /\_\   \ \ \/'\    
/\`'__\ /'__`\   \ \ \   \/\ \   \ \ , <    
\ \ \/ /\  __/    \_\ \_  \ \ \   \ \ \\`\  
 \ \_\ \ \____\   /\____\  \ \_\   \ \_\ \_\
  \/_/  \/____/   \/____/   \/_/    \/_/\/_/
                                            
                                            







  embeddings = torch.load(embedding_path, map_location="cpu")


In [6]:
coref_nlp = spacy.load('en_core_web_lg')
coref_nlp.add_pipe('coreferee')



<coreferee.manager.CorefereeBroker at 0x3d32d9990>

In [32]:
# instantiate doc parser
parser = LlamaParse(
    result_type="markdown",
    num_workers=8,
    verbose = False,
    show_progress=True,
    ignore_errors=True,
    language="en",
)

# instantiate node parser
node_parser = MarkdownElementNodeParser(llm=llama_llm, num_workers=8)

In [8]:
nest_asyncio.apply()

In [33]:
LONG_CHUNK_SIZE = 2000

def coref_text(text):
    coref_doc = coref_nlp(text.strip())
    resolved_text = ""

    for token in coref_doc:
        repres = coref_doc._.coref_chains.resolve(token)
        if repres:
            resolved_text += " " + " and ".join(
                [
                    t.text
                    if t.ent_type_ == ""
                    else [e.text for e in coref_doc.ents if t in e][0]
                    for t in repres
                ]
            )
        else:
            resolved_text += " " + token.text

    return resolved_text.strip()

def remove_table_of_contents(text):
    pattern = r"TABLE OF CONTENTS.*?(?=#)"
    cleaned_text = re.sub(pattern, "", text, flags=re.DOTALL)
    return cleaned_text.strip()

def convert_nodes_to_documents(text_nodes, object_nodes, source):
    """
    Converts nodes to Documents

    Args:
        text_nodes (List[Nodes]): List of text nodes
        object_nodes (List[Nodes]): List of object nodes
        source (str): Source of the document

    Returns:
        documents (List[Documents]): List of Documents
    """
    documents = []
    for node in text_nodes:
        text = coref_text(node.text).lower()
        doc = Document(
            text= text,
            metadata = {
                "is_table": False,
                "source": source
            }
        )
        documents.append(doc)
        
    for node in object_nodes:
        text = coref_text(node.text).lower()
        doc = Document(
            text= text,
            metadata = {
                "is_table": True,
                "source": source
            }
        )
        documents.append(doc)
        
    return documents

def make_dir(data_folder):
    os.makedirs(data_folder, exist_ok=True)

def parse_docs(file_location: str, data_folder: Optional[str] = None) -> List[Document]:
    """
    Parses PDF Folder and returns a list of Documents

    Args:
        file_location (str): PDF Folder Location
        data_folder (Optional[str], optional): Folder to save pickles (Optional). Defaults to None.

    Returns:
        List[Document]: _description_
    """
    all_docs = []
    for file_name in os.listdir(file_location):
        if not file_name.endswith(".pdf"):
            continue

        print("File: " + str(file_name))
        doc_path = os.path.join(file_location, file_name)
        modified_file_name = os.path.splitext(file_name)[0].lower().replace(' ', '_')

        # results in a list of Document Objects
        documents = parser.load_data(doc_path)
        
        for idx, doc in enumerate(documents):
            doc.text = remove_table_of_contents(doc.text)
            if idx > 4:
                break

        raw_nodes = node_parser.get_nodes_from_documents(documents)
        # list of text_nodes, list of objects
        text_nodes, objects = node_parser.get_nodes_and_objects(raw_nodes)
        
        final_docs = convert_nodes_to_documents(text_nodes, objects, modified_file_name)
        all_docs.append(final_docs)
        
        if data_folder:
            data_path = os.path.join(data_folder, modified_file_name + '.pkl')
            pickle.dump(final_docs, open(data_path, "wb"))
    
    return [item for sublist in all_docs for item in sublist]

def read_pickles(data_folder: str) -> List[Document]:
    doc_list = []
    for file_name in os.listdir(data_folder):
        doc_path = os.path.join(data_folder, file_name)
        if file_name.endswith(".pkl"):
            with open(doc_path, 'rb') as file:
                # data will be a doc_list
                data = pickle.load(file)
                doc_list.append(data)
                
    # since doc_list is a list of list of documents, we need to flatten it
    doc_list = [item for sublist in doc_list for item in sublist]
    return doc_list

def further_split_long_docs(doc_list: List[Document]) -> Tuple[List[Document], List[Document]]:
    long_docs, short_docs = [], []
    for doc in doc_list:
        is_table = doc.metadata["is_table"]
        if not is_table:
            if len(doc.text) > LONG_CHUNK_SIZE:
                long_docs.append(doc)
            else:
                short_docs.append(doc)
        else:
            short_docs.append(doc)
    return long_docs, short_docs
                
def chunk_doc(doc: Document, text_splitter: RecursiveCharacterTextSplitter) -> List[Document]:
    chunks = text_splitter.split_text(doc.text)
    return [
        Document(
            text=chunk,
            metadata={
                'is_table': doc.metadata['is_table'],
                'source': doc.metadata.get('source', '')
            }
        )
        for i, chunk in enumerate(chunks)
    ]
    
def recursive_chunk_documents(long_docs: List[Document],
                              short_docs: List[Document], 
                              chunk_size: int = 1024, 
                              chunk_overlap: int = 128,
                              separators: List[str] = ["\n\n", "\n", " ", ""]) -> List[Document]:
    
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        separators=separators
    )

    for doc in long_docs:
        short_docs.extend(chunk_doc(doc, text_splitter))

    return short_docs

def get_final_docs(data_folder: Optional[str] = None, doc_list: Optional[List[Document]] = None) -> List[Document]:
    if doc_list is None:
        if data_folder is None:
            raise ValueError("Either data_folder or doc_list must be provided")
        doc_list = read_pickles(data_folder)
    
    long_docs, short_docs = further_split_long_docs(doc_list)
    final_docs = recursive_chunk_documents(long_docs, short_docs)
    return final_docs
        
def parse_and_process_docs(file_location, data_folder: Optional[str] = None) -> List[Document]:
    if data_folder:
        make_dir(data_folder)
        all_docs = parse_docs(file_location=file_location, data_folder=data_folder)
    else:
        all_docs = parse_docs(file_location=file_location)
        
    final_docs = get_final_docs(doc_list=all_docs)
    return final_docs

In [34]:
final_docs = parse_and_process_docs(file_location="pdfs")
pickle.dump(final_docs, open('data/final_docs.pkl', "wb"))

File: Diabetes Medications.pdf


0it [00:00, ?it/s]
1it [00:00, 22795.13it/s]
0it [00:00, ?it/s]


File: managing-pre-diabetes-(updated-on-27-jul-2021)c2bfc77474154c2abf623156a4b93002.pdf


0it [00:00, ?it/s]
1it [00:00, 22192.08it/s]
0it [00:00, ?it/s]
1it [00:00, 15592.21it/s]
3it [00:00, 59918.63it/s]
0it [00:00, ?it/s]


File: Diabetic Foot Ulcer_ Symptoms and Treatment.pdf


0it [00:00, ?it/s]
0it [00:00, ?it/s]


File: Diabetes Treatment_ Insulin.pdf


0it [00:00, ?it/s]
0it [00:00, ?it/s]
0it [00:00, ?it/s]
1it [00:00, 16710.37it/s]


In [40]:
final_docs

[Document(id_='6b7e062c-7f1d-47ea-824b-0bd2c11f822e', embedding=None, metadata={'is_table': False, 'source': 'diabetes_medications'}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={}, text='diabetes medications \n\n learn about the different diabetes treatment pills available , the guidelines when taking guidelines and more . \n\n  treatment of diabetes mellitus : tablets \n\n treatment of type 2 diabetes begins with diet control . if diet alone is unable to control blood sugar levels , then tablets have to be taken . if both diet and tablets fail to control the blood sugar levels , insulin injections may be needed . \n\n oral anti - diabetes tablets are used for treating type 2 diabetes . \n\n  tell your doctor if you \n\n - are allergic to any medicines \n - are pregnant , or intend to become pregnant \n - are breastfeeding \n - are taking any other medicines \n - have any other medical problems \n\n  some guidelines to follow when taking your diabetes

In [36]:
NEO4J_URI = "bolt://localhost:7687"
NEO4J_USER = "neo4j"
NEO4J_PASSWORD = "15082001"
NEO4J_DATABASE = "neo4j"

graph_store = Neo4jPropertyGraphStore(
    username=NEO4J_USER,
    password=NEO4J_PASSWORD,
    url=NEO4J_URI,
    refresh_schema=False,
)

# gds = GraphDataScience(NEO4J_URI, database=NEO4J_DATABASE, auth=(NEO4J_USER, NEO4J_PASSWORD))

In [37]:
def delete_all_nodes(graph_store):
    graph_store.structured_query("""
    MATCH (n)
    DETACH DELETE n
    """)
    print("All nodes deleted")

In [38]:
delete_all_nodes(graph_store)

All nodes deleted


In [41]:
def remove_all_neo4j_restrictions(graph_store):
    graph_store.structured_query("""
    CALL apoc.schema.assert({}, {});
    """)

In [42]:
index = PropertyGraphIndex.from_documents(
    final_docs,
    kg_extractors=[relik],
    llm=llama_llm,
    embed_model=llama_openai_embed_model,
    property_graph_store=graph_store,
    show_progress=True,
)

Parsing nodes:   0%|          | 0/38 [00:00<?, ?it/s]

Extracting triples: 100%|██████████| 38/38 [09:31<00:00, 15.04s/it]
Generating embeddings: 100%|██████████| 1/1 [00:01<00:00,  1.17s/it]
Generating embeddings: 100%|██████████| 6/6 [00:02<00:00,  2.47it/s]


## Graph De-duplication

In [43]:
def create_vector_index(graph_store):
    graph_store.structured_query("""
    CREATE VECTOR INDEX entity IF NOT EXISTS
    FOR (m:`__Entity__`)
    ON m.embedding
    OPTIONS {indexConfig: {
    `vector.dimensions`: 1536,
    `vector.similarity_function`: 'cosine'
    }}
    """)

In [44]:
create_vector_index(graph_store)

In [45]:
def check_graph_deduplication(graph_store, similarity_threshold = 0.90, word_edit_distance = 5):
    data = graph_store.structured_query("""
    MATCH (e:__Entity__)
    CALL {
    WITH e
    CALL db.index.vector.queryNodes('entity', 10, e.embedding)
    YIELD node, score
    WITH node, score
    WHERE score > toFLoat($cutoff)
        AND (toLower(node.name) CONTAINS toLower(e.name) OR toLower(e.name) CONTAINS toLower(node.name)
            OR apoc.text.distance(toLower(node.name), toLower(e.name)) < $distance)
        AND labels(e) = labels(node)
    WITH node, score
    ORDER BY node.name
    RETURN collect(node) AS nodes
    }
    WITH distinct nodes
    WHERE size(nodes) > 1
    WITH collect([n in nodes | n.name]) AS results
    UNWIND range(0, size(results)-1, 1) as index
    WITH results, index, results[index] as result
    WITH apoc.coll.sort(reduce(acc = result, index2 IN range(0, size(results)-1, 1) |
            CASE WHEN index <> index2 AND
                size(apoc.coll.intersection(acc, results[index2])) > 0
                THEN apoc.coll.union(acc, results[index2])
                ELSE acc
            END
    )) as combinedResult
    WITH distinct(combinedResult) as combinedResult
    // extra filtering
    WITH collect(combinedResult) as allCombinedResults
    UNWIND range(0, size(allCombinedResults)-1, 1) as combinedResultIndex
    WITH allCombinedResults[combinedResultIndex] as combinedResult, combinedResultIndex, allCombinedResults
    WHERE NOT any(x IN range(0,size(allCombinedResults)-1,1) 
        WHERE x <> combinedResultIndex
        AND apoc.coll.containsAll(allCombinedResults[x], combinedResult)
    )
    RETURN combinedResult  
    """, param_map={'cutoff': similarity_threshold, 'distance': word_edit_distance})
    for row in data:
        print(row)

In [46]:
check_graph_deduplication(graph_store, similarity_threshold = 0.90)

{'combinedResult': ['tablet(s', 'the tablet(s )']}
{'combinedResult': ['blood sugar', 'blood sugar level', 'blood sugar levels', 'low blood sugar levels']}
{'combinedResult': ['glucophage', 'glucose', 'glucose metabolism', 'glucose tolerance', 'impaired glucose tolerance', 'oral glucose tolerance test', 'the glucose']}
{'combinedResult': ['glibenclamide', 'gliclazide']}
{'combinedResult': ['km / hr', 'km / hr ) \n\n & gt;=', 'km / hr ) \n\n & gt;= 75']}
{'combinedResult': ['vigorous - intensity', 'vigorous - intensity exercise']}
{'combinedResult': ['smokers', 'smoking', 'smoking cessation']}
{'combinedResult': ['ingrown', 'ingrown nails']}
{'combinedResult': ['doctor', 'food', 'your']}
{'combinedResult': ['body cells', 'cells', 'nurse', 'units']}
{'combinedResult': ['a', 'anti - diabetes', 'diabetes', 'diabetes mellitus', 'diabetes_treatment__insulin', 'diabetic', 'diabetic foot ulcer', 'diabetic ulcer', 'diabetic ulcers', 'diabetic_foot_ulcer__symptoms_and_treatment', 'humulin', 'ins

In [51]:
check_graph_deduplication(graph_store, similarity_threshold = 0.954)

{'combinedResult': ['blood sugar', 'blood sugar level', 'blood sugar levels']}
{'combinedResult': ['glucose', 'the glucose']}
{'combinedResult': ['pre - diabetes', 'pre - diabetes \n\n  1', 'pre - diabetes.†']}
{'combinedResult': ['2021', '27 july 2021', '3 july 2017', 'a', 'ace', 'igt', 'jul-2021)c2bfc77474154c2abf623156a4b93002', 'july 2017', 'ms.', 'qr']}
{'combinedResult': ['diabetes mellitus', 'type 2 diabetes mellitus']}
{'combinedResult': ['lifestyle', 'lifestyle changes', 'lifestyle intervention', 'lifestyle intervention.9,15']}
{'combinedResult': ['republic of singapore', 'singapore', 'singaporeans']}
{'combinedResult': ['obese', 'obesity']}
{'combinedResult': ['diabetes-(updated', 'diabetes-(updated-on-27-jul-2021)c2bfc77474154c2abf623156a4b93002']}
{'combinedResult': ['km / hr', 'km / hr ) \n\n & gt;=', 'km / hr ) \n\n & gt;= 75']}
{'combinedResult': ['vigorous - intensity', 'vigorous - intensity exercise']}
{'combinedResult': ['diabetic ulcer', 'diabetic ulcers']}
{'combine

In [52]:
def graph_deduplication(graph_store, similarity_threshold = 0.90, word_edit_distance = 5):
    graph_store.structured_query("""
        MATCH (e:__Entity__)
        CALL {
        WITH e
        CALL db.index.vector.queryNodes('entity', 10, e.embedding)
        YIELD node, score
        WITH node, score
        WHERE score > toFLoat($cutoff)
            AND (toLower(node.name) CONTAINS toLower(e.name) OR toLower(e.name) CONTAINS toLower(node.name)
                OR apoc.text.distance(toLower(node.name), toLower(e.name)) < $distance)
            AND labels(e) = labels(node)
        WITH node, score
        ORDER BY node.name
        RETURN collect(node) AS nodes
        }
        WITH distinct nodes
        WHERE size(nodes) > 1
        WITH collect([n in nodes | n.name]) AS results
        UNWIND range(0, size(results)-1, 1) as index
        WITH results, index, results[index] as result
        WITH apoc.coll.sort(reduce(acc = result, index2 IN range(0, size(results)-1, 1) |
                CASE WHEN index <> index2 AND
                    size(apoc.coll.intersection(acc, results[index2])) > 0
                    THEN apoc.coll.union(acc, results[index2])
                    ELSE acc
                END
        )) as combinedResult
        WITH distinct(combinedResult) as combinedResult
        // extra filtering
        WITH collect(combinedResult) as allCombinedResults
        UNWIND range(0, size(allCombinedResults)-1, 1) as combinedResultIndex
        WITH allCombinedResults[combinedResultIndex] as combinedResult, combinedResultIndex, allCombinedResults
        WHERE NOT any(x IN range(0,size(allCombinedResults)-1,1)
            WHERE x <> combinedResultIndex
            AND apoc.coll.containsAll(allCombinedResults[x], combinedResult)
        )
        CALL {
        WITH combinedResult
            UNWIND combinedResult AS name
            MATCH (e:__Entity__ {name:name})
            WITH e
            ORDER BY size(e.name) DESC // prefer longer names to remain after merging
            RETURN collect(e) AS nodes
        }
        CALL apoc.refactor.mergeNodes(nodes, {properties: {
            `.*`: 'discard'
        }})
        YIELD node
        RETURN count(*)
        """, param_map={'cutoff': similarity_threshold, 'distance': word_edit_distance}
    )

In [53]:
graph_deduplication(graph_store, similarity_threshold = 0.954)