In [1]:
!pip install wikipedia-api neo4j cohere  --quiet

In [2]:
from neo4j import GraphDatabase
import re
import os
from dotenv import load_dotenv
import cohere
import requests
from dotenv import load_dotenv
import cohere
from time import sleep
import json
from difflib import SequenceMatcher


In [3]:
from myutils import fetch_raw_text, strip_gutenberg_header_footer, chunk_text, merge_entities, remap_relationships

### LLM API and Neo4j DB connections

In [4]:
# Load all keys from .env
load_dotenv()

# Access environment variables
COHERE_API_KEY = os.getenv("COHERE_API_KEY")
NEO4J_URI = os.getenv("NEO4J_URI")
NEO4J_USER = os.getenv("NEO4J_USER")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD")

# Debug check (optional – don’t print secrets in real projects)
print("Cohere key loaded:", bool(COHERE_API_KEY))
print("Neo4j URI:", NEO4J_URI)
print("Neo4j User:", NEO4J_USER)

Cohere key loaded: True
Neo4j URI: bolt://44.200.207.55:7687
Neo4j User: neo4j


In [5]:
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
#check connection
with driver.session() as session:
    result = session.run("RETURN 1")
    print("Connection test result:", result.single()[0])  # Should print 1 if successful

Connection test result: 1


### Fetch Data


In [6]:
GUTENBERG_TXT_URL = "https://www.gutenberg.org/cache/epub/244/pg244.txt"  # A Study in Scarlet (id=244)
#GUTENBERG_TXT_URL = "https://www.gutenberg.org/cache/epub/18897/pg18897.txt"  # The Epic of Gilgamish (Langdon, id=18897)
#GUTENBERG_TXT_URL = "https://www.gutenberg.org/cache/epub/11000/pg11000.txt"  # An Old Babylonian Version of the Gilgamesh Epic (Jastrow & Clay, id=11000)

raw = fetch_raw_text(GUTENBERG_TXT_URL)
core = strip_gutenberg_header_footer(raw)

with open("data/sherlock_raw.txt", "w", encoding="utf-8") as f:
    f.write(core)
print("Saved cleaned text -> data/sherlock_raw.txt")

Saved cleaned text -> data/sherlock_raw.txt


In [10]:
# Chunk Data
text = open("data/sherlock_raw.txt", encoding="utf-8").read()

# Clean text
# remove /n with .
text = text.replace("\n", " ").replace("\r", " ").replace("  ", " ")
# start text from CHAPTER I. MR. SHERLOCK HOLMES. to avoid preface
start_idx = text.find("In the year 1878 I took my degree of Doctor of Medicine of the")
text = text[start_idx:]


chunks = chunk_text(text, max_chars=4000, overlap=1000)
print(f"Chunks: {len(chunks)}")

Chunks: 61


### Call LLM on each chunk to identify nodes and relationship

In [11]:
co = cohere.ClientV2(COHERE_API_KEY, log_warning_experimental_features=False)

In [12]:
# Global lists
global_entities = []
global_entity_map = {}      # name/alias -> id
global_relationships = []
global_relation_types = set()  # unique relation types
existing_rels = set()       # (source_id, relation_type, target_id) tuples
entity_counter = 1

# Load the response schema
with open("response_schema.json") as f:
    response_schema = json.load(f)

# Loop over each chunk
for i, chunk in enumerate(chunks):
    print(f"Processing chunk {i+1}/{len(chunks)}")

    # Prepare global lists
    entity_list_str = json.dumps({"entities": global_entities}, ensure_ascii=False)
    relation_list_str = json.dumps(list(global_relation_types), ensure_ascii=False)

    # Load prompt template
    prompt = open("prompt_template.txt").read()
    #prompt = open("prompt_v2.txt").read()
    prompt = prompt.replace("{CHUNK}", chunk)
    prompt = prompt.replace("{ENTITYLIST}", entity_list_str)
    prompt = prompt.replace("{RELATIONLIST}", relation_list_str)

    # Call the LLM
    response = co.chat(
        model="command-a-03-2025",
        messages=[{"role": "user", "content": prompt}],
        response_format={
            "type": "json_object",
            "schema": response_schema
        }
    )

    # Parse model output
    data = json.loads(response.dict()["message"]["content"][0]["text"])
    sleep(8)  # avoid rate limits

    # 🔹 Merge entities
    for ent in data["entities"]:
        key = ent["name"].lower()
        if key in global_entity_map:
            ent["id"] = global_entity_map[key]
        else:
            ent_id = f"e{entity_counter}"
            ent["id"] = ent_id
            global_entity_map[key] = ent_id
            global_entities.append(ent)
            entity_counter += 1

    # 🔹 Merge relationships (deduplicate and normalize)
    for rel in data["relationships"]:
        src_id = global_entity_map.get(rel["source"].lower(), rel["source"])
        tgt_id = global_entity_map.get(rel["target"].lower(), rel["target"])
        rel_type = rel["relation"].lower()

        rel_key = (src_id, rel_type, tgt_id)
        if rel_key not in existing_rels:
            rel["source"] = src_id
            rel["target"] = tgt_id
            rel["relation"] = rel_type
            global_relationships.append(rel)
            existing_rels.add(rel_key)
            global_relation_types.add(rel_type)




Processing chunk 1/61
Processing chunk 2/61
Processing chunk 3/61
Processing chunk 4/61
Processing chunk 5/61
Processing chunk 6/61
Processing chunk 7/61
Processing chunk 8/61
Processing chunk 9/61
Processing chunk 10/61
Processing chunk 11/61
Processing chunk 12/61
Processing chunk 13/61
Processing chunk 14/61
Processing chunk 15/61
Processing chunk 16/61
Processing chunk 17/61
Processing chunk 18/61
Processing chunk 19/61
Processing chunk 20/61
Processing chunk 21/61
Processing chunk 22/61
Processing chunk 23/61
Processing chunk 24/61
Processing chunk 25/61
Processing chunk 26/61
Processing chunk 27/61
Processing chunk 28/61
Processing chunk 29/61
Processing chunk 30/61
Processing chunk 31/61
Processing chunk 32/61
Processing chunk 33/61
Processing chunk 34/61
Processing chunk 35/61
Processing chunk 36/61
Processing chunk 37/61
Processing chunk 38/61
Processing chunk 39/61
Processing chunk 40/61
Processing chunk 41/61
Processing chunk 42/61
Processing chunk 43/61
Processing chunk 44/

TooManyRequestsError: headers: {'access-control-expose-headers': 'X-Debug-Trace-ID', 'cache-control': 'no-cache, no-store, no-transform, must-revalidate, private, max-age=0', 'content-type': 'application/json', 'expires': 'Thu, 01 Jan 1970 00:00:00 GMT', 'pragma': 'no-cache', 'vary': 'Origin', 'x-accel-expires': '0', 'x-debug-trace-id': 'e7f70a52aabe9e4bcffa212d3de73520', 'x-trial-endpoint-call-limit': '10', 'x-trial-endpoint-call-remaining': '6', 'date': 'Tue, 26 Aug 2025 22:35:35 GMT', 'content-length': '373', 'x-envoy-upstream-service-time': '8', 'server': 'envoy', 'via': '1.1 google', 'alt-svc': 'h3=":443"; ma=2592000,h3-29=":443"; ma=2592000'}, status_code: 429, body: {'id': '8e6e8dde-e399-41ee-8054-b1fd46654a05', 'message': "You are using a Trial key, which is limited to 1000 API calls / month. You can continue to use the Trial key for free or upgrade to a Production key with higher rate limits at 'https://dashboard.cohere.com/api-keys'. Contact us on 'https://discord.gg/XW44jPfYJu' or email us at support@cohere.com with any questions"}

In [16]:
# Remove relationships with missing entities
# --- Step 1: Remove invalid relationships ---
valid_ids = {e["id"] for e in global_entities}

filtered_relationships = [
    rel for rel in global_relationships
    if rel["source"] in valid_ids and rel["target"] in valid_ids
]

# --- Step 2: Remove unused entities ---
# Collect all IDs that appear in valid relationships
used_ids = {rel["source"] for rel in filtered_relationships} | {rel["target"] for rel in filtered_relationships}

filtered_entities = [ent for ent in global_entities if ent["id"] in used_ids]

# --- Debug ---
removed_entities = [ent for ent in global_entities if ent["id"] not in used_ids]
print(f"[Cleanup] Removed {len(removed_entities)} unused entities")
for e in removed_entities:
    print(" -", e["name"], f"({e['id']})")

removed_rels = [
    rel for rel in global_relationships
    if rel not in filtered_relationships
]
print(f"[Cleanup] Removed {len(removed_rels)} invalid relationships")
for r in removed_rels:
    print(" -", r)


[Cleanup] Removed 10 unused entities
 - Afghanistan (e6)
 - Barts (e17)
 - Criterion Bar (e20)
 - Mason of Bradford (e27)
 - Lefevre of Montpellier (e29)
 - Ohio (e44)
 - Wedding Ring (e67)
 - Wiggins (e96)
 - Palmyra (e139)
 - Young (e147)
[Cleanup] Removed 18 invalid relationships
 - {'source': 'e186', 'relation': 'works_as', 'target': 'e200', 'evidence_span': 'Driving and riding are as natural to me as walking, so I applied at a cabowner’s office, and soon got employment.'}
 - {'source': 'e190', 'relation': 'travels_to', 'target': 'e200', 'evidence_span': 'to St. Petersburg'}
 - {'source': 'e189', 'relation': 'apprehended_by', 'target': 'e202', 'evidence_span': 'the next thing I knew, this young man here had the bracelets on my wrists'}
 - {'source': 'e189', 'relation': 'intends_to_travel_to', 'target': 'e200', 'evidence_span': 'I was about done up. I went on cabbing it for a day or so, intending to keep at it until I could save enough to take me back to America.'}
 - {'source': 'e1

In [None]:
# --- Final entity resolution ---
canonical_entities, resolved_map = merge_entities(filtered_entities, log_merges=True)
resolved_relationships = remap_relationships(filtered_relationships, resolved_map)

[Entity Resolution] Merged 'Holborn Restaurant' -> 'Holborn' (sim=0.56)
[Entity Resolution] Merged 'Rachel' -> 'RACHE' (sim=0.91)
[Entity Resolution] Merged 'Dr. Watson' -> 'Dr Watson' (sim=0.95)
[Entity Resolution] Merged 'Enoch Drebber' -> 'Enoch J. Drebber' (sim=0.90)


In [21]:
#  Final merged JSON
final_output = {
    "entities": filtered_entities,
    "relationships": filtered_relationships
}

# Save
with open("extracted.json", "w", encoding="utf-8") as f:
    json.dump(final_output, f, indent=2, ensure_ascii=False)

print("\n✅ Final KG saved to final_output.json")


✅ Final KG saved to final_output.json


In [13]:
filtered_entities

NameError: name 'filtered_entities' is not defined

In [14]:
# Load your extracted JSON
with open("extracted.json", "r") as f:
    data = json.load(f)

In [26]:
delete_query = "MATCH (n) DETACH DELETE n"

with driver.session() as session:
    session.run(delete_query)

print("All nodes and relationships have been deleted.")

All nodes and relationships have been deleted.


In [27]:
# Create nodes with dynamic labels
query_nodes = """
UNWIND $entities AS entity
CALL apoc.merge.node([entity.type], {id: entity.id}, 
                     {name: entity.name, aliases: entity.aliases, span: entity.span}, 
                     {}) YIELD node
RETURN node
"""

# Create relationships with dynamic types
query_rels = """
UNWIND $relationships AS rel
MATCH (src {id: rel.source})
MATCH (tgt {id: rel.target})
CALL apoc.merge.relationship(src, rel.relation, {}, {evidence: rel.evidence_span}, tgt) YIELD rel AS r
RETURN r
"""

with driver.session() as session:
    session.run(query_nodes, entities=data["entities"])
    session.run(query_rels, relationships=data["relationships"])
