# 3.2 Learning our correlational matrices $A_w$

### Our correlational matrices $A_w$ = {$W_1$. $W_2$, .... $W_k$}, these are a set of K weighted adjacency matrices.. Each of $W_r \in [0,1]^{N \times N}$ above represents reationships between any of the N nodes and any of the other N nodes for the specific relation r. Hence $|E_r|$ is the number of non-zero entries in $W_r$. 
Note: Here we are just finding the 1-hop correlational links.

### Hence total edges would be |E| = $\Sigma |E_r|$

## 3.2.1 Document Parsing and Global Node identification

In [None]:
# This part is just for importing some important stuff throughout, just pretty

from dotenv import load_dotenv 
load_dotenv()

import os
from pprint import pprint

import google.generativeai as genai

In [None]:
import google.generativeai as genai

genai.configure(api_key=os.environ["GEMINI_API_KEY"])

print("Available models for content generation:")
for m in genai.list_models():
    if 'generateContent' in m.supported_generation_methods:
        print(m.name)

#### Chunk the text appropriately (NM just basic token chunking)

In [None]:
if __name__ == "__main__":
    # Load text from a file instead of using dedent
    file_path = "testing_text/sample_text.txt"
    with open(file_path, "r", encoding="utf-8") as f:
        document_text = f.read()

In [None]:
import os
import json
from typing import List

# -------------------------
# Chunking logic (same as before)
# -------------------------
def chunk_text(text: str, size: int = 1800) -> List[str]:
    chunks = []
    start = 0
    while start < len(text):
        end = min(start + size, len(text))
        chunks.append(text[start:end])
        start = end
    return chunks


# -------------------------
# Save chunks into JSON + per-chunk text files
# -------------------------
def save_document_chunks(document_text: str, out_dir: str = "extracted_output/chunks", chunk_size: int = 1800):
    os.makedirs(out_dir, exist_ok=True)

    chunks = chunk_text(document_text, size=chunk_size)

    chunk_records = []

    for idx, chunk in enumerate(chunks):
        record = {
            "chunk_index": idx,
            "chunk_text": chunk
        }
        chunk_records.append(record)

        # Save each chunk as a .txt file
        with open(os.path.join(out_dir, f"chunk_{idx}.txt"), "w", encoding="utf-8") as f:
            f.write(chunk)

    # Also save a JSON list of all chunks
    with open(os.path.join(out_dir, "chunks.json"), "w", encoding="utf-8") as f:
        json.dump(chunk_records, f, indent=2, ensure_ascii=False)

    print(f"Saved {len(chunks)} chunks into {out_dir}")
    return chunk_records


# -------------------------
# Example usage
# -------------------------
if __name__ == "__main__":
    # Load text from a file instead of using dedent
    file_path = "testing_text/sample_text.txt"
    with open(file_path, "r", encoding="utf-8") as f:
        document_text = f.read()

    save_document_chunks(document_text)



#### Now lets Extract our necessary entities (Do note the node format)

In [None]:
import os
import json
import time
from typing import List, Dict, Any, Optional, Tuple

from google.genai import Client, types
from google.genai import errors as genai_errors
from pydantic import BaseModel
from typing import Literal

# =========================
# CONFIG
# =========================
MODEL = "gemini-2.5-flash-lite"

# NEW: fallback queue (added without removing original MODEL)
MODEL_CANDIDATES = [
    "gemini-2.5-flash-lite",
    "gemini-2.5-flash",
    "gemini-2.0-flash-lite",
]

CHUNK_SIZE = 1800        # characters per chunk (roughly page-ish)
DELAY_SECONDS = 15        # small delay between calls (lite has better quota, but be nice)

client = Client(api_key=os.environ["GEMINI_API_KEY"])


# =========================
# Pydantic Schema for Entities
# =========================

class EntityNode(BaseModel):
    """
    Single entity/event node with metadata.
    We let the LLM fill these fields; we will assign global IDs later.
    """
    name: str
    type: Literal["event", "entity", "actor", "amount", "location"]
    time: Optional[str] = None        # e.g. "July 15th"
    location: Optional[str] = None    # e.g. "Valeron banking district"
    description: Optional[str] = None # short gloss / summary


# =========================
# Safe LLM Wrapper
# =========================

def safe_llm_structured(
    prompt: str,
    schema,
    temperature: float = 0.2,
    max_retries: int = 3,
):
    """
    Calls Gemini with a structured response schema (Pydantic / typing).
    Retries on transient 429 / RESOURCE_EXHAUSTED.
    """

    last_exc: Optional[Exception] = None

    # NEW: Loop over fallback models while preserving your original behavior
    for model_name in MODEL_CANDIDATES:

        print(f"\n[LLM] Trying model: {model_name}")

        config = types.GenerateContentConfig(
            temperature=temperature,
            response_mime_type="application/json",
            response_schema=schema,
        )

        for attempt in range(1, max_retries + 1):
            try:
                print(f"[LLM] Call attempt {attempt} with {model_name}")
                resp = client.models.generate_content(
                    model=model_name,
                    contents=prompt,
                    config=config,
                )
                return resp

            except genai_errors.ClientError as e:
                msg = str(e)
                last_exc = e

                # NEW: Hard quota exhaustion → switch model
                if "quota" in msg.lower() and ("exceeded" in msg.lower() or "exhausted" in msg.lower()):
                    print(f"[LLM] Daily quota exhausted for {model_name}. Switching to next model...")
                    break  # break retry loop, move to next model

                # ORIGINAL: retry on transient errors
                if "RESOURCE_EXHAUSTED" in msg or "429" in msg:
                    print("[LLM] Rate limit / quota hit, backing off...")
                    time.sleep(DELAY_SECONDS)
                    continue

                # ORIGINAL: raise all other errors
                raise

        print(f"[LLM] Model {model_name} failed after retries, moving to next fallback...")

    # If we exhaust all models, raise the last caught exception
    if last_exc:
        raise last_exc
    raise RuntimeError("safe_llm_structured failed unexpectedly.")


# =========================
# Chunking
# =========================

def chunk_text(text: str, size: int = CHUNK_SIZE) -> List[str]:
    """
    Simple character-based chunking.
    (You can replace this later with a semantic chunker if you want.)
    """
    chunks = []
    start = 0
    while start < len(text):
        end = min(start + size, len(text))
        chunks.append(text[start:end])
        start = end
    return chunks


# =========================
# Entity Extraction per Chunk (GraphRAG-style “graph extraction”)
# =========================

ENTITY_PROMPT_TEMPLATE = """
You are an information extraction system building a knowledge graph.

From the following text chunk, extract all important EVENTS and ENTITIES as
structured JSON.

For each item, return an object with:
- "name": concise description of the event/entity (no longer than 15 words)
- "type": one of ["event", "entity", "actor", "amount", "location"]
- "time":  date/time phrase exactly as in the text, if any
- "location":  location phrase exactly as in the text, if any
- "description":  1–2 sentence summary of the role of this node

Return ONLY a JSON array, for example:

[
  {{
    "name": "FCU announces formal investigation",
    "type": "event",
    "time": "July 15th",
    "location": "Valeron",
    "description": "The Financial Crimes Unit officially launches its probe into Titan."
  }},
  ...
]

TEXT CHUNK:
{chunk}
"""


# Wrapper model because Gemini cannot return a top-level List[..]
class EntityList(BaseModel):
    items: List[EntityNode]


def extract_entities_from_chunk(chunk: str) -> List[EntityNode]:
    """
    Sends one chunk to Gemini and returns a list[EntityNode].
    """
    prompt = ENTITY_PROMPT_TEMPLATE.format(chunk=chunk)

    resp = safe_llm_structured(
        prompt=prompt,
        schema=EntityList,
        temperature=0.2,
    )

    parsed = getattr(resp, "parsed", None)

    if isinstance(parsed, EntityList):
        return parsed.items

    # Fallback: try raw JSON if schema somehow failed
    try:
        raw = json.loads(resp.text)
        return [EntityNode(**obj) for obj in raw]
    except:
        print("[WARN] Empty/invalid entity response; returning [].")
        return []


# =========================
# High-level Pipeline
# =========================

def extract_entities_with_metadata(document_text: str) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
    """
    End-to-end pipeline:

    1. Chunk text.
    2. For each chunk, call LLM to get EntityNodes.
    3. Add chunk index as provenance.
    4. Deduplicate globally.
    5. Assign global IDs N1, N2, ...

    Returns:
        nodes: list of dicts with id, name, type, time, location, description, source_chunks
        raw_per_chunk: list of dicts for debugging/inspection: {"chunk_index", "chunk_text", "entities"}
    """
    chunks = chunk_text(document_text)
    all_raw_entities: List[Dict[str, Any]] = []
    provenance_records: List[Dict[str, Any]] = []

    print(f"Total chunks: {len(chunks)}")

    for idx, chunk in enumerate(chunks):
        print(f"\n--- Processing chunk {idx + 1}/{len(chunks)} ---")
        entities = extract_entities_from_chunk(chunk)

        # Save raw (per-chunk) info for debugging / analysis
        provenance_records.append(
            {
                "chunk_index": idx,
                "chunk_text": chunk,
                "entities": [e.dict() for e in entities],
            }
        )

        # Attach chunk index to each entity for later provenance
        for e in entities:
            d = e.dict()
            d["_source_chunks"] = {idx}  # use a set for merging later
            all_raw_entities.append(d)

        time.sleep(0.5)  # very light additional pacing

    # ------------- Deduplication across chunks -------------
    # Key: (name, type, time, location)
    merged: Dict[Tuple[str, str, Optional[str], Optional[str]], Dict[str, Any]] = {}

    for ent in all_raw_entities:
        key = (ent["name"], ent["type"], ent.get("time"), ent.get("location"))

        if key not in merged:
            merged[key] = ent
        else:
            # Merge provenance + maybe pick longer description
            merged[key]["_source_chunks"] |= ent["_source_chunks"]
            desc1 = merged[key].get("description") or ""
            desc2 = ent.get("description") or ""
            # Keep the longer, more informative description
            if len(desc2) > len(desc1):
                merged[key]["description"] = desc2

    dedup_entities = list(merged.values())

    # Assign global IDs N1, N2, ...
    nodes: List[Dict[str, Any]] = []
    for i, ent in enumerate(dedup_entities, start=1):
        node = {
            "id": f"N{i}",
            "name": ent["name"],
            "type": ent["type"],
            "time": ent.get("time"),
            "location": ent.get("location"),
            "description": ent.get("description"),
            # convert set of chunk indices to sorted list
            "source_chunks": sorted(list(ent["_source_chunks"])),
        }
        nodes.append(node)

    print(f"\nExtracted {len(nodes)} unique entity/event nodes.")
    return nodes, provenance_records


# =========================
# Example Usage
# =========================

if __name__ == "__main__":
    # 1) Supply your document text here
    #    e.g., from a file:
    #
    # with open("my_long_doc.txt", "r", encoding="utf-8") as f:
    #     document_text = f.read()
    #
    # For now, just set document_text = sample_text variable.
    from textwrap import dedent

    # 2) Run extraction
    nodes, per_chunk = extract_entities_with_metadata(document_text)

    # 3) Save results (GraphRAG-style “graph docs” but only nodes for now)
    os.makedirs("extracted_output", exist_ok=True)

    with open("extracted_output/entities.json", "w", encoding="utf-8") as f:
        json.dump(nodes, f, indent=2, ensure_ascii=False)

    with open("extracted_output/entities_per_chunk_debug.json", "w", encoding="utf-8") as f:
        json.dump(per_chunk, f, indent=2, ensure_ascii=False)

    print("\nSaved:")
    print("  extracted_output/entities.json")
    print("  extracted_output/entities_par_chunk_debug.json")


In [None]:
## Load the saved entities

import json
import os

def load_extracted_entities(base_dir="extracted_output"):
    """
    Loads:
      - entities.json              → list of deduplicated entity/event nodes
      - entities_per_chunk_debug.json → raw per-chunk extraction data
    """

    entities_path = os.path.join(base_dir, "entities.json")
    per_chunk_path = os.path.join(base_dir, "entities_per_chunk_debug.json")

    if not os.path.exists(entities_path):
        raise FileNotFoundError(f"Missing file: {entities_path}")

    if not os.path.exists(per_chunk_path):
        raise FileNotFoundError(f"Missing file: {per_chunk_path}")

    with open(entities_path, "r", encoding="utf-8") as f:
        entities = json.load(f)

    with open(per_chunk_path, "r", encoding="utf-8") as f:
        per_chunk = json.load(f)

    print("Loaded:")
    print(" - entities.json")
    print(" - entities_per_chunk_debug.json")

    return entities, per_chunk


# Example usage:
if __name__ == "__main__":
    nodes, nodes_per_chunk = load_extracted_entities()
    print(f"Total nodes loaded: {len(nodes)}")
    print(f"Total chunks loaded: {len(nodes_per_chunk)}")


#### Now lets Extract our necessary Relations (Do note the format) for the above extracted node entities

In [None]:
import os
import json
import time
from typing import List, Dict, Any, Optional, Tuple

from google.genai import Client, types
from google.genai import errors as genai_errors
from pydantic import BaseModel, Field

# =========================
# CONFIG
# =========================
# List of models to try, in order.
MODEL_CANDIDATES = [
    "gemini-2.5-flash-lite",
    "gemini-2.5-flash",
    "gemini-2.0-flash-lite",
]

MODEL = MODEL_CANDIDATES[0]
CURRENT_MODEL_INDEX = 0
EXHAUSTED_MODELS = set()

DELAY_SECONDS = 15  # light rate limiting

client = Client(api_key=os.environ["GEMINI_API_KEY"])


# =========================
# Pydantic Schemas for Relations
# =========================

class RelationEdge(BaseModel):
    """
    Single relation between two graph nodes (by node ID).
    """
    source_id: str = Field(
        description="ID of the source node (e.g., 'N3'). Must come from the provided node list."
    )
    target_id: str = Field(
        description="ID of the target node (e.g., 'N7'). Must come from the provided node list."
    )
    relation: str = Field(
        description=(
            "A short relation phrase (2-4 words max) describing the semantic relation "
            "between source and target (e.g., 'leads to', 'investigates', 'causes', "
            "'associated with'). Do NOT output a full sentence."
        )
    )
    description: Optional[str] = Field(
        default=None,
        description="One short sentence explaining the relation in natural language."
    )
    evidence: Optional[str] = Field(
        default=None,
        description="A sentence or short snippet copied from the text that supports this relation."
    )
    confidence: float = Field(
        description="A float between 0.0 and 1.0 indicating confidence in the correctness of this relation."
    )


class RelationList(BaseModel):
    """
    Wrapper for a list of relations (Gemini can't return top-level bare list).
    """
    relations: List[RelationEdge] = Field(
        description="List of directional relations between the provided nodes."
    )


# =========================
# Helper for model switching
# =========================

def _switch_to_next_model():
    """
    Mark current MODEL as exhausted and switch to the next available one
    in MODEL_CANDIDATES. Raises if all are exhausted.
    """
    global MODEL, CURRENT_MODEL_INDEX

    EXHAUSTED_MODELS.add(MODEL)

    for i in range(len(MODEL_CANDIDATES)):
        idx = (CURRENT_MODEL_INDEX + 1 + i) % len(MODEL_CANDIDATES)
        candidate = MODEL_CANDIDATES[idx]
        if candidate not in EXHAUSTED_MODELS:
            MODEL = candidate
            CURRENT_MODEL_INDEX = idx
            print(f"[LLM] (relations) Switching to backup model: {MODEL}")
            return

    raise RuntimeError("All configured models appear exhausted/quota-limited for today (relations).")


# =========================
# Safe LLM Wrapper (Structured)
# =========================

def safe_llm_structured(
    prompt: str,
    schema,
    temperature: float = 0.1,
    max_retries: int = 3,
):
    """
    Calls Gemini with a structured response schema.
    Retries on transient 429 / RESOURCE_EXHAUSTED.
    Also rotates across multiple models if one hits quota.
    """
    last_exc: Optional[Exception] = None

    for attempt in range(1, max_retries + 1):
        from_model = MODEL
        config = types.GenerateContentConfig(
            temperature=temperature,
            response_mime_type="application/json",
            response_schema=schema,
        )

        try:
            print(f"[LLM] (relations) Call attempt {attempt} with {from_model}")
            resp = client.models.generate_content(
                model=from_model,
                contents=prompt,
                config=config,
            )
            return resp
        except genai_errors.ClientError as e:
            msg = str(e)
            if (
                "RESOURCE_EXHAUSTED" in msg
                or "429" in msg
                or "exceeded your current quota" in msg.lower()
            ):
                print(f"[LLM] (relations) Model {from_model} hit quota / rate limit: {msg}")
                last_exc = e
                try:
                    _switch_to_next_model()
                except RuntimeError as switch_err:
                    print("[LLM] (relations) No backup models left.")
                    raise switch_err
                time.sleep(DELAY_SECONDS)
                continue

            raise

    if last_exc:
        raise last_exc
    raise RuntimeError("safe_llm_structured (relations) failed unexpectedly.")


# =========================
# Relation Extraction Prompt
# =========================

RELATION_PROMPT_TEMPLATE = """
You are building a knowledge graph from a document.

You are given:
1) A TEXT CHUNK from the document.
2) A list of NODES that appear in this chunk. Each node has:
   - "id": global node ID like "N3"
   - "name": short label
   - "type": one of ["event", "entity", "actor", "amount", "location"]
   - "time": optional date/time
   - "location": optional location
   - "description": optional summary

Your task: Extract all meaningful, **directional** relations between these nodes
based ONLY on the given TEXT CHUNK.

Rules:
- Only use node IDs that appear in the provided NODES list.
- Each relation is directional: (source_id -> target_id).
- The "relation" field must be a **short phrase (2–4 words)**, NOT a full sentence.
  Examples: "leads to", "causes", "results in", "investigates", "freezes assets of",
  "occurs after", "associated with".
- "description": one short sentence in natural language explaining the relation.
- "evidence": copy a sentence or short snippet from the TEXT CHUNK that justifies this relation.
- If you are NOT confident about any relation between a pair, do not invent one.
- It is allowed to return an empty list if no strong relations are present.
- Include a 'confidence' score between 0.0 and 1.0. 
  Use ≥0.9 only when the relation is explicitly and unambiguously stated.


Return ONLY a JSON object in this exact form:
{{
  "relations": [
    {{
      "source_id": "N3",
      "target_id": "N7",
      "relation": "leads to",
      "description": "FCU investigation follows the spike in black market sales.",
      "evidence": "On July 15th, the Financial Crimes Unit (FCU) announced a formal investigation...",
      "confidence": 0.92
    }},
    ...
  ]
}}

TEXT CHUNK:
{chunk_text}

NODES IN THIS CHUNK:
{nodes_json}
"""


# =========================
# Utilities to Map Entities -> Node IDs
# =========================

def build_entity_key(ent: Dict[str, Any]) -> Tuple[str, str, Optional[str], Optional[str]]:
    """
    Create a stable key from an entity dict:
    (name, type, time, location)
    This matches the dedup key used when building entities.json.
    """
    return (
        ent["name"],
        ent["type"],
        ent.get("time"),
        ent.get("location"),
    )


def build_node_key_index(nodes: List[Dict[str, Any]]) -> Dict[Tuple[str, str, Optional[str], Optional[str]], str]:
    """
    From the global node list (entities.json), build a mapping:
        (name, type, time, location) -> node_id ("N1", "N2", ...)
    """
    index: Dict[Tuple[str, str, Optional[str], Optional[str]], str] = {}
    for node in nodes:
        key = (
            node["name"],
            node["type"],
            node.get("time"),
            node.get("location"),
        )
        index[key] = node["id"]
    return index


# =========================
# Per-chunk Relation Extraction
# =========================

def extract_relations_for_chunk(
    chunk_record: Dict[str, Any],
    node_key_index: Dict[Tuple[str, str, Optional[str], Optional[str]], str],
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
    """
    For a single chunk:
    - Build the list of chunk-local nodes with their global IDs.
    - Ask Gemini to extract relations between them.
    - Return:
        relations_flat: List[dict] with source_id, target_id, relation, description, evidence, confidence, _chunk_index
        debug_record:   Dict with chunk_index, chunk_text, nodes, relations (for debug JSON)
    """
    chunk_index = chunk_record["chunk_index"]
    chunk_text = chunk_record["chunk_text"]
    entities_in_chunk = chunk_record["entities"]

    # Map chunk entities -> node IDs using key (name, type, time, location)
    chunk_nodes: List[Dict[str, Any]] = []
    for ent in entities_in_chunk:
        key = build_entity_key(ent)
        node_id = node_key_index.get(key)
        if not node_id:
            # Best-effort: skip entities that didn't make it into the final node set
            continue
        chunk_nodes.append(
            {
                "id": node_id,
                "name": ent["name"],
                "type": ent["type"],
                "time": ent.get("time"),
                "location": ent.get("location"),
                "description": ent.get("description"),
            }
        )

    # If fewer than 2 nodes in this chunk, no relations to extract
    if len(chunk_nodes) < 2:
        debug_record = {
            "chunk_index": chunk_index,
            "chunk_text": chunk_text,
            "nodes": chunk_nodes,
            "relations": [],
        }
        return [], debug_record

    prompt = RELATION_PROMPT_TEMPLATE.format(
        chunk_text=chunk_text,
        nodes_json=json.dumps(chunk_nodes, indent=2, ensure_ascii=False),
    )

    resp = safe_llm_structured(
        prompt=prompt,
        schema=RelationList,
        temperature=0.1,
    )

    parsed = getattr(resp, "parsed", None)
    relations: List[RelationEdge] = []

    if isinstance(parsed, RelationList):
        relations = parsed.relations
    else:
        # Fallback: try raw JSON
        try:
            raw = json.loads(resp.text)
            raw_rels = raw.get("relations", [])
            relations = [RelationEdge(**r) for r in raw_rels]
        except Exception:
            print(f"[WARN] Invalid relation response for chunk {chunk_index}; treating as empty.")
            relations = []

    # Convert to plain dicts for serialization, tagging with chunk index
    relations_flat: List[Dict[str, Any]] = []
    for r in relations:
        d = r.dict()
        d["_chunk_index"] = chunk_index
        relations_flat.append(d)

    debug_record = {
        "chunk_index": chunk_index,
        "chunk_text": chunk_text,
        "nodes": chunk_nodes,
        "relations": relations_flat,
    }

    return relations_flat, debug_record


# =========================
# High-level Pipeline
# =========================

def extract_relations_from_entities_and_chunks(
    nodes: List[Dict[str, Any]],
    per_chunk: List[Dict[str, Any]],
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
    """
    End-to-end relation extraction:

    Inputs:
      - nodes:       global entity/event nodes (id, name, type, time, location, ...)
      - per_chunk:   list of {"chunk_index", "chunk_text", "entities": [...]}
                     as produced by entities_per_chunk_debug.json

    Returns:
      - edges: list of deduplicated relations across the document, each:
            {
              "source_id": "N3",
              "target_id": "N7",
              "relation": "leads to",
              "description": "...",
              "evidence": "...",
              "confidence": 0.91,
              "source_chunks": [0, 2, ...]
            }
      - per_chunk_relations: list of debug records for each chunk:
            {
              "chunk_index": ...,
              "chunk_text": ...,
              "nodes": [...],
              "relations": [...]
            }
    """
    node_key_index = build_node_key_index(nodes)

    all_relations_flat: List[Dict[str, Any]] = []
    per_chunk_relations: List[Dict[str, Any]] = []

    print(f"Total chunks to process for relations: {len(per_chunk)}")

    for idx, chunk_rec in enumerate(per_chunk):
        print(f"\n--- Extracting relations for chunk {idx + 1}/{len(per_chunk)} ---")

        rels_flat, debug_rec = extract_relations_for_chunk(chunk_rec, node_key_index)
        all_relations_flat.extend(rels_flat)
        per_chunk_relations.append(debug_rec)

        time.sleep(0.5)  # light pacing

    # -------- Deduplicate relations across document ----------
    merged_edges: Dict[Tuple[str, str, str], Dict[str, Any]] = {}

    for rel in all_relations_flat:
        key = (rel["source_id"], rel["target_id"], rel["relation"])
        chunk_idx = rel.get("_chunk_index")
        conf = float(rel.get("confidence", 0.0))

        if key not in merged_edges:
            merged_edges[key] = {
                "source_id": rel["source_id"],
                "target_id": rel["target_id"],
                "relation": rel["relation"],
                "description": rel.get("description"),
                "evidence": rel.get("evidence"),
                "confidence_sum": conf,
                "count": 1,
                "source_chunks": set([chunk_idx] if chunk_idx is not None else []),
            }
        else:
            payload = merged_edges[key]
            # aggregate confidence as average over occurrences
            payload["confidence_sum"] += conf
            payload["count"] += 1
            if chunk_idx is not None:
                payload["source_chunks"].add(chunk_idx)
            # keep the first non-empty description / evidence if missing
            if not payload.get("description") and rel.get("description"):
                payload["description"] = rel["description"]
            if not payload.get("evidence") and rel.get("evidence"):
                payload["evidence"] = rel["evidence"]

    edges: List[Dict[str, Any]] = []
    for (s, t, r), payload in merged_edges.items():
        avg_conf = (
            payload["confidence_sum"] / payload["count"]
            if payload["count"] > 0 else 0.0
        )
        edges.append(
            {
                "source_id": s,
                "target_id": t,
                "relation": r,
                "description": payload.get("description"),
                "evidence": payload.get("evidence"),
                "confidence": avg_conf,
                "source_chunks": sorted(list(payload["source_chunks"])),
            }
        )

    print(f"\nExtracted {len(edges)} unique relations across all chunks.")
    return edges, per_chunk_relations


# =========================
# Example Usage
# =========================

if __name__ == "__main__":
    """
    Assumes you already ran the entity-extraction script and have:
      extracted_output/entities.json
      extracted_output/entities_per_chunk_debug.json
    """

    os.makedirs("extracted_output", exist_ok=True)

    # 1) Load entities + per-chunk debug from previous step
    with open("extracted_output/entities.json", "r", encoding="utf-8") as f:
        nodes = json.load(f)

    with open("extracted_output/entities_per_chunk_debug.json", "r", encoding="utf-8") as f:
        per_chunk = json.load(f)

    # 2) Run relation extraction
    relations, relations_per_chunk = extract_relations_from_entities_and_chunks(
        nodes=nodes,
        per_chunk=per_chunk,
    )

    # 3) Save outputs
    with open("extracted_output/relations.json", "w", encoding="utf-8") as f:
        json.dump(relations, f, indent=2, ensure_ascii=False)

    with open("extracted_output/relations_per_chunk_debug.json", "w", encoding="utf-8") as f:
        json.dump(relations_per_chunk, f, indent=2, ensure_ascii=False)

    print("\nSaved:")
    print("  extracted_output/relations.json")
    print("  extracted_output/relations_per_chunk_debug.json")


In [None]:
## Load the saved relations

import os
import json

def load_extracted_relations(base_dir="extracted_output"):
    """
    Loads:
      - relations.json
      - relations_per_chunk_debug.json

    Returns:
      relations: list of all deduplicated relation edges
      per_chunk_relations: list of per-chunk debug relation records
    """

    relations_path = os.path.join(base_dir, "relations.json")
    per_chunk_path = os.path.join(base_dir, "relations_per_chunk_debug.json")

    if not os.path.exists(relations_path):
        raise FileNotFoundError(f"Missing file: {relations_path}")

    if not os.path.exists(per_chunk_path):
        raise FileNotFoundError(f"Missing file: {per_chunk_path}")

    with open(relations_path, "r", encoding="utf-8") as f:
        relations = json.load(f)

    with open(per_chunk_path, "r", encoding="utf-8") as f:
        per_chunk_relations = json.load(f)

    print("Loaded:")
    print(" - relations.json")
    print(" - relations_per_chunk_debug.json")

    return relations, per_chunk_relations


# Example usage
if __name__ == "__main__":
    relations, relations_per_chunk = load_extracted_relations()

    print(f"Total relations loaded: {len(relations)}")
    print(f"Total chunks returned: {len(relations_per_chunk)}")


#### Extract our sentences

In [None]:
from wtpsplit import SaT

# --- 2. Load the SaT Model ---
# Use 'sat-3l-sm' for a good balance of quality and fast inference.
# The model will be downloaded automatically the first time this runs.
print("Loading SaT model...")
try:
    # Use the small/medium model for general sentence segmentation tasks
    sat = SaT("sat-3l-sm")
except Exception as e:
    print(f"Error loading WtpSplit model: {e}")
    # Handle the error or exit gracefully

# --- 3. Split the text ---
# The .split() method processes the text and returns a list of segmented sentences.
# By default, it handles newlines intelligently.
sentence_list = sat.split(document_text)

print("\n--- Split Sentences (SaT) ---")
for i, sent in enumerate(sentence_list):
    # Strip whitespace/newlines that the model might leave at the start/end of sentences
    print(f"{i+1}: {sent.strip()}")

print("\nFinal Output Type:", type(sentence_list))

In [None]:
sentence_list = [sent.strip() for sent in sentence_list if sent.strip()]

In [None]:
print("\n--- Split Sentences (SaT) ---")
for i, sent in enumerate(sentence_list):
    # Strip whitespace/newlines that the model might leave at the start/end of sentences
    print(f"{i+1}: {sent.strip()}")

In [None]:
import os
import json
# Assuming 'sentence_list' is defined and populated earlier in your notebook.

# --- CODE TO SAVE TO FILE ---

# 1. Define the directory and filename
output_directory = "extracted_output"
json_filename = "segmented_sentences.json"

# 2. Ensure the output directory exists
# This handles the case where 'saved_stuff' hasn't been created yet.
try:
    os.makedirs(output_directory, exist_ok=True)
except Exception as e:
    print(f"❌ Error creating directory: {e}")
    # If the directory can't be created, the save operation will fail later, 
    # but we handle this gracefully.

# 3. Construct the full file path
file_path = os.path.join(output_directory, json_filename)

# 4. Define the data structure you want to save
# Assuming paragraph_list is already populated
data_to_save = {
    "sentences": sentence_list 
}

# 5. Write the data to the file
try:
    with open(file_path, 'w', encoding='utf-8') as f:
        # Use json.dump() to write the dictionary directly to the file with formatting
        json.dump(data_to_save, f, indent=2)
        
    print(f"\n✅ Success: Sentences saved to {file_path} (JSON Format)")
    
except Exception as e:
    print(f"\n❌ Error saving JSON file: {e}")

# --- END SAVE CODE ---

In [None]:
sentence_filename = "segmented_sentences.json"
sentence_file_path = os.path.join("extracted_output", sentence_filename)
retrieved_sentences = None

print(f"\nAttempting to load sentences from: {sentence_file_path}")

with open(sentence_file_path, 'r', encoding='utf-8') as f:
    retrieved_sentences = json.load(f)
    
print(f"✅ Success: Sentences loaded from {sentence_file_path}")

# Access the list of sentences
sentences_list = retrieved_sentences.get("sentences", [])

In [None]:
import numpy as np
from sklearn.cluster import KMeans
from sentence_transformers import SentenceTransformer


# -------------------------------------------------------
# 2. Assemble phrases & combined weighted representation
# -------------------------------------------------------
relation_texts = []
relation_keys = []     # to preserve 1–1 mapping

for rel in relations:
    relation_phrase = rel["relation"]
    desc = rel.get("description", "")

    # store the original phrase for lookup
    relation_keys.append(relation_phrase)

    # combined semantic signal
    combined = f"{relation_phrase} ; {desc}"
    relation_texts.append(combined)

unique_phrases = sorted(list(set(relation_keys)))
print(f"Unique relation labels: {len(unique_phrases)}")


# -------------------------------------------------------
# 3. Build combined embeddings
# -------------------------------------------------------
embedder = SentenceTransformer("all-MiniLM-L6-v2")

print("\nEncoding relation phrases and descriptions...")
relation_embeddings = []

for rel in relations:
    rel_vec = embedder.encode(rel["relation"])
    desc_vec = embedder.encode(rel.get("description", ""))

    combined_vec = rel_vec + 0.3 * desc_vec
    relation_embeddings.append(combined_vec)

relation_embeddings = np.vstack(relation_embeddings)
print("Embedding shape:", relation_embeddings.shape)


# -------------------------------------------------------
# 4. Choose K
# -------------------------------------------------------
num_unique = len(unique_phrases)
K_MAX = 10
K = min(K_MAX, max(1, num_unique))

print(f"\nK selected = {K}")


# -------------------------------------------------------
# 5. Cluster
# -------------------------------------------------------
kmeans = KMeans(n_clusters=K, random_state=42, n_init=10)
cluster_ids = kmeans.fit_predict(relation_embeddings)

# -------------------------------------------------------
# 6. Build ClusterMap (relation → cluster_id)
# -------------------------------------------------------
ClusterMap = {}
for rel, cid in zip(relations, cluster_ids):
    ClusterMap[rel["relation"]] = int(cid) + 1     # convert 0-index → 1-index


# -------------------------------------------------------
# 7. Display clusters grouped
# -------------------------------------------------------
clusters_grouped = {}
for phrase, c_id in ClusterMap.items():
    clusters_grouped.setdefault(c_id, []).append(phrase)

print("\n--- Final Cluster Map ---")
for cid, items in sorted(clusters_grouped.items()):
    print(f"\nCluster {cid}: (n={len(items)})")
    for phrase in sorted(set(items)):
        print("  -", phrase)

pprint(ClusterMap)


In [None]:
pprint(ClusterMap)

In [None]:
pprint(nodes)

#### Getting our $W_k$ matrices

In [None]:
import numpy as np
from scipy.sparse import coo_matrix
from typing import List, Dict, Any, Tuple

# -----------------------------------------------------------
# 1. Build a mapping from node_id ("N1") → integer index
# -----------------------------------------------------------

node_to_idx: Dict[str, int] = {}

# nodes is your loaded entities.json list
for node in nodes:
    node_id_str = node["id"]          # e.g. "N10"
    numeric_id = int(node_id_str[1:]) # remove "N"
    node_to_idx[node_id_str] = numeric_id

NUM_NODES = len(nodes)
MATRIX_DIM = NUM_NODES + 1            # keeping index 0 unused

print(f"Total nodes: {NUM_NODES}")
print(f"Matrix dimension: {MATRIX_DIM} x {MATRIX_DIM}")

In [None]:
pprint(node_to_idx)

In [None]:


# -----------------------------------------------------------
# 2. Number of clusters (relation categories)
# -----------------------------------------------------------

K_CLUSTER = max(ClusterMap.values())
print(f"Total relation clusters (K): {K_CLUSTER}")

# -----------------------------------------------------------
# 3. Prepare COO builders for each W_k
# -----------------------------------------------------------

W_data = [[] for _ in range(K_CLUSTER)]
W_row  = [[] for _ in range(K_CLUSTER)]
W_col  = [[] for _ in range(K_CLUSTER)]

print("\nPopulating sparse matrix builders…")

# -----------------------------------------------------------
# 4. Iterate over all relations
# -----------------------------------------------------------

for rel in relations:

    relation_phrase = rel["relation"]       # "causes", "targets", etc.
    confidence      = float(rel["confidence"])

    # find cluster ID for this relation phrase
    cluster_id = ClusterMap.get(relation_phrase)
    if cluster_id is None:
        print(f"[WARN] relation phrase '{relation_phrase}' not in ClusterMap. Skipping.")
        continue

    matrix_idx = cluster_id - 1  # convert to 0-indexed

    # subject and object node integer indices
    row_idx = node_to_idx[rel["source_id"]]
    col_idx = node_to_idx[rel["target_id"]]

    # append to the cluster builder
    W_data[matrix_idx].append(confidence)
    W_row[matrix_idx].append(row_idx)
    W_col[matrix_idx].append(col_idx)

# -----------------------------------------------------------
# 5. Build the final sparse COO matrices W₁…Wₖ
# -----------------------------------------------------------

W_k_matrices: List[coo_matrix] = []

print("\nConstructing W_k sparse matrices…")

for k in range(K_CLUSTER):
    W_k = coo_matrix(
        (W_data[k], (W_row[k], W_col[k])),
        shape=(MATRIX_DIM, MATRIX_DIM)
    )
    W_k_matrices.append(W_k)

    print(f"W_{k+1}: shape={W_k.shape}, nnz={W_k.nnz}, "
          f"sum(weights)={W_k.data.sum():.3f}")

print(f"\nSuccessfully created {len(W_k_matrices)} sparse cluster matrices.")
##

In [None]:
## Save the above W_k sparse matrices

import os
from scipy.sparse import save_npz

def save_sparse_matrices(W_k_matrices, directory="extracted_output/sparse_W"):
    """
    Saves each W_k sparse matrix as W_k.npz inside the specified directory.
    """
    os.makedirs(directory, exist_ok=True)

    for idx, W_k in enumerate(W_k_matrices, start=1):
        filename = os.path.join(directory, f"W_{idx}.npz")
        save_npz(filename, W_k)
        print(f"[SAVE] Saved {filename} (shape={W_k.shape}, nnz={W_k.nnz})")

    print(f"\nSaved {len(W_k_matrices)} sparse matrices to '{directory}'.")

save_sparse_matrices(W_k_matrices)



In [None]:
## Load the above saved matrices

import os
from scipy.sparse import load_npz

def load_sparse_matrices(directory="extracted_output/sparse_W"):
    """
    Loads all sparse matrices named W_*.npz in the directory.
    Returns a list of COO matrices in sorted order.
    """
    matrices = []

    # Get all files named W_*.npz and sort by index
    files = [f for f in os.listdir(directory) if f.startswith("W_") and f.endswith(".npz")]
    files = sorted(files, key=lambda x: int(x.split("_")[1].split(".")[0]))

    for filename in files:
        full_path = os.path.join(directory, filename)
        W_k = load_npz(full_path).tocoo()  # ensure COO format
        matrices.append(W_k)
        print(f"[LOAD] Loaded {full_path} (shape={W_k.shape}, nnz={W_k.nnz})")

    print(f"\nLoaded {len(matrices)} sparse matrices from '{directory}'.")
    return matrices


W_k_loaded = load_sparse_matrices()


In [None]:
## Building our final A_w which is a set of our W_k$

import numpy as np
from scipy.sparse import coo_matrix, dok_matrix, csr_matrix
from typing import List, Dict, Any


# ---------------------------------------------------
# 0. Build node_to_idx and MATRIX_DIM from `nodes`
#    nodes: list of dicts like:
#      {
#        "id": "N1",
#        "name": "...",
#        "type": "...",
#        "time": ...,
#        ...
#      }
# ---------------------------------------------------
node_to_idx: Dict[str, int] = {}

# Use the numeric part of "N<number>" as the index, keep 0 unused.
max_index = 0
for node in nodes:
    node_id = node["id"]          # e.g., "N10"
    num_part = int(node_id[1:])   # "10" -> 10
    node_to_idx[node_id] = num_part
    max_index = max(max_index, num_part)

NUM_NODES = len(nodes)            # number of unique nodes
MATRIX_DIM = max_index + 1        # rows/cols = max index + 1 (0 unused)

print(f"Total number of unique nodes (len(nodes)): {NUM_NODES}")
print(f"Max node index (from IDs): {max_index}")
print(f"Matrix dimension (MATRIX_DIM): {MATRIX_DIM} (index 0 unused)\n")


# ---------------------------------------------------
# 1. Number of clusters from ClusterMap
#    ClusterMap: dict mapping relation text -> cluster_id (1..K)
# ---------------------------------------------------
K_CLUSTER = max(ClusterMap.values()) if ClusterMap else 0
print(f"Total number of relation clusters (K): {K_CLUSTER}")


# ---------------------------------------------------
# 2. Initialize K builders for COO-like accumulators
#    One set of (data, row, col) lists per cluster
# ---------------------------------------------------
W_data: List[List[float]] = [[] for _ in range(K_CLUSTER)]
W_row:  List[List[int]]   = [[] for _ in range(K_CLUSTER)]
W_col:  List[List[int]]   = [[] for _ in range(K_CLUSTER)]

print("\nPopulating sparse matrix builders in a single pass over relations...")

for rel in relations:
    # rel: {
    #   "source_id": "N1",
    #   "target_id": "N2",
    #   "relation": "targets",
    #   "description": "...",
    #   "evidence": "...",
    #   "confidence": 0.95,
    #   "source_chunks": [...]
    # }

    relation_phrase = rel["relation"]
    confidence      = float(rel.get("confidence", 0.0))

    # Look up which cluster this relation phrase belongs to
    cluster_id = ClusterMap.get(relation_phrase)
    if cluster_id is None:
        # If some relation string wasn’t clustered (should be rare),
        # we can skip it or log a warning.
        print(f"Warning: relation '{relation_phrase}' not found in ClusterMap. Skipping.")
        continue

    # cluster_id is 1-based; list index is 0-based
    matrix_index = cluster_id - 1

    # Map node IDs -> integer indices (as per node_to_idx)
    src_id = rel["source_id"]
    tgt_id = rel["target_id"]

    if src_id not in node_to_idx or tgt_id not in node_to_idx:
        print(f"Warning: node id(s) {src_id}, {tgt_id} not in node_to_idx. Skipping edge.")
        continue

    row_idx = node_to_idx[src_id]
    col_idx = node_to_idx[tgt_id]

    # Append this edge to that cluster’s accumulator lists
    W_data[matrix_index].append(confidence)
    W_row[matrix_index].append(row_idx)
    W_col[matrix_index].append(col_idx)


# ---------------------------------------------------
# 3. Helper: max-aggregate multiple edges into final CSR
# ---------------------------------------------------
def finalize_relation_matrix(
    data_list: List[float],
    row_list: List[int],
    col_list: List[int],
    n_dim: int
) -> csr_matrix:
    """
    Build an (n_dim x n_dim) sparse matrix where, for each (row, col),
    we store the MAX of all confidence scores seen for that edge
    in this cluster.
    """
    if not data_list:
        # Return an empty matrix
        return csr_matrix((n_dim, n_dim), dtype=np.float32)

    W_dok = dok_matrix((n_dim, n_dim), dtype=np.float32)

    for data, row, col in zip(data_list, row_list, col_list):
        value = float(data)
        current_max = W_dok[row, col]  # 0.0 if not yet set
        if value > current_max:
            W_dok[row, col] = value

    return W_dok.tocsr()


# ---------------------------------------------------
# 4. Build final aggregated CSR matrices per cluster
# ---------------------------------------------------
Final_Aggregated_W: List[csr_matrix] = []

print("\n--- Starting Max Aggregation for All Clusters ---")

for k_index in range(K_CLUSTER):
    raw_data_list = W_data[k_index]
    row_list_k    = W_row[k_index]
    col_list_k    = W_col[k_index]

    if not raw_data_list:
        # No edges in this cluster
        empty_csr = csr_matrix((MATRIX_DIM, MATRIX_DIM), dtype=np.float32)
        Final_Aggregated_W.append(empty_csr)
        print(f"✅ Final W_{k_index + 1}: EMPTY matrix, Shape={empty_csr.shape}")
        continue

    try:
        data_array_k = np.asarray(raw_data_list, dtype=np.float32)
    except Exception as e:
        print(f"❌ ERROR in W_{k_index + 1} data conversion: {e}")
        # Push empty matrix or skip; here we push empty to keep indexing consistent
        empty_csr = csr_matrix((MATRIX_DIM, MATRIX_DIM), dtype=np.float32)
        Final_Aggregated_W.append(empty_csr)
        continue

    Final_W_k_CSR = finalize_relation_matrix(
        data_array_k.tolist(),  # list of floats
        row_list_k,             # list of row indices
        col_list_k,             # list of col indices
        MATRIX_DIM
    )

    Final_Aggregated_W.append(Final_W_k_CSR)

    print(
        f"✅ Final W_{k_index + 1}: "
        f"Shape={Final_W_k_CSR.shape}, "
        f"Unique Edges={Final_W_k_CSR.nnz}, "
        f"Total Weight={Final_W_k_CSR.data.sum():.4f}"
    )

print(f"\nSuccessfully created {len(Final_Aggregated_W)} final, maximized CSR matrices.")


# ---------------------------------------------------
# 5. Pretty-print non-zero entries for sanity check
# ---------------------------------------------------
print("\n--- Final Maximized Relational Matrices (W_k) ---")
print(f"Total Matrices in List: {len(Final_Aggregated_W)}\n")

for k_index, final_Wk in enumerate(Final_Aggregated_W):
    cluster_id = k_index + 1
    print(f"\n**Matrix W_{cluster_id} (Cluster {cluster_id})**")
    print(
        f"Shape: {final_Wk.shape}, "
        f"Unique Edges (nnz): {final_Wk.nnz}, "
        f"Total Weight: {final_Wk.data.sum():.4f}"
    )

    if final_Wk.nnz == 0:
        print("  [Matrix is empty (no relationships in this cluster)]")
        continue

    coo_wk: coo_matrix = final_Wk.tocoo()
    print("  Non-Zero Edges (Subject Index -> Object Index | Max Confidence Score):")
    for r, c, d in zip(coo_wk.row, coo_wk.col, coo_wk.data):
        print(f"  ({r} -> {c}) | Score: {d:.4f}")


In [None]:
A_w = Final_Aggregated_W