In [1]:
## 1. Setup & Imports
import os
import asyncio
from rdflib import Graph as RDFGraph
from neo4j import GraphDatabase

from neo4j_graphrag.indexes import create_vector_index
from neo4j_graphrag.embeddings import OpenAIEmbeddings
from neo4j_graphrag.llm.openai_llm import OpenAILLM
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import FixedSizeSplitter
from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline
from neo4j_graphrag.experimental.components.resolver import SinglePropertyExactMatchResolver
from neo4j_graphrag.retrievers import VectorRetriever, Text2CypherRetriever
from neo4j_graphrag.generation import GraphRAG

from src.utils import getSchemaFromOnto, getNLOntology, getPKs

from dotenv import load_dotenv

In [2]:
load_dotenv()

## 2. Load Environment Variables
api_key = os.getenv("OPENAI_API_KEY")
uri = os.getenv("NEO4J_URI")
user = os.getenv("NEO4J_USERNAME")
password = os.getenv("NEO4J_PASSWORD")

AUTH = (user, password)

In [3]:
print(api_key)
print(uri)
print(user)
print(password)

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
bolt://localhost:7687
neo4j
password


In [4]:
# Connect to Neo4j
driver = GraphDatabase.driver(uri, auth=AUTH)

In [5]:
## 2. Create Vector Index
INDEX_NAME = "chunk-index"
DIMENSION = 3072
create_vector_index(
    driver,
    INDEX_NAME,
    label="Chunk",
    embedding_property="embedding",
    dimensions=DIMENSION,
    similarity_fn="cosine",
)

In [6]:
from langchain.document_loaders import PyMuPDFLoader  

# Creating a function to read Multiple PDF files  
def process_pdfs_in_directory(directory_path):  

    documents = []

    for filename in os.listdir(directory_path):  
        if filename.endswith(".pdf"):  
            file_path = os.path.join(directory_path, filename) 
            pdf_loader = PyMuPDFLoader(file_path=file_path)
            document = pdf_loader.load()
            print(f"File loading done for: {filename}")
            documents.append(document)

    return documents

In [7]:
all_docs = process_pdfs_in_directory("data/")  # Specify the directory containing your PDF files

File loading done for: 2022 Q3 AAPL.pdf


In [8]:
# Initialize the base list  
base_docs = []  
  
# Flatten the all_docs structure and extend base_docs  
for doc_list in all_docs:
    base_docs.extend(doc_list)  
  
print('Length of basedocs is now ' + str(len(base_docs))) 

Length of basedocs is now 28


In [9]:
import pickle  
file_path = 'data/Pickle_File/base_docs.pkl'  
  
# Serialize the list of Document objects and save it to a file  
with open(file_path, 'wb') as file:  
    pickle.dump(base_docs, file)  
  
print(f"List of documents saved to {file_path}")

List of documents saved to data/Pickle_File/base_docs.pkl


In [10]:
import pickle   
  
# Specify the file path where your list of documents is saved  
file_path = 'data/Pickle_File/base_docs.pkl' 
  
# Deserialize the file content back into a list of Document objects  
with open(file_path, 'rb') as file:  
    base_docs = pickle.load(file)
  
print("List of documents loaded successfully: " + str(len(base_docs)))

List of documents loaded successfully: 28


In [11]:
from langchain.text_splitter import RecursiveCharacterTextSplitter

text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=300)

chunks = text_splitter.split_documents(base_docs)
len(chunks)

67

In [12]:
chunks[0]

Document(metadata={'producer': 'EDGRpdf Service w/ EO.Pdf 22.0.40.0', 'creator': 'EDGAR Filing HTML Converter', 'creationdate': '2022-07-29T06:03:21-04:00', 'source': 'data/2022 Q3 AAPL.pdf', 'file_path': 'data/2022 Q3 AAPL.pdf', 'total_pages': 28, 'format': 'PDF 1.4', 'title': '0000320193-22-000070', 'author': 'EDGAR Online, a division of Donnelley Financial Solutions', 'subject': 'Form 10-Q filed on 2022-07-29 for the period ending 2022-06-25', 'keywords': '0000320193-22-000070; ; 10-Q', 'moddate': '2022-07-29T06:03:28-04:00', 'trapped': '', 'encryption': 'Standard V2 R3 128-bit RC4', 'modDate': "D:20220729060328-04'00'", 'creationDate': "D:20220729060321-04'00'", 'page': 0}, page_content='UNITED STATES\nSECURITIES AND EXCHANGE COMMISSION\nWashington, D.C. 20549\nFORM 10-Q\n(Mark One)\n☒ QUARTERLY REPORT PURSUANT TO SECTION 13 OR 15(d) OF THE SECURITIES EXCHANGE ACT OF 1934\nFor the quarterly period ended June\xa025, 2022\nor\n☐ TRANSITION REPORT PURSUANT TO SECTION 13 OR 15(d) OF TH

In [13]:
for doc in chunks:  
    del doc.metadata['source']  
    file_path = doc.metadata['file_path']  
    doc.metadata['file_name'] = os.path.basename(file_path)  
    del doc.metadata['file_path']  

In [14]:
from pprint import pprint
pprint(chunks[0])

Document(metadata={'producer': 'EDGRpdf Service w/ EO.Pdf 22.0.40.0', 'creator': 'EDGAR Filing HTML Converter', 'creationdate': '2022-07-29T06:03:21-04:00', 'total_pages': 28, 'format': 'PDF 1.4', 'title': '0000320193-22-000070', 'author': 'EDGAR Online, a division of Donnelley Financial Solutions', 'subject': 'Form 10-Q filed on 2022-07-29 for the period ending 2022-06-25', 'keywords': '0000320193-22-000070; ; 10-Q', 'moddate': '2022-07-29T06:03:28-04:00', 'trapped': '', 'encryption': 'Standard V2 R3 128-bit RC4', 'modDate': "D:20220729060328-04'00'", 'creationDate': "D:20220729060321-04'00'", 'page': 0, 'file_name': '2022 Q3 AAPL.pdf'}, page_content='UNITED STATES\nSECURITIES AND EXCHANGE COMMISSION\nWashington, D.C. 20549\nFORM 10-Q\n(Mark One)\n☒ QUARTERLY REPORT PURSUANT TO SECTION 13 OR 15(d) OF THE SECURITIES EXCHANGE ACT OF 1934\nFor the quarterly period ended June\xa025, 2022\nor\n☐ TRANSITION REPORT PURSUANT TO SECTION 13 OR 15(d) OF THE SECURITIES EXCHANGE ACT OF 1934\nFor t

In [15]:
with open("onto/improved_financial_ontology.ttl", "r") as f:
    onto = f.read()
    print(onto)

@prefix ex: <http://example.org/financial/> .
@prefix rdfs: <http://www.w3.org/2000/01/rdf-schema#> .

ex:disclosesRisk a <http://www.w3.org/2002/07/owl#ObjectProperty> ;
    rdfs:label "Discloses Risk" ;
    rdfs:comment "A report that describes risk factors." ;
    rdfs:domain ex:FinancialReport ;
    rdfs:range ex:RiskFactor .

ex:dividendDeclared a <http://www.w3.org/2002/07/owl#DatatypeProperty> ;
    rdfs:label "Dividend Declared" ;
    rdfs:comment "Reported dividend for the time period." ;
    rdfs:domain ex:FinancialReport .

ex:hasStockInfo a <http://www.w3.org/2002/07/owl#ObjectProperty> ;
    rdfs:label "Has Stock Info" ;
    rdfs:comment "The company has associated stock information." ;
    rdfs:domain ex:Company ;
    rdfs:range ex:StockInformation .

ex:includes a <http://www.w3.org/2002/07/owl#ObjectProperty> ;
    rdfs:label "Includes" ;
    rdfs:comment "A report includes an income statement." ;
    rdfs:domain ex:FinancialReport ;
    rdfs:range ex:IncomeStatement .


In [16]:
## 3. Load Ontology and Schema
from IPython.display import display

g = RDFGraph()
g.parse("onto/financial_report_ontology.ttl")
neo4j_schema = getSchemaFromOnto(g) # Load the schema from the ontology
display(neo4j_schema)

SchemaConfig(entities={'FinancialReport': {'label': 'FinancialReport', 'description': 'A 10-Q or 10-K financial disclosure submitted by a public company.', 'properties': [{'name': 'reportNumber', 'type': 'STRING', 'description': 'The SEC form number such as 10-Q or 10-K.'}]}, 'Company': {'label': 'Company', 'description': 'A publicly traded corporation that submits financial reports.', 'properties': []}, 'FinancialStatement': {'label': 'FinancialStatement', 'description': 'A general class for financial statements.', 'properties': []}, 'IncomeStatement': {'label': 'IncomeStatement', 'description': 'An income statement detailing revenue, expenses, and net income.', 'properties': []}, 'BalanceSheet': {'label': 'BalanceSheet', 'description': 'A balance sheet showing assets, liabilities, and equity.', 'properties': []}, 'CashFlowStatement': {'label': 'CashFlowStatement', 'description': 'A statement of cash inflows and outflows.', 'properties': []}, 'ShareholdersEquityStatement': {'label': '

In [94]:
neo4j_schema.potential_schema

[('FinancialReport', 'hasFinancialStatement', 'FinancialStatement'),
 ('FinancialStatement', 'hasMetric', 'FinancialMetric'),
 ('FinancialReport', 'hasRiskFactor', 'RiskFactor'),
 ('FinancialReport', 'hasLegalProceeding', 'LegalProceeding'),
 ('FinancialReport', 'hasStockInfo', 'StockInformation'),
 ('MarketDisclosure', 'relatedTo', 'Company')]

In [91]:
entities=neo4j_schema.entities.values()
entities

dict_values([{'label': 'FinancialReport', 'description': 'A 10-Q or 10-K financial disclosure submitted by a public company.', 'properties': [{'name': 'reportNumber', 'type': 'STRING', 'description': 'The SEC form number such as 10-Q or 10-K.'}]}, {'label': 'Company', 'description': 'A publicly traded corporation that submits financial reports.', 'properties': []}, {'label': 'FinancialStatement', 'description': 'A general class for financial statements.', 'properties': []}, {'label': 'IncomeStatement', 'description': 'An income statement detailing revenue, expenses, and net income.', 'properties': []}, {'label': 'BalanceSheet', 'description': 'A balance sheet showing assets, liabilities, and equity.', 'properties': []}, {'label': 'CashFlowStatement', 'description': 'A statement of cash inflows and outflows.', 'properties': []}, {'label': 'ShareholdersEquityStatement', 'description': 'Tracks changes in equity of shareholders over a period.', 'properties': []}, {'label': 'FinancialMetric

In [92]:
relations=neo4j_schema.relations.values()
relations

dict_values([{'label': 'hasFinancialStatement', 'description': 'Links a report to the financial statements it contains.', 'properties': []}, {'label': 'hasMetric', 'description': 'Connects a statement to its metrics like revenue or net income.', 'properties': []}, {'label': 'hasRiskFactor', 'description': 'Connects a report to its described risk factors.', 'properties': []}, {'label': 'hasLegalProceeding', 'description': 'Legal proceedings discussed in the report.', 'properties': []}, {'label': 'hasStockInfo', 'description': 'Connects the report to stock/share-related data.', 'properties': []}, {'label': 'relatedTo', 'description': 'General relationship to a company or segment.', 'properties': []}])

In [19]:
nl_ontology = getNLOntology(g)
print(nl_ontology) # Load the NL ontology from the graph


Node Labels:
FinancialReport: A 10-Q or 10-K financial disclosure submitted by a public company.
Company: A publicly traded corporation that submits financial reports.
FinancialStatement: A general class for financial statements.
IncomeStatement: An income statement detailing revenue, expenses, and net income.
BalanceSheet: A balance sheet showing assets, liabilities, and equity.
CashFlowStatement: A statement of cash inflows and outflows.
ShareholdersEquityStatement: Tracks changes in equity of shareholders over a period.
FinancialMetric: A quantitative value in a financial report, like Net Income or Revenue.
StockInformation: Details about company shares, dividends, etc.
RiskFactor: Qualitative risks disclosed by a company.
LegalProceeding: A legal action or case referenced in the financial report.
MarketDisclosure: Narrative discussion on business, market conditions, and trends.

Node Properties:
reportNumber: Attribute that applies to entities of type FinancialReport. It represent

In [20]:
from src.models import KnowledgeGraphLLM
# from langchain_ollama import ChatOllama
from langchain_openai import OpenAIEmbeddings

llm_gpt_4o_mini = KnowledgeGraphLLM(model_name="gpt-4o-mini", max_tokens=10000, api_key=api_key)

llm_embedding_large_3 = OpenAIEmbeddings(api_key=api_key, model="text-embedding-3-large")

# llm_llama31_8b = ChatOllama(
#     model="llama3.1:8b",
#     temperature=0,
#     base_url = "http://10.5.61.140:5000/"
# )

In [21]:
llm_gpt_4o_mini.invoke("What is the capital of France?")

'[{"s": "France", "p": "hasCapital", "o": "Paris", "stype": "Country", "otype": "City", "value": 0.0, "unitCode": "", "temporalCoverage": "", "tickerSymbol": "", "dividendDeclared": 0.0, "sharesOutstanding": 0.0, "reportNumber": ""}]'

In [None]:
# from neo4j_graphrag.embeddings import OpenAIEmbeddings
# from neo4j_graphrag.llm.openai_llm import OpenAILLM

# embedder = OpenAIEmbeddings(api_key=api_key, model="text-embedding-3-large")
# llm = OpenAILLM(api_key=api_key, model_name="gpt-4o-mini", model_params={"temperature": 0, "max_tokens": 3000, "response_format": {"type": "json_object"}})


In [None]:
# from langchain_core.prompts import ChatPromptTemplate

# # Multi Query: Different Perspectives
# template_for_extracting_triples = """
# ### Follow the INSTRUCTION carefully:
# You are an expert in **Domain** constructing a **knowledge graph**. Given a **Context** and **Ontologies defined in Natural Language**, **Entities Types**, **Relations Types** which is provided at the end, perform the following tasks:

# #### **Step 1: Extract Relevant Entities**
# - Extract entites based on the **Ontologies defined in Natural Language** provided below for the Given **Domain** and **Context**.

# #### **Step 2: Use only and only the 7 predefined relation types, otherwise, you will be given penalty:**
# - Extract relations based on the **Ontologies defined in Natural Language** provided below for the Given **Domain** and **Context**.

# #### **Step 3: Identify Relationships and Generate Triplets**
# - Determine the **relationships** between the *Entity Types** and the using the **triplet format**: (head_concept, relation, tail_concept)
    
# - Relationship Directionality: 
#     - Some relations are strictly directional, meaning (A, Evaluate-for, B) is valid, but (B, Evaluate-for, A) is not.
#     - The relations "Compare" and "Conjunction" are bidirectional.
#     - The query concept may be the head or tail in a triplet, but additional triplets between extracted concepts are encouraged.

# #### **Step 4: Format the Output only and only use the given format, otherwise, you will be given penalty**
# - Return ONLY and ONLY a list of triplets in this format: (concept, relation, concept)
# - For Example:
#     (natural language explanation, Used-for, model reasoning)
#     (natural language explanation, Evaluate-for, classification performance)

# #### Boundary Conditions:
# - A tuple of triplet is considered eligible if it has only and only 3 items that is 2 concepts and one relationship between them

# #### Important notes:
# - No additional explanations, numbering, or extra text.

# #### Here is the Domain: 
# {domain}

# #### Here is the Ontologies defined in Natural Language:
# {nl_ontology}

# #### Here is the Context:
# {chunk_text}
# """
# prompt_template_for_extracting_triples = ChatPromptTemplate.from_template(template_for_extracting_triples)

In [None]:
# prompt = prompt_template_for_extracting_triples.invoke({"domain":'Financial Report', "chunk_text":base_docs[5].page_content, "nl_ontology":nl_ontology})
# print(prompt)

In [None]:
# llm_chain = prompt_template_for_extracting_triples | llm_gpt_4o_mini 
# response = llm_chain.invoke({"domain":'Financial Report', "chunk_text":base_docs[5].page_content, "nl_ontology":nl_ontology})

In [None]:
# print(response)

In [None]:
def safe_json_parse(s):
    import ast
    try:
        return ast.literal_eval(s)
    except Exception as e:
        print("Parsing error:", e)
        return None


In [None]:
# response = safe_json_parse(response) # This will return a list of tuples if the parsing is successful, or None if it fails.
# print(type(response[0]))
# print(response)

In [None]:
# from typing import List, Optional
# from neo4j import GraphDatabase
# from langchain.schema import Document
# from langchain_core.output_parsers import StrOutputParser
# import os

# class CustomKGPipeline:
#     def __init__(
#         self,
#         driver: GraphDatabase,
#         embedder,
#         llm,
#         domain: str,
#         ontology: str,
#         prompt_path: str,
#         neo4j_database: Optional[str] = None,
#     ):
#         self.driver = driver
#         self.embedder = embedder
#         self.llm = llm
#         self.domain = domain
#         self.ontology = ontology
#         self.prompt_path = prompt_path
#         self.neo4j_database = neo4j_database

#         # We will load these later
#         self.prompt = None

#     def _load_prompt(self):
#         if not os.path.exists(self.prompt_path):
#             raise FileNotFoundError(f"Prompt file not found at {self.prompt_path}")
#         with open(self.prompt_path, "r") as f:
#             self.template_for_extracting_triples = f.read()


#     def _render_prompt(self):
#         if self.prompt is None:
#             template = self._load_prompt()
#             self.prompt_template = ChatPromptTemplate.from_template(self.template_for_extracting_triples)
#         return self.prompt_template
    
#     def _create_chunk_nodes(self, docs: List[Document]):
#         with self.driver.session(database=self.neo4j_database) as session:
#             for i, doc in enumerate(docs):
#                 text = doc.page_content
#                 embedding = self.embedder.embed_query(text)

#                 props = {
#                     "chunk_index": i,
#                     "text": text,
#                     "embedding": embedding,
#                 }

#                 # Merge metadata into props
#                 props.update(doc.metadata)

#                 # Create the chunk node
#                 session.run(
#                     """
#                     MERGE (c:Chunk {chunk_index: $chunk_index})
#                     SET c += $props
#                     """,
#                     chunk_index=i,
#                     props=props,
#                 )

#                 # Create NEXT and PREVIOUS links if applicable
#                 if i > 0:
#                     session.run(
#                         """
#                         MATCH (c1:Chunk {chunk_index: $prev}), (c2:Chunk {chunk_index: $curr})
#                         MERGE (c1)-[:NEXT_CHUNK]->(c2)
#                         MERGE (c2)-[:PREVIOUS_CHUNK]->(c1)
#                         """,
#                         prev=i - 1,
#                         curr=i,
#                     )


#     def _extract_triples(self, prompt, chunk_text) -> List[tuple[str, str, str]]:
#         """
#         Sends the prompt to LLM and parses the output into (s, p, o) triples.
#         The LLM is expected to return JSON in the format:
#         [{"s": "Subject", "p": "Predicate", "o": "Object"}, ...]
#         """
#         triple_extracter_chain = prompt | self.llm | StrOutputParser()
#         response = triple_extracter_chain.invoke({
#             "chunk_text": chunk_text,
#             "domain": self.domain,
#             "nl_ontology": self.ontology
#         })

#         try:
#             import ast
#             data = ast.literal_eval(response)
#         except Exception as e:
#             print("Failed to parse LLM output:", e)
#             return []

#         triples = []
#         for item in data:
#             s, p, o = item.get("s"), item.get("p"), item.get("o")
#             if s and p and o:
#                 triples.append((s.strip(), p.strip(), o.strip()))

#         return triples


#     def _create_kg_from_chunks(self, docs: List[Document]):
#         for i, doc in enumerate(docs):
#             prompt = self._render_prompt()
#             triples = self._extract_triples(prompt, doc.page_content)

#             with self.driver.session(database=self.neo4j_database) as session:
#                 for s, p, o in triples:
#                     if not (s and p and o):
#                         continue  # skip malformed ones

#                     # Create entity nodes and connect
#                     session.run(
#                         f"""
#                         MERGE (subj:Entity {{name: $s}})
#                         MERGE (obj:Entity {{name: $o}})
#                         MERGE (subj)-[:{p.upper()}]->(obj)
#                         WITH subj, obj
#                         MATCH (c:Chunk {{chunk_index: $chunk_index}})
#                         MERGE (subj)-[:MENTIONED_IN]->(c)
#                         MERGE (obj)-[:MENTIONED_IN]->(c)
#                         """,
#                         {"s": s, "o": o, "chunk_index": i},
#                     )

#     def _deduplicate_entities(self):
#         with self.driver.session(database=self.neo4j_database) as session:
#             session.run("""
#             MATCH (e1:Entity), (e2:Entity)
#             WHERE e1.name = e2.name AND id(e1) < id(e2)
#             CALL apoc.refactor.mergeNodes([e1, e2]) YIELD node
#             RETURN count(node)
#             """)

#     def run(self, docs: List[Document]):
#         print("📄 Creating chunk nodes...")
#         self._create_chunk_nodes(docs)

#         print("🔍 Extracting triples...")
#         self._create_kg_from_chunks(docs)

#         print("🧹 Deduplicating entities...")
#         self._deduplicate_entities()

#         print("✅ Done.")

In [None]:
from typing import List, Optional
from neo4j import GraphDatabase
from langchain.schema import Document
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
import os
import json
import ast
from pydantic import BaseModel, ValidationError
from typing import List


class KGTriple(BaseModel):
    s: str
    p: str
    o: str
    stype: str = "Entity"
    otype: str = "Entity"


class CustomKGPipeline:
    def __init__(
        self,
        driver: GraphDatabase,
        embedder,
        llm,
        domain: str,
        ontology: str,
        prompt_path: str,
        neo4j_database: Optional[str] = None,
    ):
        self.driver = driver
        self.embedder = embedder
        self.llm = llm
        self.domain = domain
        self.ontology = ontology
        self.prompt_path = prompt_path
        self.neo4j_database = neo4j_database

        # We will load these later
        self.prompt = None

    def _load_prompt(self):
        if not os.path.exists(self.prompt_path):
            raise FileNotFoundError(f"Prompt file not found at {self.prompt_path}")
        with open(self.prompt_path, "r") as f:
            self.template_for_extracting_triples = f.read()


    def _render_prompt(self):
        if self.prompt is None:
            template = self._load_prompt()
            self.prompt_template = ChatPromptTemplate.from_template(self.template_for_extracting_triples)
        return self.prompt_template
    
    def _create_chunk_nodes(self, docs: List[Document]):
        with self.driver.session(database=self.neo4j_database) as session:
            for i, doc in enumerate(docs):
                text = doc.page_content
                embedding = self.embedder.embed_query(text)

                props = {
                    "chunk_index": i,
                    "text": text,
                    "embedding": embedding,
                }

                # Merge metadata into props
                props.update(doc.metadata)

                # Create the chunk node
                session.run(
                    """
                    MERGE (c:Chunk {chunk_index: $chunk_index})
                    SET c += $props
                    """,
                    chunk_index=i,
                    props=props,
                )

                # Create NEXT and PREVIOUS links if applicable
                if i > 0:
                    session.run(
                        """
                        MATCH (c1:Chunk {chunk_index: $prev}), (c2:Chunk {chunk_index: $curr})
                        MERGE (c1)-[:NEXT_CHUNK]->(c2)
                        MERGE (c2)-[:PREVIOUS_CHUNK]->(c1)
                        """,
                        prev=i - 1,
                        curr=i,
                    )


    def _extract_triples(self, prompt, chunk_text) -> List[dict]:
        """
        Sends the prompt to LLM and parses the output into a list of triple dicts.
        Each dict contains 's', 'p', 'o', and optionally 'stype', 'otype', and other properties.
        """
        triple_extracter_chain = prompt | self.llm | StrOutputParser()
        response = triple_extracter_chain.invoke({
            "chunk_text": chunk_text,
            "domain": self.domain,
            "nl_ontology": self.ontology
        })

        try:
            raw_triples = ast.literal_eval(response)
        except Exception as e:
            print(f"⚠️ Failed to parse response: {e}")
            return []

        triples = []
        for item in raw_triples:
            try:
                triple = KGTriple(**item)
                triples.append(triple)
            except Exception as e:
                print(f"⛔ Invalid triple skipped: {item} due to {e}")
                continue

        return triples


    def _create_kg_from_chunks(self, docs: List[Document]):
        for i, doc in enumerate(docs):
            prompt = self._render_prompt()
            triples = self._extract_triples(prompt, doc.page_content)

            with self.driver.session(database=self.neo4j_database) as session:
                for triple in triples:
                    s, p, o = triple.s, triple.p, triple.o
                    stype, otype = triple.stype, triple.otype

                    if not (s and p and o):
                        continue

                    # Optional properties
                    o_props = {
                                k: v for k, v in triple.model_dump().items()
                                if k not in {"s", "p", "o", "stype", "otype"}
                            }
                    # Clean the labels to remove any problematic characters
                    def safe_label(label: Optional[str], default="Entity") -> str:
                        if not label:
                            return default
                        return "".join(c for c in label if c.isalnum() or c == "_") or default

                    stype = safe_label(triple.stype)
                    otype = safe_label(triple.otype)

                    # Create entity nodes and connect
                    cypher = f"""
                    MERGE (subj:{stype} {{name: $s}})
                    MERGE (obj:{otype} {{name: $o}})
                    SET obj += $o_props
                    WITH subj, obj
                    MATCH (c:Chunk {{chunk_index: $chunk_index}})
                    MERGE (subj)-[:MENTIONED_IN]->(c)
                    MERGE (obj)-[:MENTIONED_IN]->(c)
                    """

                    session.run(
                        cypher,
                        {"s": s, "o": o, "o_props": o_props, "chunk_index": i},
                    )

    def _deduplicate_entities(self):
        with self.driver.session(database=self.neo4j_database) as session:
            session.run("""
            MATCH (e1:Entity), (e2:Entity)
            WHERE e1.name = e2.name AND id(e1) < id(e2)
            CALL apoc.refactor.mergeNodes([e1, e2]) YIELD node
            RETURN count(node)
            """)

    def run(self, docs: List[Document]):
        print("📄 Creating chunk nodes...")
        self._create_chunk_nodes(docs)

        print("🔍 Extracting triples...")
        self._create_kg_from_chunks(docs)

        print("🧹 Deduplicating entities...")
        self._deduplicate_entities()

        print("✅ Done.")



In [None]:
pipeline = CustomKGPipeline(
    driver=driver,
    embedder=llm_embedding_large_3,
    llm=llm_gpt_4o_mini,
    domain="Quaterly Financial Reports of Companies",
    ontology=nl_ontology,
    prompt_path="prompt/updated_prompt.txt",
)

In [None]:
len(chunks)

67

In [None]:
pipeline.run(chunks[20:30])  # Pass each chunk to the pipeline

📄 Creating chunk nodes...
🔍 Extracting triples...




🧹 Deduplicating entities...
✅ Done.


In [None]:
with driver.session() as session:
    session.run("MATCH (n) DETACH DELETE n")
    session.run("DROP INDEX `chunk-index` IF EXISTS")
    session.run("CALL apoc.schema.assert({}, {})")
