In [None]:
!pip install torch transformers pandas scikit-learn tqdm textstat spacy accelerate xgboost

# Install and download all at once
!python -m spacy download en_core_web_sm
!python -m spacy download fr_core_news_sm
!python -m spacy download de_core_news_sm
!python -m spacy download es_core_news_sm

In [None]:
from google.colab import drive
drive.mount('/content/drive')


In [None]:
# ============================================
# Block A — Setup + Load datasets (Drive in, RUN out)
# ============================================

# (optional) mount Drive first in Colab
# from google.colab import drive
# drive.mount('/content/drive')

import os, re, math, random, json, contextlib
from typing import List, Dict, Any, Optional

import numpy as np
import pandas as pd

# Torch / HF
import torch
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(False)
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

# Quiet AMP deprecation warnings with thin wrappers
from torch.amp import autocast as _autocast, GradScaler as _GradScaler
def autocast(enabled=True):
    return _autocast(device_type="cuda") if (enabled and torch.cuda.is_available() and enabled) \
           else contextlib.nullcontext()
def GradScaler(**kw):
    if torch.cuda.is_available():
        return _GradScaler(device="cuda", **kw)
    return _GradScaler(**kw)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def seed_everything(seed: int = 13):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed);
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# -----------------------------
# Config: READ from Drive; SAVE to /content/emc_run
# -----------------------------
class Config:
    # READ inputs/artifacts FROM Drive (put your CSVs here)
    DATA_DIR = "/content/drive/MyDrive/emc"

    # SAVE all new artifacts/checkpoints TO local runtime (won't sync to Drive)
    RUN_DIR  = "/content/emc_run"

    # Data (read from Drive)
    TRAIN_PATH     = os.path.join(DATA_DIR, "train.csv")
    OOD_DATA_PATH  = os.path.join(DATA_DIR, "ood_dataset.csv")
    TERMS_PATH     = os.path.join(DATA_DIR, "engineering_terms.csv")  # used later by FeatureExtractor

    # Artifacts (save to /content/emc_run)
    XLM_R_MODEL_PATH   = os.path.join(RUN_DIR, "xlmr_only_outputs/pytorch_model.bin")
    SIMPLE_PT_PATH     = os.path.join(RUN_DIR, "simple_fusion_outputs/fusion_simple.pt")
    GATED_PT_PATH      = os.path.join(RUN_DIR, "gated_fusion_outputs/fusion_gated.pt")
    SIMPLE_SCALER_PATH = os.path.join(RUN_DIR, "simple_fusion_outputs/scaler12.pkl")
    GATED_SCALER_PATH  = os.path.join(RUN_DIR, "gated_fusion_outputs/scaler12.pkl")
    RESULTS_CSV_PATH   = os.path.join(RUN_DIR, "domain_adaptation_results.csv")
    DAPT_SAVE_DIR      = os.path.join(RUN_DIR, "dapt_mlm")

    # Model settings (will be reused later)
    MODEL_NAME = "xlm-roberta-base"
    MAX_LEN = 256
    FEAT_DIM = 12
    NUM_LABELS = 2  # will be overwritten after we infer

    # Training/eval defaults (used later)
    VAL_SPLIT = 0.1
    BATCH_SIZE = 32
    NUM_WORKERS = 2
    PIN_MEMORY = (DEVICE == "cuda")

    # Few-shot DA
    DA_ADAPT_RATIO = 0.4
    DA_EPOCHS = 4
    DA_LR = 1.5e-5
    DA_UNFREEZE_TOP_N = 16

# Make sure local run dirs exist
os.makedirs(Config.RUN_DIR, exist_ok=True)
for p in [
    os.path.dirname(Config.XLM_R_MODEL_PATH),
    os.path.dirname(Config.SIMPLE_PT_PATH),
    os.path.dirname(Config.GATED_PT_PATH),
    os.path.dirname(Config.SIMPLE_SCALER_PATH),
    os.path.dirname(Config.GATED_SCALER_PATH),
    Config.DAPT_SAVE_DIR,
]:
    os.makedirs(p, exist_ok=True)

# -----------------------------
# Robust data loading helpers
# -----------------------------
def _require_file(path: str, name: str):
    if not os.path.isfile(path):
        raise FileNotFoundError(f"Missing {name} at: {path}")

def _clean_text_series(s: pd.Series) -> pd.Series:
    # ensure str; strip; replace NaN with ""
    s = s.astype(str).fillna("").map(lambda x: x.strip())
    # collapse super long whitespace
    s = s.map(lambda x: re.sub(r"\s+", " ", x))
    return s

def _ensure_lang_col(df: pd.DataFrame) -> pd.DataFrame:
    if 'lang' not in df.columns:
        df = df.copy()
        df['lang'] = 'en'
    return df

def _drop_empty_rows(df: pd.DataFrame, text_col="content") -> pd.DataFrame:
    before = len(df)
    df = df[df[text_col].astype(str).str.strip().str.len() > 0].copy()
    after = len(df)
    if after < before:
        print(f"• Dropped {before - after} empty-text rows")
    return df

def build_global_label_mapping(train_df: pd.DataFrame, ood_df: pd.DataFrame) -> Dict[Any, int]:
    """
    Build a consistent mapping over ALL labels in train + OOD so that
    labels are contiguous ints 0..K-1. Return dict original_label -> idx.
    """
    all_labels = pd.concat([train_df['label'], ood_df['label']], axis=0)
    # If already int and contiguous 0..K-1, keep identity mapping
    if pd.api.types.is_integer_dtype(all_labels):
        uniq = sorted(all_labels.unique().tolist())
        if uniq and uniq[0] == 0 and uniq[-1] == len(uniq)-1:
            return {i: i for i in uniq}
    # Otherwise, map sorted unique values to 0..K-1
    uniq = sorted(all_labels.unique().tolist(), key=lambda x: str(x))
    return {v: i for i, v in enumerate(uniq)}

def apply_label_mapping(df: pd.DataFrame, mapping: Dict[Any, int]) -> pd.DataFrame:
    df = df.copy()
    df['label'] = df['label'].map(mapping)
    if df['label'].isna().any():
        missing = df[df['label'].isna()]
        raise ValueError(f"Found labels not in mapping: {missing['label'].tolist()[:5]} ...")
    df['label'] = df['label'].astype(int)
    return df

def summarize_dataset(train_df: pd.DataFrame, ood_df: pd.DataFrame):
    print("\n=== Dataset summary ===")
    print(f"device: {DEVICE}")
    print(f"train: {len(train_df):,} rows | ood: {len(ood_df):,} rows")
    # train label dist
    tl = train_df['label'].value_counts().sort_index()
    print("train label distribution:", tl.to_dict())
    if 'domain' in ood_df.columns:
        dom_counts = ood_df['domain'].value_counts().to_dict()
        print("ood domains:", dom_counts)
    print("=======================")

# -----------------------------
# Load, clean, map labels
# -----------------------------
seed_everything(13)
print(f"Using device: {DEVICE}")

# require CSVs on Drive
_require_file(Config.TRAIN_PATH, "train.csv")
_require_file(Config.OOD_DATA_PATH, "ood_dataset.csv")
_require_file(Config.TERMS_PATH, "engineering_terms.csv")  # used later; just verifying now

# read
train_df = pd.read_csv(Config.TRAIN_PATH)
ood_df   = pd.read_csv(Config.OOD_DATA_PATH)

# basic column checks
for c in ['content', 'label']:
    if c not in train_df.columns:
        raise KeyError(f"train.csv missing required column: {c}")
for c in ['content', 'label', 'domain']:
    if c not in ood_df.columns:
        raise KeyError(f"ood_dataset.csv missing required column: {c}")

# clean text; ensure lang; drop empty rows
train_df = train_df.copy()
train_df['content'] = _clean_text_series(train_df['content'])
train_df = _ensure_lang_col(train_df)
train_df = _drop_empty_rows(train_df, 'content')

ood_df = ood_df.copy()
ood_df['content'] = _clean_text_series(ood_df['content'])
ood_df = _ensure_lang_col(ood_df)
ood_df = _drop_empty_rows(ood_df, 'content')

# build & apply a GLOBAL label mapping  (IMPORTANT for consistency)
label_map = build_global_label_mapping(train_df, ood_df)
with open(os.path.join(Config.RUN_DIR, "label_mapping.json"), "w") as f:
    json.dump({str(k): int(v) for k, v in label_map.items()}, f, indent=2)

train_df = apply_label_mapping(train_df, label_map)
ood_df   = apply_label_mapping(ood_df,   label_map)

# infer NUM_LABELS from mapped data and store into Config
Config.NUM_LABELS = int(pd.unique(pd.concat([train_df['label'], ood_df['label']], axis=0)).shape[0])
print("Detected NUM_LABELS =", Config.NUM_LABELS)

# (optional) persist cleaned copies used for all runs (for reproducibility)
train_clean_path = os.path.join(Config.RUN_DIR, "train_clean.csv")
ood_clean_path   = os.path.join(Config.RUN_DIR, "ood_clean.csv")
train_df.to_csv(train_clean_path, index=False)
ood_df.to_csv(ood_clean_path, index=False)
print(f"Saved cleaned datasets to:\n  {train_clean_path}\n  {ood_clean_path}")

# quick summary
summarize_dataset(train_df, ood_df)

# NOTE:
# - Next cells (training, DAPT loading, fusion, adaptation, etc.) should
#   reference paths from Config.* and use the already-loaded train_df/ood_df.
# - All artifacts will be saved to /content/emc_run (RUN_DIR), keeping Drive clean.


In [None]:
# ============================================
# Block B — Tokenizer, features, datasets, fusion models & utils
# ============================================

from transformers import AutoTokenizer, AutoModel
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from tqdm import tqdm
import joblib

# --- Tokenizer
tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_NAME)

# --- Terms lexicon & 12D feature extractor -----------------
import re

_WORD_RE = re.compile(r"\w+", re.UNICODE)
def simple_words(t: str): return _WORD_RE.findall(t or "")

def sent_count(t: str) -> int:
    if not t: return 0
    return max(1, len(re.split(r'[.!?]+[\s\n]+', t)))

def punct_count(t: str) -> int:
    return sum(1 for ch in (t or "") if ch in ".,;:!?")

class TermsLexicon:
    def __init__(self, csv_path: str, term_col="terms", lang_col="lang"):
        import pandas as pd, os
        if not os.path.exists(csv_path): raise FileNotFoundError(csv_path)
        df = pd.read_csv(csv_path)
        if term_col not in df.columns: raise ValueError(f"Missing '{term_col}'")
        if lang_col not in df.columns: df[lang_col] = 'en'
        self.by_lang = {
            str(l).lower(): set(
                str(x).strip().lower()
                for x in d[term_col].dropna().tolist() if str(x).strip()
            )
            for l, d in df.groupby(lang_col)
        }

    def pct_in_text(self, text: str, lang: str) -> float:
        if not text: return 0.0
        terms = self.by_lang.get((lang or "en").lower(), set())
        if not terms: return 0.0
        ws = [w.lower() for w in simple_words(text)]
        if not ws: return 0.0
        return sum(1 for w in ws if w in terms) / max(1, len(ws))

import numpy as np

_NUM_RE = re.compile(r"[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?")
def extract_numbers(text: str):
    nums, dec = [], 0
    for m in _NUM_RE.finditer(text or ""):
        s = m.group(0)
        try:
            v = float(s)
            if ('.' in s) or ('e' in s.lower()): dec += 1
            nums.append(abs(v))
        except: pass
    return nums, dec

def _finite_or_zero(x: float) -> float:
    return float(x) if np.isfinite(x) else 0.0

# Optional textstat for English readability
try:
    import textstat
    _HAS_TEXTSTAT = True
except Exception:
    _HAS_TEXTSTAT = False

STD_TERMS = {"iso","asme","ieee","din","ansi","iec","ul","astm","en"}
SAFETY_TERMS = {"safety","hazard","warning","risk","caution","danger","emergency"}

class FeatureExtractor12:
    MAX_CHARS, MAX_WORDS, MAX_SENTS, MAX_PUNCT, MAX_NUM_COUNT = 200_000, 40_000, 10_000, 50_000, 10_000
    MAX_NUM_ABS, MAX_AVG_MAG = 1e12, 1e12
    def __init__(self, tlex: TermsLexicon): self.tlex = tlex

    def extract_one(self, text: str, lang: str) -> np.ndarray:
        text = "" if text is None else str(text); lang = (lang or "en").lower()
        ws = simple_words(text); n_words = len(ws)
        chars = min(len(text), self.MAX_CHARS)
        words = min(n_words, self.MAX_WORDS)
        sents = min(sent_count(text), self.MAX_SENTS)

        if lang == "en" and _HAS_TEXTSTAT and text.strip():
            fre = _finite_or_zero(textstat.flesch_reading_ease(text))
            fog = _finite_or_zero(textstat.gunning_fog(text))
        else:
            fre = fog = 0.0

        eng_pct = self.tlex.pct_in_text(text, lang)
        punc = min(punct_count(text), self.MAX_PUNCT)

        nums, dec_cnt = extract_numbers(text)
        nnums = min(len(nums), self.MAX_NUM_COUNT)
        avg_mag = min(float(np.mean([min(v, self.MAX_NUM_ABS) for v in nums])) if nums else 0.0, self.MAX_AVG_MAG)
        dec_ratio = float(dec_cnt / len(nums)) if nums else 0.0

        low = text.lower()
        has_std = 1.0 if any(t in low for t in STD_TERMS) else 0.0
        has_saf = 1.0 if any(t in low for t in SAFETY_TERMS) else 0.0

        feats = np.array([chars, words, sents, fre, fog, eng_pct, punc, nnums, has_std, has_saf, avg_mag, dec_ratio], dtype=np.float32)
        if not np.all(np.isfinite(feats)):
            feats = np.nan_to_num(feats, nan=0.0, posinf=self.MAX_AVG_MAG, neginf=0.0)
        return feats.astype(np.float32)

    def extract_df(self, df: pd.DataFrame) -> np.ndarray:
        assert 'content' in df.columns
        if 'lang' not in df.columns:
            df = df.copy(); df['lang'] = 'en'
        rows = [self.extract_one(r.get("content",""), r.get("lang","en")) for _, r in tqdm(df.iterrows(), total=len(df), desc="Extracting features")]
        return np.stack(rows, axis=0).astype(np.float32)

tlex = TermsLexicon(Config.TERMS_PATH)
fe = FeatureExtractor12(tlex)

# --- Datasets ----------------------------------------------
class TextFeatDataset(Dataset):
    def __init__(self, df: pd.DataFrame, tokenizer, feats: Optional[np.ndarray] = None):
        self.df = df.reset_index(drop=True)
        self.tok = tokenizer
        self.feats = feats
        self.labels = df['label'].values.astype(int) if 'label' in df.columns else None

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        enc = self.tok(
            str(row['content']),
            truncation=True,
            max_length=Config.MAX_LEN,
            padding="max_length",
            return_tensors="pt"
        )
        item = {
            'input_ids': enc['input_ids'].squeeze(0),
            'attention_mask': enc['attention_mask'].squeeze(0),
        }
        if self.feats is not None:
            item['feats'] = torch.tensor(self.feats[idx], dtype=torch.float32)
        if self.labels is not None:
            item['labels'] = torch.tensor(int(self.labels[idx]), dtype=torch.long)
        return item

def make_loader(dataset, batch_size, sampler=None, shuffle=False):
    return DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=sampler,
        shuffle=(shuffle if sampler is None else False),
        num_workers=Config.NUM_WORKERS,
        pin_memory=Config.PIN_MEMORY,
        drop_last=False
    )

# --- Fusion models ------------------------------------------
def masked_mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    mask = attention_mask.unsqueeze(-1).float()
    return (last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-6)

class SimpleFusion(nn.Module):
    def __init__(self, model_name: str, n_feats: int, n_labels: int = None):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        try: self.encoder.gradient_checkpointing_enable()
        except Exception: pass
        H = self.encoder.config.hidden_size
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(H + n_feats, n_labels or Config.NUM_LABELS)

    def forward(self, input_ids, attention_mask, feats):
        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        pooled = masked_mean_pool(out.last_hidden_state, attention_mask)
        fused = torch.cat([pooled, feats], dim=1)
        return self.classifier(self.dropout(fused))

class GatedFusion(nn.Module):
    def __init__(self, model_name: str, n_feats: int, n_labels: int = None, feat_proj: int = 64):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        try: self.encoder.gradient_checkpointing_enable()
        except Exception: pass
        H = self.encoder.config.hidden_size
        self.fe_proj = nn.Sequential(nn.Linear(n_feats, feat_proj), nn.ReLU())
        self.gate = nn.Sequential(nn.Linear(H + n_feats, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid())
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(H + feat_proj, n_labels or Config.NUM_LABELS)

    def forward(self, input_ids, attention_mask, feats):
        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        pooled = masked_mean_pool(out.last_hidden_state, attention_mask)
        alpha = self.gate(torch.cat([pooled, feats], dim=1))
        ef = self.fe_proj(feats)
        fused = torch.cat([pooled, alpha * ef], dim=1)
        return self.classifier(self.dropout(fused))

def unfreeze_top_n(model: nn.Module, top_n: int):
    # Works for HF encoders (roberta/xlm_roberta)
    enc = (getattr(model, "base_model", None)
           or getattr(model, "roberta", None)
           or getattr(model, "xlm_roberta", None)
           or getattr(model, "encoder", None)  # for fusion models' encoder
          )
    if enc is None:
        return
    if hasattr(enc, "encoder") and hasattr(enc.encoder, "layer"):
        layers = enc.encoder.layer
        K = len(layers)
        for i, layer in enumerate(layers):
            req = (i >= K - top_n)
            for p in layer.parameters(): p.requires_grad = req
    # always keep classifier trainable
    if hasattr(model, "classifier"):
        for p in model.classifier.parameters():
            p.requires_grad = True

print("Block B ready ✔️")


In [None]:
# ============================================
# Block C — DAPT training + DAPT/Head loaders (robust copy)
# ============================================

from transformers import (
    AutoModelForMaskedLM, AutoModelForSequenceClassification,
    DataCollatorForLanguageModeling, TrainingArguments, Trainer
)

# --- Unlabeled dataset for MLM
class UnlabeledTextDataset(Dataset):
    def __init__(self, texts: List[str], tokenizer, max_len: int):
        self.texts = texts
        self.tok = tokenizer
        self.max_len = max_len
    def __len__(self): return len(self.texts)
    def __getitem__(self, idx):
        enc = self.tok(
            str(self.texts[idx]),
            truncation=True,
            max_length=self.max_len,
            padding="max_length",
            return_tensors="pt"
        )
        return {'input_ids': enc['input_ids'].squeeze(0), 'attention_mask': enc['attention_mask'].squeeze(0)}

def run_mlm_corpus_training(
    texts: List[str],
    save_dir: str,
    epochs: int,
    batch: int,
    lr: float,
    warmup_ratio: float,
    max_steps: Optional[int],
    tokenizer,
):
    if len(texts) == 0:
        print("⚠️ DAPT skipped: no texts")
        return
    os.makedirs(save_dir, exist_ok=True)
    model = AutoModelForMaskedLM.from_pretrained(Config.MODEL_NAME)
    model.to(DEVICE)
    # ✅ use shorter seq length for MLM to be memory-safe
    dataset = UnlabeledTextDataset(texts, tokenizer, max_len=Config.MLM_MAX_LEN)
    data_collator = DataCollatorForLanguageModeling(tokenizer, mlm_probability=0.15)

    training_args = TrainingArguments(
        output_dir=save_dir,
        overwrite_output_dir=True,
        num_train_epochs=epochs,
        per_device_train_batch_size=batch,
        learning_rate=lr,
        warmup_ratio=warmup_ratio,
        logging_steps=50,
        save_strategy="no",   # keep only final weights we save manually
        fp16=(DEVICE=="cuda"),
        gradient_checkpointing=True,
        max_steps=max_steps,
        dataloader_num_workers=Config.NUM_WORKERS,
        report_to="none",
        disable_tqdm=True,
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
    )
    print(f"Starting DAPT MLM on {DEVICE} …")
    trainer.train()
    # Write a plain state_dict() to a simple bin path
    torch.save(model.state_dict(), os.path.join(save_dir, "pytorch_model.bin"))
    print("DAPT complete.")

# ---- Helper: materialize the MLM model then copy its encoder 1:1
def _materialize_mlm_encoder_state(bin_path: str, base_model_name: str):
    """
    Loads a base MLM model, applies the saved state_dict (whatever its prefixes),
    and returns *its encoder* state_dict. This sidesteps key-prefix mismatches.
    """
    if not os.path.isfile(bin_path):
        return None
    # Create a fresh MLM model with the same architecture
    mlm = AutoModelForMaskedLM.from_pretrained(base_model_name)
    sd = torch.load(bin_path, map_location="cpu")
    if isinstance(sd, dict) and "state_dict" in sd:
        sd = sd["state_dict"]
    # Strip a leading "model." if present
    if sd and all(k.startswith("model.") for k in sd.keys()):
        sd = {k[len("model."):]: v for k, v in sd.items()}
    # Load what we can (strict=False is intentional)
    mlm.load_state_dict(sd, strict=False)

    src_enc = (getattr(mlm, "roberta", None)
               or getattr(mlm, "xlm_roberta", None))
    if src_enc is None:
        return None
    return src_enc.state_dict()

def safe_load_seqcls_from_dapt(dapt_dir: str, num_labels: int, base_model_name: str):
    """
    Build a SequenceClassification model and copy ONLY encoder weights from
    the DAPT MLM checkpoint by *materializing the MLM encoder* first.
    """
    m = AutoModelForSequenceClassification.from_pretrained(base_model_name, num_labels=num_labels)
    try: m.gradient_checkpointing_enable()
    except Exception: pass

    bin_path = os.path.join(dapt_dir, "pytorch_model.bin")
    src_enc_sd = _materialize_mlm_encoder_state(bin_path, base_model_name)
    if src_enc_sd is None:
        print(f"⚠️ No usable DAPT bin at {bin_path}; using base encoder.")
        return m

    dest_enc = (getattr(m, "base_model", None)
                or getattr(m, "roberta", None)
                or getattr(m, "xlm_roberta", None))
    if dest_enc is None:
        print("⚠️ Could not locate encoder in classifier; skipping DAPT copy.")
        return m

    dest_sd = dest_enc.state_dict()
    copied = 0
    new_sd = {}
    for k in dest_sd.keys():
        if k in src_enc_sd:
            new_sd[k] = src_enc_sd[k]
            copied += 1
    dest_enc.load_state_dict({**dest_sd, **new_sd}, strict=False)
    print(f"✅ DAPT encoder load (robust): copied {copied}/{len(dest_sd)} tensors from {bin_path}")
    return m

def load_ft_classifier_only_into(model, ft_state_path: str):
    """Load ONLY the classifier head from an in-domain fine-tuned checkpoint."""
    if not os.path.isfile(ft_state_path):
        print(f"⚠️ Classifier ckpt not found at {ft_state_path}")
        return
    sd = torch.load(ft_state_path, map_location="cpu")
    if isinstance(sd, dict) and "state_dict" in sd:
        sd = sd["state_dict"]
    head = {k: v for k, v in sd.items() if k.startswith("classifier.")}
    missing, unexpected = model.load_state_dict(head, strict=False)
    print(f"✅ Loaded classifier head from {ft_state_path} (kept encoder). "
          f"missing={len(missing)} unexpected={len(unexpected)}")

def fusion_available(pt_path: str, scaler_path: str) -> bool:
    return os.path.isfile(pt_path) and os.path.isfile(scaler_path)

def load_dapt_into_fusion_encoder(fusion_model, dapt_dir: str, base_model_name: str):
    """Copy DAPT encoder weights into fusion.encoder using the robust source encoder trick."""
    bin_path = os.path.join(dapt_dir, "pytorch_model.bin")
    src_enc_sd = _materialize_mlm_encoder_state(bin_path, base_model_name)
    if src_enc_sd is None:
        print(f"⚠️ No usable DAPT bin at {bin_path}; skipping DAPT→fusion.")
        return
    enc = getattr(fusion_model, "encoder", None)
    if enc is None:
        print("⚠️ Fusion has no .encoder; skip.")
        return
    dest_sd = enc.state_dict()
    copied = 0
    new_sd = {}
    for k in dest_sd.keys():
        if k in src_enc_sd:
            new_sd[k] = src_enc_sd[k]
            copied += 1
    enc.load_state_dict({**dest_sd, **new_sd}, strict=False)
    print(f"✅ DAPT→Fusion encoder (robust): copied {copied}/{len(dest_sd)} tensors from {bin_path}")

print("Block C ready ✔️")


In [None]:
# ============================================
# Block D — Supervised fine-tuning (base classifier)
# ============================================

from transformers import get_linear_schedule_with_warmup, AutoModelForSequenceClassification
from sklearn.model_selection import train_test_split

def train_transformer_supervised(train_df: pd.DataFrame, tokenizer, num_labels: int):
    tr, va = train_test_split(train_df, test_size=Config.VAL_SPLIT, stratify=train_df['label'], random_state=13)
    ds_tr = TextFeatDataset(tr, tokenizer, feats=None)
    ds_va = TextFeatDataset(va, tokenizer, feats=None)
    dl_tr = make_loader(ds_tr, batch_size=32, shuffle=True)
    dl_va = make_loader(ds_va, batch_size=64)

    model = AutoModelForSequenceClassification.from_pretrained(Config.MODEL_NAME, num_labels=num_labels).to(DEVICE)
    try: model.gradient_checkpointing_enable()
    except Exception: pass

    # Freeze embeddings + bottom layers for stability
    enc = (getattr(model, "base_model", None)
           or getattr(model, "roberta", None)
           or getattr(model, "xlm_roberta", None))
    if enc is not None and hasattr(enc, "embeddings"):
        for p in enc.embeddings.parameters(): p.requires_grad = False
    if enc is not None and hasattr(enc, "encoder") and hasattr(enc.encoder, "layer"):
        for i, layer in enumerate(enc.encoder.layer):
            if i < 4:  # freeze bottom 4
                for p in layer.parameters(): p.requires_grad = False

    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=2e-5, weight_decay=0.01)
    total_steps = len(dl_tr) * 3
    scheduler = get_linear_schedule_with_warmup(optimizer, int(0.06*total_steps), total_steps)
    loss_fn = nn.CrossEntropyLoss()
    scaler = GradScaler()

    best_f1, bad = -1.0, 0
    for epoch in range(3):
        model.train(); running = 0.0
        for b in tqdm(dl_tr, desc=f"FT Epoch {epoch+1}/3", leave=False):
            ids=b['input_ids'].to(DEVICE); am=b['attention_mask'].to(DEVICE); y=b['labels'].to(DEVICE)
            optimizer.zero_grad(set_to_none=True)
            with autocast(enabled=(DEVICE=="cuda")):
                out = model(input_ids=ids, attention_mask=am, return_dict=True)
                loss = loss_fn(out.logits, y)
            scaler.scale(loss).backward()
            scaler.step(optimizer); scaler.update(); scheduler.step()
            running += float(loss.item())

        # val
        f1 = evaluate_model(model, dl_va, 'transformer')
        print(f"Epoch {epoch+1}: loss={running/max(1,len(dl_tr)):.4f}  val_macroF1={f1:.4f}")
        if f1 > best_f1:
            best_f1, bad = f1, 0
            os.makedirs(os.path.dirname(Config.XLM_R_MODEL_PATH), exist_ok=True)
            torch.save(model.state_dict(), Config.XLM_R_MODEL_PATH)
            print("✅ Saved best model to RUN_DIR.")
        else:
            bad += 1
            if bad >= 2:
                print("⏹️ Early stopping")
                break

print("Block D ready ✔️")


In [None]:
# ============================================
# Block E — Eval/proba/uncertainty + adaptation
# ============================================

from sklearn.metrics import f1_score

def evaluate_model(model, data_loader, model_type: str = 'transformer') -> float:
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in tqdm(data_loader, desc=f"Evaluating {model_type}", leave=False):
            ids = batch['input_ids'].to(DEVICE); am = batch['attention_mask'].to(DEVICE)
            if 'labels' in batch:
                all_labels.extend(batch['labels'].cpu().numpy().tolist())

            if model_type == 'transformer':
                with autocast(enabled=(DEVICE=="cuda")):
                    logits = model(input_ids=ids, attention_mask=am, return_dict=True).logits
            elif model_type in ('simple','gated'):
                feats = batch['feats'].to(DEVICE)
                logits = model(ids, am, feats)
            else:
                raise ValueError(model_type)

            preds = logits.argmax(1).cpu().tolist()
            all_preds.extend(preds)
    return f1_score(all_labels, all_preds, average='macro') if all_labels else float("nan")

def get_probs_transformer(model, ds, idx, batch=64):
    if len(idx) == 0: return np.array([])
    dl = DataLoader(ds, batch_size=batch, sampler=SubsetRandomSampler(idx))
    model.eval(); out = []
    with torch.no_grad():
        for b in dl:
            ids=b['input_ids'].to(DEVICE); am=b['attention_mask'].to(DEVICE)
            with autocast(enabled=(DEVICE=="cuda")):
                logits = model(input_ids=ids, attention_mask=am, return_dict=True).logits
            p = torch.softmax(logits, dim=-1)[:, 1].cpu().numpy()
            out.append(p)
    return np.concatenate(out) if out else np.array([])

def get_probs_fusion(model, ds, idx, batch=64):
    if len(idx)==0: return np.array([])
    dl = DataLoader(ds, batch_size=batch, sampler=SubsetRandomSampler(idx))
    model.eval(); out=[]
    with torch.no_grad():
        for b in dl:
            ids=b['input_ids'].to(DEVICE); am=b['attention_mask'].to(DEVICE); feats=b['feats'].to(DEVICE)
            logits = model(ids, am, feats)
            p = torch.softmax(logits, dim=-1)[:, 1].cpu().numpy()
            out.append(p)
    return np.concatenate(out) if out else np.array([])

def pick_uncertain_indices_with_proba(proba_fn, k, n):
    idx_all = np.arange(n)
    if k <= 0 or n <= k:
        return np.array([], dtype=int), idx_all
    probs = proba_fn(idx_all)
    if probs.size == 0:
        rng = np.random.default_rng(13); rng.shuffle(idx_all)
        return idx_all[:k], idx_all[k:]
    p = np.clip(probs, 1e-6, 1-1e-6)
    ent = -(p*np.log(p) + (1-p)*np.log(1-p))
    order = np.argsort(-ent)
    adapt_idx = idx_all[order[:k]]
    test_idx  = idx_all[order[k:]]
    return adapt_idx, test_idx

def tune_weight_and_threshold(p_t, p_x, y_true, weights=None, thresholds=None):
    if len(p_t)==0 or len(y_true)==0:
        return (-1.0, 1.0, 0.5)
    if weights is None: weights = np.linspace(0.0, 1.0, 11)
    if thresholds is None: thresholds = np.linspace(0.2, 0.8, 25)
    best = (-1.0, 1.0, 0.5)
    for w in weights:
        mix = w * p_t + (1 - w) * (p_x if p_x is not None and len(p_x)==len(p_t) else 0.0)
        for thr in thresholds:
            preds = (mix >= thr).astype(int)
            f1 = f1_score(y_true, preds, average='macro')
            if f1 > best[0]: best = (f1, w, thr)
    return best

from transformers import get_linear_schedule_with_warmup

def adapt_transformer_fewshot(model, ds_all, adapt_idx, lr, epochs):
    dl = make_loader(ds_all, batch_size=Config.BATCH_SIZE, sampler=SubsetRandomSampler(adapt_idx))
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    steps = max(1, len(dl)*epochs)
    scheduler = get_linear_schedule_with_warmup(optimizer, int(0.06*steps), steps)
    scaler_amp = GradScaler()
    loss_fn = nn.CrossEntropyLoss()
    model.train()
    for _ in range(epochs):
        for b in dl:
            ids=b['input_ids'].to(DEVICE); am=b['attention_mask'].to(DEVICE); y=b['labels'].to(DEVICE)
            optimizer.zero_grad(set_to_none=True)
            with autocast(enabled=(DEVICE=="cuda")):
                logits = model(input_ids=ids, attention_mask=am, return_dict=True).logits
                loss = loss_fn(logits, y)
            scaler_amp.scale(loss).backward()
            scaler_amp.step(optimizer); scaler_amp.update(); scheduler.step()

def adapt_fusion_fewshot(model, ds_all, adapt_idx, lr, epochs):
    dl = make_loader(ds_all, batch_size=Config.BATCH_SIZE, sampler=SubsetRandomSampler(adapt_idx))
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    scaler_amp = GradScaler()
    loss_fn = nn.CrossEntropyLoss()
    model.train()
    for _ in range(epochs):
        for b in dl:
            ids=b['input_ids'].to(DEVICE); am=b['attention_mask'].to(DEVICE); y=b['labels'].to(DEVICE); feats=b['feats'].to(DEVICE)
            optimizer.zero_grad(set_to_none=True)
            with autocast(enabled=(DEVICE=="cuda")):
                logits = model(ids, am, feats)
                loss = loss_fn(logits, y)
            scaler_amp.scale(loss).backward()
            scaler_amp.step(optimizer); scaler_amp.update()

print("Block E ready ✔️")


In [None]:
# ============================================
# Block F — Fit scalers on train, init XGB + Fusion (no leakage)
# ============================================

from sklearn.preprocessing import StandardScaler
import xgboost as xgb

print("Extracting features for in-domain training (XGB & fusion scalers)…")
feats_tr_raw = fe.extract_df(train_df)
y_tr = train_df['label'].astype(int).values

# XGB scaler (fit on train only)
xgb_scaler = StandardScaler()
feats_tr_scaled = xgb_scaler.fit_transform(feats_tr_raw)

# XGB baseline
try:
    xgb_model = xgb.XGBClassifier(
        n_estimators=1500,
        objective="binary:logistic" if Config.NUM_LABELS == 2 else "multi:softprob",
        num_class=(Config.NUM_LABELS if Config.NUM_LABELS>2 else None),
        tree_method="hist", n_jobs=4, random_state=13,
        eval_metric="logloss" if Config.NUM_LABELS==2 else "mlogloss"
    )
    xgb_model.fit(feats_tr_scaled, y_tr)
    print("XGBoost baseline trained.")
except Exception as e:
    xgb_model = None
    print("⚠️ xgboost not installed or failed; skipping XGB baseline.", e)

# Fusion scalers (fit on train only; persist in RUN_DIR)
os.makedirs(os.path.dirname(Config.SIMPLE_SCALER_PATH), exist_ok=True)
os.makedirs(os.path.dirname(Config.GATED_SCALER_PATH), exist_ok=True)

if os.path.isfile(Config.SIMPLE_SCALER_PATH):
    simple_scaler = joblib.load(Config.SIMPLE_SCALER_PATH)
else:
    simple_scaler = StandardScaler().fit(feats_tr_raw)
    joblib.dump(simple_scaler, Config.SIMPLE_SCALER_PATH)

if os.path.isfile(Config.GATED_SCALER_PATH):
    gated_scaler = joblib.load(Config.GATED_SCALER_PATH)
else:
    gated_scaler = StandardScaler().fit(feats_tr_raw)
    joblib.dump(gated_scaler, Config.GATED_SCALER_PATH)

# Init fusion models; load any saved heads if present
simple = SimpleFusion(Config.MODEL_NAME, Config.FEAT_DIM, n_labels=Config.NUM_LABELS).to(DEVICE)
gated  = GatedFusion(Config.MODEL_NAME, Config.FEAT_DIM, n_labels=Config.NUM_LABELS, feat_proj=64).to(DEVICE)

if os.path.isfile(Config.SIMPLE_PT_PATH):
    sd = torch.load(Config.SIMPLE_PT_PATH, map_location="cpu")
    if isinstance(sd, dict) and "state_dict" in sd: sd = sd["state_dict"]
    simple.load_state_dict(sd, strict=False)

if os.path.isfile(Config.GATED_PT_PATH):
    sd = torch.load(Config.GATED_PT_PATH, map_location="cpu")
    if isinstance(sd, dict) and "state_dict" in sd: sd = sd["state_dict"]
    gated.load_state_dict(sd, strict=False)

print("Block F ready ✔️")


In [None]:
# ============================================
# Block G — FT if needed, per-domain DAPT + adaptation loop, save
# ============================================

from transformers import AutoModelForSequenceClassification
from sklearn.metrics import f1_score

# 1) Supervised FT (train once if no checkpoint)
if not os.path.isfile(Config.XLM_R_MODEL_PATH):
    print("🧪 Training base classifier (supervised FT)…")
    train_transformer_supervised(train_df, tokenizer, num_labels=Config.NUM_LABELS)
else:
    print("ℹ️ Found existing fine-tuned checkpoint; skipping FT.")

results: List[Dict[str, Any]] = []
domains = ood_df['domain'].unique()

for domain in domains:
    print(f"\n--- Processing domain: {domain} ---")
    seed_everything(13)                    # reproducibility per domain
    torch.cuda.empty_cache()               # avoid memory fragmentation

    ddf = ood_df[ood_df['domain']==domain].copy().reset_index(drop=True)
    y_true = ddf['label'].astype(int).values

    # DAPT: train per-domain if missing
    dapt_dir = os.path.join(Config.DAPT_SAVE_DIR, domain.replace(" ", "_"))
    if not os.path.isfile(os.path.join(dapt_dir, "pytorch_model.bin")):
        print(f"🧪 Running DAPT (MLM) for domain '{domain}' …")
        run_mlm_corpus_training(
            ddf['content'].astype(str).tolist(),
            save_dir=dapt_dir,
            epochs=Config.DAPT_EPOCHS,
            batch=Config.DAPT_BATCH,
            lr=Config.DAPT_LR,
            warmup_ratio=Config.DAPT_WARMUP_RATIO,
            max_steps=Config.DAPT_MAX_STEPS,
            tokenizer=tokenizer,
        )

    # Build DAPT+head classifier (robust encoder copy)
    xlm_r = safe_load_seqcls_from_dapt(
        dapt_dir=dapt_dir,
        num_labels=Config.NUM_LABELS,
        base_model_name=Config.MODEL_NAME
    )
    load_ft_classifier_only_into(xlm_r, Config.XLM_R_MODEL_PATH)
    xlm_r.to(DEVICE)

    # Push DAPT into fusion encoders (robust)
    load_dapt_into_fusion_encoder(simple, dapt_dir, Config.MODEL_NAME)
    load_dapt_into_fusion_encoder(gated,  dapt_dir, Config.MODEL_NAME)

    # Build OOD datasets
    feats_ood_raw = fe.extract_df(ddf)
    ds_all        = TextFeatDataset(ddf, tokenizer, feats=None)
    ds_simple_all = TextFeatDataset(ddf, tokenizer, feats=simple_scaler.transform(feats_ood_raw))
    ds_gated_all  = TextFeatDataset(ddf, tokenizer, feats=gated_scaler.transform(feats_ood_raw))

    n = len(ddf)
    k_target = max(int(Config.DA_ADAPT_RATIO * n), 64)
    k = min(k_target, max(0, n//2))

    # ✅ Select *separate* uncertain sets per model (better adaptation for each)
    proba_t = lambda idxs: get_probs_transformer(xlm_r, ds_all, idxs, batch=Config.BATCH_SIZE)
    adapt_t, test_t = pick_uncertain_indices_with_proba(proba_t, k=k, n=n)
    print(f"[Transformer] Adapt: {len(adapt_t)} | Test: {len(test_t)}")

    proba_s = lambda idxs: get_probs_fusion(simple, ds_simple_all, idxs, batch=Config.BATCH_SIZE)
    adapt_s, test_s = pick_uncertain_indices_with_proba(proba_s, k=k, n=n)
    print(f"[SimpleFusion] Adapt: {len(adapt_s)} | Test: {len(test_s)}")

    proba_g = lambda idxs: get_probs_fusion(gated, ds_gated_all, idxs, batch=Config.BATCH_SIZE)
    adapt_g, test_g = pick_uncertain_indices_with_proba(proba_g, k=k, n=n)
    print(f"[GatedFusion] Adapt: {len(adapt_g)} | Test: {len(test_g)}")

    # ---- Zero-shot: transformer
    f1 = evaluate_model(xlm_r, make_loader(ds_all, Config.BATCH_SIZE, sampler=SubsetRandomSampler(np.arange(n))), 'transformer')
    results.append({'Domain':domain,'Model':'XLM-R Only','Evaluation':'Zero-Shot','Macro F1-Score':float(f1)})

    # ---- Zero-shot: XGB
    if 'xgb_model' in globals() and xgb_model is not None:
        preds = xgb_model.predict(xgb_scaler.transform(feats_ood_raw))
        f1 = f1_score(y_true, preds, average='macro')
        results.append({'Domain':domain,'Model':'XGBoost + Features','Evaluation':'Zero-Shot','Macro F1-Score':float(f1)})

    # ---- Zero-shot: fusion (DAPT encoders)
    f1 = evaluate_model(simple, make_loader(ds_simple_all, Config.BATCH_SIZE, sampler=SubsetRandomSampler(np.arange(n))), 'simple')
    results.append({'Domain':domain,'Model':'Simple Fusion','Evaluation':'Zero-Shot (DAPT encoder)','Macro F1-Score':float(f1)})

    f1 = evaluate_model(gated, make_loader(ds_gated_all, Config.BATCH_SIZE, sampler=SubsetRandomSampler(np.arange(n))), 'gated')
    results.append({'Domain':domain,'Model':'Gated Fusion','Evaluation':'Zero-Shot (DAPT encoder)','Macro F1-Score':float(f1)})

    # ---- Zero-shot Ensemble tuned on transformer-adapt split (consistent)
    if 'xgb_model' in globals() and xgb_model is not None and len(adapt_t)>0 and len(test_t)>0:
        p_t_adapt = get_probs_transformer(xlm_r, ds_all, adapt_t, batch=Config.BATCH_SIZE)
        p_x_adapt = xgb_model.predict_proba(xgb_scaler.transform(feats_ood_raw[adapt_t]))[:,1]
        best_f1, best_w, best_thr = tune_weight_and_threshold(p_t_adapt, p_x_adapt, y_true[adapt_t])

        p_t_test = get_probs_transformer(xlm_r, ds_all, test_t, batch=Config.BATCH_SIZE)
        p_x_test = xgb_model.predict_proba(xgb_scaler.transform(feats_ood_raw[test_t]))[:,1]
        p_blend = best_w*p_t_test + (1-best_w)*p_x_test
        preds = (p_blend >= best_thr).astype(int)
        f1_blend = f1_score(y_true[test_t], preds, average='macro')
        results.append({'Domain':domain,'Model':f'Ensemble(T+XGB) w={best_w:.2f} thr={best_thr:.2f}',
                        'Evaluation':'Zero-Shot (tuned on adapt)','Macro F1-Score':float(f1_blend)})

    # ---- Domain Adaptation: transformer + fusion (each on its own adapt set)
    if len(adapt_t)>0 and len(test_t)>0:
        print("  Running Domain Adaptation…")
        unfreeze_top_n(xlm_r, Config.DA_UNFREEZE_TOP_N)
        adapt_transformer_fewshot(xlm_r, ds_all, adapt_t, lr=Config.DA_LR, epochs=Config.DA_EPOCHS)
        f1 = evaluate_model(xlm_r, make_loader(ds_all, Config.BATCH_SIZE, sampler=SubsetRandomSampler(test_t)), 'transformer')
        results.append({'Domain':domain,'Model':'XLM-R Only','Evaluation':'Domain Adapted','Macro F1-Score':float(f1)})

    if len(adapt_s)>0 and len(test_s)>0:
        unfreeze_top_n(simple, Config.DA_UNFREEZE_TOP_N)
        adapt_fusion_fewshot(simple, ds_simple_all, adapt_s, lr=Config.DA_LR, epochs=Config.DA_EPOCHS)
        f1 = evaluate_model(simple, make_loader(ds_simple_all, Config.BATCH_SIZE, sampler=SubsetRandomSampler(test_s)), 'simple')
        results.append({'Domain':domain,'Model':'Simple Fusion','Evaluation':'Domain Adapted','Macro F1-Score':float(f1)})

    if len(adapt_g)>0 and len(test_g)>0:
        unfreeze_top_n(gated, Config.DA_UNFREEZE_TOP_N)
        adapt_fusion_fewshot(gated, ds_gated_all, adapt_g, lr=Config.DA_LR, epochs=Config.DA_EPOCHS)
        f1 = evaluate_model(gated, make_loader(ds_gated_all, Config.BATCH_SIZE, sampler=SubsetRandomSampler(test_g)), 'gated')
        results.append({'Domain':domain,'Model':'Gated Fusion','Evaluation':'Domain Adapted','Macro F1-Score':float(f1)})

    # Optional: ensemble after adaptation (still using transformer split for comparability)
    if 'xgb_model' in globals() and xgb_model is not None and len(adapt_t)>0 and len(test_t)>0:
        p_t_adapt2 = get_probs_transformer(xlm_r, ds_all, adapt_t, batch=Config.BATCH_SIZE)
        p_x_adapt2 = xgb_model.predict_proba(xgb_scaler.transform(feats_ood_raw[adapt_t]))[:,1]
        best_f1, best_w, best_thr = tune_weight_and_threshold(p_t_adapt2, p_x_adapt2, y_true[adapt_t])

        p_t_test2 = get_probs_transformer(xlm_r, ds_all, test_t, batch=Config.BATCH_SIZE)
        p_x_test2 = xgb_model.predict_proba(xgb_scaler.transform(feats_ood_raw[test_t]))[:,1]
        p_blend2 = best_w*p_t_test2 + (1-best_w)*p_x_test2
        preds2 = (p_blend2 >= best_thr).astype(int)
        f1_blend2 = f1_score(y_true[test_t], preds2, average='macro')
        results.append({'Domain':domain,'Model':f'Ensemble(T+XGB) w={best_w:.2f} thr={best_thr:.2f}',
                        'Evaluation':'Domain Adapted (tuned)','Macro F1-Score':float(f1_blend2)})

# Save & pretty print
results_df = pd.DataFrame(results)
os.makedirs(os.path.dirname(Config.RESULTS_CSV_PATH), exist_ok=True)
results_df.to_csv(Config.RESULTS_CSV_PATH, index=False)

print("\n\n" + "="*80)
print("CROSS-DOMAIN GENERALIZATION FINAL RESULTS")
print("="*80)
from pandas import option_context
with option_context('display.max_rows', None, 'display.max_columns', None):
    print(nice_pivot(results_df))
print(f"\n✅ Results summary saved to: {Config.RESULTS_CSV_PATH}")

print("Block G done ✔️")


In [None]:
# ============================================
# Sanity Checks v2: fixed splits + class weighting + threshold calibration
# Run AFTER Blocks A–G (where models, configs, datasets are defined)
# ============================================

import numpy as np, pandas as pd, os, torch
from torch.utils.data import DataLoader, SubsetRandomSampler
from sklearn.metrics import f1_score
from tqdm import tqdm

# ---------- Utilities from the pipeline (light wrappers) ----------
def _predict_argmax_transformer(model, ds, idx, batch=64):
    if len(idx) == 0: return np.array([], dtype=int)
    dl = DataLoader(ds, batch_size=batch, sampler=SubsetRandomSampler(idx))
    model.eval(); out=[]
    with torch.no_grad():
        for b in dl:
            ids=b['input_ids'].to(DEVICE); am=b['attention_mask'].to(DEVICE)
            logits = model(input_ids=ids, attention_mask=am, return_dict=True).logits
            out.append(logits.argmax(1).cpu().numpy())
    return np.concatenate(out) if out else np.array([], dtype=int)

def _predict_argmax_fusion(model, ds, idx, batch=64):
    if len(idx) == 0: return np.array([], dtype=int)
    dl = DataLoader(ds, batch_size=batch, sampler=SubsetRandomSampler(idx))
    model.eval(); out=[]
    with torch.no_grad():
        for b in dl:
            ids=b['input_ids'].to(DEVICE); am=b['attention_mask'].to(DEVICE); feats=b['feats'].to(DEVICE)
            logits = model(input_ids=ids, attention_mask=am, feats=feats)
            out.append(logits.argmax(1).cpu().numpy())
    return np.concatenate(out) if out else np.array([], dtype=int)

def _get_probs_transformer(model, ds, idx, batch=64):
    if Config.NUM_LABELS != 2: raise ValueError("Threshold calibration only coded for binary.")
    if len(idx) == 0: return np.array([])
    dl = DataLoader(ds, batch_size=batch, sampler=SubsetRandomSampler(idx))
    model.eval(); out=[]
    with torch.no_grad():
        for b in dl:
            ids=b['input_ids'].to(DEVICE); am=b['attention_mask'].to(DEVICE)
            p = torch.softmax(model(input_ids=ids, attention_mask=am, return_dict=True).logits, dim=-1)[:,1]
            out.append(p.cpu().numpy())
    return np.concatenate(out) if out else np.array([])

def _get_probs_fusion(model, ds, idx, batch=64):
    if Config.NUM_LABELS != 2: raise ValueError("Threshold calibration only coded for binary.")
    if len(idx) == 0: return np.array([])
    dl = DataLoader(ds, batch_size=batch, sampler=SubsetRandomSampler(idx))
    model.eval(); out=[]
    with torch.no_grad():
        for b in dl:
            ids=b['input_ids'].to(DEVICE); am=b['attention_mask'].to(DEVICE); feats=b['feats'].to(DEVICE)
            p = torch.softmax(model(input_ids=ids, attention_mask=am, feats=feats), dim=-1)[:,1]
            out.append(p.cpu().numpy())
    return np.concatenate(out) if out else np.array([])

def _bootstrap_ci_macro_f1(y_true, y_pred, n_boot=1000, seed=123):
    if len(y_true) == 0: return (float("nan"), float("nan"))
    rng = np.random.default_rng(seed)
    n = len(y_true); scores = []
    for _ in range(n_boot):
        idx = rng.integers(0, n, n)
        scores.append(f1_score(y_true[idx], y_pred[idx], average='macro'))
    lo, hi = np.percentile(scores, [2.5, 97.5])
    return float(lo), float(hi)

def _entropy(p):
    p = np.clip(p, 1e-9, 1-1e-9)
    return -(p*np.log(p) + (1-p)*np.log(1-p))

def _calibrate_threshold(p, y, grid=None):
    # maximize macro-F1 on adapt set
    if grid is None: grid = np.linspace(0.2, 0.8, 61)
    best = (0.0, 0.5)
    for t in grid:
        preds = (p >= t).astype(int)
        f1 = f1_score(y, preds, average='macro')
        if f1 > best[0]: best = (f1, t)
    return best[1]

def _class_weights_from_labels(y, num_classes):
    counts = np.bincount(y, minlength=num_classes).astype(np.float32)
    w = counts.sum() / np.maximum(1.0, counts)
    w = w / w.mean()
    return torch.tensor(w, dtype=torch.float32, device=DEVICE)

# ---------- Weighted adaptation (transformer + fusion) ----------
def adapt_transformer_weighted(model, ds_all, adapt_idx, lr=1.5e-5, epochs=4, label_smoothing=0.05):
    from torch.optim import AdamW
    from transformers import get_linear_schedule_with_warmup
    from torch.amp import autocast, GradScaler

    y_adapt = ds_all.labels[adapt_idx]
    ce = torch.nn.CrossEntropyLoss(weight=_class_weights_from_labels(y_adapt, Config.NUM_LABELS),
                                   label_smoothing=label_smoothing)
    opt = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    steps = max(1, math.ceil(len(adapt_idx)/max(1, Config.BATCH_SIZE)))*epochs
    sch = get_linear_schedule_with_warmup(opt, 0, steps)
    scaler = GradScaler(device="cuda") if DEVICE=="cuda" else GradScaler()
    dl = DataLoader(ds_all, batch_size=Config.BATCH_SIZE, sampler=SubsetRandomSampler(adapt_idx))
    model.train()
    for _ in range(epochs):
        for b in dl:
            opt.zero_grad(set_to_none=True)
            ids=b['input_ids'].to(DEVICE); am=b['attention_mask'].to(DEVICE); y=b['labels'].to(DEVICE)
            with autocast(device_type="cuda") if DEVICE=="cuda" else torch.cuda.amp.autocast(enabled=False):
                loss = ce(model(input_ids=ids, attention_mask=am, return_dict=True).logits, y)
            scaler.scale(loss).backward()
            scaler.step(opt); scaler.update(); sch.step()

def adapt_fusion_weighted(model, ds_all, adapt_idx, lr=1.5e-5, epochs=4, label_smoothing=0.05):
    from torch.optim import AdamW
    from transformers import get_linear_schedule_with_warmup
    from torch.amp import autocast, GradScaler

    y_adapt = ds_all.labels[adapt_idx]
    ce = torch.nn.CrossEntropyLoss(weight=_class_weights_from_labels(y_adapt, Config.NUM_LABELS),
                                   label_smoothing=label_smoothing)
    opt = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    steps = max(1, math.ceil(len(adapt_idx)/max(1, Config.BATCH_SIZE)))*epochs
    sch = get_linear_schedule_with_warmup(opt, 0, steps)
    scaler = GradScaler(device="cuda") if DEVICE=="cuda" else GradScaler()
    dl = DataLoader(ds_all, batch_size=Config.BATCH_SIZE, sampler=SubsetRandomSampler(adapt_idx))
    model.train()
    for _ in range(epochs):
        for b in dl:
            opt.zero_grad(set_to_none=True)
            ids=b['input_ids'].to(DEVICE); am=b['attention_mask'].to(DEVICE); feats=b['feats'].to(DEVICE); y=b['labels'].to(DEVICE)
            with autocast(device_type="cuda") if DEVICE=="cuda" else torch.cuda.amp.autocast(enabled=False):
                loss = ce(model(input_ids=ids, attention_mask=am, feats=feats), y)
            scaler.scale(loss).backward()
            scaler.step(opt); scaler.update(); sch.step()

# ---------- Precompute FIXED adapt/test splits (once) ----------
def make_fixed_splits(entropy_seed=13):
    fixed = {}
    rng = np.random.default_rng(entropy_seed)
    for domain in ood_df['domain'].unique():
        ddf = ood_df[ood_df['domain']==domain].copy().reset_index(drop=True)
        n = len(ddf)
        feats = fe.extract_df(ddf)
        ds_all = TextFeatDataset(ddf, tokenizer, feats=None)

        # transformer with deterministic seed for split
        seed_everything(entropy_seed)
        dapt_dir = os.path.join(Config.DAPT_SAVE_DIR, domain.replace(" ", "_"))
        m = safe_load_seqcls_from_dapt(dapt_dir, Config.NUM_LABELS, Config.MODEL_NAME)
        load_ft_classifier_only_into(m, Config.XLM_R_MODEL_PATH); m.to(DEVICE)

        p = _get_probs_transformer(m, ds_all, np.arange(n), batch=Config.BATCH_SIZE)
        ent = _entropy(p)

        k_target = max(int(Config.DA_ADAPT_RATIO*n), 64)
        k = min(k_target, max(1, n//2))
        order = np.argsort(-ent)   # highest uncertainty first
        adapt_idx = order[:k]
        test_idx  = order[k:]
        fixed[domain] = (adapt_idx, test_idx)
    return fixed

# ---------- Main runner with fixed splits + weighted adaptation + threshold calibration ----------
def run_sanity_checks_v2(
    seeds=(7,11,19),
    models=("xlmr","simple","gated"),
    n_boot=1000,
    fixed_splits=None
):
    if fixed_splits is None:
        fixed_splits = make_fixed_splits(entropy_seed=13)

    rows_summary=[]; rows_ci=[]; rows_perclass=[]

    for domain in ood_df['domain'].unique():
        ddf = ood_df[ood_df['domain']==domain].copy().reset_index(drop=True)
        ddf = remap_labels_if_needed(ddf)
        y_all = ddf['label'].astype(int).values
        n = len(ddf)

        feats_ood_raw = fe.extract_df(ddf)
        ds_all = TextFeatDataset(ddf, tokenizer, feats=None)
        ds_simple_all = TextFeatDataset(ddf, tokenizer, feats=simple_scaler.transform(feats_ood_raw)) if ('simple' in models and 'simple_scaler' in globals()) else None
        ds_gated_all  = TextFeatDataset(ddf, tokenizer, feats=gated_scaler.transform(feats_ood_raw))  if ('gated'  in models and 'gated_scaler'  in globals()) else None

        adapt_idx, test_idx = fixed_splits[domain]
        per_seed = {m:[] for m in models}

        for seed in seeds:
            seed_everything(seed)
            dapt_dir = os.path.join(Config.DAPT_SAVE_DIR, domain.replace(" ", "_"))

            # --- Transformer ---
            if "xlmr" in models:
                xlm = safe_load_seqcls_from_dapt(dapt_dir, Config.NUM_LABELS, Config.MODEL_NAME)
                load_ft_classifier_only_into(xlm, Config.XLM_R_MODEL_PATH); xlm.to(DEVICE)
                unfreeze_top_n(xlm, Config.DA_UNFREEZE_TOP_N)
                adapt_transformer_weighted(xlm, ds_all, adapt_idx, lr=Config.DA_LR, epochs=Config.DA_EPOCHS, label_smoothing=0.05)

                if Config.NUM_LABELS==2:
                    p_adapt = _get_probs_transformer(xlm, ds_all, adapt_idx, batch=Config.BATCH_SIZE)
                    thr = _calibrate_threshold(p_adapt, y_all[adapt_idx])
                    p_test = _get_probs_transformer(xlm, ds_all, test_idx, batch=Config.BATCH_SIZE)
                    y_pred = (p_test >= thr).astype(int)
                else:
                    y_pred = _predict_argmax_transformer(xlm, ds_all, test_idx, batch=Config.BATCH_SIZE)

                f1 = f1_score(y_all[test_idx], y_pred, average='macro'); per_seed["xlmr"].append(f1)
                lo,hi = _bootstrap_ci_macro_f1(y_all[test_idx], y_pred, n_boot=n_boot, seed=seed+101)
                rows_ci.append({"Domain":domain,"Model":"XLM-R Only (DA)","Seed":seed,"Macro F1 (95% CI)":f"[{lo:.3f}, {hi:.3f}]"})
                pc = f1_score(y_all[test_idx], y_pred, average=None, labels=np.arange(Config.NUM_LABELS))
                for cls_id, cls_f1 in enumerate(pc):
                    rows_perclass.append({"Domain":domain,"Model":"XLM-R Only (DA)","Seed":seed,"Class":cls_id,"F1":float(cls_f1)})

            # --- Simple Fusion ---
            if ds_simple_all is not None:
                simple = SimpleFusion(Config.MODEL_NAME, Config.FEAT_DIM).to(DEVICE)
                load_dapt_into_fusion_encoder(simple, dapt_dir, Config.MODEL_NAME)
                unfreeze_top_n(simple, Config.DA_UNFREEZE_TOP_N)
                adapt_fusion_weighted(simple, ds_simple_all, adapt_idx, lr=Config.DA_LR, epochs=Config.DA_EPOCHS, label_smoothing=0.05)

                if Config.NUM_LABELS==2:
                    p_adapt = _get_probs_fusion(simple, ds_simple_all, adapt_idx, batch=Config.BATCH_SIZE)
                    thr = _calibrate_threshold(p_adapt, y_all[adapt_idx])
                    p_test = _get_probs_fusion(simple, ds_simple_all, test_idx, batch=Config.BATCH_SIZE)
                    y_pred = (p_test >= thr).astype(int)
                else:
                    y_pred = _predict_argmax_fusion(simple, ds_simple_all, test_idx, batch=Config.BATCH_SIZE)

                f1 = f1_score(y_all[test_idx], y_pred, average='macro'); per_seed.setdefault("simple",[]).append(f1)
                lo,hi = _bootstrap_ci_macro_f1(y_all[test_idx], y_pred, n_boot=n_boot, seed=seed+202)
                rows_ci.append({"Domain":domain,"Model":"Simple Fusion (DA)","Seed":seed,"Macro F1 (95% CI)":f"[{lo:.3f}, {hi:.3f}]"})
                pc = f1_score(y_all[test_idx], y_pred, average=None, labels=np.arange(Config.NUM_LABELS))
                for cls_id, cls_f1 in enumerate(pc):
                    rows_perclass.append({"Domain":domain,"Model":"Simple Fusion (DA)","Seed":seed,"Class":cls_id,"F1":float(cls_f1)})

            # --- Gated Fusion ---
            if ds_gated_all is not None:
                gated = GatedFusion(Config.MODEL_NAME, Config.FEAT_DIM, feat_proj=64).to(DEVICE)
                load_dapt_into_fusion_encoder(gated, dapt_dir, Config.MODEL_NAME)
                unfreeze_top_n(gated, Config.DA_UNFREEZE_TOP_N)
                adapt_fusion_weighted(gated, ds_gated_all, adapt_idx, lr=Config.DA_LR, epochs=Config.DA_EPOCHS, label_smoothing=0.05)

                if Config.NUM_LABELS==2:
                    p_adapt = _get_probs_fusion(gated, ds_gated_all, adapt_idx, batch=Config.BATCH_SIZE)
                    thr = _calibrate_threshold(p_adapt, y_all[adapt_idx])
                    p_test = _get_probs_fusion(gated, ds_gated_all, test_idx, batch=Config.BATCH_SIZE)
                    y_pred = (p_test >= thr).astype(int)
                else:
                    y_pred = _predict_argmax_fusion(gated, ds_gated_all, test_idx, batch=Config.BATCH_SIZE)

                f1 = f1_score(y_all[test_idx], y_pred, average='macro'); per_seed.setdefault("gated",[]).append(f1)
                lo,hi = _bootstrap_ci_macro_f1(y_all[test_idx], y_pred, n_boot=n_boot, seed=seed+303)
                rows_ci.append({"Domain":domain,"Model":"Gated Fusion (DA)","Seed":seed,"Macro F1 (95% CI)":f"[{lo:.3f}, {hi:.3f}]"})
                pc = f1_score(y_all[test_idx], y_pred, average=None, labels=np.arange(Config.NUM_LABELS))
                for cls_id, cls_f1 in enumerate(pc):
                    rows_perclass.append({"Domain":domain,"Model":"Gated Fusion (DA)","Seed":seed,"Class":cls_id,"F1":float(cls_f1)})

        # summarize mean±std
        for m, label in [("xlmr","XLM-R Only (DA)"), ("simple","Simple Fusion (DA)"), ("gated","Gated Fusion (DA)")]:
            if m not in models: continue
            arr = np.array(per_seed.get(m,[]), dtype=float)
            if arr.size:
                rows_summary.append({"Domain":domain,"Model":label,"Seeds":str(list(seeds)),"Macro F1 (mean ± std)":f"{arr.mean():.4f} ± {arr.std(ddof=1):.4f}"})

    df_sum = pd.DataFrame(rows_summary).sort_values(["Domain","Model"])
    df_ci  = pd.DataFrame(rows_ci).sort_values(["Domain","Model","Seed"])
    df_pc  = pd.DataFrame(rows_perclass).sort_values(["Domain","Model","Seed","Class"])

    out_dir = "/content/emc_run"; os.makedirs(out_dir, exist_ok=True)
    df_sum.to_csv(os.path.join(out_dir,"sanity_v2_seed_robustness_summary.csv"), index=False)
    df_ci.to_csv(os.path.join(out_dir,"sanity_v2_bootstrap_cis.csv"), index=False)
    df_pc.to_csv(os.path.join(out_dir,"sanity_v2_per_class_f1.csv"), index=False)

    print("\n=== Seed robustness (Macro F1 mean ± std) — v2 ===")
    print(df_sum.to_string(index=False))
    print("\n=== Bootstrap 95% CI (per seed) — v2 ===")
    print(df_ci.to_string(index=False))
    print("\n=== Per-class F1 (per seed) — v2 ===")
    print(df_pc.to_string(index=False))
    print(f"\n✅ Saved:\n- {os.path.join(out_dir,'sanity_v2_seed_robustness_summary.csv')}\n- {os.path.join(out_dir,'sanity_v2_bootstrap_cis.csv')}\n- {os.path.join(out_dir,'sanity_v2_per_class_f1.csv')}")

# ---- Run it
models_to_check = ("xlmr","simple","gated")   # adjust if needed
fixed_splits = make_fixed_splits(entropy_seed=13)
run_sanity_checks_v2(seeds=(7,11,19), models=models_to_check, n_boot=1000, fixed_splits=fixed_splits)


In [None]:
# ================================
# Sanity Checks: seeds + bootstrap
# Run this AFTER Blocks A–G
# ================================

import numpy as np, pandas as pd, math, os
from tqdm import tqdm
from sklearn.metrics import f1_score
from torch.utils.data import DataLoader, SubsetRandomSampler
import torch

# ---------- Helpers (predictions, selection, bootstrap) ----------
def _predict_transformer(model, ds, idx, batch=64):
    if len(idx) == 0: return np.array([], dtype=int)
    dl = DataLoader(ds, batch_size=batch, sampler=SubsetRandomSampler(idx))
    model.eval(); preds = []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(DEVICE); am = b['attention_mask'].to(DEVICE)
            out = model(input_ids=ids, attention_mask=am, return_dict=True)
            p = out.logits.argmax(1).cpu().numpy()
            preds.append(p)
    return np.concatenate(preds) if preds else np.array([], dtype=int)

def _predict_fusion(model, ds, idx, batch=64):
    if len(idx) == 0: return np.array([], dtype=int)
    dl = DataLoader(ds, batch_size=batch, sampler=SubsetRandomSampler(idx))
    model.eval(); preds = []
    with torch.no_grad():
        for b in dl:
            ids = b['input_ids'].to(DEVICE); am = b['attention_mask'].to(DEVICE); feats = b['feats'].to(DEVICE)
            logits = model(input_ids=ids, attention_mask=am, feats=feats)
            p = logits.argmax(1).cpu().numpy()
            preds.append(p)
    return np.concatenate(preds) if preds else np.array([], dtype=int)

def _entropy_from_probs_binary(p):
    p = np.clip(p, 1e-9, 1 - 1e-9)
    return -(p * np.log(p) + (1 - p) * np.log(1 - p))

def _pick_uncertain_by_entropy(proba_fn, n, k):
    """proba_fn: function(idxs)->probs_of_class1 (np.ndarray)"""
    idx_all = np.arange(n)
    probs = proba_fn(idx_all)
    ent = _entropy_from_probs_binary(probs)
    order = np.argsort(-ent)          # high entropy = more uncertain
    k = min(max(k, 1), max(1, n - 1)) # ensure non-empty test
    adapt_idx = idx_all[order[:k]]
    test_idx  = idx_all[order[k:]]
    return adapt_idx, test_idx

def _bootstrap_ci_macro_f1(y_true, y_pred, n_boot=1000, seed=123):
    if len(y_true) == 0: return (float("nan"), float("nan"))
    rng = np.random.default_rng(seed)
    n = len(y_true)
    scores = []
    for _ in range(n_boot):
        idx = rng.integers(0, n, n)
        scores.append(f1_score(y_true[idx], y_pred[idx], average='macro'))
    lo, hi = np.percentile(scores, [2.5, 97.5])
    return float(lo), float(hi)

# ---------- Minimal fallbacks for fusion adaptation if not defined ----------
try:
    adapt_fusion_fewshot
except NameError:
    def adapt_fusion_fewshot(model, ds_all, adapt_idx, lr=1.5e-5, epochs=4):
        from torch.optim import AdamW
        from transformers import get_linear_schedule_with_warmup
        from torch.amp import GradScaler, autocast
        scaler = GradScaler(device="cuda") if DEVICE=="cuda" else GradScaler()
        opt = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
        steps = max(1, len(adapt_idx) // max(1, Config.BATCH_SIZE)) * epochs
        sch = get_linear_schedule_with_warmup(opt, 0, steps)
        ce = torch.nn.CrossEntropyLoss()
        dl = DataLoader(ds_all, batch_size=Config.BATCH_SIZE,
                        sampler=SubsetRandomSampler(adapt_idx), drop_last=False)
        model.train()
        for _ in range(epochs):
            for b in dl:
                ids=b['input_ids'].to(DEVICE); am=b['attention_mask'].to(DEVICE)
                feats=b['feats'].to(DEVICE); y=b['labels'].to(DEVICE)
                opt.zero_grad(set_to_none=True)
                with autocast(device_type="cuda") if DEVICE=="cuda" else torch.cuda.amp.autocast(enabled=False):
                    logits = model(input_ids=ids, attention_mask=am, feats=feats)
                    loss = ce(logits, y)
                scaler.scale(loss).backward()
                scaler.step(opt); scaler.update(); sch.step()

try:
    get_probs_fusion
except NameError:
    def get_probs_fusion(model, ds, idx, batch=64):
        if len(idx) == 0: return np.array([])
        dl = DataLoader(ds, batch_size=batch, sampler=SubsetRandomSampler(idx))
        model.eval(); out=[]
        with torch.no_grad():
            for b in dl:
                ids=b['input_ids'].to(DEVICE); am=b['attention_mask'].to(DEVICE); feats=b['feats'].to(DEVICE)
                logits = model(input_ids=ids, attention_mask=am, feats=feats)
                p = torch.softmax(logits, dim=-1)[:, 1].cpu().numpy()
                out.append(p)
        return np.concatenate(out) if len(out)>0 else np.array([])

# ---------- Main sanity runner ----------
def run_sanity_checks(
    seeds=(7, 11, 19),
    models=("xlmr", "simple", "gated"),   # choose any subset
    n_boot=1000
):
    domains = ood_df['domain'].unique()
    rows_summary = []
    rows_ci = []
    rows_perclass = []

    for domain in domains:
        print(f"\n=== Sanity checks for domain: {domain} ===")
        ddf = ood_df[ood_df['domain']==domain].copy().reset_index(drop=True)
        y_all = ddf['label'].astype(int).values
        n = len(ddf)

        # Build fixed datasets once per domain
        feats_ood_raw = fe.extract_df(ddf)
        ds_all = TextFeatDataset(ddf, tokenizer, feats=None)

        ds_simple_all = None; ds_gated_all = None
        have_simple = ("simple" in models) and ('simple_scaler' in globals())
        have_gated  = ("gated"  in models) and ('gated_scaler'  in globals())

        if have_simple:
            ds_simple_all = TextFeatDataset(ddf, tokenizer, feats=simple_scaler.transform(feats_ood_raw))
        if have_gated:
            ds_gated_all = TextFeatDataset(ddf, tokenizer, feats=gated_scaler.transform(feats_ood_raw))

        # compute adapt size
        k_target = max(int(Config.DA_ADAPT_RATIO * n), 64)
        k = min(k_target, max(1, n//2))

        # collect per-seed results
        per_seed_scores = {m: [] for m in models}

        for seed in seeds:
            seed_everything(seed)

            dapt_dir = os.path.join(Config.DAPT_SAVE_DIR, domain.replace(" ", "_"))

            # XLM-R model (DAPT encoder + FT head)
            xlm_r = safe_load_seqcls_from_dapt(dapt_dir, Config.NUM_LABELS, Config.MODEL_NAME)
            load_ft_classifier_only_into(xlm_r, Config.XLM_R_MODEL_PATH)
            xlm_r.to(DEVICE)

            # Fusion models (fresh instances per seed)
            simple_local = gated_local = None
            if have_simple:
                simple_local = SimpleFusion(Config.MODEL_NAME, Config.FEAT_DIM)
                simple_local.to(DEVICE)
                load_dapt_into_fusion_encoder(simple_local, dapt_dir, Config.MODEL_NAME)
            if have_gated:
                gated_local  = GatedFusion(Config.MODEL_NAME, Config.FEAT_DIM, feat_proj=64)
                gated_local.to(DEVICE)
                load_dapt_into_fusion_encoder(gated_local,  dapt_dir, Config.MODEL_NAME)

            # Select adapt/test indices using Transformer uncertainty (consistent split)
            p_all = get_probs_transformer(xlm_r, ds_all, np.arange(n), batch=Config.BATCH_SIZE)
            adapt_idx, test_idx = _pick_uncertain_by_entropy(lambda idxs: p_all[idxs], n=n, k=k)

            # ---- Adapt & evaluate: XLM-R
            if "xlmr" in models:
                unfreeze_top_n(xlm_r, Config.DA_UNFREEZE_TOP_N)
                adapt_transformer_fewshot(xlm_r, ds_all, adapt_idx, lr=Config.DA_LR, epochs=Config.DA_EPOCHS)
                y_pred = _predict_transformer(xlm_r, ds_all, test_idx, batch=Config.BATCH_SIZE)
                f1 = f1_score(y_all[test_idx], y_pred, average='macro')
                per_seed_scores["xlmr"].append(f1)
                # bootstrap CI & per-class
                lo, hi = _bootstrap_ci_macro_f1(y_all[test_idx], y_pred, n_boot=n_boot, seed=seed+42)
                rows_ci.append({"Domain": domain, "Model": "XLM-R Only (DA)", "Seed": seed, "Macro F1 (95% CI)": f"[{lo:.3f}, {hi:.3f}]"})
                pc = f1_score(y_all[test_idx], y_pred, average=None, labels=np.arange(Config.NUM_LABELS))
                for cls_id, cls_f1 in enumerate(pc):
                    rows_perclass.append({"Domain":domain,"Model":"XLM-R Only (DA)","Seed":seed,"Class":cls_id,"F1":float(cls_f1)})

            # ---- Adapt & evaluate: Simple Fusion
            if have_simple and "simple" in models:
                unfreeze_top_n(simple_local, Config.DA_UNFREEZE_TOP_N)
                adapt_fusion_fewshot(simple_local, ds_simple_all, adapt_idx, lr=Config.DA_LR, epochs=Config.DA_EPOCHS)
                y_pred = _predict_fusion(simple_local, ds_simple_all, test_idx, batch=Config.BATCH_SIZE)
                f1 = f1_score(y_all[test_idx], y_pred, average='macro')
                per_seed_scores["simple"].append(f1)
                lo, hi = _bootstrap_ci_macro_f1(y_all[test_idx], y_pred, n_boot=n_boot, seed=seed+99)
                rows_ci.append({"Domain": domain, "Model": "Simple Fusion (DA)", "Seed": seed, "Macro F1 (95% CI)": f"[{lo:.3f}, {hi:.3f}]"})
                pc = f1_score(y_all[test_idx], y_pred, average=None, labels=np.arange(Config.NUM_LABELS))
                for cls_id, cls_f1 in enumerate(pc):
                    rows_perclass.append({"Domain":domain,"Model":"Simple Fusion (DA)","Seed":seed,"Class":cls_id,"F1":float(cls_f1)})

            # ---- Adapt & evaluate: Gated Fusion
            if have_gated and "gated" in models:
                unfreeze_top_n(gated_local, Config.DA_UNFREEZE_TOP_N)
                adapt_fusion_fewshot(gated_local, ds_gated_all, adapt_idx, lr=Config.DA_LR, epochs=Config.DA_EPOCHS)
                y_pred = _predict_fusion(gated_local, ds_gated_all, test_idx, batch=Config.BATCH_SIZE)
                f1 = f1_score(y_all[test_idx], y_pred, average='macro')
                per_seed_scores["gated"].append(f1)
                lo, hi = _bootstrap_ci_macro_f1(y_all[test_idx], y_pred, n_boot=n_boot, seed=seed+123)
                rows_ci.append({"Domain": domain, "Model": "Gated Fusion (DA)", "Seed": seed, "Macro F1 (95% CI)": f"[{lo:.3f}, {hi:.3f}]"})
                pc = f1_score(y_all[test_idx], y_pred, average=None, labels=np.arange(Config.NUM_LABELS))
                for cls_id, cls_f1 in enumerate(pc):
                    rows_perclass.append({"Domain":domain,"Model":"Gated Fusion (DA)","Seed":seed,"Class":cls_id,"F1":float(cls_f1)})

        # Summarize mean±std across seeds
        for m in models:
            if len(per_seed_scores[m]) == 0: continue
            arr = np.array(per_seed_scores[m], dtype=float)
            rows_summary.append({
                "Domain": domain,
                "Model": {"xlmr":"XLM-R Only (DA)","simple":"Simple Fusion (DA)","gated":"Gated Fusion (DA)"}[m],
                "Seeds": f"{list(seeds)}",
                "Macro F1 (mean ± std)": f"{arr.mean():.4f} ± {arr.std(ddof=1):.4f}"
            })

    df_summary   = pd.DataFrame(rows_summary).sort_values(["Domain","Model"]).reset_index(drop=True)
    df_ci        = pd.DataFrame(rows_ci).sort_values(["Domain","Model","Seed"]).reset_index(drop=True)
    df_perclass  = pd.DataFrame(rows_perclass).sort_values(["Domain","Model","Seed","Class"]).reset_index(drop=True)

    # Save & display
    out_dir = os.path.join("/content", "emc_run")
    os.makedirs(out_dir, exist_ok=True)
    df_summary.to_csv(os.path.join(out_dir, "sanity_seed_robustness_summary.csv"), index=False)
    df_ci.to_csv(os.path.join(out_dir, "sanity_bootstrap_cis.csv"), index=False)
    df_perclass.to_csv(os.path.join(out_dir, "sanity_per_class_f1.csv"), index=False)

    print("\n=== Seed robustness (Macro F1 mean ± std) ===")
    print(df_summary.to_string(index=False))
    print("\n=== Bootstrap 95% CI (per seed) ===")
    print(df_ci.to_string(index=False))
    print("\n=== Per-class F1 (per seed) ===")
    print(df_perclass.to_string(index=False))
    print(f"\n✅ Saved:\n- {os.path.join(out_dir, 'sanity_seed_robustness_summary.csv')}\n- {os.path.join(out_dir, 'sanity_bootstrap_cis.csv')}\n- {os.path.join(out_dir, 'sanity_per_class_f1.csv')}")

# ---------- Run the sanity checks ----------
# Choose which models to evaluate; comment out ones you don't need
models_to_check = ("xlmr", "simple", "gated")   # or e.g., ("xlmr","simple")
run_sanity_checks(seeds=(7,11,19), models=models_to_check, n_boot=1000)


In [None]:
# Create bar charts (with error bars) for your v2 seed-robust Macro F1 results
# Requirements followed: matplotlib only, one chart per figure, no explicit colors/styles.

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from caas_jupyter_tools import display_dataframe_to_user

# ---- Seed-robust results (mean ± std across seeds {7,11,19}) ----
data = [
    # Domain, Model, mean, std
    ("Biomedical Engineering", "Gated Fusion (DA)", 0.4631, 0.0956),
    ("Biomedical Engineering", "Simple Fusion (DA)", 0.4773, 0.0628),
    ("Biomedical Engineering", "XLM-R Only (DA)", 0.4582, 0.0355),
    ("Chemical Engineering",   "Gated Fusion (DA)", 0.5036, 0.0696),
    ("Chemical Engineering",   "Simple Fusion (DA)", 0.4842, 0.0412),
    ("Chemical Engineering",   "XLM-R Only (DA)", 0.4655, 0.0329),
    ("Software Engineering",   "Gated Fusion (DA)", 0.5053, 0.0635),
    ("Software Engineering",   "Simple Fusion (DA)", 0.4882, 0.0526),
    ("Software Engineering",   "XLM-R Only (DA)", 0.4689, 0.0440),
]

df = pd.DataFrame(data, columns=["Domain", "Model", "MacroF1_mean", "MacroF1_std"])

# Show the data as a table in the UI
display_dataframe_to_user("Seed-robust Macro F1 (mean ± std)", df)

# Ensure output directory
out_dir = "/mnt/data/figures"
os.makedirs(out_dir, exist_ok=True)

# --- Helper to annotate bars ---
def annotate_bars(ax, rects, values):
    for rect, val in zip(rects, values):
        height = rect.get_height()
        ax.annotate(f"{val:.3f}",
                    xy=(rect.get_x() + rect.get_width() / 2, height),
                    xytext=(0, 3),
                    textcoords="offset points",
                    ha='center', va='bottom', fontsize=9)

# --- Create a separate figure per domain ---
saved_paths = []
for domain in df["Domain"].unique():
    sub = df[df["Domain"] == domain].copy()
    # Keep a fixed model order for readability
    model_order = ["XLM-R Only (DA)", "Simple Fusion (DA)", "Gated Fusion (DA)"]
    sub["Model"] = pd.Categorical(sub["Model"], categories=model_order, ordered=True)
    sub = sub.sort_values("Model")

    fig, ax = plt.subplots(figsize=(7, 5))
    x = np.arange(len(sub))
    bars = ax.bar(x, sub["MacroF1_mean"].values, yerr=sub["MacroF1_std"].values, capsize=5)
    ax.set_xticks(x)
    ax.set_xticklabels(sub["Model"].tolist(), rotation=20, ha="right")
    ax.set_ylim(0.0, 1.0)
    ax.set_ylabel("Macro F1")
    ax.set_title(f"Out-of-Domain Performance after DA — {domain}")
    annotate_bars(ax, bars, sub["MacroF1_mean"].values)

    fname = f"{domain.lower().replace(' ', '_')}_da_seed_robust_bars.png"
    fpath = os.path.join(out_dir, fname)
    plt.tight_layout()
    plt.savefig(fpath, dpi=200, bbox_inches="tight")
    plt.show()
    saved_paths.append(fpath)

# Also produce a single combined figure with all domains grouped by model.
# (Still one chart; each model has three bars—one per domain—with error bars)

model_order = ["XLM-R Only (DA)", "Simple Fusion (DA)", "Gated Fusion (DA)"]
domains = ["Biomedical Engineering", "Chemical Engineering", "Software Engineering"]

fig, ax = plt.subplots(figsize=(8, 5))
width = 0.22
x = np.arange(len(model_order))

for i, domain in enumerate(domains):
    sub = df[df["Domain"] == domain].copy()
    sub["Model"] = pd.Categorical(sub["Model"], categories=model_order, ordered=True)
    sub = sub.sort_values("Model")
    means = sub["MacroF1_mean"].values
    stds = sub["MacroF1_std"].values

    rects = ax.bar(x + i*width - width, means, width, yerr=stds, capsize=4, label=domain)
    annotate_bars(ax, rects, means)

ax.set_xticks(x)
ax.set_xticklabels(model_order, rotation=15, ha="right")
ax.set_ylim(0.0, 1.0)
ax.set_ylabel("Macro F1")
ax.set_title("Out-of-Domain Performance after DA — Seed-Robust Means ± SD")
ax.legend()

combined_path = os.path.join(out_dir, "combined_da_seed_robust_bars.png")
plt.tight_layout()
plt.savefig(combined_path, dpi=200, bbox_inches="tight")
plt.show()

saved_paths.append(combined_path)

saved_paths


In [None]:
!pip install caas_jupyter_tools