In [14]:
from neo4j import GraphDatabase
from pathlib import Path
from typing import Dict, Tuple, List, Optional
import os
import csv
import sys

# ======================
# --- CONFIG SECTION ---
# ======================
NEO4J_URI  = "neo4j://localhost:7687"
NEO4J_USER = "neo4j"
NEO4J_PASS = "YourStrongPass123!"

# Path to the directory that holds the FinDKG files
# If you run this script from anywhere, point this to your Neo4j project's import dir.
DATA_DIR = Path("/Users/agniksaha/Downloads/FinDKG-main/FinDKG_dataset/FinDKG-full")

# Batch sizes (tune for speed/memory)
NODE_BATCH = 20_000
EDGE_BATCH = 20_000

# Tag so you can filter queries if you like
DATASET_TAG = "FinDKG"

# Isolated labels / rel type (to avoid collisions with prior data)
L_ENTITY = "FDEntity"
L_REL    = "FDRelation"
L_TIME   = "FDTime"
R_FACT   = "FD_FACT"

# =================================
# --- FILE READERS (exact fit) ---
# =================================
def read_time2id_csv(fp: Path) -> Dict[int, str]:
    """
    time2id.txt format (CSV):
      TimeID,DATE_WK
      0,2018-01-07
      ...
    """
    out: Dict[int, str] = {}
    if not fp.exists():
        return out
    with fp.open("r", encoding="utf-8") as f:
        rdr = csv.DictReader(f)
        for row in rdr:
            tid_s = row.get("TimeID") or row.get("timeid") or row.get("time_id")
            date_s = row.get("DATE_WK") or row.get("date") or row.get("DATE")
            if tid_s is None or date_s is None:
                continue
            try:
                tid = int(tid_s.strip())
            except ValueError:
                continue
            out[tid] = date_s.strip()
    return out

def read_relation2id_tsv(fp: Path) -> Dict[int, str]:
    """
    relation2id.txt format (TSV):
      Relate_To<TAB>0
      Control<TAB>1
      ...
    """
    out: Dict[int, str] = {}
    with fp.open("r", encoding="utf-8") as f:
        for ln in f:
            ln = ln.strip()
            if not ln:
                continue
            parts = ln.split("\t")
            if len(parts) != 2:
                parts = ln.split()
                if len(parts) != 2:
                    continue
            name, rid_s = parts[0], parts[1]
            try:
                rid = int(rid_s)
            except ValueError:
                continue
            out[rid] = name
    return out

def read_entity2id(fp: Path) -> Dict[int, str]:
    """
    entity2id.txt format (TSV with 4 columns):
      name<TAB>id<TAB>...<TAB>...
    Use the FIRST numeric token after name as eid.
    """
    out: Dict[int, str] = {}
    with fp.open("r", encoding="utf-8") as f:
        for ln in f:
            ln = ln.rstrip("\n\r")
            if not ln:
                continue
            parts = ln.split("\t") if "\t" in ln else ln.rsplit(maxsplit=3)
            if len(parts) < 2:
                continue
            name = parts[0].strip()
            eid_s = None
            for tok in parts[1:]:
                tok = tok.strip()
                if tok.isdigit() or (tok.startswith("-") and tok[1:].isdigit()):
                    eid_s = tok
                    break
            if eid_s is None:
                continue
            try:
                eid = int(eid_s)
            except ValueError:
                continue
            out[eid] = name
    return out

def read_edges_ids(fp: Path) -> List[Tuple[int, int, int, Optional[int]]]:
    """
    train/valid/test format (IDs only):
      h_id  t_id  r_id  [time_id]
    separated by TAB or SPACE.
    """
    edges: List[Tuple[int, int, int, Optional[int]]] = []
    with fp.open("r", encoding="utf-8") as f:
        for ln in f:
            ln = ln.strip()
            if not ln:
                continue
            parts = ln.split("\t")
            if len(parts) < 3:
                parts = ln.split()
            if len(parts) < 3:
                continue
            try:
                h = int(parts[0]); t = int(parts[1]); r = int(parts[2])
                tm = int(parts[3]) if len(parts) >= 4 and parts[3] != "" else None
            except ValueError:
                raise ValueError(f"{fp.name} must contain numeric IDs only (got: {ln[:80]}...)")
            edges.append((h, t, r, tm))
    return edges

# ==================================
# --- CYPHER / INGESTION ROUTES ---
# ==================================
def create_constraints(driver):
    with driver.session() as session:
        # Only on our isolated labels
        session.run(f"CREATE CONSTRAINT IF NOT EXISTS FOR (e:{L_ENTITY}) REQUIRE e.eid IS UNIQUE;")
        session.run(f"CREATE CONSTRAINT IF NOT EXISTS FOR (r:{L_REL})    REQUIRE r.rid IS UNIQUE;")
        session.run(f"CREATE CONSTRAINT IF NOT EXISTS FOR (t:{L_TIME})   REQUIRE t.tid IS UNIQUE;")
        # Neo4j 5+: unique key on relationship
        try:
            session.run(f"CREATE CONSTRAINT IF NOT EXISTS FOR ()-[f:{R_FACT}]-() REQUIRE f.key IS UNIQUE;")
        except Exception as e:
            print("Note: FD_FACT(key) uniqueness not created (likely Neo4j < 5):", e, file=sys.stderr)

def ingest_entities(driver, id2name: Dict[int, str]):
    rows = [{"eid": k, "name": v, "dataset": DATASET_TAG} for k, v in id2name.items()]
    for i in range(0, len(rows), NODE_BATCH):
        chunk = rows[i:i+NODE_BATCH]
        with driver.session() as session:
            session.run(f"""
                UNWIND $rows AS row
                MERGE (e:{L_ENTITY} {{eid: row.eid}})
                ON CREATE SET e.name = row.name, e.dataset = row.dataset
                ON MATCH  SET e.name = coalesce(e.name, row.name),
                              e.dataset = coalesce(e.dataset, row.dataset)
            """, rows=chunk)

def ingest_relations(driver, id2name: Dict[int, str]):
    rows = [{"rid": k, "name": v, "dataset": DATASET_TAG} for k, v in id2name.items()]
    for i in range(0, len(rows), NODE_BATCH):
        chunk = rows[i:i+NODE_BATCH]
        with driver.session() as session:
            session.run(f"""
                UNWIND $rows AS row
                MERGE (r:{L_REL} {{rid: row.rid}})
                ON CREATE SET r.name = row.name, r.dataset = row.dataset
                ON MATCH  SET r.name = coalesce(r.name, row.name),
                              r.dataset = coalesce(r.dataset, row.dataset)
            """, rows=chunk)

def ingest_times(driver, id2date: Dict[int, str]):
    if not id2date:
        return
    rows = [{"tid": k, "value": v, "dataset": DATASET_TAG} for k, v in id2date.items()]
    for i in range(0, len(rows), NODE_BATCH):
        chunk = rows[i:i+NODE_BATCH]
        with driver.session() as session:
            session.run(f"""
                UNWIND $rows AS row
                MERGE (t:{L_TIME} {{tid: row.tid}})
                ON CREATE SET t.value = row.value, t.dataset = row.dataset
                ON MATCH  SET t.value = coalesce(t.value, row.value),
                              t.dataset = coalesce(t.dataset, row.dataset)
            """, rows=chunk)

def ingest_edges(driver,
                 edges: List[Tuple[int, int, int, Optional[int]]],
                 split_name: str,
                 rid2name: Dict[int, str],
                 tid2date: Dict[int, str]):
    """
    One relationship type (:FD_FACT) with a unique 'key' so we can MERGE without duplicates.
    Properties: rid, rel, tid, time, split, dataset
    """
    if not edges:
        print(f"No {split_name} edges to ingest.")
        return

    rows = []
    for (h, t, r, tm) in edges:
        tm_str = "NA" if tm is None else str(tm)
        key = f"{h}|{t}|{r}|{tm_str}|{split_name}|{DATASET_TAG}"
        rows.append({
            "h": h, "t": t,
            "rid": r, "rel": rid2name.get(r, str(r)),
            "tid": tm, "time": (tid2date.get(tm) if tm is not None else None),
            "split": split_name, "dataset": DATASET_TAG, "key": key
        })

    print(f"Ingesting {len(rows):,} {split_name} edges…")
    for i in range(0, len(rows), EDGE_BATCH):
        chunk = rows[i:i+EDGE_BATCH]
        with driver.session() as session:
            session.run(f"""
                UNWIND $rows AS row
                MATCH (h:{L_ENTITY} {{eid: row.h}})
                MATCH (t:{L_ENTITY} {{eid: row.t}})
                MERGE (h)-[e:{R_FACT} {{key: row.key}}]->(t)
                ON CREATE SET e.rid=row.rid, e.rel=row.rel, e.tid=row.tid, e.time=row.time,
                              e.split=row.split, e.dataset=row.dataset
                ON MATCH  SET e.rid=coalesce(e.rid,row.rid),
                              e.rel=coalesce(e.rel,row.rel),
                              e.tid=coalesce(e.tid,row.tid),
                              e.time=coalesce(e.time,row.time),
                              e.split=coalesce(e.split,row.split),
                              e.dataset=coalesce(e.dataset,row.dataset)
            """, rows=chunk)

# ==========================
# --- SANITY QUERIES -------
# ==========================
def sample_queries(driver):
    with driver.session() as session:
        print("\n# of FDEntity nodes:",
              session.run(f"MATCH (e:{L_ENTITY}) RETURN count(e) AS n").single()["n"])
        print("# of FDRelation nodes:",
              session.run(f"MATCH (r:{L_REL}) RETURN count(r) AS n").single()["n"])
        print("# of FDTime nodes:",
              session.run(f"MATCH (t:{L_TIME}) RETURN count(t) AS n").single()["n"])
        print("# of {R_FACT} relationships:",
              session.run(f"MATCH ()-[f:{R_FACT}]-() RETURN count(f) AS n").single()["n"])

        print("\nTop 10 entities by FACT-degree (isolated labels):")
        for rec in session.run(f"""
            MATCH (e:{L_ENTITY})
            OPTIONAL MATCH (e)-[f:{R_FACT}]-()
            WITH e, count(f) AS deg
            RETURN e.name AS entity, deg
            ORDER BY deg DESC LIMIT 10
        """):
            print(f"{rec['entity']}: {rec['deg']}")

        target = "Apple Inc."
        print(f"\nNeighbors of '{target}' (limit 10):")
        for rec in session.run(f"""
            MATCH (e:{L_ENTITY} {{name:$name}})
            OPTIONAL MATCH (e)-[r:{R_FACT}]-(nbr:{L_ENTITY})
            RETURN nbr.name AS neighbor, r.rel AS relation, r.time AS time, r.split AS split
            LIMIT 10
        """, name=target):
            print(rec)

        year_prefix = "2022"
        print(f"\nEdges with time starting with '{year_prefix}':")
        for rec in session.run(f"""
            MATCH (h:{L_ENTITY})-[r:{R_FACT}]->(t:{L_ENTITY})
            WHERE r.time IS NOT NULL AND r.time STARTS WITH $pref
            RETURN h.name AS head, r.rel AS rel, t.name AS tail, r.time AS time
            LIMIT 10
        """, pref=year_prefix):
            print(rec)

# ===============
# --- MAIN ------
# ===============
def main():
    assert DATA_DIR.exists(), f"DATA_DIR not found: {DATA_DIR}"

    # Read IDs in their true formats
    time_map     = read_time2id_csv(DATA_DIR / "time2id.txt")          # {tid -> date_str}
    relation_map = read_relation2id_tsv(DATA_DIR / "relation2id.txt")  # {rid -> name}
    entity_map   = read_entity2id(DATA_DIR / "entity2id.txt")          # {eid -> name}

    print(f"Loaded IDs — Entities:{len(entity_map):,} Relations:{len(relation_map):,} Times:{len(time_map):,}")

    # Read splits (IDs only). If this errors, your split files are names — ping me and I’ll swap readers.
    train_edges = read_edges_ids(DATA_DIR / "train.txt")
    valid_edges = read_edges_ids(DATA_DIR / "valid.txt")
    test_edges  = read_edges_ids(DATA_DIR / "test.txt")
    print(f"Parsed edges — train:{len(train_edges):,} valid:{len(valid_edges):,} test:{len(test_edges):,}")

    # Connect + ingest into isolated labels
    driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASS))
    try:
        create_constraints(driver)
        ingest_entities(driver, entity_map)
        ingest_relations(driver, relation_map)
        ingest_times(driver, time_map)
        ingest_edges(driver, train_edges, "train", relation_map, time_map)
        ingest_edges(driver, valid_edges, "valid", relation_map, time_map)
        ingest_edges(driver, test_edges,  "test",  relation_map, time_map)
        sample_queries(driver)
    finally:
        driver.close()

if __name__ == "__main__":
    main()

Loaded IDs — Entities:13,645 Relations:15 Times:261
Parsed edges — train:222,732 valid:9,404 test:10,013
Ingesting 222,732 train edges…
Ingesting 9,404 valid edges…
Ingesting 10,013 test edges…

# of FDEntity nodes: 13645
# of FDRelation nodes: 15
# of FDTime nodes: 261
# of {R_FACT} relationships: 439113

Top 10 entities by FACT-degree (isolated labels):
Donald Trump: 64642
United States: 52086
China: 36322
U.S. Federal Reserve: 33537
Joe Biden: 21698
COVID-19: 8703
Russia: 8415
US Government: 7587
Meta Platforms Inc.: 6378
Apple Inc.: 6203

Neighbors of 'Apple Inc.' (limit 10):
<Record neighbor='House Speaker Paul Ryan' relation='Relate_To' time='2018-03-18' split='train'>
<Record neighbor='Ukrainian Government' relation='2266' time='2022-11-27' split='test'>
<Record neighbor='Mr. Macron' relation='10178' time='2022-12-18' split='test'>
<Record neighbor='Governor Gavin Newsom' relation='10203' time='2022-12-25' split='test'>
<Record neighbor='The Biden Administration' relation='2560'