# MODELS & ALGOs

In [None]:
#!/usr/bin/env python3
# ICE‑ID ER Runner · TriBERTa‑ER | DistilBERT | MiniLM‑CE
# dual task (within / across census) · thread pool · resumable · verbose progress
# ------------------------------------------------------------------------------

import os, itertools, json, logging, gc, random, time
from pathlib import Path
from multiprocessing import cpu_count
from multiprocessing.dummy import Pool                     # thread‑based

os.environ["TOKENIZERS_PARALLELISM"] = "false"

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, accuracy_score
from sklearn.model_selection import train_test_split
from transformers import (AutoTokenizer, AutoModelForSequenceClassification,
                          Trainer, TrainingArguments)
from sentence_transformers import (SentenceTransformer, CrossEncoder,
                                   InputExample, losses)

# ───────────────────────── config ─────────────────────────
ART_DIR   = Path("artifacts")
DATA_DIR  = Path("raw_data")
OUTPUT    = Path("models_er"); OUTPUT.mkdir(exist_ok=True)
DEVICE    = torch.device("cuda" if torch.cuda.is_available() else "cpu")

N_RUNS          = 10
MAX_TR, MAX_TE  = 100_000, 100_000
PAIR_SAMPLES    = 25_000
BATCH_EMB       = 512
DISTIL_BS       = 16
MINILM_BS       = 16
RNG_SEED        = 42
CKPT_FILE       = OUTPUT / "checkpoint.json"
MODES           = ("within", "across")

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(message)s",
    datefmt="%H:%M:%S"
)

# ───────────────────── helpers ─────────────────────
class Timer:
    def __init__(self, msg):
        self.msg, self.t0 = msg, time.perf_counter()
        logging.info("▶ %s …", self.msg)
    def __enter__(self): return self
    def __exit__(self, *_):
        logging.info("⏱ %s: %.1fs", self.msg, time.perf_counter() - self.t0)

def load_people():
    with Timer("load people.csv"):
        ppl = (pd.read_csv(DATA_DIR/"people.csv", low_memory=False)
                 .rename(columns=str.lower).set_index("id"))
    for c in ("first_name","middle_name","patronym","surname"):
        ppl[c] = ppl.get(c,"").fillna("").astype(str)
    ppl["full_name"] = ppl[["first_name","middle_name","patronym","surname"]] \
        .apply(lambda r: " ".join(w.strip().lower() for w in r if w), axis=1)
    for c in ("birthyear","heimild"):
        ppl[c] = pd.to_numeric(ppl.get(c,0), errors="coerce").fillna(0).astype(int)
    lbl = (pd.read_csv(ART_DIR/"row_labels.csv")
             .set_index("row_id")["person"]
             .pipe(pd.to_numeric, errors="coerce").astype("Int64"))
    ppl["person"] = pd.Series(ppl.index, index=ppl.index).map(lbl)
    return ppl.reset_index()

def trigram(s):
    s = "".join(c for c in s.lower() if c.isalnum())
    return {s[i:i+3] for i in range(len(s)-2)} if len(s) >= 3 else {s}

def build_blocks(df):
    with Timer("build trigram blocks"):
        with Pool(cpu_count()) as p:
            df = df.copy(); df["trigs"] = list(p.map(trigram, df["full_name"].tolist()))
        blocks = {}
        for idx, row in tqdm(df.iterrows(), total=len(df), desc="index blocks"):
            key = (frozenset(row["trigs"]),
                   (row["birthyear"]//10) if row["birthyear"] > 0 else -1)
            blocks.setdefault(key, []).append(idx)
    logging.info("blocks: %d", len(blocks))
    return blocks

def _pairs(bucket, k, rng):
    n = len(bucket); max_p = n*(n-1)//2
    if n < 2 or k == 0: return []
    if max_p <= 3_000:
        return rng.sample(list(itertools.combinations(bucket, 2)), min(k, max_p))
    seen = set()
    while len(seen) < k:
        seen.add(tuple(sorted(rng.sample(bucket, 2))))
    return list(seen)

def sample_pairs(blocks, df, tr_idx, te_idx, k_tr, k_te, mode, rng):
    def keep(i, j):
        same = df.at[i, "heimild"] == df.at[j, "heimild"] and df.at[i, "heimild"] > 0
        return same if mode == "within" else not same

    tr, te = [], []
    for bucket in tqdm(rng.sample(list(blocks.values()), len(blocks)),
                       total=len(blocks),
                       desc=f"sampling {mode}",
                       leave=False):
        if len(tr) < k_tr:
            tr += [p for p in _pairs([i for i in bucket if i in tr_idx],
                                      k_tr - len(tr), rng) if keep(*p)]
        if len(te) < k_te:
            te += [p for p in _pairs([i for i in bucket if i in te_idx],
                                      k_te - len(te), rng) if keep(*p)]
        if len(tr) >= k_tr and len(te) >= k_te:
            break
    return tr[:k_tr], te[:k_te]

def labels_for(pairs, df):
    return np.array([int(df.at[i,"person"] == df.at[j,"person"]
                         and not pd.isna(df.at[i,"person"])) for i,j in pairs], int)

# ───────────────────── models ─────────────────────
class TriBERTaER:
    def __init__(self, name="sentence-transformers/all-MiniLM-L6-v2"):
        self.model = SentenceTransformer(name, device=DEVICE)
        self.loss  = losses.CosineSimilarityLoss(self.model)
    @staticmethod
    def serial(row): return " ; ".join(f"{k}: {v}" for k,v in row.items())
    def train(self, df, idx, out_dir, n_pairs=PAIR_SAMPLES):
        rng, examples = random.Random(RNG_SEED), []
        grouped = df.loc[idx].dropna(subset=["person"]).groupby("person")
        persons = list(grouped.groups)
        while len(examples) < n_pairs:
            pid = rng.choice(persons)
            rows = grouped.get_group(pid).index.tolist()
            if len(rows) >= 2:
                a, p = rng.sample(rows, 2)
                examples.append(InputExample(
                    texts=[self.serial(df.loc[a]), self.serial(df.loc[p])], label=1.0))
            pid2 = rng.choice([x for x in persons if x != pid])
            n1   = rng.choice(rows)
            n2   = rng.choice(grouped.get_group(pid2).index.tolist())
            examples.append(InputExample(
                texts=[self.serial(df.loc[n1]), self.serial(df.loc[n2])], label=0.0))
        loader = DataLoader(examples, shuffle=True, batch_size=32)
        with Timer(f"TriBERTa fit {out_dir.name}"):
            self.model.fit([(loader, self.loss)], epochs=1,
                           show_progress_bar=False, output_path=str(out_dir))
    def score(self, df, pairs, bs=BATCH_EMB):
        sims = []
        for i in tqdm(range(0,len(pairs),bs), desc="TriBERTa score", leave=False):
            chunk = pairs[i:i+bs]
            a = [self.serial(df.loc[x]) for x,_ in chunk]
            b = [self.serial(df.loc[y]) for _,y in chunk]
            ea = self.model.encode(a, convert_to_tensor=True, device=DEVICE)
            eb = self.model.encode(b, convert_to_tensor=True, device=DEVICE)
            sims.append(torch.cosine_similarity(ea, eb).cpu().numpy())
        return np.concatenate(sims)/2 + 0.5

class EncodeDataset(Dataset):
    def __init__(self, enc, y): self.enc, self.y = enc, y
    def __getitem__(self, i):
        d = {k:v[i] for k,v in self.enc.items()}; d["labels"] = self.y[i]; return d
    def __len__(self): return len(self.y)

class DistilBERTER:
    def __init__(self, name="distilbert-base-uncased"):
        self.tok = AutoTokenizer.from_pretrained(name)
        self.model = AutoModelForSequenceClassification.from_pretrained(
            name, num_labels=2).to(DEVICE)
    def train(self, df, pairs, y, out_dir):
        txt = [f"A: {TriBERTaER.serial(df.loc[i])}\nB: {TriBERTaER.serial(df.loc[j])}"
               for i,j in pairs]
        enc = self.tok(txt, truncation=True, padding=True,
                       max_length=256, return_tensors="pt")
        ds  = EncodeDataset(enc, torch.tensor(y))
        args = TrainingArguments(output_dir=str(out_dir),
                                per_device_train_batch_size=8,
                                num_train_epochs=1, save_strategy="no",
                                fp16=True, remove_unused_columns=False,
                                logging_steps=50)
        with Timer(f"DistilBERT fit {out_dir.name}"):
            Trainer(self.model, args, train_dataset=ds,
                    tokenizer=self.tok).train()
        self.model.save_pretrained(out_dir)
    def score(self, df, pairs, bs=DISTIL_BS):
        txt = [f"A: {TriBERTaER.serial(df.loc[i])}\nB: {TriBERTaER.serial(df.loc[j])}"
               for i,j in pairs]
        out=[]
        for i in tqdm(range(0,len(txt),bs), desc="Distil score", leave=False):
            enc=self.tok(txt[i:i+bs], truncation=True, padding=True,
                         max_length=256, return_tensors="pt").to(DEVICE)
            with torch.no_grad():
                logits=self.model(**enc).logits
            out.append(torch.softmax(logits,-1)[:,1].cpu().numpy())
        return np.concatenate(out)

class MiniLMCE:
    def __init__(self, name="sentence-transformers/paraphrase-MiniLM-L6-v2"):
        self.model = CrossEncoder(name, num_labels=1, device=DEVICE)
    def train(self, df, pairs, y, out_dir):
        samples = [InputExample(
            texts=[TriBERTaER.serial(df.loc[i]),
                   TriBERTaER.serial(df.loc[j])], label=float(l))
            for (i,j), l in zip(pairs, y)]
        loader = DataLoader(samples, shuffle=True, batch_size=16)
        with Timer(f"MiniLM‑CE fit {out_dir.name}"):
            self.model.fit(train_dataloader=loader, epochs=1)
        self.model.save(out_dir)
    def score(self, df, pairs, bs=MINILM_BS):
        txt=[(TriBERTaER.serial(df.loc[i]),TriBERTaER.serial(df.loc[j]))
             for i,j in pairs]
        return np.array(self.model.predict(txt, batch_size=bs))

# ───────────────────── metrics / ckpt ─────────────────────
def metrics(y,p,t=0.5):
    pred=(p>=t).astype(int)
    pr,rc,f1,_=precision_recall_fscore_support(y,pred,average="binary",zero_division=0)
    return {"precision":pr,"recall":rc,"f1":f1,"accuracy":accuracy_score(y,pred),
            "auc":roc_auc_score(y,p) if len(np.unique(y))>1 else float("nan")}

def ckpt_load(): return (json.loads(CKPT_FILE.read_text())["done_runs"]
                         if CKPT_FILE.exists() else 0)
def ckpt_save(done,log):
    CKPT_FILE.write_text(json.dumps({"done_runs":done}))
    (OUTPUT/"metrics_runs").mkdir(exist_ok=True)
    (OUTPUT/f"metrics_runs/run_{done}.json").write_text(json.dumps(log[-1],indent=2))

# ───────────────────── main ─────────────────────
if __name__ == "__main__":
    df = load_people()
    df_lab = df[df["person"].notna()].reset_index(drop=True)
    blocks = build_blocks(df_lab)
    collected = {m:[] for m in MODES}

    for run in range(ckpt_load(), N_RUNS):
        logging.info("=== RUN %d/%d ===", run+1, N_RUNS)
        rng=random.Random(RNG_SEED+run)
        tr_idx, te_idx = train_test_split(df_lab.index.tolist(),
                                          test_size=0.2,
                                          random_state=run)

        for mode in MODES:
            with Timer(f"pair sampling {mode}"):
                p_tr, p_te = sample_pairs(blocks, df_lab, tr_idx, te_idx,
                                          MAX_TR, MAX_TE, mode, rng)
            y_tr, y_te = labels_for(p_tr, df_lab), labels_for(p_te, df_lab)

            tb = TriBERTaER();   tb_dir = OUTPUT/f"{mode}_triberta_{run}"
            db = DistilBERTER(); db_dir = OUTPUT/f"{mode}_distil_{run}"
            ce = MiniLMCE();     ce_dir = OUTPUT/f"{mode}_minilm_{run}"

            tb.train(df_lab, tr_idx, tb_dir);          p_tb = tb.score(df_lab, p_te)
            db.train(df_lab, p_tr, y_tr, db_dir);      p_db = db.score(df_lab, p_te)
            ce.train(df_lab, p_tr, y_tr, ce_dir);      p_ce = ce.score(df_lab, p_te)

            collected[mode].append({"TriBERTa":metrics(y_te,p_tb),
                                    "DistilBERT":metrics(y_te,p_db),
                                    "MiniLM-CE":metrics(y_te,p_ce)})

            del tb, db, ce
            torch.cuda.empty_cache(); gc.collect()

        ckpt_save(run+1, collected)
        logging.info("✔ run %d finished", run+1)

    ks=["precision","recall","f1","accuracy","auc"]
    for mode in MODES:
        avg={m:{k:np.mean([r[m][k] for r in collected[mode]]) for k in ks}
             for m in ("TriBERTa","DistilBERT","MiniLM-CE")}
        pd.DataFrame(avg).T.to_csv(OUTPUT/f"{mode}_metrics_summary_avg.csv")

    logging.info("🎉 all runs complete")


12:01:10 INFO ▶ load people.csv …
12:01:11 INFO ⏱ load people.csv: 1.9s
12:01:15 INFO ▶ build trigram blocks …


index blocks:   0%|          | 0/476683 [00:00<?, ?it/s]

12:01:27 INFO ⏱ build trigram blocks: 12.7s
12:01:27 INFO blocks: 161413
12:01:28 INFO === RUN 1/10 ===
12:01:28 INFO ▶ pair sampling within …
12:28:05 INFO ⏱ pair sampling within: 1597.6s


KeyboardInterrupt: 