In [1]:
from __future__ import annotations
import os
import re
import tempfile
import random
import itertools
import pathlib
import warnings
import logging
import gc

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from scipy import sparse
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score,
    f1_score, roc_auc_score, roc_curve, adjusted_rand_score
)
from pytorch_tabnet.tab_model import TabNetClassifier
from tabpfn import TabPFNClassifier
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoModelForCausalLM,
    BitsAndBytesConfig
)
from pytorch_tabular import TabularModel
from pytorch_tabular.config import DataConfig, TrainerConfig, OptimizerConfig
from pytorch_tabular.models import TabTransformerConfig, FTTransformerConfig
from joblib import Parallel, delayed
import multiprocessing

# ─── Settings ────────────────────────────────────────────────
DEVICE     = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ART_DIR    = pathlib.Path("artifacts")
DATA_DIR   = pathlib.Path("raw_data")
THR        = 0.5
NEG_RATIO  = 1
MODELS     = [
    "logreg","tabnet","tabpfn","saint",
    "tabtransformer","fttransformer","nars",
    "ditto","attendem","hf_llm"
]
N_RUNS     = 10
NUM_WORKERS = multiprocessing.cpu_count()

# ───────────────────────────── Baseline ──────────────────────────────
class LogReg(BaseEstimator, ClassifierMixin):
    """Wrapper around sklearn.linear_model.LogisticRegression."""
    def __init__(self, penalty="l2", C=1.0, solver="lbfgs",
                 max_iter=2000, n_jobs=-1, **kwargs):
        self._lr = LogisticRegression(
            penalty=penalty, C=C, solver=solver,
            max_iter=max_iter, n_jobs=n_jobs, **kwargs
        )

    def fit(self, X: np.ndarray, y: np.ndarray):
        self._lr.fit(X, y)
        return self

    def predict(self, X: np.ndarray) -> np.ndarray:
        return self._lr.predict(X)

    def predict_proba(self, X: np.ndarray) -> np.ndarray:
        return self._lr.predict_proba(X)


# ───────────────────── Tabular DL baselines ─────────────────────────
class TabNet(BaseEstimator, ClassifierMixin):
    def __init__(self):
        self.tab = TabNetClassifier()

    def fit(self, X, y):
        self.tab.fit(X=X, y=y, eval_set=[(X, y)],
                     max_epochs=200, patience=20, verbose=0)
        return self

    def predict_proba(self, X):
        return self.tab.predict_proba(X)

    def predict(self, X):
        return self.tab.predict(X)


class TabPFN(TabPFNClassifier):
    def __init__(self):
        super().__init__(
            device="cuda" if torch.cuda.is_available() else "cpu",
            N_ensemble_configurations=32
        )


# ───────────────────── PyTorch-Tabular wrappers ─────────────────────
class _PTab(BaseEstimator, ClassifierMixin):
    CFG = None
    def __init__(self, cat_cols: list[str], num_cols: list[str]):
        self.cat, self.num = cat_cols, num_cols

    def fit(self, X, y):
        df = pd.DataFrame(X, columns=self.num + self.cat)
        df["_y_"] = y
        data_cfg = DataConfig(
            target=["_y_"],
            continuous_cols=self.num,
            categorical_cols=self.cat
        )
        trainer_cfg = TrainerConfig(
            max_epochs=50, gpus=0, progress_bar=False
        )
        self.tm = TabularModel(
            data_cfg, self.CFG(), trainer_cfg, OptimizerConfig()
        )
        self.tm.fit(train=df)
        return self

    def predict_proba(self, X):
        df = pd.DataFrame(X, columns=self.num + self.cat)
        p = self.tm.predict(df)["prediction"].values
        return np.vstack([1 - p, p]).T

    def predict(self, X):
        return (self.predict_proba(X)[:, 1] >= 0.5).astype(int)


class TabTransformer(_PTab):
    CFG = TabTransformerConfig


class FTTransformer(_PTab):
    CFG = FTTransformerConfig


# ════════════════════  IN-FILE NARS IMPLEMENTATION  ════════════════════
P_POOL_SIZE = 100
N_PTRS      = 10

ULTIMATE_COLS = [
    "ID","Manntal","Nafn","Fornafn","Millinafn","Eftirnafn","Aettarnafn",
    "Faedingarar","Kyn","Stada","Hjuskapur","bi_einstaklingur","bi_baer",
    "bi_hreppur","bi_sokn","bi_sysla","cleaned_status","uniqueness_score",
    "id_individual","score",
]
BAN_LIST = {
    "ID","Millinafn","Aettarnafn","bi_hreppur","bi_sokn",
    "bi_sysla","uniqueness_score","id_individual","score"
}


class Truth:
    def __init__(self, f: float, c: float):
        self.f = f
        self.c = max(min(c, 0.99), 0.01)

    @property
    def wp(self):
        return self.f * self.c / (1 - self.c)

    @property
    def wn(self):
        return (1 - self.f) * self.c / (1 - self.c)

    def revise(self, other: "Truth"):
        wp = self.wp + other.wp
        wn = self.wn + other.wn
        f  = wp / (wp + wn)
        c  = (wp + wn) / (wp + wn + 1)
        return Truth(f, c)

    @property
    def e(self):
        return self.c * (self.f - 0.5) + 0.5


class Pattern:
    def __init__(self, stmts: set[str], truth: Truth):
        self.statements = stmts
        self.truth      = truth

    def __len__(self):
        return len(self.statements)

    def __hash__(self):
        return sum(hash(s) for s in self.statements)

    @property
    def f(self):
        return self.truth.f

    @property
    def c(self):
        return self.truth.c

    @property
    def e(self):
        return self.truth.e

    def match(self, other: "Pattern"):
        shared = self.statements & other.statements
        a_un   = self.statements - shared
        b_un   = other.statements - shared
        matched= Pattern(shared, self.truth.revise(other.truth))
        self_out = Pattern(a_un, self.truth)
        other_out= Pattern(b_un, other.truth)
        base_len = max(len(self), len(other)) or 1
        sim = len(shared) / base_len
        conf_e = max(self.truth.e, other.truth.e)
        return (sim, conf_e), self_out, matched, other_out


class PatternPool:
    def __init__(self, size: int = P_POOL_SIZE):
        self.size = size
        self.pool: list[Pattern] = []

    def add(self, p: Pattern):
        # replace existing if same statements & higher confidence
        for i, old in enumerate(self.pool):
            if old.statements == p.statements:
                if p.c > old.c:
                    self.pool[i] = p
                return
        # insert sorted by e
        for i, old in enumerate(self.pool):
            if p.e > old.e:
                self.pool.insert(i, p)
                break
        else:
            self.pool.append(p)
        # keep pool size in check
        if len(self.pool) > self.size:
            self.pool.pop(len(self.pool)//2)

    def get_ptrs(self, n: int = N_PTRS):
        half = n // 2
        return set(self.pool[:half] + self.pool[-half:])


def preprocess(r1: np.ndarray, r2: np.ndarray) -> Pattern:
    stmts = set()
    label = -1
    for idx, (a, b) in enumerate(zip(r1, r2)):
        col = ULTIMATE_COLS[idx] if idx < len(ULTIMATE_COLS) else f"COL{idx}"
        if col in BAN_LIST:
            continue
        if col == "Manntal" and a and b:
            diff = abs(float(a) - float(b))
            stmts.add(f"year_diff_{int(diff)}")
        elif col in ("Nafn","Fornafn","Eftirnafn","Faedingarar","Kyn","Hjuskapur"):
            eq = "same" if a == b else "diff"
            stmts.add(f"{eq}_{col.lower()}")
        elif col == "bi_einstaklingur":
            label = 1 if a == b else 0
    return Pattern(stmts, Truth(label if label != -1 else 0.5, 0.9))


def match_ultimate(r1: np.ndarray, r2: np.ndarray,
                   pool: PatternPool, n_ptr: int,
                   just_eval: bool=False) -> float:
    PTC = preprocess(r1, r2)
    exps = []
    for ptr in pool.get_ptrs(n_ptr):
        (sim, e), a, b, c = PTC.match(ptr)
        exps.append(Truth(e, sim))
        if not just_eval:
            if a.statements: pool.add(a)
            if b.statements: pool.add(b)
            if c.statements: pool.add(c)
    if exps:
        out = exps[0]
        for t in exps[1:]:
            out = out.revise(t)
        return out.e
    else:
        if not just_eval:
            pool.add(PTC)
        return 0.5


class NARS(BaseEstimator, ClassifierMixin):
    """Neuro-Analogical Reasoning System (symbolic)."""
    def __init__(self, train_pairs: int=3000):
        self.train_pairs = train_pairs
        self.pool = PatternPool()

    def _tok(self, s: str) -> np.ndarray:
        return np.array([tok.strip() for tok in s.split(" ; ")])

    def fit(self, Xpair: np.ndarray, y=None):
        for i in range(min(self.train_pairs, len(Xpair))):
            a, b = Xpair[i]
            match_ultimate(self._tok(a), self._tok(b),
                           self.pool, N_PTRS, False)
        return self

    def _p(self, a: str, b: str) -> float:
        return match_ultimate(self._tok(a), self._tok(b),
                               self.pool, N_PTRS, True)

    def predict_proba(self, Xpair: np.ndarray):
        p = np.fromiter((self._p(a,b) for a,b in Xpair),
                        float, len(Xpair))
        return np.vstack([1-p, p]).T

    def predict(self, Xpair: np.ndarray):
        return (self.predict_proba(Xpair)[:,1] >= 0.5).astype(int)


# ════════════════════  Minimal Ditto  ════════════════════════════
class Ditto(BaseEstimator, ClassifierMixin):
    """
    Simplified Ditto: pair-wise transformer with CE + margin loss.
    """
    def __init__(self,
                 plm: str="microsoft/deberta-v3-small",
                 max_len: int=256,
                 lr: float=2e-5,
                 margin: float=0.4,
                 bs: int=16,
                 epochs: int=3):
        self.tok = AutoTokenizer.from_pretrained(plm)
        self.model = AutoModelForSequenceClassification.from_pretrained(
            plm, num_labels=2).to(DEVICE)
        self.bs, self.epochs = bs, epochs
        self.lr, self.margin = lr, margin

    @staticmethod
    def _wrap(rec: str) -> str:
        # assume rec is "col1=val1 ; col2=val2 ; …"
        return " ".join(f"[COL] {kv}" for kv in rec.split(" ; "))

    def _encode(self, pairs):
        texts = [
            f"{self._wrap(a)} [SEP] {self._wrap(b)}"
            for a, b in pairs
        ]
        return self.tok(
            texts, truncation=True, padding=True,
            max_length=256, return_tensors="pt"
        ).to(DEVICE)

    def fit(self, Xpair: np.ndarray, y: np.ndarray):
        enc = self._encode(Xpair)
        labels = torch.tensor(y, device=DEVICE)
        ds = TensorDataset(enc["input_ids"], enc["attention_mask"], labels)
        loader = DataLoader(ds, batch_size=self.bs, shuffle=True)
        opt = torch.optim.AdamW(self.model.parameters(), lr=self.lr)
        ce  = nn.CrossEntropyLoss()

        for _ in range(self.epochs):
            for ids, att, lab in loader:
                opt.zero_grad()
                logits = self.model(ids, attention_mask=att).logits
                # margin loss with hardest negative
                pos_scores = logits[lab==1][:,1]
                neg_scores = logits[lab==0][:,1]
                hard_neg = neg_scores.max() if len(neg_scores)>0 else 0.0
                loss = ce(logits, lab) + torch.relu(self.margin - pos_scores + hard_neg).mean()
                loss.backward(); opt.step()
        return self

    @torch.no_grad()
    def predict_proba(self, Xpair: np.ndarray):
        enc = self._encode(Xpair)
        logits = self.model(**enc).logits.softmax(-1)
        p = logits[:,1].cpu().numpy()
        return np.vstack([1-p, p]).T

    def predict(self, Xpair: np.ndarray):
        return (self.predict_proba(Xpair)[:,1] >= 0.5).astype(int)


# ════════════════════  Minimal SAINT  ════════════════════════════
class SAINTBlock(nn.Module):
    """One block: column-attn then row-attn."""
    def __init__(self, num_cols: int, d_model=64, nhead=4,
                 d_ff=128, num_layers=3, dropout=0.1):
        super().__init__()
        self.col_emb = nn.Embedding(num_cols, d_model)
        self.row_cls = nn.Parameter(torch.randn(1, 1, d_model))
        layers = []
        for _ in range(num_layers):
            layers.append(nn.TransformerEncoderLayer(
                d_model, nhead, dim_feedforward=d_ff,
                dropout=dropout, batch_first=True
            ))  # column-wise
            layers.append(nn.TransformerEncoderLayer(
                d_model, nhead, dim_feedforward=d_ff,
                dropout=dropout, batch_first=True
            ))  # row-wise
        self.tr = nn.Sequential(*layers)
        self.out = nn.Linear(d_model, 1)

    def forward(self, x):
        # x: (B, C)
        B, C = x.shape
        # embed each column
        tok = x.unsqueeze(-1) + self.col_emb.weight[:C]
        cls = self.row_cls.expand(B, -1, -1)
        z = torch.cat([cls, tok], dim=1)      # (B, C+1, D)
        h = self.tr(z)[:, 0]                  # take cls token
        return self.out(h).squeeze(-1)        # (B,)


class Saint(BaseEstimator, ClassifierMixin):
    """
    SAINT for tabular ER: alternating column+row attention.
    """
    def __init__(self, num_cols: int, epochs=80, lr=1e-3,
                 d_model=64, nhead=4, d_ff=128, num_layers=3, dropout=0.1):
        self.device = DEVICE
        self.net = SAINTBlock(
            num_cols, d_model, nhead, d_ff, num_layers, dropout
        ).to(self.device)
        self.epochs = epochs
        self.lr = lr
        self.loss_fn = nn.BCEWithLogitsLoss()

    def fit(self, X: np.ndarray, y: np.ndarray):
        X_t = torch.tensor(X, dtype=torch.float32, device=self.device)
        y_t = torch.tensor(y.reshape(-1,1), dtype=torch.float32, device=self.device)
        ds = TensorDataset(X_t, y_t)
        loader = DataLoader(ds, batch_size=1024, shuffle=True)
        opt = torch.optim.AdamW(self.net.parameters(), lr=self.lr)

        self.net.train()
        for _ in range(self.epochs):
            for xb, yb in loader:
                opt.zero_grad()
                logits = self.net(xb)
                loss = self.loss_fn(logits.unsqueeze(-1), yb)
                loss.backward(); opt.step()
        return self

    @torch.no_grad()
    def predict_proba(self, X: np.ndarray):
        X_t = torch.tensor(X, dtype=torch.float32, device=self.device)
        logits = self.net(X_t).sigmoid().cpu().numpy()
        return np.vstack([1-logits, logits]).T

    def predict(self, X: np.ndarray):
        return (self.predict_proba(X)[:,1] >= 0.5).astype(int)


# ════════════════════  HFZeroShot ─═══════════════════════════
class HFZeroShot(BaseEstimator, ClassifierMixin):
    """
    Zero-shot entity match via Llama-3 8B-Instruct GPTQ-4bit.
    """
    def __init__(self,
                 repo_id="astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit",
                 max_new=1):
        cfg = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4"
        )
        self.tok = AutoTokenizer.from_pretrained(repo_id, use_fast=True)
        self.llm = AutoModelForCausalLM.from_pretrained(
            repo_id, device_map="auto", quantization_config=cfg
        ).eval()
        self.max_new = max_new
        self.yes = re.compile(r"\b(yes|sí|ja)\b", re.I)

    def fit(self, X, y=None):
        return self

    @torch.no_grad()
    def predict_proba(self, Xpair: np.ndarray):
        prompts = [
            f"Same person? Yes or No.\nA:{a}\nB:{b}\nAnswer:"
            for a, b in Xpair
        ]
        batch = self.tok(
            prompts, return_tensors="pt",
            padding=True, truncation=True, max_length=512
        ).to(self.llm.device)
        gen = self.llm.generate(
            **batch, max_new_tokens=self.max_new, temperature=0.0
        )
        outs = self.tok.batch_decode(
            gen[:, batch["input_ids"].shape[1]:],
            skip_special_tokens=True
        )
        p = np.array([1.0 if self.yes.search(o) else 0.0 for o in outs])
        return np.vstack([1-p, p]).T

    def predict(self, Xpair: np.ndarray):
        return (self.predict_proba(Xpair)[:,1] >= 0.5).astype(int)


# ═════════════════════ registry helper ═════════════════════
def _unavail(pkg):
    raise ImportError(f"Please pip install {pkg} to use this model.")

def get_model(
    key: str,
    cat_cols: list[str] | None = None,
    num_cols: list[str] | None = None
):
    k = key.lower()
    if k == "logreg":
        return LogReg(max_iter=2000, n_jobs=-1)
    if k == "tabnet":
        return TabNet()
    if k == "tabpfn":
        return TabPFN()
    if k == "saint":
        if num_cols is None:
            raise ValueError("SAINT requires `num_cols` length")
        return Saint(num_cols=len(num_cols))
    if k == "tabtransformer":
        return TabTransformer(cat_cols or [], num_cols or [])
    if k == "fttransformer":
        return FTTransformer(cat_cols or [], num_cols or [])
    if k == "nars":
        return NARS()
    if k == "ditto":
        return Ditto()
    if k in ("hf_llm", "llm"):
        return HFZeroShot()
    raise ValueError(f"Unknown model key: {key}")


In [None]:
from __future__ import annotations

import itertools, logging, multiprocessing, pathlib, random, warnings, gc, os, re
from typing import List, Tuple

import numpy as np
import pandas as pd
from scipy import sparse
from sklearn.metrics import (accuracy_score, precision_score, recall_score,
                             f1_score, roc_auc_score)
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn

# ───────────────────────── constants ──────────────────────────
DEVICE      = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_num_threads(4)                       # keep BLAS polite

ART_DIR     = pathlib.Path("artifacts")
DATA_DIR    = pathlib.Path("raw_data")

THR                     = 0.5
NEG_RATIO               = 1
N_RUNS                  = 10
NUM_WORKERS             = multiprocessing.cpu_count()
MAX_PAIRS_PER_SPLIT     = 40_000        # hard cap per train / test split
BATCH_SIZE_MATRIX       = 4_000         # rows per sparse-diff batch

MODELS = [
    "logreg", "tabnet", "tabpfn", "saint",
    "tabtransformer", "fttransformer",
    "nars", "ditto", "attendem", "hf_llm"
]

# ───────────────────── baseline logistic reg ──────────────────────
class LogReg(BaseEstimator, ClassifierMixin):
    def __init__(self, penalty="l2", C=1.0, solver="lbfgs",
                 max_iter=2000, n_jobs=-1, **kw):
        self._lr = LogisticRegression(
            penalty=penalty, C=C, solver=solver,
            max_iter=max_iter, n_jobs=n_jobs, **kw)

    def fit(self, X, y): self._lr.fit(X, y); return self
    def predict(self, X): return self._lr.predict(X)
    def predict_proba(self, X): return self._lr.predict_proba(X)

# ─────────────────────── tabular DL baselines ─────────────────────
class TabNet(BaseEstimator, ClassifierMixin):
    def __init__(self): self.tab = TabNetClassifier()
    def fit(self, X, y):
        self.tab.fit(X=X, y=y, eval_set=[(X, y)],
                     max_epochs=200, patience=20, verbose=0); return self
    def predict_proba(self, X): return self.tab.predict_proba(X)
    def predict(self, X): return self.tab.predict(X)

class TabPFN(TabPFNClassifier):
    def __init__(self):
        super().__init__(device=("cuda" if torch.cuda.is_available() else "cpu"),
                         N_ensemble_configurations=32)

# ────────────────────── PyTorch-Tabular wrappers ──────────────────
class _PTab(BaseEstimator, ClassifierMixin):
    CFG = None
    def __init__(self, cat_cols:list[str], num_cols:list[str]):
        self.cat, self.num = cat_cols, num_cols
    def fit(self, X, y):
        df = pd.DataFrame(X, columns=self.num+self.cat); df["_y_"]=y
        data_cfg = DataConfig(target=["_y_"],
                              continuous_cols=self.num,
                              categorical_cols=self.cat)
        trainer_cfg = TrainerConfig(max_epochs=50, gpus=0, progress_bar=False)
        self.tm = TabularModel(data_cfg, self.CFG(), trainer_cfg, OptimizerConfig())
        self.tm.fit(train=df); return self
    def predict_proba(self, X):
        df = pd.DataFrame(X, columns=self.num+self.cat)
        p = self.tm.predict(df)["prediction"].values
        return np.vstack([1-p, p]).T
    def predict(self, X): return (self.predict_proba(X)[:,1]>=0.5).astype(int)

class TabTransformer(_PTab): CFG = TabTransformerConfig
class FTTransformer(_PTab):  CFG = FTTransformerConfig

# ───────────────────────── NARS (symbolic) ────────────────────────
P_POOL_SIZE, N_PTRS = 100, 10
ULTIMATE_COLS = ["ID","Manntal","Nafn","Fornafn","Millinafn","Eftirnafn","Aettarnafn",
                 "Faedingarar","Kyn","Stada","Hjuskapur","bi_einstaklingur","bi_baer",
                 "bi_hreppur","bi_sokn","bi_sysla","cleaned_status","uniqueness_score",
                 "id_individual","score"]
BAN_LIST = {"ID","Millinafn","Aettarnafn","bi_hreppur","bi_sokn",
            "bi_sysla","uniqueness_score","id_individual","score"}

class Truth:
    def __init__(self,f:float,c:float): self.f=f; self.c=max(min(c,0.99),0.01)
    @property
    def wp(self): return self.f*self.c/(1-self.c)
    @property
    def wn(self): return (1-self.f)*self.c/(1-self.c)
    def revise(self,o:"Truth"):
        wp,wn=self.wp+o.wp,self.wn+o.wn
        return Truth(wp/(wp+wn),(wp+wn)/(wp+wn+1))
    @property
    def e(self): return self.c*(self.f-0.5)+0.5

class Pattern:
    def __init__(self,stmts:set[str],truth:Truth): self.statements, self.truth = stmts, truth
    def __len__(self): return len(self.statements)
    def __hash__(self): return sum(hash(s) for s in self.statements)
    @property
    def e(self): return self.truth.e
    def match(self,o:"Pattern"):
        sh=self.statements&o.statements
        a,b=Pattern(self.statements-sh,self.truth),Pattern(o.statements-sh,o.truth)
        m=Pattern(sh,self.truth.revise(o.truth)); base=max(len(self),len(o)) or 1
        return (len(sh)/base,max(self.truth.e,o.truth.e)),a,m,b

class PatternPool:
    def __init__(self,size:int=P_POOL_SIZE): self.size=size; self.pool:list[Pattern]=[]
    def add(self,p:Pattern):
        for i,old in enumerate(self.pool):
            if old.statements==p.statements and p.truth.c>old.truth.c:
                self.pool[i]=p; return
        for i,old in enumerate(self.pool):
            if p.e>old.e: self.pool.insert(i,p); break
        else: self.pool.append(p)
        if len(self.pool)>self.size: self.pool.pop(len(self.pool)//2)
    def get_ptrs(self,n:int=N_PTRS):
        h=n//2; return set(self.pool[:h]+self.pool[-h:])

def preprocess(r1:np.ndarray,r2:np.ndarray)->Pattern:
    stmts,setlbl=set(),-1
    for idx,(a,b) in enumerate(zip(r1,r2)):
        col=ULTIMATE_COLS[idx] if idx<len(ULTIMATE_COLS) else f"COL{idx}"
        if col in BAN_LIST: continue
        if col=="Manntal" and a and b:
            stmts.add(f"year_diff_{int(abs(float(a)-float(b)))}")
        elif col in ("Nafn","Fornafn","Eftirnafn","Faedingarar","Kyn","Hjuskapur"):
            stmts.add(("same_" if a==b else "diff_")+col.lower())
        elif col=="bi_einstaklingur": setlbl=1 if a==b else 0
    return Pattern(stmts,Truth(setlbl if setlbl!=-1 else 0.5,0.9))

def match_ultimate(r1,r2,pool:PatternPool,n_ptr:int,just_eval:bool=False)->float:
    PTC=preprocess(r1,r2); exps=[]
    for ptr in pool.get_ptrs(n_ptr):
        (sim,e),a,m,b = PTC.match(ptr); exps.append(Truth(e,sim))
        if not just_eval:
            for p in (a,m,b):
                if p.statements: pool.add(p)
    if exps:
        out=exps[0]
        for t in exps[1:]: out=out.revise(t)
        return out.e
    if not just_eval: pool.add(PTC)
    return 0.5

class NARS(BaseEstimator, ClassifierMixin):
    def __init__(self,train_pairs:int=3000):
        self.train_pairs=train_pairs; self.pool=PatternPool()
    def _tok(self,s:str): return np.array([t.strip() for t in s.split(" ; ")])
    def fit(self,Xpair,y=None):
        for a,b in Xpair[:self.train_pairs]:
            match_ultimate(self._tok(a),self._tok(b),self.pool,N_PTRS,False)
        return self
    def _p(self,a:str,b:str): return match_ultimate(self._tok(a),self._tok(b),self.pool,N_PTRS,True)
    def predict_proba(self,Xpair):
        p=np.fromiter((self._p(a,b) for a,b in Xpair),float,len(Xpair))
        return np.vstack([1-p,p]).T
    def predict(self,Xpair): return (self.predict_proba(Xpair)[:,1]>=0.5).astype(int)

# ─────────────────────────── Ditto  (pair PLM) ────────────────────
class Ditto(BaseEstimator, ClassifierMixin):
    def __init__(self,plm="microsoft/deberta-v3-small",max_len=256,lr=2e-5,
                 margin=0.4,bs=16,epochs=3):
        self.tok=AutoTokenizer.from_pretrained(plm)
        self.model=AutoModelForSequenceClassification.from_pretrained(plm,num_labels=2).to(DEVICE)
        self.bs,self.epochs,self.lr,self.margin = bs,epochs,lr,margin
    @staticmethod
    def _wrap(rec:str): return " ".join(f"[COL] {kv}" for kv in rec.split(" ; "))
    def _encode(self,pairs):
        txt=[f"{self._wrap(a)} [SEP] {self._wrap(b)}" for a,b in pairs]
        return self.tok(txt,truncation=True,padding=True,max_length=256,
                        return_tensors="pt").to(DEVICE)
    def fit(self,Xpair,y):
        enc=self._encode(Xpair)
        labels=torch.tensor(y,device=DEVICE)
        ds=TensorDataset(enc["input_ids"],enc["attention_mask"],labels)
        loader=DataLoader(ds,batch_size=self.bs,shuffle=True,
                          num_workers=NUM_WORKERS,pin_memory=True)
        opt=torch.optim.AdamW(self.model.parameters(),lr=self.lr); ce=nn.CrossEntropyLoss()
        for _ in range(self.epochs):
            for ids,att,lab in loader:
                opt.zero_grad()
                logits=self.model(ids,attention_mask=att).logits
                pos=logits[lab==1][:,1]; neg=logits[lab==0][:,1]
                hard_neg=neg.max() if len(neg)>0 else 0.0
                loss=ce(logits,lab)+torch.relu(self.margin-pos+hard_neg).mean()
                loss.backward(); opt.step()
        return self
    @torch.no_grad()
    def predict_proba(self,Xpair):
        logits=self.model(**self._encode(Xpair)).logits.softmax(-1)[:,1].cpu().numpy()
        return np.vstack([1-logits,logits]).T
    def predict(self,Xpair): return (self.predict_proba(Xpair)[:,1]>=0.5).astype(int)

# ───────────────────────────── SAINT ──────────────────────────────
class SAINTBlock(nn.Module):
    def __init__(self,num_cols:int,d_model=64,nhead=4,d_ff=128,
                 num_layers=3,dropout=0.1):
        super().__init__()
        self.col_emb=nn.Embedding(num_cols,d_model)
        self.row_cls=nn.Parameter(torch.randn(1,1,d_model))
        layers=[]
        for _ in range(num_layers):
            layers.append(nn.TransformerEncoderLayer(d_model,nhead,d_ff,dropout,batch_first=True))
            layers.append(nn.TransformerEncoderLayer(d_model,nhead,d_ff,dropout,batch_first=True))
        self.tr=nn.Sequential(*layers); self.out=nn.Linear(d_model,1)
    def forward(self,x):
        B,C=x.shape
        tok=x.unsqueeze(-1)+self.col_emb.weight[:C]
        cls=self.row_cls.expand(B,-1,-1)
        h=self.tr(torch.cat([cls,tok],1))[:,0]
        return self.out(h).squeeze(-1)

class Saint(BaseEstimator, ClassifierMixin):
    def __init__(self,num_cols:int,epochs=80,lr=1e-3,d_model=64,
                 nhead=4,d_ff=128,num_layers=3,dropout=0.1):
        self.net=SAINTBlock(num_cols,d_model,nhead,d_ff,num_layers,dropout).to(DEVICE)
        self.epochs, self.lr = epochs,lr; self.loss=nn.BCEWithLogitsLoss()
    def fit(self,X,y):
        X_t=torch.tensor(X,dtype=torch.float32,device=DEVICE)
        y_t=torch.tensor(y.reshape(-1,1),dtype=torch.float32,device=DEVICE)
        ds=TensorDataset(X_t,y_t)
        loader=DataLoader(ds,batch_size=1024,shuffle=True,
                          num_workers=NUM_WORKERS,pin_memory=True)
        opt=torch.optim.AdamW(self.net.parameters(),lr=self.lr)
        self.net.train()
        for _ in range(self.epochs):
            for xb,yb in loader:
                opt.zero_grad()
                loss=self.loss(self.net(xb).unsqueeze(-1), yb)
                loss.backward(); opt.step()
        return self
    @torch.no_grad()
    def predict_proba(self,X):
        p=self.net(torch.tensor(X,dtype=torch.float32,device=DEVICE)).sigmoid().cpu().numpy()
        return np.vstack([1-p,p]).T
    def predict(self,X): return (self.predict_proba(X)[:,1]>=0.5).astype(int)

# ───────────────────── HF Llama-3 zero-shot baseline ───────────────
class HFZeroShot(BaseEstimator, ClassifierMixin):
    def __init__(self,repo_id="astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit",max_new=1):
        cfg=BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_compute_dtype=torch.bfloat16,
                               bnb_4bit_use_double_quant=True,bnb_4bit_quant_type="nf4")
        self.tok=AutoTokenizer.from_pretrained(repo_id,use_fast=True)
        self.llm=AutoModelForCausalLM.from_pretrained(repo_id,device_map="auto",
                                                     quantization_config=cfg).eval()
        self.max_new=max_new; self.yes=re.compile(r"\b(yes|sí|ja)\b",re.I)
    def fit(self,X,y=None): return self
    @torch.no_grad()
    def predict_proba(self,Xpair):
        prompts=[f"Same person? Yes or No.\nA:{a}\nB:{b}\nAnswer:" for a,b in Xpair]
        batch=self.tok(prompts,return_tensors="pt",padding=True,truncation=True,
                       max_length=512).to(self.llm.device)
        gen=self.llm.generate(**batch,max_new_tokens=self.max_new,temperature=0.0)
        outs=self.tok.batch_decode(gen[:,batch["input_ids"].shape[1]:],
                                   skip_special_tokens=True)
        p=np.array([1.0 if self.yes.search(o) else 0.0 for o in outs])
        return np.vstack([1-p,p]).T
    def predict(self,Xpair): return (self.predict_proba(Xpair)[:,1]>=0.5).astype(int)

# ─────────────────────────── registry helper ──────────────────────
def get_model(key:str,cat_cols:list[str]|None=None,num_cols:list[str]|None=None):
    k=key.lower()
    if k=="logreg": return LogReg(max_iter=2000,n_jobs=-1)
    if k=="tabnet": return TabNet()
    if k=="tabpfn": return TabPFN()
    if k=="saint":
        if num_cols is None: raise ValueError("SAINT requires `num_cols`")
        return Saint(num_cols=len(num_cols))
    if k=="tabtransformer": return TabTransformer(cat_cols or [],num_cols or [])
    if k=="fttransformer":  return FTTransformer(cat_cols or [],num_cols or [])
    if k=="nars": return NARS()
    if k=="ditto": return Ditto()
    if k in ("hf_llm","llm"): return HFZeroShot()
    raise ValueError(f"Unknown model key: {key}")

# ───────────────────────── helpers ───────────────────────────────
def load_people() -> pd.DataFrame:
    ppl = (pd.read_csv(DATA_DIR / "people.csv", low_memory=False)
             .rename(columns=str.lower)
             .set_index("id"))
    for c in ["first_name", "middle_name", "patronym", "surname"]:
        if c not in ppl.columns:
            ppl[c] = ""
    ppl["full_name"] = (ppl[["first_name", "middle_name",
                             "patronym", "surname"]]
                        .fillna("")
                        .apply(lambda r: " ".join(w.strip().lower()
                                                  for w in r if w), axis=1))
    lbl = (pd.read_csv(ART_DIR / "row_labels.csv")
             .set_index("row_id")["person"]
             .pipe(pd.to_numeric, errors="coerce"))
    ppl["person"] = ppl.index.to_series().map(lbl)
    return ppl.reset_index()


# ─────────── pair-sampling primitives with hard caps ───────────
def _sample_pairs(bucket: List[int], k: int,
                  rng: np.random.RandomState) -> List[Tuple[int, int]]:
    """
    Uniformly sample up to k unordered pairs from `bucket`
    without ever constructing a 2-D object array (fixes ValueError).
    """
    n = len(bucket)
    if n < 2 or k == 0:
        return []
    all_pairs = list(itertools.combinations(bucket, 2))
    if len(all_pairs) <= k:
        return all_pairs
    idx = rng.choice(len(all_pairs), k, replace=False)
    return [all_pairs[i] for i in idx]


def make_pairs(idxs: np.ndarray, y: np.ndarray, neg_ratio: int,
               rng: np.random.RandomState):
    """Generic positive / negative generator with per-label cap."""
    lab2idx: dict[int, list[int]] = {}
    for i in idxs:
        lab2idx.setdefault(y[i], []).append(i)

    # positives – capped per label
    pos = []
    for inds in lab2idx.values():
        if len(inds) < 2:
            continue
        want = min(len(inds) * (len(inds) - 1) // 2, 200)
        pos.extend(_sample_pairs(inds, want, rng))

    # negatives
    n_neg = min(len(pos) * neg_ratio, MAX_PAIRS_PER_SPLIT - len(pos))
    neg = set()
    labels = list(lab2idx)
    while len(neg) < n_neg and len(labels) > 1:
        l1, l2 = rng.choice(labels, 2, replace=False)
        neg.add(tuple(sorted((rng.choice(lab2idx[l1]),
                              rng.choice(lab2idx[l2])))))

    pairs  = pos + list(neg)
    if len(pairs) > MAX_PAIRS_PER_SPLIT:
        pairs = rng.choice(pairs, MAX_PAIRS_PER_SPLIT, replace=False).tolist()

    labels_arr = np.array([1] * len(pos) + [0] * len(neg), dtype=np.int8)
    return pairs, labels_arr


def make_pairs_within(idxs: np.ndarray, y: np.ndarray,
                      heim: np.ndarray, rng: np.random.RandomState):
    """Pairs only among rows that share the SAME heimild value."""
    buckets: dict[int, list[int]] = {}
    for i in idxs:
        buckets.setdefault(heim[i], []).append(i)

    pairs, label_chunks = [], []
    for sub in buckets.values():
        p, l = make_pairs(sub, y, NEG_RATIO, rng)
        pairs.extend(p); label_chunks.append(l)

    labels = np.concatenate(label_chunks) if label_chunks else np.empty(0, dtype=np.int8)
    return pairs[:MAX_PAIRS_PER_SPLIT], labels[:MAX_PAIRS_PER_SPLIT]


def make_pairs_across(idxs: np.ndarray, y: np.ndarray,
                      heim: np.ndarray, rng: np.random.RandomState):
    """Pairs whose two rows come from DIFFERENT heimild values."""
    lab2idx: dict[int, list[int]] = {}
    for i in idxs:
        lab2idx.setdefault(y[i], []).append(i)

    pos = []
    for inds in lab2idx.values():
        if len(inds) < 2:
            continue
        by_census: dict[int, list[int]] = {}
        for i in inds:
            by_census.setdefault(heim[i], []).append(i)
        if len(by_census) < 2:
            continue
        c_keys = list(by_census)
        for a in range(len(c_keys)):
            for b in range(a + 1, len(c_keys)):
                pairs_ab = list(itertools.product(by_census[c_keys[a]],
                                                  by_census[c_keys[b]]))
                if len(pairs_ab) > 200:
                    pairs_ab = rng.choice(pairs_ab, 200, replace=False).tolist()
                pos.extend(pairs_ab)

    # negatives = diff person & diff census
    neg = set()
    want_neg = min(len(pos) * NEG_RATIO,
                   MAX_PAIRS_PER_SPLIT - len(pos))
    while len(neg) < want_neg:
        i, j = rng.choice(idxs, 2, replace=False)
        if y[i] != y[j] and heim[i] != heim[j]:
            neg.add(tuple(sorted((i, j))))

    pairs = pos + list(neg)
    if len(pairs) > MAX_PAIRS_PER_SPLIT:
        pairs = rng.choice(pairs, MAX_PAIRS_PER_SPLIT, replace=False).tolist()

    labels = np.array([1] * len(pos) + [0] * len(neg), dtype=np.int8)
    return pairs, labels


def pair_matrix(X: sparse.spmatrix, pairs: List[Tuple[int, int]],
                batch_size: int = BATCH_SIZE_MATRIX) -> sparse.csr_matrix:
    """Build |pairs|×d absolute-difference matrix in memory-capped batches."""
    rows = []
    for i in range(0, len(pairs), batch_size):
        chunk = pairs[i:i+batch_size]
        a = X[[p[0] for p in chunk]]
        b = X[[p[1] for p in chunk]]
        rows.append(abs(a - b))
    return sparse.vstack(rows).tocsr()

# ───────────────────────── run one task ──────────────────────────
_DENSE_MODELS = {"tabnet", "tabpfn", "saint",
                 "tabtransformer", "fttransformer", "logreg"}

def run_task(scenario: str, tag: str,
             idx_tr, idx_te, X, y, txt, heim,
             rng: np.random.RandomState):
    logging.info("=== %s | %s ===", tag.upper(), scenario.upper())

    make_pairs_fn = make_pairs_within if scenario == "within" else make_pairs_across
    pairs_tr, y_tr = make_pairs_fn(idx_tr, y, heim, rng)
    pairs_te, y_te = make_pairs_fn(idx_te, y, heim, rng)

    X_tr_sp = pair_matrix(X, pairs_tr)
    X_te_sp = pair_matrix(X, pairs_te)

    X_tr_dense = X_tr_sp.A.astype(np.float32)
    X_te_dense = X_te_sp.A.astype(np.float32)
    num_cols   = [f"f{i}" for i in range(X_tr_dense.shape[1])]

    results = []
    for name in MODELS:
        try:
            mdl = get_model(name, [], num_cols)

            if name in {"nars", "ditto", "attendem", "hf_llm"}:
                tr_txt = np.column_stack([txt[[i for i, _ in pairs_tr]],
                                          txt[[j for _, j in pairs_tr]]])
                te_txt = np.column_stack([txt[[i for i, _ in pairs_te]],
                                          txt[[j for _, j in pairs_te]]])
                mdl.fit(tr_txt, y_tr)
                probs = mdl.predict_proba(te_txt)[:, 1]
            else:
                Xtr = X_tr_dense if name in _DENSE_MODELS else X_tr_sp
                Xte = X_te_dense if name in _DENSE_MODELS else X_te_sp
                mdl.fit(Xtr, y_tr)
                probs = mdl.predict_proba(Xte)[:, 1]

            pred = (probs >= THR).astype(np.int8)
            results.append({
                "model":     name,
                "accuracy":  accuracy_score(y_te, pred),
                "precision": precision_score(y_te, pred, zero_division=0),
                "recall":    recall_score(y_te, pred, zero_division=0),
                "f1":        f1_score(y_te, pred, zero_division=0),
                "roc_auc":   roc_auc_score(y_te, probs),
            })
        except Exception as e:
            warnings.warn(f"{name} failed: {e}")

    return pd.DataFrame(results).set_index("model")


# ─────────────────────────── main loop ───────────────────────────
def main():
    logging.basicConfig(level=logging.INFO,
                        format="%(asctime)s %(levelname)s %(message)s",
                        datefmt="%H:%M:%S")

    ppl  = load_people()
    X    = sparse.load_npz(ART_DIR / "iceid_ml_ready.npz")
    y    = ppl["person"].values
    txt  = ppl["full_name"].values
    heim = ppl["heimild"].values

    # drop NaNs & singletons
    base = (~np.isnan(y)) & (y != -1)
    freq = pd.Series(y[base]).value_counts()
    good = freq[freq >= 2].index
    keep = base & np.isin(y, good)

    X, y, txt, heim = X[keep], y[keep], txt[keep], heim[keep]
    labels = np.unique(y)

    all_runs = []
    for run in range(N_RUNS):
        rng = np.random.RandomState(run)

        train_lbls = rng.choice(labels,
                                size=int(0.8 * len(labels)),
                                replace=False)
        idx_tr = np.where(np.isin(y, train_lbls))[0]
        idx_te = np.where(~np.isin(y, train_lbls))[0]

        for scenario in ("within", "across"):
            df_run = run_task(scenario,
                              f"run_{run+1}",
                              idx_tr, idx_te,
                              X, y, txt, heim, rng)
            df_run["scenario"] = scenario
            all_runs.append(df_run.set_index("scenario", append=True))

        torch.cuda.empty_cache(); gc.collect()

    avg = (pd.concat(all_runs)
             .swaplevel()          # index = (scenario, model)
             .groupby(level=[0,1]).mean())

    print("\n=== AVERAGE OVER RUNS ===")
    print(avg.round(4))
    avg.to_csv("average_results.csv")


if __name__ == "__main__":
    main()

10:06:11 INFO === RUN_1 | WITHIN ===


ValueError: a must be 1-dimensional