In [1]:
#!/usr/bin/env python3
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Data
from pathlib import Path
from scipy import sparse
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
DATA_DIR   = Path("smb_data")  # Input CSVs directory
ART_DIR    = Path("artifacts") # Output artifacts directory
ART_DIR.mkdir(exist_ok=True)


# ENSEMBLE OF ML MODELS - DEPRECATED

In [1]:
#!/usr/bin/env python3
# train_ensemble_with_reports.py  ·  dual-task (within / across census)
#                                  + streaming diagnostics to avoid OOM
# -----------------------------------------------------------------------------
# – trains 4 tree models + simple-average ensemble
# – evaluates on “within-census” and “cross-census” record pairs
# – full metrics / ROC / clustering diagnostics
# – diagnostics scored in RAM-friendly batches (default 20 000 rows subsample)
# -----------------------------------------------------------------------------

import itertools
import json
import joblib
import logging
import random
import time
import warnings
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from catboost import CatBoostClassifier
from lightgbm import LGBMClassifier
from scipy import sparse
from sklearn.ensemble import RandomForestClassifier
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import (
    adjusted_rand_score,
    accuracy_score,
    auc,
    f1_score,
    precision_recall_fscore_support,
    roc_curve,
)
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
from xgboost import XGBClassifier

# ───────────── paths / constants ─────────────
ART_DIR      = Path("artifacts")
BASE_MODEL   = Path("models_ensemble"); BASE_MODEL.mkdir(exist_ok=True)
BASE_REPORT  = Path("reports");         BASE_REPORT.mkdir(exist_ok=True)

RNG_SEED     = 42
NEG_PER_POS  = 2
TOP_K        = 5
SUBSET_DEF   = 1000        # default rows for diagnostics subsample

# ───────────── helpers ─────────────
class Timer:
    def __init__(self, msg):
        self.msg, self.t0 = msg, time.perf_counter()
    def __enter__(self):
        logging.info(f"▶ {self.msg}")
        return self
    def __exit__(self, *_):
        logging.info(f"⏱ {self.msg}: {time.perf_counter() - self.t0:.1f}s")

def prf_auc(y_true, y_prob):
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    return {"roc_auc": auc(fpr, tpr)}, (fpr, tpr)

# ───────────────────────────────────────
#  Data loading
# ───────────────────────────────────────
def load_X_labels_heimild():
    """Load sparse feature matrix, person-ids and ‘heimild’ source labels."""
    X = sparse.load_npz(ART_DIR / "iceid_ml_ready.npz")
    lbls = pd.read_csv(ART_DIR / "row_labels.csv")
    lbls["person"] = (
        pd.to_numeric(lbls["person"], errors="coerce")
          .fillna(-1)
          .astype(int)
    )
    heimild_df = (
        pd.read_csv("smb_data/people.csv", usecols=["id", "heimild"])
          .set_index("id")
    )
    heimild = (
        heimild_df.reindex(lbls.row_id)["heimild"]
                  .fillna(-1)
                  .astype(int)
                  .values
    )
    return X, lbls["person"].values, heimild

# ───────────────────────────────────────
#  Pair sampling
# ───────────────────────────────────────
def build_pairs(labels, heimild, mode):
    idx_lab = [i for i, p in enumerate(labels) if p != -1]
    id_by_person = {}
    for i in idx_lab:
        id_by_person.setdefault(labels[i], []).append(i)
    # positives
    pos = []
    for grp in id_by_person.values():
        for a, b in itertools.combinations(grp, 2):
            same = heimild[a] == heimild[b]
            if (mode == "within" and same) or (mode == "across" and not same):
                pos.append((a, b))
    # negatives
    neg = set()
    target = len(pos) * NEG_PER_POS
    rng = random.Random(RNG_SEED)
    while len(neg) < target:
        a, b = rng.sample(idx_lab, 2)
        if labels[a] == labels[b]:
            continue
        same = heimild[a] == heimild[b]
        if (mode == "within" and same) or (mode == "across" and not same):
            neg.add(tuple(sorted((a, b))))
    return pos, list(neg)

# ───────────────────────────────────────
#  Sparse pair builder
# ───────────────────────────────────────
def pair_matrix(X, pairs):
    return sparse.vstack([sparse.hstack([X[i], X[j]]) for i, j in pairs])

# ───────────────────────────────────────
#  Threshold tuning
# ───────────────────────────────────────
def tune_thr(y, p):
    best_t, best_f = 0.5, 0.0
    for t in np.linspace(0, 1, 101):
        f = f1_score(y, p >= t)
        if f > best_f:
            best_f, best_t = f, t
    return best_t

# ───────────────────────────────────────
#  Streaming diagnostics helpers
# ───────────────────────────────────────
def batched_pairs(idx_rows, batch_sz=10_000):
    combos = itertools.combinations(range(len(idx_rows)), 2)
    while True:
        batch = list(itertools.islice(combos, batch_sz))
        if not batch:
            break
        yield batch

def score_pairs_stream(X, idx_rows, ens_fn, batch_sz=10_000):
    """Compute ensemble scores in streaming batches with progress bar."""
    scores = []
    total_pairs = len(idx_rows) * (len(idx_rows)-1) // 2
    # approximate number of batches
    n_batches = (total_pairs + batch_sz - 1)//batch_sz
    for chunk in tqdm(
        batched_pairs(idx_rows, batch_sz),
        total=n_batches,
        desc="diagnostics batches",
        unit="batch"
    ):
        mat = pair_matrix(X, [(idx_rows[i], idx_rows[j]) for i, j in chunk])
        scores.extend(ens_fn(mat))
    return np.asarray(scores, dtype=np.float32)

# ───────────────────────────────────────
#  Diagnostics metrics
# ───────────────────────────────────────
def clustering_cc(labels, probs, thr):
    n = len(labels)
    parent = list(range(n))
    def find(x):
        while parent[x] != x:
            parent[x] = parent[parent[x]]
            x = parent[x]
        return x
    k = 0
    for i, j in itertools.combinations(range(n), 2):
        if probs[k] >= thr:
            parent[find(i)] = find(j)
        k += 1
    clusters = [find(i) for i in range(n)]
    gold = [l if l != -1 else -1 for l in labels]
    return adjusted_rand_score(gold, clusters)

def clustering_aggl(labels, probs, n_clusters):
    n = len(labels)
    P = np.zeros((n, n))
    k = 0
    for i, j in itertools.combinations(range(n), 2):
        P[i,j] = P[j,i] = probs[k]
        k += 1
    D = 1 - P
    pred = AgglomerativeClustering(
        n_clusters=n_clusters, metric="precomputed", linkage="average"
    ).fit_predict(D)
    gold = [l if l != -1 else -1 for l in labels]
    return adjusted_rand_score(gold, pred)

def retrieval_at_k(labels, probs, k=TOP_K):
    n = len(labels)
    P = np.zeros((n, n))
    idx = 0
    for i, j in itertools.combinations(range(n), 2):
        P[i,j] = P[j,i] = probs[idx]
        idx += 1
    correct = 0
    total_true = sum(
        1 for i in range(n) for j in range(n)
        if i != j and labels[i] == labels[j] != -1
    )
    retrieved = 0
    for i in range(n):
        topk = np.argsort(P[i])[::-1][:k]
        retrieved += k
        correct += sum(1 for j in topk if labels[i] == labels[j] != -1)
    precision = correct / retrieved if retrieved else 0
    recall    = correct / total_true if total_true else 0
    return precision, recall

# ───────────────────────────────────────
#  Pipeline per task
# ───────────────────────────────────────
def run_task(mode, X, labels, heimild, subset_sz=SUBSET_DEF):
    tag = mode
    model_dir  = BASE_MODEL  / tag; model_dir.mkdir(exist_ok=True)
    report_dir = BASE_REPORT / tag; report_dir.mkdir(exist_ok=True)

    logging.info("=== %s ===", tag)
    pos, neg = build_pairs(labels, heimild, mode)
    if not pos:
        logging.warning("no positives for %s", tag)
        return

    # build train/test pairs
    X_pos = pair_matrix(X, pos)
    X_neg = pair_matrix(X, neg)
    y_pos = np.ones(len(pos), int)
    y_neg = np.zeros(len(neg), int)
    X_pairs = sparse.vstack([X_pos, X_neg])
    y_pairs = np.concatenate([y_pos, y_neg])
    X_tr, X_te, y_tr, y_te = train_test_split(
        X_pairs, y_pairs, stratify=y_pairs, test_size=0.2, random_state=RNG_SEED
    )

    # initialize models
    models = {
        "xgb": XGBClassifier(
            tree_method="hist", eval_metric="logloss",
            n_estimators=300, random_state=RNG_SEED
        ),
        "lgb": LGBMClassifier(
            n_estimators=300, random_state=RNG_SEED, verbosity=-1
        ),
        "cat": CatBoostClassifier(
            iterations=300, depth=6, learning_rate=0.1,
            random_state=RNG_SEED, verbose=False
        ),
        "rf": RandomForestClassifier(
            n_estimators=300, n_jobs=-1, random_state=RNG_SEED
        ),
    }

    # train models
    with Timer(f"fit-{tag}"):
        for name, mdl in tqdm(models.items(), desc=f"fit-{tag}", unit="model"):
            mdl.fit(X_tr, y_tr)
            joblib.dump(mdl, model_dir / f"{name}.pkl")

    # ensemble prediction function
    ens_prob = lambda M: np.column_stack(
        [m.predict_proba(M)[:,1] for m in models.values()]
    ).mean(1)

    # evaluate on hold-out
    y_prob = ens_prob(X_te)
    thr    = tune_thr(y_te, y_prob)
    preds  = y_prob >= thr

    pr, rc, f1, _ = precision_recall_fscore_support(
        y_te, preds, average="binary", zero_division=0
    )
    acc           = accuracy_score(y_te, preds)
    fpr, tpr, _   = roc_curve(y_te, y_prob)
    auc_val       = auc(fpr, tpr)

    # write metrics
    metrics = dict(
        precision=pr, recall=rc, f1=f1,
        accuracy=acc, auc=auc_val, threshold=thr
    )
    (report_dir / "metrics.json").write_text(json.dumps(metrics, indent=2))
    pd.DataFrame([metrics]).to_csv(
        report_dir / "metrics_summary.csv", index=False
    )

    # plot ROC
    plt.figure()
    plt.plot(fpr, tpr, label=f"AUC={auc_val:.3f}")
    plt.plot([0,1], [0,1], "--")
    plt.legend()
    plt.title(f"ROC – {tag}")
    plt.tight_layout()
    plt.savefig(report_dir / "roc.png")
    plt.close()

    # ─────── diagnostics ───────
    subset_idx = np.random.default_rng(RNG_SEED).choice(
        len(labels), size=min(subset_sz, len(labels)), replace=False
    )
    logging.info("%s diagnostics on %d rows", tag, len(subset_idx))
    probs_sub = score_pairs_stream(
        X, subset_idx, ens_prob, batch_sz=100_000
    )

    y_sub = labels[subset_idx]
    ari_cc = clustering_cc(y_sub, probs_sub, thr)
    ari_ag = clustering_aggl(
        y_sub, probs_sub, len(np.unique(y_sub[y_sub != -1]))
    )
    p_k, r_k = retrieval_at_k(y_sub, probs_sub)

    diag = dict(
        ari_cc=ari_cc, ari_ag=ari_ag,
        precision_at_k=p_k, recall_at_k=r_k
    )
    (report_dir / "diagnostics.json").write_text(json.dumps(diag, indent=2))

    # plot cluster sizes
    plt.figure()
    pd.Series(y_sub[y_sub != -1]).value_counts().hist(bins=50)
    plt.title("Cluster size distribution")
    plt.tight_layout()
    plt.savefig(report_dir / "cluster_size_dist.png")
    plt.close()

    logging.info("✔ %s done – results in %s", tag, report_dir)

# ───────────────────────────────────────
#  Main entrypoint
# ───────────────────────────────────────
if __name__ == "__main__":
    warnings.filterwarnings("ignore")
    logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")

    X, labels, heimild = load_X_labels_heimild()

    run_task("within" , X, labels, heimild, SUBSET_DEF)
    run_task("across" , X, labels, heimild, SUBSET_DEF)


FileNotFoundError: [Errno 2] No such file or directory: 'smb_data/people.csv'

# FINAL ENSEMBLE TRAINING

In [None]:
#!/usr/bin/env python3
# train_ensemble_gpu.py  ·  dual-task (within / across census)
#                          + GPU support + age disparity calculation
#                          + repeated random-sampling runs + averaged results
# -----------------------------------------------------------------------------
# – trains 4 tree models + simple-average ensemble
# – runs end-to-end pipeline 10 times with new random pairs each run
# – evaluates on fresh train/test split each run
# – diagnostics on held-out records each run
# – clears memory between runs
# – averages all metrics across runs and saves final averages
# -----------------------------------------------------------------------------

import itertools
import json
import joblib
import logging
import random
import time
import warnings
from pathlib import Path
import gc

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy import sparse
from sklearn.datasets import dump_svmlight_file
from sklearn.cluster import AgglomerativeClustering
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    adjusted_rand_score,
    accuracy_score,
    auc,
    f1_score,
    precision_recall_fscore_support,
    roc_curve,
)
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
from joblib import Parallel, delayed

# Attempt to import GPU-accelerated libraries
try:
    import xgboost as xgb
    from xgboost import XGBClassifier
except ImportError:
    logging.warning("XGBoost not found. XGBoost models will not be trained.")
    XGBClassifier = None
try:
    from lightgbm import LGBMClassifier
except ImportError:
    logging.warning("LightGBM not found. LightGBM models will not be trained.")
    LGBMClassifier = None
try:
    from catboost import CatBoostClassifier
except ImportError:
    logging.warning("CatBoost not found. CatBoost models will not be trained.")
    CatBoostClassifier = None
try:
    import torch  # For CUDA check
except ImportError:
    torch = None
    logging.warning("PyTorch not installed. CUDA availability check might be limited.")


# ───────────── paths / constants ─────────────
DATA_DIR       = Path("raw_data")
ART_DIR        = Path("artifacts")
BASE_MODEL_DIR = Path("models_ensemble_gpu"); BASE_MODEL_DIR.mkdir(exist_ok=True, parents=True)
BASE_REPORT    = Path("reports_gpu");         BASE_REPORT.mkdir(exist_ok=True, parents=True)
ART_DIR.mkdir(exist_ok=True)

RNG_SEED                = 42
NEG_PER_POS             = 2
TOP_K                   = 5
AGE_DISPARITY_PERCENTILE = 95

# Caps & diagnostics batching
MAX_TOTAL_PAIRS         = 500_000    # cap on training pairs
MAX_DIAG_PAIRS          = 100_000    # cap on diagnostics pairs per run
DIAG_CLUSTER_SUBSET_SZ  = 2000       # clustering diag batch size
DIAG_CLUSTER_BATCHES    = 10         # clustering diag num batches

# Number of repeated runs
NUM_RUNS = 10


# ───────────── helpers ─────────────
class Timer:
    def __init__(self, msg):
        self.msg, self.t0 = msg, time.perf_counter()
    def __enter__(self):
        logging.info(f"▶ {self.msg}")
        return self
    def __exit__(self, *_):
        logging.info(f"⏱ {self.msg}: {time.perf_counter() - self.t0:.1f}s")


def pair_matrix(X, pairs):
    """Builds [X[i] | X[j]] for each (i,j)."""
    if not pairs:
        return sparse.csr_matrix((0, X.shape[1]*2), dtype=X.dtype)
    return sparse.vstack([
        sparse.hstack([X[i], X[j]]) for i,j in pairs
    ])


def tune_thr(y_true, y_prob):
    best_t, best_f1 = 0.5, 0.0
    for t in np.linspace(0,1,101):
        f1 = f1_score(y_true, y_prob>=t, zero_division=0)
        if f1>best_f1:
            best_f1, best_t = f1, t
    return best_t


# ───────────────────────────────────────
#  Data loading & Age threshold
# ───────────────────────────────────────
def load_training_data_with_details():
    logging.info("Loading features + metadata for training...")
    X_path = ART_DIR/"iceid_ml_ready.npz"
    X = sparse.load_npz(X_path)
    rows_df = pd.read_csv(ART_DIR/"row_labels.csv")
    rows_df["person"] = pd.to_numeric(rows_df["person"],errors="coerce").fillna(-1).astype(int)
    labels      = rows_df["person"].values
    orig_ids    = rows_df["row_id"].values

    people_df = pd.read_csv(
        DATA_DIR/"people.csv",
        usecols=["id","heimild","birthyear"],
        dtype={"id":orig_ids.dtype}
    ).set_index("id")
    aligned    = people_df.reindex(orig_ids)
    heimild    = aligned["heimild"].fillna(-1).astype(int).values
    birthyrs   = pd.to_numeric(aligned["birthyear"],errors="coerce").fillna(0).astype(int).values

    logging.info(f"Loaded X {X.shape}, labels {len(labels)}, heimild len {len(heimild)}")
    return X, labels, heimild, birthyrs


def calculate_and_save_age_disparity_threshold(labels, birthyrs, out_dir):
    p2b = {}
    for i,p in enumerate(labels):
        by = birthyrs[i]
        if p!=-1 and by>1000:
            p2b.setdefault(p,[]).append(by)
    spans = [(max(bs)-min(bs)) if len(bs)>1 else 0 for bs in p2b.values()]
    if not spans:
        thr=10
    else:
        thr=int(np.percentile(spans,AGE_DISPARITY_PERCENTILE))
        if thr<1 and AGE_DISPARITY_PERCENTILE>=90 and any(s>0 for s in spans):
            thr=1
    out_dir.mkdir(exist_ok=True, parents=True)
    (out_dir/"age_disparity_threshold.json").write_text(
        json.dumps({"age_disparity_threshold_years":thr})
    )
    logging.info(f"Age‐disp threshold={thr}")
    return thr


# ───────────────────────────────────────
#  Pair sampling
# ───────────────────────────────────────
def build_pairs_for_training(labels, heimild, mode, seed):
    rng = random.Random(seed)
    idx_lab = [i for i,p in enumerate(labels) if p!=-1]
    by_person={}
    for i in idx_lab:
        by_person.setdefault(labels[i],[]).append(i)

    pos=[]
    for inds in by_person.values():
        for a,b in itertools.combinations(inds,2):
            if heimild[a]==-1 or heimild[b]==-1: continue
            same=(heimild[a]==heimild[b])
            if mode=="within" and same: pos.append((a,b))
            if mode=="across" and not same: pos.append((a,b))

    target_neg = len(pos)*NEG_PER_POS
    neg_set=set()
    attempts,max_attempts=0,target_neg*20+1000
    while len(neg_set)<target_neg and attempts<max_attempts:
        attempts+=1
        a,b = rng.sample(idx_lab,2)
        if labels[a]==labels[b] or heimild[a]==-1 or heimild[b]==-1: continue
        same=(heimild[a]==heimild[b])
        if mode=="within" and same: neg_set.add(tuple(sorted((a,b))))
        if mode=="across" and not same: neg_set.add(tuple(sorted((a,b))))

    pos_final=pos if len(pos)+len(neg_set)<=MAX_TOTAL_PAIRS else rng.sample(pos,MAX_TOTAL_PAIRS//2)
    neg_final=list(neg_set) if len(pos)+len(neg_set)<=MAX_TOTAL_PAIRS else rng.sample(list(neg_set),MAX_TOTAL_PAIRS//2)
    return pos_final, neg_final


def sample_diag_pairs(indices, heimild, mode, n_pairs, seed):
    rng, pairs = random.Random(seed), set()
    attempts, max_attempts = 0, n_pairs*20+1000
    while len(pairs)<n_pairs and attempts<max_attempts:
        attempts+=1
        a,b = rng.sample(indices,2)
        if heimild[a]==-1 or heimild[b]==-1: continue
        same=(heimild[a]==heimild[b])
        if mode=="within" and not same: continue
        if mode=="across" and same: continue
        pairs.add(tuple(sorted((a,b))))
    return list(pairs)


# ───────────────────────────────────────
#  Batched diagnostics averaging
# ───────────────────────────────────────
def _diag_batch_metrics(batch_ids, pairs, scores, thr, labels):
    # same as before: CC, AG, P@K, R@K on just batch_ids subset
    local_map={rid:i for i,rid in enumerate(batch_ids)}
    N=len(batch_ids)
    gold=[labels[r] for r in batch_ids]
    # treat unknowns
    neg=-2
    for i,v in enumerate(gold):
        if v==-1: gold[i]=neg; neg-=1
    # CC
    parent=list(range(N))
    def find(x):
        if parent[x]!=x: parent[x]=find(parent[x])
        return parent[x]
    def union(a,b):
        ra,rb=find(a),find(b)
        if ra!=rb: parent[ra]=rb
    for (a,b),s in zip(pairs,scores):
        if a in local_map and b in local_map and s>=thr:
            union(local_map[a],local_map[b])
    pred_cc=[find(i) for i in range(N)]
    ari_cc=adjusted_rand_score(gold,pred_cc)
    # AG
    ari_ag=float('nan')
    true_clusters=len({g for g in gold if g>=0})
    if true_clusters==0: true_clusters=N
    if 2<=N and 1<=true_clusters<=N:
        sim=np.zeros((N,N),dtype=np.float32)
        for (a,b),s in zip(pairs,scores):
            if a in local_map and b in local_map:
                i,j=local_map[a],local_map[b]
                sim[i,j]=sim[j,i]=s
        dist=1-sim; np.fill_diagonal(dist,0)
        try:
            agg=AgglomerativeClustering(n_clusters=true_clusters,
                                        metric="precomputed",
                                        linkage="average")
            pred_ag=agg.fit_predict(dist)
            ari_ag=adjusted_rand_score(gold,pred_ag)
        except:
            ari_ag=float('nan')
    # Retrieval
    prob=np.full((N,N),-1.0,dtype=np.float32)
    for (a,b),s in zip(pairs,scores):
        if a in local_map and b in local_map:
            i,j=local_map[a],local_map[b]
            prob[i,j]=prob[j,i]=s
    correct=0;retr=0;rel=0
    for i in range(N):
        lbl=labels[batch_ids[i]]
        if lbl==-1: continue
        # relevant
        r_count=sum(1 for j in range(N)
                    if j!=i and labels[batch_ids[j]]==lbl)
        rel+=r_count
        # top-K
        neigh=[(prob[i,j],j) for j in range(N) if prob[i,j]>-0.5]
        neigh.sort(key=lambda x:x[0],reverse=True)
        for s,j in neigh[:TOP_K]:
            retr+=1
            if labels[batch_ids[j]]==lbl: correct+=1
    rel/=2
    p_at_k=correct/retr if retr else 0.0
    r_at_k=correct/rel if rel else 1.0

    return ari_cc, ari_ag, p_at_k, r_at_k


def batched_diagnostics(pairs, scores, thr, heldout, labels):
    rng=random.Random(RNG_SEED)
    results=Parallel(n_jobs=-1)(
        delayed(_diag_batch_metrics)(
            rng.sample(heldout,min(DIAG_CLUSTER_SUBSET_SZ,len(heldout))),
            pairs,scores,thr,labels
        ) for _ in range(DIAG_CLUSTER_BATCHES)
    )
    arr=np.array(results,dtype=np.float32)  # (batches,4)
    return tuple(arr.mean(axis=0).tolist())


# ───────────────────────────────────────
#  Main pipeline per run
# ───────────────────────────────────────
def run_once(run_id, X, labels, heimild, birthyrs, heldout):
    """
    Execute one full pipeline: sample pairs, train, test, diagnose.
    Returns dict of metrics.
    """
    np.random.seed(RNG_SEED+run_id)
    random.seed(RNG_SEED+run_id)

    # Age threshold (constant across runs)
    age_thr = calculate_and_save_age_disparity_threshold(
        labels, birthyrs, BASE_REPORT/"_global_training_outputs"
    )

    all_results={}
    for mode in ("within","across"):
        # 1) sample train pairs
        pos,neg = build_pairs_for_training(labels,heimild,mode,seed=RNG_SEED+run_id)
        Xp=pair_matrix(X,pos); Xn=pair_matrix(X,neg)
        yp=np.ones(len(pos),int); yn=np.zeros(len(neg),int)
        Xall=sparse.vstack([Xp,Xn]); yall=np.concatenate([yp,yn])
        Xtr,Xte,ytr,yte = train_test_split(
            Xall,yall,stratify=yall,test_size=0.2,random_state=RNG_SEED+run_id
        )

        # 2) train ensemble
        models={}
        if XGBClassifier:
            models["xgb"]=XGBClassifier(
                tree_method="hist",
                device="cuda" if torch and torch.cuda.is_available() else "cpu",
                eval_metric="logloss",n_estimators=300,random_state=RNG_SEED+run_id
            )
        if LGBMClassifier:
            models["lgb"]=LGBMClassifier(
                n_estimators=300,random_state=RNG_SEED+run_id,verbosity=-1,
                device="gpu" if torch and torch.cuda.is_available() else "cpu"
            )
        if CatBoostClassifier:
            models["cat"]=CatBoostClassifier(
                iterations=300,depth=6,learning_rate=0.1,
                random_state=RNG_SEED+run_id,verbose=False,
                task_type="GPU" if torch and torch.cuda.is_available() else "CPU"
            )
        models["rf"]=RandomForestClassifier(n_estimators=300,n_jobs=-1,random_state=RNG_SEED+run_id)

        trained={}
        with Timer(f"[run {run_id}] Fitting models for {mode}"):
            for name,mdl in models.items():
                try:
                    if name=="xgb":
                        tmp=BASE_MODEL_DIR/f"{mode}_run{run_id}_xgb.libsvm"
                        dump_svmlight_file(Xtr,ytr,str(tmp))
                        dtrain=xgb.DMatrix(f"{tmp}?format=libsvm#dtrain.cache")
                        params={"objective":"binary:logistic","tree_method":"gpu_hist",
                                "eval_metric":"logloss","seed":RNG_SEED+run_id}
                        bst=xgb.train(params,dtrain,num_boost_round=300)
                        trained["xgb"]=bst
                    else:
                        mdl.fit(Xtr,ytr); trained[name]=mdl
                except Exception as e:
                    logging.error(f"Run {run_id} {mode} {name} train error: {e}")

        def _pred(m, M):
            if isinstance(m,xgb.Booster):
                return m.predict(xgb.DMatrix(M))
            else:
                return m.predict_proba(M)[:,1]
        ensemble_fn=lambda M: np.column_stack([_pred(m,M) for m in trained.values()]).mean(axis=1)

        # 3) evaluate
        p_te=ensemble_fn(Xte)
        thr = tune_thr(yte,p_te)
        pr,rc,f1,_=precision_recall_fscore_support(yte,p_te>=thr,average="binary",zero_division=0)
        acc=accuracy_score(yte,p_te>=thr)
        try:
            fpr,tpr,_=roc_curve(yte,p_te); aucv=auc(fpr,tpr)
        except: aucv=float('nan')
        all_results[f"{mode}_test"]={"precision":pr,"recall":rc,"f1":f1,
                                     "accuracy":acc,"threshold":thr,"auc":aucv}

        # 4) diagnostics
        diag_pairs=sample_diag_pairs(heldout,heimild,mode,MAX_DIAG_PAIRS,seed=RNG_SEED+run_id)
        diag_scores=[]
        B=10_000
        for i in range(0,len(diag_pairs),B):
            sub=diag_pairs[i:i+B]
            M=pair_matrix(X,sub)
            diag_scores.extend(ensemble_fn(M))
        ari_cc,ari_ag,p_at_k,r_at_k = batched_diagnostics(
            diag_pairs,diag_scores,thr,heldout,labels
        )
        all_results[f"{mode}_diag"]={
            "ari_cc":ari_cc,"ari_ag":ari_ag,
            "precision_at_k":p_at_k,"recall_at_k":r_at_k
        }

        # free memory
        del Xp,Xn,Xall,Xtr,Xte,trained,ensemble_fn,diag_pairs,diag_scores
        gc.collect()

    return all_results


if __name__=="__main__":
    warnings.filterwarnings("ignore",category=UserWarning)
    warnings.filterwarnings("ignore",category=FutureWarning)
    logging.basicConfig(level=logging.INFO,
        format="%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S")

    if torch:
        logging.info(f"PyTorch {torch.__version__}; CUDA={'yes' if torch.cuda.is_available() else 'no'}")

    # load data once
    X, labels, heimild, birthyrs = load_training_data_with_details()

    # prepare held-out record split (constant across runs)
    full_labeled = np.where(labels!=-1)[0]
    _, heldout = train_test_split(full_labeled, test_size=0.2, random_state=RNG_SEED)
    heldout = heldout.tolist()

    # run NUM_RUNS times
    all_runs = []
    for run_id in range(NUM_RUNS):
        res = run_once(run_id, X, labels, heimild, birthyrs, heldout)
        all_runs.append(res)

    # average metrics across runs
    avg = {}
    for mode_stage in all_runs[0].keys():
        # collect values per run
        vals = {k:[] for k in all_runs[0][mode_stage]}
        for r in all_runs:
            for k,v in r[mode_stage].items():
                vals[k].append(v)
        avg[mode_stage] = {k: float(np.mean(vs)) for k,vs in vals.items()}

    # — save per-mode & combined averaged metrics —
    for mode in ("within", "across"):
        mode_avg = {
            "test":       avg[f"{mode}_test"],
            "diagnostics": avg[f"{mode}_diag"]
        }
        mode_dir = BASE_REPORT / mode
        mode_dir.mkdir(exist_ok=True, parents=True)
        (mode_dir / "averaged_metrics.json").write_text(json.dumps(mode_avg, indent=2))

        flat = {"mode": mode}
        flat.update({f"test_{k}": v for k, v in mode_avg["test"].items()})
        flat.update({f"diag_{k}": v for k, v in mode_avg["diagnostics"].items()})
        pd.DataFrame([flat]).to_csv(mode_dir / "averaged_metrics.csv", index=False)

    # combined average across both modes
    combined = {}
    for k in avg["within_test"]:
        combined[f"test_{k}"] = 0.5 * (avg["within_test"][k] + avg["across_test"][k])
    for k in avg["within_diag"]:
        combined[f"diag_{k}"] = 0.5 * (avg["within_diag"][k] + avg["across_diag"][k])

    comb_dir = BASE_REPORT / "combined"
    comb_dir.mkdir(exist_ok=True, parents=True)
    (comb_dir / "averaged_metrics.json").write_text(json.dumps(combined, indent=2))
    pd.DataFrame([{"mode": "combined", **combined}]) \
        .to_csv(comb_dir / "averaged_metrics.csv", index=False)

    logging.info("✔ All runs complete; per-mode and combined averaged metrics written to reports/")


2025-05-11 22:37:45 [INFO] PyTorch 2.7.0+cu126; CUDA=yes
2025-05-11 22:37:45 [INFO] Loading features + metadata for training...
2025-05-11 22:37:46 [INFO] Loaded X (984028, 50799), labels 984028, heimild len 984028
2025-05-11 22:37:47 [INFO] Age‐disp threshold=76
2025-05-11 22:37:52 [INFO] ▶ [run 0] Fitting models for within
2025-05-11 22:39:11 [INFO] ⏱ [run 0] Fitting models for within: 78.6s
2025-05-11 22:40:45 [INFO] ▶ [run 0] Fitting models for across
2025-05-11 23:00:12 [INFO] ⏱ [run 0] Fitting models for across: 1166.8s
2025-05-11 23:00:43 [INFO] Age‐disp threshold=76
2025-05-11 23:00:48 [INFO] ▶ [run 1] Fitting models for within
2025-05-11 23:02:15 [INFO] ⏱ [run 1] Fitting models for within: 86.4s
2025-05-11 23:03:44 [INFO] ▶ [run 1] Fitting models for across
2025-05-11 23:23:04 [INFO] ⏱ [run 1] Fitting models for across: 1159.6s
2025-05-11 23:23:35 [INFO] Age‐disp threshold=76
2025-05-11 23:23:40 [INFO] ▶ [run 2] Fitting models for within
2025-05-11 23:25:02 [INFO] ⏱ [run 2] Fi

In [3]:
#!/usr/bin/env python3
# deploy_ensemble_blocking_final.py
# -----------------------------------------------------------------------------
# - Loads trained ensemble models and age/task thresholds.
# - Loads ML features for deployment records (X_deploy.npz).
# - Loads raw attributes for deployment records by:
#   - Reading row_labels_deploy.csv to identify deployment record IDs.
#   - Loading and filtering full people.csv and manntol_einstaklingar_new.csv.
#   - Joining to get necessary attributes (names, birthyear, sex, parish_id).
# - Derives primary name part and given name for blocking.
# - Generates a custom string prefix key for blocking.
# - Applies filters: gender, age disparity, cross-census parish mismatch.
# - Predicts on candidate pairs in batches & outputs clusters.
# -----------------------------------------------------------------------------

import itertools
import json
import joblib
import logging
import time
import warnings
from pathlib import Path

import numpy as np
import pandas as pd
from scipy import sparse
from tqdm.auto import tqdm

# Attempt to import for model loading if models were saved as specific types
# (joblib usually handles this fine if the libraries are in the environment)
# from xgboost import XGBClassifier 
# from lightgbm import LGBMClassifier
# from catboost import CatBoostClassifier
# from sklearn.ensemble import RandomForestClassifier


# ───────────── paths / constants ─────────────
# --- Input Paths ---
# Root directory for original data files (people.csv, manntol_einstaklingar_new.csv)
ORIGINAL_DATA_DIR = Path("raw_data") 

# Directory where preprocessing script outputs for the DEPLOYMENT SUBSET are stored
# This dir should contain: X_deploy.npz (as iceid_ml_ready.npz) and row_labels_deploy.csv (as row_labels.csv)
# If your preprocessing script always outputs to 'artifacts', then this might be 'artifacts'
# Or, you might copy the relevant outputs for the deployment set to a dedicated directory.
# For clarity, let's assume a dedicated dir for these inputs, or adjust if using 'artifacts' directly.
DEPLOY_PREPROCESSED_INPUT_DIR = Path("artifacts_deploy_subset") # Contains outputs of preprocessing on 800k
DEPLOY_PREPROCESSED_INPUT_DIR.mkdir(exist_ok=True, parents=True)
# Expected files in DEPLOY_PREPROCESSED_INPUT_DIR:
#   - iceid_ml_ready.npz (features for the deployment records)
#   - row_labels.csv (mapping rows of the above NPZ to original 'id's)

# Paths to trained models and training reports (for thresholds)
TRAINED_MODEL_BASE_DIR = Path("models_ensemble_gpu")
TRAINING_REPORTS_BASE_DIR = Path("reports_gpu")

# --- Output Paths ---
DEPLOY_OUTPUT_DIR = Path("deploy_output_final"); DEPLOY_OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

# --- Blocking / Prediction Constants ---
CANDIDATE_PAIR_BATCH_SIZE = 50000 
# For custom_blocking_key generation:
# Using patronym/surname and first given name.
# Adjust prefix lengths based on name characteristics and desired block granularity.
PRIMARY_NAME_PART_PREFIX_LEN = 5 
GIVEN_NAME_PREFIX_LEN = 3   

# ───────────── helpers ─────────────
class Timer:
    def __init__(self, msg):
        self.msg, self.t0 = msg, time.perf_counter()
    def __enter__(self):
        logging.info(f"▶ {self.msg}")
        return self
    def __exit__(self, *_):
        logging.info(f"⏱ {self.msg}: {time.perf_counter() - self.t0:.1f}s")

def generate_custom_blocking_key_from_parts(primary_name_val, first_given_name_val):
    pn_str = str(primary_name_val).strip().upper()
    gn_str = str(first_given_name_val).strip().upper()

    pn_prefix = pn_str[:PRIMARY_NAME_PART_PREFIX_LEN]
    gn_prefix = gn_str[:GIVEN_NAME_PREFIX_LEN]
    
    if not pn_prefix: pn_prefix = "XPN" 
    if not gn_prefix: gn_prefix = "XGN" 
        
    return f"{pn_prefix}_{gn_prefix}"

# ───────────────────────────────────────
#  Load Models and Thresholds from Training
# ───────────────────────────────────────
def load_ensemble_models_for_deployment(mode):
    model_dir = TRAINED_MODEL_BASE_DIR / mode
    models_loaded = {}
    # First, look for XGB Booster
    xgb_path = model_dir / "xgb.booster"
    if xgb_path.exists():
        booster = xgb.Booster()
        booster.load_model(str(xgb_path))
        models_loaded["xgb"] = booster
        logging.info(f"Loaded XGBoost Booster for '{mode}'")
    else:
        # fallback to old .pkl if you still have it
        pkl = model_dir / "xgb.pkl"
        if pkl.exists():
            models_loaded["xgb"] = joblib.load(pkl)
            logging.info(f"Loaded XGBClassifier(pkl) for '{mode}'")

    # Now the other three
    for name in ("lgb","cat","rf"):
        p = model_dir / f"{name}.pkl"
        if p.exists():
            models_loaded[name] = joblib.load(p)
            logging.info(f"Loaded model '{name}' for '{mode}'")

    if not models_loaded:
        raise FileNotFoundError(f"No models found in {model_dir}")

    # unify Booster vs sklearn API
    def _prob(m, X):
        if isinstance(m, xgb.Booster):
            return m.predict(xgb.DMatrix(X))
        else:
            return m.predict_proba(X)[:,1]

    return lambda M: np.column_stack([_prob(m, M) for m in models_loaded.values()]).mean(axis=1)

def load_age_threshold_from_training():
    # Path where training script saved the age disparity threshold
    threshold_file = TRAINING_REPORTS_BASE_DIR / "_global_training_outputs" / "age_disparity_threshold.json"
    if not threshold_file.exists():
        logging.error(f"Age disparity threshold file not found: {threshold_file}. Using fallback default of 10 years.")
        return 10 
    with open(threshold_file, 'r') as f:
        data = json.load(f)
    age_thresh = data.get("age_disparity_threshold_years", 10) # Default if key missing
    logging.info(f"Loaded age disparity threshold: {age_thresh} years.")
    return age_thresh

def load_decision_threshold_for_task(mode):
    # Path where training script saved metrics, including the decision threshold
    metrics_file = TRAINING_REPORTS_BASE_DIR / mode / "metrics.json"
    try:
        with open(metrics_file, 'r') as f:
            decision_thresh = json.load(f).get('threshold', 0.5) # Default if key missing
        logging.info(f"Loaded decision threshold for '{mode}' model: {decision_thresh:.3f}")
        return decision_thresh
    except FileNotFoundError:
        logging.warning(f"Metrics file for '{mode}' ({metrics_file}) not found. Using default decision threshold 0.5.")
        return 0.5

# ───────────────────────────────────────
#  Load and Prepare Deployment Data (ML Features + Raw Attributes for Blocking)
# ───────────────────────────────────────
def load_and_prepare_deployment_data():
    logging.info("Loading and preparing deployment data...")
    
    # 1. Load ML features (X_deploy) and the mapping from its rows to original IDs
    X_deploy_path = DEPLOY_PREPROCESSED_INPUT_DIR / "iceid_ml_ready.npz"
    if not X_deploy_path.exists():
        # fallback to the main artifacts directory
        fallback = ART_DIR / "iceid_ml_ready.npz"
        if fallback.exists():
            logging.warning(f"{X_deploy_path} not found; falling back to {fallback}")
            X_deploy_path = fallback
        else:
            raise FileNotFoundError(
                f"Deployment feature file not found in either {DEPLOY_PREPROCESSED_INPUT_DIR} or {ART_DIR}"
            )
    X_deploy_features = sparse.load_npz(X_deploy_path)
    
    row_labels_for_deploy_path = DEPLOY_PREPROCESSED_INPUT_DIR / "row_labels.csv"
    if not row_labels_for_deploy_path.exists():
        # fallback to the main artifacts directory
        fallback = ART_DIR / "row_labels.csv"
        if fallback.exists():
            logging.warning(f"{row_labels_for_deploy_path} not found; falling back to {fallback}")
            row_labels_for_deploy_path = fallback
        else:
            raise FileNotFoundError(
                f"Row labels file not found in either {DEPLOY_PREPROCESSED_INPUT_DIR} or {ART_DIR}"
            )
    # Columns: 'row_id' (original ID), 'person' (will be -1 or empty for unlabeled)
    deploy_row_to_original_id_df = pd.read_csv(row_labels_for_deploy_path, dtype={'row_id': str})
    deploy_original_ids = deploy_row_to_original_id_df['row_id'].unique() # Get unique original IDs for the deployment set
    
    logging.info(f"Loaded X_deploy: {X_deploy_features.shape}. Deployment set has {len(deploy_original_ids)} unique original IDs.")

    # 2. Load raw attributes from original full data files, then filter for deployment IDs
    # Load people.csv (full)
    people_full_path = ORIGINAL_DATA_DIR / "people.csv"
    if not people_full_path.exists(): raise FileNotFoundError(f"{people_full_path} not found.")
    people_full_df = pd.read_csv(
        people_full_path,
        usecols=["id", "heimild", "first_name", "patronym", "surname", "birthyear", "sex"],
        low_memory=False,
        dtype={"id": str, "first_name": str, "patronym": str, "surname": str, "sex": str}
    )
    # Filter for deployment IDs
    people_deploy_subset_df = people_full_df[people_full_df['id'].isin(deploy_original_ids)].set_index('id')

    # Load manntol_einstaklingar_new.csv (full) for parish_id ('bi_sokn')
    manntol_full_path = ORIGINAL_DATA_DIR / "manntol_einstaklingar_new.csv" # Make sure this name is correct
    if not manntol_full_path.exists(): raise FileNotFoundError(f"{manntol_full_path} not found.")
    manntol_full_df = pd.read_csv(
        manntol_full_path,
        usecols=["id", "bi_sokn"], # 'bi_sokn' is used as parish_id in preprocessing
        dtype={"id": str, "bi_sokn": str}
    )
    # Filter for deployment IDs and select parish_id
    manntol_deploy_subset_df = manntol_full_df[manntol_full_df['id'].isin(deploy_original_ids)].set_index('id')
    
    # Join to get parish_id with other people attributes
    raw_attrs_deploy_df = people_deploy_subset_df.join(manntol_deploy_subset_df['bi_sokn'].rename("parish_id_raw"), how="left")
    raw_attrs_deploy_df = raw_attrs_deploy_df.reset_index() # 'id' becomes a column again

    # 3. Align these raw attributes with X_deploy_features (order by X_deploy's row index)
    # deploy_row_to_original_id_df has 'row_id' (original id) and implicitly its index is X_deploy's row index
    # We need to ensure people_attrs_aligned_df is indexed 0..N-1 corresponding to X_deploy rows
    
    # Map: X_row_index -> original_id -> raw_attributes
    # Start with deploy_row_to_original_id_df which has 'row_id' (original ID)
    # Its index (0..X_deploy.shape[0]-1) is what we want for the final aligned DFl.
    final_aligned_attrs_df = deploy_row_to_original_id_df.copy()
    final_aligned_attrs_df = final_aligned_attrs_df.rename(columns={'row_id': 'id'}) # 'id' now holds original ID
    final_aligned_attrs_df = final_aligned_attrs_df.set_index('id').join(
        raw_attrs_deploy_df.set_index('id'), # raw_attrs_deploy_df has all needed raw cols
        how="left"
    )
    final_aligned_attrs_df = final_aligned_attrs_df.reset_index() # 'id' is original_id
    # Ensure it's sorted/reindexed to match X_deploy_features's implicit 0..N-1 row indexing
    # If deploy_row_to_original_id_df was already sorted by its original index (0..N-1), this should be fine.
    # For safety, explicitly create a row_index column if not present and sort by it.
    if 'row_index' not in final_aligned_attrs_df.columns: # Should not happen if from row_labels.csv format
         final_aligned_attrs_df['row_index'] = range(len(final_aligned_attrs_df)) # Assume it was implicitly ordered
    final_aligned_attrs_df = final_aligned_attrs_df.set_index('row_index').sort_index()


    if len(final_aligned_attrs_df) != X_deploy_features.shape[0]:
        raise ValueError("CRITICAL: Mismatch after aligning raw attributes with X_deploy features. "
                         f"Aligned: {len(final_aligned_attrs_df)}, X_deploy: {X_deploy_features.shape[0]}")

    # 4. Derive additional fields needed for blocking from raw attributes
    # Sex_male_raw (0 for female, 1 for male, -1 for unknown)
    final_aligned_attrs_df['sex_male_raw'] = final_aligned_attrs_df['sex'].apply(
        lambda x: 1 if isinstance(x, str) and x.lower() == "karl" else (0 if isinstance(x, str) and x.lower() == "kona" else -1)
    ).astype(int)

    # Primary name part for blocking (Patronym if available, else Surname)
    final_aligned_attrs_df['patronym'] = final_aligned_attrs_df['patronym'].fillna('')
    final_aligned_attrs_df['surname'] = final_aligned_attrs_df['surname'].fillna('')
    final_aligned_attrs_df['primary_name_part_for_blocking'] = final_aligned_attrs_df.apply(
        lambda row: row['patronym'] if row['patronym'] else row['surname'], axis=1
    )
    # Given name for blocking
    final_aligned_attrs_df['given_name_for_blocking'] = final_aligned_attrs_df['first_name'].fillna('')

    # Custom blocking key
    final_aligned_attrs_df['custom_blocking_key'] = final_aligned_attrs_df.apply(
        lambda row: generate_custom_blocking_key_from_parts(
            row['primary_name_part_for_blocking'], row['given_name_for_blocking']
        ), axis=1
    )
    
    # Convert relevant columns to expected types
    final_aligned_attrs_df['heimild'] = pd.to_numeric(final_aligned_attrs_df['heimild'], errors='coerce').fillna(-1).astype(int)
    final_aligned_attrs_df['birthyear'] = pd.to_numeric(final_aligned_attrs_df['birthyear'], errors='coerce').fillna(0).astype(int)
    final_aligned_attrs_df['parish_id_raw'] = final_aligned_attrs_df['parish_id_raw'].astype(str).fillna('')


    logging.info(f"Successfully prepared {len(final_aligned_attrs_df)} aligned records with raw attributes and blocking keys.")
    # Important columns in final_aligned_attrs_df (indexed 0..N-1 like X_deploy):
    # 'id' (original ID), 'heimild', 'birthyear', 'sex_male_raw', 'parish_id_raw', 'custom_blocking_key'
    return X_deploy_features, final_aligned_attrs_df

# ───────────────────────────────────────
#  Blocking and Candidate Pair Generation (using derived attributes)
# ───────────────────────────────────────
def generate_candidate_pairs_for_deployment(df_with_blocking_attrs, age_disp_thresh):
    # df_with_blocking_attrs is indexed 0..N-1 (like X_deploy)
    # It contains: 'custom_blocking_key', 'sex_male_raw', 'birthyear', 'heimild', 'parish_id_raw'
    logging.info("Generating candidate pairs for deployment using custom string key blocking...")
    cand_pairs_within = []
    cand_pairs_across = []

    # Group by the generated custom_blocking_key
    grouped_by_key = df_with_blocking_attrs.groupby('custom_blocking_key')
    
    for block_key, group_from_key in tqdm(grouped_by_key, desc="Blocking by custom_key"):
        if len(group_from_key) < 2: continue # Need at least two records in a block

        # Get record indices (0..N-1 for X_deploy) within this block
        indices_in_block = group_from_key.index.tolist()

        for i in range(len(indices_in_block)):
            for j in range(i + 1, len(indices_in_block)):
                idx1, idx2 = indices_in_block[i], indices_in_block[j]
                
                rec1_attrs = df_with_blocking_attrs.loc[idx1]
                rec2_attrs = df_with_blocking_attrs.loc[idx2]

                # Filter 1: Gender (must match if both known and different from -1)
                if rec1_attrs['sex_male_raw'] != -1 and rec2_attrs['sex_male_raw'] != -1 and \
                   rec1_attrs['sex_male_raw'] != rec2_attrs['sex_male_raw']:
                    continue

                # Filter 2: Age Disparity (if both birthyears are valid)
                # Assuming birthyear 0 means unknown/invalid after preprocessing
                if rec1_attrs['birthyear'] > 0 and rec2_attrs['birthyear'] > 0:
                    if abs(rec1_attrs['birthyear'] - rec2_attrs['birthyear']) > age_disp_thresh:
                        continue
                
                # Determine if pair is for "within" or "across" model based on 'heimild'
                is_same_heimild = rec1_attrs['heimild'] == rec2_attrs['heimild']
                
                # Ensure heimilds are valid if comparing (not -1)
                heimild1_valid = rec1_attrs['heimild'] != -1
                heimild2_valid = rec2_attrs['heimild'] != -1

                if is_same_heimild and heimild1_valid: # Both records must have the same, valid heimild
                    cand_pairs_within.append(tuple(sorted((idx1, idx2))))
                elif not is_same_heimild and heimild1_valid and heimild2_valid: # Different, valid heimilds
                    # Filter 3: Cross-Census Parish Mismatch (if parishes are known and different, not a match)
                    parish1 = str(rec1_attrs['parish_id_raw']).strip()
                    parish2 = str(rec2_attrs['parish_id_raw']).strip()
                    if parish1 and parish2 and parish1 != parish2 : # Both non-empty and different
                        continue 
                    cand_pairs_across.append(tuple(sorted((idx1, idx2))))

    # Remove duplicates
    final_cand_pairs_within = sorted(list(set(cand_pairs_within)))
    final_cand_pairs_across = sorted(list(set(cand_pairs_across)))

    logging.info(f"Generated {len(final_cand_pairs_within)} 'within-census' candidate pairs after filtering.")
    logging.info(f"Generated {len(final_cand_pairs_across)} 'across-census' candidate pairs after filtering.")
    return final_cand_pairs_within, final_cand_pairs_across

# ───────────────────────────────────────
#  Predict on Candidates & Cluster
# ───────────────────────────────────────
def predict_and_cluster_deployment(
    X_features_deploy, 
    aligned_raw_attributes_df, # Indexed 0..N-1, contains 'id' (original_id)
    list_of_candidate_pairs, 
    ensemble_predict_fn, 
    decision_thresh, 
    output_file_tag
    ):
    
    logging.info(f"Predicting on {len(list_of_candidate_pairs)} candidate pairs for '{output_file_tag}'...")
    
    # For mapping X_features_deploy row indices back to original record IDs
    # aligned_raw_attributes_df['id'] contains the original string IDs
    original_id_lookup = aligned_raw_attributes_df['id'] 

    if not list_of_candidate_pairs:
        logging.info(f"No candidate pairs for '{output_file_tag}'. Creating empty output files.")
        empty_links_df = pd.DataFrame(columns=['original_id_1', 'original_id_2', 'probability', 'is_match'])
        empty_links_df.to_csv(DEPLOY_OUTPUT_DIR / f"{output_file_tag}_predicted_links.csv", index=False)
        
        empty_clusters_df_data = [{"original_id": original_id_lookup.loc[i], "cluster_id": i} for i in range(X_features_deploy.shape[0])]
        empty_clusters_df = pd.DataFrame(empty_clusters_df_data)
        empty_clusters_df.to_csv(DEPLOY_OUTPUT_DIR / f"{output_file_tag}_clusters.csv", index=False)
        return

    all_predicted_probs = []
    num_batches = (len(list_of_candidate_pairs) + CANDIDATE_PAIR_BATCH_SIZE - 1) // CANDIDATE_PAIR_BATCH_SIZE
    
    for i in tqdm(range(0, len(list_of_candidate_pairs), CANDIDATE_PAIR_BATCH_SIZE), 
                  total=num_batches, desc=f"Predicting '{output_file_tag}' batches"):
        current_batch_pairs = list_of_candidate_pairs[i : i + CANDIDATE_PAIR_BATCH_SIZE]
        if not current_batch_pairs: continue

        # Construct feature matrix for this batch of pairs
        # pair_matrix is defined in the training script, copy or re-import if separate
        X_batch_for_model = sparse.vstack([sparse.hstack([X_features_deploy[p_idx1], X_features_deploy[p_idx2]]) 
                                           for p_idx1, p_idx2 in current_batch_pairs])
        
        batch_probs = ensemble_predict_fn(X_batch_for_model)
        all_predicted_probs.extend(batch_probs)
    
    all_predicted_probs = np.array(all_predicted_probs)
    
    # Store all candidate pairs with their probabilities and match prediction
    output_links_data = []
    for k, (idx1, idx2) in enumerate(list_of_candidate_pairs):
        prob = all_predicted_probs[k]
        is_match_prediction = prob >= decision_thresh
        output_links_data.append({
            "original_id_1": original_id_lookup.loc[idx1], 
            "original_id_2": original_id_lookup.loc[idx2],
            "probability": prob,
            "is_match": is_match_prediction
        })
    
    predicted_links_output_df = pd.DataFrame(output_links_data)
    predicted_links_output_df.to_csv(DEPLOY_OUTPUT_DIR / f"{output_file_tag}_predicted_links.csv", index=False)
    logging.info(f"Saved {len(predicted_links_output_df)} candidate pair predictions for '{output_file_tag}'.")

    # Clustering using Connected Components on pairs predicted as matches
    logging.info(f"Clustering predicted matches for '{output_file_tag}'...")
    # Get pairs that are actual matches based on threshold
    pairs_confirmed_as_matches = []
    for k, (idx1, idx2) in enumerate(list_of_candidate_pairs):
        if all_predicted_probs[k] >= decision_thresh:
            pairs_confirmed_as_matches.append((idx1, idx2)) # Using X_features_deploy row indices

    if not pairs_confirmed_as_matches:
        logging.info(f"No pairs predicted as matches for '{output_file_tag}'. All records are singletons.")
        # Each record is its own cluster
        singleton_clusters_data = [{"original_id": original_id_lookup.loc[i], "cluster_id": i} 
                                   for i in range(X_features_deploy.shape[0])]
        singleton_clusters_df = pd.DataFrame(singleton_clusters_data)
        singleton_clusters_df.to_csv(DEPLOY_OUTPUT_DIR / f"{output_file_tag}_clusters.csv", index=False)
        return

    # Build graph only with confirmed matches
    nodes_in_match_graph = sorted(list(set(idx for pair in pairs_confirmed_as_matches for idx in pair)))
    
    local_node_map = {node_idx: i for i, node_idx in enumerate(nodes_in_match_graph)}
    num_local_graph_nodes = len(nodes_in_match_graph)
    cc_parent = list(range(num_local_graph_nodes))

    def find_cc_set(i_local_idx):
        if cc_parent[i_local_idx] == i_local_idx: return i_local_idx
        cc_parent[i_local_idx] = find_cc_set(cc_parent[i_local_idx])
        return cc_parent[i_local_idx]

    def unite_cc_sets(i_local_idx, j_local_idx):
        root_i, root_j = find_cc_set(i_local_idx), find_cc_set(j_local_idx)
        if root_i != root_j: cc_parent[root_j] = root_i # Union by making one root parent of other

    for idx1_row, idx2_row in pairs_confirmed_as_matches:
        # These indices are already for X_features_deploy
        if idx1_row in local_node_map and idx2_row in local_node_map: # Should always be true
            unite_cc_sets(local_node_map[idx1_row], local_node_map[idx2_row])
    
    # Assign final cluster IDs
    # All records from X_features_deploy get a cluster_id
    # Records not in any confirmed match pair become their own singleton cluster.
    
    # Map X_features_deploy row_index to a final cluster_id value
    final_record_to_cluster_map = {} 
    current_global_cluster_id = 0
    # Map local CC root index to a global cluster_id
    local_root_to_global_cluster_id = {} 

    # Assign cluster IDs for records involved in matches
    for i_row_idx in nodes_in_match_graph: # Iterate over X_features_deploy row indices in graph
        local_idx = local_node_map[i_row_idx]
        root_for_node = find_cc_set(local_idx)
        if root_for_node not in local_root_to_global_cluster_id:
            local_root_to_global_cluster_id[root_for_node] = current_global_cluster_id
            current_global_cluster_id += 1
        final_record_to_cluster_map[i_row_idx] = local_root_to_global_cluster_id[root_for_node]

    # Assign cluster IDs for singleton records (not in any match)
    output_cluster_assignment_data = []
    for i_row_idx in range(X_features_deploy.shape[0]):
        original_id_val = original_id_lookup.loc[i_row_idx]
        if i_row_idx in final_record_to_cluster_map:
            assigned_cluster_id = final_record_to_cluster_map[i_row_idx]
        else: # Singleton
            assigned_cluster_id = current_global_cluster_id
            current_global_cluster_id += 1
        output_cluster_assignment_data.append({"original_id": original_id_val, "cluster_id": assigned_cluster_id})
        
    final_clusters_df = pd.DataFrame(output_cluster_assignment_data)
    final_clusters_df.to_csv(DEPLOY_OUTPUT_DIR / f"{output_file_tag}_clusters.csv", index=False)
    logging.info(f"Saved {len(final_clusters_df)} records with cluster assignments for '{output_file_tag}'. "
                 f"Found {final_clusters_df['cluster_id'].nunique()} unique clusters.")

# ───────────────────────────────────────
#  Main Deployment Logic
# ───────────────────────────────────────
if __name__ == "__main__":
    warnings.filterwarnings("ignore", category=UserWarning)
    warnings.filterwarnings("ignore", category=FutureWarning)
    logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", datefmt='%Y-%m-%d %H:%M:%S')

    with Timer("Total Deployment Run"):
        # 1. Load trained models and necessary thresholds
        with Timer("Loading trained models and thresholds"):
            age_disparity_param = load_age_threshold_from_training()
            # Load ensemble prediction functions
            ensemble_model_within = load_ensemble_models_for_deployment("within")
            ensemble_model_across = load_ensemble_models_for_deployment("across")
            # Load decision thresholds for each task
            decision_threshold_within = load_decision_threshold_for_task("within")
            decision_threshold_across = load_decision_threshold_for_task("across")

        # 2. Load and prepare deployment data (ML features and raw attributes for blocking)
        with Timer("Loading and preparing all deployment data"):
            # X_deploy is the ML feature matrix (sparse)
            # people_attributes_for_deploy_df is indexed 0..N-1 like X_deploy,
            # and contains 'id'(original_id), 'heimild', 'birthyear', 'sex_male_raw', 
            # 'parish_id_raw', 'custom_blocking_key'.
            X_deploy, people_attributes_for_deploy_df = load_and_prepare_deployment_data()

        # 3. Generate candidate pairs using blocking rules
        with Timer("Generating candidate pairs via blocking logic"):
            candidate_pairs_within_mode, candidate_pairs_across_mode = generate_candidate_pairs_for_deployment(
                people_attributes_for_deploy_df, age_disparity_param
            )

        # 4. Process "within-census" candidate pairs
        with Timer("Processing 'within-census' candidate pairs"):
            predict_and_cluster_deployment(
                X_deploy, people_attributes_for_deploy_df, 
                candidate_pairs_within_mode, 
                ensemble_model_within, decision_threshold_within, 
                "deployment_within_census" # output file tag
            )

        # 5. Process "across-census" candidate pairs
        with Timer("Processing 'across-census' candidate pairs"):
            predict_and_cluster_deployment(
                X_deploy, people_attributes_for_deploy_df, 
                candidate_pairs_across_mode, 
                ensemble_model_across, decision_threshold_across, 
                "deployment_across_census" # output file tag
            )

    logging.info("--- Deployment Script Finished ---")
    logging.info(f"All deployment outputs are in directory: {DEPLOY_OUTPUT_DIR}")
    logging.info("Review *_predicted_links.csv for pair probabilities and match status.")
    logging.info("Review *_clusters.csv for final entity cluster assignments (using original_id).")

2025-05-11 19:35:42 [INFO] ▶ Total Deployment Run
2025-05-11 19:35:42 [INFO] ▶ Loading trained models and thresholds
2025-05-11 19:35:42 [INFO] Loaded age disparity threshold: 76 years.
2025-05-11 19:35:42 [INFO] Loaded XGBoost Booster for 'within'
2025-05-11 19:35:42 [INFO] Loaded model 'lgb' for 'within'
2025-05-11 19:35:42 [INFO] Loaded model 'cat' for 'within'
2025-05-11 19:35:43 [INFO] Loaded model 'rf' for 'within'
2025-05-11 19:35:43 [INFO] Loaded XGBoost Booster for 'across'
2025-05-11 19:35:43 [INFO] Loaded model 'lgb' for 'across'
2025-05-11 19:35:43 [INFO] Loaded model 'cat' for 'across'
2025-05-11 19:35:43 [INFO] Loaded decision threshold for 'within' model: 0.540
2025-05-11 19:35:43 [INFO] ⏱ Loading trained models and thresholds: 0.9s
2025-05-11 19:35:43 [INFO] ▶ Loading and preparing all deployment data
2025-05-11 19:35:43 [INFO] Loading and preparing deployment data...
2025-05-11 19:35:43 [INFO] Loaded X_deploy: (984028, 50799). Deployment set has 984028 unique original 

Blocking by custom_key:   0%|          | 0/47046 [00:00<?, ?it/s]

2025-05-11 21:51:00 [INFO] ⏱ Generating candidate pairs via blocking logic: 8105.7s
2025-05-11 21:51:00 [INFO] ⏱ Total Deployment Run: 8118.5s


KeyboardInterrupt: 