In [None]:
!pip install spacy transformers scikit-learn
!python -m spacy download en_core_web_sm

Collecting en-core-web-sm==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m80.2 MB/s[0m eta [36m0:00:00[0m
[?25h[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
[38;5;3m⚠ Restart to reload dependencies[0m
If you are in a Jupyter or Colab notebook, you may need to restart Python in
order to load all the package's dependencies. You can do this by selecting the
'Restart kernel' or 'Restart runtime' option.


In [None]:
import os
import json
import random
import hashlib
from copy import deepcopy
from typing import List, Optional, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup, set_seed
from tqdm.auto import tqdm

# ============================================================
# Narrative Similarity — Pairwise Ranking with Strong Upgrades
# - Cross-encoder scoring (anchor,candidate) run twice (A and B)
# - Train-only hardening augmentations (no data leakage)
# - In-batch negatives + R-Drop + FGM
# - Optional model-based hard-negative mining (train-only)
# - Salience-guided TTA using DeBERTa sentence scoring
# - Per-example test-time adaptation (only norm+scorer, no labels)
# ============================================================

# ----------------------------
# SEED / DEVICE
# ----------------------------
SEED = 42
set_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# ----------------------------
# CONFIG
# ----------------------------
CHECKPOINT    = "microsoft/deberta-v3-large"
MAX_LEN       = 512

BATCH_SIZE    = 2
GRAD_ACC      = 4
EPOCHS        = 6
PATIENCE      = 4

FREEZE_LAYERS = 4
LR_BACKBONE   = 1e-5
LR_HEAD       = 1e-4
WEIGHT_DECAY  = 0.05

MARGIN        = 0.10
POOLING       = "mean"   # "mean" or "cls"

# Word-level truncation budgets
ANC_WORDS  = 240
CAND_WORDS = 260

# Deterministic preprocessing applied in BOTH train & eval (label-free)
P_ENTITY_ANON = 0.60
P_DROP_PARENS = 0.35
STYLE_NORMALIZE_ALWAYS = True

# Train-only augmentations in harden_training_data
DO_WORD_DROPOUT_POS   = True
DO_SENT_DROPOUT_POS   = True
DO_SHUFFLE_NEG        = True
DO_OUTCOME_SWAP_NEG   = True
DO_SPLICED_NEG        = True
DO_MINED_NEG          = True

P_WORD_DROPOUT_POS    = 0.55
P_SENT_DROPOUT_POS    = 0.55
P_SHUFFLE_NEG         = 0.25
P_OUTCOME_SWAP_NEG    = 0.30
P_SPLICED_NEG         = 0.10
P_MINED_NEG           = 1.00

# Mining controls (train-only)
MINED_TOPK = 20
MINED_MIN_SIM = 0.10

# Train-time truncation jitter
P_RAND_TRUNC_TRAIN = 0.15

# ----------------------------
# TRAINING OBJECTIVE UPGRADES
# ----------------------------
USE_INBATCH_NCE = True
USE_R_DROP = True
R_DROP_BETA = 0.5  # KL weight

USE_PAIRWISE_AUX = True
PAIRWISE_ALPHA = 0.5  # auxiliary margin loss weight

USE_FGM = True
FGM_EPS = 0.5  # typical 0.25–1.0

# Model-based mining (train-only) after epoch 1
USE_MODEL_MINING = True
MODEL_MINING_EPOCH = 2          # mine after finishing epoch 1, before epoch 2 training
MODEL_MINING_SHORTLIST = 60     # TF-IDF shortlist size
MODEL_MINING_ADD_PER_EX = 1     # add up to 1 mined negative augmentation per training row

# ----------------------------
# TTA + Per-example TTA adaptation
# ----------------------------
USE_SALIENCE_TTA = True

# sentence salience selection
SALIENCE_TOPK_LIST = (2, 4)
SALIENCE_KEEP_LAST = 2
SALIENCE_KEEP_FIRST = True
SALIENCE_MAX_SENTS = 40  # cap for speed

# combine with truncation views
TTA_TRUNC_VIEWS = ("headtail", "tail")  # keep modest; salience does heavy lifting
TTA_AGG = "mean"  # mean over diffs

# per-example adaptation
USE_PER_EX_ADAPT = True
ADAPT_STEPS = 3
ADAPT_LR = 1e-3
ADAPT_ONLY_IF_ABS_DIFF_LT = 0.20  # adapt only if uncertain (least risky)
ADAPT_VAR_W = 1.0
ADAPT_PSEUDO_W = 1.0
ADAPT_DRIFT_W = 0.10

tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT, use_fast=True)

# ============================================================
# spaCy backend (no regex entity detection)
# ============================================================

def load_spacy():
    import spacy
    try:
        return spacy.load("en_core_web_sm")
    except Exception:
        try:
            from spacy.cli import download as spacy_download
            spacy_download("en_core_web_sm")
            return spacy.load("en_core_web_sm")
        except Exception as e:
            raise RuntimeError(
                "spaCy model en_core_web_sm not available. Install: "
                "pip install spacy && python -m spacy download en_core_web_sm"
            ) from e

NLP = load_spacy()
print("[NLP] Backend = spacy")

# ============================================================
# Utils (no regex)
# ============================================================

def clean_text(text: str) -> str:
    if not text:
        return ""
    t = str(text).replace("\\", " ")
    return " ".join(t.split())

def safe_bool(val):
    if isinstance(val, bool):
        return val
    if isinstance(val, str):
        v = val.strip().lower()
        return v not in ("false", "0", "no", "off", "")
    return bool(val)

def hash_int(s: str) -> int:
    h = hashlib.md5(s.encode("utf-8")).hexdigest()
    return int(h[:8], 16)

def deterministic_gate(seed_int: int, p: float) -> bool:
    if p <= 0:
        return False
    if p >= 1:
        return True
    x = (seed_int % 10_000) / 10_000.0
    return x < p

def text_seed(text: str) -> int:
    return hash_int(clean_text(text))

def triplet_seed(anc: str, a: str, b: str) -> int:
    return hash_int(clean_text(anc) + "||" + clean_text(a) + "||" + clean_text(b))

def row_fingerprint(r: dict) -> str:
    anc = truncate_view(r["anchor_text"], ANC_WORDS, "headtail", 0)
    a   = truncate_view(r["text_a"],      CAND_WORDS, "headtail", 0)
    b   = truncate_view(r["text_b"],      CAND_WORDS, "headtail", 0)
    return hashlib.md5((anc + "||" + a + "||" + b).encode("utf-8")).hexdigest()

def report_overlap(train_rows: List[dict], val_rows: List[dict], name="val"):
    train_set = set(row_fingerprint(r) for r in train_rows)
    val_set   = set(row_fingerprint(r) for r in val_rows)
    print(f"[Leakage Check] Exact-duplicate overlap train vs {name}: {len(train_set & val_set)}")

# ============================================================
# Style normalization + entity anonymization (train+test, label-free)
# ============================================================

def squash_runs(s: str, ch: str) -> str:
    out = []
    run = 0
    for c in s:
        if c == ch:
            run += 1
            if run <= 1:
                out.append(c)
        else:
            run = 0
            out.append(c)
    return "".join(out)

def remove_parentheticals(text: str) -> str:
    out = []
    depth = 0
    for ch in text:
        if ch == "(":
            depth += 1
            continue
        if ch == ")":
            depth = max(0, depth - 1)
            continue
        if depth == 0:
            out.append(ch)
    return clean_text("".join(out))

def style_normalize(text: str, seed_int: int) -> str:
    t = clean_text(text)
    if not t:
        return ""
    t = t.replace("“", '"').replace("”", '"').replace("’", "'")
    t = squash_runs(t, "!")
    t = squash_runs(t, "?")
    t = squash_runs(t, ".")
    if deterministic_gate(seed_int + 911, P_DROP_PARENS):
        t = remove_parentheticals(t)
    return clean_text(t)

def anonymize_entities_spacy(text: str) -> str:
    t = clean_text(text)
    if not t:
        return ""
    doc = NLP(t)
    allowed = {"PERSON", "GPE", "LOC", "ORG", "DATE", "TIME"}
    ents = [e for e in doc.ents if e.label_ in allowed]
    if not ents:
        return t

    mapping = {}
    counters = {lab: 0 for lab in allowed}

    pieces = []
    last = 0
    for e in ents:
        key = (e.text, e.label_)
        if key not in mapping:
            counters[e.label_] += 1
            mapping[key] = f"{e.label_}_{counters[e.label_]}"
        pieces.append(t[last:e.start_char])
        pieces.append(mapping[key])
        last = e.end_char
    pieces.append(t[last:])
    return clean_text("".join(pieces))

def preprocess_text(text: str) -> str:
    """
    Deterministic & label-free (depends only on the text string).
    Applied in BOTH train and eval to avoid train/test mismatch.
    """
    t = clean_text(text)
    if not t:
        return ""
    s = text_seed(t)

    if STYLE_NORMALIZE_ALWAYS:
        t = style_normalize(t, s)

    if deterministic_gate(s + 1337, P_ENTITY_ANON):
        t = anonymize_entities_spacy(t)

    return clean_text(t)

# ============================================================
# Truncation views
# ============================================================

def truncate_view(text: str, max_words: int, mode: str, seed_int: int) -> str:
    t = clean_text(text)
    if not t:
        return ""
    words = t.split()
    if len(words) <= max_words:
        return t

    if mode == "head":
        return " ".join(words[:max_words])
    if mode == "tail":
        return " ".join(words[-max_words:])
    if mode == "rand":
        rng = random.Random(seed_int)
        start_max = max(0, len(words) - max_words)
        start = rng.randint(0, start_max) if start_max > 0 else 0
        return " ".join(words[start:start + max_words])

    half = max_words // 2
    return " ".join(words[:half]) + " ... " + " ".join(words[-half:])

# ============================================================
# Sentence splitting and augmentations (train-only)
# ============================================================

def split_sentences(text: str) -> List[str]:
    t = clean_text(text)
    if not t:
        return []
    doc = NLP(t)
    sents = [s.text.strip() for s in doc.sents if s.text.strip()]
    return sents if sents else [t]

def augment_word_dropout(text: str, rate: float, seed_int: int) -> str:
    rng = random.Random(seed_int)
    words = clean_text(text).split()
    if len(words) < 6:
        return clean_text(text)
    kept = [w for w in words if rng.random() > rate]
    if len(kept) < 3:
        kept = words[:3]
    return " ".join(kept)

def drop_sentences(text: str, drop_prob: float, seed_int: int, keep_ends: bool = True) -> str:
    sents = split_sentences(text)
    if len(sents) < 5:
        return clean_text(text)
    rng = random.Random(seed_int)
    keep = []
    for i, s in enumerate(sents):
        if keep_ends and (i == 0 or i == len(sents) - 1):
            keep.append(s)
            continue
        if rng.random() > drop_prob:
            keep.append(s)
    if len(keep) < 2:
        return clean_text(text)
    return clean_text(" ".join(keep))

def augment_shuffle_sentences(text: str, seed_int: int) -> str:
    sents = split_sentences(text)
    if len(sents) < 3:
        return clean_text(text)
    rng = random.Random(seed_int)
    rng.shuffle(sents)
    return clean_text(" ".join(sents))

def _choose_donor_text(pool_texts: List[str], forbidden: set, seed_int: int, min_sents: int) -> Optional[str]:
    if not pool_texts:
        return None
    rng = random.Random(seed_int)
    n = len(pool_texts)
    for _ in range(min(60, n)):
        idx = rng.randrange(n)
        cand = pool_texts[idx]
        if not cand or cand in forbidden:
            continue
        if len(split_sentences(cand)) >= min_sents:
            return cand
    return None

def outcome_swap_negative(
    pos_text: str,
    pool_texts: List[str],
    seed_int: int,
    forbidden: Optional[set] = None,
    min_pos_sents: int = 4,
    swap_last_k: int = 2
) -> Optional[str]:
    pos_text = clean_text(pos_text)
    if not pos_text:
        return None
    forbidden = forbidden or set()

    pos_sents = split_sentences(pos_text)
    if len(pos_sents) < min_pos_sents or len(pos_sents) <= swap_last_k:
        return None

    donor = _choose_donor_text(pool_texts, forbidden, seed_int, min_sents=swap_last_k)
    if donor is None:
        return None
    donor_sents = split_sentences(donor)
    if len(donor_sents) < swap_last_k:
        return None

    swapped = clean_text(" ".join(pos_sents[:-swap_last_k] + donor_sents[-swap_last_k:]))
    if not swapped or swapped == pos_text:
        return None
    return swapped

def spliced_negative(pos_text: str, pool_texts: List[str], seed_int: int) -> Optional[str]:
    sents = split_sentences(pos_text)
    if len(sents) < 4 or not pool_texts:
        return None
    rng = random.Random(seed_int)
    donor = pool_texts[rng.randrange(len(pool_texts))]
    donor_sents = split_sentences(donor)
    if len(donor_sents) < 2:
        return None
    mid = len(sents) // 2
    splice = clean_text(" ".join(sents[:mid] + donor_sents[mid:mid + max(1, len(sents) - mid)]))
    if not splice or splice == clean_text(pos_text):
        return None
    return splice

# ============================================================
# Data loading
# ============================================================

def normalize_row(r: dict) -> Optional[dict]:
    anc, a, b, label = None, None, None, None
    if "anchor_text" in r:
        anc, a, b = r.get("anchor_text"), r.get("text_a"), r.get("text_b")
        val = r.get("text_a_is_closer", r.get("label", 0))
        label = safe_bool(val)
    elif "anchor_story" in r:
        anc, a, b, label = r.get("anchor_story"), r.get("similar_story"), r.get("dissimilar_story"), True

    if not anc or not a or not b:
        return None

    return {
        "anchor_text": clean_text(anc),
        "text_a": clean_text(a),
        "text_b": clean_text(b),
        "text_a_is_closer": bool(label),
    }

def load_jsonl(path: str) -> List[dict]:
    rows = []
    if not os.path.exists(path):
        return rows
    with open(path, encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
                nr = normalize_row(obj)
                if nr:
                    rows.append(nr)
            except Exception:
                continue
    return rows

# ============================================================
# Train-only hardening (NO data leakage)
# ============================================================

def harden_training_data(train_rows: List[dict]) -> Tuple[List[dict], List[str]]:
    """
    Returns:
      - augmented_rows
      - train_pool_texts (unique_texts) used for donor selection/mining (TRAIN ONLY)
    """
    print("\n[Hardening] Mining & Augmenting Training Set (TRAIN ONLY)...")

    unique_texts = list(set(
        [r["anchor_text"] for r in train_rows] +
        [r["text_a"] for r in train_rows] +
        [r["text_b"] for r in train_rows]
    ))
    unique_texts = [t for t in unique_texts if t and t.strip()]

    use_mining = DO_MINED_NEG and (len(unique_texts) > 50)
    if use_mining:
        vectorizer = TfidfVectorizer(stop_words="english", min_df=1, max_features=5000)
        tfidf_matrix = vectorizer.fit_transform(unique_texts)  # TRAIN-ONLY fit
        text_to_idx = {t: i for i, t in enumerate(unique_texts)}
    else:
        vectorizer = None
        tfidf_matrix = None
        text_to_idx = {}

    augmented = []
    stats = {
        "orig": 0,
        "mined": 0,
        "shuffle_neg": 0,
        "spliced_neg": 0,
        "outcome_swap_neg": 0,
        "pos_word_dropout": 0,
        "pos_sent_dropout": 0,
    }

    for r in tqdm(train_rows, desc="Augmenting"):
        anc = r["anchor_text"]
        if r["text_a_is_closer"]:
            pos, neg = r["text_a"], r["text_b"]
        else:
            pos, neg = r["text_b"], r["text_a"]

        base_seed = hash_int(anc + "||" + pos + "||" + neg)

        augmented.append({
            "anchor_text": anc, "text_a": pos, "text_b": neg,
            "text_a_is_closer": True, "type": "original"
        })
        stats["orig"] += 1

        if DO_SENT_DROPOUT_POS and (random.random() < P_SENT_DROPOUT_POS):
            pos_sd = drop_sentences(pos, drop_prob=0.22, seed_int=base_seed + 10, keep_ends=True)
            augmented.append({
                "anchor_text": anc, "text_a": pos_sd, "text_b": neg,
                "text_a_is_closer": True, "type": "pos_sent_dropout"
            })
            stats["pos_sent_dropout"] += 1

        if DO_WORD_DROPOUT_POS and (random.random() < P_WORD_DROPOUT_POS):
            pos_wd = augment_word_dropout(pos, rate=0.12, seed_int=base_seed + 11)
            augmented.append({
                "anchor_text": anc, "text_a": pos_wd, "text_b": neg,
                "text_a_is_closer": True, "type": "pos_word_dropout"
            })
            stats["pos_word_dropout"] += 1

        if DO_SHUFFLE_NEG and (random.random() < P_SHUFFLE_NEG):
            shuf = augment_shuffle_sentences(pos, seed_int=base_seed + 20)
            if shuf and shuf != clean_text(pos):
                augmented.append({
                    "anchor_text": anc, "text_a": pos, "text_b": shuf,
                    "text_a_is_closer": True, "type": "neg_shuffle"
                })
                stats["shuffle_neg"] += 1

        if DO_OUTCOME_SWAP_NEG and (random.random() < P_OUTCOME_SWAP_NEG):
            forbidden = {anc, pos, neg}
            swapped = outcome_swap_negative(
                pos_text=pos,
                pool_texts=unique_texts,  # TRAIN ONLY
                seed_int=base_seed + 30,
                forbidden=forbidden,
                min_pos_sents=4,
                swap_last_k=2
            )
            if swapped:
                augmented.append({
                    "anchor_text": anc, "text_a": pos, "text_b": swapped,
                    "text_a_is_closer": True, "type": "neg_outcome_swap"
                })
                stats["outcome_swap_neg"] += 1

        if DO_SPLICED_NEG and (random.random() < P_SPLICED_NEG):
            splice = spliced_negative(pos, unique_texts, seed_int=base_seed + 40)
            if splice:
                augmented.append({
                    "anchor_text": anc, "text_a": pos, "text_b": splice,
                    "text_a_is_closer": True, "type": "neg_spliced"
                })
                stats["spliced_neg"] += 1

        if use_mining and (random.random() < P_MINED_NEG) and (anc in text_to_idx):
            anc_vec = tfidf_matrix[text_to_idx[anc]]
            scores = cosine_similarity(anc_vec, tfidf_matrix).flatten()
            top_indices = np.argsort(scores)[::-1][:MINED_TOPK]

            mined = None
            for idx in top_indices:
                cand = unique_texts[idx]
                if cand == anc or cand == pos:
                    continue
                if scores[idx] < MINED_MIN_SIM:
                    break
                mined = cand
                break

            if mined:
                augmented.append({
                    "anchor_text": anc, "text_a": pos, "text_b": mined,
                    "text_a_is_closer": True, "type": "neg_mined"
                })
                stats["mined"] += 1

    print(f"\n[Augmentation Stats] {stats}")
    return augmented, unique_texts

# ============================================================
# Dataset
# ============================================================

class NarrativeDataset(Dataset):
    def __init__(self, rows: List[dict], is_train: bool = False):
        self.rows = rows
        self.is_train = is_train

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, idx):
        r = self.rows[idx]
        anc, a, b = r["anchor_text"], r["text_a"], r["text_b"]
        label = 1.0 if r["text_a_is_closer"] else 0.0

        # train-only swap augmentation (label-consistent)
        if self.is_train and (random.random() < 0.5):
            a, b = b, a
            label = 1.0 - label

        return anc, a, b, torch.tensor(label, dtype=torch.float32)

# ============================================================
# Shared-anchor preprocessing for triplets
# ============================================================

def prepare_triplet_views(anc: str, a: str, b: str, mode: str, seed_int: int) -> Tuple[str, str, str]:
    """
    One anchor view shared across both comparisons (critical for stability).
    Preprocess is deterministic & label-free.
    """
    anc_p = preprocess_text(anc)
    a_p   = preprocess_text(a)
    b_p   = preprocess_text(b)

    anc_t = truncate_view(anc_p, ANC_WORDS,  mode, seed_int + 101)
    a_t   = truncate_view(a_p,   CAND_WORDS, mode, seed_int + 102)
    b_t   = truncate_view(b_p,   CAND_WORDS, mode, seed_int + 103)
    return anc_t, a_t, b_t

def collate_train_strings(batch):
    """
    Returns preprocessed/truncated strings:
      anchors: [B]
      pos:     [B]
      neg:     [B]
    so we can compute in-batch negatives efficiently.
    """
    anc, a, b, lab = zip(*batch)
    anchors, pos_list, neg_list = [], [], []

    for x_anc, x_a, x_b, y in zip(anc, a, b, lab):
        s = triplet_seed(x_anc, x_a, x_b)
        mode = "rand" if random.random() < P_RAND_TRUNC_TRAIN else "headtail"
        anc_t, a_t, b_t = prepare_triplet_views(x_anc, x_a, x_b, mode=mode, seed_int=s)

        if float(y.item()) == 1.0:
            pos, neg = a_t, b_t
        else:
            pos, neg = b_t, a_t

        anchors.append(anc_t)
        pos_list.append(pos)
        neg_list.append(neg)

    return anchors, pos_list, neg_list

def collate_eval_pairs(batch):
    anc, a, b, lab = zip(*batch)

    anc_a, cand_a = [], []
    anc_b, cand_b = [], []

    for x_anc, x_a, x_b in zip(anc, a, b):
        s = triplet_seed(x_anc, x_a, x_b)
        anc_t, a_t, b_t = prepare_triplet_views(x_anc, x_a, x_b, mode="headtail", seed_int=s)
        anc_a.append(anc_t); cand_a.append(a_t)
        anc_b.append(anc_t); cand_b.append(b_t)

    tok_a = tokenizer(anc_a, cand_a, padding=True, truncation=True, max_length=MAX_LEN, return_tensors="pt")
    tok_b = tokenizer(anc_b, cand_b, padding=True, truncation=True, max_length=MAX_LEN, return_tensors="pt")
    return tok_a, tok_b, torch.stack(lab)

def collate_raw(batch):
    anc, a, b, lab = zip(*batch)
    return list(anc), list(a), list(b), torch.stack(lab)

# ============================================================
# Model
# ============================================================

class PairwiseRanker(torch.nn.Module):
    def __init__(self, model_name: str, freeze_layers: int = 0, pooling: str = "mean"):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(model_name)
        self.pooling = pooling

        encoder = getattr(self.backbone, "encoder", None)
        if hasattr(self.backbone, "deberta"):
            encoder = self.backbone.deberta.encoder

        if encoder and freeze_layers > 0:
            for i, layer in enumerate(encoder.layer):
                if i < freeze_layers:
                    for p in layer.parameters():
                        p.requires_grad = False

        hidden = self.backbone.config.hidden_size
        self.norm = torch.nn.LayerNorm(hidden)
        self.scorer = torch.nn.Linear(hidden, 1)

    def forward(self, **tok):
        out = self.backbone(**tok)
        if self.pooling == "cls":
            rep = out.last_hidden_state[:, 0]
        else:
            mask = tok["attention_mask"].unsqueeze(-1).float()
            rep = (out.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
        return self.scorer(self.norm(rep)).squeeze(-1)

def get_word_embedding_param(model: torch.nn.Module):
    # robust search for word embedding weights
    for name, p in model.named_parameters():
        if p.requires_grad and ("word_embeddings" in name or name.endswith("embeddings.word_embeddings.weight")):
            return name, p
    return None, None

class FGM:
    def __init__(self, model: torch.nn.Module, eps: float = 0.5):
        self.model = model
        self.eps = eps
        self.backup = {}
        self.embed_name, self.embed_param = get_word_embedding_param(model)

    def attack(self):
        if self.embed_param is None or self.embed_param.grad is None:
            return False
        grad = self.embed_param.grad
        norm = torch.norm(grad)
        if norm == 0 or torch.isnan(norm):
            return False
        self.backup[self.embed_name] = self.embed_param.data.clone()
        r_at = self.eps * grad / (norm + 1e-12)
        self.embed_param.data.add_(r_at)
        return True

    def restore(self):
        if self.embed_name in self.backup:
            self.embed_param.data = self.backup[self.embed_name]
        self.backup = {}

# ============================================================
# Scoring helpers
# ============================================================

def score_pairs(model: torch.nn.Module, anchors: List[str], cands: List[str]) -> torch.Tensor:
    tok = tokenizer(anchors, cands, padding=True, truncation=True, max_length=MAX_LEN, return_tensors="pt")
    tok = {k: v.to(DEVICE) for k, v in tok.items()}
    return model(**tok)

def score_matrix(model: torch.nn.Module, anchors: List[str], cands: List[str]) -> torch.Tensor:
    """
    Compute logits matrix [B, C] where logits[i,j] = score(anchor_i, cand_j).
    Cross-encoder => requires B*C pairs.
    """
    B = len(anchors)
    C = len(cands)
    flat_anchors = []
    flat_cands = []
    for i in range(B):
        flat_anchors.extend([anchors[i]] * C)
        flat_cands.extend(cands)

    logits_flat = score_pairs(model, flat_anchors, flat_cands)
    return logits_flat.view(B, C)

# ============================================================
# Train-time model-based mining (TRAIN ONLY)
# ============================================================

@torch.no_grad()
def model_mine_hard_negatives(
    model: torch.nn.Module,
    train_rows: List[dict],
    train_pool_texts: List[str],
    add_per_ex: int = 1,
    shortlist_k: int = 60,
) -> List[dict]:
    """
    Leakage-safe:
      - uses only train_rows and train_pool_texts
      - no val/test access
    Strategy:
      - TF-IDF shortlist for each anchor (fast)
      - among shortlist, pick highest model score as hard negative (excluding pos/anc)
    """
    if len(train_pool_texts) < 50:
        return []

    vectorizer = TfidfVectorizer(stop_words="english", min_df=1, max_features=8000)
    pool_mat = vectorizer.fit_transform(train_pool_texts)
    pool_idx = {t: i for i, t in enumerate(train_pool_texts)}

    extra = []
    model.eval()

    for r in tqdm(train_rows, desc="Model-mining (train-only)"):
        anc = r["anchor_text"]
        if r["text_a_is_closer"]:
            pos, neg = r["text_a"], r["text_b"]
        else:
            pos, neg = r["text_b"], r["text_a"]

        if anc not in pool_idx:
            continue

        anc_vec = pool_mat[pool_idx[anc]]
        sims = cosine_similarity(anc_vec, pool_mat).flatten()
        top = np.argsort(sims)[::-1]

        candidates = []
        for idx in top[: max(shortlist_k, 10)]:
            cand = train_pool_texts[idx]
            if not cand or cand == anc or cand == pos or cand == neg:
                continue
            if sims[idx] < 0.05:
                break
            candidates.append(cand)
            if len(candidates) >= shortlist_k:
                break

        if not candidates:
            continue

        # score candidates with the current model (cross-encoder)
        anc_p = preprocess_text(anc)
        pos_p = preprocess_text(pos)

        # keep consistent truncation
        anc_t = truncate_view(anc_p, ANC_WORDS, "headtail", text_seed(anc_p))
        pos_t = truncate_view(pos_p, CAND_WORDS, "headtail", text_seed(pos_p))

        # preprocess + truncate candidate shortlist
        cand_proc = []
        for c in candidates:
            c_p = preprocess_text(c)
            c_t = truncate_view(c_p, CAND_WORDS, "headtail", text_seed(c_p))
            cand_proc.append(c_t)

        anchors_rep = [anc_t] * len(cand_proc)
        scores = score_pairs(model, anchors_rep, cand_proc).detach().cpu().numpy()
        hard_idx = int(np.argmax(scores))
        hard_neg = cand_proc[hard_idx]

        for _ in range(add_per_ex):
            extra.append({
                "anchor_text": anc,
                "text_a": pos,
                "text_b": hard_neg,
                "text_a_is_closer": True,
                "type": "neg_model_mined",
            })

    return extra

# ============================================================
# Training
# ============================================================

def train(
    model: PairwiseRanker,
    train_rows: List[dict],
    train_pool_texts: List[str],
    val_loader,
):
    optimizer = torch.optim.AdamW([
        {
            "params": [p for n, p in model.named_parameters()
                       if not n.startswith("scorer") and p.requires_grad],
            "lr": LR_BACKBONE,
            "weight_decay": WEIGHT_DECAY,
        },
        {
            "params": [p for n, p in model.named_parameters()
                       if n.startswith("scorer") and p.requires_grad],
            "lr": LR_HEAD,
            "weight_decay": 0.0,
        }
    ])
    scaler = torch.amp.GradScaler("cuda", enabled=(DEVICE == "cuda"))

    fgm = FGM(model, eps=FGM_EPS) if USE_FGM else None

    best_acc, bad_epochs = 0.0, 0

    for epoch in range(1, EPOCHS + 1):
        # optional: model-based mining once after epoch 1
        if USE_MODEL_MINING and epoch == MODEL_MINING_EPOCH:
            print("\n[Model Mining] Adding train-only hard negatives (no label usage).")
            mined_extra = model_mine_hard_negatives(
                model=model,
                train_rows=train_rows,
                train_pool_texts=train_pool_texts,
                add_per_ex=MODEL_MINING_ADD_PER_EX,
                shortlist_k=MODEL_MINING_SHORTLIST,
            )
            if mined_extra:
                print(f"[Model Mining] Added {len(mined_extra)} extra rows.")
                train_rows = train_rows + mined_extra
            else:
                print("[Model Mining] No rows added.")

        train_loader = DataLoader(
            NarrativeDataset(train_rows, is_train=True),
            batch_size=BATCH_SIZE,
            shuffle=True,
            collate_fn=collate_train_strings,
        )

        num_steps = (EPOCHS * len(train_loader) + GRAD_ACC - 1) // GRAD_ACC
        scheduler = get_linear_schedule_with_warmup(optimizer, int(0.1 * num_steps), num_steps)

        model.train()
        optimizer.zero_grad(set_to_none=True)
        accum = 0
        loss_sum = 0.0

        print("\nStarting Training...")
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}")

        for step, (anchors, pos_list, neg_list) in enumerate(pbar, 1):
            B = len(anchors)
            candidates = pos_list + neg_list  # 2B candidates
            targets = torch.arange(B, device=DEVICE)  # pos indices are [0..B-1]

            with torch.amp.autocast("cuda", enabled=(DEVICE == "cuda")):
                # R-Drop: two stochastic passes for logits matrix
                logits1 = score_matrix(model, anchors, candidates)
                if USE_R_DROP:
                    logits2 = score_matrix(model, anchors, candidates)
                else:
                    logits2 = None

                # In-batch CE (InfoNCE-style)
                if USE_INBATCH_NCE:
                    ce1 = F.cross_entropy(logits1, targets)
                    if logits2 is not None:
                        ce2 = F.cross_entropy(logits2, targets)
                        ce = 0.5 * (ce1 + ce2)
                    else:
                        ce = ce1
                else:
                    # fallback: only own neg
                    pos_scores = logits1[torch.arange(B), targets]
                    neg_scores = logits1[torch.arange(B), torch.arange(B) + B]
                    ce = F.softplus(-(pos_scores - neg_scores)).mean()

                # Pairwise aux loss (own pos vs own neg with margin)
                if USE_PAIRWISE_AUX:
                    pos_scores = logits1[torch.arange(B), targets]
                    neg_scores = logits1[torch.arange(B), torch.arange(B) + B]
                    pair_loss = F.softplus(-(pos_scores - neg_scores) + MARGIN).mean()
                else:
                    pair_loss = torch.tensor(0.0, device=DEVICE)

                # R-Drop KL on per-anchor candidate distribution
                if USE_R_DROP and logits2 is not None:
                    p1 = F.log_softmax(logits1, dim=1)
                    p2 = F.log_softmax(logits2, dim=1)
                    q1 = p1.exp()
                    q2 = p2.exp()
                    kl12 = F.kl_div(p1, q2, reduction="batchmean")
                    kl21 = F.kl_div(p2, q1, reduction="batchmean")
                    kl = 0.5 * (kl12 + kl21)
                else:
                    kl = torch.tensor(0.0, device=DEVICE)

                loss = ce + PAIRWISE_ALPHA * pair_loss + R_DROP_BETA * kl

            scaler.scale(loss / GRAD_ACC).backward()

            # FGM adversarial step (single extra forward/backward)
            if USE_FGM and fgm is not None:
                attacked = fgm.attack()
                if attacked:
                    with torch.amp.autocast("cuda", enabled=(DEVICE == "cuda")):
                        logits_adv = score_matrix(model, anchors, candidates)
                        ce_adv = F.cross_entropy(logits_adv, targets)
                        if USE_PAIRWISE_AUX:
                            pos_adv = logits_adv[torch.arange(B), targets]
                            neg_adv = logits_adv[torch.arange(B), torch.arange(B) + B]
                            pair_adv = F.softplus(-(pos_adv - neg_adv) + MARGIN).mean()
                        else:
                            pair_adv = torch.tensor(0.0, device=DEVICE)
                        loss_adv = ce_adv + PAIRWISE_ALPHA * pair_adv
                    scaler.scale(loss_adv / GRAD_ACC).backward()
                    fgm.restore()

            loss_sum += float(loss.item())
            accum += 1

            if accum == GRAD_ACC:
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                optimizer.zero_grad(set_to_none=True)
                accum = 0

            pbar.set_postfix(loss=loss_sum / step)

        val_acc = evaluate(model, val_loader)
        print(f"Epoch {epoch} | Val Acc: {val_acc:.4f}")

        if val_acc > best_acc + 0.001:
            best_acc = val_acc
            bad_epochs = 0
            torch.save(model.state_dict(), "best_model.pt")
            print("  [*] New Best Model Saved!")
        else:
            bad_epochs += 1
            if bad_epochs >= PATIENCE:
                print(f"Early stopping. Best: {best_acc:.4f}")
                break

    return model

# ============================================================
# Standard eval
# ============================================================

@torch.no_grad()
def evaluate(model, loader) -> float:
    model.eval()
    correct, total = 0, 0
    for tok_a, tok_b, lab in loader:
        tok_a = {k: v.to(DEVICE) for k, v in tok_a.items()}
        tok_b = {k: v.to(DEVICE) for k, v in tok_b.items()}
        lab = lab.to(DEVICE)
        s_a = model(**tok_a)
        s_b = model(**tok_b)
        correct += ((s_a > s_b).float() == lab).sum().item()
        total += lab.size(0)
    return correct / total if total else 0.0

# ============================================================
# Salience-guided sentence TTA using DeBERTa scoring
# ============================================================

@torch.no_grad()
def sentence_salience_view(model: PairwiseRanker, anchor: str, cand: str, topk: int,
                           keep_last: int = 2, keep_first: bool = True,
                           max_sents: int = 40) -> str:
    """
    Build a candidate view by selecting top-K salient sentences w.r.t anchor,
    where salience is DeBERTa score of (anchor, sentence).
    Label-free, uses only the test instance inputs.
    """
    anchor_p = preprocess_text(anchor)
    cand_p = preprocess_text(cand)

    sents = split_sentences(cand_p)
    if not sents:
        return cand_p
    if len(sents) > max_sents:
        # Keep early and late sentences to preserve narrative arcs
        head = sents[: max_sents // 2]
        tail = sents[-(max_sents - len(head)):]
        sents = head + tail

    # define protected zones
    first_sent = sents[0:1]
    last_sents = sents[-keep_last:] if keep_last > 0 and len(sents) > keep_last else sents[-1:]
    mid_sents = sents[1:len(sents) - len(last_sents)] if len(sents) > (1 + len(last_sents)) else []

    # score mid sentences
    if mid_sents:
        anc_rep = [anchor_p] * len(mid_sents)
        scores = score_pairs(model, anc_rep, mid_sents).detach().cpu().numpy()
        top_idx = np.argsort(scores)[::-1][: min(topk, len(mid_sents))]
        chosen = [mid_sents[i] for i in sorted(top_idx)]  # keep original order
    else:
        chosen = []

    out = []
    if keep_first and first_sent:
        out.extend(first_sent)
    out.extend(chosen)
    out.extend(last_sents)

    # de-dup while preserving order
    seen = set()
    final = []
    for s in out:
        ss = clean_text(s)
        if ss and ss not in seen:
            seen.add(ss)
            final.append(ss)

    return clean_text(" ".join(final))

def build_tta_views_for_triplet(model: PairwiseRanker, anc: str, a: str, b: str) -> List[Tuple[str, str, str]]:
    """
    Returns list of (anc_view, a_view, b_view) for TTA.
    Uses:
      - DeBERTa-based salience sentence selection (strong)
      - plus a couple truncation views (cheap)
    """
    anc_p = preprocess_text(anc)
    # Use stable anchor view
    anc_view = truncate_view(anc_p, ANC_WORDS, "headtail", text_seed(anc_p))

    views = []

    # truncation-only views (cheap)
    for v in TTA_TRUNC_VIEWS:
        a_p = preprocess_text(a)
        b_p = preprocess_text(b)
        a_v = truncate_view(a_p, CAND_WORDS, v, text_seed(a_p))
        b_v = truncate_view(b_p, CAND_WORDS, v, text_seed(b_p))
        views.append((anc_view, a_v, b_v))

    if USE_SALIENCE_TTA:
        for k in SALIENCE_TOPK_LIST:
            a_sv = sentence_salience_view(
                model=model, anchor=anc_view, cand=a, topk=k,
                keep_last=SALIENCE_KEEP_LAST, keep_first=SALIENCE_KEEP_FIRST,
                max_sents=SALIENCE_MAX_SENTS,
            )
            b_sv = sentence_salience_view(
                model=model, anchor=anc_view, cand=b, topk=k,
                keep_last=SALIENCE_KEEP_LAST, keep_first=SALIENCE_KEEP_FIRST,
                max_sents=SALIENCE_MAX_SENTS,
            )
            # Ensure final truncation to fit
            a_sv = truncate_view(a_sv, CAND_WORDS, "headtail", text_seed(a_sv))
            b_sv = truncate_view(b_sv, CAND_WORDS, "headtail", text_seed(b_sv))
            views.append((anc_view, a_sv, b_sv))

        # outcome-heavy view
        def last_k_sent_view(x: str, k: int = 4) -> str:
            xp = preprocess_text(x)
            sents = split_sentences(xp)
            tail = sents[-k:] if len(sents) >= k else sents
            return clean_text(" ".join(tail)) if tail else xp

        a_tail = truncate_view(last_k_sent_view(a, 4), CAND_WORDS, "headtail", text_seed(a))
        b_tail = truncate_view(last_k_sent_view(b, 4), CAND_WORDS, "headtail", text_seed(b))
        views.append((anc_view, a_tail, b_tail))

    # de-dup views
    dedup = []
    seen = set()
    for av, aa, bb in views:
        key = hashlib.md5((av + "||" + aa + "||" + bb).encode("utf-8")).hexdigest()
        if key not in seen:
            seen.add(key)
            dedup.append((av, aa, bb))
    return dedup

@torch.no_grad()
def diff_for_view(model: PairwiseRanker, anc_view: str, a_view: str, b_view: str) -> torch.Tensor:
    s_a = score_pairs(model, [anc_view], [a_view])[0]
    s_b = score_pairs(model, [anc_view], [b_view])[0]
    return s_a - s_b

# ============================================================
# Per-example test-time adaptation (least risky)
# ============================================================

def get_adapt_params(model: PairwiseRanker):
    # only adapt LayerNorm + scorer (small, low risk)
    params = []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if n.startswith("norm.") or n.startswith("scorer."):
            params.append(p)
    return params

def snapshot_adapt_params(model: PairwiseRanker):
    snap = {}
    for n, p in model.named_parameters():
        if n.startswith("norm.") or n.startswith("scorer."):
            snap[n] = p.detach().clone()
    return snap

def restore_adapt_params(model: PairwiseRanker, snap: dict):
    with torch.no_grad():
        for n, p in model.named_parameters():
            if n in snap:
                p.copy_(snap[n])

def predict_with_tta_and_per_example_adapt(model: PairwiseRanker, anc: str, a: str, b: str) -> float:
    """
    Returns aggregated diff (sA - sB) after:
      - salience-guided TTA views (DeBERTa-based)
      - optional per-example adaptation (only norm+scorer), no true labels
    """
    model.eval()

    views = build_tta_views_for_triplet(model, anc, a, b)

    # compute initial diffs across views (no adaptation)
    diffs0 = torch.stack([diff_for_view(model, av, aa, bb) for (av, aa, bb) in views]).detach()
    mean0 = diffs0.mean()
    agg0 = mean0.item()

    # optional: only adapt uncertain examples
    if (not USE_PER_EX_ADAPT) or (abs(agg0) >= ADAPT_ONLY_IF_ABS_DIFF_LT) or (len(views) < 2):
        return agg0

    # pseudo-label from aggregated sign (label-free)
    y = 1.0 if agg0 > 0 else 0.0
    y_sign = 1.0 if y == 1.0 else -1.0

    # snapshot + adapt only LN+head
    snap = snapshot_adapt_params(model)
    params = get_adapt_params(model)
    opt = torch.optim.Adam(params, lr=ADAPT_LR)

    # freeze everything else (safety)
    for n, p in model.named_parameters():
        if n.startswith("norm.") or n.startswith("scorer."):
            p.requires_grad = True
        else:
            p.requires_grad = False

    try:
        for _ in range(ADAPT_STEPS):
            # keep dropout off for stability
            model.eval()

            diffs = torch.stack([diff_for_view(model, av, aa, bb) for (av, aa, bb) in views])
            mean = diffs.mean()
            var = diffs.var(unbiased=False)

            # pseudo-label logistic loss encourages consistent sign across views
            # (no true labels; uses model's own pseudo-label)
            pseudo_loss = F.softplus(-y_sign * diffs).mean()

            # prevent drift away from original aggregated prediction
            drift = (mean - mean0).pow(2)

            loss = ADAPT_VAR_W * var + ADAPT_PSEUDO_W * pseudo_loss + ADAPT_DRIFT_W * drift

            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()

        # final agg
        model.eval()
        diffs1 = torch.stack([diff_for_view(model, av, aa, bb) for (av, aa, bb) in views]).detach()
        agg = diffs1.mean().item() if TTA_AGG == "mean" else diffs1.max().item()

    finally:
        # restore original parameters + requires_grad
        restore_adapt_params(model, snap)
        for _, p in model.named_parameters():
            p.requires_grad = True

    return agg

@torch.no_grad()
def evaluate_tta_adapt(model: PairwiseRanker, loader_raw) -> float:
    correct, total = 0, 0
    for anc, a, b, lab in tqdm(loader_raw, desc="Eval TTA+Adapt"):
        for x_anc, x_a, x_b, y in zip(anc, a, b, lab.tolist()):
            diff = predict_with_tta_and_per_example_adapt(model, x_anc, x_a, x_b)
            pred = 1.0 if diff > 0 else 0.0
            correct += 1 if pred == float(y) else 0
            total += 1
    return correct / total if total else 0.0

# ============================================================
# Main
# ============================================================

def main():
    train_path = "/content/dev_track_a.jsonl"
    val_path   = "/content/sample_track_a.jsonl"

    if not os.path.exists(train_path):
        print(f"Missing {train_path}")
        return

    train_rows = load_jsonl(train_path)
    val_rows   = load_jsonl(val_path)

    if val_rows:
        report_overlap(train_rows, val_rows, name="val")

    hard_train, train_pool_texts = harden_training_data(train_rows)

    val_loader = DataLoader(
        NarrativeDataset(val_rows, is_train=False),
        batch_size=8,
        shuffle=False,
        collate_fn=collate_eval_pairs,
    )
    val_loader_raw = DataLoader(
        NarrativeDataset(val_rows, is_train=False),
        batch_size=8,
        shuffle=False,
        collate_fn=collate_raw,
    )

    model = PairwiseRanker(CHECKPOINT, freeze_layers=FREEZE_LAYERS, pooling=POOLING).to(DEVICE)
    train(model, hard_train, train_pool_texts, val_loader)

    if os.path.exists("best_model.pt"):
        model.load_state_dict(torch.load("best_model.pt", map_location=DEVICE))

    print(f"\nFinal Eval (no TTA): {evaluate(model, val_loader):.4f}")

    # TTA + per-example adaptation
    # (transductive inference — allowed only if your rules allow weight updates at test time)
    print(f"Final Eval (Salience TTA + per-example Adapt): {evaluate_tta_adapt(model, val_loader_raw):.4f}")

if __name__ == "__main__":
    main()


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


[NLP] Backend = spacy
[Leakage Check] Exact-duplicate overlap train vs val: 0

[Hardening] Mining & Augmenting Training Set (TRAIN ONLY)...


Augmenting:   0%|          | 0/200 [00:00<?, ?it/s]


[Augmentation Stats] {'orig': 200, 'mined': 138, 'shuffle_neg': 44, 'spliced_neg': 17, 'outcome_swap_neg': 52, 'pos_word_dropout': 104, 'pos_sent_dropout': 103}

Starting Training...


Epoch 1:   0%|          | 0/329 [00:00<?, ?it/s]

Epoch 1 | Val Acc: 0.6410
  [*] New Best Model Saved!

[Model Mining] Adding train-only hard negatives (no label usage).


Model-mining (train-only):   0%|          | 0/658 [00:00<?, ?it/s]

[Model Mining] Added 656 extra rows.

Starting Training...


Epoch 2:   0%|          | 0/657 [00:00<?, ?it/s]

Epoch 2 | Val Acc: 0.6923
  [*] New Best Model Saved!

Starting Training...


Epoch 3:   0%|          | 0/657 [00:00<?, ?it/s]

Epoch 3 | Val Acc: 0.6923

Starting Training...


Epoch 4:   0%|          | 0/657 [00:00<?, ?it/s]

Epoch 4 | Val Acc: 0.6923

Starting Training...


Epoch 5:   0%|          | 0/657 [00:00<?, ?it/s]

Epoch 5 | Val Acc: 0.6667

Starting Training...


Epoch 6:   0%|          | 0/657 [00:00<?, ?it/s]

Epoch 6 | Val Acc: 0.7179
  [*] New Best Model Saved!

Final Eval (no TTA): 0.7179


Eval TTA+Adapt:   0%|          | 0/5 [00:00<?, ?it/s]

Final Eval (Salience TTA + per-example Adapt): 0.7692
