In [9]:
import json
from neo4j import GraphDatabase
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain.text_splitter import RecursiveCharacterTextSplitter
from dotenv import load_dotenv
import os
import re
import time
import hashlib

In [10]:
dotenv_loaded = load_dotenv()
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
#NEO4J_URI = os.getenv("NEO4J_URI")
#NEO4J_USERNAME = os.getenv("NEO4J_USERNAME")
#NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD")
#NEO4J_DATABASE = os.getenv("NEO4J_DATABASE")

In [11]:
NEO4J_USERNAME = "neo4j"
NEO4J_PASSWORD = "cancer_monograph"
NEO4J_URI = "neo4j://127.0.0.1:7687"
NEO4J_DATABASE = "monograph"

In [12]:
def get_extraction_chain():
    llm = ChatGroq(temperature=0, model_name="llama-3.1-8b-instant")

    prompt_template = """
    You are an expert in oncology and data extraction. From the following text, extract entities and relationships
    based on the provided schema. Output ONLY a valid JSON object containing two lists: "nodes" and "relationships".

    Schema:
    Node Labels: Cancer, Location, Institution, RiskFactor, Gene, Researcher, Study
    Relationship Types: HAS_HIGH_INCIDENCE_IN, ASSOCIATED_WITH, LINKED_TO_GENE, AFFILIATED_WITH, INVESTIGATES, CONDUCTED

    Rules:
    1. Only extract information explicitly mentioned in the text. Do not infer or add outside knowledge.
    2. For "HAS_HIGH_INCIDENCE_IN", the source must be a Cancer and the target a Location.
    3. For "ASSOCIATED_WITH", the source must be a Cancer and the target a RiskFactor.
    4. Normalize entity names (e.g., "oesophageal cancer", "cancer of the oesophagus" -> "Esophageal Cancer").
    5. A "Study" is an official research program or project, like "National Cancer Registry Programme".
    6. If no entities or relationships are found, return an empty JSON object: {{"nodes": [], "relationships": []}}.

    Text to process:
    "{text}"

    Output JSON:
    """

    prompt = ChatPromptTemplate.from_template(prompt_template)
    return prompt | llm | StrOutputParser()

In [13]:
class Neo4jGraph:
    def __init__(self, uri, user, password, database):
        self._driver = GraphDatabase.driver(uri, auth=(user, password), database=database)

    def close(self):
        self._driver.close()
    
    def execute_query(self, query, parameters=None):
        with self._driver.session() as session:
            return session.run(query, parameters)

    def setup_constraints(self):
        # sets up unique constraings on node labels to prevent duplicate entities.
        # it will only create constraints if they dont already exsits

        constraints = [
            "CREATE CONSTRAINT IF NOT EXISTS FOR (c:Cancer) REQUIRE c.name IS UNIQUE",
            "CREATE CONSTRAINT IF NOT EXISTS FOR (l:Location) REQUIRE l.name IS UNIQUE",
            "CREATE CONSTRAINT IF NOT EXISTS FOR (i:Institution) REQUIRE i.name IS UNIQUE",
            "CREATE CONSTRAINT IF NOT EXISTS FOR (r:RiskFactor) REQUIRE r.name IS UNIQUE",
            "CREATE CONSTRAINT IF NOT EXISTS FOR (g:Gene) REQUIRE g.name IS UNIQUE",
            "CREATE CONSTRAINT IF NOT EXISTS FOR (p:Researcher) REQUIRE p.name IS UNIQUE",
            "CREATE CONSTRAINT IF NOT EXISTS FOR (s:Study) REQUIRE s.name IS UNIQUE" 
        ]
        for constraint in constraints:
            self.execute_query(constraint)
        
        print("neo4j constraints are set up")

    def setup_indexes(self):
        # sets up indexes on node properties for faster lookups
        # it will only create indexes if they dont already exists

        indexes = [
            "CREATE INDEX IF NOT EXISTS FOR (c:Cancer) ON (c.name)",
            "CREATE INDEX IF NOT EXISTS FOR (l:Location) ON (l.name)",
            "CREATE INDEX IF NOT EXISTS FOR (i:Institution) ON (i.name)",
            "CREATE INDEX IF NOT EXISTS FOR (r:RiskFactor) ON (r.name)",
            "CREATE INDEX IF NOT EXISTS FOR (g:Gene) ON (g.name)",
            "CREATE INDEX IF NOT EXISTS FOR (p:Researcher) ON (p.name)",
            "CREATE INDEX IF NOT EXISTS FOR (s:Study) ON (s.name)"
        ]
        for index in indexes:
            self.execute_query(index)
        print("neo4j indexes are set up")

    def load_graph_data(self, graph_data):
        # Loads a batch of extracted nodes and relationships into Neo4j using APOC procedures.

        if not graph_data or (not graph_data.get("nodes") and not graph_data.get("relationships")):
           print (" - NO data to load, Skipping")
           return 0, 0

        nodes = graph_data.get("nodes", [])
        relationships = graph_data.get("relationships", []) 

        # --- Data Validation Step ---
        # Filter out nodes that are missing a name property, as this is required by the database.
        valid_nodes = [node for node in nodes if node.get("name")]
        valid_node_names = {node['name'] for node in valid_nodes}

        # Filter out relationships that connect to or from an invalid node.
        valid_relationships = [
            rel for rel in relationships
            if rel.get("source") in valid_node_names and rel.get("target") in valid_node_names
        ]
        
        if not valid_nodes and not valid_relationships:
            print("   -> No valid data to load after filtering. Skipping.")
            return 0, 0

        # create query using apoc.merge.node to create or merge nodes with dynamic lables
        apoc_node_query = """
        UNWIND $nodes AS node_data
        CALL apoc.merge.node([node_data.label], {name: node_data.name}) YIELD node
        RETURN count(node) AS nodes_created
        """
        
        # cypher query using apoc.create.relationship to create relationships with dynamic types
        apoc_rel_query = """
        UNWIND $relationships AS rel_data
        MATCH (source {name: rel_data.source})
        MATCH (target {name: rel_data.target})
        CALL apoc.create.relationship(source, rel_data.type, {}, target) YIELD rel
        RETURN count(rel) AS rels_created
        """

        nodes_created = 0
        rels_created = 0

        with self._driver.session() as session:
            if nodes:
                result = session.run(apoc_node_query, nodes = nodes)
                nodes_created = result.single()["nodes_created"]
            if relationships:
                result = session.run(apoc_rel_query, relationships = relationships)
                rels_created = result.single()["rels_created"]

        return nodes_created, rels_created

In [14]:
# Table processing
def process_leading_cancer_tables(tables):
    # specifically processes tables that list leading sites of cancer for a location
    # uses heuristics and regex to parse structured data, which is more reliable than an LLM for this task

    graph_data = {"nodes" : [], "relationships" : []}

    for table in tables:
        if "Leading Sites of Cancer" in table.get("table_name", "") or "Fig. 2" in table.get("table_name", ""):
            #Regex to find the location name from the figure title
            location_name_match = re.search(r"Fig\. \d+\.\d+:\s*Ten Leading Sites of Cancer.*?\((.*?)\)", table["table_name"])
            if not location_name_match:
                # fallback regex for other table name formats
                location_name_match = re.search(r"in\s(.*?)\s*\()", table["table_name"])

            if location_name_match:
                location_name = location_name_match.group(1).strip()
                graph_data["nodes"].append({"name": location_name, "label": "location"})

                for row in table["data"]:
                    # assumes the cancer name is in the first column of the table row
                    first_key = next(iter(row))
                    cancer_name = row[first_key].strip()

                    # simple normalizations to ensure consistency
                    if cancer_name and "cancer" not in cancer_name.lower():
                        cancer_name = f"{cancer_name} Cancer"

                    if cancer_name:
                        # add the cancer node
                        graph_data["nodes"].append({"name": cancer_name, "label": "cancer"})
                        # create relationship between cancer and the location
                        graph_data["relationships"].append({
                            "source": cancer_name,
                            "target": location_name,
                            "type": "HAS_HIGH_INCIDENCE_IN"
                        })
    
    return graph_data




                                            

In [15]:
# MAIN PROCESSING

def main():
    # main function to orchestrate the knowledge graph creation process

    if not GROQ_API_KEY:
        print("error: GROQ_API_KEY not found")
        return
    
    # initialize neo4j conncetion and set up database constraints and indexes
    try:
        graph_db = Neo4jGraph(NEO4J_URI, NEO4J_USERNAME, NEO4J_PASSWORD, NEO4J_DATABASE)
        graph_db.setup_constraints()
        graph_db.setup_indexes()
    except Exception as e:
        print(f"ERROR CONNECTING TO NEO4J OR SETTING UP SCHEMA: {e}")
        print(" PLEASE ENSURE THE DATABASE IS RUNNING AND CREDENTIALS ARE CORRECT")
        return
    
    # Load the cleaned JSON data from a local file
    try:
        with open("cleaned.json", "r", encoding="utf-8") as f:
            monograph = json.load(f)
    except FileNotFoundError:
        print("Error: 'cleaned.json' not found. Please ensure the file is in the same directory")
        return
    except json.JSONDecodeError:
        print("Error: 'Cleaned.json' is ont valid Json file")
        return

    # Initialize the LLM extraction chain and text splitter
    extraction_chain = get_extraction_chain()
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=200)

    # set to store hashes of processes chunks to avoid duplicates
    processed_chunks_hashes = set()

    print("\nProcessing narrative text in LLM...")
    for chapter in monograph.get("chapters", []):
        print(f" - Extracting from chapter {chapter.get('chapter', 'N/A')}: {chapter.get('title', 'No Title')}")
        content = chapter.get("content", "")
        if not content:
            continue

        # chunk the content to fit into the LLM's context window
        chunks = text_splitter.split_text(content)

        for i, chunk in enumerate(chunks):
            try: 
                # check for duplicate chunks to avoid redundant LLM calls
                chunk_hash = hashlib.sha256(chunk.encode('utf-8')).hexdigest()
                if chunk_hash in processed_chunks_hashes:
                    print(f" -> Skipping duplicate chunk {i+1}/{len(chunks)}.")
                    continue
                processed_chunks_hashes.add(chunk_hash)

                # get the extracted data from LLM
                extracted_data_str = extraction_chain.invoke({"text":chunk})

                # parse the JSON output from the LLM
                try:
                    json_match = re.search(r"```json\s*([\s\S]*?)\s*```", extracted_data_str) # llm sometimes returns markdown'''json..''', this cleans it
                    if json_match:
                        extracted_data_str = json_match.group(1)
                    extracted_data = json.loads(extracted_data_str)
                except:
                    print(f" -< LLM returned invalid json for chunk {i+1}, skipping")
                    continue

                # load the extracted data into neo4j
                nodes_loaded, rels_loaded = graph_db.load_graph_data(extracted_data)
                print(f" -> chunk {i+1}/{len(chunks)}: loaded {nodes_loaded} nodes and {rels_loaded} relationships")
                time.sleep(1) # simple rate limiting to avoid overwhelming the API

            except Exception as e:
                #print(f" -> An error occurred processsing chunk {i+1} in chaoter {chapter.get('chapter', 'N/A')} ")
                print(f"-> An error occurred processing chunk {i+1} in chapter {chapter['chapter']}: {e}")

    # process structured table data
    print("\nProcessing structured table data...")
    for chapter in monograph.get("chapters", []):
        if chapter.get("tables"):
            print(f"  - Processing tables from Chapter {chapter.get('chapter', 'N/A')}")

            # Use our specialized function for leading cancer tables.
            table_graph_data = process_leading_cancer_tables(chapter["tables"])

            if table_graph_data.get("nodes") or table_graph_data.get("relationships"):
                nodes_loaded, rels_loaded = graph_db.load_graph_data(table_graph_data)
                print(f"    -> Loaded {nodes_loaded} nodes and {rels_loaded} relationships from tables.")

    # cleand up the database connection
    graph_db.close()
    print("\nKnowledge graph construction complete")
    

if __name__ == "__main__":
    main()




neo4j constraints are set up
neo4j indexes are set up

Processing narrative text in LLM...
 - Extracting from chapter 1: Population and Cancer Incidence
   -> No valid data to load after filtering. Skipping.
 -> chunk 1/10: loaded 0 nodes and 0 relationships
   -> No valid data to load after filtering. Skipping.
 -> chunk 2/10: loaded 0 nodes and 0 relationships
   -> No valid data to load after filtering. Skipping.
 -> chunk 3/10: loaded 0 nodes and 0 relationships
   -> No valid data to load after filtering. Skipping.
 -> chunk 4/10: loaded 0 nodes and 0 relationships
   -> No valid data to load after filtering. Skipping.
 -> chunk 5/10: loaded 0 nodes and 0 relationships
   -> No valid data to load after filtering. Skipping.
 -> chunk 6/10: loaded 0 nodes and 0 relationships
 -> chunk 7/10: loaded 14 nodes and 9 relationships
   -> No valid data to load after filtering. Skipping.
 -> chunk 8/10: loaded 0 nodes and 0 relationships
   -> No valid data to load after filtering. Skipping