In [1]:
# ============================================================
# ONE-MODEL, GATED-ONLY DEFENSE PIPELINE (build matrices from scratch)
#
# - Builds CLS matrices M_q ∈ R^{K×D} from scratch:
#     Contriever retrieve -> cross-encoder rerank -> take last-layer [CLS]
# - Trains a SINGLE MIL model with RUN labels ONLY (no chunk labels in train/val)
#     SetTransformer encoder + Attention-MIL pooling
# - Defense is GATED:
#     If run_pred==0 => do NOT localize/remove any chunks
#     If run_pred==1 => localize chunks using scores s_i
# - Chunk labels are used ONLY at TEST time for eval-only reporting:
#     gt_rule: ("gpt" in chunk_id) for malicious runs; all-0 for benign runs
#
# Outputs saved to a NEW folder.
# ============================================================

import os, json, random, math
from pathlib import Path
from typing import List, Tuple, Dict, Any

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification


# -------------------------
# CONFIG
# -------------------------
SEED = 7

K = 5  # top-k reranked passages per run  (CHANGED: was 10)
RETRIEVER_MODEL_NAME = "facebook/contriever"

# ✅ CHANGED: BGE -> MiniLM cross-encoder reranker
RERANKER_MODEL_NAME  = "cross-encoder/ms-marco-MiniLM-L6-v2"

# ✅ NEW: tag for output folder/filenames (sanitized to be filesystem-safe)
RERANKER_TAG = "cross-encoder_ms-marco-MiniLM-L6-v2"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---- Paths (EDIT if needed) ----
CHUNKS_JSON_PATH_BENIGN = "/workspace/newrags/msmarcorag/msmarco_contriever_chunks.json"
EMB_JSON_PATH_BENIGN    = "/workspace/newrags/msmarcorag/msmarco_contriever_embeddings_contriever.json"

BASE_DIR_MAL = "/workspace/newrags/poisoned_msmarco"
CHUNKS_JSON_PATH_MAL = f"{BASE_DIR_MAL}/msmarco_contriever_chunks_poisoned.json"
EMB_JSON_PATH_MAL    = f"{BASE_DIR_MAL}/msmarco_contriever_embeddings_poisoned2.json"

QUERY_LIST_PATH = f"{BASE_DIR_MAL}/successful_poisoned_msmarco.json"

# ---- NEW output folder (different from previous runs) ----
# ✅ CHANGED: include reranker name/tag in the output directory
OUT_DIR = Path(f"{BASE_DIR_MAL}/msmarco_MIL_SetTransformer_GATEDONLY_{RERANKER_TAG}_K{K}_v1")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# ✅ CHANGED: include reranker tag in filenames (keeps artifacts distinct)
DATASET_NPZ   = OUT_DIR / f"dataset_with_chunkids_{RERANKER_TAG}_K{K}.npz"
BEST_CKPT_PTH = OUT_DIR / f"BEST_MIL_SetTransformer_GATEDONLY_{RERANKER_TAG}_K{K}.pth"
REPORT_JSON   = OUT_DIR / f"test_report_gated_only_{RERANKER_TAG}.json"
CALIB_JSON    = OUT_DIR / f"calibration_{RERANKER_TAG}.json"

# ---- Build matrices from scratch ----
BUILD_MATRICES = True
FORCE_REBUILD  = True  # set False if you want to reuse DATASET_NPZ

# ---- Split ----
TRAIN_SPLIT = 0.70
VAL_SPLIT   = 0.15  # remaining is test

# ---- Training ----
BATCH_SIZE   = 32
EPOCHS       = 35
LR           = 2e-4
WEIGHT_DECAY = 1e-4
GRAD_CLIP    = 1.0

# ---- Model ----
D_MODEL  = 256
N_HEADS  = 8
N_LAYERS = 3
DROPOUT  = 0.15
EMBED_DIM = 128  # for SupCon

# ---- Loss knobs (RUN labels only) ----
W_DET = 1.0

# Positive coverage (helps chunk TPR but can raise FPR if too strong)
W_COV   = 0.60
RHO_COV = 0.55

# Strong benign suppression (main lever for reducing chunk FPR)
W_NEG_MEAN = 1.20
W_NEG_MAX  = 0.80
W_NEG_LSE  = 0.50  # tail penalty on benign

# Prevent attention collapse on positives (useful when many chunks poisoned)
W_ATTN_ENT = 0.20

# Optional SupCon (still only run labels)
USE_SUPCON = True
W_SUPCON   = 0.50
SUPCON_TAU = 0.2

# ---- Thresholds ----
THR_DET = 0.5

# Calibrate thr_loc using benign VAL only (no chunk labels), then clamp to avoid crazy tiny thresholds
TARGET_CHUNK_FPR_BENIGN_VAL = 0.10  # desired fraction of benign chunks flagged (before gating)
MIN_THR_LOC = 0.02                 # clamp low
MAX_THR_LOC = 0.98                 # clamp high

# Alternative: remove top-r chunks for flagged runs (more stable than thresholds).
USE_TOPR_INSTEAD_OF_THRESHOLD = False
TOP_R = 2  # remove top 2 chunks when run_pred==1 (only used if USE_TOPR...=True)


# -------------------------
# Reproducibility
# -------------------------
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

seed_everything(SEED)


# ============================================================
# JSON / data helpers
# ============================================================
def _load_json(path: str):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def _parse_queries(obj) -> list:
    if not isinstance(obj, list):
        raise ValueError("QUERY_LIST_PATH must be a JSON list.")
    out = []
    for it in obj:
        if not isinstance(it, dict):
            continue
        qid = it.get("qid", it.get("query_id", it.get("id", None)))
        q   = it.get("query", it.get("question", it.get("text", None)))
        if qid is None or q is None:
            continue
        out.append({"qid": int(qid), "query": str(q)})
    if not out:
        raise ValueError("Could not parse any (qid, query) items from QUERY_LIST_PATH.")
    return out

def _get_chunk_id(chunk: dict) -> str:
    for k in ["chunk_id", "id", "_id", "docid", "passage_id"]:
        if k in chunk:
            return str(chunk[k])
    return str(hash(chunk.get("text", "")))

def _l2_normalize(x: np.ndarray, eps: float = 1e-12) -> np.ndarray:
    n = np.linalg.norm(x, axis=-1, keepdims=True)
    return x / (n + eps)

def _pad_or_truncate_rows(mat: np.ndarray, k: int) -> np.ndarray:
    mat = np.asarray(mat, dtype=np.float32)
    if mat.ndim != 2:
        raise ValueError(f"Expected 2D matrix, got {mat.shape}")
    if mat.shape[0] == k:
        return mat
    if mat.shape[0] > k:
        return mat[:k, :]
    pad = np.zeros((k - mat.shape[0], mat.shape[1]), dtype=np.float32)
    return np.vstack([mat, pad])

def _load_chunks_and_embs(chunks_path: str, emb_path: str):
    chunks = _load_json(chunks_path)
    embs   = _load_json(emb_path)

    # embeddings may be dict keyed by chunk_id OR list aligned with chunks
    if isinstance(embs, dict):
        aligned_chunks = []
        emb_list = []
        for ch in chunks:
            cid = _get_chunk_id(ch)
            if cid in embs:
                aligned_chunks.append(ch)
                emb_list.append(embs[cid])
        chunks = aligned_chunks
        P = np.array(emb_list, dtype=np.float32)
    elif isinstance(embs, list):
        P = np.array(embs, dtype=np.float32)
        if P.shape[0] != len(chunks):
            raise ValueError("Embeddings list length != chunks length.")
    else:
        raise ValueError("Unknown embeddings JSON format.")

    P = _l2_normalize(P.astype(np.float32))
    return chunks, P


# ============================================================
# Build matrices from scratch:
#   Contriever(q) -> retrieve topK in pre-embedded corpus -> cross-encoder rerank
#   Extract last-layer [CLS] hidden vector for each (q, p_i)
# ============================================================
@torch.no_grad()
def _mean_pool(last_hidden, mask):
    mask = mask.unsqueeze(-1)
    summed = (last_hidden * mask).sum(dim=1)
    denom = mask.sum(dim=1).clamp(min=1e-6)
    return summed / denom

@torch.no_grad()
def encode_query_contriever(query: str, retr_tok, retr_model) -> np.ndarray:
    inp = retr_tok(query, return_tensors="pt", truncation=True, padding=True, max_length=256).to(DEVICE)
    out = retr_model(**inp)
    emb = _mean_pool(out.last_hidden_state, inp["attention_mask"])
    emb = F.normalize(emb, p=2, dim=-1)
    return emb.squeeze(0).detach().cpu().numpy().astype(np.float32)

def dense_search(q_emb: np.ndarray, P: np.ndarray, top_k: int):
    sims = P @ q_emb  # cosine if both L2-normalized
    if top_k >= sims.shape[0]:
        idxs = np.argsort(-sims)
    else:
        idxs = np.argpartition(-sims, top_k-1)[:top_k]
        idxs = idxs[np.argsort(-sims[idxs])]
    return idxs.tolist()

@torch.no_grad()
def rerank_and_get_cls_and_ids(query: str, candidate_chunks: list, rr_tok, rr_model) -> Tuple[np.ndarray, List[str]]:
    pairs = [[query, c.get("text","")] for c in candidate_chunks]
    inp = rr_tok(pairs, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE)
    out = rr_model(**inp, output_hidden_states=True)

    scores = out.logits.squeeze(-1)         # (K,)
    last_hidden = out.hidden_states[-1]     # (K,L,H)
    cls = last_hidden[:, 0, :]              # (K,H)

    order = torch.argsort(scores, descending=True)
    cls = cls[order]

    cand_ids = [_get_chunk_id(c) for c in candidate_chunks]
    cand_ids = [cand_ids[i] for i in order.detach().cpu().tolist()]

    return cls.detach().cpu().numpy().astype(np.float32), cand_ids

def build_dataset_npz(force_rebuild: bool = False):
    if DATASET_NPZ.exists() and not force_rebuild:
        print(f"[build] dataset exists: {DATASET_NPZ}")
        return

    print(f"[build] Building matrices from scratch | device={DEVICE} | K={K}")
    retr_tok   = AutoTokenizer.from_pretrained(RETRIEVER_MODEL_NAME)
    retr_model = AutoModel.from_pretrained(RETRIEVER_MODEL_NAME).to(DEVICE).eval()

    rr_tok   = AutoTokenizer.from_pretrained(RERANKER_MODEL_NAME)
    rr_model = AutoModelForSequenceClassification.from_pretrained(RERANKER_MODEL_NAME).to(DEVICE).eval()

    D = int(getattr(rr_model.config, "hidden_size", -1))
    if D <= 0:
        raise ValueError("Could not read reranker hidden_size.")
    print(f"[build] Reranker hidden_size D={D}")

    ben_chunks, ben_P = _load_chunks_and_embs(CHUNKS_JSON_PATH_BENIGN, EMB_JSON_PATH_BENIGN)
    mal_chunks, mal_P = _load_chunks_and_embs(CHUNKS_JSON_PATH_MAL,    EMB_JSON_PATH_MAL)

    queries = _parse_queries(_load_json(QUERY_LIST_PATH))
    print(f"[build] queries={len(queries)} (we build 2 runs per query => 2*Q runs)")

    X_list, y_list, qid_list, chunkids_list = [], [], [], []

    def run_side(qid: int, query: str, chunks: list, P: np.ndarray, label: int):
        q_emb = encode_query_contriever(query, retr_tok, retr_model)
        idxs  = dense_search(q_emb, P, top_k=K)
        cand  = [chunks[i] for i in idxs]

        cls_mat, reranked_ids = rerank_and_get_cls_and_ids(query, cand, rr_tok, rr_model)  # (K,D)
        cls_mat = _pad_or_truncate_rows(cls_mat, K)

        X_list.append(cls_mat)
        y_list.append(label)
        qid_list.append(qid)
        chunkids_list.append(reranked_ids)

    for i, it in enumerate(queries, 1):
        qid, qtext = it["qid"], it["query"]

        # Run on benign corpus
        run_side(qid, qtext, ben_chunks, ben_P, label=0)

        # Run on malicious corpus
        run_side(qid, qtext, mal_chunks, mal_P, label=1)

        if i % 50 == 0:
            print(f"[build] {i}/{len(queries)} queries done")

    X = np.stack(X_list, axis=0).astype(np.float32)      # (N,K,D)
    y = np.array(y_list, dtype=np.int64)                 # (N,)
    qids = np.array(qid_list, dtype=np.int64)            # (N,)
    chunk_ids = np.array(chunkids_list, dtype=object)    # (N,) each is list length K

    # Balance classes for training stability
    idx0 = np.where(y==0)[0]
    idx1 = np.where(y==1)[0]
    n = min(len(idx0), len(idx1))
    rng = np.random.default_rng(SEED)
    idx0 = rng.choice(idx0, size=n, replace=False)
    idx1 = rng.choice(idx1, size=n, replace=False)
    idx = np.concatenate([idx0, idx1])
    rng.shuffle(idx)

    X = X[idx]; y = y[idx]; qids = qids[idx]; chunk_ids = chunk_ids[idx]

    np.savez_compressed(DATASET_NPZ, X=X, y=y, qids=qids, chunk_ids=chunk_ids)
    print(f"[build] saved {DATASET_NPZ}")
    print(f"        X={X.shape} y={y.shape} unique_qids={len(np.unique(qids))}")


# ============================================================
# Split by qid (prevents leakage of same query across splits)
# ============================================================
def split_by_qid(qids: np.ndarray, train=0.7, val=0.15, seed=SEED):
    uq = np.unique(qids)
    rng = np.random.default_rng(seed)
    rng.shuffle(uq)
    n = len(uq)
    n_tr = int(train*n)
    n_va = int(val*n)

    tr_q = uq[:n_tr]
    va_q = uq[n_tr:n_tr+n_va]
    te_q = uq[n_tr+n_va:]

    tr_mask = np.isin(qids, tr_q)
    va_mask = np.isin(qids, va_q)
    te_mask = np.isin(qids, te_q)
    return tr_mask, va_mask, te_mask


# ============================================================
# SupCon loss (run labels only)
# ============================================================
def supervised_contrastive_loss(z: torch.Tensor, y: torch.Tensor, tau: float = 0.2) -> torch.Tensor:
    B = z.size(0)
    sim = (z @ z.t()) / max(tau, 1e-6)
    sim = sim - torch.eye(B, device=z.device) * 1e9

    y = y.view(-1, 1)
    mask_pos = (y == y.t()).float()
    mask_pos = mask_pos - torch.eye(B, device=z.device)

    exp_sim = torch.exp(sim)
    denom = exp_sim.sum(dim=1, keepdim=True).clamp(min=1e-9)
    log_prob = sim - torch.log(denom)

    pos_count = mask_pos.sum(dim=1).clamp(min=1.0)
    loss = -(mask_pos * log_prob).sum(dim=1) / pos_count
    return loss.mean()


# ============================================================
# Model: SetTransformer encoder + per-chunk head + attention MIL pooling
# ============================================================
class SetTransformerBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dropout: float):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.ln1  = nn.LayerNorm(d_model)
        self.ff   = nn.Sequential(
            nn.Linear(d_model, 4*d_model),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(4*d_model, d_model),
        )
        self.ln2  = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        h, _ = self.attn(x, x, x, need_weights=False)
        x = self.ln1(x + self.drop(h))
        h = self.ff(x)
        x = self.ln2(x + self.drop(h))
        return x

class AttentionMILPooling(nn.Module):
    def __init__(self, d_model: int, d_attn: int = 128):
        super().__init__()
        self.V = nn.Linear(d_model, d_attn)
        self.w = nn.Linear(d_attn, 1)

    def forward(self, H):
        A = self.w(torch.tanh(self.V(H))).squeeze(-1)  # (B,K)
        a = torch.softmax(A, dim=1)                    # (B,K)
        bag = (a.unsqueeze(-1) * H).sum(dim=1)         # (B,d)
        return a, bag

class MIL_SetTransformer_AttnMIL(nn.Module):
    def __init__(self, D_in: int, d_model: int, n_heads: int, n_layers: int,
                 dropout: float, embed_dim: int):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(D_in, d_model),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(d_model, d_model),
        )
        self.blocks = nn.ModuleList([
            SetTransformerBlock(d_model, n_heads, dropout) for _ in range(n_layers)
        ])

        self.chunk_head = nn.Linear(d_model, 1)

        self.pool = AttentionMILPooling(d_model, d_attn=128)
        self.run_head = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(d_model, 1),
        )

        # For SupCon embedding (run-level)
        self.emb_head = nn.Sequential(
            nn.Linear(2*d_model, 2*d_model),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(2*d_model, embed_dim),
        )

    def forward(self, M):
        # M: (B,K,D)
        H = self.proj(M)
        for blk in self.blocks:
            H = blk(H)

        chunk_logits = self.chunk_head(H).squeeze(-1)  # (B,K)
        s = torch.sigmoid(chunk_logits)                # (B,K)

        a, bag = self.pool(H)
        run_logit = self.run_head(bag).squeeze(-1)     # (B,)
        y_hat = torch.sigmoid(run_logit)

        h_mean = H.mean(dim=1)
        h_max  = H.max(dim=1).values
        g = torch.cat([h_mean, h_max], dim=1)          # (B,2d)
        emb = self.emb_head(g)
        emb = F.normalize(emb, p=2, dim=1)

        return chunk_logits, s, a, y_hat, emb


# ============================================================
# Dataset
# ============================================================
class RunDataset(Dataset):
    def __init__(self, X: np.ndarray, y: np.ndarray):
        self.X = X.astype(np.float32)
        self.y = y.astype(np.int64)
    def __len__(self): return len(self.X)
    def __getitem__(self, idx):
        return torch.from_numpy(self.X[idx]), torch.tensor(self.y[idx], dtype=torch.long)


# ============================================================
# Metrics + eval-only chunk GT
# ============================================================
def _counts_from_preds(y_true: np.ndarray, y_pred: np.ndarray):
    tp = int(((y_true == 1) & (y_pred == 1)).sum())
    tn = int(((y_true == 0) & (y_pred == 0)).sum())
    fp = int(((y_true == 0) & (y_pred == 1)).sum())
    fn = int(((y_true == 1) & (y_pred == 0)).sum())
    return tp, tn, fp, fn

def _rates(tp, tn, fp, fn):
    eps = 1e-9
    acc  = (tp + tn) / max(tp + tn + fp + fn, 1)
    tpr  = tp / (tp + fn + eps)
    fpr  = fp / (fp + tn + eps)
    prec = tp / (tp + fp + eps)
    return {"acc": acc, "tpr": tpr, "fpr": fpr, "prec": prec}

def _chunk_gt_eval_only(y_run: np.ndarray, chunk_ids_obj: np.ndarray) -> np.ndarray:
    """
    Eval-only chunk GT:
      - benign runs: all 0
      - malicious runs: chunk=1 if "gpt" in chunk_id (case-insensitive)
    """
    N = len(y_run)
    gt = np.zeros((N, K), dtype=np.int64)
    for i in range(N):
        if int(y_run[i]) == 0:
            continue
        ids = chunk_ids_obj[i]
        for j, cid in enumerate(ids):
            if cid is None:
                continue
            if "gpt" in str(cid).lower():
                gt[i, j] = 1
    return gt


# ============================================================
# Calibration for thr_loc (benign val only, no chunk labels)
# ============================================================
@torch.no_grad()
def calibrate_thr_loc_benign_only(model: nn.Module, X_val: np.ndarray, y_val: np.ndarray,
                                  target_chunk_fpr: float) -> float:
    """
    Choose thr_loc so that on benign VAL runs (y=0), about target_chunk_fpr
    of chunk scores exceed thr_loc. (No chunk labels needed.)
    """
    model.eval()
    X = torch.from_numpy(X_val.astype(np.float32)).to(DEVICE)
    y = y_val.astype(np.int64)

    _, s, _, _, _ = model(X)
    s_np = s.detach().cpu().numpy()  # (N,K)

    neg_scores = s_np[y == 0].reshape(-1)
    if neg_scores.size == 0:
        return 0.5

    q = max(0.0, min(1.0, 1.0 - float(target_chunk_fpr)))
    thr = float(np.quantile(neg_scores, q))
    return thr


# ============================================================
# GATED-ONLY evaluation (what you asked for)
# ============================================================
@torch.no_grad()
def evaluate_gated_only(model: nn.Module, X_np: np.ndarray, y_np: np.ndarray, chunk_ids_obj: np.ndarray,
                        thr_det: float, thr_loc: float,
                        use_topr: bool = False, top_r: int = 2) -> Dict[str, Any]:
    """
    - Run-level detection metrics
    - Chunk-level eval-only metrics with GATING:
        if run_pred==0 => predict all-0 chunks
        if run_pred==1 => predict chunks via threshold OR top-r
    """
    model.eval()
    X = torch.from_numpy(X_np.astype(np.float32)).to(DEVICE)
    y = y_np.astype(np.int64)

    chunk_logits, s, _, y_hat, _ = model(X)

    run_prob = y_hat.detach().cpu().numpy()
    run_pred = (run_prob >= thr_det).astype(np.int64)

    # Run-level metrics
    tp, tn, fp, fn = _counts_from_preds(y, run_pred)
    run_report = {
        "counts": {"tp": tp, "tn": tn, "fp": fp, "fn": fn},
        "rates": _rates(tp, tn, fp, fn),
        "thr_det": float(thr_det),
        "pooling": "AttentionMIL(Ilse2018)",
    }

    # Chunk preds
    s_np = s.detach().cpu().numpy()  # (N,K)

    pred_chunk = np.zeros_like(s_np, dtype=np.int64)

    if use_topr:
        # For runs predicted poisoned, mark top-r by score
        for i in range(len(s_np)):
            if run_pred[i] == 0:
                continue
            idx = np.argsort(-s_np[i])[:top_r]
            pred_chunk[i, idx] = 1
        loc_mode = f"topr(r={top_r})"
    else:
        pred_chunk = (s_np >= thr_loc).astype(np.int64)
        # GATING
        pred_chunk[run_pred == 0, :] = 0
        loc_mode = f"threshold(thr_loc={thr_loc:.6f})"

    # Eval-only GT
    gt_chunk = _chunk_gt_eval_only(y, chunk_ids_obj)

    gt_f = gt_chunk.reshape(-1)
    pr_f = pred_chunk.reshape(-1)

    tp2, tn2, fp2, fn2 = _counts_from_preds(gt_f, pr_f)
    chunk_report = {
        "counts": {"tp": tp2, "tn": tn2, "fp": fp2, "fn": fn2},
        "rates": _rates(tp2, tn2, fp2, fn2),
        "localization": loc_mode,
        "gating": "Only localize when run_pred==1",
        "gt_rule": "eval-only: ('gpt' in chunk_id) for malicious runs; all-0 for benign runs",
    }

    return {"run_level": run_report, "chunk_level_eval_only_gated": chunk_report}

import torch
from pathlib import Path


# ============================================================
# Train (run-label-only) + test report (gated-only)
# ============================================================
def train_and_test():
    print(f"[out] OUT_DIR = {OUT_DIR}")

    if BUILD_MATRICES:
        build_dataset_npz(force_rebuild=FORCE_REBUILD)

    data = np.load(DATASET_NPZ, allow_pickle=True)
    X = data["X"].astype(np.float32)         # (N,K,D)
    y = data["y"].astype(np.int64)           # (N,)
    qids = data["qids"].astype(np.int64)     # (N,)
    chunk_ids = data["chunk_ids"]            # (N,) object

    N, K_, D = X.shape
    assert K_ == K
    print(f"[data] X={X.shape} y={y.shape} unique_qids={len(np.unique(qids))}")

    tr_mask, va_mask, te_mask = split_by_qid(qids, train=TRAIN_SPLIT, val=VAL_SPLIT, seed=SEED)
    X_tr, y_tr = X[tr_mask], y[tr_mask]
    X_va, y_va, ids_va = X[va_mask], y[va_mask], chunk_ids[va_mask]
    X_te, y_te, ids_te = X[te_mask], y[te_mask], chunk_ids[te_mask]
    print(f"[split] train={len(X_tr)} val={len(X_va)} test={len(X_te)}")

    train_loader = DataLoader(RunDataset(X_tr, y_tr), batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

    model = MIL_SetTransformer_AttnMIL(
        D_in=D, d_model=D_MODEL, n_heads=N_HEADS, n_layers=N_LAYERS,
        dropout=DROPOUT, embed_dim=EMBED_DIM
    ).to(DEVICE)

    opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    bce = nn.BCELoss()

    best_val = float("inf")
    best_state = None
    best_opt_state = None

    for ep in range(1, EPOCHS+1):
        model.train()
        tot = 0.0
        nseen = 0

        for xb, yb in train_loader:
            xb = xb.to(DEVICE)
            yb = yb.to(DEVICE)
            yb_f = yb.float()

            opt.zero_grad(set_to_none=True)

            chunk_logits, s, a, y_hat, emb = model(xb)

            # RUN-label-only losses
            L_det = bce(y_hat, yb_f)

            mean_s = s.mean(dim=1)                # (B,)
            max_s  = s.max(dim=1).values          # (B,)
            lse = torch.logsumexp(chunk_logits, dim=1) - math.log(K)  # (B,)

            neg_mask = (1 - yb_f)  # y=0
            pos_mask = yb_f        # y=1

            # Benign suppression (lower chunk FPR)
            L_neg_mean = (mean_s * neg_mask).mean()
            L_neg_max  = (max_s  * neg_mask).mean()
            L_neg_lse  = (torch.sigmoid(lse) * neg_mask).mean()

            # Positive coverage (keep chunk TPR)
            L_cov = (F.relu(RHO_COV - mean_s) * pos_mask).mean()

            # Attention entropy on positives (avoid collapse when many chunks are poison)
            eps = 1e-9
            ent = -(a * (a + eps).log()).sum(dim=1)  # (B,)
            L_attn_ent = ((-ent) * pos_mask).mean()  # maximize entropy on positives

            # SupCon (optional)
            if USE_SUPCON:
                L_sc = supervised_contrastive_loss(emb, yb, tau=SUPCON_TAU)
            else:
                L_sc = torch.tensor(0.0, device=DEVICE)

            loss = (
                W_DET * L_det
                + W_COV * L_cov
                + W_NEG_MEAN * L_neg_mean
                + W_NEG_MAX  * L_neg_max
                + W_NEG_LSE  * L_neg_lse
                + W_ATTN_ENT * L_attn_ent
                + W_SUPCON * L_sc
            )

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
            opt.step()

            bs = xb.size(0)
            tot += float(loss.item()) * bs
            nseen += bs

        train_loss = tot / max(nseen, 1)

        # Val model selection uses ONLY run BCE (still no chunk labels)
        model.eval()
        with torch.no_grad():
            Xv = torch.from_numpy(X_va.astype(np.float32)).to(DEVICE)
            yv = torch.from_numpy(y_va.astype(np.float32)).to(DEVICE)
            _, _, _, yhat_v, _ = model(Xv)
            val_loss = float(bce(yhat_v, yv).item())

        print(f"[ep {ep:02d}] train_loss={train_loss:.6f} val_runBCE={val_loss:.6f}")

        if val_loss < best_val - 1e-6:
            best_val = val_loss
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            best_opt_state = opt.state_dict()

            torch.save({
                "model_state_dict": best_state,
                "optimizer_state_dict": best_opt_state,
                "K": int(K),
                "D_EXPECTED": int(D),
                "embed_dim": int(EMBED_DIM),
                "best_val_loss": float(best_val),
                "seed": int(SEED),
                "reranker_model_name": str(RERANKER_MODEL_NAME),
                "reranker_tag": str(RERANKER_TAG),
                "note": "GATEDONLY: SetTransformer+AttnMIL. Train/Val use RUN labels only. Chunk GT used only for TEST eval.",
            }, BEST_CKPT_PTH)
            print(f"  -> saved BEST: {BEST_CKPT_PTH}")

    # Load best model
    ckpt = torch.load(BEST_CKPT_PTH, map_location=DEVICE)
    model.load_state_dict(ckpt["model_state_dict"])
    model.eval()

    # Calibrate thr_loc (benign val only, no chunk labels)
    if not USE_TOPR_INSTEAD_OF_THRESHOLD:
        thr_loc_raw = calibrate_thr_loc_benign_only(
            model, X_va, y_va, target_chunk_fpr=TARGET_CHUNK_FPR_BENIGN_VAL
        )
        thr_loc = float(np.clip(thr_loc_raw, MIN_THR_LOC, MAX_THR_LOC))
    else:
        thr_loc_raw = None
        thr_loc = None

    calib_obj = {
        "thr_det": float(THR_DET),
        "use_topr_instead_of_threshold": bool(USE_TOPR_INSTEAD_OF_THRESHOLD),
        "top_r": int(TOP_R),
        "target_chunk_fpr_benign_val": float(TARGET_CHUNK_FPR_BENIGN_VAL),
        "thr_loc_raw": None if thr_loc_raw is None else float(thr_loc_raw),
        "thr_loc_clamped": None if thr_loc is None else float(thr_loc),
        "thr_loc_clamp_min": float(MIN_THR_LOC),
        "thr_loc_clamp_max": float(MAX_THR_LOC),
        "reranker_model_name": str(RERANKER_MODEL_NAME),
        "reranker_tag": str(RERANKER_TAG),
    }
    with open(CALIB_JSON, "w", encoding="utf-8") as f:
        json.dump(calib_obj, f, indent=2)
    print(f"[calib] saved {CALIB_JSON}")

    # Test report (GATED ONLY)
    if USE_TOPR_INSTEAD_OF_THRESHOLD:
        report = evaluate_gated_only(
            model, X_te, y_te, ids_te,
            thr_det=THR_DET, thr_loc=0.5,
            use_topr=True, top_r=TOP_R
        )
    else:
        report = evaluate_gated_only(
            model, X_te, y_te, ids_te,
            thr_det=THR_DET, thr_loc=thr_loc,
            use_topr=False, top_r=TOP_R
        )

    print("\n=== TEST REPORT (run + chunk eval-only, GATED ONLY) ===")
    print(report)

    with open(REPORT_JSON, "w", encoding="utf-8") as f:
        json.dump(report, f, indent=2)
    print(f"[report] saved {REPORT_JSON}")

    return model, report


# ============================================================
# RUN
# ============================================================
if __name__ == "__main__":
    train_and_test()

[out] OUT_DIR = /workspace/newrags/poisoned_msmarco/msmarco_MIL_SetTransformer_GATEDONLY_cross-encoder_ms-marco-MiniLM-L6-v2_K5_v1
[build] Building matrices from scratch | device=cuda | K=5




tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/794 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

[build] Reranker hidden_size D=384
[build] queries=800 (we build 2 runs per query => 2*Q runs)
[build] 50/800 queries done
[build] 100/800 queries done
[build] 150/800 queries done
[build] 200/800 queries done
[build] 250/800 queries done
[build] 300/800 queries done
[build] 350/800 queries done
[build] 400/800 queries done
[build] 450/800 queries done
[build] 500/800 queries done
[build] 550/800 queries done
[build] 600/800 queries done
[build] 650/800 queries done
[build] 700/800 queries done
[build] 750/800 queries done
[build] 800/800 queries done
[build] saved /workspace/newrags/poisoned_msmarco/msmarco_MIL_SetTransformer_GATEDONLY_cross-encoder_ms-marco-MiniLM-L6-v2_K5_v1/dataset_with_chunkids_cross-encoder_ms-marco-MiniLM-L6-v2_K5.npz
        X=(1600, 5, 384) y=(1600,) unique_qids=800
[data] X=(1600, 5, 384) y=(1600,) unique_qids=800
[split] train=1120 val=240 test=240
[ep 01] train_loss=1.605589 val_runBCE=0.123899
  -> saved BEST: /workspace/newrags/poisoned_msmarco/msmarco_MI