In [None]:
import sys

!{sys.executable} -m pip install -U dataparser

# Initial Loading (From M2)

In [None]:
## Retrieve the K1

import json
import numpy as np

load_path = "extracted_output/k1.json"

with open(load_path, "r") as f:
    loaded_list = json.load(f)

# Convert back to (int, np.int64) tuple format
K1_loaded = [(int(a), np.int64(b)) for (a, b) in loaded_list]

print(f"Loaded K1 from {load_path}")
print(K1_loaded[:10])  # preview

## Retrieve the Z (GAE embeddings)

import torch

load_path = "extracted_output/Z.pt"

Z_loaded = torch.load(load_path, map_location="cpu")  # safe for CPU use

print(f"Loaded Z from: {load_path}")
print("Shape:", Z_loaded.shape)
print("Dtype:", Z_loaded.dtype)

## Retrieve the A_co_occur

from scipy.sparse import load_npz

load_path = "extracted_output/A_co_occur.npz"

A_co_occur_loaded = load_npz(load_path)

print(f"Loaded A_co_occur from: {load_path}")
print(A_co_occur_loaded)


In [None]:
E_prior = K1_loaded
A_co_occur = A_co_occur_loaded
Z = Z_loaded

In [None]:
import os
from pprint import pprint

## Load our sentences_list
sentence_filename = "segmented_sentences.json"
sentence_file_path = os.path.join("extracted_output", sentence_filename)
retrieved_sentences = None

print(f"\nAttempting to load sentences from: {sentence_file_path}")

with open(sentence_file_path, 'r', encoding='utf-8') as f:
    retrieved_sentences = json.load(f)
    
print(f"✅ Success: Sentences loaded from {sentence_file_path}")

# Access the list of sentences
sentences_list = retrieved_sentences.get("sentences", [])

In [None]:
## Load the saved relations

import os
import json

def load_extracted_relations(base_dir="extracted_output"):
    """
    Loads:
      - relations.json
      - relations_per_chunk_debug.json

    Returns:
      relations: list of all deduplicated relation edges
      per_chunk_relations: list of per-chunk debug relation records
    """

    relations_path = os.path.join(base_dir, "relations.json")
    per_chunk_path = os.path.join(base_dir, "relations_per_chunk_debug.json")

    if not os.path.exists(relations_path):
        raise FileNotFoundError(f"Missing file: {relations_path}")

    if not os.path.exists(per_chunk_path):
        raise FileNotFoundError(f"Missing file: {per_chunk_path}")

    with open(relations_path, "r", encoding="utf-8") as f:
        relations = json.load(f)

    with open(per_chunk_path, "r", encoding="utf-8") as f:
        per_chunk_relations = json.load(f)

    print("Loaded:")
    print(" - relations.json")
    print(" - relations_per_chunk_debug.json")

    return relations, per_chunk_relations


# Example usage
if __name__ == "__main__":
    relations, relations_per_chunk = load_extracted_relations()

    print(f"Total relations loaded: {len(relations)}")
    print(f"Total chunks returned: {len(relations_per_chunk)}")


## Load the saved entities

import json
import os

def load_extracted_entities(base_dir="extracted_output"):
    """
    Loads:
      - entities.json              → list of deduplicated entity/event nodes
      - entities_per_chunk_debug.json → raw per-chunk extraction data
    """

    entities_path = os.path.join(base_dir, "entities.json")
    per_chunk_path = os.path.join(base_dir, "entities_per_chunk_debug.json")

    if not os.path.exists(entities_path):
        raise FileNotFoundError(f"Missing file: {entities_path}")

    if not os.path.exists(per_chunk_path):
        raise FileNotFoundError(f"Missing file: {per_chunk_path}")

    with open(entities_path, "r", encoding="utf-8") as f:
        entities = json.load(f)

    with open(per_chunk_path, "r", encoding="utf-8") as f:
        per_chunk = json.load(f)

    print("Loaded:")
    print(" - entities.json")
    print(" - entities_per_chunk_debug.json")

    return entities, per_chunk


# Example usage:
if __name__ == "__main__":
    nodes, nodes_per_chunk = load_extracted_entities()
    print(f"Total nodes loaded: {len(nodes)}")
    print(f"Total chunks loaded: {len(nodes_per_chunk)}")


# 3.3.2 Core CoCaD

## (a) Building $W_{direct}$

### (i).  $f_{structural}$ (Our GNN intervention score)

In [None]:
import json
import pprint
from typing import List, Tuple, Dict, Any, Union

import numpy as np
from scipy.sparse import csr_matrix, lil_matrix
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GAE
from torch_geometric.utils import from_scipy_sparse_matrix

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

# =========================================================
# HYPERPARAMETERS
# =========================================================
B_ENSEMBLE   = 5    # number of GAE models in the ensemble
GCN_HIDDEN   = 64   # hidden dimension
GCN_LATENT   = 64   # latent dimension
GCN_EPOCHS   = 200  # epochs per GAE
LEARNING_RATE = 0.005
WEIGHT_DECAY  = 5e-4

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# =========================================================
# 1. GCN ENCODER (for GAE)
# =========================================================

class GCNEncoder(nn.Module):
    """
    Two-layer GCN encoder used inside GAE.
    cached=False so that changes to edge_index are respected.
    """
    def __init__(self, in_channels: int, hidden_channels: int, out_channels: int):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels, cached=False)
        self.conv2 = GCNConv(hidden_channels, out_channels, cached=False)
        self.reset_parameters()

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x


# =========================================================
# 2. BUILD BASE GRAPH FROM Z, A_co_occur, E_prior
# =========================================================

def build_base_graph(
    Z_loaded: Union[np.ndarray, torch.Tensor],
    A_co_occur: csr_matrix,
    E_prior: List[Tuple[int, int]],
) -> Tuple[torch.Tensor, csr_matrix, torch.Tensor]:
    """
    Uses the learned Z_loaded as node features X.

    Inputs:
      - Z_loaded: latent embeddings (N x d), either np.ndarray or torch.Tensor.
                  Index 0 corresponds to node "N1" (1-based in E_prior).
      - A_co_occur: csr_matrix adjacency (N x N), typically your symmetric
                    co-occurrence scaffolding graph.
      - E_prior: list of (i, j) pairs with 1-based node indices (N1 -> 1, N2 -> 2, ...).
                 These are your candidate / prior edges from ANN etc.

    Steps:
      1) Convert Z_loaded → torch.Tensor (X).
      2) Convert A_co_occur → LIL and add all E_prior edges (if missing).
      3) Symmetrize adjacency.
      4) Convert to PyG edge_index.

    Returns:
      x          : torch.Tensor of shape (N, d)
      A_sym      : symmetric csr_matrix adjacency (N x N)
      edge_index : torch.LongTensor [2, num_edges]
    """
    assert isinstance(A_co_occur, csr_matrix), "A_co_occur must be a CSR matrix."

    # ---- Handle Z_loaded ----
    if isinstance(Z_loaded, np.ndarray):
        x = torch.from_numpy(Z_loaded.astype(np.float32))
    elif isinstance(Z_loaded, torch.Tensor):
        x = Z_loaded.float()
    else:
        raise TypeError(f"Z_loaded must be np.ndarray or torch.Tensor, got {type(Z_loaded)}")

    N_nodes = x.shape[0]
    assert A_co_occur.shape == (N_nodes, N_nodes), \
        "A_co_occur shape must match number of rows in Z_loaded."

    # ---- Start from mutable adjacency ----
    A_lil: lil_matrix = A_co_occur.tolil()

    # ---- Add E_prior edges (1-based → 0-based) ----
    for (i_raw, j_raw) in E_prior:
        i = int(i_raw) - 1
        j = int(j_raw) - 1
        if i < 0 or j < 0 or i >= N_nodes or j >= N_nodes:
            # defensive: skip out-of-range indices
            continue
        if i == j:
            continue

        # treat as undirected for the base graph
        A_lil[i, j] = 1.0
        A_lil[j, i] = 1.0

    # Back to CSR and symmetrize
    A_aug = A_lil.tocsr().astype(np.float32)
    A_sym = ((A_aug + A_aug.T) > 0).astype(np.float32)

    # PyG edge_index
    edge_index, _ = from_scipy_sparse_matrix(A_sym)

    return x, A_sym, edge_index


# =========================================================
# 3. TRAIN A SINGLE GAE
# =========================================================

def train_single_gae(
    x: torch.Tensor,
    edge_index: torch.Tensor,
    d_in: int,
    d_hidden: int,
    d_latent: int,
    epochs: int,
    seed: int,
) -> GAE:
    """
    Train a single GAE with a 2-layer GCN encoder on the given graph.
    """
    torch.manual_seed(seed)
    np.random.seed(seed)

    x = x.to(device)
    edge_index = edge_index.to(device)

    encoder = GCNEncoder(d_in, d_hidden, d_latent).to(device)
    model = GAE(encoder).to(device)

    optimizer = Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

    data = Data(x=x, edge_index=edge_index)

    model.train()
    for epoch in range(1, epochs + 1):
        optimizer.zero_grad()
        z = model.encode(data.x, data.edge_index)
        loss = model.recon_loss(z, data.edge_index)
        loss.backward()
        optimizer.step()

        if epoch == 1 or epoch % 50 == 0:
            print(f"[Seed {seed}] Epoch {epoch:03d} | Loss: {loss.item():.4f}")

    return model


# =========================================================
# 4. BUILD ENSEMBLE OF GNNs
# =========================================================

def build_gnn_ensemble(
    x: torch.Tensor,
    edge_index: torch.Tensor,
    d_in: int,
    d_hidden: int,
    d_latent: int,
    epochs: int,
    B: int,
) -> List[GAE]:
    """
    Train B different GAE models with different random seeds.
    """
    ensemble: List[GAE] = []
    for b in range(B):
        seed = 100 + b
        print(f"\n=== Training GAE model {b+1}/{B} (seed={seed}) ===")
        model_b = train_single_gae(
            x=x,
            edge_index=edge_index,
            d_in=d_in,
            d_hidden=d_hidden,
            d_latent=d_latent,
            epochs=epochs,
            seed=seed,
        )
        ensemble.append(model_b)
    return ensemble


# =========================================================
# 5. GNN PREDICTION FOR NODE j
# =========================================================

def gnn_predict_node_embedding(
    model: GAE,
    x: torch.Tensor,
    edge_index: torch.Tensor,
    node_id_1based: int,
) -> np.ndarray:
    """
    Runs the GAE encoder and returns the latent vector for node j (1-based index).
    """
    model.eval()
    x = x.to(device)
    edge_index = edge_index.to(device)

    with torch.no_grad():
        z = model.encode(x, edge_index)  # shape: [N_nodes, d_latent]

    j_idx = node_id_1based - 1
    z_j = z[j_idx].detach().cpu().numpy()

    # optional: normalize
    norm = np.linalg.norm(z_j) + 1e-12
    return z_j / norm


# =========================================================
# 6. COSINE DISTANCE
# =========================================================

def cosine_distance(v1: np.ndarray, v2: np.ndarray) -> float:
    """
    1 - cosine similarity between two 1D vectors.
    """
    v1 = v1.flatten()
    v2 = v2.flatten()
    denom = (np.linalg.norm(v1) * np.linalg.norm(v2)) + 1e-12
    cos_sim = float(np.dot(v1, v2) / denom)
    return 1.0 - cos_sim


# =========================================================
# 7. MAIN: STRUCTURAL EFFECT SCORES (NODE DELETION)
# =========================================================

def compute_structural_effect_scores(
    Z_loaded: Union[np.ndarray, torch.Tensor],
    A_co_occur: csr_matrix,
    E_prior: List[Tuple[int, int]],
    B_ensemble: int = B_ENSEMBLE,
    gcn_hidden: int = GCN_HIDDEN,
    gcn_latent: int = GCN_LATENT,
    gcn_epochs: int = GCN_EPOCHS,
) -> Dict[Tuple[int, int], Dict[str, float]]:
    """
    Structural effect f_structural via a bootstrap ensemble of GNNs
    using a NODE-DELETION intervention.

    For each (i, j) in E_prior (1-based indices):
      For each model b in ensemble:
        - z_normal_{j,b}     : embedding of node j on full graph
        - z_intervened_{j,b} : embedding of node j when node i is isolated
        - score_b(i, j)      : cosine_distance(z_normal, z_intervened)

      μ_GNN(i, j)   = mean_b(score_b)
      σ^2_GNN(i, j) = var_b(score_b)

    Returns:
      dict mapping (i, j) → {
        "mean_ensemble_score": float,
        "variance_ensemble_score": float
      }
    """

    # Shape sanity
    if isinstance(Z_loaded, np.ndarray):
        N_nodes, d_embed = Z_loaded.shape
    elif isinstance(Z_loaded, torch.Tensor):
        N_nodes, d_embed = Z_loaded.shape
    else:
        raise TypeError(f"Z_loaded must be np.ndarray or torch.Tensor, got {type(Z_loaded)}")

    print(f"Z_loaded shape = ({N_nodes}, {d_embed})")

    # 1) Build base graph (features + adjacency)
    x, A_sym_base, edge_index_base = build_base_graph(Z_loaded, A_co_occur, E_prior)

    # 2) Train ensemble on full graph
    print("\n=== Building GNN Ensemble ===")
    ensemble_models = build_gnn_ensemble(
        x=x,
        edge_index=edge_index_base,
        d_in=d_embed,
        d_hidden=gcn_hidden,
        d_latent=gcn_latent,
        epochs=gcn_epochs,
        B=B_ensemble,
    )

    x_device = x.to(device)
    edge_index_full_device = edge_index_base.to(device)

    results: Dict[Tuple[int, int], Dict[str, float]] = {}

    # 3) For each (i, j) pair: delete node i’s edges and measure effect on j
    print("\n=== Computing structural effects for E_prior (NODE DELETION) ===")
    for (i_raw, j_raw) in E_prior:
        i = int(i_raw)
        j = int(j_raw)

        if i == j:
            # ignore self-edges if any
            continue

        print(f"\n-- Pair (i={i}, j={j}) --")

        # ---- NODE DELETION: isolate node i ----
        A_int_lil = A_sym_base.tolil()
        idx = i - 1  # 0-based index in adjacency

        # zero out row and column
        A_int_lil[idx, :] = 0.0
        A_int_lil[:, idx] = 0.0

        A_int = A_int_lil.tocsr()
        edge_index_int, _ = from_scipy_sparse_matrix(A_int)
        edge_index_int_device = edge_index_int.to(device)

        scores_b: List[float] = []

        for b_idx, model_b in enumerate(ensemble_models):
            # normal embedding
            z_normal_j = gnn_predict_node_embedding(
                model=model_b,
                x=x_device,
                edge_index=edge_index_full_device,
                node_id_1based=j,
            )

            # embedding after deleting node i’s edges
            z_intervened_j = gnn_predict_node_embedding(
                model=model_b,
                x=x_device,
                edge_index=edge_index_int_device,
                node_id_1based=j,
            )

            score_b = cosine_distance(z_normal_j, z_intervened_j)
            scores_b.append(score_b)
            print(f"  Model {b_idx+1}/{B_ensemble}: score_b = {score_b:.6f}")

        mean_score = float(np.mean(scores_b))
        var_score = float(np.var(scores_b))

        results[(i, j)] = {
            "mean_ensemble_score": round(mean_score, 6),
            "variance_ensemble_score": round(var_score, 8),
        }

        print(f"  -> mean = {mean_score:.6f}, var = {var_score:.8f}")

    return results


# =========================================================
# 8. EXAMPLE USAGE (plug in your Z, A_co_occur, E_prior)
# =========================================================

if __name__ == "__main__":
    """
    Example wiring. You already have:

      - entities (list of dicts; not directly used here)
      - relations (list of dicts; used earlier to build A_co_occur and E_prior)
      - Z.pt (latent embeddings) saved earlier
      - A_co_occur.npz (symmetric co-occurrence adjacency) saved earlier
      - E_prior.json (candidate edges from ANN / prior stage), 1-based indices

    Replace paths as needed.
    """
    import os
    from scipy.sparse import load_npz
    
    # 4) Compute structural effects
    results = compute_structural_effect_scores(
        Z_loaded=Z,
        A_co_occur=A_co_occur,
        E_prior=E_prior,
    )

    print("\n=== Structural Effect Scores (μ, σ²) ===")
    pprint.pprint(results)


In [None]:
w_direct_structural_results = results

In [None]:
type(w_direct_structural_results)

In [None]:
# Save the above
import os
import pickle

def save_w_direct_structural_results(w_direct_structural_results: dict,
                                     base_path="extracted_output/cocad/w_direct"):
    """
    Saves w_direct_structural_results as a pickle file.
    """

    os.makedirs(base_path, exist_ok=True)

    save_path = os.path.join(base_path, "w_direct_structural_results.pkl")

    with open(save_path, "wb") as f:
        pickle.dump(w_direct_structural_results, f)

    print(f"Saved w_direct_structural_results → {save_path}")


save_w_direct_structural_results(w_direct_structural_results)

In [None]:
## Retrieve from the above
import os
import pickle

def load_w_direct_structural_results(base_path="extracted_output/cocad/w_direct"):
    """
    Loads w_direct_structural_results from pickle file.
    """

    load_path = os.path.join(base_path, "w_direct_structural_results.pkl")

    if not os.path.exists(load_path):
        raise FileNotFoundError(f"No file found at {load_path}")

    with open(load_path, "rb") as f:
        data = pickle.load(f)

    print(f"Loaded w_direct_structural_results ← {load_path}")
    return data

w_direct_structural_results = load_w_direct_structural_results()

### (ii).  $f_{llm}$ (The LLM Counterfactual)

#### A. Context Retrieval (RAG-HyDE-RAV)

In [None]:
from pprint import pprint

In [None]:
with open("extracted_output/entities.json", "r") as f:
    nodes = json.load(f)

# Convert list → dict
entities_by_id = { node["id"]: node for node in nodes }


In [None]:
pprint(entities_by_id)

In [None]:
import os
import json
import time
import numpy as np
import faiss
import torch
from typing import List, Tuple, Set, Dict, Any, Optional

# External Libraries required: google-genai, sentence-transformers, faiss-cpu, numpy, scikit-learn
from google.genai import Client, types
from google.genai import errors as genai_errors
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from pydantic import BaseModel
from typing import Literal

# --- HYPERPARAMETERS ---
K_HYPOTHETICAL = 2       # Number of causal hypotheses to generate per pair
R_STOCHASTIC_SAMPLES = 3 # Number of samples for Semantic Entropy estimation
K_RAG = 3                # Top-k items retrieved for verification (per view)
K_BASE = 5               # Base pool size for RAG-MMR
K_EXPANSION = 5          # Pool expansion factor for uncertainty
LAMBDA_MMR = 0.5         # MMR trade-off: 0.5 favors both relevance and diversity
KAPPA_RRF = 60           # RRF constant
TAU_SUPPORT = 0.80       # Min confidence (p_support) threshold for verification
TAU_ENTROPY = 0.15       # Max Semantic Entropy threshold (low entropy = stable)
DELAY_SECONDS = 10       # Delay between API calls for stability (free tier friendly)

PRIMARY_MODEL = "gemini-2.5-flash-lite"
FALLBACK_MODEL = "gemini-2.0-flash-lite"

# --- ASSUMED INPUTS (Must be defined in the execution environment) ---
# N_map (Dict[str, Any]): Node ID -> node info (string or rich dict)
#   e.g. {
#     "N1": "Global Magnitsky Sanctions Bill passage"
#   }
#   OR
#   {
#     "N1": {
#       "name": "...",
#       "type": "...",
#       "time": "...",
#       "location": "...",
#       "description": "..."
#     },
#     ...
#   }
#
# candidate_set (Set[Tuple[int, int]]): Candidate edges in *0-indexed* integer space,
#   aligned with sorted(N_map.keys()) order.
#
# sentences (List[str]): The full document split into sentences (sentence chunking).
#
# relations (List[Dict]): 
#   [
#     {
#       "source_id": "N1",
#       "target_id": "N2",
#       "relation": "targets",
#       "description": "...",
#       "evidence": "...",
#       "confidence": 0.95,
#       "source_chunks": [0]
#     },
#     ...
#   ]

# ---------------------------------------------------------------------

# Initialize Client and Embedder
client = Client(api_key=os.environ["GEMINI_API_KEY"])
embedder = SentenceTransformer("all-MiniLM-L6-v2")


# ====================================================================
# GENERAL LLM CALL WRAPPER WITH FALLBACK
# ====================================================================

def safe_generate_content(
    contents: str,
    config: types.GenerateContentConfig,
    max_retries_per_model: int = 2,
) -> types.GenerateContentResponse:
    """
    Try PRIMARY_MODEL first; on 429/RESOURCE_EXHAUSTED, fallback to FALLBACK_MODEL.
    Retries each model up to max_retries_per_model times.
    """

    models_to_try = [PRIMARY_MODEL, FALLBACK_MODEL]
    last_exc: Optional[Exception] = None

    for model_name in models_to_try:
        for attempt in range(1, max_retries_per_model + 1):
            try:
                print(f"[safe_generate_content] Calling model={model_name}, attempt={attempt}")
                resp = client.models.generate_content(
                    model=model_name,
                    contents=contents,
                    config=config,
                )
                return resp
            except genai_errors.ClientError as e:
                msg = str(e)
                status_code = getattr(e, "status_code", None)
                print(f"[safe_generate_content] ClientError on {model_name}: {msg}")

                # 429 / RESOURCE_EXHAUSTED -> try next model or retry
                if status_code == 429 or "RESOURCE_EXHAUSTED" in msg:
                    last_exc = e
                    # backoff a bit before next attempt / model
                    time.sleep(DELAY_SECONDS)
                    # break out of retry loop and move to next model
                    break
                else:
                    # Other client errors are not transient
                    raise

    # If we reach here, both models failed with quota/rate errors
    if last_exc is not None:
        raise last_exc
    raise RuntimeError("safe_generate_content failed with unknown error.")


# ====================================================================
# PHASE 1: PREPARATION & INDEXING
# ====================================================================

def setup_document_index(sentences: List[str]) -> Tuple[faiss.IndexFlatL2, np.ndarray]:
    """
    Sentence-level vector store (VD) + FAISS index over sentences.
    """
    print("1. Embedding document sentences (VD)...")
    VD = embedder.encode(sentences, convert_to_numpy=True)

    d_embed = VD.shape[1]
    idx_doc = faiss.IndexFlatL2(d_embed)
    idx_doc.add(VD)
    print(f"2. FAISS Index built on M={len(sentences)} sentences.")

    return idx_doc, VD


def build_relation_chunks(
    relations: List[Dict[str, Any]],
    N_map: Dict[str, Any],
) -> List[str]:
    """
    Build textual 'relation chunks' for retrieval, one per relation.
    Example chunk:

    "N1 -> N2 (targets): The Global Magnitsky Sanctions Bill primarily targeted Titan Industries.
     Evidence: The primary target of this legislation was the sprawling Titan Industries conglomerate."
    """
    chunks: List[str] = []

    def fmt_node_short(node_id: str) -> str:
        v = N_map.get(node_id, node_id)
        if isinstance(v, dict):
            name = v.get("name") or node_id
        else:
            name = str(v)
        return f"{node_id} ({name})"

    for r in relations:
        src = r.get("source_id", "")
        tgt = r.get("target_id", "")
        rel = r.get("relation", "")
        desc = r.get("description") or ""
        ev = r.get("evidence") or ""

        src_str = fmt_node_short(src)
        tgt_str = fmt_node_short(tgt)

        chunk = f"{src_str} -> {tgt_str} [{rel}]. {desc}"
        if ev:
            chunk += f" Evidence: {ev}"
        chunks.append(chunk.strip())

    return chunks


def setup_relation_index(
    relations: Optional[List[Dict[str, Any]]],
    N_map: Dict[str, Any],
) -> Tuple[Optional[faiss.IndexFlatL2], Optional[np.ndarray], Optional[List[str]]]:
    """
    Build a FAISS index over 'relation chunks' if relations are provided.
    Returns (idx_rel, VR, rel_chunks) or (None, None, None) if no relations.
    """
    if not relations:
        print("No relations provided – skipping relation index.")
        return None, None, None

    print("Embedding relation chunks (VR)...")
    rel_chunks = build_relation_chunks(relations, N_map)
    if not rel_chunks:
        print("No relation chunks built – skipping relation index.")
        return None, None, None

    VR = embedder.encode(rel_chunks, convert_to_numpy=True)
    d_embed = VR.shape[1]
    idx_rel = faiss.IndexFlatL2(d_embed)
    idx_rel.add(VR)
    print(f"Relation FAISS Index built on R={len(rel_chunks)} relation chunks.")
    return idx_rel, VR, rel_chunks


# ====================================================================
# PHASE 2: LLM CALLS AND UNCERTAINTY ESTIMATION
# ====================================================================

def format_node_for_prompt(node_id: str, N_map: Dict[str, Any]) -> str:
    """
    Turn a node id into a rich, human-readable description for the LLM.
    Handles:
      - simple string entries, or
      - dict entries with name/type/time/location/description.
    """
    v = N_map.get(node_id, node_id)
    if isinstance(v, dict):
        name = v.get("name") or node_id
        ntype = v.get("type")
        time_val = v.get("time")
        loc = v.get("location")
        desc = v.get("description")

        parts = [name]
        meta_bits = []
        if ntype:
            meta_bits.append(ntype)
        if time_val:
            meta_bits.append(f"time={time_val}")
        if loc:
            meta_bits.append(f"location={loc}")
        if meta_bits:
            parts.append(f"({', '.join(meta_bits)})")
        if desc:
            parts.append(f"- {desc}")
        return " ".join(parts)
    else:
        return str(v)


def llm_generate_hypotheses(node_i_str: str, node_j_str: str) -> List[str]:
    """f_hypothetical: Generates K_HYPOTHETICAL causal claims."""

    prompt = f"""
You are a causal reasoning assistant.

Generate {K_HYPOTHETICAL} short, hypothetical sentences describing a plausible causal connection
from the *cause* to the *effect*.

Cause node: {node_i_str}
Effect node: {node_j_str}

Rules:
- Start each sentence with 'It is plausible that'.
- Be concise (max 25 words).
- Only output the sentences, one per line, no numbering, no extra text.
"""

    resp = safe_generate_content(
        contents=prompt,
        config=types.GenerateContentConfig(
            temperature=0.8,
        ),
    )

    raw_text = resp.text or ""
    if not raw_text.strip():
        print("[llm_generate_hypotheses] Empty LLM response; returning no hypotheses.")
        return []

    hypotheses = [
        line.strip()
        for line in raw_text.split("\n")
        if line.strip() and "It is plausible that" in line
    ]

    return hypotheses[:K_HYPOTHETICAL]


class VerifySupport(BaseModel):
    """Schema for claim verification result."""
    support: Literal["YES", "NO"]


def llm_verify_claim(claim: str, evidence: str, temperature: float) -> Tuple[float, str]:
    """f_verify: Verifies factual support and returns confidence + the model's primary choice."""

    prompt_verify = f"""
Claim: '{claim}'.
Evidence:
\"\"\"
{evidence}
\"\"\"

Based ONLY on this Evidence, does the Evidence strongly support the Claim?

Answer strictly as a JSON object of the form:
{{"support": "YES"}} or {{"support": "NO"}}.
"""

    resp = safe_generate_content(
        contents=prompt_verify,
        config=types.GenerateContentConfig(
            temperature=temperature,
            response_mime_type="application/json",
            response_schema=VerifySupport,
        ),
    )

    choice = "NO"

    try:
        parsed = getattr(resp, "parsed", None)
        support_val = None

        if isinstance(parsed, VerifySupport):
            support_val = parsed.support

        if support_val not in ("YES", "NO"):
            raise ValueError("Could not find valid 'support' field in structured response")

        choice = support_val

    except Exception as e:
        print(f"[llm_verify_claim] Failed to parse structured response, defaulting to NO: {e}")
        choice = "NO"

    p_support = 1.0 if choice == "YES" else 0.0
    return p_support, choice


def estimate_semantic_entropy(claim: str, evidence: str) -> Tuple[float, float]:
    """
    Calculates p_support (mean confidence) and H_semantic via R stochastic passes.
    """
    TEMPERATURE = 0.7
    yes_count = 0

    for _ in range(R_STOCHASTIC_SAMPLES):
        p_support_r, choice_r = llm_verify_claim(claim, evidence, temperature=TEMPERATURE)
        if choice_r == "YES":
            yes_count += 1
        time.sleep(DELAY_SECONDS)  # Respect rate limits

    p_support_mean = yes_count / R_STOCHASTIC_SAMPLES

    # Consistency = majority frequency; H_semantic proxy = 1 - consistency
    consistency = max(p_support_mean, 1.0 - p_support_mean)
    h_semantic_proxy = 1.0 - consistency

    return p_support_mean, h_semantic_proxy


# ====================================================================
# PHASE 3: RAG-MMR AND RANK FUSION (RRF) LOGIC
# ====================================================================

def mmr_rerank(query_vec, candidates_vectors, k_rag: int) -> List[int]:
    """
    Reranks candidate indices using Maximal Marginal Relevance (MMR).
    query_vec: shape (d,) or (1, d)
    candidates_vectors: np.ndarray of shape (K, d)
    """
    # Ensure numpy arrays
    q = np.array(query_vec)
    C = np.array(candidates_vectors)

    # Make sure both are 2D: (1, d) and (K, d)
    if q.ndim == 1:
        q = q.reshape(1, -1)
    elif q.ndim > 2:
        q = q.reshape(q.shape[0], -1)

    if C.ndim == 1:
        C = C.reshape(1, -1)
    elif C.ndim > 2:
        C = C.reshape(C.shape[0], -1)

    num_candidates = C.shape[0]
    if num_candidates == 0:
        return []

    selected_indices: List[int] = []

    for _ in range(min(k_rag, num_candidates)):
        best_mmr_score = -np.inf
        best_candidate_index = -1

        for idx in range(num_candidates):
            if idx in selected_indices:
                continue

            d_i_vec = C[idx:idx+1, :]  # shape (1, d)

            # Relevance: sim(d_i, q)
            rel_score = cosine_similarity(d_i_vec, q)[0, 0]

            # Diversity: max sim(d_i, d_j) over already selected
            div_score = 0.0
            if selected_indices:
                selected_vecs = C[selected_indices, :]  # shape (m, d)
                sim_to_selected = cosine_similarity(d_i_vec, selected_vecs)[0]  # shape (m,)
                div_score = float(sim_to_selected.max())

            mmr_score = (LAMBDA_MMR * rel_score) - ((1.0 - LAMBDA_MMR) * div_score)

            if mmr_score > best_mmr_score:
                best_mmr_score = mmr_score
                best_candidate_index = idx

        if best_candidate_index != -1:
            selected_indices.append(best_candidate_index)

    return selected_indices


def reciprocal_rank_fusion(all_ranked_lists: List[List[int]], k_final: int) -> List[int]:
    """Applies RRF to combine multiple ranked lists into a single consensus list."""
    RRF_scores: Dict[int, float] = {}
    for ranked_list in all_ranked_lists:
        for rank, doc_index in enumerate(ranked_list, start=1):
            score = 1.0 / (KAPPA_RRF + rank)
            RRF_scores[doc_index] = RRF_scores.get(doc_index, 0.0) + score

    sorted_scores = sorted(RRF_scores.items(), key=lambda item: item[1], reverse=True)
    return [doc_index for doc_index, score in sorted_scores][:k_final]


# ====================================================================
# PHASE 4: MAIN PIPELINE EXECUTION
# ====================================================================

def run_causal_verification_pipeline(
    candidate_set: Set[Tuple[int, int]],
    N_map: Dict[str, Any],
    sentences: List[str],
    relations: Optional[List[Dict[str, Any]]] = None,
) -> Dict[Tuple[int, int], Dict[str, Any]]:
    """
    Executes the full hypothesis verification and RAG-MMR pipeline for all candidate pairs.

    Improvements vs older version:
    - Uses rich node metadata from N_map (if available).
    - Uses both sentence chunks (sentences) AND relation chunks (relations)
      as evidence for verifying hypotheses.
    - Keeps progress bar + RAG-MMR + RRF for final evidence selection.
    """

    # --- Setup ---
    # Sort node IDs numerically (handles 'N0' if present)
    sorted_node_ids = sorted(
        N_map.keys(),
        key=lambda x: int(x[1:]) if x != "N0" else 0
    )

    # idx -> node_id str map
    idx_to_node_id = {idx: node_id for idx, node_id in enumerate(sorted_node_ids)}

    # 1. Prepare sentence-level document index
    idx_doc, VD = setup_document_index(sentences)

    # 2. Prepare relation-level index 
    idx_rel, VR, rel_chunks = setup_relation_index(relations, N_map)

    final_output_map: Dict[Tuple[int, int], Dict[str, Any]] = {}

    # Precompute valid pairs (exclude ones with 0) and fix an order
    # NOTE: candidate_set is assumed to be 0-based indices aligned with sorted_node_ids
    candidate_pairs: List[Tuple[int, int]] = [
        (i, j) for (i, j) in sorted(candidate_set) if i != 0 and j != 0
    ]
    total_pairs = len(candidate_pairs)

    print(f"\nTotal valid pairs to process: {total_pairs}")

    def print_progress(current_idx: int):
        """Print simple progress bar on a single line."""
        if total_pairs == 0:
            return
        progress = current_idx / total_pairs
        bar_len = 30
        filled = int(bar_len * progress)
        bar = "#" * filled + "-" * (bar_len - filled)
        print(
            f"\rProgress {current_idx}/{total_pairs} [{bar}] {progress * 100:5.1f}%",
            end="",
            flush=True,
        )

    # --- Main Loop: Iterate through each plausible link (i, j) ---
    for idx_pair, (i, j) in enumerate(candidate_pairs, start=1):
        # Update progress bar
        print_progress(idx_pair)

        node_i_id = idx_to_node_id.get(i, f"N{i}")
        node_j_id = idx_to_node_id.get(j, f"N{j}")

        node_i_str = format_node_for_prompt(node_i_id, N_map)
        node_j_str = format_node_for_prompt(node_j_id, N_map)

        print(f"\n--- Processing Pair: {node_i_id} ({i}) -> {node_j_id} ({j}) ---")
        print(f"Cause: {node_i_str}")
        print(f"Effect: {node_j_str}")

        H_verified_list: List[str] = []
        all_ranked_lists_for_pair: List[List[int]] = []

        # --- 1. Hypothesis Generation ---
        hypotheses = llm_generate_hypotheses(node_i_str, node_j_str)
        if not hypotheses:
            print("Skipped: No hypotheses generated.")
            final_output_map[(i, j)] = {
                "verified_causal_hypothesis": [],
                "evidence_text": "No strong, stable evidence found.",
            }
            time.sleep(DELAY_SECONDS)
            continue

        # --- 2. Verification and Dual Filtering ---
        for hl in hypotheses:
            # 2a. Embed hypothesis once
            v_hl = embedder.encode([hl])[0].reshape(1, -1)

            # --- Sentence-based retrieval ---
            D_rag, I_rag = idx_doc.search(v_hl, K_RAG)
            sentence_snippets = " ".join([sentences[idx] for idx in I_rag[0]])

            # --- Relation-based retrieval (if available) ---
            relation_snippets = ""
            if idx_rel is not None and rel_chunks is not None and VR is not None:
                D_rel, I_rel = idx_rel.search(v_hl, K_RAG)
                relation_snippets = " ".join(
                    [rel_chunks[k] for k in I_rel[0] if 0 <= k < len(rel_chunks)]
                )

            # Combine both views into a single evidence text for verification
            if relation_snippets:
                combined_evidence = (
                    "SENTENCE EVIDENCE:\n" + sentence_snippets +
                    "\n\nRELATION EVIDENCE:\n" + relation_snippets
                )
            else:
                combined_evidence = sentence_snippets

            # 2b. Estimate support + semantic entropy
            p_support, h_semantic = estimate_semantic_entropy(hl, combined_evidence)

            # 2c. Dual Filtering: Check thresholds
            if p_support > TAU_SUPPORT and h_semantic < TAU_ENTROPY:
                H_verified_list.append(hl)

                # --- 3. Adaptive Pooling and MMR (For verified hypotheses only) ---
                score_h = p_support  # Use p_support as the score for adaptive pooling

                # Adaptive Pool Size: kpool = kbase + (1 - score) * kexpansion
                k_pool = int(K_BASE + (1.0 - score_h) * K_EXPANSION)
                k_pool = max(K_RAG, k_pool)  # Ensure pool is at least k_RAG

                # Retrieve larger pool from sentence index (for final evidence selection)
                D_pool, I_pool = idx_doc.search(v_hl, k_pool)
                pool_vectors = VD[I_pool[0]]

                # MMR Reranking
                final_indices_h = mmr_rerank(v_hl, pool_vectors, k_rag=K_RAG)

                # Store indices (mapped back to global sentence IDs)
                global_indices_h = [I_pool[0][idx] for idx in final_indices_h]
                all_ranked_lists_for_pair.append(global_indices_h)

            print(
                f"  Claim Verified: {p_support:.2f}/{TAU_SUPPORT:.2f} (Support) | "
                f"{h_semantic:.2f}/{TAU_ENTROPY:.2f} (Entropy) -> "
                f"{'KEPT' if hl in H_verified_list else 'REJECTED'}"
            )

        # --- 4. Final Rank Fusion (RRF) over sentence indices ---
        if all_ranked_lists_for_pair:
            final_indices_total = reciprocal_rank_fusion(all_ranked_lists_for_pair, k_final=3)
            evidence_text = "\n".join([sentences[idx] for idx in final_indices_total])
        else:
            final_indices_total = []
            evidence_text = "No strong, stable evidence found."

        # --- 5. Final Output Format ---
        final_output_map[(i, j)] = {
            "verified_causal_hypothesis": H_verified_list,
            "evidence_text": evidence_text,
        }

        # One delay per pair to avoid hammering the API
        time.sleep(DELAY_SECONDS)

    # Finish progress bar line cleanly
    if total_pairs > 0:
        print_progress(total_pairs)
        print()  # newline

    return final_output_map


In [None]:
final_causal_map = run_causal_verification_pipeline(
    candidate_set=E_prior,          # still 0-indexed ints
    N_map=entities_by_id,      # dict: "N1" -> rich node dict or name string
    sentences=sentences_list,  # sentence chunks
    relations=relations,       # full relations list (optional but recommended)
)

pairwise_structural_context_retrieval = final_causal_map

## Save the above

import os
import pickle

def save_final_causal_map(final_causal_map,
                          base_path="extracted_output/cocad/w_direct"):
    """
    Saves final_causal_map as a pickle file.
    Location:
      extracted_output/cocad/w_direct/final_causal_map.pkl
    """
    os.makedirs(base_path, exist_ok=True)

    save_path = os.path.join(base_path, "final_causal_map.pkl")

    with open(save_path, "wb") as f:
        pickle.dump(final_causal_map, f)

    print(f"[SAVE] final_causal_map saved → {save_path}")

# Save map
save_final_causal_map(pairwise_structural_context_retrieval)

In [None]:
## Load the above 

def load_final_causal_map(base_path="extracted_output/cocad/w_direct"):
    """
    Loads final_causal_map from pickle file.
    """
    load_path = os.path.join(base_path, "final_causal_map.pkl")

    if not os.path.exists(load_path):
        raise FileNotFoundError(f"No file found at: {load_path}")

    with open(load_path, "rb") as f:
        data = pickle.load(f)

    print(f"[LOAD] final_causal_map loaded ← {load_path}")
    return data

pairwise_structural_context_retrieval = load_final_causal_map()

In [None]:
pprint(pairwise_structural_context_retrieval)

In [None]:
target = (10, 2)

for k, v in pairwise_structural_context_retrieval.items():
    # Convert both elements of the key to int for safe comparison
    if int(k[0]) == target[0] and int(k[1]) == target[1]:
        print("Key:", k)
        print("Value:", v)
        break


##### evidence_text vs verified_causal_hypothesis

- evidence_text = The actual text snippets retrieved from the document that support that causal hypothesis.

- verified_causal_hypothesis = A sentence generated by the LLM saying what causal relationship might hold between node i and node j — if the model believes it is plausible AND well-supported.

#### B. Counterfactual Reasoning (CoT Prompt)

In [None]:
import time
from typing import Dict, Tuple, Any, List, Optional

from google.genai import types
from google.genai import errors as genai_errors

# Uses your global:
#   client = Client(api_key=os.environ["GEMINI_API_KEY"])
#   PRIMARY_MODEL = "gemini-2.5-flash-lite"

# ------------------------------
# HYPERPARAMETERS
# ------------------------------
NS_SAMPLES_CF = 5              # Ns counterfactual samples
CF_TEMPERATURE = 0.7
CF_DELAY_SECONDS = 5         # Avoid rate-limit hammering

# ------------------------------
# MODEL ROTATION (same pool as earlier)
# ------------------------------
MODEL_CANDIDATES = [
    PRIMARY_MODEL,             # "gemini-2.5-flash-lite"
    "gemini-2.5-flash",
    "gemini-2.0-flash-lite",
]

MODEL: str = MODEL_CANDIDATES[0]
CURRENT_MODEL_INDEX: int = 0
EXHAUSTED_MODELS = set()


def _switch_to_next_model():
    """
    Mark current MODEL as exhausted and switch to the next available one.
    Raises if all candidates are exhausted.
    """
    global MODEL, CURRENT_MODEL_INDEX

    EXHAUSTED_MODELS.add(MODEL)

    for i in range(len(MODEL_CANDIDATES)):
        idx = (CURRENT_MODEL_INDEX + 1 + i) % len(MODEL_CANDIDATES)
        candidate = MODEL_CANDIDATES[idx]
        if candidate not in EXHAUSTED_MODELS:
            MODEL = candidate
            CURRENT_MODEL_INDEX = idx
            print(f"[LLM] (CF) Switching to backup model: {MODEL}")
            return

    raise RuntimeError("All configured models appear exhausted/quota-limited for today (counterfactual).")


def safe_llm_text(
    prompt: str,
    temperature: float = CF_TEMPERATURE,
    max_retries: int = 3,
) -> str:
    """
    Simple text-generation wrapper with:
      - structured retries
      - model rotation over MODEL_CANDIDATES on quota / rate-limit errors.
    """
    last_exc: Optional[Exception] = None

    for attempt in range(1, max_retries + 1):
        from_model = MODEL
        config = types.GenerateContentConfig(
            temperature=temperature,
            # no response_schema here; we get plain text
        )

        try:
            print(f"[LLM] (CF) Call attempt {attempt} with {from_model}")
            resp = client.models.generate_content(
                model=from_model,
                contents=prompt,
                config=config,
            )
            return getattr(resp, "text", "") or ""
        except genai_errors.ClientError as e:
            msg = str(e)
            if (
                "RESOURCE_EXHAUSTED" in msg
                or "429" in msg
                or "exceeded your current quota" in msg.lower()
            ):
                print(f"[LLM] (CF) Model {from_model} hit quota / rate limit: {msg}")
                last_exc = e
                try:
                    _switch_to_next_model()
                except RuntimeError as switch_err:
                    print("[LLM] (CF) No backup models left.")
                    raise switch_err
                time.sleep(CF_DELAY_SECONDS)
                continue

            raise

    if last_exc:
        raise last_exc
    raise RuntimeError("safe_llm_text (CF) failed unexpectedly.")


# =========================================================
# 1. PARSER FOR LLM OUTPUT
# =========================================================
def _parse_reasoning_and_score(raw_text: str) -> Tuple[str, float]:
    """
    Parses the LLM output into (reasoning_str, score_float).

    Expected soft pattern:
        <reasoning text...>
        SCORE: <float>

    If parsing fails, returns score = 0.0.
    """
    if raw_text is None:
        return "", 0.0

    text = raw_text.strip()
    lines = [ln.strip() for ln in text.splitlines() if ln.strip()]

    reasoning = text
    score = 0.0

    for ln in reversed(lines):
        if ln.upper().startswith("SCORE"):
            parts = ln.split(":")
            if len(parts) >= 2:
                try:
                    score = float(parts[1].strip())
                except ValueError:
                    score = 0.0

            # Reasoning = text before SCORE line
            idx = text.rfind(ln)
            if idx != -1:
                reasoning = text[:idx].strip()
            break

    return reasoning, score


# =========================================================
# 2. COUNTERFACTUAL REASONING
# =========================================================
def counterfactual_reasoning_for_pairs(
    pairs: List[Tuple[int, int]],
    context_map: Dict[Tuple[int, int], str],   # evidence_text from earlier stage
    ns_samples: int = NS_SAMPLES_CF,
    temperature: float = CF_TEMPERATURE,
    delay_seconds: int = CF_DELAY_SECONDS,
) -> Dict[Tuple[int, int], Dict[str, Any]]:
    """
    For each pair (i, j), retrieve context_map[(i, j)] = evidence_text
    from earlier RAG verification stage.

    Then run Ns LLM samples asking:
        "If i were NOT present, what is the effect on j?"

    Returns:
      {
        (i, j): {
           "context": <evidence_text>,
           "reasoning_strings": [...],
           "causal_strength_scores": [...]
        }
      }
    """
    results: Dict[Tuple[int, int], Dict[str, Any]] = {}

    total_pairs = len(pairs)
    if total_pairs == 0:
        print("No pairs provided for counterfactual reasoning.")
        return results

    def print_progress(current_idx: int):
        """Simple pair-level progress bar."""
        progress = current_idx / total_pairs
        bar_len = 30
        filled = int(bar_len * progress)
        bar = "#" * filled + "-" * (bar_len - filled)
        print(
            f"\r[CF] Progress {current_idx}/{total_pairs} [{bar}] {progress * 100:5.1f}%",
            end="",
            flush=True,
        )

    print(f"Starting counterfactual reasoning for {total_pairs} pairs...")

    for pair_idx, (i, j) in enumerate(pairs, start=1):
        key = (int(i), int(j))

        context_txt = context_map.get(key, "")
        print(f"\n\n=== Counterfactual for (i={i}, j={j}) ===")
        print(f"Context length: {len(context_txt)} chars")

        reasoning_list: List[str] = []
        scores_list: List[float] = []

        # --------------------------------------------------
        # Build prompt with evidence_text as context
        # --------------------------------------------------
        base_prompt = f"""
You are a precise causal reasoning assistant.

Context (evidence supporting the causal relation i -> j):
\"\"\"{context_txt}\"\"\"

We are analyzing the counterfactual influence of i on j.

1. Reasoning:
   If i were NOT present, what would happen to j?
   - Use only the information in the context.
   - Think step-by-step.
   - Explicitly state whether the causal impact is strong, weak, or negligible.

2. After your explanation, on a new line write:
   SCORE: <a single number between 0.0 and 1.0>

Where:
- SCORE ≈ 0.0 → almost no causal effect of i on j.
- SCORE ≈ 1.0 → strong causal effect of i on j.

Return only the explanation + the SCORE line.
""".strip()

        # ======================================================
        #  Run Ns LLM samples
        # ======================================================
        for s_idx in range(ns_samples):
            print(f"  -> Sample {s_idx+1}/{ns_samples} ...")

            try:
                raw_text = safe_llm_text(
                    prompt=base_prompt,
                    temperature=temperature,
                )

                reasoning_s, score_s = _parse_reasoning_and_score(raw_text)

                reasoning_list.append(reasoning_s)
                scores_list.append(score_s)

                print(f"     SCORE parsed = {score_s:.3f}")

            except Exception as e:
                print(f"     [Warning] LLM call failed: {e}")
                reasoning_list.append("")
                scores_list.append(0.0)

            # Respect rate-limits
            time.sleep(delay_seconds)

        # ======================================================
        # Store results
        # ======================================================
        results[key] = {
            "context": context_txt,
            "reasoning_strings": reasoning_list,
            "causal_strength_scores": scores_list,
        }

        # Update pair-level progress bar
        print_progress(pair_idx)

    # Final newline for clean console
    print()
    return results


In [None]:
pairs = list(pairwise_structural_context_retrieval.keys())

context_map = {
    k: pairwise_structural_context_retrieval[k]["evidence_text"]
    for k in pairwise_structural_context_retrieval
}

cot_results_counterfactual = counterfactual_reasoning_for_pairs(pairs, context_map)

# Save the above:
import os, pickle

save_path = "extracted_output/cocad/w_direct/cot_results_counterfactual.pkl"
os.makedirs(os.path.dirname(save_path), exist_ok=True)

with open(save_path, "wb") as f:
    pickle.dump(cot_results_counterfactual, f)

print("Saved to:", save_path)


In [None]:
## Load the above
import pickle

load_path = "extracted_output/cocad/w_direct/cot_results_counterfactual.pkl"

with open(load_path, "rb") as f:
    cot_results_counterfactual = pickle.load(f)

print("Loaded.")


In [None]:
pprint(cot_results_counterfactual)

In [None]:
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer

# Load embedder for semantic coherence
coh_embedder = SentenceTransformer("all-MiniLM-L6-v2")


def compute_llm_metrics(cot_results_counterfactual):
    """
    Takes your counterfactual reasoning output and computes:
      - llm_mean_ensemble_score
      - llm_variance_ensemble_score
      - llm_semantic_coherence_score (Rcoh)
    """
    w_direct_llm_results = {}

    for pair, data in cot_results_counterfactual.items():
        scores = data["causal_strength_scores"]        # list of floats
        reasons = data["reasoning_strings"]            # list of strings
        Ns = len(scores)

        # ------------ Mean (µLLM) ------------
        mean_score = float(np.mean(scores))

        # ------------ Variance (σ²LLM) ------------
        variance_score = float(np.var(scores))

        # ------------ Semantic Coherence (Rcoh) ------------
        # Embed each reasoning string
        embeddings = coh_embedder.encode(
            reasons, convert_to_numpy=True, normalize_embeddings=True
        )

        # Compute pairwise cosine similarity matrix
        sim_matrix = cosine_similarity(embeddings)

        # Extract upper triangle (p < q)
        total_pairs = Ns * (Ns - 1) / 2
        if total_pairs > 0:
            upper_tri_vals = []
            for p in range(Ns):
                for q in range(p + 1, Ns):
                    upper_tri_vals.append(sim_matrix[p, q])

            semantic_coherence = float(np.mean(upper_tri_vals))
        else:
            semantic_coherence = 0.0

        # ------------ Store ------------
        w_direct_llm_results[pair] = {
            "llm_mean_ensemble_score": round(mean_score, 6),
            "llm_variance_ensemble_score": round(variance_score, 6),
            "llm_semantic_coherence_score": round(semantic_coherence, 6),
        }

    return w_direct_llm_results


In [None]:
w_direct_llm_results = compute_llm_metrics(cot_results_counterfactual)


In [None]:
## Save the above

import os
import pickle

# Ensure output directory exists
save_path = "extracted_output/cocad/w_direct"
os.makedirs(save_path, exist_ok=True)

file_path = os.path.join(save_path, "w_direct_llm_results.pkl")

with open(file_path, "wb") as f:
    pickle.dump(w_direct_llm_results, f)

print(f"Saved w_direct_llm_results → {file_path}")


In [None]:
# Load the above

import pickle
import os

file_path = "extracted_output/cocad/w_direct/w_direct_llm_results.pkl"

with open(file_path, "rb") as f:
    w_direct_llm_results_loaded = pickle.load(f)

print("Loaded w_direct_llm_results:")
print(type(w_direct_llm_results_loaded))
print(len(w_direct_llm_results_loaded))


In [None]:
pprint(w_direct_llm_results)

In [None]:
import numpy as np
import pprint

# Extract each metric across all pairs
mean_scores = [v["llm_mean_ensemble_score"] for v in w_direct_llm_results.values()]
coherence_scores = [v["llm_semantic_coherence_score"] for v in w_direct_llm_results.values()]
variance_scores = [v["llm_variance_ensemble_score"] for v in w_direct_llm_results.values()]

results = {
    # Averages and variances
    "mean_score_avg": float(np.mean(mean_scores)),
    "mean_score_var": float(np.var(mean_scores, ddof=0)),

    "coherence_score_avg": float(np.mean(coherence_scores)),
    "coherence_score_var": float(np.var(coherence_scores, ddof=0)),

    "variance_score_avg": float(np.mean(variance_scores)),
    "variance_score_var": float(np.var(variance_scores, ddof=0)),

    # Medians
    "mean_score_median": float(np.median(mean_scores)),
    "coherence_score_median": float(np.median(coherence_scores)),
    "variance_score_median": float(np.median(variance_scores)),

    # 75th percentile cutoffs (score required to be in top 25%)
    "mean_score_top40_cutoff": float(np.percentile(mean_scores, 60)),
    "coherence_score_top40_cutoff": float(np.percentile(coherence_scores, 60)),
    "variance_score_top40_cutoff": float(np.percentile(variance_scores, 60)),
}

pprint.pprint(results)


### (iii).  $f_{CI}$ and $f_{llmconfounder}$

#### A. Finding our confounders ($A_{set-confounded}$)

Now to find a "Potential" set of confounders you need a okish reliable asymmetric signal (As of now our only such think is the previous step f_llm's counterfactual reasoning)

In [None]:
%pip install dateparser

In [None]:
import re
from typing import Dict, Tuple, List, Set, Any

import numpy as np
import dateparser


# ---------------------------------------------------------------------
# 1. Time parsing utilities using "dateparser"
# ---------------------------------------------------------------------

def normalize_time_to_ordinal(time_str: str) -> int | None:
    
    if not time_str or not time_str.strip():
        return None

    text = time_str.strip()

    text = re.sub(
        r"\b(around|about|approximately|circa|around the|early|late|mid|beginning of|end of)\b",
        "",
        text,
        flags=re.IGNORECASE,
    )

    text = text.strip(" ,.;")

    if not text:
        return None

    dt = dateparser.parse(
        text,
        settings={
            "PREFER_DATES_FROM": "past",
            "RETURN_AS_TIMEZONE_AWARE": False,
        },
    )
    if dt is None:
        return None

    return dt.date().toordinal()



def build_node_time_ordinals(entities: List[Dict[str, Any]]) -> Dict[int, int | None]:
    """
    From the rich `entities` list, build:

        node_time_ordinals[idx] = ordinal or None

    where idx is the numeric node index (1-based, from 'N1', 'N2', ...).
    """
    node_time_ordinals: Dict[int, int | None] = {}

    for ent in entities:
        node_id = ent.get("id", "")
        if not node_id or not node_id.startswith("N"):
            continue

        try:
            idx = int(node_id[1:])
        except ValueError:
            continue

        time_str = ent.get("time")
        ordinal = normalize_time_to_ordinal(time_str) if time_str else None
        node_time_ordinals[idx] = ordinal

    return node_time_ordinals


# ---------------------------------------------------------------------
# 2. Build A_quick_causal from LLM scores + temporal gating
# ---------------------------------------------------------------------

def build_quick_causal_adjacency(
    entities: List[Dict[str, Any]],
    w_direct_llm_results: Dict[Tuple[int, int], Dict[str, float]],
    tau_score: float = 0.7,
    tau_coh: float = 0.9,
) -> np.ndarray:
    """
    Build the directed, asymmetric adjacency matrix A_quick_causal (N+1 x N+1)
    using:

      - μ_LLM(i, j) = w_direct_llm_results[(i,j)]["llm_mean_ensemble_score"]
      - R_coh(i, j) = w_direct_llm_results[(i,j)]["llm_semantic_coherence_score"]

    AND enforce a temporal constraint:
      - if both i and j have parseable times and time(i) > time(j),
        then we forbid edge i -> j (cannot cause the past).

    Returns:
        A_quick_causal: np.ndarray of shape (N+1, N+1), dtype=int8
        (index 0 is unused; nodes are 1..N)
    """
    # Determine how many nodes we have from the entities list
    num_nodes = max(int(ent["id"][1:]) for ent in entities)
    A_quick_causal = np.zeros((num_nodes + 1, num_nodes + 1), dtype=np.int8)

    # Precompute node time ordinals
    node_time_ordinals = build_node_time_ordinals(entities)

    for (i, j), scores in w_direct_llm_results.items():
        mu = float(scores.get("llm_mean_ensemble_score", 0.0))
        r_coh = float(scores.get("llm_semantic_coherence_score", 0.0))

        # Threshold on score & coherence
        if mu <= tau_score or r_coh <= tau_coh:
            continue  # do not include this edge at all

        # Temporal gating (if we have both times)
        t_i = node_time_ordinals.get(i)
        t_j = node_time_ordinals.get(j)

        if t_i is not None and t_j is not None:
            # If i's time is strictly AFTER j's time, i -> j is forbidden.
            if t_i > t_j:
                # e.g., event on July 15 should not cause event on July 2
                continue

        # Passed all filters → keep directed edge i -> j
        A_quick_causal[i, j] = 1

    return A_quick_causal


# ---------------------------------------------------------------------
# 3. Confounder search via BFS on A_quick_causal^T
# ---------------------------------------------------------------------

def build_transpose_adj_list(A_quick_causal: np.ndarray) -> Dict[int, List[int]]:
    """
    Build adjacency lists for the transpose graph A^T:

      If A_quick_causal[u, v] == 1 (u -> v in original),
      then in the transpose we store an edge v -> u.

    Returns:
        adj_T: dict[node] -> list of parents in the original graph
    """
    num_nodes = A_quick_causal.shape[0] - 1  # ignore index 0
    adj_T: Dict[int, List[int]] = {i: [] for i in range(1, num_nodes + 1)}

    # Only iterate over 1..N (skip row/col 0)
    for u in range(1, num_nodes + 1):
        row = A_quick_causal[u]
        # find all v where u -> v
        targets = np.where(row == 1)[0]
        for v in targets:
            if v == 0:
                continue
            # In transpose: v -> u
            adj_T[v].append(u)

    return adj_T


def bfs_ancestors(start: int, adj_T: Dict[int, List[int]], max_hops: int) -> Set[int]:
    """
    BFS on the transpose adjacency from 'start', up to 'max_hops' steps.

    This finds all nodes k such that k -> ... -> start in the original graph.

    Returns:
        ancestors: set of node indices k (excluding 'start').
    """
    visited: Set[int] = {start}
    frontier: List[int] = [start]
    ancestors: Set[int] = set()

    hops = 0
    while frontier and hops < max_hops:
        next_frontier: List[int] = []
        for node in frontier:
            for parent in adj_T.get(node, []):
                if parent not in visited:
                    visited.add(parent)
                    ancestors.add(parent)
                    next_frontier.append(parent)
        frontier = next_frontier
        hops += 1

    return ancestors


def compute_potential_confounders(
    A_quick_causal: np.ndarray,
    max_hops: int = 3,
) -> Dict[Tuple[int, int], List[int]]:
    """
    For every directed edge (i, j) in A_quick_causal (i.e., A[i,j] == 1),
    compute the set of potential confounders:

        M_set_confound(i, j) = Ancestors_i ∩ Ancestors_j

    where ancestors are computed on the transpose graph up to 'max_hops'.

    Returns:
        M_set_potential_confounder: dict mapping
            (i, j) -> sorted list of node indices k.
    """
    num_nodes = A_quick_causal.shape[0] - 1
    adj_T = build_transpose_adj_list(A_quick_causal)

    M_set_potential_confounder: Dict[Tuple[int, int], List[int]] = {}

    # Iterate over all directed edges in A_quick_causal
    for i in range(1, num_nodes + 1):
        row = A_quick_causal[i]
        targets = np.where(row == 1)[0]  # all j where i -> j
        for j in targets:
            if j == 0:
                continue

            # Get ancestor sets for i and j
            ancestors_i = bfs_ancestors(i, adj_T, max_hops=max_hops)
            ancestors_j = bfs_ancestors(j, adj_T, max_hops=max_hops)

            confounders = (ancestors_i & ancestors_j) - {i, j}
            M_set_potential_confounder[(i, j)] = sorted(confounders)

    return M_set_potential_confounder


# ---------------------------------------------------------------------
# 4. High-level helper tying everything together
# ---------------------------------------------------------------------

def build_quick_causal_and_confounders(
    entities: List[Dict[str, Any]],
    w_direct_llm_results: Dict[Tuple[int, int], Dict[str, float]],
    tau_score: float = 0.7,
    tau_coh: float = 0.9,
    max_hops: int = 3,
) -> Tuple[np.ndarray, Dict[Tuple[int, int], List[int]]]:
    """
    Main entrypoint:

    1. Build A_quick_causal with:
         - LLM score threshold (tau_score)
         - coherence threshold (tau_coh)
         - temporal constraint (cause time <= effect time when both known)

    2. On this directed graph, compute potential confounders for each edge
       via BFS on the transpose up to 'max_hops' hops.

    Returns:
        A_quick_causal, M_set_potential_confounder
    """
    A_quick_causal = build_quick_causal_adjacency(
        entities=entities,
        w_direct_llm_results=w_direct_llm_results,
        tau_score=tau_score,
        tau_coh=tau_coh,
    )

    M_set_potential_confounder = compute_potential_confounders(
        A_quick_causal,
        max_hops=max_hops,
    )

    return A_quick_causal, M_set_potential_confounder


# ---------------------------------------------------------------------
# 5. Example usage (you would plug in your real data here)
# ---------------------------------------------------------------------
if __name__ == "__main__":
    # Example: plug your `entities` list and `w_direct_llm_results` here.
    # entities = [...]  # from entities.json
    # w_direct_llm_results = {...}  # built from your CoT counterfactual stage

    # A_quick_causal, M_conf = build_quick_causal_and_confounders(
    #     entities,
    #     w_direct_llm_results,
    #     tau_score=0.7,
    #     tau_coh=0.9,
    #     max_hops=3,
    # )
    #
    # print("A_quick_causal shape:", A_quick_causal.shape)
    # print("Potential confounders per edge:")
    # for (i, j), ks in M_conf.items():
    #     print(f"({i}, {j}): {ks}")
    pass

#     entities: List[Dict[str, Any]],
#     w_direct_llm_results: Dict[Tuple[int, int], Dict[str, float]],
#     tau_score: float = 0.7,
#     tau_coh: float = 0.9,
#     max_hops: int = 3,


- A_quick_causal = A_quick_causal is a final directed causal adjacency matrix.

  It is built only from LLM counterfactual results (μ_LLM and R_coh).

  Its only used for ancestor traversal for confounder detection.

- M_set_potential_confounder = For every directed candidate edge i → j, you find the set of nodes k such that.

  Different for each edge — completely pair-specific.

In [None]:
# num_nodes: e.g. from your entities list
num_nodes = max(int(ent["id"][1:]) for ent in nodes)

A_quick_causal, M_set_potential_confounder = build_quick_causal_and_confounders(
    nodes,
    w_direct_llm_results,
    tau_score=0.57, # LLM score
    tau_coh=0.89, # Coh score
    max_hops=3,
)

# Example:
print("Directed edges in A_quick_causal:")
num_nodes = A_quick_causal.shape[0] - 1  # ignore index 0
for i in range(1, num_nodes + 1):
    for j in range(1, num_nodes + 1):
        if A_quick_causal[i, j] == 1:
            print(f"{i} -> {j}")

print("\nPotential confounders per edge:")
for (i, j), confs in M_set_potential_confounder.items():
    print(f"({i}, {j}): {confs}")


In [None]:
## Save the above

import os
import pickle

save_dir = "extracted_output/cocad/w_direct"
os.makedirs(save_dir, exist_ok=True)

# Save adjacency matrix
with open(os.path.join(save_dir, "A_quick_causal.pkl"), "wb") as f:
    pickle.dump(A_quick_causal, f)

# Save confounder map
with open(os.path.join(save_dir, "M_set_potential_confounder.pkl"), "wb") as f:
    pickle.dump(M_set_potential_confounder, f)

print("Saved A_quick_causal and M_set_potential_confounder into:", save_dir)


In [None]:
## Load the above

import pickle
import os

load_dir = "extracted_output/cocad/w_direct"

with open(os.path.join(load_dir, "A_quick_causal.pkl"), "rb") as f:
    A_quick_causal = pickle.load(f)

with open(os.path.join(load_dir, "M_set_potential_confounder.pkl"), "rb") as f:
    M_set_potential_confounder = pickle.load(f)

print("Loaded A_quick_causal:", type(A_quick_causal), A_quick_causal.shape)
print("Example edges:")
num_nodes = A_quick_causal.shape[0] - 1
for i in range(1, num_nodes + 1):
    for j in range(1, num_nodes + 1):
        if A_quick_causal[i, j] == 1:
            print(f"{i} -> {j}")

print("\nLoaded confounder map size:", len(M_set_potential_confounder))


In [None]:
type(M_set_potential_confounder)

In [None]:
num_nodes = max(int(ent["id"][1:]) for ent in nodes)

pprint.pprint(num_nodes)

In [None]:
pprint.pprint(entities_by_id)

#### B. Dual Check on confounders

Now we have (i,j) pairs which possibly have a potential confounder, so we just need to evaluate for (i,j) pairs in the A_quick_causal.

We want a mathematically principled way to evaluate whether the link $i \rightarrow j$ is still meaningful after removing the influence of all potential confounders $M_{set-confound}$(i,j).

Above for a given pair (i,j) remove all its corresponding confounders at once then calculate the score of the direct link $i \rightarrow j$.

Above for a given link (i,j) we can get 3 types of scores:
- Positive $f_{CI}(i,j)$: i and j still move together even after removing all confounders → suggests a true direct causal link likely exists.
- Zero $f_{CI}(i,j)$: i and j become independent once confounders are removed → the original link was likely spurious / confounded, not causal.
- Negative $f_{CI}(i,j)$: i and j move in opposite directions after conditioning on confounders → indicates a negative / inhibitory causal influence or a suppressed relationship.

##### The statistical check $f_{CI}$

(Just added the Ledoit-Wolf and GraphicalLasso (there is an option to toggle between these) for regularization which arent there in the report currently)

In [None]:
import numpy as np
from typing import Dict, Tuple, List, Union, Any
from sklearn.covariance import LedoitWolf, GraphicalLasso

ArrayLike = Union[np.ndarray, "torch.Tensor"]  # torch optional


def compute_f_ci_scores(
    Z: ArrayLike,
    M_set_potential_confounder: Dict[Tuple[int, int], List[int]],
    lambda_shrink: float = 0.25,           # 0 = no LW, 1 = full LW, 0.2–0.4 is "slight"
    use_graphical_lasso: bool = False,
    graphical_lasso_alpha: float = 0.01,
) -> Dict[Tuple[int, int], float]:
    """
    Compute fCI(i, j) for each directed pair (i, j), using a REGULARIZED precision
    matrix and supporting multiple confounders.

    V' = {i, j} ∪ M_set_confound(i, j)
    Z_V' : dz x n'   (dz = latent dim, n' = |V'|)
    Σ_reg : n' x n'  regularized covariance over nodes
    Θ     : n' x n'  precision matrix (inverse of Σ_reg or learned via GraphicalLasso)

    fCI(i, j) = - Θ_ij / sqrt(Θ_ii * Θ_jj)

    Inputs
    ------
    Z : np.ndarray or torch.Tensor, shape (N_nodes+1, dz) or (N_nodes, dz)
        GAE latent embeddings. Row index k is node k (1-based); row 0 may be N0.

    M_set_potential_confounder : dict
        {
          (i, j): [k1, k2, ...]   # all indices are 1-based node IDs
        }

    lambda_shrink : float in [0, 1]
        How much to blend Ledoit–Wolf covariance into the raw covariance:
            Σ_reg = (1 - λ) * Σ_raw + λ * Σ_LW

    use_graphical_lasso : bool
        If True, ignore lambda_shrink and use GraphicalLasso directly to learn Θ.

    Returns
    -------
    f_ci_scores : dict
        {
          (i, j): f_ci_value (float in [-1, 1], or 0.0 if degenerate)
        }
    """

    # --- 0. Normalize Z to numpy array ---
    if hasattr(Z, "detach"):  # torch.Tensor
        Z_np = Z.detach().cpu().numpy().astype(np.float64)
    else:
        Z_np = np.asarray(Z, dtype=np.float64)

    num_nodes, dz = Z_np.shape
    f_ci_scores: Dict[Tuple[int, int], float] = {}

    for (i, j), confounders in M_set_potential_confounder.items():
        # Ensure 1-based indices
        i_idx = int(i)
        j_idx = int(j)

        # Skip invalid indices (N0 or out-of-range)
        if i_idx <= 0 or j_idx <= 0:
            f_ci_scores[(i, j)] = 0.0
            continue
        if i_idx >= num_nodes or j_idx >= num_nodes:
            f_ci_scores[(i, j)] = 0.0
            continue

        # --- 1. Build V' = {i, j} ∪ confounders ---
        V_prime: List[int] = [i_idx, j_idx]
        for k in confounders:
            k_idx = int(k)
            if 0 < k_idx < num_nodes and k_idx not in V_prime:
                V_prime.append(k_idx)

        n_prime = len(V_prime)

        # Need at least i and j, and at least 2 embedding dims
        if n_prime < 2 or dz < 2:
            f_ci_scores[(i, j)] = 0.0
            continue

        # --- 2. Extract embeddings for nodes in V' ---
        # Z_np: (N_nodes+1, dz)
        rows = [idx for idx in V_prime]   # 1-based rows directly
        Z_sub = Z_np[rows, :]            # (n_prime, dz)
        Z_Vprime = Z_sub.T               # (dz, n_prime) -> rows = embedding dims, cols = nodes

        # --- 3. Raw covariance Σ_raw over nodes ---
        # np.cov with rowvar=False means: each column is a variable (node).
        try:
            Sigma_raw = np.cov(Z_Vprime, rowvar=False)  # (n_prime, n_prime)
        except Exception:
            f_ci_scores[(i, j)] = 0.0
            continue

        if Sigma_raw.shape != (n_prime, n_prime):
            f_ci_scores[(i, j)] = 0.0
            continue

        # --- 3B. Regularization ---
        if use_graphical_lasso:
            # GraphicalLasso learns precision directly from the same orientation:
            # X: samples x features = dz x n_prime
            try:
                gl = GraphicalLasso(alpha=graphical_lasso_alpha).fit(Z_Vprime)
                Theta = gl.precision_  # (n_prime, n_prime)
            except Exception:
                # Fallback: plain pseudo-inverse of raw covariance
                try:
                    Theta = np.linalg.pinv(Sigma_raw)
                except Exception:
                    f_ci_scores[(i, j)] = 0.0
                    continue
        else:
            # Ledoit–Wolf shrinkage on covariance across nodes (same shape as Sigma_raw)
            # X: samples x features = dz x n_prime
            try:
                lw = LedoitWolf().fit(Z_Vprime)
                Sigma_LW = lw.covariance_  # (n_prime, n_prime)
            except Exception:
                Sigma_LW = Sigma_raw

            lam = float(np.clip(lambda_shrink, 0.0, 1.0))
            Sigma_reg = (1.0 - lam) * Sigma_raw + lam * Sigma_LW

            # Precision matrix via pseudo-inverse
            try:
                Theta = np.linalg.pinv(Sigma_reg)
            except Exception:
                f_ci_scores[(i, j)] = 0.0
                continue

        # --- 4. Compute partial correlation fCI(i, j) from Θ ---
        # Positions of i, j in V_prime: [0] and [1] by construction
        pos_i = 0
        pos_j = 1

        theta_ij = Theta[pos_i, pos_j]
        theta_ii = Theta[pos_i, pos_i]
        theta_jj = Theta[pos_j, pos_j]

        denom = theta_ii * theta_jj
        if denom <= 0:
            f_ci = 0.0
        else:
            f_ci = -theta_ij / np.sqrt(denom)
            f_ci = float(np.clip(f_ci, -1.0, 1.0))

        f_ci_scores[(i, j)] = f_ci

    return f_ci_scores


In [None]:
# Z is your GAE embeddings (torch.Tensor or np.ndarray), shape (N_nodes+1, dz)
# M_set_potential_confounder is what we built earlier:
#   {(i, j): [k1, k2, ...], ...}   with 1-based node IDs

f_ci_dict = compute_f_ci_scores(
    Z,
    M_set_potential_confounder,
    lambda_shrink=0.25   # <-- relaxed and works best for your size
)


# Example: see fCI for edge (1, 3)
print("fCI(1,3) =", f_ci_dict.get((1, 3)))


In [None]:
pprint.pprint(f_ci_dict)

##### The LLM check $f_{llm-confounder}$

w_direct_llm_confounders

In [None]:
import os
import json
import time
import math
from typing import Dict, Tuple, List, Any, Optional

import numpy as np
from sentence_transformers import SentenceTransformer

from google.genai import Client, types
from google.genai import errors as genai_errors
from pydantic import BaseModel, ValidationError

# -------------------------
# Configuration / Hyperparams
# -------------------------
MODEL_CANDIDATES = [
    "gemini-2.5-flash-lite",
    "gemini-2.5-flash",
    "gemini-2.0-flash-lite",
]
MODEL = MODEL_CANDIDATES[0]
CURRENT_MODEL_INDEX = 0
EXHAUSTED_MODELS = set()

NS_SAMPLES = 5               # number of CoT samples per pair
CF_TEMPERATURE = 0.7
CF_DELAY_SECONDS = 15         # seconds between requests
TAU_COARSE_PARSE_SLEEP = 0.5

# Persistence paths
OUT_DIR = "extracted_output"
COCAD_DIR = os.path.join(OUT_DIR, "cocad", "w_direct")
RAW_REASONINGS_DIR = os.path.join(COCAD_DIR, "llm_confounder_raw_reasonings")
RESULTS_JSON = os.path.join(COCAD_DIR, "w_direct_llm_confounders.json")
EMBEDDINGS_NPZ = os.path.join(COCAD_DIR, "llm_confounder_reasoning_embeddings.npz")

os.makedirs(RAW_REASONINGS_DIR, exist_ok=True)
os.makedirs(COCAD_DIR, exist_ok=True)

# Initialize LLM client + embedder
client = Client(api_key=os.environ.get("GEMINI_API_KEY", ""))
embedder = SentenceTransformer("all-MiniLM-L6-v2")


# -------------------------
# Pydantic model for structured LLM response
# -------------------------
class ConfounderResponse(BaseModel):
    reasoning: str
    score: float


# -------------------------
# Helpers: model rotation + safe structured call
# -------------------------
def _switch_to_next_model():
    global MODEL, CURRENT_MODEL_INDEX
    EXHAUSTED_MODELS.add(MODEL)
    for i in range(len(MODEL_CANDIDATES)):
        idx = (CURRENT_MODEL_INDEX + 1 + i) % len(MODEL_CANDIDATES)
        candidate = MODEL_CANDIDATES[idx]
        if candidate not in EXHAUSTED_MODELS:
            MODEL = candidate
            CURRENT_MODEL_INDEX = idx
            print(f"[LLM] (confounder) switching to backup model: {MODEL}")
            return
    raise RuntimeError("All configured models exhausted (confounder).")


def safe_generate_structured(
    prompt: str,
    temperature: float = CF_TEMPERATURE,
    max_retries: int = 3,
) -> Optional[ConfounderResponse]:
    """
    Try to get a structured JSON response parsed into ConfounderResponse.
    Uses model rotation and retries on quota errors.
    Returns parsed ConfounderResponse or None if parsing failed.
    """
    last_exc = None
    for attempt in range(1, max_retries + 1):
        from_model = MODEL
        config = types.GenerateContentConfig(
            temperature=temperature,
            response_mime_type="application/json",
            response_schema=ConfounderResponse,
        )
        try:
            print(f"[LLM] (confounder) structured call attempt {attempt} model={from_model}")
            resp = client.models.generate_content(model=from_model, contents=prompt, config=config)
            parsed = getattr(resp, "parsed", None)
            if isinstance(parsed, ConfounderResponse):
                return parsed
            # if parsed is a dict-like that pydantic can parse, try that
            try:
                if parsed is not None:
                    return ConfounderResponse.parse_obj(parsed)
            except ValidationError:
                pass
            # parsing failed even though model returned something; fallthrough to raw-text fallback
            raw_text = getattr(resp, "text", "") or ""
            # attempt to parse raw_text heuristically below in caller if None returned
            return None
        except genai_errors.ClientError as e:
            msg = str(e)
            last_exc = e
            print(f"[LLM] (confounder) client error on {from_model}: {msg}")
            # detect quota/rate-limit
            if "RESOURCE_EXHAUSTED" in msg or "429" in msg or "exceeded your current quota" in msg.lower():
                try:
                    _switch_to_next_model()
                except RuntimeError:
                    raise
                time.sleep(CF_DELAY_SECONDS)
                continue
            # otherwise escalate
            raise
    if last_exc:
        raise last_exc
    return None


# -------------------------
# Fallback parser for unstructured/raw responses
# -------------------------
def parse_json_like_response(raw_text: str) -> Tuple[str, Optional[float]]:
    """
    Attempts to parse the LLM reply expected to be JSON like:
      {"reasoning": "...", "score": 0.72}
    Returns (reasoning_str, score_float_or_None).
    If parsing fails, tries to heuristically extract a trailing SCORE: <num>.
    """
    if not raw_text:
        return "", None

    # 1) Try JSON parse
    try:
        j = json.loads(raw_text)
        reasoning = j.get("reasoning") if isinstance(j.get("reasoning"), str) else ""
        score = j.get("score")
        try:
            score = float(score) if score is not None else None
        except Exception:
            score = None
        return (reasoning.strip(), score)
    except Exception:
        pass

    # 2) Try to find a last line "SCORE: <num>"
    lines = [ln.strip() for ln in raw_text.splitlines() if ln.strip()]
    score = None
    reasoning = raw_text.strip()
    for ln in reversed(lines):
        up = ln.upper()
        if up.startswith("SCORE"):
            parts = ln.split(":")
            if len(parts) >= 2:
                try:
                    score = float(parts[1].strip())
                except Exception:
                    score = None
            # reasoning = everything before this line
            idx = raw_text.rfind(ln)
            if idx != -1:
                reasoning = raw_text[:idx].strip()
            break

    # 3) If still no score, try to find a number between 0 and 1 anywhere
    if score is None:
        import re
        m = re.search(r"([0](?:\.\d+)?|1(?:\.0+)?)", raw_text)
        if m:
            try:
                cand = float(m.group(1))
                if 0.0 <= cand <= 1.0:
                    score = cand
            except Exception:
                score = None

    if reasoning is None:
        reasoning = ""
    return (reasoning.strip(), score)


# -------------------------
# Semantic coherence helper
# -------------------------
def mean_pairwise_cosine(embeddings: List[np.ndarray]) -> float:
    """
    embeddings: list of 1D numpy arrays (vectors)
    returns average pairwise cosine similarity across p<q pairs.
    If len < 2 returns 1.0 (fully coherent) by convention.
    """
    n = len(embeddings)
    if n < 2:
        return 1.0
    mats = np.vstack(embeddings)  # shape (n, d)
    norms = np.linalg.norm(mats, axis=1, keepdims=True) + 1e-12
    mats = mats / norms
    sim = mats @ mats.T  # (n,n) cosine matrix
    iu = np.triu_indices(n, k=1)
    vals = sim[iu]
    return float(vals.mean())


# -------------------------
# Main Option B Implementation (uses structured responses with fallback)
# -------------------------
def llm_confounder_cot_check_and_persist(
    M_set_potential_confounder: Dict[Tuple[int, int], List[int]],
    entities_by_id: Dict[int, Dict[str, Any]],
    pairwise_context_map: Dict[Tuple[int, int], str],
    ns_samples: int = NS_SAMPLES,
    temperature: float = CF_TEMPERATURE,
    delay_seconds: int = CF_DELAY_SECONDS,
) -> Dict[Tuple[int, int], Dict[str, Any]]:
    """
    Runs the LLM CoT confounder checks for each (i,j) in M_set_potential_confounder.

    Inputs:
      - M_set_potential_confounder: {(i,j): [k1,k2,...], ...} (1-based ints)
      - entities_by_id: {id_int: entity_dict} where entity_dict contains 'name' etc.
      - pairwise_context_map: {(i,j): evidence_text} from earlier RAG stage.

    Returns:
      - results: dict keyed by (i,j) containing:
          {
            "mu_score": float,
            "var_score": float,
            "R_coh": float,
            "Cconfidence": float,
            "S_final": float,
            "raw_reasonings": [...],
            "raw_scores": [...],
            "reasoning_filenames": [...],
            "reasoning_embedding_keys": [...],
            "structured_responses": [...],  # raw parsed objects if available
          }
    Also persists:
      - RESULTS_JSON (summary under extracted_output/cocad/w_direct)
      - EMBEDDINGS_NPZ (all embeddings per pair)
      - RAW_REASONINGS_DIR/* (raw .txt files per sample)
    """

    all_results: Dict[Tuple[int, int], Dict[str, Any]] = {}
    embeddings_store: Dict[str, np.ndarray] = {}

    pairs = list(M_set_potential_confounder.keys())
    total_pairs = len(pairs)
    print(f"[LLM-Confounder-CoT] Running on {total_pairs} directed pairs...")

    for idx_pair, (i, j) in enumerate(pairs, start=1):
        print(f"\n[{idx_pair}/{total_pairs}] Pair (i={i}, j={j})")

        confounders = M_set_potential_confounder.get((i, j), [])
        i_name = entities_by_id.get(int(i), {}).get("name", f"N{int(i)}")
        j_name = entities_by_id.get(int(j), {}).get("name", f"N{int(j)}")
        confounder_names = [
            entities_by_id.get(int(k), {}).get("name", f"N{int(k)}") for k in confounders
        ]

        context_text = pairwise_context_map.get((i, j), "")
        if not context_text:
            print("  Warning: no context / evidence_text for this pair; skipping.")
            all_results[(i, j)] = {
                "mu_score": 0.0,
                "var_score": 0.0,
                "R_coh": 1.0,
                "Cconfidence": 0.0,
                "S_final": 0.0,
                "raw_reasonings": [],
                "raw_scores": [],
                "reasoning_filenames": [],
                "reasoning_embedding_keys": [],
                "structured_responses": [],
                "note": "no context"
            }
            continue

        # Build structured / constrained prompt
        confounder_list_str = (
            "[" + ", ".join([f'"{n}"' for n in confounder_names]) + "]"
            if confounder_names
            else "[]"
        )
        prompt = f"""
You are a careful causal-reasoning assistant.

Context (evidence):
\"\"\"{context_text}\"\"\" 

Claim to test: Does '{i_name}' -> '{j_name}' represent a direct causal influence?

Potential Confounders (held constant): {confounder_list_str}

Task (2 steps):
1) Reasoning: If '{i_name}' were NOT present, while all Potential Confounders above are HELD CONSTANT,
   explain step-by-step only using the Context why there would still / would not be a direct effect on '{j_name}'.
2) Return a single JSON object EXACTLY in the form:
   {{ "reasoning": "<your step-by-step reasoning>", "score": <a decimal between 0.0 and 1.0> }}

Constraints:
- Use only information from the provided Context (do NOT hallucinate new facts).
- Keep the "reasoning" concise but explicit (2-6 sentences is fine).
- Return strictly one JSON object (no extra text).
"""

        raw_reasonings: List[str] = []
        raw_scores: List[float] = []
        structured_responses: List[Dict[str, Any]] = []
        per_pair_embedding_keys: List[str] = []
        per_pair_reasoning_filenames: List[str] = []
        per_pair_embeddings: List[np.ndarray] = []

        # stochastic samples
        for s_idx in range(ns_samples):
            print(f"   sample {s_idx+1}/{ns_samples} ...", end=" ")
            parsed_obj: Optional[ConfounderResponse] = None
            raw_text_for_file = ""
            try:
                # first try to get structured parsed response via the typed schema call
                parsed_obj = safe_generate_structured(prompt=prompt, temperature=temperature)
                if parsed_obj is not None:
                    reasoning_str = parsed_obj.reasoning.strip()
                    score_val = float(parsed_obj.score)
                    raw_text_for_file = json.dumps(parsed_obj.dict(), ensure_ascii=False)
                else:
                    # fallback: ask raw text using safe llm and parse heuristically
                    # reuse safe_generate_structured's model rotation logic by calling the unstructured safe path
                    # (we call the current MODEL directly)
                    print("(structured parse failed; falling back to raw text) ", end="")
                    from_model = MODEL
                    config = types.GenerateContentConfig(temperature=temperature)
                    resp = client.models.generate_content(model=from_model, contents=prompt, config=config)
                    raw_text = getattr(resp, "text", "") or ""
                    raw_text_for_file = raw_text
                    # attempt to parse heuristically
                    reasoning_str, score_val = parse_json_like_response(raw_text)
                    if score_val is None:
                        score_val = 0.0

                # persist raw reasoning text to file
                fn = os.path.join(RAW_REASONINGS_DIR, f"pair_{i}_{j}_sample_{s_idx+1}.txt")
                with open(fn, "w", encoding="utf-8") as fw:
                    fw.write(raw_text_for_file)
                per_pair_reasoning_filenames.append(fn)

                raw_reasonings.append(reasoning_str)
                raw_scores.append(float(score_val))

                # save structured dict if available (for auditing)
                if parsed_obj is not None:
                    structured_responses.append(parsed_obj.dict())
                else:
                    # store a minimal structured fallback
                    structured_responses.append({"reasoning": reasoning_str, "score": float(score_val)})

                # embed the reasoning (or empty vector if empty)
                if isinstance(reasoning_str, str) and reasoning_str.strip():
                    emb = embedder.encode([reasoning_str], convert_to_numpy=True)[0]
                else:
                    emb = np.zeros(embedder.get_sentence_embedding_dimension(), dtype=np.float32)

                emb_key = f"pair_{i}_{j}_s{s_idx+1}"
                embeddings_store[emb_key] = emb
                per_pair_embedding_keys.append(emb_key)
                per_pair_embeddings.append(emb)

                print(f"SCORE={score_val:.3f}")

            except Exception as e:
                print(f"[ERROR] LLM/sample failed: {e}")
                # store fallback placeholders, persist error file
                raw_reasonings.append("")
                raw_scores.append(0.0)
                structured_responses.append({"reasoning": "", "score": 0.0})
                fn = os.path.join(RAW_REASONINGS_DIR, f"pair_{i}_{j}_sample_{s_idx+1}_error.txt")
                with open(fn, "w", encoding="utf-8") as fw:
                    fw.write(f"[ERROR] {e}\n")
                per_pair_reasoning_filenames.append(fn)
                emb_key = f"pair_{i}_{j}_s{s_idx+1}"
                emb = np.zeros(embedder.get_sentence_embedding_dimension(), dtype=np.float32)
                embeddings_store[emb_key] = emb
                per_pair_embedding_keys.append(emb_key)
                per_pair_embeddings.append(emb)

            # rate-limit pause
            time.sleep(delay_seconds)

        # Aggregate stats
        scores_arr = np.asarray(raw_scores, dtype=float)
        mu_score = float(np.mean(scores_arr)) if len(scores_arr) > 0 else 0.0
        var_score = float(np.var(scores_arr, ddof=0)) if len(scores_arr) > 0 else 0.0

        # compute normalized variance in [0,1] for Cconfidence
        eps = 1e-8
        normalized_variance = var_score / (var_score + (mu_score ** 2) + eps)
        normalized_variance = float(np.clip(normalized_variance, 0.0, 1.0))
        Cconfidence = float(1.0 - normalized_variance)

        # semantic coherence Rcoh
        R_coh = float(mean_pairwise_cosine(per_pair_embeddings))

        # final combined score (as proposed): mu * Cconfidence
        S_final = float(mu_score * Cconfidence)

        # save per-pair summary
        all_results[(i, j)] = {
            "mu_score": mu_score,
            "var_score": var_score,
            "R_coh": R_coh,
            "Cconfidence": Cconfidence,
            "S_final": S_final,
            "raw_reasonings": raw_reasonings,
            "raw_scores": raw_scores,
            "reasoning_filenames": per_pair_reasoning_filenames,
            "reasoning_embedding_keys": per_pair_embedding_keys,
            "structured_responses": structured_responses,
            "confounder_names": confounder_names,
            "i_name": i_name,
            "j_name": j_name,
            "context_snippet_len": len(context_text),
        }

        # progress info
        print(f"  -> mu={mu_score:.3f}, var={var_score:.4f}, R_coh={R_coh:.3f}, Cconf={Cconfidence:.3f}, Sfinal={S_final:.3f}")

    # Persist summary JSON (convert tuple keys to "i_j")
    serializable_results = {}
    for (i, j), rec in all_results.items():
        key = f"{i}_{j}"
        serializable_results[key] = rec
    with open(RESULTS_JSON, "w", encoding="utf-8") as fout:
        json.dump(serializable_results, fout, indent=2, ensure_ascii=False)
    print(f"\nSaved summary results to: {RESULTS_JSON}")

    # Persist embeddings_store to npz
    np.savez_compressed(EMBEDDINGS_NPZ, **embeddings_store)
    print(f"Saved all reasoning embeddings to: {EMBEDDINGS_NPZ}")

    return all_results


# -------------------------
# Example usage:
# -------------------------
if __name__ == "__main__":

    # Run the pipeline (this will call the LLM and persist files)
    results = llm_confounder_cot_check_and_persist(
        M_set_potential_confounder=M_set_potential_confounder,
        entities_by_id=entities_by_id,
        pairwise_context_map=pairwise_structural_context_retrieval,
        ns_samples=NS_SAMPLES,
        temperature=CF_TEMPERATURE,
        delay_seconds=CF_DELAY_SECONDS,
    )

    print("\nDone. Example item (first) :")
    if results:
        k0 = next(iter(results.keys()))
        print(k0, results[k0])
    else:
        print("no results (empty input).")


In [None]:
from pprint import pprint

pprint(results)

In [None]:
## To save the above results to the extracted_output/cocad/w_direct/w_direct_llm_confounders.json
import os
import json
import numpy as np
from typing import Any

def _to_json_compatible(o: Any):
    """Recursively convert numpy types/arrays and other non-JSON-native types to JSON-friendly types."""
    # Numpy scalar
    if isinstance(o, (np.floating, np.float32, np.float64)):
        return float(o)
    if isinstance(o, (np.integer, np.int32, np.int64)):
        return int(o)
    if isinstance(o, np.ndarray):
        return o.tolist()
    # basic types + lists/dicts/tuples
    if isinstance(o, dict):
        return {str(k): _to_json_compatible(v) for k, v in o.items()}
    if isinstance(o, (list, tuple)):
        return [_to_json_compatible(v) for v in o]
    # fallback to str for unknown objects
    if isinstance(o, (str, int, float, bool)) or o is None:
        return o
    return str(o)

def save_w_direct_llm_confounders(results: dict, save_path: str):
    """
    Save `results` where keys may be tuples (i,j).
    Converts tuple keys to "i_j" string keys and makes values JSON-safe.
    """
    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    serializable = {}
    for k, v in results.items():
        # Turn tuple key (i,j) into "i_j"
        if isinstance(k, tuple):
            key_str = f"{int(k[0])}_{int(k[1])}"
        else:
            key_str = str(k)
        serializable[key_str] = _to_json_compatible(v)

    with open(save_path, "w", encoding="utf-8") as fw:
        json.dump(serializable, fw, indent=2, ensure_ascii=False)

    print(f"[OK] Saved w_direct LLM confounders to: {save_path}")

# example usage:
save_dir = "extracted_output/cocad/w_direct"
save_path = os.path.join(save_dir, "w_direct_llm_confounders.json")
# results is your dict from earlier
save_w_direct_llm_confounders(results, save_path)


In [None]:
## Load the above 


import json
from typing import Dict, Tuple, Any

def load_w_direct_llm_confounders(load_path: str) -> Dict[Tuple[int,int], Any]:
    """
    Load JSON saved with save_w_direct_llm_confounders and restore tuple keys (i,j).
    Returns dict with keys (i,j) as ints.
    """
    with open(load_path, "r", encoding="utf-8") as fr:
        raw = json.load(fr)

    results_restored = {}
    for key_str, value in raw.items():
        # Expect keys like "i_j"
        if isinstance(key_str, str) and "_" in key_str:
            try:
                a, b = key_str.split("_", 1)
                k = (int(a), int(b))
            except Exception:
                # fallback: use the key string as-is (not a tuple)
                k = key_str
        else:
            # fallback: use the original key string
            try:
                k = int(key_str)
            except Exception:
                k = key_str

        results_restored[k] = value

    return results_restored

# example usage:
load_path = "extracted_output/cocad/w_direct/w_direct_llm_confounders.json"
w_direct_llm_confounders = load_w_direct_llm_confounders(load_path)
print(f"[OK] Loaded {len(w_direct_llm_confounders)} items from {load_path}")
# inspect one:
if w_direct_llm_confounders:
    first_k = next(iter(w_direct_llm_confounders))
    print("Example key:", first_k)
    print("Example value keys:", list(w_direct_llm_confounders[first_k].keys())[:10])


What all is saved above and where:

- Summary JSON:

Saved in extracted_output/llm_confounder_results.json

- Raw LLM outputs for audit/provenances

Saved to extracted_output/llm_confounder_raw_reasonings/.
There we have One .txt file per sample:

```
pair_<i>_<j>_sample_<s>.txt
pair_<i>_<j>_sample_<s>_error.txt   (if LLM call failed)

```

- Reasoning Embeddings (for semantic coherence R_coh):

Saved to extracted_output/llm_confounder_reasoning_embeddings.npz

```
{
  "pair_i_j_s1": <768-d vector>,
  "pair_i_j_s2": <768-d vector>,
  ...
}
```

Each key corresponds to a single CoT reasoning sample for a pair.

Used to compute semantic coherence:

$R_{coh}​(i,j)$=mean pairwise cosine similarity

In [None]:
## Lets find our initial W_direct results from earlier stages
## These would later anyays be refined in EM-refinement stage

## As of now lets take W_direct as product of structural and LLM stages f*** the f_fusion as of now



In [None]:
from pprint import pprint

pprint(w_direct_structural_results)

In [None]:
pprint(w_direct_llm_results)