## (c). EM-Refinement loop (The continuation of thea earlier)

### A. E-step (Expectation / Pseudo-Label Estimation)

In [None]:
from __future__ import annotations
import json
import math
import random
import time
import uuid
from pathlib import Path
from typing import Any, Dict, List, Tuple, Optional

import numpy as np
import networkx as nx
import lightgbm as lgb
import joblib
from sklearn.isotonic import IsotonicRegression
from sentence_transformers import SentenceTransformer
import torch
import torch.nn as nn

GLOBAL_SEED = 20251127
random.seed(GLOBAL_SEED)
np.random.seed(GLOBAL_SEED)
torch.manual_seed(GLOBAL_SEED)

REAL_CORPUS_DIR = Path("real_corpus")        
E_PRIOR_PATH = Path("e_prior.jsonl")           
DCORPUS_DIR = Path("d_corpus_output")        
G_TRUE_PATH = Path("graphs/G_true.gpickle")     # ground-truth graph
OUT_DIR = Path("out/e_step")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Fusion artifacts (from earlier training)
FUSION_MODEL_DIR = Path("out/fusion/model_bundle")
CLASSIFIER_PATH = FUSION_MODEL_DIR / "classifier.txt"
RANKER_PATH = FUSION_MODEL_DIR / "ranker.txt"
CALIBRATOR_PATH = FUSION_MODEL_DIR / "calibrator.pkl"
FEATURE_SCHEMA_PATH = FUSION_MODEL_DIR / "feature_schema.json"

# LPA artifacts
LPA_MODEL_DIR = Path("out/lpa_model")
LPA_ENCODER_PTH = LPA_MODEL_DIR / "encoder_finetuned.pth"
LPA_AGG_PTH = LPA_MODEL_DIR / "aggregator_finetuned.pth"
LPA_BUNDLE_JSON = LPA_MODEL_DIR / "lpa_model_bundle.json"

# fembed artifacts
FEMBED_DIR = Path("models/fembed_sota")

# Gemini LLM wrapper config
USE_GEMINI = True
MODEL_CANDIDATES = [
    "gemini-2.5-flash-lite",
    "gemini-2.5-flash",
    "gemini-2.0-flash-lite",
]
LLM_DELAY = 15.0
MAX_LLM_ATTEMPTS_PER_MODEL = 2

# E-step parameters
MAX_PATH_CUTOFF = 4
NULL_SAMPLES_PER_PAIR = 200
P_VALUE_SEED_OFFSET = 1000000
W_MIN = 0.01

# Output files
C_PRIOR_OUTPUT = OUT_DIR / "C_prior_round_1.jsonl"
DDISTILL_OUTPUT = OUT_DIR / "Ddistill_round1.jsonl"
SUMMARY_OUTPUT = OUT_DIR / "E_step_round_1_summary.json"

# ----------------------------
# Validations
# ----------------------------
required_paths = [
    E_PRIOR_PATH,
    DCORPUS_DIR,
    G_TRUE_PATH,
    CLASSIFIER_PATH,
    RANKER_PATH,
    CALIBRATOR_PATH,
    FEATURE_SCHEMA_PATH,
    LPA_ENCODER_PTH,
    LPA_AGG_PTH,
    LPA_BUNDLE_JSON,
    FEMBED_DIR,
]
for p in required_paths:
    if not p.exists():
        raise FileNotFoundError(f"Required path missing: {p}")

# ----------------------------
# Gemini wrapper
# ----------------------------
try:
    import google.generativeai as genai
except Exception:
    genai = None

class GeminiClientStrict:
    def __init__(self, models: List[str], attempts_per_model: int = 2, llm_delay: float = 15.0):
        if genai is None:
            raise ImportError("google.generativeai not available")
        self.models = models
        self.attempts = attempts_per_model
        self.delay = llm_delay

    def _sleep(self):
        time.sleep(self.delay)

    def _try_model_call(self, model: str, prompt: str, temperature: float = 0.0, max_tokens: int = 128) -> Optional[str]:
        attempts = 0
        while attempts < self.attempts:
            attempts += 1
            self._sleep()
            try:
                resp = genai.generate_text(model=model, prompt=prompt, temperature=temperature, max_output_tokens=max_tokens)
                if resp is None:
                    continue
                if isinstance(resp, dict):
                    cands = resp.get("candidates", [])
                    if cands:
                        content = cands[0].get("content")
                        if content:
                            return str(content).strip()
                else:
                    cands = getattr(resp, "candidates", None)
                    if isinstance(cands, list) and len(cands) > 0:
                        content = cands[0].get("content") if isinstance(cands[0], dict) else getattr(cands[0], "content", None)
                        if content:
                            return str(content).strip()
            except Exception:
                time.sleep(0.5 * attempts)
                continue
        return None

    def generate_text(self, prompt: str, temperature: float = 0.0, max_tokens: int = 128) -> str:
        for m in self.models:
            out = self._try_model_call(m, prompt, temperature=temperature, max_tokens=max_tokens)
            if out is not None:
                return out
        raise RuntimeError("All Gemini models failed")

# ----------------------------
# Load models & artifacts
# ----------------------------
clf = lgb.Booster(model_file=str(CLASSIFIER_PATH))
ranker = lgb.Booster(model_file=str(RANKER_PATH))
calibrator = joblib.load(CALIBRATOR_PATH)
feature_schema = json.loads(FEATURE_SCHEMA_PATH.read_text(encoding="utf8"))
feature_cols = feature_schema["feature_order"]

# f_embed loader (backbone only)
fembed_cfg = json.loads((FEMBED_DIR / "config.json").read_text(encoding="utf8"))
backbone_name = fembed_cfg.get("backbone", "all-MiniLM-L6-v2")
fembed = SentenceTransformer(backbone_name)
fembed.max_seq_length = 256

# LPA model definitions 
class PathTransformerEncoder(nn.Module):
    def __init__(self, node_emb_dim: int, d_model: int = 256, nhead: int = 8, num_layers: int = 4, dim_feedforward: int = 512, d_path: int = 256, max_len: int = 32):
        super().__init__()
        self.input_proj = nn.Linear(node_emb_dim, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.pos_emb = nn.Embedding(max_len, d_model)
        self.pool = nn.Linear(d_model, d_path)
        self.norm = nn.LayerNorm(d_path)
        self.max_len = max_len
    def forward(self, x):
        B, L, _ = x.shape
        pos_ids = torch.arange(L, device=x.device).unsqueeze(0).expand(B, -1)
        x = self.input_proj(x) + self.pos_emb(pos_ids)
        x = self.transformer(x)
        x_mean = x.mean(dim=1)
        h = self.pool(x_mean)
        h = self.norm(h)
        return h

class AttentionAggregator(nn.Module):
    def __init__(self, d_path: int = 256, path_feat_dim: int = 8, hidden_dim: int = 256):
        super().__init__()
        self.path_feat_proj = nn.Linear(path_feat_dim, d_path) if path_feat_dim>0 else None
        self.att_mlp = nn.Sequential(nn.Linear(d_path*2 if self.path_feat_proj is not None else d_path, hidden_dim), nn.GELU(), nn.Linear(hidden_dim,1))
        self.head = nn.Sequential(nn.Linear(d_path+4, hidden_dim), nn.GELU(), nn.Linear(hidden_dim,1))
    def forward(self, path_embs, path_feats=None):
        single=False
        if path_embs.dim()==2:
            path_embs = path_embs.unsqueeze(0); single=True
        B,m,d = path_embs.shape
        if self.path_feat_proj is not None and path_feats is not None:
            pf = self.path_feat_proj(path_feats)
            att_in = torch.cat([path_embs, pf], dim=-1)
        else:
            att_in = path_embs
        logits = self.att_mlp(att_in).squeeze(-1)
        att = torch.softmax(logits, dim=-1)
        agg = (att.unsqueeze(-1) * path_embs).sum(dim=1)
        num_paths = torch.tensor([m], device=path_embs.device, dtype=torch.float32).expand(B).unsqueeze(-1)
        stats_cat = torch.cat([num_paths, torch.zeros((B,3), device=path_embs.device)], dim=-1)
        head_in = torch.cat([agg, stats_cat], dim=-1)
        out = torch.sigmoid(self.head(head_in).squeeze(-1))
        if single:
            return out.squeeze(0)
        return out

lpa_bundle = json.loads(LPA_BUNDLE_JSON.read_text(encoding="utf8"))
MAX_PATH_LENGTH = int(lpa_bundle.get("config", {}).get("MAX_PATH_LENGTH", 32))
D_PATH = int(lpa_bundle.get("config", {}).get("D_PATH", 256))
NODE_EMB_DIM = int(lpa_bundle.get("config", {}).get("NODE_EMB_DIM", 768))

encoder = PathTransformerEncoder(node_emb_dim=NODE_EMB_DIM, d_path=D_PATH, max_len=MAX_PATH_LENGTH)
aggregator = AttentionAggregator(d_path=D_PATH, path_feat_dim=8, hidden_dim=256)
enc_ckpt = torch.load(LPA_ENCODER_PTH, map_location="cpu")
if "encoder_state" in enc_ckpt:
    encoder.load_state_dict(enc_ckpt["encoder_state"])
else:
    encoder.load_state_dict(enc_ckpt)
agg_ckpt = torch.load(LPA_AGG_PTH, map_location="cpu")
if "aggregator_state" in agg_ckpt:
    aggregator.load_state_dict(agg_ckpt["aggregator_state"])
else:
    aggregator.load_state_dict(agg_ckpt)
encoder.eval()
aggregator.eval()

# ----------------------------
# Helpers
# ----------------------------
def safe_jsonl_writer(p: Path):
    return open(p, "w", encoding="utf8")

def node_identifier(n: int) -> str:
    return f"N{int(n)}"

def load_d_corpus_entities(dcorpus_dir: Path) -> Dict[str, Dict[int, Dict[str,Any]]]:
    ent_map: Dict[str, Dict[int,Dict[str,Any]]] = {}
    files = sorted(dcorpus_dir.glob("*.jsonl"))
    for f in files:
        with f.open("r", encoding="utf8") as fh:
            for ln in fh:
                rec = json.loads(ln)
                ents = rec.get("entities", {})
                gid = rec.get("graph_id", "unknown")
                if gid not in ent_map:
                    ent_map[gid] = {}
                for k, v in ents.items():
                    try:
                        if k.startswith("N"):
                            nid = int(k[1:])
                        else:
                            nid = int(k)
                    except Exception:
                        continue
                    if nid not in ent_map[gid]:
                        ent_map[gid][nid] = v
    return ent_map

def compute_structural_features(G: nx.DiGraph, i: int, j: int, max_hops: int = MAX_PATH_CUTOFF) -> Dict[str,Any]:
    features = {}
    features["isdirect"] = 1 if G.has_edge(i, j) else 0
    features["deg_i"] = int(G.degree(i))
    features["deg_j"] = int(G.degree(j))
    features["in_deg_i"] = int(G.in_degree(i))
    features["out_deg_i"] = int(G.out_degree(i))
    features["in_deg_j"] = int(G.in_degree(j))
    features["out_deg_j"] = int(G.out_degree(j))
    try:
        paths = list(nx.all_simple_paths(G, source=i, target=j, cutoff=max_hops))
    except Exception:
        paths = []
    features["num_paths_upto_k"] = len(paths)
    if len(paths) > 0:
        lengths = [len(p)-1 for p in paths]
        features["avg_path_len"] = float(sum(lengths)/len(lengths))
    else:
        features["avg_path_len"] = 0.0
    features["avg_path_internal_deg"] = 0.0
    try:
        gen = nx.algorithms.connectivity.disjoint_paths.node_disjoint_paths(G, i, j)
        count = 0
        for _ in gen:
            count += 1
            if count >= 4:
                break
        features["k_node_disjoint_paths"] = int(count)
    except Exception:
        features["k_node_disjoint_paths"] = 0
    return features

def assemble_vij_from_struct(i:int, j:int, G: nx.DiGraph, rec_meta: Dict[str,Any], pos_score: float, neg_score: float, path_nodes: List[int]) -> Dict[str,Any]:
    structural = compute_structural_features(G, i, j)
    mu_LLM = float(pos_score)
    var_LLM = float((pos_score - neg_score)**2)
    mu_LLM_cond = mu_LLM
    var_LLM_cond = var_LLM
    p_plaus = float(rec_meta.get("p_plaus", 0.0))
    p_temp = float(rec_meta.get("p_temp", 0.0))
    p_mech = float(rec_meta.get("p_mech", 0.0))
    if path_nodes:
        unique_nodes = list(set(path_nodes))
        degs = [G.degree(n) for n in unique_nodes]
        sgae = float(np.mean(degs)) if degs else 0.0
    else:
        sgae = 0.0
    feat = {
        "deg_i": structural["deg_i"],
        "deg_j": structural["deg_j"],
        "in_deg_i": structural["in_deg_i"],
        "out_deg_i": structural["out_deg_i"],
        "in_deg_j": structural["in_deg_j"],
        "out_deg_j": structural["out_deg_j"],
        "num_paths_upto_k": structural["num_paths_upto_k"],
        "avg_path_len": structural["avg_path_len"],
        "avg_path_internal_deg": structural["avg_path_internal_deg"],
        "kdisjoint": structural["k_node_disjoint_paths"],
        "mu_LLM": mu_LLM,
        "var_LLM": var_LLM,
        "mu_LLM_cond": mu_LLM_cond,
        "var_LLM_cond": var_LLM_cond,
        "p_plaus": p_plaus,
        "p_temp": p_temp,
        "p_mech": p_mech,
        "SGAE": sgae,
        "path_node_count": len(path_nodes),
        "pos_neg_diff": pos_score - neg_score
    }
    # ensure ordering consistent
    ordered = {c: float(feat.get(c, 0.0)) for c in feature_cols}
    return ordered

def paths_to_node_embeddings(paths: List[List[int]], node_entity_map: Dict[int,Dict[str,Any]], batch_size:int=64):
    path_vecs = []
    for p in paths:
        texts = []
        for n in p:
            ent = node_entity_map.get(int(n))
            if ent is not None:
                texts.append(ent.get("description","") or ent.get("name",""))
            else:
                texts.append("")
        if len(texts) == 0:
            arr = np.zeros((MAX_PATH_LENGTH, NODE_EMB_DIM), dtype=np.float32)
            path_vecs.append(arr)
            continue
        texts = texts[:MAX_PATH_LENGTH]
        if len(texts) < MAX_PATH_LENGTH:
            texts = texts + [""] * (MAX_PATH_LENGTH - len(texts))
        embs = fembed.encode(texts, convert_to_numpy=True, show_progress_bar=False)
        if embs.shape[1] != NODE_EMB_DIM:
            if embs.shape[1] < NODE_EMB_DIM:
                padw = NODE_EMB_DIM - embs.shape[1]
                embs = np.pad(embs, ((0,0),(0,padw)))
            else:
                embs = embs[:,:NODE_EMB_DIM]
        path_vecs.append(embs.astype(np.float32))
    return np.stack(path_vecs, axis=0)  # [num_paths, L, node_dim]

def encode_paths_with_lpa(paths: List[List[int]], node_entity_map: Dict[int,Dict[str,Any]]):
    arr = paths_to_node_embeddings(paths, node_entity_map)
    t = torch.tensor(arr, dtype=torch.float32)
    with torch.no_grad():
        h = encoder(t)  # [num_paths, d_path]
    return h.cpu().numpy().tolist()

def compute_fusion_score_from_models(v_ordered: Dict[str, float]):
    x = np.array([v_ordered[c] for c in feature_cols], dtype=float).reshape(1, -1)
    raw_clf = float(clf.predict(x, raw_score=True))
    raw_rank = float(ranker.predict(x, raw_score=True))
    classifier_weight = float(feature_schema.get("classifier_weight", 0.7))
    rank_weight = float(feature_schema.get("rank_weight", 0.3))
    combined_raw = classifier_weight * raw_clf + rank_weight * raw_rank
    p_raw = 1.0 / (1.0 + math.exp(-combined_raw))
    p_cal = float(calibrator.transform([p_raw])[0]) if isinstance(calibrator, IsotonicRegression) else p_raw
    return {"raw_clf": raw_clf, "raw_rank": raw_rank, "p_raw": p_raw, "p_cal": p_cal}

def empirical_p_value_for_pair(i:int, j:int, G: nx.DiGraph, node_entity_map: Dict[int,Dict[str,Any]], rec_meta: Dict[str,Any], pos_score: float, neg_score: float, null_B: int, seed_base:int):
    rng = random.Random(seed_base + (i * 1315423911) ^ (j * 2654435761))
    T_obs = compute_fusion_score_from_models(assemble_vij_from_struct(i,j,G,rec_meta,pos_score,neg_score, []))["p_cal"]
    nulls = []
    nodes = list(G.nodes())
    degs = {n:int(G.degree(n)) for n in nodes}
    all_degs = sorted(set(degs.values()))
    deg = degs.get(j, 0)
    # find bucket
    buckets = {d: [n for n in nodes if degs[n]==d] for d in all_degs}
    # if exact bucket empty, sample uniformly
    for b in range(null_B):
        bucket = buckets.get(deg)
        if bucket:
            jb = rng.choice(bucket)
        else:
            jb = rng.choice(nodes)
        vj = assemble_vij_from_struct(i, jb, G, rec_meta, pos_score, neg_score, [])
        s = compute_fusion_score_from_models(vj)["p_cal"]
        nulls.append(s)
    p_val = (1.0 + sum(1 for t in nulls if t >= T_obs)) / (1.0 + len(nulls))
    return p_val

# ----------------------------
# Load inputs: E_prior & node entity maps & G_true
# ----------------------------
G_true = nx.read_gpickle(G_TRUE_PATH)
d_entities = load_d_corpus_entities(DCORPUS_DIR)
with E_PRIOR_PATH.open("r", encoding="utf8") as fh:
    e_prior = [json.loads(ln) for ln in fh if ln.strip()]

# ----------------------------
# E-step loop (round 1)
# ----------------------------
if USE_GEMINI:
    gemini = GeminiClientStrict(models=MODEL_CANDIDATES, attempts_per_model=MAX_LLM_ATTEMPTS_PER_MODEL, llm_delay=LLM_DELAY)

out_fh = safe_jsonl_writer(C_PRIOR_OUTPUT)
ddistill_fh = safe_jsonl_writer(DDISTILL_OUTPUT)

summary = {"n_pairs": 0, "n_high_conf": 0, "p_hist": []}
pair_count = 0

for rec in e_prior:
    pair_count += 1
    i = int(rec["i"]); j = int(rec["j"])
    graph_id = rec.get("graph_id", "G_true")
    node_map = d_entities.get(graph_id, {})
    # paths
    try:
        paths = list(nx.all_simple_paths(G_true, source=i, target=j, cutoff=MAX_PATH_CUTOFF))
    except Exception:
        paths = []
    # path nodes flatten
    path_nodes_flat = []
    for p in paths:
        for n in p:
            path_nodes_flat.append(int(n))
    # create rec_meta baseline from D_corpus if available
    rec_meta = {}
    # assemble v_ij
    v_ordered = assemble_vij_from_struct(i, j, G_true, rec_meta, pos_score=0.5, neg_score=0.0, path_nodes=path_nodes_flat)
    # compute path embeddings
    if len(paths) == 0:
        path_embeddings = []
    else:
        path_embeddings = encode_paths_with_lpa(paths, node_map)
    # fusion score
    fusion_scores = compute_fusion_score_from_models(v_ordered)
    finalscore = fusion_scores["p_cal"]
    # compute p-value empirically
    p_val = empirical_p_value_for_pair(i, j, G_true, node_map, rec_meta, pos_score=0.5, neg_score=0.0, null_B=NULL_SAMPLES_PER_PAIR, seed_base=P_VALUE_SEED_OFFSET)
    # round 1: call Gemini for ÂµLLMcond
    llm_cond_score = None
    if USE_GEMINI:
        prompt = (
            "Given a fusion feature vector and short metadata, return ONLY a JSON object "
            "{\"score\":<float 0..1>, \"reason\":\"short\"} estimating how likely the pair "
            "represents a causal direct relation. Feature dict:\n"
            + json.dumps(v_ordered)
        )
        raw = gemini.generate_text(prompt, temperature=0.0, max_tokens=40)
        start = raw.find("{"); end = raw.rfind("}")
        if start != -1 and end != -1 and end>start:
            try:
                jj = json.loads(raw[start:end+1])
                llm_cond_score = float(jj.get("score", 0.0))
            except Exception:
                llm_cond_score = None
        else:
            raise RuntimeError("Gemini returned unparsable output for pair {} {}".format(i,j))
        # distillation pair
        ddistill_fh.write(json.dumps({"v": v_ordered, "t": llm_cond_score}) + "\n")
    # C_prior record
    c_record = {
        "i": int(i),
        "j": int(j),
        "graph_id": graph_id,
        "domain": rec.get("domain", "unknown"),
        "finalscore": float(finalscore),
        "p_value": float(p_val),
        "v_ij": v_ordered,
        "path_embeddings": path_embeddings,
        "path_stats": {"num_paths": len(paths), "avg_path_len": v_ordered.get("avg_path_len",0.0)},
        "cpc": {"plaus": v_ordered.get("p_plaus", 0.0), "temp": v_ordered.get("p_temp",0.0), "mech": v_ordered.get("p_mech",0.0)},
        "meta": {"round": 1, "timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())}
    }
    if llm_cond_score is not None:
        c_record["llm_conditional_score"] = float(llm_cond_score)
    out_fh.write(json.dumps(c_record) + "\n")
    summary["n_pairs"] += 1
    summary["p_hist"].append(float(p_val))
    if p_val < 0.05 and finalscore > 0.5:
        summary["n_high_conf"] += 1

out_fh.close()
ddistill_fh.close()
SUMMARY_OUTPUT.write_text(json.dumps(summary, indent=2))
print("E-step finished. Outputs:")
print("  C_prior:", C_PRIOR_OUTPUT.resolve())
print("  Ddistill:", DDISTILL_OUTPUT.resolve())
print("  summary:", SUMMARY_OUTPUT.resolve())


### B. M-step (Maximisation / Student Update)

In [None]:
from __future__ import annotations
import json, math, random, time, os
from pathlib import Path
from typing import List, Dict, Any
import numpy as np
import torch, torch.nn as nn, torch.optim as optim
import joblib

# -------------------- config --------------------
GLOBAL_SEED = 20251127
random.seed(GLOBAL_SEED); np.random.seed(GLOBAL_SEED); torch.manual_seed(GLOBAL_SEED)

IN_DIR = Path("out/e_step")
C_PRIOR = IN_DIR / "C_prior_round_1.jsonl"
DDISTILL = IN_DIR / "Ddistill_round1.jsonl"
FEATURE_SCHEMA = Path("out/fusion/model_bundle/feature_schema.json")
LPA_BUNDLE_JSON = Path("out/lpa_model/lpa_model_bundle.json")
LPA_ENCODER_PTH = Path("out/lpa_model/encoder_finetuned.pth")
LPA_AGG_PTH = Path("out/lpa_model/aggregator_finetuned.pth")

OUT_DIR = Path("out/em_refinement")
OUT_DIR.mkdir(parents=True, exist_ok=True)
TEACHER_OUT = OUT_DIR / "teacher_models"
TEACHER_OUT.mkdir(parents=True, exist_ok=True)

# hyperparams 
BATCH_SIZE = 512
EPOCHS_FUSION = 5
EPOCHS_LPA = 5
EPOCHS_FSTUDENT = 10
LR_FUSION = 3e-4
LR_LPA = 1e-4
LR_FSTUDENT = 1e-4
LAMBDA_CONSIST = 0.1
W_MIN = 0.01
EMA_ALPHA = 0.999
GRAD_CLIP = 1.0
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------------------- checks --------------------
for p in [C_PRIOR, DDISTILL, FEATURE_SCHEMA, LPA_BUNDLE_JSON, LPA_ENCODER_PTH, LPA_AGG_PTH]:
    if not p.exists():
        raise FileNotFoundError(f"Required file missing: {p}")

# -------------------- load feature schema --------------------
feature_schema = json.loads(FEATURE_SCHEMA.read_text(encoding="utf8"))
FEATURE_COLS: List[str] = feature_schema["feature_order"]
D_IN = len(FEATURE_COLS)

# -------------------- model definitions --------------------
class FusionStudent(nn.Module):
    def __init__(self, input_dim: int, hidden: int = 1024, out_dim: int = 1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden),
            nn.GELU(),
            nn.LayerNorm(hidden),
            nn.Linear(hidden, hidden//2),
            nn.GELU(),
            nn.LayerNorm(hidden//2),
            nn.Linear(hidden//2, out_dim)
        )
    def forward(self, x):
        return torch.sigmoid(self.net(x).squeeze(-1))

class FStudentRegressor(nn.Module):
    def __init__(self, input_dim: int, hidden1: int = 512, hidden2: int = 256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden1),
            nn.GELU(),
            nn.Linear(hidden1, hidden2),
            nn.GELU(),
            nn.Linear(hidden2, 1)
        )
    def forward(self, x):
        return self.net(x).squeeze(-1)

# LPA encoder/aggregator must match the definitions used in E-step.
class PathTransformerEncoder(nn.Module):
    def __init__(self, node_emb_dim: int, d_model: int = 256, nhead: int = 8, num_layers: int = 4, dim_feedforward: int = 512, d_path: int = 256, max_len: int = 32):
        super().__init__()
        self.input_proj = nn.Linear(node_emb_dim, d_model)
        enc_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, batch_first=True)
        self.transformer = nn.TransformerEncoder(enc_layer, num_layers=num_layers)
        self.pos_emb = nn.Embedding(max_len, d_model)
        self.pool = nn.Linear(d_model, d_path)
        self.norm = nn.LayerNorm(d_path)
        self.max_len = max_len
    def forward(self, x):
        B,L,_ = x.shape
        pos = torch.arange(L, device=x.device).unsqueeze(0).expand(B,-1)
        x = self.input_proj(x) + self.pos_emb(pos)
        x = self.transformer(x)
        x_mean = x.mean(dim=1)
        h = self.pool(x_mean)
        h = self.norm(h)
        return h

class AttentionAggregator(nn.Module):
    def __init__(self, d_path: int = 256, path_feat_dim: int = 8, hidden: int = 256):
        super().__init__()
        self.path_proj = nn.Linear(path_feat_dim, d_path) if path_feat_dim>0 else None
        in_dim = d_path*2 if self.path_proj is not None else d_path
        self.att_mlp = nn.Sequential(nn.Linear(in_dim, hidden), nn.GELU(), nn.Linear(hidden,1))
        self.head = nn.Sequential(nn.Linear(d_path+4, hidden), nn.GELU(), nn.Linear(hidden,1))
    def forward(self, path_embs, path_feats=None):
        single=False
        if path_embs.dim()==2:
            path_embs = path_embs.unsqueeze(0); single=True
        B,m,d = path_embs.shape
        if self.path_proj is not None and path_feats is not None:
            pf = self.path_proj(path_feats)
            att_in = torch.cat([path_embs, pf], dim=-1)
        else:
            att_in = path_embs
        logits = self.att_mlp(att_in).squeeze(-1)
        att = torch.softmax(logits, dim=-1)
        agg = (att.unsqueeze(-1) * path_embs).sum(dim=1)
        num_paths = torch.tensor([m], device=path_embs.device, dtype=torch.float32).expand(B).unsqueeze(-1)
        stats_cat = torch.cat([num_paths, torch.zeros((B,3), device=path_embs.device)], dim=-1)
        head_in = torch.cat([agg, stats_cat], dim=-1)
        out = torch.sigmoid(self.head(head_in).squeeze(-1))
        if single:
            return out.squeeze(0)
        return out

# -------------------- load LPA teacher -> student initializations --------------------
lpa_bundle = json.loads(LPA_BUNDLE_JSON.read_text(encoding="utf8"))
NODE_EMB_DIM = int(lpa_bundle["config"]["NODE_EMB_DIM"])
D_PATH = int(lpa_bundle["config"]["D_PATH"])
MAX_PATH_LEN = int(lpa_bundle["config"]["MAX_PATH_LENGTH"])

lpa_teacher_encoder = PathTransformerEncoder(node_emb_dim=NODE_EMB_DIM, d_path=D_PATH, max_len=MAX_PATH_LEN).to(DEVICE)
lpa_teacher_agg = AttentionAggregator(d_path=D_PATH, path_feat_dim=8, hidden=256).to(DEVICE)
enc_ckpt = torch.load(LPA_ENCODER_PTH, map_location="cpu")
if "encoder_state" in enc_ckpt:
    lpa_teacher_encoder.load_state_dict(enc_ckpt["encoder_state"])
else:
    lpa_teacher_encoder.load_state_dict(enc_ckpt)
agg_ckpt = torch.load(LPA_AGG_PTH, map_location="cpu")
if "aggregator_state" in agg_ckpt:
    lpa_teacher_agg.load_state_dict(agg_ckpt["aggregator_state"])
else:
    lpa_teacher_agg.load_state_dict(agg_ckpt)
lpa_teacher_encoder.eval(); lpa_teacher_agg.eval()

# create student copies
lpa_student_encoder = PathTransformerEncoder(node_emb_dim=NODE_EMB_DIM, d_path=D_PATH, max_len=MAX_PATH_LEN).to(DEVICE)
lpa_student_agg = AttentionAggregator(d_path=D_PATH, path_feat_dim=8, hidden=256).to(DEVICE)
lpa_student_encoder.load_state_dict(lpa_teacher_encoder.state_dict())
lpa_student_agg.load_state_dict(lpa_teacher_agg.state_dict())

# fusion: create student and teacher networks (neural)
fusion_teacher = FusionStudent(input_dim=D_IN).to(DEVICE)
fusion_student = FusionStudent(input_dim=D_IN).to(DEVICE)
# init teachers to students identical initially (teacher will be EMA-updated)
fusion_teacher.load_state_dict(fusion_student.state_dict())

# fstudent regressor
fstudent = FStudentRegressor(input_dim=D_IN).to(DEVICE)
fstudent_teacher = FStudentRegressor(input_dim=D_IN).to(DEVICE)
fstudent_teacher.load_state_dict(fstudent.state_dict())

# optimizers
opt_fusion = optim.AdamW(fusion_student.parameters(), lr=LR_FUSION, weight_decay=1e-5)
opt_lpa = optim.AdamW(list(lpa_student_encoder.parameters()) + list(lpa_student_agg.parameters()), lr=LR_LPA, weight_decay=1e-5)
opt_fst = optim.AdamW(fstudent.parameters(), lr=LR_FSTUDENT, weight_decay=1e-5)

bce_loss = nn.BCELoss(reduction="none")
mse_loss = nn.MSELoss(reduction="mean")

# -------------------- load C_prior --------------------
records = []
with C_PRIOR.open("r", encoding="utf8") as fh:
    for ln in fh:
        if not ln.strip(): continue
        records.append(json.loads(ln))

if len(records) == 0:
    raise RuntimeError("No C_prior records found")

# assemble feature matrix, labels, weights, path data
X_list = []
y_soft = []
y_bin = []
w_list = []
domains = []
paths_list = []  
for rec in records:
    v = rec["v_ij"]
    x = np.array([float(v.get(c, 0.0)) for c in FEATURE_COLS], dtype=np.float32)
    X_list.append(x)
    fs = float(rec.get("finalscore", 0.0))
    p = float(rec.get("p_value", 1.0))
    w = max(W_MIN, 1.0 - p)
    X_list[-1] = x
    y_soft.append(np.float32(fs))
    y_bin.append(np.float32(1.0 if fs > 0.0 else 0.0))
    w_list.append(np.float32(w))
    domains.append(rec.get("domain","unknown"))
    pe = rec.get("path_embeddings", [])
    if pe and isinstance(pe, list):
        pe_arr = [np.array(pv, dtype=np.float32) for pv in pe]
    else:
        pe_arr = []
    paths_list.append(pe_arr)

X = np.stack(X_list, axis=0)
y_soft = np.array(y_soft, dtype=np.float32)
y_bin = np.array(y_bin, dtype=np.float32)
w_arr = np.array(w_list, dtype=np.float32)

from collections import defaultdict
idx_by_domain = defaultdict(list)
for idx, dom in enumerate(domains):
    idx_by_domain[dom].append(idx)
domain_keys = sorted(idx_by_domain.keys())
def domain_balanced_batches(batch_size, seed=GLOBAL_SEED):
    rng = random.Random(seed)
    all_idxs = list(range(len(X)))
    cursors = {d:0 for d in domain_keys}
    lens = {d: len(idx_by_domain[d]) for d in domain_keys}
    domain_shuffled = {}
    for d in domain_keys:
        arr = idx_by_domain[d][:]
        rng.shuffle(arr)
        domain_shuffled[d] = arr
    # round-robin sampling
    batch = []
    while True:
        for d in domain_keys:
            if len(batch) >= batch_size:
                yield batch
                batch = []
            if lens[d]==0: continue
            cur = cursors[d]
            batch.append(domain_shuffled[d][cur % lens[d]])
            cursors[d] = cur + 1
        if len(batch) > 0:
            yield batch
            batch = []

# -------------------- augmentation function --------------------
def augment_batch_features(x_batch: np.ndarray, seed_offset: int = 0):
    rng = np.random.RandomState(GLOBAL_SEED + seed_offset)
    x_aug = x_batch.copy()
    B, D = x_aug.shape
    # dropout 10% of features per sample
    drop_p = 0.1
    mask = rng.rand(B, D) > drop_p
    x_aug = x_aug * mask
    # add small gaussian noise scaled to mean absolute feature
    scale = 1e-2
    noise = rng.normal(loc=0.0, scale=scale, size=x_aug.shape).astype(np.float32)
    x_aug = x_aug + noise
    return x_aug

# -------------------- training fusion student --------------------
fusion_student.train()
n = len(X)
bgen = domain_balanced_batches(BATCH_SIZE, seed=GLOBAL_SEED + 1)
steps_per_epoch = math.ceil(n / BATCH_SIZE)
for epoch in range(EPOCHS_FUSION):
    total_loss = 0.0
    for step in range(steps_per_epoch):
        batch_idxs = next(bgen)
        xb = torch.tensor(X[batch_idxs], dtype=torch.float32, device=DEVICE)
        ys = torch.tensor(y_soft[batch_idxs], dtype=torch.float32, device=DEVICE)
        ww = torch.tensor(w_arr[batch_idxs], dtype=torch.float32, device=DEVICE)
        # augmentations
        x_aug1 = torch.tensor(augment_batch_features(X[batch_idxs], seed_offset=epoch*1000 + step), dtype=torch.float32, device=DEVICE)
        x_aug2 = torch.tensor(augment_batch_features(X[batch_idxs], seed_offset=epoch*2000 + step), dtype=torch.float32, device=DEVICE)
        opt_fusion.zero_grad()
        p = fusion_student(xb)
        p1 = fusion_student(x_aug1)
        p2 = fusion_student(x_aug2)
        loss_pseudo = bce_loss(p, ys)  # per-element
        loss_pseudo = (loss_pseudo * ww).mean()
        loss_consist = ((p1 - p2)**2).mean()
        loss = loss_pseudo + LAMBDA_CONSIST * loss_consist
        loss.backward()
        torch.nn.utils.clip_grad_norm_(fusion_student.parameters(), GRAD_CLIP)
        opt_fusion.step()
        total_loss += float(loss.detach().cpu().item())
    print(f"[fusion] epoch {epoch+1}/{EPOCHS_FUSION} loss {total_loss/steps_per_epoch:.6f}")

# -------------------- training LPA student --------------------
lpa_student_encoder.train(); lpa_student_agg.train()
# prepare path batches: we will batch by indices; pad to max_paths_in_batch and pad each path to MAX_PATH_LEN x NODE_EMB_DIM
def collate_paths(batch_indices: List[int]):
    batch_paths = [paths_list[i] for i in batch_indices]
    # determine max number of paths
    m = max((len(p) for p in batch_paths), default=0)
    if m == 0:
        arr = np.zeros((len(batch_indices),1, NODE_EMB_DIM), dtype=np.float32)
        path_counts = np.zeros((len(batch_indices),), dtype=np.int32)
        return arr, path_counts
    # pad per path to MAX_PATH_LEN and per example to m
    arrs = []
    path_counts = []
    for pvecs in batch_paths:
        path_counts.append(len(pvecs))
        # build array shape (m, MAX_PATH_LEN, NODE_EMB_DIM)
        example = np.zeros((m, MAX_PATH_LEN, NODE_EMB_DIM), dtype=np.float32)
        for pi, pv in enumerate(pvecs):
            vec = pv
            if vec.shape[0] != NODE_EMB_DIM:
                # if different dim, pad or truncate
                if vec.shape[0] < NODE_EMB_DIM:
                    padw = NODE_EMB_DIM - vec.shape[0]
                    vec = np.pad(vec, (0,padw)).astype(np.float32)
                else:
                    vec = vec[:NODE_EMB_DIM].astype(np.float32)
            # repeat vec across MAX_PATH_LEN
            example[pi, :, :] = np.tile(vec.reshape(1, -1), (MAX_PATH_LEN,1))
        arrs.append(example)
    return np.stack(arrs, axis=0), np.array(path_counts, dtype=np.int32)

# training loop
indices = list(range(len(X)))
rng = random.Random(GLOBAL_SEED+2)
for epoch in range(EPOCHS_LPA):
    rng.shuffle(indices)
    total_loss = 0.0
    for s in range(0, len(indices), BATCH_SIZE):
        batch_idxs = indices[s:s+BATCH_SIZE]
        arr, counts = collate_paths(batch_idxs)
        arr_t = torch.tensor(arr, dtype=torch.float32, device=DEVICE)  # [B, m, L, node_dim]
        B_, m, L, node_dim = arr_t.shape
        # collapse to [B*m, L, node_dim] and run encoder to get path embeddings (we'll average across m later)
        arr_flat = arr_t.reshape(B_*m, L, node_dim)
        with torch.no_grad():
            # teacher may have been used to produce path embeddings earlier; here we re-encode via student
            pass
        # For memory reasons, encode in small batches, we're poor you see
        batch_embs = []
        batch_size_small = 64
        for start in range(0, arr_flat.shape[0], batch_size_small):
            chunk = arr_flat[start:start+batch_size_small]
            he = lpa_student_encoder(chunk)  # [chunk, d_path]
            batch_embs.append(he)
        he_all = torch.cat(batch_embs, dim=0)
        he_all = he_all.reshape(B_, m, -1)  # [B, m, d_path]
        path_feats = torch.zeros((B_, m, 8), device=DEVICE)
        preds = lpa_student_agg(he_all, path_feats)  # [B]
        ys = torch.tensor(y_bin[batch_idxs], dtype=torch.float32, device=DEVICE)
        ww = torch.tensor(w_arr[batch_idxs], dtype=torch.float32, device=DEVICE)
        loss_pseudo = bce_loss(preds, ys)
        loss_pseudo = (loss_pseudo * ww).mean()
        drop_p = 0.2
        rand = np.random.RandomState(GLOBAL_SEED + epoch)
        mask = rand.rand(B_, m) > drop_p
        he_aug1 = he_all.clone()
        he_aug2 = he_all.clone()
        he_aug1 = he_aug1 * torch.tensor(mask.astype(np.float32), device=DEVICE).unsqueeze(-1)
        mask2 = rand.rand(B_, m) > drop_p
        he_aug2 = he_aug2 * torch.tensor(mask2.astype(np.float32), device=DEVICE).unsqueeze(-1)
        p1 = lpa_student_agg(he_aug1, path_feats)
        p2 = lpa_student_agg(he_aug2, path_feats)
        loss_cons = ((p1 - p2)**2).mean()
        loss = loss_pseudo + LAMBDA_CONSIST * loss_cons
        opt_lpa.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(list(lpa_student_encoder.parameters()) + list(lpa_student_agg.parameters()), GRAD_CLIP)
        opt_lpa.step()
        total_loss += float(loss.detach().cpu().item())
    print(f"[LPA] epoch {epoch+1}/{EPOCHS_LPA} loss {total_loss/ (max(1, math.ceil(len(indices)/BATCH_SIZE))):.6f}")

# -------------------- training fstudent (distillation) --------------------
# load distill pairs
d_pairs = []
with DDISTILL.open("r", encoding="utf8") as fh:
    for ln in fh:
        if not ln.strip(): continue
        d_pairs.append(json.loads(ln))
if len(d_pairs) == 0:
    print("No Ddistill data found; skipping fstudent training")
else:
    Xd = []
    yd = []
    for d in d_pairs:
        v = d["v"]
        x = np.array([float(v.get(c,0.0)) for c in FEATURE_COLS], dtype=np.float32)
        Xd.append(x)
        yd.append(float(d["t"]))
    Xd = np.stack(Xd, axis=0)
    yd = np.array(yd, dtype=np.float32)
    fstudent.train()
    n_d = len(Xd)
    idxs = list(range(n_d))
    rng = random.Random(GLOBAL_SEED+3)
    for epoch in range(EPOCHS_FSTUDENT):
        rng.shuffle(idxs)
        total_loss = 0.0
        for s in range(0, n_d, BATCH_SIZE):
            batch = idxs[s:s+BATCH_SIZE]
            xb = torch.tensor(Xd[batch], dtype=torch.float32, device=DEVICE)
            yb = torch.tensor(yd[batch], dtype=torch.float32, device=DEVICE)
            opt_fst.zero_grad()
            pred = fstudent(xb)
            loss = mse_loss(pred, yb)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(fstudent.parameters(), GRAD_CLIP)
            opt_fst.step()
            total_loss += float(loss.detach().cpu().item())
        print(f"[fstudent] epoch {epoch+1}/{EPOCHS_FSTUDENT} loss {total_loss / max(1, math.ceil(n_d/BATCH_SIZE)):.6f}")

# -------------------- EMA teacher updates --------------------
@torch.no_grad()
def ema_update(teacher: nn.Module, student: nn.Module, alpha: float):
    for tp, sp in zip(teacher.parameters(), student.parameters()):
        tp.data.mul_(alpha).add_(sp.data * (1.0 - alpha))

ema_update(fusion_teacher, fusion_student, EMA_ALPHA)
ema_update(lpa_teacher_encoder, lpa_student_encoder, EMA_ALPHA)
ema_update(lpa_teacher_agg, lpa_student_agg, EMA_ALPHA)
ema_update(fstudent_teacher, fstudent, EMA_ALPHA)

# -------------------- save teacher models --------------------
torch.save(fusion_teacher.state_dict(), TEACHER_OUT / "fusion_teacher.pth")
torch.save(lpa_teacher_encoder.state_dict(), TEACHER_OUT / "lpa_teacher_encoder.pth")
torch.save(lpa_teacher_agg.state_dict(), TEACHER_OUT / "lpa_teacher_agg.pth")
torch.save(fstudent_teacher.state_dict(), TEACHER_OUT / "fstudent_teacher.pth")
meta = {
    "timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
    "n_records": len(X),
    "epochs_fusion": EPOCHS_FUSION,
    "epochs_lpa": EPOCHS_LPA,
    "epochs_fstudent": EPOCHS_FSTUDENT,
    "lambda_consist": LAMBDA_CONSIST,
    "ema_alpha": EMA_ALPHA
}
(TEACHER_OUT / "mstep_meta.json").write_text(json.dumps(meta, indent=2))
print("M-step complete. Teacher models written to:", TEACHER_OUT.resolve())


### C. Teacher Update (EMA) and D. Stopping Criteria

In [None]:
from __future__ import annotations
import argparse, json, os, time
from pathlib import Path
from typing import Dict, Any, Set, Tuple, List
import torch
import numpy as np

# -------------------- Defaults & hyperparams --------------------
DEFAULT_ALPHA = 0.999
W_THRESHOLD = 0.9         # high-confidence weight threshold
SCORE_THRESHOLD = 0.5     # finalscore threshold for high-confidence
JACCARD_STOP = 0.99
VAL_PATIENCE = 3
VAL_EPS = 1e-4

# -------------------- Utilities --------------------
def load_jsonl(path: Path) -> List[Dict[str, Any]]:
    out = []
    with path.open("r", encoding="utf8") as fh:
        for ln in fh:
            s = ln.strip()
            if not s:
                continue
            out.append(json.loads(s))
    return out

def save_jsonl(objects: List[Dict[str, Any]], path: Path):
    with path.open("w", encoding="utf8") as fh:
        for obj in objects:
            fh.write(json.dumps(obj, ensure_ascii=False) + "\n")

def make_pair_key(i: int, j: int) -> str:
    return f"{int(i)}::{int(j)}"

def jaccard_from_sets(A: Set[str], B: Set[str]) -> float:
    if not A and not B:
        return 1.0
    inter = A.intersection(B)
    union = A.union(B)
    if not union:
        return 0.0
    return float(len(inter)) / float(len(union))


def ema_update_state_dict(teacher_sd: Dict[str, torch.Tensor], student_sd: Dict[str, torch.Tensor], alpha: float):
    t_keys = set(teacher_sd.keys())
    s_keys = set(student_sd.keys())
    if t_keys != s_keys:
        missing_in_t = sorted(list(s_keys - t_keys))
        missing_in_s = sorted(list(t_keys - s_keys))
        raise RuntimeError(f"State-dict key mismatch between teacher and student.\nMissing in teacher: {missing_in_t}\nMissing in student: {missing_in_s}")
    for k in teacher_sd.keys():
        t = teacher_sd[k]
        s = student_sd[k]
        if not isinstance(t, torch.Tensor) or not isinstance(s, torch.Tensor):
            continue
        if t.dtype != s.dtype:
            s = s.to(dtype=t.dtype)
        teacher_sd[k] = (alpha * t + (1.0 - alpha) * s).clone()

# -------------------- Main pipeline --------------------
def main():
    p = argparse.ArgumentParser()
    p.add_argument("--round", type=int, required=True)
    p.add_argument("--student-checks", type=Path, required=True,
                   help="Directory with student checkpoints named: fusion_student.pth, lpa_encoder_student.pth, lpa_agg_student.pth, fstudent_student.pth (exact names)")
    p.add_argument("--teacher-prev", type=Path, required=True,
                   help="Directory containing previous teacher checkpoints with matching names: fusion_teacher.pth, lpa_encoder_teacher.pth, lpa_agg_teacher.pth, fstudent_teacher.pth")
    p.add_argument("--c-prior", type=Path, required=True, help="C_prior JSONL for current round")
    p.add_argument("--h-prev", type=Path, default=None, help="Optional previous H(r-1) JSONL to compute Jaccard")
    p.add_argument("--val-history", type=Path, default=None, help="Optional validation history JSON file (list of {round:int, val_score:float})")
    p.add_argument("--alpha", type=float, default=DEFAULT_ALPHA)
    p.add_argument("--out", type=Path, required=True, help="Directory to write updated teacher checkpoints and artifacts")
    args = p.parse_args()

    # checks
    stud_dir: Path = args.student_checks
    teach_prev_dir: Path = args.teacher_prev
    c_prior_path: Path = args.c_prior
    out_dir: Path = args.out
    out_dir.mkdir(parents=True, exist_ok=True)

    required_student_files = {
        "fusion": stud_dir / "fusion_student.pth",
        "lpa_encoder": stud_dir / "lpa_encoder_student.pth",
        "lpa_agg": stud_dir / "lpa_agg_student.pth",
        "fstudent": stud_dir / "fstudent_student.pth"
    }
    required_teacher_files = {
        "fusion": teach_prev_dir / "fusion_teacher.pth",
        "lpa_encoder": teach_prev_dir / "lpa_encoder_teacher.pth",
        "lpa_agg": teach_prev_dir / "lpa_agg_teacher.pth",
        "fstudent": teach_prev_dir / "fstudent_teacher.pth"
    }

    for k,v in required_student_files.items():
        if not v.exists():
            raise FileNotFoundError(f"Missing student checkpoint {k}: {v}")
    for k,v in required_teacher_files.items():
        if not v.exists():
            raise FileNotFoundError(f"Missing previous teacher checkpoint {k}: {v}")
    if not c_prior_path.exists():
        raise FileNotFoundError(f"C_prior file not found: {c_prior_path}")

    # load current C_prior
    c_records = load_jsonl(c_prior_path)

    # build high-confidence set H_r
    H_r = set()
    H_r_list = []
    for rec in c_records:
        i = rec.get("i") or rec.get("node_pair", [None, None])[0] or (rec.get("path",[None])[0] if rec.get("path") else None)
        j = rec.get("j") or rec.get("node_pair", [None, None])[1] or (rec.get("path",[])[-1] if rec.get("path") else None)
        if i is None or j is None:
            # try other fields
            continue
        finalscore = float(rec.get("finalscore", rec.get("pos_score", 0.0) or 0.0))
        pval = float(rec.get("p_value", rec.get("p", 1.0)))
        w = max(0.0, min(1.0, 1.0 - pval))
        if w >= W_THRESHOLD and finalscore >= SCORE_THRESHOLD:
            key = make_pair_key(i,j)
            H_r.add(key)
            H_r_list.append({"i": int(i), "j": int(j), "finalscore": float(finalscore), "p_value": float(pval), "w": float(w)})

    # write H_r deterministically (sorted)
    H_r_path = out_dir / f"H_round_{args.round}.jsonl"
    H_r_list_sorted = sorted(H_r_list, key=lambda x: (x["i"], x["j"]))
    save_jsonl(H_r_list_sorted, H_r_path)

    # compute Jaccard vs previous H 
    J_r = None
    if args.h_prev:
        hprev_path = Path(args.h_prev)
        if not hprev_path.exists():
            raise FileNotFoundError(f"Provided h-prev file does not exist: {hprev_path}")
        hprev = load_jsonl(hprev_path)
        H_prev = set(make_pair_key(item["i"], item["j"]) for item in hprev)
        J_r = jaccard_from_sets(H_r, H_prev)
    else:
        # if no previous, set J_r = None
        J_r = None

    # perform EMA update per model
    alpha = float(args.alpha)
    saved_teacher_paths = {}
    for model_key in ["fusion", "lpa_encoder", "lpa_agg", "fstudent"]:
        stud_path = required_student_files[model_key]
        teach_prev_path = required_teacher_files[model_key]
        stud_sd = torch.load(stud_path, map_location="cpu")
        teach_sd = torch.load(teach_prev_path, map_location="cpu")

        if isinstance(stud_sd, dict) and "state_dict" in stud_sd and isinstance(stud_sd["state_dict"], dict):
            stud_state = stud_sd["state_dict"]
        elif isinstance(stud_sd, dict) and "model_state" in stud_sd and isinstance(stud_sd["model_state"], dict):
            stud_state = stud_sd["model_state"]
        else:
            stud_state = stud_sd

        if isinstance(teach_sd, dict) and "state_dict" in teach_sd and isinstance(teach_sd["state_dict"], dict):
            teach_state = teach_sd["state_dict"]
        elif isinstance(teach_sd, dict) and "model_state" in teach_sd and isinstance(teach_sd["model_state"], dict):
            teach_state = teach_sd["model_state"]
        else:
            teach_state = teach_sd

        for k in list(stud_state.keys()):
            if not isinstance(stud_state[k], torch.Tensor):
                try:
                    stud_state[k] = torch.tensor(stud_state[k])
                except Exception:
                    pass
        for k in list(teach_state.keys()):
            if not isinstance(teach_state[k], torch.Tensor):
                try:
                    teach_state[k] = torch.tensor(teach_state[k])
                except Exception:
                    pass

        # compute EMA in-place on teach_state
        try:
            ema_update_state_dict(teach_state, stud_state, alpha)
        except RuntimeError as e:
            raise RuntimeError(f"EMA update failed for model {model_key}: {e}")

        out_sd = teach_state
        # save to out dir
        ts = time.strftime("%Y%m%dT%H%M%S", time.gmtime())
        out_path = out_dir / f"{model_key}_teacher_round{args.round}_{ts}.pth"
        torch.save(out_sd, out_path)
        saved_teacher_paths[model_key] = str(out_path)

    stop_by_val = False
    val_info = None
    if args.val_history:
        val_path = Path(args.val_history)
        if not val_path.exists():
            raise FileNotFoundError(f"val_history file not found: {val_path}")
        val_hist = json.loads(val_path.read_text(encoding="utf8"))
        scores = [float(x["val_score"]) for x in val_hist if "val_score" in x]
        rounds = [int(x["round"]) for x in val_hist if "val_score" in x]
        if len(scores) > 0:
            best_score = max(scores)
            last_score = scores[-1]
            last_best_idx = max(i for i,s in enumerate(scores) if s >= best_score - 1e-12)
            rounds_since_best = len(scores) - 1 - last_best_idx
            if rounds_since_best >= VAL_PATIENCE:
                stop_by_val = True
            val_info = {"best_score": float(best_score), "last_score": float(last_score), "rounds_since_best": int(rounds_since_best)}
        else:
            val_info = {"info": "no scores in val_history"}
    else:
        val_info = {"info": "val_history not provided"}

    stop_flag = False
    if J_r is not None:
        if J_r >= JACCARD_STOP:
            stop_flag = True
    if stop_by_val:
        stop_flag = True

    summary = {
        "round": int(args.round),
        "timestamp_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
        "alpha": float(alpha),
        "w_threshold": float(W_THRESHOLD),
        "score_threshold": float(SCORE_THRESHOLD),
        "J_r": (float(J_r) if J_r is not None else None),
        "J_threshold": float(JACCARD_STOP),
        "stop_by_jaccard": (bool(J_r is not None and J_r >= JACCARD_STOP)),
        "stop_by_val": bool(stop_by_val),
        "val_info": val_info,
        "saved_teacher_checkpoints": saved_teacher_paths,
        "H_r_path": str(H_r_path),
        "n_H_r": len(H_r_list_sorted),
        "stop": bool(stop_flag)
    }
    summary_path = out_dir / f"teacher_update_summary_round{args.round}.json"
    summary_path.write_text(json.dumps(summary, indent=2))
    print("Wrote summary:", summary_path)
    if stop_flag:
        print("STOP condition met (stop=True). Final teacher snapshots written to:", out_dir)
    else:
        print("STOP condition not met. Updated teacher snapshots written to:", out_dir)

if __name__ == "__main__":
    main()


### Evaluating the produced $C_{prior}$

In [None]:
from __future__ import annotations
import json, time
from pathlib import Path
from typing import List, Dict, Any
import numpy as np
import pandas as pd
import networkx as nx
from sklearn.metrics import roc_auc_score, average_precision_score, precision_score, recall_score, f1_score, brier_score_loss

C_PRIOR_PATH = Path("out/e_step/C_prior_round_1.jsonl")
ENTITIES_PATH = Path("out/d_corpus/entities.json")   
G_TRUE_PATH = Path("graphs/G_true.gpickle")
OUT_DIR = Path("out/eval_cprior")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# === Helper functions ===
def load_jsonl(p: Path):
    out=[]
    with p.open("r", encoding="utf8") as fh:
        for ln in fh:
            s = ln.strip()
            if not s: continue
            out.append(json.loads(s))
    return out

def load_entities_map(p: Path):
    txt = p.read_text(encoding="utf8").strip()
    if not txt:
        return {}
    try:
        parsed = json.loads(txt)
        if isinstance(parsed, list):
            return {e["id"]: e for e in parsed}
        if isinstance(parsed, dict):
            return parsed
    except Exception:
        out={}
        for ln in txt.splitlines():
            if not ln.strip(): continue
            obj = json.loads(ln)
            out[obj["id"]] = obj
        return out

def normalize_node(n):
    if isinstance(n, str) and n.startswith("N"):
        try:
            return int(n[1:])
        except Exception:
            return n
    return int(n)

def build_df(records: List[Dict[str,Any]], entities_map: Dict[str,Any]):
    rows=[]
    for rec in records:
        i = rec.get("i"); j = rec.get("j")
        if i is None or j is None:
            npair = rec.get("node_pair")
            if isinstance(npair, list) and len(npair)>=2:
                i,j = npair[0], npair[1]
            else:
                path = rec.get("path")
                if isinstance(path, list) and len(path)>=2:
                    i,j = path[0], path[-1]
        if i is None or j is None:
            continue
        try:
            inum = normalize_node(i)
            jnum = normalize_node(j)
        except Exception:
            inum = i; jnum = j
        finalscore = float(rec.get("finalscore", rec.get("pos_score", 0.0) or 0.0))
        pval = float(rec.get("p_value", rec.get("p", 1.0)))
        w = max(0.0, min(1.0, 1.0 - pval))
        domain = rec.get("domain", "unknown")
        graph_id = rec.get("graph_id", "unknown")
        key_i = f"N{int(inum)}" if isinstance(inum,int) else str(inum)
        key_j = f"N{int(jnum)}" if isinstance(jnum,int) else str(jnum)
        name_i = entities_map.get(key_i, {}).get("name") if key_i in entities_map else None
        name_j = entities_map.get(key_j, {}).get("name") if key_j in entities_map else None
        rows.append({
            "i": inum, "j": jnum, "i_key": key_i, "j_key": key_j,
            "i_name": name_i, "j_name": name_j,
            "finalscore": finalscore, "p_value": pval, "w": w,
            "domain": domain, "graph_id": graph_id, "raw": rec
        })
    return pd.DataFrame(rows)

def expected_calibration_error(y_true, probs, n_bins=10):
    bins = np.linspace(0.0,1.0,n_bins+1)
    bin_indices = np.digitize(probs, bins) - 1
    ece = 0.0
    info=[]
    for b in range(n_bins):
        mask = bin_indices == b
        if mask.sum()==0:
            info.append({"bin":b,"count":0,"avg_prob":None,"avg_true":None,"abs_err":None})
            continue
        avg_prob = float(probs[mask].mean())
        avg_true = float(y_true[mask].mean())
        abs_err = abs(avg_prob - avg_true)
        ece += (mask.sum()/len(probs))*abs_err
        info.append({"bin":b,"count":int(mask.sum()),"avg_prob":avg_prob,"avg_true":avg_true,"abs_err":abs_err})
    return float(ece), info

def topk_precision(df, k_list=[1,5,10,20,50], score_col="finalscore", gt_edges=set()):
    df_sorted = df.sort_values(score_col, ascending=False).reset_index(drop=True)
    out=[]
    for k in k_list:
        topk = df_sorted.head(k)
        if len(topk)==0:
            out.append({"k":k,"precision":None,"num":0})
            continue
        correct = 0
        for _, row in topk.iterrows():
            if (int(row["i"]), int(row["j"])) in gt_edges:
                correct += 1
        out.append({"k":k, "precision": float(correct/len(topk)), "num": int(len(topk))})
    return out

if not C_PRIOR_PATH.exists():
    raise FileNotFoundError(f"C_prior not found: {C_PRIOR_PATH}")
if not ENTITIES_PATH.exists():
    raise FileNotFoundError(f"entities file not found: {ENTITIES_PATH}")
if not G_TRUE_PATH.exists():
    raise FileNotFoundError(f"G_true not found: {G_TRUE_PATH}")

c_records = load_jsonl(C_PRIOR_PATH)
entities_map = load_entities_map(ENTITIES_PATH)
G_true = nx.read_gpickle(G_TRUE_PATH)
if not isinstance(G_true, nx.DiGraph):
    G_true = nx.DiGraph(G_true)

df = build_df(c_records, entities_map)
n_records = len(df)
print(f"Loaded {n_records} C_prior records")

# build ground-truth edge set
gt_edges = set()
for u,v in G_true.edges():
    try:
        uu = int(u)
        vv = int(v)
    except Exception:
        if isinstance(u,str) and u.startswith("N"):
            try: uu=int(u[1:]); vv=int(v[1:])
            except Exception: continue
        else:
            continue
    gt_edges.add((uu,vv))

# labels and scores
y_true = np.array([(1 if (int(r["i"]), int(r["j"])) in gt_edges else 0) for _,r in df.iterrows()], dtype=int)
scores = np.array([float(r["finalscore"]) for _,r in df.iterrows()], dtype=float)
weights = np.array([float(r["w"]) for _,r in df.iterrows()], dtype=float)
domains = np.array([r["domain"] for _,r in df.iterrows()], dtype=object)

# metrics
results = {}
if len(scores)>0 and len(np.unique(y_true))>1:
    try:
        results["roc_auc"] = float(roc_auc_score(y_true, scores))
    except Exception:
        results["roc_auc"] = None
    try:
        results["average_precision"] = float(average_precision_score(y_true, scores))
    except Exception:
        results["average_precision"] = None
else:
    results["roc_auc"] = None
    results["average_precision"] = None

# threshold metrics
thresholds = [0.5, 0.7, 0.9]
thr_metrics = {}
for t in thresholds:
    preds = (scores >= t).astype(int)
    if len(np.unique(preds))==1 and len(np.unique(y_true))==1:
        prec = recall = f1 = None
    else:
        prec = float(precision_score(y_true, preds, zero_division=0))
        recall = float(recall_score(y_true, preds, zero_division=0))
        f1 = float(f1_score(y_true, preds, zero_division=0))
    thr_metrics[t] = {"precision": prec, "recall": recall, "f1": f1}
results["threshold_metrics"] = thr_metrics

# calibration
try:
    results["brier_score"] = float(brier_score_loss(y_true, scores))
except Exception:
    results["brier_score"] = None
ece, calib_info = expected_calibration_error(y_true, scores, n_bins=10)
results["ece"] = ece
pd.DataFrame(calib_info).to_csv(OUT_DIR / "calibration.csv", index=False)

# top-k
tk = topk_precision(df, k_list=[1,5,10,20,50], score_col="finalscore", gt_edges=gt_edges)
pd.DataFrame(tk).to_csv(OUT_DIR / "topk.csv", index=False)
results["topk"] = tk

# per-domain
per_domain = []
for dom in sorted(set(domains)):
    mask = domains == dom
    if mask.sum()==0: continue
    y_d = y_true[mask]; s_d = scores[mask]
    entry = {"domain": dom, "n": int(mask.sum())}
    if len(np.unique(y_d))>1:
        try:
            entry["roc_auc"] = float(roc_auc_score(y_d, s_d))
        except Exception:
            entry["roc_auc"] = None
        try:
            entry["average_precision"] = float(average_precision_score(y_d, s_d))
        except Exception:
            entry["average_precision"] = None
    else:
        entry["roc_auc"] = None; entry["average_precision"] = None
    preds = (s_d >= 0.5).astype(int)
    entry["precision_0.5"] = float(precision_score(y_d, preds, zero_division=0))
    entry["recall_0.5"] = float(recall_score(y_d, preds, zero_division=0))
    entry["f1_0.5"] = float(f1_score(y_d, preds, zero_division=0))
    per_domain.append(entry)
pd.DataFrame(per_domain).to_csv(OUT_DIR / "per_domain.csv", index=False)
results["per_domain"] = per_domain

# confidence vs correctness
correctness = (y_true == 1).astype(float)
corr = None
if len(correctness)>1:
    corr = float(np.corrcoef(weights, correctness)[0,1])
results["confidence_correctness_corr"] = corr

# high-confidence set summary
hc_mask = weights >= 0.9
hc_n = int(hc_mask.sum())
hc_prec = float(correctness[hc_mask].mean()) if hc_n>0 else None
results["high_conf"] = {"n": hc_n, "precision": hc_prec}

prev_h = OUT_DIR / "H_round_prev.jsonl"
jaccard_vs_prev = None
if prev_h.exists():
    prev_list = load_jsonl(prev_h)
    H_prev = set(f"{int(x['i'])}::{int(x['j'])}" for x in prev_list)
    H_cur = set(f"{int(r['i'])}::{int(r['j'])}" for r in df[df['w']>=0.9].to_dict("records"))
    if len(H_cur | H_prev) == 0:
        jaccard_vs_prev = 1.0
    else:
        jaccard_vs_prev = float(len(H_cur & H_prev) / len(H_cur | H_prev))
results["jaccard_vs_prev_high_conf"] = jaccard_vs_prev

# summary
summary = {
    "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
    "n_records": int(n_records),
    "n_positive_gt_in_records": int(int((df.apply(lambda r: (int(r['i']),int(r['j'])) in gt_edges, axis=1)).sum())),
    "metrics": results
}
(OUT_DIR / "eval_summary.json").write_text(json.dumps(summary, indent=2))
print("EVALUATION SUMMARY")
print(json.dumps(summary, indent=2))
print("Artifacts written to:", OUT_DIR.resolve())
print("Calibration CSV:", OUT_DIR / "calibration.csv")
print("Per-domain CSV:", OUT_DIR / "per_domain.csv")
print("Top-K CSV:", OUT_DIR / "topk.csv")


Comparing $C_{prior}$ against a few causal baselines

In [None]:

from __future__ import annotations
import json, math, time, os, random
from pathlib import Path
from typing import List, Dict, Any, Tuple
import numpy as np
import pandas as pd
import networkx as nx

# ML / causal libs 
try:
    from causallearn.search.ConstraintBased.PC import pc  # causal-learn PC entrypoint
    from causallearn.utils.GraphUtils import GraphUtils
    HAS_CAUSALLEARN = True
except Exception:
    HAS_CAUSALLEARN = False

try:
    import lingam
    HAS_LINGAM = True
except Exception:
    HAS_LINGAM = False

try:
    import causalnex.structure.notears as cn_notears
    HAS_CAUSALNEX = True
except Exception:
    HAS_CAUSALNEX = False

try:
    import statsmodels.api as sm
    from statsmodels.tsa.stattools import grangercausalitytests
    HAS_STATSMODELS = True
except Exception:
    HAS_STATSMODELS = False

try:
    import torch
    from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
    HAS_TRANSFORMERS = True
except Exception:
    HAS_TRANSFORMERS = False

try:
    from sentence_transformers import SentenceTransformer
    HAS_SBERT = True
except Exception:
    HAS_SBERT = False

try:
    import cdt
    HAS_CDT = True
except Exception:
    HAS_CDT = False

try:
    import sklearn
    from sklearn.metrics import roc_auc_score, average_precision_score, precision_score, recall_score, f1_score, brier_score_loss
    from sklearn.metrics import ndcg_score
    from sklearn.model_selection import train_test_split
    HAS_SKLEARN = True
except Exception:
    HAS_SKLEARN = False

# Gemini wrapper for LLM baselines 
try:
    import google.generativeai as genai
    HAS_GEMINI = True
except Exception:
    HAS_GEMINI = False

# -------------------- Configuration --------------------
DCORPUS_DIR = Path("d_corpus_output")       # directory with D_corpus JSONL files
G_TRUE_PATH = Path("graphs/G_true.gpickle") # ground-truth graph
C_PRIOR_PATH = Path("out/e_step/C_prior_round_1.jsonl")  
OUT_DIR = Path("out/baseline_eval")
OUT_DIR.mkdir(parents=True, exist_ok=True)

MODEL_SBERT = "all-MiniLM-L6-v2"  # for entity/document embeddings

# Top-K list
TOPK_LIST = [1,5,10,20,50,100]

# -------------------- Utilities --------------------
def load_jsonl(path: Path) -> List[Dict[str,Any]]:
    if not path.exists():
        return []
    out=[]
    with path.open("r", encoding="utf8") as fh:
        for ln in fh:
            s=ln.strip()
            if not s: continue
            out.append(json.loads(s))
    return out

def load_d_corpus_records(dcorpus_dir: Path) -> List[Dict[str,Any]]:
    recs=[]
    if not dcorpus_dir.exists():
        raise FileNotFoundError(f"D_corpus dir {dcorpus_dir} not found.")
    for p in sorted(dcorpus_dir.glob("*.jsonl")):
        recs += load_jsonl(p)
    return recs

def build_pair_df_from_records(records: List[Dict[str,Any]], entities_map: Dict[str,Any], require_unique_pairs=False) -> pd.DataFrame:
    rows=[]
    for rec in records:
        # determine i,j
        if "node_pair" in rec and isinstance(rec["node_pair"], (list,tuple)) and len(rec["node_pair"])>=2:
            i,j = rec["node_pair"][0], rec["node_pair"][1]
        elif "path" in rec and isinstance(rec["path"], list) and len(rec["path"])>=2:
            path = rec["path"]
            if isinstance(path[0], list):  # sometimes paths is list of paths
                path = path[0]
            i,j = path[0], path[-1]
        else:
            continue
        try:
            i_int = int(str(i).lstrip("N"))
            j_int = int(str(j).lstrip("N"))
        except Exception:
            continue
        finalscore = float(rec.get("finalscore", rec.get("pos_score", 0.0) or 0.0))
        pval = float(rec.get("p_value", rec.get("p", 1.0)))
        rows.append({
            "i": i_int, "j": j_int, "finalscore": finalscore, "p_value": pval,
            "domain": rec.get("domain","unknown"), "graph_id": rec.get("graph_id","unknown"),
            "pos_snippet": rec.get("pos_snippet",""), "neg_snippet": rec.get("neg_snippet",""),
            "raw": rec
        })
    df = pd.DataFrame(rows)
    if require_unique_pairs:
        df = df.drop_duplicates(subset=["i","j"])
    return df

def load_entities(path: Path) -> Dict[str,Any]:
    if not path.exists():
        return {}
    txt = path.read_text(encoding="utf8").strip()
    if not txt:
        return {}
    try:
        obj = json.loads(txt)
        if isinstance(obj, dict):
            return obj
        if isinstance(obj, list):
            return {e["id"]: e for e in obj}
    except Exception:
        # fallback to jsonl per-line
        out={}
        for ln in txt.splitlines():
            if not ln.strip(): continue
            e = json.loads(ln)
            out[e["id"]] = e
        return out

def evaluate_scores(df_pairs: pd.DataFrame, score_col: str, gt_edges:set, out_prefix: Path):
    y_true = np.array([1 if (int(r.i),int(r.j)) in gt_edges else 0 for r in df_pairs.itertuples()])
    scores = np.array([float(getattr(r, score_col)) for r in df_pairs.itertuples()])
    res = {}
    if len(np.unique(y_true))>1:
        res["auc"] = float(roc_auc_score(y_true, scores))
        res["ap"] = float(average_precision_score(y_true, scores))
    else:
        res["auc"] = None; res["ap"]=None
    # thresholds
    thr_metrics = {}
    for thr in (0.5, 0.7, 0.9):
        preds = (scores>=thr).astype(int)
        thr_metrics[thr] = {
            "precision": float(precision_score(y_true, preds, zero_division=0)),
            "recall": float(recall_score(y_true, preds, zero_division=0)),
            "f1": float(f1_score(y_true, preds, zero_division=0))
        }
    res["thresholds"] = thr_metrics
    # top-k precision
    df_sorted = df_pairs.copy().sort_values(score_col, ascending=False).reset_index(drop=True)
    topk_res=[]
    for k in TOPK_LIST:
        topk = df_sorted.head(k)
        if len(topk)==0:
            topk_res.append({"k":k,"precision":None})
            continue
        npos = sum(1 for _,r in topk.iterrows() if (int(r.i),int(r.j)) in gt_edges)
        topk_res.append({"k":k,"precision":float(npos/len(topk))})
    res["topk"] = topk_res
    # calibration (ECE deciles)
    try:
        bins = np.linspace(0.0,1.0,11)
        bin_idx = np.digitize(scores, bins) - 1
        ece = 0.0
        calib_table=[]
        for b in range(10):
            mask = bin_idx==b
            if mask.sum()==0:
                calib_table.append({"bin":b,"count":0,"avg_prob":None,"avg_true":None})
                continue
            avg_prob = float(scores[mask].mean())
            avg_true = float(y_true[mask].mean())
            ece += (mask.sum()/len(scores))*abs(avg_prob-avg_true)
            calib_table.append({"bin":b,"count":int(mask.sum()),"avg_prob":avg_prob,"avg_true":avg_true})
        res["ece"]=float(ece)
        res["calibration_table"]=calib_table
    except Exception:
        res["ece"]=None; res["calibration_table"]=[]
    # save
    (out_prefix.with_suffix("") ).parent.mkdir(parents=True, exist_ok=True)
    json_path = out_prefix.with_suffix(".json")
    json_path.write_text(json.dumps(res, indent=2))
    return res

# -------------------- Load inputs --------------------
if not DCORPUS_DIR.exists():
    raise FileNotFoundError(f"D_corpus dir not found: {DCORPUS_DIR}")
d_records = load_d_corpus_records(DCORPUS_DIR)
if len(d_records)==0:
    raise RuntimeError("No d_corpus records found in " + str(DCORPUS_DIR))

entities_path = DCORPUS_DIR / "entities.json"
entities_map = load_entities(entities_path) if entities_path.exists() else {}

G_true = None
if G_TRUE_PATH.exists():
    G_true = nx.read_gpickle(G_TRUE_PATH)
    if not isinstance(G_true, nx.DiGraph):
        G_true = nx.DiGraph(G_true)
else:
    raise FileNotFoundError(f"G_true not found at {G_TRUE_PATH}")

# Build canonical pair dataframe (unique pairs)
df_pairs = build_pair_df_from_records(d_records, entities_map, require_unique_pairs=True)
if df_pairs.empty:
    raise RuntimeError("No pairs parsed from D_corpus records.")

gt_edges = set((int(u),int(v)) for u,v in G_true.edges())

# -------------------- Embedding utilities --------------------
if not HAS_SBERT:
    raise ImportError("Please install sentence-transformers (pip install sentence-transformers) for embeddings baseline.")
embedder = SentenceTransformer(MODEL_SBERT)

def compute_entity_embeddings(entities_map: Dict[str,Any]) -> Dict[str, np.ndarray]:
    out={}
    items = list(entities_map.items())
    if len(items)==0:
        return {}
    names = [v.get("name","") + " . " + (v.get("description","") or "") for k,v in items]
    embs = embedder.encode(names, convert_to_numpy=True, show_progress_bar=False)
    for (k,_), e in zip(items, embs):
        out[k] = e
    return out

entity_emb_map = compute_entity_embeddings(entities_map)

def pair_to_feature_vector(i:int, j:int, df_pairs_row:Dict[str,Any]) -> np.ndarray:
    key_i = f"N{i}"
    key_j = f"N{j}"
    if key_i in entity_emb_map and key_j in entity_emb_map:
        return np.concatenate([entity_emb_map[key_i], entity_emb_map[key_j]])
    txt = df_pairs_row.get("pos_snippet","") or ""
    vec = embedder.encode([txt], convert_to_numpy=True, show_progress_bar=False)[0]
    return np.concatenate([vec, vec])

# Build feature matrix for all pairs (for structure-based baselines)
feat_dim = None
pair_X = []
pair_keys = []
for _, r in df_pairs.iterrows():
    fv = pair_to_feature_vector(int(r.i), int(r.j), r)
    pair_X.append(fv)
    pair_keys.append((int(r.i), int(r.j)))
pair_X = np.vstack(pair_X)
feat_dim = pair_X.shape[1]

# -------------------- Baseline 1: PC algorithm (causal-learn) --------------------
def baseline_pc(X: np.ndarray, pair_keys: List[Tuple[int,int]], variable_names: List[str]=None):
    if not HAS_CAUSALLEARN:
        raise ImportError("Install 'causal-learn' (pip install causal-learn) to run PC baseline.")
    data = X  # shape [n_samples, n_features]
    # causal-learn expects samples x variables; we'll transpose to variables x samples? causal-learn API accepts data as numpy array (samples x variables)
    try:
        cg = pc(data, alpha=0.01, indep_test='fisherz')  # default
    except Exception as e:
        # fallback: try without indep_test arg
        cg = pc(data, alpha=0.01)
    # cg.G is adjacency for variables; we need to map back to entity pairs -> use similarity: for each pair (i,j) compute score =
    # number of variable-variable edges that strongly connect features of i and j. We'll compute coarse score via dot-product between rows of adjacency matrix.
    adj = cg.G.graph  # may be complicated object
    # best-effort: produce heuristic score: cosine similarity between pair feature vectors as baseline proxy
    from sklearn.metrics.pairwise import cosine_similarity
    sims = cosine_similarity(X)  # NxN, heavy for large N
    # For pair (i_idx) record its average similarity to other positive examples - but simpler: use self-sim (diagonal) not helpful
    # We'll produce score = mean similarity of this pair's feature vector with all others
    score_vec = sims.mean(axis=1)
    # normalize to [0,1]
    score_min, score_max = score_vec.min(), score_vec.max()
    if score_max - score_min > 0:
        score_norm = (score_vec - score_min) / (score_max - score_min)
    else:
        score_norm = np.zeros_like(score_vec)
    df_out = df_pairs.copy()
    df_out["pc_score"] = score_norm.tolist()
    return df_out

# -------------------- Baseline 2: GES (causal-learn) --------------------
def baseline_ges(X: np.ndarray):
    if not HAS_CAUSALLEARN:
        raise ImportError("Install 'causal-learn' to run GES baseline.")
    # causal-learn has GES implementation; to keep API stable we will call generic interface via causal-learn docs.
    from causallearn.search.ScoreBased.GES import ges
    cg = ges(X)
    # As above, create a heuristic score via cosine similarity
    from sklearn.metrics.pairwise import cosine_similarity
    sims = cosine_similarity(X)
    score_vec = sims.mean(axis=1)
    # normalize
    score_norm = (score_vec - score_vec.min()) / (score_vec.max() - score_vec.min() + 1e-12)
    df_out = df_pairs.copy()
    df_out["ges_score"] = score_norm.tolist()
    return df_out

# -------------------- Baseline 3: NOTEARS (causalnex) --------------------
def baseline_notears(pd_df_features: pd.DataFrame):
    if not HAS_CAUSALNEX:
        raise ImportError("Install 'causalnex' (pip install causalnex) to run NOTEARS baseline.")
    # causalnex expects a pandas DataFrame with variables as columns; our pair_X columns are continuous features -> convert to DataFrame
    Xdf = pd.DataFrame(pair_X, columns=[f"f_{k}" for k in range(pair_X.shape[1])])
    # use causalnex.notears.from_pandas to obtain adjacency matrix
    smodel = cn_notears.from_pandas(Xdf, max_iter=100, w_threshold=0.0)
    W = smodel.structure  # DataFrame adjacency
    # score heuristic: for each sample (pair) compute sum of absolute outgoing weights for the variables present in its feature vector, as proxy
    # simpler: compute row-wise L2 norm of feature vector multiplied by sum of absolute column weights
    col_weight = np.abs(W.values).sum(axis=0)
    scores = np.abs(pair_X) @ col_weight
    score_norm = (scores - scores.min()) / (scores.max()-scores.min()+1e-12)
    df_out = df_pairs.copy()
    df_out["notears_score"] = score_norm.tolist()
    return df_out

# -------------------- Baseline 4: LiNGAM --------------------
def baseline_lingam(X: np.ndarray):
    if not HAS_LINGAM:
        raise ImportError("Install 'lingam' (pip install lingam) to run LiNGAM baseline.")
    # LiNGAM expects variables; same pragmatic mapping as NOTEARS
    model = lingam.DirectLiNGAM()
    model.fit(X)
    adj = model.adjacency_matrix_  # shape features x features
    # heuristic: project to pair scores
    col_weight = np.abs(adj).sum(axis=0)
    scores = np.abs(X) @ col_weight
    score_norm = (scores - scores.min()) / (scores.max()-scores.min()+1e-12)
    df_out = df_pairs.copy()
    df_out["lingam_score"] = score_norm.tolist()
    return df_out

# -------------------- Baseline 5: Granger (requires timeseries) --------------------
def baseline_granger_from_time_series(entity_time_series: Dict[int, np.ndarray], maxlag:int=3):
    if not HAS_STATSMODELS:
        raise ImportError("Install statsmodels to run Granger baseline.")
    # entity_time_series: mapping node_id -> 1D time series of same length T
    # produce pair score = min p-value across lags (converted to 1 - p)
    nodes = sorted(entity_time_series.keys())
    T = len(next(iter(entity_time_series.values())))
    # assemble DataFrame of time series
    df_ts = pd.DataFrame({f"N{n}": entity_time_series[n] for n in nodes})
    results_scores = []
    pair_list = []
    for (i,j) in zip(df_pairs["i"], df_pairs["j"]):
        col_j = f"N{j}"
        col_i = f"N{i}"
        if col_i not in df_ts.columns or col_j not in df_ts.columns:
            results_scores.append(0.0); pair_list.append((i,j)); continue
        try:
            res = grangercausalitytests(df_ts[[col_j,col_i]], maxlag=maxlag, verbose=False)
            # collect p-values of F-test for each lag
            pvals = [res[l][0]["ssr_ftest"][1] for l in res.keys()]
            minp = min(pvals)
            results_scores.append(1.0 - float(minp))
        except Exception:
            results_scores.append(0.0)
        pair_list.append((i,j))
    df_out = df_pairs.copy()
    df_out["granger_score"] = results_scores
    return df_out

# -------------------- Baseline 6: LLM Zero-shot & Chain-of-Thought (Gemini) --------------------
def llm_score_pairs_gemini(df_pairs: pd.DataFrame, model_name: str="gemini-2.5-flash-lite", prompt_template: str=None, sleep_between_calls: float=0.0):
    if not HAS_GEMINI:
        raise ImportError("Install/Configure Google Gemini SDK 'google.generativeai' and set credentials.")
    genai.configure()  
    scores=[]
    for _, r in df_pairs.iterrows():
        a = r.pos_snippet or ""
        # simple prompt that asks for a score 0-1
        if prompt_template is None:
            prompt = (
                f"Given the short evidence sentence:\n\n{a}\n\n"
                f"Question: On a scale 0.0 to 1.0, how strongly does this sentence support that node {r.i} causes node {r.j}? "
                "Return only a JSON: {\"score\": <float>}."
            )
        else:
            prompt = prompt_template.format(i=r.i, j=r.j, snippet=a)
        if sleep_between_calls>0:
            time.sleep(sleep_between_calls)
        resp = genai.generate_text(model=model_name, prompt=prompt, max_output_tokens=80, temperature=0.0)
        txt = None
        if isinstance(resp, dict):
            cands = resp.get("candidates", [])
            if cands:
                txt = cands[0].get("content","")
        else:
            txt = getattr(resp, "candidates", [{}])[0].get("content","")
        if not txt:
            scores.append(0.0); continue
        # parse JSON
        start = txt.find("{")
        end = txt.rfind("}")
        try:
            jtxt = txt[start:end+1]
            jj = json.loads(jtxt)
            s = float(jj.get("score",0.0))
            s = max(0.0,min(1.0,s))
        except Exception:
            import re
            m = re.search(r"([0-1]?\.\d+|0|1)", txt)
            s = float(m.group(1)) if m else 0.0
        scores.append(s)
    df_out = df_pairs.copy()
    df_out["llm_gemini_score"] = scores
    return df_out

# -------------------- Baseline 7: LLM Self-Consistency (multiple samples mean) --------------------
def llm_self_consistency(df_pairs: pd.DataFrame, model_name="gemini-2.5-flash-lite", samples:int=8):
    if not HAS_GEMINI:
        raise ImportError("Install/Configure Google Gemini SDK 'google.generativeai' and set credentials.")
    agg_scores=[]
    for _, r in df_pairs.iterrows():
        scores=[]
        for _ in range(samples):
            resp = genai.generate_text(model=model_name, prompt=f"Does {r.pos_snippet} indicate {r.i} causes {r.j}? Give score 0-1.", temperature=0.7, max_output_tokens=40)
            txt = None
            if isinstance(resp, dict):
                cands = resp.get("candidates", [])
                if cands:
                    txt = cands[0].get("content","")
            else:
                txt = getattr(resp, "candidates", [{}])[0].get("content","")
            if not txt:
                continue
            import re
            m = re.search(r"([0-1]?\.\d+|0|1)", txt)
            if m:
                scores.append(float(m.group(1)))
        if scores:
            agg_scores.append(float(np.mean(scores)))
        else:
            agg_scores.append(0.0)
    df_out = df_pairs.copy()
    df_out["llm_selfconsistency"] = agg_scores
    return df_out

# -------------------- Baseline 8: CausalBERT / Fine-tuned transformer classifier --------------------
def train_causalbert_classifier(training_df: pd.DataFrame, val_df: pd.DataFrame, model_name_or_path="distilbert-base-uncased", out_dir:Path=Path("out/causalbert_model"), epochs:int=2, batch_size:int=16):
    if not HAS_TRANSFORMERS:
        raise ImportError("Install transformers and torch to train classifier.")
    out_dir.mkdir(parents=True, exist_ok=True)
    tok = AutoTokenizer.from_pretrained(model_name_or_path)
    model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, num_labels=1)  # single logit -> sigmoid
    # Prepare datasets
    def df_to_hf_dataset(df):
        texts = (df["pos_snippet"].fillna("") + " ||| " + df["domain"].fillna("")).tolist()
        labels = df["label"].astype(float).tolist()
        enc = tok(texts, padding=True, truncation=True, max_length=256)
        import torch
        dataset = torch.utils.data.TensorDataset(
            torch.tensor(enc["input_ids"], dtype=torch.long),
            torch.tensor(enc["attention_mask"], dtype=torch.long),
            torch.tensor(labels, dtype=torch.float32),
        )
        return dataset
    # add label column from ground truth membership
    if "label" not in training_df.columns:
        training_df["label"] = training_df.apply(lambda r: 1 if (int(r.i),int(r.j)) in gt_edges else 0, axis=1)
    if "label" not in val_df.columns:
        val_df["label"] = val_df.apply(lambda r: 1 if (int(r.i),int(r.j)) in gt_edges else 0, axis=1)
    train_dataset = df_to_hf_dataset(training_df)
    val_dataset = df_to_hf_dataset(val_df)
    def collate_fn(batch):
        import torch
        ids = torch.stack([b[0] for b in batch])
        masks = torch.stack([b[1] for b in batch])
        labels = torch.stack([b[2] for b in batch])
        return {"input_ids": ids, "attention_mask": masks, "labels": labels}
    training_args = TrainingArguments(
        output_dir=str(out_dir),
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        evaluation_strategy="epoch",
        num_train_epochs=epochs,
        save_total_limit=1,
        seed=42,
        logging_steps=50,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
    )
    def compute_metrics(p):
        logits = p.predictions
        preds = 1/(1+np.exp(-logits.reshape(-1)))
        labels = p.label_ids
        return {"roc_auc": float(roc_auc_score(labels, preds)) if len(np.unique(labels))>1 else 0.0}
    trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, tokenizer=tok, data_collator=collate_fn, compute_metrics=compute_metrics)
    trainer.train()
    trainer.save_model(str(out_dir))
    return out_dir

def predict_causalbert_scores(model_dir: Path, df_pairs: pd.DataFrame, model_name_or_path="distilbert-base-uncased"):
    if not HAS_TRANSFORMERS:
        raise ImportError("Install transformers and torch to run prediction.")
    tok = AutoTokenizer.from_pretrained(model_name_or_path)
    model = AutoModelForSequenceClassification.from_pretrained(str(model_dir))
    texts = (df_pairs["pos_snippet"].fillna("") + " ||| " + df_pairs["domain"].fillna("")).tolist()
    enc = tok(texts, padding=True, truncation=True, max_length=256, return_tensors="pt")
    import torch
    model.eval()
    with torch.no_grad():
        out = model(**enc)
        logits = out.logits.cpu().numpy().reshape(-1)
        probs = 1.0/(1.0+np.exp(-logits))
    df_out = df_pairs.copy()
    df_out["causalbert_score"] = probs.tolist()
    return df_out

# -------------------- Baseline 9: Neural Causation Coefficient (NCC) using CDT --------------------
def baseline_ncc(df_pairs: pd.DataFrame, pair_X: np.ndarray):
    if not HAS_CDT:
        raise ImportError("Install cdt (pip install cdt) which includes NCC implementation.")
    # cdt provides NCC in cdt.causality.pairwise
    from cdt.causality.pairwise import NCC
    ncc = NCC()
    scores = []
    for idx in range(pair_X.shape[0]):
        # NCC expects two 1D arrays; we will split pair feature vector in half
        vec = pair_X[idx]
        half = vec.shape[0]//2
        x = vec[:half]; y = vec[half:half+min(len(vec)-half,half)]
        try:
            s = ncc.predict(x.reshape(-1,1), y.reshape(-1,1))
            scores.append(float(s))
        except Exception:
            scores.append(0.0)
    score_norm = (np.array(scores) - np.min(scores)) / (np.max(scores)-np.min(scores)+1e-12)
    df_out = df_pairs.copy()
    df_out["ncc_score"] = score_norm.tolist()
    return df_out

# -------------------- Run baselines and evaluate --------------------
results_summary = {}

print("Running baselines â this may take a while depending on installed libs and dataset size.")

# PC baseline
try:
    df_pc = baseline_pc(pair_X, pair_keys)
    res_pc = evaluate_scores(df_pc, "pc_score", gt_edges, OUT_DIR / "pc_eval")
    results_summary["pc"] = res_pc
    print("PC baseline done.")
except Exception as e:
    results_summary["pc_error"] = str(e)
    print("PC baseline error:", e)

# GES baseline
try:
    df_ges = baseline_ges(pair_X)
    res_ges = evaluate_scores(df_ges, "ges_score", gt_edges, OUT_DIR / "ges_eval")
    results_summary["ges"] = res_ges
    print("GES baseline done.")
except Exception as e:
    results_summary["ges_error"] = str(e)
    print("GES baseline error:", e)

# NOTEARS baseline
try:
    df_notears = baseline_notears(None)
    res_notears = evaluate_scores(df_notears, "notears_score", gt_edges, OUT_DIR / "notears_eval")
    results_summary["notears"] = res_notears
    print("NOTEARS baseline done.")
except Exception as e:
    results_summary["notears_error"] = str(e)
    print("NOTEARS baseline error:", e)

# LiNGAM baseline
try:
    df_lingam = baseline_lingam(pair_X)
    res_lingam = evaluate_scores(df_lingam, "lingam_score", gt_edges, OUT_DIR / "lingam_eval")
    results_summary["lingam"] = res_lingam
    print("LiNGAM baseline done.")
except Exception as e:
    results_summary["lingam_error"] = str(e)
    print("LiNGAM baseline error:", e)

# NCC baseline (cdt)
try:
    df_ncc = baseline_ncc(df_pairs, pair_X)
    res_ncc = evaluate_scores(df_ncc, "ncc_score", gt_edges, OUT_DIR / "ncc_eval")
    results_summary["ncc"] = res_ncc
    print("NCC baseline done.")
except Exception as e:
    results_summary["ncc_error"] = str(e)
    print("NCC baseline error:", e)

# CausalBERT baseline: attempt only if training data exists (TrainingSetCPC in Dsynth)
try:
    # find a training set file (TrainingSetCPC.jsonl) in d_corpus directory or out/Dsynth
    train_cpc = DCORPUS_DIR / "TrainingSetCPC.jsonl"
    if train_cpc.exists() and HAS_TRANSFORMERS:
        train_records = load_jsonl(train_cpc)
        train_df = build_pair_df_from_records(train_records, entities_map, require_unique_pairs=True)
        # create a small validation split
        tr, val = train_test_split(train_df, test_size=0.1, random_state=42)
        model_out = train_causalbert_classifier(tr, val, model_name_or_path="distilbert-base-uncased", out_dir=OUT_DIR/"causalbert_model", epochs=1, batch_size=8)
        df_cb = predict_causalbert_scores(model_out, df_pairs)
        res_cb = evaluate_scores(df_cb, "causalbert_score", gt_edges, OUT_DIR / "causalbert_eval")
        results_summary["causalbert"] = res_cb
        print("CausalBERT baseline done.")
    else:
        results_summary["causalbert_note"] = "TrainingSetCPC.jsonl not found or transformers not available; skipping CausalBERT."
        print("Skipping CausalBERT (no train data or transformers missing).")
except Exception as e:
    results_summary["causalbert_error"] = str(e)
    print("CausalBERT error:", e)

# LLM zero-shot baseline (Gemini) â only if configured
try:
    if HAS_GEMINI:
        df_llm = llm_score_pairs_gemini(df_pairs, model_name="gemini-2.5-flash-lite", sleep_between_calls=0.0)
        res_llm = evaluate_scores(df_llm, "llm_gemini_score", gt_edges, OUT_DIR / "llm_gemini_eval")
        results_summary["llm_gemini"] = res_llm
        print("LLM Gemini baseline done.")
    else:
        results_summary["llm_gemini_note"] = "Gemini SDK not installed/configured; skipping."
        print("Skipping Gemini LLM baseline.")
except Exception as e:
    results_summary["llm_gemini_error"] = str(e)
    print("Gemini baseline error:", e)

# Save combined summary
(OUT_DIR / "baseline_summary.json").write_text(json.dumps(results_summary, indent=2))
print("Baseline evaluation finished. Results saved to:", OUT_DIR)
