In [1]:
import os
import json
from typing import List, Dict, Any, Tuple
from collections import defaultdict, Counter

import requests
import numpy as np
import pandas as pd
import networkx as nx

from dotenv import load_dotenv
from openai import OpenAI

# Load environment
load_dotenv()

WEAVIATE_URL = os.getenv("WEAVIATE_URL")
WEAVIATE_API_KEY = os.getenv("WEAVIATE_API_KEY")
COLLECTION_NAME = os.getenv("COLLECTION_NAME", "MyDocs")

OPENAI_API_KEY_CHAT = os.getenv("OPENAI_API_KEY_CHAT")  # chat key (not embed)
OPENAI_MODEL = os.getenv("OPENAI_MODEL")  

assert WEAVIATE_URL and WEAVIATE_API_KEY, "Weaviate env vars missing"
assert OPENAI_API_KEY_CHAT, "OPENAI_API_KEY_CHAT missing"

client = OpenAI(api_key=OPENAI_API_KEY_CHAT)

# Weaviate REST / GraphQL headers
weaviate_headers = {
    "Content-Type": "application/json",
    "Authorization": f"Bearer {WEAVIATE_API_KEY}",
}

# Pagination / data volume
PAGE_LIMIT = 500        # objects per GraphQL page
MAX_CHUNKS = 2000       # set to None for full ~50k, start small for dev

# Entity extraction settings
ENTITY_TYPES = ["PERSON", "ORG", "LOCATION", "EVENT", "WORK", "DATE", "OTHER"]
MAX_CHARS_PER_CHUNK = 4000   # safeguard, your chunks are small anyway

# OpenAI throughput control (very rough)
CHUNK_BATCH = 100             # chunks per loop before a short sleep
SLEEP_SECONDS = 3


In [2]:

def fetch_page(cursor: str = None, limit: int = PAGE_LIMIT) -> List[Dict[str, Any]]: 
    after_clause = f'after: "{cursor}"' if cursor else ""
    query = {
        "query": f"""
        {{
          Get {{
            {COLLECTION_NAME}(
              limit: {limit}
              {after_clause}
            ) {{
              chunk_id
              doc_id
              text
              _additional {{
                id
              }}
            }}
          }}
        }}
        """
    }
    resp = requests.post(
        f"{WEAVIATE_URL}/v1/graphql",
        json=query,
        headers=weaviate_headers,
        timeout=60,
    )
    resp.raise_for_status()
    data = resp.json()
    if "errors" in data:
        raise RuntimeError(data["errors"])
    return data["data"]["Get"][COLLECTION_NAME]


In [3]:
# Pull chunks into a DataFrame
all_rows = []
cursor = None

while True:
    rows = fetch_page(cursor, PAGE_LIMIT)
    if not rows:
        break

    all_rows.extend(rows)
    cursor = rows[-1]["_additional"]["id"]
    print(f"Fetched {len(all_rows)} chunks...")

    if MAX_CHUNKS is not None and len(all_rows) >= MAX_CHUNKS:
        all_rows = all_rows[:MAX_CHUNKS]
        print(f"Reached MAX_CHUNKS={MAX_CHUNKS}, stopping.")
        break

print("Total chunks fetched:", len(all_rows))

records = []
for r in all_rows:
    records.append(
        {
            "chunk_id": r.get("chunk_id"),
            "doc_id": r.get("doc_id"),
            "text": r.get("text") or "",
            "_id": r["_additional"]["id"],
        }
    )

df = pd.DataFrame(records)
df.head(), df.shape


Fetched 500 chunks...
Fetched 1000 chunks...
Fetched 1500 chunks...
Fetched 2000 chunks...
Reached MAX_CHUNKS=2000, stopping.
Total chunks fetched: 2000


(      chunk_id  doc_id                                               text  \
 0  016697_c937  016697  0024-001 3181903 0 3181903 RES N 50-43-43-10-1...   
 1  023731_c153  023731   scholars. But, in a way, he was very wrong. Y...   
 2  027333_c015  027333  5 Is Read: Yes Is Invitation: No GUID: 734F669...   
 3  025231_c005  025231  -mail message is subject to the Dubai World Gr...   
 4  017088_c153  017088   for a return of McCarthyism or for military h...   
 
                                     _id  
 0  000203a2-f584-43f1-8527-e167b0bf8c6e  
 1  00046a59-5d95-4bdd-9bef-2acc2feead61  
 2  0008b585-596a-42b9-8d0a-9be00a32b994  
 3  00091260-d45d-4f44-bb4e-8100c280de0e  
 4  00099043-8c80-4015-8d4c-6912225c5d60  ,
 (2000, 4))

In [None]:
ENTITY_TYPES = [
    "PERSON",
    "ORG",
    "LOCATION",
    "EVENT",
    "WORK",
    "DATE",
    "OTHER",
]

SYSTEM_PROMPT = """
some prompt here...
"""

def extract_entities_for_chunk(text: str) -> Dict[str, Any]:
    """
    Extract entities + relations using OpenAI with strict constraints.
    The system prompt prevents hallucination and external inference.
    """
    text = text[:MAX_CHARS_PER_CHUNK]

    user_content = (
        "Extract entities and relationships from this text. "
        "Stay strictly within what the text explicitly states.\n\n"
        f"Text:\n{text}\n\n"
        "Respond ONLY with valid JSON."
    )

    resp = client.chat.completions.create(
        model=OPENAI_MODEL,
        messages=[
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": user_content},
        ],
        response_format={"type": "json_object"},
    )

    parsed = json.loads(resp.choices[0].message.content)
    return {
        "entities": parsed.get("entities", []) or [],
        "relations": parsed.get("relations", []) or [],
    }


In [4]:
import tiktoken

# Configure context usage
MODEL_CTX = 400_000                    # GPT-5 nano context window
TARGET_FRACTION = 0.7                  # use 70 percent of context
MAX_INPUT_TOKENS_PER_REQ = int(MODEL_CTX * TARGET_FRACTION)

# Hard cap on number of chunks per request so JSON stays manageable
MAX_ITEMS_PER_BATCH = 200

encoding = tiktoken.get_encoding("cl100k_base")

def iter_token_batches(df_subset, max_tokens=MAX_INPUT_TOKENS_PER_REQ, max_items=MAX_ITEMS_PER_BATCH):
    """
    Yield lists of rows (as namedtuples) so that the combined input tokens
    stay under max_tokens and we never exceed max_items per batch.
    """
    batch = []
    token_count = 0

    # Rough overhead per item in the JSON wrapper and prompt text
    PER_ITEM_OVERHEAD = 20
    BASE_OVERHEAD = 300   # prompt, instructions, JSON scaffolding

    for row in df_subset.itertuples(index=False):
        text = row.text or ""
        n_tokens = len(encoding.encode(text))
        needed = n_tokens + PER_ITEM_OVERHEAD

        # If adding this item would exceed budget or item cap, emit current batch
        if batch and (token_count + needed + BASE_OVERHEAD > max_tokens or len(batch) >= max_items):
            yield batch
            batch = []
            token_count = 0

        batch.append(row)
        token_count += needed

    if batch:
        yield batch


In [8]:
from openai import OpenAI
import json
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY_CHAT"))

GPT5_NANO_MODEL = os.getenv("OPENAI_MODEL")  # e.g. "gpt-5-nano-2025-08-07"


def extract_entities_batch(rows_batch):
    """
    rows_batch: list of namedtuples from df.itertuples()
    Returns list of dicts:
      {chunk_id, doc_id, entities, relations}
    in the same order as rows_batch.
    """
    if not rows_batch:
        return []

    # Build compact payload for the model
    items = []
    for r in rows_batch:
        items.append(
            {
                "chunk_id": r.chunk_id,
                "doc_id": r.doc_id,
                "text": r.text or "",
            }
        )

    system_msg = (
        "Your task is to process short text snippets from chunked documents,\n"
        "identifying key entities and relationships WITHOUT adding any external knowledge,\n"
        "assumptions, or hallucinations. Stick strictly to what is explicitly stated in the text.\n"
        "Extraction rules:\n"
        "1. Only extract entities that are clearly and directly mentioned.\n"
        "2. Use these entity types (aligned with investigations, legal proceedings, and networks):\n"
        "  PERSON: Named individuals, aliases, or explicitly identified roles\n"
        "          (victims, witnesses, accused, associates, officials).\n"
        "  ORG: Organizations, companies, foundations, institutions, law firms, airlines, etc.\n"
        "  LOCATION: Places, addresses, properties, geographic areas, residences, islands, cities.\n"
        "  EVENT: Explicit incidents, meetings, trips, legal actions, interrogations, flights, parties.\n"
        "  WORK: Documents, logs, books, media, lists (for example “flight logs”, “black book”).\n"
        "  DATE: Explicit dates, times, periods, years. Only extract literal time references.\n"
        "  OTHER: Only if necessary. Tail numbers, phone numbers, financial amounts.\n"
        "          Use sparingly and with justification.\n"
        "3. Relationships:\n"
        "  Only extract relationships that are explicitly implied by the text.\n"
        "  Do NOT infer anything or rely on outside knowledge.\n"
        "  Use descriptive predicates such as:\n"
        "    'associated_with'\n"
        "    'victim_of'\n"
        "    'accused_by'\n"
        "    'traveled_with'\n"
        "    'traveled_to'\n"
        "    'met_with'\n"
        "    'owned_by'\n"
        "    'mentioned_in'\n"
        "    'connected_to'\n"
        "  Only include a relationship if the text directly suggests the connection.\n"
        "You will receive a JSON array under the key 'items'. Each element has:\n"
        "  - chunk_id (string)\n"
        "  - doc_id (string or number)\n"
        "  - text (string)\n"
        "For each item, extract entities and relations.\n"
        "Return a JSON object with a single key 'results' that is an array.\n"
        "Each result must be in the same order as the input and have:\n"
        "  - chunk_id\n"
        "  - doc_id\n"
        "  - entities: list of { name, type }\n"
        "  - relations: list of { subject, predicate, object }\n"
        "Do NOT output anything outside the JSON object.\n"
    )

    user_payload = {"items": items}
    user_msg = json.dumps(user_payload, ensure_ascii=False)

    # Use chat.completions with JSON mode
    resp = client.chat.completions.create(
        model=GPT5_NANO_MODEL,
        messages=[
            {"role": "system", "content": system_msg},
            {"role": "user", "content": user_msg},
        ],
        response_format={"type": "json_object"},
    )

    out_text = resp.choices[0].message.content
    data = json.loads(out_text)

    results = data.get("results", [])

    if len(results) != len(rows_batch):
        raise ValueError(
            f"Model returned {len(results)} results for {len(rows_batch)} inputs"
        )

    return results


In [9]:
from time import sleep, time
import json
import os

# ---------------- CONFIG FOR MONITORING ----------------

MAX_ROWS_FOR_RUN = len(df)      # or len(df) when ready
PROGRESS_EVERY_BATCH = 5        # print every N batches
SAVE_EVERY_BATCH = 5            # flush every N batches
SLEEP_BETWEEN_BATCHES = 0       # set to small number if you hit RPM limits

results_path = "entity_extraction_raw.jsonl"
failed_path = "entity_extraction_failed.jsonl"

truncate_outputs = True

if truncate_outputs:
    open(results_path, "w", encoding="utf-8").close()
    open(failed_path, "w", encoding="utf-8").close()

df_subset = df.iloc[:MAX_ROWS_FOR_RUN].copy()
total_rows = len(df_subset)

print("=" * 80)
print(f"Starting batched entity extraction on {total_rows} chunks")
print(f"Model: {GPT5_NANO_MODEL}")
print(f"Results file: {results_path}")
print(f"Failed file:  {failed_path}")
print("=" * 80)

success_rows = 0
failed_rows = 0
batch_index = 0
start_time = time()

with open(results_path, "a", encoding="utf-8") as f_ok, \
     open(failed_path, "a", encoding="utf-8") as f_fail:

    for batch_rows in iter_token_batches(df_subset):
        batch_index += 1
        batch_size = len(batch_rows)
        start_row = success_rows + failed_rows + 1
        end_row = start_row + batch_size - 1

        try:
            batch_results = extract_entities_batch(batch_rows)

            # Write one line per chunk, same format as before
            for r in batch_results:
                out = {
                    "chunk_id": r.get("chunk_id"),
                    "doc_id": r.get("doc_id"),
                    "entities": r.get("entities", []),
                    "relations": r.get("relations", []),
                }
                f_ok.write(json.dumps(out, ensure_ascii=False) + "\n")

            success_rows += batch_size

        except Exception as e:
            # If the whole batch fails, record all rows in failed file
            print(f"[Batch {batch_index}] Error: {e}")
            for row in batch_rows:
                err = {
                    "chunk_id": row.chunk_id,
                    "doc_id": row.doc_id,
                    "error": str(e),
                }
                f_fail.write(json.dumps(err, ensure_ascii=False) + "\n")
            failed_rows += batch_size

        # Progress and ETA
        if batch_index % PROGRESS_EVERY_BATCH == 0 or success_rows + failed_rows >= total_rows:
            done = success_rows + failed_rows
            elapsed = time() - start_time
            rate = done / elapsed if elapsed > 0 else 0
            rows_per_min = rate * 60 if rate > 0 else 0
            remaining = total_rows - done
            eta_sec = remaining / rate if rate > 0 else 0
            pct = done / total_rows * 100

            print(
                f"[Batch {batch_index:4d}] rows {done:6d}/{total_rows:6d} "
                f"({pct:5.1f}%) succ={success_rows:6d} fail={failed_rows:5d} "
                f"elapsed={elapsed/60:5.1f}m eta={eta_sec/60:5.1f}m "
                f"rows/min={rows_per_min:6.1f}"
            )

        if batch_index % SAVE_EVERY_BATCH == 0:
            f_ok.flush()
            f_fail.flush()
            print(
                f"  Flushed at batch {batch_index}. "
                f"Success rows={success_rows}, Failed rows={failed_rows}"
            )

        if SLEEP_BETWEEN_BATCHES > 0:
            sleep(SLEEP_BETWEEN_BATCHES)

elapsed_total = time() - start_time
print("=" * 80)
print("Done batched entity extraction loop")
print(f"Success rows: {success_rows}, Failed rows: {failed_rows}")
print(f"Total time: {elapsed_total/60:5.1f} minutes")
print(f"Final results in {results_path}")
if failed_rows > 0:
    print(f"Failed rows in {failed_path}")
print("=" * 80)


Starting batched entity extraction on 2000 chunks
Model: gpt-5-nano-2025-08-07
Results file: entity_extraction_raw.jsonl
Failed file:  entity_extraction_failed.jsonl
[Batch 1] Error: Model returned 18 results for 200 inputs
[Batch 2] Error: Model returned 25 results for 200 inputs
[Batch 3] Error: Model returned 0 results for 200 inputs


KeyboardInterrupt: 

In [None]:
# Build a DataFrame of entity mentions (one row per mention)
mention_rows = []

for r in results:
    chunk_id = r["chunk_id"]
    doc_id = r["doc_id"]
    ents = r["entities"]
    for e in ents:
        name = (e.get("name") or "").strip()
        etype = (e.get("type") or "OTHER").upper()
        if not name:
            continue
        mention_rows.append(
            {
                "chunk_id": chunk_id,
                "doc_id": doc_id,
                "name": name,
                "type": etype,
            }
        )

mentions_df = pd.DataFrame(mention_rows)
print("Mention rows:", len(mentions_df))
mentions_df.head()


In [None]:
def normalize_name(name: str) -> str:
    n = name.strip()
    n = " ".join(n.split())
    return n

mentions_df["norm_name"] = mentions_df["name"].apply(normalize_name)

# group by normalized name + type
grouped = mentions_df.groupby(["norm_name", "type"])

entity_rows = []
entity_id_map: Dict[Tuple[str, str], int] = {}
next_eid = 1

for (norm_name, etype), group in grouped:
    eid = next_eid
    entity_id_map[(norm_name, etype)] = eid
    next_eid += 1

    doc_ids = sorted(set(group["doc_id"].dropna().tolist()))
    chunk_ids = sorted(set(group["chunk_id"].dropna().tolist()))
    freq = len(group)

    entity_rows.append(
        {
            "entity_id": eid,
            "name": norm_name,
            "type": etype,
            "frequency": freq,
            "doc_ids": json.dumps(doc_ids),
            "chunk_ids": json.dumps(chunk_ids),
        }
    )

entities_df = pd.DataFrame(entity_rows)
print("Canonical entities:", len(entities_df))
entities_df.head()

def get_eid(row):
    key = (row["norm_name"], row["type"])
    return entity_id_map.get(key)

mentions_df["entity_id"] = mentions_df.apply(get_eid, axis=1)
mentions_df.head()



In [None]:
# Chunk -> Entity edges
mention_edges = []
for _, row in mentions_df.iterrows():
    mention_edges.append(
        ("chunk:" + str(row["chunk_id"]), "MENTIONS", "entity:" + str(row["entity_id"]))
    )

len(mention_edges)
# Entity co-occurrence edges based on chunk co-appearance

co_counts = Counter()

for chunk_id, grp in mentions_df.groupby("chunk_id"):
    eids = sorted(set(grp["entity_id"].dropna().tolist()))
    for i in range(len(eids)):
        for j in range(i + 1, len(eids)):
            a, b = eids[i], eids[j]
            key = (a, b)
            co_counts[key] += 1

co_rows = []
for (a, b), w in co_counts.items():
    co_rows.append(
        {
            "source": "entity:" + str(a),
            "target": "entity:" + str(b),
            "relation": "RELATED_TO",
            "weight": w,
        }
    )

edges_df = pd.DataFrame(co_rows)
print("Entity-entity edges:", len(edges_df))
edges_df.head()


In [None]:
relation_rows = []

for r in results:
    chunk_id = r["chunk_id"]
    rels = r["relations"]
    ents = r["entities"]

    # build map from name -> (norm, type, eid)
    local_map = {}
    for e in ents:
        name = (e.get("name") or "").strip()
        etype = (e.get("type") or "OTHER").upper()
        if not name:
            continue
        norm = normalize_name(name)
        eid = entity_id_map.get((norm, etype))
        if eid:
            local_map[name] = eid

    for rel in rels:
        subj = rel.get("subject")
        obj = rel.get("object")
        pred = rel.get("predicate") or "connected_to"

        if not subj or not obj:
            continue
        eid_s = local_map.get(subj)
        eid_o = local_map.get(obj)
        if not eid_s or not eid_o:
            continue

        relation_rows.append(
            {
                "source": "entity:" + str(eid_s),
                "target": "entity:" + str(eid_o),
                "relation": pred,
                "weight": 1,
            }
        )

explicit_rel_df = pd.DataFrame(relation_rows)
print("Explicit relations edges:", len(explicit_rel_df))
explicit_rel_df.head()

all_edges_df = pd.concat([edges_df, explicit_rel_df], ignore_index=True)
print("Total edges:", len(all_edges_df))
all_edges_df.head()


In [None]:
entities_df.to_csv("entities.csv", index=False)
all_edges_df.to_csv("edges.csv", index=False)
print("Wrote entities.csv and edges.csv")