In [3]:
## 1. Setup & Imports
import os
import asyncio
from rdflib import Graph as RDFGraph
from neo4j import GraphDatabase

from neo4j_graphrag.indexes import create_vector_index
from neo4j_graphrag.embeddings import OpenAIEmbeddings
from neo4j_graphrag.llm.openai_llm import OpenAILLM
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import FixedSizeSplitter
from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline
from neo4j_graphrag.experimental.components.resolver import SinglePropertyExactMatchResolver
from neo4j_graphrag.retrievers import VectorRetriever, Text2CypherRetriever
from neo4j_graphrag.generation import GraphRAG

# from src.utils import getSchemaFromOnto
from dotenv import load_dotenv

In [4]:
load_dotenv()

## 2. Load Environment Variables
api_key = os.getenv("OPENAI_API_KEY")
uri = os.getenv("NEO4J_URI")
user = os.getenv("NEO4J_USERNAME")
password = os.getenv("NEO4J_PASSWORD")

AUTH = (user, password)

In [5]:
print(api_key)
print(uri)
print(user)
print(password)

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
bolt://localhost:7687
neo4j
password


In [6]:
# Connect to Neo4j
driver = GraphDatabase.driver(uri, auth=AUTH)

In [7]:
## 2. Create Vector Index
INDEX_NAME = "chunk-index"
DIMENSION = 3072
create_vector_index(
    driver,
    INDEX_NAME,
    label="Chunk",
    embedding_property="embedding",
    dimensions=DIMENSION,
    similarity_fn="cosine",
)

In [6]:
from langchain.document_loaders import PyMuPDFLoader  

# Creating a function to read Multiple PDF files  
def process_pdfs_in_directory(directory_path):  

    documents = []

    for filename in os.listdir(directory_path):  
        if filename.endswith(".pdf"):  
            file_path = os.path.join(directory_path, filename) 
            pdf_loader = PyMuPDFLoader(file_path=file_path)
            document = pdf_loader.load()
            print(f"File loading done for: {filename}")
            documents.append(document)

    return documents

In [7]:
all_docs = process_pdfs_in_directory("data/")  # Specify the directory containing your PDF files

File loading done for: 2022 Q3 AAPL.pdf


In [8]:
# Initialize the base list  
base_docs = []  
  
# Flatten the all_docs structure and extend base_docs  
for doc_list in all_docs:
    base_docs.extend(doc_list)  
  
print('Length of basedocs is now ' + str(len(base_docs))) 

Length of basedocs is now 28


In [9]:
import pickle  
file_path = 'data/Pickle_File/base_docs.pkl'  
  
# Serialize the list of Document objects and save it to a file  
with open(file_path, 'wb') as file:  
    pickle.dump(base_docs, file)  
  
print(f"List of documents saved to {file_path}")

List of documents saved to data/Pickle_File/base_docs.pkl


In [10]:
import pickle   
  
# Specify the file path where your list of documents is saved  
file_path = 'data/Pickle_File/base_docs.pkl' 
  
# Deserialize the file content back into a list of Document objects  
with open(file_path, 'rb') as file:  
    base_docs = pickle.load(file)
  
print("List of documents loaded successfully: " + str(len(base_docs)))

List of documents loaded successfully: 28


In [11]:
from langchain.text_splitter import RecursiveCharacterTextSplitter

text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=300)

chunks = text_splitter.split_documents(base_docs)
len(chunks)

67

In [12]:
chunks[0]

Document(metadata={'producer': 'EDGRpdf Service w/ EO.Pdf 22.0.40.0', 'creator': 'EDGAR Filing HTML Converter', 'creationdate': '2022-07-29T06:03:21-04:00', 'source': 'data/2022 Q3 AAPL.pdf', 'file_path': 'data/2022 Q3 AAPL.pdf', 'total_pages': 28, 'format': 'PDF 1.4', 'title': '0000320193-22-000070', 'author': 'EDGAR Online, a division of Donnelley Financial Solutions', 'subject': 'Form 10-Q filed on 2022-07-29 for the period ending 2022-06-25', 'keywords': '0000320193-22-000070; ; 10-Q', 'moddate': '2022-07-29T06:03:28-04:00', 'trapped': '', 'encryption': 'Standard V2 R3 128-bit RC4', 'modDate': "D:20220729060328-04'00'", 'creationDate': "D:20220729060321-04'00'", 'page': 0}, page_content='UNITED STATES\nSECURITIES AND EXCHANGE COMMISSION\nWashington, D.C. 20549\nFORM 10-Q\n(Mark One)\n☒ QUARTERLY REPORT PURSUANT TO SECTION 13 OR 15(d) OF THE SECURITIES EXCHANGE ACT OF 1934\nFor the quarterly period ended June\xa025, 2022\nor\n☐ TRANSITION REPORT PURSUANT TO SECTION 13 OR 15(d) OF TH

In [13]:
for index,doc in enumerate(chunks):  
    del doc.metadata['source']  
    file_path = doc.metadata['file_path']  
    doc.metadata['file_name'] = os.path.basename(file_path)  
    del doc.metadata['file_path']  

In [14]:
from pprint import pprint
pprint(chunks[0])

Document(metadata={'producer': 'EDGRpdf Service w/ EO.Pdf 22.0.40.0', 'creator': 'EDGAR Filing HTML Converter', 'creationdate': '2022-07-29T06:03:21-04:00', 'total_pages': 28, 'format': 'PDF 1.4', 'title': '0000320193-22-000070', 'author': 'EDGAR Online, a division of Donnelley Financial Solutions', 'subject': 'Form 10-Q filed on 2022-07-29 for the period ending 2022-06-25', 'keywords': '0000320193-22-000070; ; 10-Q', 'moddate': '2022-07-29T06:03:28-04:00', 'trapped': '', 'encryption': 'Standard V2 R3 128-bit RC4', 'modDate': "D:20220729060328-04'00'", 'creationDate': "D:20220729060321-04'00'", 'page': 0, 'file_name': '2022 Q3 AAPL.pdf'}, page_content='UNITED STATES\nSECURITIES AND EXCHANGE COMMISSION\nWashington, D.C. 20549\nFORM 10-Q\n(Mark One)\n☒ QUARTERLY REPORT PURSUANT TO SECTION 13 OR 15(d) OF THE SECURITIES EXCHANGE ACT OF 1934\nFor the quarterly period ended June\xa025, 2022\nor\n☐ TRANSITION REPORT PURSUANT TO SECTION 13 OR 15(d) OF THE SECURITIES EXCHANGE ACT OF 1934\nFor t

In [15]:
with open("onto/test_2.ttl", "r") as f:
    onto = f.read()
    print(onto)

@prefix fr: <http://example.org/financial-report#> .
@prefix rdfs: <http://www.w3.org/2000/01/rdf-schema#> .
@prefix owl:  <http://www.w3.org/2002/07/owl#> .
@prefix xsd:  <http://www.w3.org/2001/XMLSchema#> .

######################################################
# Ontology Declaration
######################################################
fr:FinancialReportOntology a owl:Ontology ;
    rdfs:comment "Exhaustive ontology for capturing structured info from financial filings (like 10-Q, 10-K)." ;
    rdfs:label "Financial Report Ontology (Exhaustive Example)" .

######################################################
# Classes
######################################################
fr:Company a owl:Class ;
    rdfs:label "Company" ;
    rdfs:comment "A corporate entity, e.g., Apple Inc." .

fr:FinancialReport a owl:Class ;
    rdfs:label "Financial Report" ;
    rdfs:comment "A formal published report, e.g. 10-Q, 10-K, that includes statements." .

fr:FinancialStatement a owl:Class ;
    

In [16]:
from rdflib import Graph
g = Graph()
g.parse("onto/test_2.ttl", format="turtle")  # Ensure this path is correct!

neo4j_schema = getSchemaFromOnto(g)

print("Entities:", [e["label"] for e in neo4j_schema.entities.values()])
print("Object Properties:", [r["label"] for r in neo4j_schema.relations.values()])
print("Data Properties:", ({p["name"] for e in neo4j_schema.entities.values() for p in e.get("properties", [])}))
print("Schema Triples:", neo4j_schema.potential_schema)


Entities: ['Company', 'FinancialReport', 'FinancialStatement', 'FinancialItem', 'Product', 'Service', 'MonetaryValue', 'DatePeriod', 'ShareholderEquity']
Object Properties: ['reportsOn', 'hasStatement', 'hasLineItem', 'relatedToProduct', 'relatedToService', 'hasValue', 'coversPeriod']
Data Properties: {'periodEndDate', 'amountCurrency', 'serviceName', 'companyName', 'dateLabel', 'lineItemName', 'reportType', 'amountValue', 'productName'}
Schema Triples: [('Company', 'reportsOn', 'FinancialReport'), ('FinancialReport', 'hasStatement', 'FinancialStatement'), ('FinancialStatement', 'hasLineItem', 'FinancialItem'), ('FinancialItem', 'relatedToProduct', 'Product'), ('FinancialItem', 'relatedToService', 'Service'), ('FinancialItem', 'hasValue', 'MonetaryValue'), ('FinancialItem', 'coversPeriod', 'DatePeriod')]


In [17]:
classes = set([e["label"] for e in neo4j_schema.entities.values()])
print("Classes:", classes)
obj_props = set([r["label"] for r in neo4j_schema.relations.values()])
print("Object Properties:", obj_props)
data_props = set([p["name"] for e in neo4j_schema.entities.values() for p in e.get("properties", [])])
print("Data Properties:", data_props)

Classes: {'FinancialStatement', 'DatePeriod', 'FinancialReport', 'MonetaryValue', 'FinancialItem', 'Company', 'Product', 'ShareholderEquity', 'Service'}
Object Properties: {'hasLineItem', 'relatedToService', 'hasStatement', 'relatedToProduct', 'reportsOn', 'hasValue', 'coversPeriod'}
Data Properties: {'periodEndDate', 'amountCurrency', 'serviceName', 'companyName', 'dateLabel', 'lineItemName', 'reportType', 'amountValue', 'productName'}


In [18]:
print(chunks[21].page_content)

The following table shows the fair value of the Company’s non-current marketable debt securities, by contractual maturity, as of June 25, 2022 (in millions):
Due after 1 year through 5 years
$
92,970 
Due after 5 years through 10 years
19,317 
Due after 10 years
18,790 
Total fair value
$
131,077 
Derivative Instruments and Hedging
The Company may use derivative instruments to partially offset its business exposure to foreign exchange and interest rate risk. However, the Company may
choose not to hedge certain exposures for a variety of reasons, including accounting considerations or the prohibitive economic cost of hedging particular
exposures. There can be no assurance the hedges will offset more than a portion of the financial impact resulting from movements in foreign exchange or
interest rates.
Foreign Exchange Risk
To protect gross margins from fluctuations in foreign currency exchange rates, the Company may enter into forward contracts, option contracts or other
instruments, and

In [47]:
total = ""
for chunk in chunks[20:30]:
    total += chunk.page_content + "\n\n"
print(total)

171 
(124)
20,576 
— 
775 
19,801 
Subtotal
162,873 
1,742 
(1,197)
163,418 
8,027 
27,514 
127,877 
Total 
$
189,961 
$
1,753 
$
(1,198)
$
190,516 
$
34,940 
$
27,699 
$
127,877 
(1)
Level 1 fair value estimates are based on quoted prices in active markets for identical assets or liabilities.
(2)
Level 2 fair value estimates are based on observable inputs other than quoted prices in active markets for identical assets and liabilities, quoted prices for
identical or similar assets or liabilities in inactive markets, or other inputs that are observable or can be corroborated by observable market data for substantially
the full term of the assets or liabilities.
(3)
As of June 25, 2022 and September 25, 2021, total marketable securities included $14.1 billion and $17.9 billion, respectively, that were restricted from general
use, related to the European Commission decision finding that Ireland granted state aid to the Company, and other agreements.
(1)
(2)
(3)
(1)
(2)
(3)
Apple Inc. | Q3

In [19]:
# nl_ontology = getNLOntology(g)
# print(nl_ontology) # Load the NL ontology from the graph

In [20]:
from src.models import KnowledgeGraphLLM
# from langchain_ollama import ChatOllama
from langchain_openai import OpenAIEmbeddings, ChatOpenAI

llm_gpt_4o_mini = KnowledgeGraphLLM(model_name="gpt-4o-mini", max_tokens=10000, api_key=api_key)

# llm_gpt_4o_mini = ChatOpenAI(model_name="gpt-4o-mini", max_tokens=10000, api_key=api_key)

llm_embedding_large_3 = OpenAIEmbeddings(api_key=api_key, model="text-embedding-3-large")

# llm_llama31_8b = ChatOllama(
#     model="llama3.1:8b",
#     temperature=0,
#     base_url = "http://10.5.61.140:5000/"
# )

In [28]:
############# RUN EXAMPLE #############

import json

prompt = """Extract a triple from the statement:
'Apple Inc. introduced a new iPhone and it has an MSRP of $999.'
"""

response = llm_gpt_4o_mini.invoke(prompt)

print("Raw structured object:", response)
# print("JSON form:\n", json.dumps(response.model_dump(), indent=2))
# print("JSON form:\n", json.dumps(json.load(response).model_dump(), indent=2))


Raw structured object: {"triples": [{"subject": "Apple Inc.", "subject_type": "Company", "relation": "introduced", "object": "iPhone", "object_type": "Product", "subject_attributes": [], "object_attributes": [{"key": "MSRP", "value": "$999"}]}]}


In [29]:
response 

'{"triples": [{"subject": "Apple Inc.", "subject_type": "Company", "relation": "introduced", "object": "iPhone", "object_type": "Product", "subject_attributes": [], "object_attributes": [{"key": "MSRP", "value": "$999"}]}]}'

In [30]:
response = json.loads(response)
print("JSON form:\n", response)

JSON form:
 {'triples': [{'subject': 'Apple Inc.', 'subject_type': 'Company', 'relation': 'introduced', 'object': 'iPhone', 'object_type': 'Product', 'subject_attributes': [], 'object_attributes': [{'key': 'MSRP', 'value': '$999'}]}]}


In [150]:
from typing import List, Optional
from pydantic import BaseModel, Field, Extra
from langchain_openai import ChatOpenAI
import json
import os

########################################################
#  1) Define a KeyValueAttribute model (no free dicts)  #
########################################################

class KeyValueAttribute(BaseModel):
    """Represents one attribute as a key-value pair."""
    key: str = Field(..., description="Attribute name")
    value: str = Field(..., description="Attribute value")

    class Config:
        extra = "forbid"
        json_schema_extra = {
            "type": "object",
            "additionalProperties": False
        }

########################################################
#  2) Define your Triple + Triples with KV attributes  #
########################################################

class Triple(BaseModel):
    subject: str = Field(..., description="Subject text")
    subject_type: str = Field(..., description="Class/type of the Subject")
    relation: str = Field(..., description="Relation from subject to object")
    object: str = Field(..., description="Object text")
    object_type: str = Field(..., description="Class/type of the Object")

    subject_attributes: List[KeyValueAttribute] = Field(
        default_factory=list,
        description="List of arbitrary attributes describing the Subject"
    )
    object_attributes: List[KeyValueAttribute] = Field(
        default_factory=list,
        description="List of arbitrary attributes describing the Object"
    )

    class Config:
        extra = "forbid"
        json_schema_extra = {
            "type": "object",
            "additionalProperties": False
        }

class Triples(BaseModel):
    triples: List[Triple] = Field(..., description="All extracted knowledge graph triples")

    class Config:
        extra = "forbid"
        json_schema_extra = {
            "type": "object",
            "additionalProperties": False
        }
        

############# LLM SETUP #############

api_key = os.getenv("OPENAI_API_KEY")
model = ChatOpenAI(
    model="gpt-4o-mini",  # or "gpt-4"
    max_tokens=1000,
    api_key=api_key
)

# Wrap it so the model returns typed output:
structured_model = model.with_structured_output(Triples)

############# RUN EXAMPLE #############

prompt = """Extract a triple from the statement:
'Apple Inc. introduced a new iPhone and it has an MSRP of $999.'
"""

response = structured_model.invoke(prompt)

print("Raw structured object:", response)
print("JSON form:\n", json.dumps(response.model_dump(), indent=2))


Raw structured object: triples=[Triple(subject='Apple Inc.', subject_type='Company', relation='introduced', object='new iPhone', object_type='Product', subject_attributes=[], object_attributes=[KeyValueAttribute(key='MSRP', value='$999')])]
JSON form:
 {
  "triples": [
    {
      "subject": "Apple Inc.",
      "subject_type": "Company",
      "relation": "introduced",
      "object": "new iPhone",
      "object_type": "Product",
      "subject_attributes": [],
      "object_attributes": [
        {
          "key": "MSRP",
          "value": "$999"
        }
      ]
    }
  ]
}


In [142]:
# llm_gpt_4o_mini.invoke("What is the capital of France?")

In [None]:
def build_dynamic_prompt(classes, obj_props, data_props, chunk_text):
    cls_str = '\n'.join(f"- {c}" for c in classes)
    rel_str = '\n'.join(f"- {r}" for r in obj_props)
    attr_str = '\n'.join(f"- {a}" for a in data_props)

    return f"""
You are an expert in extracting structured knowledge from financial documents.

Classes:
{cls_str}

Object Properties:
{rel_str}

Data Properties (can be attached to subject_attributes or object_attributes):
{attr_str}

Extract JSON triples like:
[
  {{
    "subject_text": "Apple Inc.",
    "subject_type": "Company",
    "subject_attributes": {{ "companyName": "Apple Inc." }},
    "relation": "reports on",
    "object_text": "Q3 2022 10-Q",
    "object_type": "Financial Report",
    "object_attributes": {{ "reportType": "10-Q" }}
  }}
]

Context:
{chunk_text}
"""

In [105]:
from langchain_core.prompts import ChatPromptTemplate

# Multi Query: Different Perspectives
template_for_extracting_triples = """
You are an expert in extracting structured knowledge from financial documents.

Classes:
{cls_str}

Object Properties:
{rel_str}

Data Properties (can be attached to subject_attributes or object_attributes):
{attr_str}

Context:
{chunk_text}

# #### Important notes:
# - No additional explanations, numbering, or extra text.

Extract JSON triples like:
## Output format:
[
  {{
    "subject_text": "Apple Inc.",
    "subject_type": "Company",
    "subject_attributes": {{ "companyName": "Apple Inc." }},
    "relation": "reports on",
    "object_text": "Q3 2022 10-Q",
    "object_type": "Financial Report",
    "object_attributes": {{ "reportType": "10-Q" }}
  }}
]
"""
prompt_template_for_extracting_triples = ChatPromptTemplate.from_template(template_for_extracting_triples)

In [106]:
input_data = {
    "cls_str": '\n'.join(f"- {c}" for c in neo4j_schema.entities.values()),
    "rel_str": '\n'.join(f"- {r}" for r in neo4j_schema.relations.values()),
    "attr_str": '\n'.join(f"- {a}" for a in ({p["name"] for e in neo4j_schema.entities.values() for p in e.get("properties", [])})),
    "chunk_text": chunks[0].page_content
}

In [107]:
from langchain_core.output_parsers import StrOutputParser

chain = prompt_template_for_extracting_triples | llm_gpt_4o_mini | StrOutputParser()

result = chain.invoke(input_data)

In [110]:
print(result)

[
  {
    "subject_text": "Apple Inc.",
    "subject_type": "Company",
    "subject_attributes": { "companyName": "Apple Inc." },
    "relation": "reports on",
    "object_text": "Q2 2022 10-Q",
    "object_type": "Financial Report",
    "object_attributes": { "reportType": "10-Q", "periodEndDate": "2022-06-25" }
  }
]


In [111]:
import json

result_json = json.loads(result)
print(result_json)

[{'subject_text': 'Apple Inc.', 'subject_type': 'Company', 'subject_attributes': {'companyName': 'Apple Inc.'}, 'relation': 'reports on', 'object_text': 'Q2 2022 10-Q', 'object_type': 'Financial Report', 'object_attributes': {'reportType': '10-Q', 'periodEndDate': '2022-06-25'}}]


In [31]:
# from typing import List
# from langchain.schema import Document

# class CustomKGPipeline:
#     def __init__(self, driver, embedder, kg_llm, classes, object_properties, data_properties, prompt_template_path, neo4j_database=None):
#         self.driver = driver
#         self.embedder = embedder
#         self.kg_llm = kg_llm
#         self.classes = classes
#         self.object_properties = object_properties
#         self.data_properties = data_properties
#         self.neo4j_database = neo4j_database
#         self.prompt_template_path = prompt_template_path

#     def _create_chunk_nodes(self, docs: List[Document]):
#         with self.driver.session(database=self.neo4j_database) as session:
#             for i, doc in enumerate(docs):
#                 props = {
#                     "chunk_index": i,
#                     "text": doc.page_content,
#                     "embedding": self.embedder.embed_query(doc.page_content)
#                 }
#                 props.update(doc.metadata or {})
#                 session.run("""
#                 MERGE (c:Chunk {chunk_index: $i})
#                 SET c += $props
#                 """, i=i, props=props)
#                 if i > 0:
#                     session.run("""
#                     MATCH (c1:Chunk {chunk_index: $prev}), (c2:Chunk {chunk_index: $curr})
#                     MERGE (c1)-[:NEXT_CHUNK]->(c2)
#                     MERGE (c2)-[:PREVIOUS_CHUNK]->(c1)
#                     """, prev=i-1, curr=i)

#     def _prompt_template(self):
#         with open(self.prompt_template_path, "r") as f:
#             template_for_extracting_triples = f.read()
#         prompt_template_for_extracting_triples = ChatPromptTemplate.from_template(template_for_extracting_triples)
#         return prompt_template_for_extracting_triples

#     def _create_kg_from_chunks(self, docs: List[Document]):
#         for i, doc in enumerate(docs):
#             try:
#                 # Generate the prompt template
#                 prompt_template_for_extracting_triples = self._prompt_template()
                
#                 # Create the LLM chain
#                 llm_chain = prompt_template_for_extracting_triples | self.kg_llm | StrOutputParser()
                
#                 # Prepare input data
#                 input_data = {
#                     'cls_str': self.classes,
#                     'rel_str': self.object_properties,
#                     'attr_str': self.data_properties,
#                     'chunk_text': doc.page_content
#                 }
                
#                 # Invoke the LLM chain
#                 result = llm_chain.invoke(input_data)
                
#                 # Parse the result into triples
#                 triples = json.loads(result)
                
#                 # Process triples and store them in Neo4j
#                 with self.driver.session(database=self.neo4j_database) as session:
#                     for t in triples:
#                         s_txt, s_type, s_attr = t['subject_text'], t['subject_type'], t.get('subject_attributes', {})
#                         o_txt, o_type, o_attr = t['object_text'], t['object_type'], t.get('object_attributes', {})
#                         rel = t['relation']
                        
#                         # Construct Cypher query
#                         cypher = f"""
#                         MERGE (subj:{s_type.replace(' ', '_')} {{name: $sName}})
#                           ON CREATE SET subj.entity_type = $sType
#                         MERGE (obj:{o_type.replace(' ', '_')} {{name: $oName}})
#                           ON CREATE SET obj.entity_type = $oType
#                         MERGE (subj)-[r:{rel.upper().replace(' ', '_')}]->(obj)
#                         WITH subj, obj
#                         MATCH (c:Chunk {{chunk_index: $chunk_index}})
#                         MERGE (subj)-[:MENTIONED_IN]->(c)
#                         MERGE (obj)-[:MENTIONED_IN]->(c)
#                         """
#                         params = {
#                             "sName": s_txt, "sType": s_type, "oName": o_txt,
#                             "oType": o_type, "chunk_index": i
#                         }
#                         for idx, (k, v) in enumerate(s_attr.items()):
#                             cypher += f"\nSET subj.{k} = coalesce(subj.{k}, $sa{idx})"
#                             params[f"sa{idx}"] = v
#                         for idx, (k, v) in enumerate(o_attr.items()):
#                             cypher += f"\nSET obj.{k} = coalesce(obj.{k}, $oa{idx})"
#                             params[f"oa{idx}"] = v
                        
#                         # Execute Cypher query
#                         session.run(cypher, params)
#             except Exception as e:
#                 print(f"Error processing chunk {i}: {e}")


#     def _deduplicate_entities(self):
#         with self.driver.session(database=self.neo4j_database) as session:
#             session.run("""
#             MATCH (e1), (e2)
#             WHERE e1.name = e2.name AND e1.entity_type = e2.entity_type AND id(e1) < id(e2)
#             CALL apoc.refactor.mergeNodes([e1, e2], {properties:'combine'}) YIELD node
#             RETURN count(node)
#             """)

#     def run(self, docs: List[Document]):
#         print("📄 Creating chunk nodes...")
#         self._create_chunk_nodes(docs)
#         print("🔍 Extracting triples...")
#         self._create_kg_from_chunks(docs)
#         print("🧹 Deduplicating entities...")
#         self._deduplicate_entities()
#         print("✅ Done.")

In [41]:
import re
import json
from typing import List
from langchain.schema import Document
from langchain_core.output_parsers import StrOutputParser
from langchain.prompts import ChatPromptTemplate

def sanitize_key(key: str) -> str:
    """
    Replace invalid Neo4j property characters with underscores.
    - Removes spaces, punctuation, etc.
    - If it starts with a digit, prepend an underscore.
    """
    # Replace all non-(alphanumeric/underscore) with underscore
    safe_key = re.sub(r'[^0-9A-Za-z_]', '_', key)
    # If it starts with digit, prepend underscore
    if safe_key and safe_key[0].isdigit():
        safe_key = f"_{safe_key}"
    return safe_key

class CustomKGPipeline:
    def __init__(
        self,
        driver,
        embedder,
        kg_llm,
        classes,
        object_properties,
        data_properties,
        prompt_template_path,
        neo4j_database=None
    ):
        self.driver = driver
        self.embedder = embedder
        self.kg_llm = kg_llm
        self.classes = classes
        self.object_properties = object_properties
        self.data_properties = data_properties
        self.neo4j_database = neo4j_database
        self.prompt_template_path = prompt_template_path

    def _create_chunk_nodes(self, docs: List[Document]):
        """
        Creates 'Chunk' nodes in Neo4j for each document chunk,
        storing text, metadata, and embeddings.
        """
        with self.driver.session(database=self.neo4j_database) as session:
            for i, doc in enumerate(docs):
                props = {
                    "chunk_index": i,
                    "text": doc.page_content,
                    "embedding": self.embedder.embed_query(doc.page_content)
                }
                props.update(doc.metadata or {})
                session.run(
                    """
                    MERGE (c:Chunk {chunk_index: $i})
                    SET c += $props
                    """,
                    i=i, props=props
                )
                if i > 0:
                    session.run(
                        """
                        MATCH (c1:Chunk {chunk_index: $prev}), (c2:Chunk {chunk_index: $curr})
                        MERGE (c1)-[:NEXT_CHUNK]->(c2)
                        MERGE (c2)-[:PREVIOUS_CHUNK]->(c1)
                        """,
                        prev=i-1, curr=i
                    )

    def _prompt_template(self):
        """
        Loads a textual prompt template from file and
        returns a ChatPromptTemplate object.
        """
        with open(self.prompt_template_path, "r") as f:
            template_for_extracting_triples = f.read()
        return ChatPromptTemplate.from_template(template_for_extracting_triples)

    def _create_kg_from_chunks(self, docs: List[Document]):
        """
        For each chunk, calls the LLM to extract triples,
        then merges the resulting data into Neo4j.
        """
        for i, doc in enumerate(docs):
            try:
                # Build a chain that:
                # 1) Renders your prompt
                # 2) Passes it to the LLM
                # 3) Parses the LLM response as a raw string (StrOutputParser)
                llm_chain = self._prompt_template() | self.kg_llm | StrOutputParser()
                result = llm_chain.invoke({
                    'cls_str': self.classes,
                    'rel_str': self.object_properties,
                    'attr_str': self.data_properties,
                    'chunk_text': doc.page_content
                })

                # result should be valid JSON, e.g.:
                # {
                #   "triples": [
                #     { "subject": ..., "subject_type": ..., "relation": ..., "object": ...,
                #       "object_type": ..., "subject_attributes": [...], "object_attributes": [...] }
                #   ]
                # }
                result_obj = json.loads(result)
                triples = result_obj["triples"]

                with self.driver.session(database=self.neo4j_database) as session:
                    for t in triples:
                        # new schema fields
                        s_txt = t["subject"]
                        s_type = t["subject_type"]
                        o_txt = t["object"]
                        o_type = t["object_type"]
                        rel = t["relation"]

                        # subject_attributes is a list of {"key": "...", "value": "..."}
                        s_attr_list = t.get("subject_attributes", [])
                        s_attr_dict = {pair["key"]: pair["value"] for pair in s_attr_list}

                        # object_attributes is also a list
                        o_attr_list = t.get("object_attributes", [])
                        o_attr_dict = {pair["key"]: pair["value"] for pair in o_attr_list}

                        # Sanitize the property keys
                        s_attr_sanitized = {}
                        for orig_k, val in s_attr_dict.items():
                            safe_k = sanitize_key(orig_k)
                            s_attr_sanitized[safe_k] = val

                        o_attr_sanitized = {}
                        for orig_k, val in o_attr_dict.items():
                            safe_k = sanitize_key(orig_k)
                            o_attr_sanitized[safe_k] = val

                        # Construct the Cypher query
                        cypher = f"""
                        MERGE (subj:{s_type.replace(' ', '_')} {{name: $sName}})
                          ON CREATE SET subj.entity_type = $sType
                        MERGE (obj:{o_type.replace(' ', '_')} {{name: $oName}})
                          ON CREATE SET obj.entity_type = $oType
                        MERGE (subj)-[r:{rel.upper().replace(' ', '_')}]->(obj)
                        WITH subj, obj
                        MATCH (c:Chunk {{chunk_index: $chunk_index}})
                        MERGE (subj)-[:MENTIONED_IN]->(c)
                        MERGE (obj)-[:MENTIONED_IN]->(c)
                        """

                        params = {
                            "sName": s_txt,
                            "sType": s_type,
                            "oName": o_txt,
                            "oType": o_type,
                            "chunk_index": i
                        }

                        # Now loop over s_attr_sanitized to set subject properties
                        for idx, (k, v) in enumerate(s_attr_sanitized.items()):
                            cypher += f"\nSET subj.{k} = coalesce(subj.{k}, $sa{idx})"
                            params[f"sa{idx}"] = v

                        # Similarly for o_attr_sanitized
                        for idx, (k, v) in enumerate(o_attr_sanitized.items()):
                            cypher += f"\nSET obj.{k} = coalesce(obj.{k}, $oa{idx})"
                            params[f"oa{idx}"] = v

                        session.run(cypher, params)

            except Exception as e:
                print(f"Error processing chunk {i}: {e}")

    def _deduplicate_entities(self):
        """
        Merges duplicate nodes that share the same (name, entity_type)
        by combining their properties via apoc.refactor.mergeNodes.
        """
        with self.driver.session(database=self.neo4j_database) as session:
            session.run("""
            MATCH (e1), (e2)
            WHERE e1.name = e2.name AND e1.entity_type = e2.entity_type AND id(e1) < id(e2)
            CALL apoc.refactor.mergeNodes([e1, e2], {properties:'combine'}) YIELD node
            RETURN count(node)
            """)

    def run(self, docs: List[Document]):
        """
        High-level orchestration method:
          1) Create chunk nodes
          2) Extract triples from each chunk & store them
          3) Deduplicate entities
        """
        print("📄 Creating chunk nodes...")
        self._create_chunk_nodes(docs)
        print("🔍 Extracting triples...")
        self._create_kg_from_chunks(docs)
        print("🧹 Deduplicating entities...")
        self._deduplicate_entities()
        print("✅ Done.")


In [42]:
pipeline = CustomKGPipeline(
    driver=driver,
    embedder=llm_embedding_large_3,
    kg_llm=llm_gpt_4o_mini,
    classes=classes,
    object_properties=obj_props,
    data_properties=data_props,
    prompt_template_path="prompt/prompt_2.txt",
)

In [46]:
pipeline.run(chunks[20:30])

📄 Creating chunk nodes...
🔍 Extracting triples...




🧹 Deduplicating entities...
✅ Done.


In [8]:
with driver.session() as session:
    session.run("MATCH (n) DETACH DELETE n")
    session.run("DROP INDEX `chunk-index` IF EXISTS")
    session.run("CALL apoc.schema.assert({}, {})")
