In [None]:
#Dapt Implementation
import os, re, math, json, time, random, warnings
from typing import List
from pathlib import Path
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer, AutoModelForMaskedLM,
    DataCollatorForLanguageModeling,
    get_cosine_schedule_with_warmup
)

warnings.filterwarnings("once", category=UserWarning)


try:
    from tqdm import tqdm
    _HAS_TQDM = True
except Exception:
    _HAS_TQDM = False
    def tqdm(*args, **kwargs):
        class _Dummy:
            def update(self, *a, **k): pass
            def set_postfix(self, *a, **k): pass
            def set_postfix_str(self, *a, **k): pass
            def close(self): pass
        return _Dummy()


PARTS_DIR = "."
PART_BASENAMES = ["PartOne", "PartTwo", "PartThree", "PartFour"]
EXT_ORDER = [".xlsx", ".csv", ".tsv", ".xls"]

def resolve_parts(parts_dir: str, basenames: List[str]) -> List[str]:
    paths = []
    for base in basenames:
        base_path = Path(parts_dir) / base
        if base_path.suffix:
            if base_path.exists():
                paths.append(str(base_path))
            else:
                raise FileNotFoundError(f"[DAPT] File not found: {base_path}")
        else:
            found = None
            for ext in EXT_ORDER:
                cand = base_path.with_suffix(ext)
                if cand.exists():
                    found = cand; break
            if not found:
                tried = ", ".join(str(base_path.with_suffix(ext)) for ext in EXT_ORDER)
                raise FileNotFoundError(f"[DAPT] Could not find a file for '{base}'. Tried: {tried}")
            paths.append(str(found))
    return paths

PART_PATHS: List[str] = resolve_parts(PARTS_DIR, PART_BASENAMES)
print(f"[DAPT] Using parts: {PART_PATHS}")


TASKB_LEAK_GUARD = "taskB_youtube_raw.cleaned.csv"

# Training Budget
TOTAL_STEPS       = 120_000
VAL_EVERY_STEPS   = 5_000
CKPT_EVERY_STEPS  = 20_000
GRAD_ACCUM_STEPS  = 2


MAX_LEN     = 160  
BATCH_SIZE  = 16
NUM_WORKERS = 0
PIN_MEMORY  = torch.cuda.is_available()


WARMUP_RATIO = 0.06
WEIGHT_DECAY = 0.01
LR           = 1e-4  


USE_LORA        = True
MERGE_AND_SAVE  = True
LORA_R          = 8
LORA_ALPHA      = 16
LORA_DROPOUT    = 0.10
USE_GRAD_CHKPT  = False  


_LORA_BASE_TARGETS  = ["query","key","value","dense"]
_LORA_EXTRA_TARGETS = ["q_proj","k_proj","v_proj","out_proj","intermediate.dense","output.dense"]
LORA_TARGETS        = list(dict.fromkeys(_LORA_BASE_TARGETS + _LORA_EXTRA_TARGETS))


VAL_FRAC, VAL_MIN, VAL_MAX = 0.01, 2_000, 20_000


SEED = 42

SAVE_ADAPTER_DIR = "dapt_lora_adapter"
SAVE_MERGED_DIR  = "dapt_merged_backbone"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[DAPT] device={DEVICE}, steps={TOTAL_STEPS}, batch={BATCH_SIZE}, grad_accum={GRAD_ACCUM_STEPS}")

def set_all_seeds(seed: int = 42):
    import numpy as _np
    random.seed(seed); _np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

URL_RE  = re.compile(r'https?://\S+|www\.\S+', re.IGNORECASE)
MENT_RE = re.compile(r'@\w+')

def normalize_text(s: str) -> str:
    s = str(s)
    s = (s.replace("â€™","’").replace("â€œ","“").replace("â€\x9d","”").replace("â€“","–"))
    s = URL_RE.sub("<url>", s); s = MENT_RE.sub("@user", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s

TEXT_COLNAME_PREFS = ["CommentText","comment_text","text","raw_text","Comment","body","content"]

def guess_text_column(df: pd.DataFrame) -> str:
    for c in TEXT_COLNAME_PREFS:
        if c in df.columns: return c
    obj_cols = [c for c in df.columns if df[c].dtype == "object"]
    if not obj_cols: raise ValueError("No text-like column found.")
    lens = {c: df[c].astype(str).str.len().mean() for c in obj_cols}
    return max(lens, key=lens.get)

def _read_excel(path: Path) -> pd.DataFrame:
    try:    return pd.read_excel(path, engine="openpyxl")
    except Exception: return pd.read_excel(path)

def load_table_texts(path: str) -> List[str]:
    p = Path(path); ext = p.suffix.lower()
    if ext in [".xlsx",".xls"]:
        df = _read_excel(p)
    elif ext == ".tsv":
        df = pd.read_csv(p, sep="\t", encoding="utf-8", on_bad_lines="skip")
    else:
        df = pd.read_csv(p, encoding="utf-8", on_bad_lines="skip")
    col = guess_text_column(df)
    print(f"[DAPT] using text column '{col}' from {p.name} (columns={list(df.columns)})")
    series = df[col].dropna().astype(str)
    return [normalize_text(x) for x in series if x.strip()]

def as_mlm_checkpoint(backbone_name: str) -> str:
    lower = backbone_name.lower()
    if "twitter-roberta-base-sentiment" in lower: return "cardiffnlp/twitter-roberta-base"
    if "sentiment" in lower and "cardiffnlp" in lower: return "cardiffnlp/twitter-roberta-base"
    return backbone_name


set_all_seeds(SEED)

all_texts: List[str] = []
for p in PART_PATHS:
    all_texts.extend(load_table_texts(p))
print(f"[DAPT] loaded {len(all_texts):,} lines from {len(PART_PATHS)} part(s)")

if TASKB_LEAK_GUARD and Path(TASKB_LEAK_GUARD).exists():
    tb = load_table_texts(TASKB_LEAK_GUARD)
    leak_set = set(tb)
    before = len(all_texts)
    all_texts = [t for t in all_texts if t not in leak_set]
    print(f"[DAPT] leak-guard removed {before - len(all_texts):,} lines that matched TaskB")
else:
    print("[DAPT] leak-guard skipped (TaskB file not found)")

random.shuffle(all_texts)
n_total = len(all_texts)
n_val   = min(max(int(n_total*VAL_FRAC), VAL_MIN), VAL_MAX)
val_texts = all_texts[:n_val]; trn_texts = all_texts[n_val:]
print(f"[DAPT] split -> train={len(trn_texts):,}, val={len(val_texts):,}")


BACKBONE_NAME = "cardiffnlp/twitter-roberta-base"  
TOKENIZER = AutoTokenizer.from_pretrained(BACKBONE_NAME, use_fast=True)

added_tokens = 0
for tok in ["<url>","@user"]:
    if TOKENIZER.convert_tokens_to_ids(tok) == TOKENIZER.unk_token_id:
        TOKENIZER.add_tokens([tok]); added_tokens += 1
if added_tokens: print(f"[DAPT] tokenizer extended by {added_tokens} tokens")

mlm_name = as_mlm_checkpoint(BACKBONE_NAME)
mlm = AutoModelForMaskedLM.from_pretrained(mlm_name)   
if added_tokens: mlm.resize_token_embeddings(len(TOKENIZER))
if USE_GRAD_CHKPT and hasattr(mlm, "gradient_checkpointing_enable"):
    mlm.gradient_checkpointing_enable()


peft_available = True
try:
    from peft import LoraConfig, get_peft_model
    try:
        from peft import TaskType
        _TASKTYPE = (
            getattr(TaskType, "MASKED_LM", None)
            or getattr(TaskType, "TOKEN_CLS", None)
            or getattr(TaskType, "FEATURE_EXTRACTION", None)
            or getattr(TaskType, "SEQ_CLS", None)
            or getattr(TaskType, "CAUSAL_LM", None)
        )
    except Exception:
        TaskType = None
        _TASKTYPE = None
except Exception as e:
    peft_available = False
    print("[DAPT] PEFT not available; falling back to full-model. Err:", e)

if USE_LORA and peft_available:
    lora_kwargs = dict(
        r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT,
        target_modules=LORA_TARGETS, bias="none",
    )
    if _TASKTYPE is not None:
        lora_kwargs["task_type"] = _TASKTYPE
    lcfg = LoraConfig(**lora_kwargs)
    mlm = get_peft_model(mlm, lcfg)
    if USE_GRAD_CHKPT and hasattr(mlm, "gradient_checkpointing_enable"):
        mlm.gradient_checkpointing_enable()
    print("[DAPT] LoRA attached for MLM:", LORA_TARGETS)
else:
    for p in mlm.parameters(): p.requires_grad = True
    if LR > 1e-5:
        print(f"[DAPT] Lowering LR for full-model DAPT from {LR:g} -> 5e-6")
        LR = 5e-6

mlm.to(DEVICE)


def count_params(m):
    total = sum(p.numel() for p in m.parameters())
    train = sum(p.numel() for p in m.parameters() if p.requires_grad)
    return total, train
_tot, _tr = count_params(mlm)
print(f"[DAPT] params total={_tot/1e6:.1f}M | trainable={_tr/1e6:.2f}M")


class LineDataset(Dataset):
    def __init__(self, lines: List[str]): self.lines = lines
    def __len__(self): return len(self.lines)
    def __getitem__(self, idx): return self.lines[idx]

train_ds = LineDataset(trn_texts); val_ds = LineDataset(val_texts)
collator = DataCollatorForLanguageModeling(tokenizer=TOKENIZER, mlm=True, mlm_probability=0.15)

def collate_fn(batch_lines: List[str]):
    enc = TOKENIZER(batch_lines, padding=True, truncation=True, max_length=MAX_LEN, return_tensors="pt")
    features = [{"input_ids": ids} for ids in enc["input_ids"]]
    masked = collator(features)
    masked["attention_mask"] = enc["attention_mask"]
    return masked

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collate_fn, drop_last=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE*2, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=collate_fn)


from torch.optim import AdamW
optim = AdamW(mlm.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
warmup_steps = max(10, int(TOTAL_STEPS * WARMUP_RATIO))
sched = get_cosine_schedule_with_warmup(optim, num_warmup_steps=warmup_steps, num_training_steps=TOTAL_STEPS)
scaler = torch.amp.GradScaler(enabled=True)

os.makedirs(SAVE_ADAPTER_DIR, exist_ok=True); os.makedirs(SAVE_MERGED_DIR, exist_ok=True)


best_val_loss = float("inf"); global_step = 0; last_log_t = time.time()

def eval_mlm(model) -> float:
    model.eval(); total = 0.0; count = 0
    with torch.no_grad():
        for batch in val_loader:
            ids = batch["input_ids"].to(DEVICE, non_blocking=True)
            msk = batch["attention_mask"].to(DEVICE, non_blocking=True)
            lab = batch["labels"].to(DEVICE, non_blocking=True)
            loss = model(input_ids=ids, attention_mask=msk, labels=lab).loss
            total += loss.item() * ids.size(0); count += ids.size(0)
    model.train(); return total / max(1, count)


print(f"[DAPT] Config: steps={TOTAL_STEPS}, val_every={VAL_EVERY_STEPS}, ckpt_every={CKPT_EVERY_STEPS}, "
      f"batch={BATCH_SIZE}×accum{GRAD_ACCUM_STEPS}, max_len={MAX_LEN}, lr={LR:g}, warmup={WARMUP_RATIO}, "
      f"lora={USE_LORA}, r={LORA_R}, alpha={LORA_ALPHA}, dropout={LORA_DROPOUT}, grad_ckpt={USE_GRAD_CHKPT}")

mlm.train(); torch.cuda.empty_cache()
print(f"[DAPT] START | TOTAL_STEPS={TOTAL_STEPS:,} | warmup_steps={warmup_steps:,} | LR={LR:g}")

bar = tqdm(total=TOTAL_STEPS, desc="[DAPT] steps", unit="step") if _HAS_TQDM else None
accum = 0; done = False
while not done:
    for batch in train_loader:
        ids = batch["input_ids"].to(DEVICE, non_blocking=True)
        msk = batch["attention_mask"].to(DEVICE, non_blocking=True)
        lab = batch["labels"].to(DEVICE, non_blocking=True)

        with torch.amp.autocast(device_type="cuda", enabled=torch.cuda.is_available()):
            out = mlm(input_ids=ids, attention_mask=msk, labels=lab)
            loss = out.loss

        scaler.scale(loss / GRAD_ACCUM_STEPS).backward(); accum += 1
        if accum % GRAD_ACCUM_STEPS == 0:
            scaler.unscale_(optim)
            torch.nn.utils.clip_grad_norm_(mlm.parameters(), 1.0)
            scaler.step(optim); scaler.update()
            optim.zero_grad(set_to_none=True); sched.step()
            global_step += 1

            if bar is not None:
                bar.update(1)
                if global_step % 50 == 0:
                    bar.set_postfix(loss=f"{loss.item():.4f}", lr=f"{sched.get_last_lr()[0]:.2e}")
            elif global_step % 100 == 0:
                dt = time.time() - last_log_t
                print(f"[DAPT] step {global_step:>6}/{TOTAL_STEPS} | loss={loss.item():.4f} | {dt:.1f}s/100 steps")
                last_log_t = time.time()

            if global_step % VAL_EVERY_STEPS == 0 or global_step == TOTAL_STEPS:
                val_loss = eval_mlm(mlm)
                ppl = math.exp(min(20.0, val_loss))
                improved = val_loss < best_val_loss - 1e-4
                if improved: best_val_loss = val_loss
                msg = f"[DAPT]   VAL @ {global_step}: loss={val_loss:.4f} | ppl≈{ppl:.2f} | {'*BEST*' if improved else ''}"
                print(msg); 
                if bar is not None: bar.set_postfix_str(msg.replace("[DAPT]   ", ""))

            if global_step % CKPT_EVERY_STEPS == 0 and USE_LORA and peft_available:
                try:
                    mlm.save_pretrained(SAVE_ADAPTER_DIR); TOKENIZER.save_pretrained(SAVE_ADAPTER_DIR)
                    print(f"[DAPT]   adapter snapshot saved -> {SAVE_ADAPTER_DIR}")
                except Exception as e:
                    print("[DAPT]   snapshot save failed:", repr(e))

            if global_step >= TOTAL_STEPS:
                done = True; break

if bar is not None: bar.close()
print(f"[DAPT] DONE at step {global_step} | best_val_loss={best_val_loss:.4f}")


if USE_LORA and peft_available:
    try:
        mlm.save_pretrained(SAVE_ADAPTER_DIR); TOKENIZER.save_pretrained(SAVE_ADAPTER_DIR)
        print(f"[DAPT] adapter saved -> {SAVE_ADAPTER_DIR}")
    except Exception as e:
        print("[DAPT] adapter save failed:", repr(e))

if MERGE_AND_SAVE and USE_LORA and peft_available:
    try:
        from peft import PeftModel
        merged = mlm.merge_and_unload() if isinstance(mlm, PeftModel) else mlm
        merged.save_pretrained(SAVE_MERGED_DIR); TOKENIZER.save_pretrained(SAVE_MERGED_DIR)
        print(f"[DAPT] merged backbone saved -> {SAVE_MERGED_DIR}")
    except Exception as e:
        print("[DAPT] merge-and-save failed:", repr(e))

ADAPTATION_STATUS = {
    "used_lora": bool(USE_LORA and peft_available),
    "save_adapter_dir": os.path.abspath(SAVE_ADAPTER_DIR),
    "save_merged_dir":  os.path.abspath(SAVE_MERGED_DIR),
    "total_steps": int(TOTAL_STEPS),
    "best_val_loss": float(best_val_loss)
}
print("[DAPT] Adaptation Status:", json.dumps(ADAPTATION_STATUS, indent=2))


In [None]:
import math, torch, os
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer


DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", DEVICE)



_DAPT_MERGED  = "dapt_merged_backbone"
_DAPT_ADAPTER = "dapt_lora_adapter"
_TAPT_MERGED  = "tapt_merged_backbone"
_TAPT_ADAPTER = "tapt_lora_adapter"



try:
    from peft import PeftModel
    _PEFT_OK = True
except Exception:
    _PEFT_OK = False


BACKBONE_NAME = "cardiffnlp/twitter-roberta-base-sentiment-latest"  # base classifier-ish backbone
_TOK_SRC      = BACKBONE_NAME
_USE_ADAPTER  = False
_ADAPTER_DIR  = None





if os.path.isdir(_DAPT_MERGED):
    BACKBONE_NAME = _DAPT_MERGED
    _TOK_SRC = _DAPT_MERGED
    _USE_ADAPTER = False
elif os.path.isdir(_DAPT_ADAPTER) and _PEFT_OK:
    
    BACKBONE_NAME = "cardiffnlp/twitter-roberta-base"
    _TOK_SRC = _DAPT_ADAPTER
    _USE_ADAPTER = True
    _ADAPTER_DIR = _DAPT_ADAPTER
elif os.path.isdir(_DAPT_ADAPTER) and not _PEFT_OK:
    
    BACKBONE_NAME = _DAPT_ADAPTER
    _TOK_SRC = _DAPT_ADAPTER
    _USE_ADAPTER = False
elif os.path.isdir(_TAPT_MERGED):
    BACKBONE_NAME = _TAPT_MERGED
    _TOK_SRC = _TAPT_MERGED
    _USE_ADAPTER = False
elif os.path.isdir(_TAPT_ADAPTER) and _PEFT_OK:
    BACKBONE_NAME = "cardiffnlp/twitter-roberta-base"
    _TOK_SRC = _TAPT_ADAPTER
    _USE_ADAPTER = True
    _ADAPTER_DIR = _TAPT_ADAPTER
elif os.path.isdir(_TAPT_ADAPTER) and not _PEFT_OK:
    BACKBONE_NAME = _TAPT_ADAPTER
    _TOK_SRC = _TAPT_ADAPTER
    _USE_ADAPTER = False
else:
    BACKBONE_NAME = "cardiffnlp/twitter-roberta-base-sentiment-latest"
    _TOK_SRC = BACKBONE_NAME
    _USE_ADAPTER = False


TOKENIZER = AutoTokenizer.from_pretrained(_TOK_SRC, use_fast=True)

print("[ADAPT] merged dirs exist?  DAPT:", os.path.isdir(_DAPT_MERGED), "| TAPT:", os.path.isdir(_TAPT_MERGED))
print("[ADAPT] adapter dirs exist? DAPT:", os.path.isdir(_DAPT_ADAPTER), "| TAPT:", os.path.isdir(_TAPT_ADAPTER))
print("[ADAPT] PEFT available?:", _PEFT_OK)
print("[ADAPT] BACKBONE_NAME:", BACKBONE_NAME, "| TOKENIZER src:", _TOK_SRC, "| USE_ADAPTER:", _USE_ADAPTER)


def _gn_groups(C: int) -> int:
    for g in (8, 4, 2, 1):
        if C % g == 0:
            return g
    return 1

def get_act(name: str):
    if name == 'relu': return nn.ReLU()
    if name == 'silu': return nn.SiLU()
    if name == 'mish': return nn.Mish()
    return nn.GELU()  

def get_norm(name: str, C: int):
    if name == 'bn':   return nn.BatchNorm1d(C)
    if name == 'gn8':  return nn.GroupNorm(_gn_groups(C), C)
    if name == 'ln':   return nn.GroupNorm(1, C)  
    return nn.Identity()

class SE1d(nn.Module):
    def __init__(self, C: int, r: int):
        super().__init__()
        m = max(1, C // r)
        self.fc1 = nn.Linear(C, m)
        self.fc2 = nn.Linear(m, C)
    def forward(self, x):                 
        s = x.mean(dim=2)                 
        s = F.silu(self.fc1(s))
        s = torch.sigmoid(self.fc2(s)).unsqueeze(-1)
        return x * s

class ResidualBranch(nn.Module):
    def __init__(self, core: nn.Module, skip: nn.Module):
        super().__init__()
        self.core = core
        self.skip = skip
    def forward(self, x):
        return self.core(x) + (x if isinstance(self.skip, nn.Identity) else self.skip(x))


def _autocast_off():
    try:
        return torch.amp.autocast(device_type=('cuda' if torch.cuda.is_available() else 'cpu'), enabled=False)
    except Exception:
        return torch.cuda.amp.autocast(enabled=False)


class CNNHead(nn.Module):
    def __init__(self,
                 hidden: int,
                 num_classes: int,
                 layers=2,
                 filters=256,
                 branches=3,
                 kernels=(1,5,7),
                 dilations=(2,8,8),
                 dropout=0.2,
                 pooling='attn',
                 act='gelu',
                 norm='bn',
                 sep=False,
                 groups=8,
                 se_ratio=4,
                 residual=True,
                 kmax_k=1,
                 gem_p=4.0):
        super().__init__()
        self.cfg = dict(layers=layers, filters=filters, branches=branches, kernels=tuple(kernels),
                        dilations=tuple(dilations), dropout=dropout, pooling=pooling, act=act, norm=norm,
                        sep=sep, groups=groups, se_ratio=se_ratio, residual=residual, kmax_k=kmax_k, gem_p=gem_p)
        b = branches
        f_total = filters
        f_per = max(1, f_total // b)

        self.branches = nn.ModuleList()
        for bi, k in enumerate(kernels):
            core = []
            in_ch = hidden
            d = dilations[bi] if bi < len(dilations) else 1
            pad = ((k - 1) * d) // 2

            use_res = residual
            skip = None

            for li in range(layers):
                if sep:
                    
                    core.append(nn.Conv1d(in_ch, in_ch, kernel_size=k, padding=pad, dilation=d, groups=in_ch))
                    core.append(get_act(act)); core.append(get_norm(norm, in_ch))
                    core.append(nn.Conv1d(in_ch, f_per, kernel_size=1))
                else:
                    
                    g_req = int(groups)
                    if (in_ch % g_req == 0) and (f_per % g_req == 0):
                        g_ok = g_req
                    else:
                        g_ok = 1
                        for g in (8,4,2,1):
                            if (in_ch % g == 0) and (f_per % g == 0):
                                g_ok = g
                                break
                    core.append(nn.Conv1d(in_ch, f_per, kernel_size=k, padding=pad, dilation=d, groups=g_ok))
                core.append(get_act(act))
                core.append(get_norm(norm, f_per))
                if se_ratio and se_ratio > 0:
                    core.append(SE1d(f_per, se_ratio))

                in_ch = f_per
                if use_res and skip is None:
                    skip = nn.Identity() if hidden == f_per else nn.Conv1d(hidden, f_per, 1)

            branch_core = nn.Sequential(*core)
            self.branches.append(ResidualBranch(branch_core, skip if (use_res and skip is not None) else nn.Identity()) if use_res else branch_core)

        concat_ch = f_per * b
        self.proj = nn.Conv1d(concat_ch, f_total, kernel_size=1) if concat_ch != f_total else None

        self.pooling = pooling
        if self.pooling == 'attn':
            self.attn = nn.Linear(f_total, 1)
        elif self.pooling == 'gem':
            self.gem_p = nn.Parameter(torch.tensor(float(gem_p)))
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(f_total, num_classes)

    def _gem(self, x, eps=1e-6):  
        p = self.gem_p
        x = x.clamp(min=eps).pow(p)
        x = x.mean(dim=2)
        return x.pow(1.0/p)

    def forward(self, last_hidden: torch.Tensor, attn_mask: torch.Tensor):
        
        with _autocast_off():
            x = last_hidden.transpose(1, 2).contiguous().float()  

            feats_list = [branch(x) for branch in self.branches]
            feats = torch.cat(feats_list, dim=1)
            if self.proj is not None:
                feats = self.proj(feats)

            if self.pooling == 'max':
                x_out = feats.amax(dim=2)
            elif self.pooling == 'avg':
                x_out = feats.mean(dim=2)
            elif self.pooling == 'gem':
                x_out = self._gem(feats)
            elif self.pooling == 'kmax':
                mask = (attn_mask == 0)[:, None, :]
                feats_m = feats.masked_fill(mask, torch.finfo(feats.dtype).min)
                k = max(1, int(self.cfg.get('kmax_k', 1)))
                vals, _ = torch.topk(feats_m, k, dim=2)
                x_out = vals.mean(dim=2)
            else:  # 'attn'
                feats_T = feats.transpose(1, 2)                    
                logits = self.attn(feats_T)                        
                mask = (attn_mask == 0).unsqueeze(-1)              
                logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
                w = torch.softmax(logits, dim=1)                   
                x_out = (feats_T * w).sum(dim=1)                   

            x_out = self.dropout(x_out)
            return self.out(x_out)

class Backbone(nn.Module):
    def __init__(self, name: str):
        super().__init__()
        
        self.model = AutoModel.from_pretrained(name)

        
        if _USE_ADAPTER and _PEFT_OK and (_ADAPTER_DIR is not None) and os.path.isdir(_ADAPTER_DIR):
            try:
                self.model = PeftModel.from_pretrained(self.model, _ADAPTER_DIR)
                print("Loaded adapter from:", _ADAPTER_DIR)
            except Exception as _e:
                print("PEFT adapter load skipped:", _e)

        
        try:
            vocab_model = self.model.get_input_embeddings().weight.shape[0]
            vocab_tok   = len(TOKENIZER)
            if vocab_tok != vocab_model and hasattr(self.model, "resize_token_embeddings"):
                print(f"[ADAPT] Resizing token embeddings: {vocab_model} -> {vocab_tok}")
                self.model.resize_token_embeddings(vocab_tok)
        except Exception as _e:
            print("Embedding resize check failed/skipped:", _e)

        self.hidden = self.model.config.hidden_size

    def forward(self, input_ids, attention_mask):
        out = self.model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        return out.last_hidden_state

class SentimentModel(nn.Module):
    def __init__(self, num_classes=3):
        super().__init__()
        self.backbone = Backbone(BACKBONE_NAME)
        self.head = CNNHead(
            hidden=self.backbone.hidden,
            num_classes=num_classes,
            layers=2, filters=256, branches=3,
            kernels=(1,5,7), dilations=(2,8,8),
            dropout=0.2, pooling='attn',
            act='gelu', norm='bn',
            sep=False, groups=8, se_ratio=4, residual=True,
            kmax_k=1, gem_p=4.0
        )
        self.backbone_frozen = False
        
        try:
            self.backbone.model.gradient_checkpointing_enable()
        except Exception:
            pass

    def freeze_backbone(self, freeze=True):
        for p in self.backbone.parameters():
            p.requires_grad = not freeze
        self.backbone_frozen = freeze

    def forward(self, input_ids, attention_mask):
        if self.backbone_frozen:
            with torch.no_grad():
                last_hidden = self.backbone(input_ids, attention_mask)
            last_hidden = last_hidden.detach()
        else:
            last_hidden = self.backbone(input_ids, attention_mask)
        return self.head(last_hidden, attention_mask)


NUM_CLASSES = 3
model = SentimentModel(num_classes=NUM_CLASSES).to(DEVICE)


print("Backbone:", BACKBONE_NAME, "| hidden:", model.backbone.hidden)
cfg = model.head.cfg
print("Head cfg:", cfg)
for i, br in enumerate(model.head.branches):
    core = br.core if isinstance(br, ResidualBranch) else br
    groups_seq = [m.groups for m in core if isinstance(m, nn.Conv1d)]
    print(f"branch {i} kernel={cfg['kernels'][i]} dil={cfg['dilations'][i]} groups(seq)={groups_seq}")


In [None]:
import os, hashlib
import pandas as pd, numpy as np
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler
from transformers import DataCollatorWithPadding


SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

VAL_FRAC = 0.10
MAX_LEN  = 160   
BATCH    = 64
NUM_WORKERS = 0


assert 'TOKENIZER' in globals(), "Please run Cell 1 first so TOKENIZER is defined."


DATASET = 'B'   
CSV_A = "taskA_youtube_raw.csv"
CSV_B = "taskB_youtube_raw.csv"
CSV_PATH = CSV_A if DATASET=='A' else CSV_B

TEXT_COL  = "text"
LABEL_COL = "label"

class YouTubeCommentDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=160):
        self.texts = texts
        self.labels = labels
        self.tok = tokenizer
        self.max_len = max_len
    def __len__(self):
        return len(self.texts)
    def __getitem__(self, i):
        
        enc = self.tok(
            str(self.texts[i]),
            truncation=True,
            padding=False,
            max_length=self.max_len
        )
        return {
            'input_ids': enc['input_ids'],
            'attention_mask': enc['attention_mask'],
            'labels': int(self.labels[i])
        }


COLLATE = DataCollatorWithPadding(tokenizer=TOKENIZER, pad_to_multiple_of=8, return_tensors="pt")

def build_loader(texts, labels, batch=BATCH, shuffle=False):
    ds = YouTubeCommentDataset(texts, labels, TOKENIZER, MAX_LEN)
    if shuffle:
        
        g = torch.Generator()
        g.manual_seed(SEED)
        sampler = RandomSampler(ds, generator=g)
        return DataLoader(
            ds, batch_size=batch, sampler=sampler,
            num_workers=NUM_WORKERS, pin_memory=True, collate_fn=COLLATE
        )
    else:
        return DataLoader(
            ds, batch_size=batch, shuffle=False,
            num_workers=NUM_WORKERS, pin_memory=True, collate_fn=COLLATE
        )

df = pd.read_csv(CSV_PATH, encoding="utf-8", engine="python").copy()
df[LABEL_COL] = df[LABEL_COL].astype(int)

train_df, val_df = train_test_split(
    df, test_size=VAL_FRAC, stratify=df[LABEL_COL], random_state=SEED
)

def _checksum_index(idx):
    a = np.asarray(idx, dtype=np.int64)
    return hashlib.md5(a.tobytes()).hexdigest()[:10]

NUM_CLASSES = int(df[LABEL_COL].nunique())  # for Cell 3

print(f"Dataset: {CSV_PATH}")
print("Counts -> train/val:", len(train_df), len(val_df))
print("Index checksums -> train:", _checksum_index(train_df.index),
      " val:", _checksum_index(val_df.index))
print("NUM_CLASSES:", NUM_CLASSES)

Xtr, ytr = train_df[TEXT_COL].astype(str).tolist(), train_df[LABEL_COL].astype(int).tolist()
Xva, yva = val_df[TEXT_COL].astype(str).tolist(),   val_df[LABEL_COL].astype(int).tolist()

train_loader = build_loader(Xtr, ytr, shuffle=True)
val_loader   = build_loader(Xva, yva, shuffle=False)


In [None]:
#TAPT (This doesn't get used)
import os, re, shutil
import pandas as pd
from typing import List

import torch
from torch.utils.data import Dataset
from transformers import (
    AutoTokenizer, AutoModelForMaskedLM,
    DataCollatorForLanguageModeling, Trainer, TrainingArguments
)
from peft import LoraConfig, get_peft_model  



BASE_FOR_TAPT = "dapt_merged_backbone" if os.path.isdir("dapt_merged_backbone") else "cardiffnlp/twitter-roberta-base"

try:
    MAX_LEN
except NameError:
    MAX_LEN = 160


try:
    VAL_FRAC
except NameError:
    VAL_FRAC = 0.10
try:
    SEED
except NameError:
    SEED = 42


try:
    CSV_PATH
except NameError:
    
    CSV_PATH = "taskB_youtube_raw.csv"


try:
    TEXT_COL
except NameError:
    TEXT_COL = "text"
try:
    LABEL_COL
except NameError:
    LABEL_COL = "label"

TOTAL_STEPS_TAPT     = 12_000     
WARMUP_RATIO_TAPT    = 0.06
LR_TAPT              = 1e-4
BATCH_TRAIN_PER_DEV  = 16
GRAD_ACCUM_STEPS     = 2
MLM_PROB             = 0.15

LORA_R               = 8
LORA_ALPHA           = 16
LORA_DROPOUT         = 0.10
OUT_ADAPTER_DIR      = "tapt_lora_adapter"
OUT_MERGED_DIR       = "tapt_merged_backbone"


assert os.path.exists(CSV_PATH), f"TAPT needs your labeled dataset CSV (got {CSV_PATH!r} not found)."

df = pd.read_csv(CSV_PATH)
if TEXT_COL not in df.columns:
    
    for c in ["clean_text", "comment", "body", "content"]:
        if c in df.columns:
            TEXT_COL = c; break
if LABEL_COL not in df.columns:
    for c in ["label", "labels", "target"]:
        if c in df.columns:
            LABEL_COL = c; break

assert TEXT_COL in df.columns, f"Could not find a text column in {CSV_PATH}"
assert LABEL_COL in df.columns, f"Could not find a label column in {CSV_PATH}"



from sklearn.model_selection import train_test_split
train_df, _ = train_test_split(
    df[[TEXT_COL, LABEL_COL]],
    test_size=VAL_FRAC,
    stratify=df[LABEL_COL] if df[LABEL_COL].nunique() > 1 else None,
    random_state=SEED,
)

train_texts = [str(x) for x in train_df[TEXT_COL].dropna().tolist()]

def _basic_clean(s: str) -> str:
    s = re.sub(r"\s+", " ", s).strip()
    return s

train_texts = [_basic_clean(s) for s in train_texts if s.strip()]
train_texts = list(dict.fromkeys(train_texts))  

print(f"[TAPT] Source: {CSV_PATH} | train_texts={len(train_texts)}")

class _MLMDataset(Dataset):
    def __init__(self, texts: List[str], tokenizer, max_len: int):
        self.enc = tokenizer(
            texts,
            truncation=True,
            max_length=max_len,
            padding=False,
            return_attention_mask=True
        )
    def __len__(self): return len(self.enc["input_ids"])
    def __getitem__(self, i):
        return {k: torch.tensor(v[i]) for k, v in self.enc.items()}


print(f"[TAPT] Base backbone: {BASE_FOR_TAPT}")
tokenizer = AutoTokenizer.from_pretrained(BASE_FOR_TAPT, use_fast=True)
model = AutoModelForMaskedLM.from_pretrained(BASE_FOR_TAPT)


LORA_TARGETS = [
    "query","key","value","dense",
    "q_proj","k_proj","v_proj","out_proj",
    "intermediate.dense","output.dense"
]


try:
    from peft import TaskType
    _TASKTYPE = (
        getattr(TaskType, "MASKED_LM", None)
        or getattr(TaskType, "TOKEN_CLS", None)
        or getattr(TaskType, "FEATURE_EXTRACTION", None)
        or getattr(TaskType, "SEQ_CLS", None)
        or getattr(TaskType, "CAUSAL_LM", None)
    )
except Exception:
    TaskType = None
    _TASKTYPE = None

lora_kwargs = dict(
    r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT,
    target_modules=LORA_TARGETS, bias="none",
)
if _TASKTYPE is not None:
    lora_kwargs["task_type"] = _TASKTYPE

lcfg = LoraConfig(**lora_kwargs)
model = get_peft_model(model, lcfg)

try:
    model.print_trainable_parameters()
except Exception as e:
    print("[TAPT] print_trainable_parameters() unavailable:", e)

dataset = _MLMDataset(train_texts, tokenizer, MAX_LEN)
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=MLM_PROB)


fp16 = torch.cuda.is_available()
args = TrainingArguments(
    output_dir=OUT_ADAPTER_DIR,
    overwrite_output_dir=True,
    max_steps=TOTAL_STEPS_TAPT,
    per_device_train_batch_size=BATCH_TRAIN_PER_DEV,
    gradient_accumulation_steps=GRAD_ACCUM_STEPS,
    learning_rate=LR_TAPT,
    warmup_ratio=WARMUP_RATIO_TAPT,
    weight_decay=0.01,
    logging_steps=100,
    save_steps=2_000,
    save_total_limit=2,
    prediction_loss_only=True,
    dataloader_drop_last=False,
    fp16=fp16,
    report_to=[],
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=dataset,
    data_collator=collator,
)

print("[TAPT] Starting training…")
trainer.train()
print("[TAPT] Done.")


for d in [OUT_ADAPTER_DIR, OUT_MERGED_DIR]:
    if os.path.isdir(d):
        shutil.rmtree(d)

trainer.save_model(OUT_ADAPTER_DIR)
tokenizer.save_pretrained(OUT_ADAPTER_DIR)

print("[TAPT] Merging LoRA into base weights…")
merged = trainer.model.merge_and_unload()
os.makedirs(OUT_MERGED_DIR, exist_ok=True)
merged.save_pretrained(OUT_MERGED_DIR)
tokenizer.save_pretrained(OUT_MERGED_DIR)

print(f"[TAPT] Saved adapter -> {OUT_ADAPTER_DIR}")
print(f"[TAPT] Saved merged backbone -> {OUT_MERGED_DIR}")


In [None]:
#Final Training loop
import os, re, math, torch, numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.amp import GradScaler
from sklearn.metrics import accuracy_score, f1_score, classification_report
from transformers import get_cosine_schedule_with_warmup
from tqdm import tqdm


EPOCHS                  = 26
PATIENCE                = 8
FREEZE_BACKBONE_EPOCHS  = 4

HEAD_LR        = 1e-4
BACKBONE_LR    = 1e-5
LR_LAYER_DECAY = 1      
WEIGHT_DECAY   = 0.01
MIXED_PREC     = True
LABEL_SMOOTH   = 0.06
CLASS_WEIGHT_EXP = 0.35
WARMUP_RATIO   = 0.08
EMA_DECAY      = 0.999
R_DROP_ALPHA   = 1.00   
    

try:
    IS_TASK_A = ('taskA' in CSV_PATH)
except NameError:
    IS_TASK_A = False
CKPT_OUT = f"task{'A' if IS_TASK_A else 'B'}_from_scratch_mtnas_parity.pt"


_DAPT_MERGED  = "dapt_merged_backbone"
_DAPT_ADAPTER = "dapt_lora_adapter"
_TAPT_MERGED  = "tapt_merged_backbone"
_TAPT_ADAPTER = "tapt_lora_adapter"

def _safe_name_or_path(m):
    try:
        return getattr(m, "name_or_path", None)
    except Exception:
        return None

def _is_peft_model(m):
    try:
        return hasattr(m, "peft_config") or m.__class__.__name__.lower().startswith("peft")
    except Exception:
        return False

def _detect_adapt_status(model, tokenizer):
    has_dm = os.path.isdir(_DAPT_MERGED)
    has_da = os.path.isdir(_DAPT_ADAPTER)
    has_tm = os.path.isdir(_TAPT_MERGED)
    has_ta = os.path.isdir(_TAPT_ADAPTER)

    core = getattr(model.backbone, "model", model.backbone)
    name_path = _safe_name_or_path(core) or "<unknown>"
    is_peft   = _is_peft_model(core)

    try: tok_vocab = len(tokenizer)
    except Exception: tok_vocab = -1
    try: emb_rows = core.get_input_embeddings().weight.size(0)
    except Exception: emb_rows = -2

    specials = {}
    try:
        vv = tokenizer.get_vocab()
        specials["<url>"] = "<url>" in vv
        specials["@user"] = "@user" in vv
    except Exception:
        specials["<url>"] = None
        specials["@user"] = None

    if name_path.endswith(_DAPT_MERGED):
        verdict = "DAPT merged"
    elif is_peft and has_da:
        verdict = "DAPT adapter (PEFT)"
    elif name_path.endswith(_DAPT_ADAPTER):
        verdict = "DAPT direct-load (no PEFT)"
    elif name_path.endswith(_TAPT_MERGED):
        verdict = "TAPT merged"
    elif is_peft and has_ta:
        verdict = "TAPT adapter (PEFT)"
    elif name_path.endswith(_TAPT_ADAPTER):
        verdict = "TAPT direct-load (no PEFT)"
    else:
        verdict = "Base only"

    warnings = []
    if tok_vocab != emb_rows and emb_rows > 0:
        warnings.append(f"Tokenizer vocab ({tok_vocab}) != embed rows ({emb_rows}) — embeddings NOT resized.")
    if any([has_dm, has_da]) and ("DAPT" not in verdict):
        warnings.append("DAPT artifacts exist, but training is not using them.")
    if any([has_tm, has_ta]) and ("TAPT" not in verdict) and not any([has_dm, has_da]):
        warnings.append("TAPT artifacts exist, but training is not using them.")
    if specials.get("<url>") is False or specials.get("@user") is False:
        warnings.append("Tokenizer missing <url> and/or @user special tokens.")

    return {
        "verdict": verdict,
        "name_or_path": name_path,
        "is_peft": bool(is_peft),
        "tok_vocab": int(tok_vocab),
        "embed_rows": int(emb_rows),
        "specials": specials,
        "has_dirs": {
            "dapt_merged": has_dm, "dapt_adapter": has_da,
            "tapt_merged": has_tm, "tapt_adapter": has_ta
        },
        "warnings": warnings
    }

_status = _detect_adapt_status(model, TOKENIZER)
print("\n==== ADAPTATION STATUS ====")
print(" verdict         :", _status["verdict"])
print(" backbone path   :", _status["name_or_path"])
print(" PEFT attached?  :", _status["is_peft"])
print(" vocab vs embeds :", _status["tok_vocab"], "vs", _status["embed_rows"])
print(" specials        :", _status["specials"])
print(" dirs present    :", _status["has_dirs"])
if _status["warnings"]:
    for w in _status["warnings"]:
        print(" Warning", w)
else:
    print(" All adaptation checks look consistent")
print("========================================\n")


def _emb_vocab_ok_verify(core_model, tok):
    try:
        return core_model.get_input_embeddings().weight.shape[0] == len(tok)
    except Exception:
        return True  

_core_model = getattr(getattr(model, "backbone", model), "model", getattr(model, "backbone", model))
_model_path = getattr(_core_model, "name_or_path", str(type(_core_model)))
_tok_path   = getattr(TOKENIZER, "name_or_path", "?")

print(f"[Verify] model path:     {_model_path}")
print(f"[Verify] tokenizer path: {_tok_path}")
print(f"[Verify] vocab~emb match: {_emb_vocab_ok_verify(_core_model, TOKENIZER)}")

_lora_active = any("lora_" in n for n, p in model.named_parameters())
print(f"[Verify] PEFT/LoRA active? {_lora_active} (False is expected when using merged backbones)")


assert _emb_vocab_ok_verify(_core_model, TOKENIZER), \
    "Tokenizer vocab size and model embedding size mismatch — likely loaded the wrong tokenizer/backbone."


def _set_gradient_checkpointing(model, enabled: bool):
    try:
        core = getattr(model.backbone, "model", model.backbone)
        if enabled and hasattr(core, "gradient_checkpointing_enable"):
            core.gradient_checkpointing_enable()
        if (not enabled) and hasattr(core, "gradient_checkpointing_disable"):
            core.gradient_checkpointing_disable()
    except Exception:
        pass


def _is_nodecay(name: str, p: nn.Parameter) -> bool:
    lname = name.lower()
    return (p.ndim < 2) or ('layernorm' in lname) or ('layer_norm' in lname) or lname.endswith('.bias')

NUM_LAYERS = getattr(getattr(model.backbone, "model", model.backbone).config, "num_hidden_layers", 12)
_layer_pat = re.compile(r'encoder\.layer\.(\d+)\.')

def _layer_id_from_name(n: str) -> int:
    if "embeddings" in n:
        return -1
    m = _layer_pat.search(n)
    if m: 
        return int(m.group(1))      
    return NUM_LAYERS               

def _lr_for_layer(layer_id: int) -> float:
    
    if layer_id == -1:
        depth = NUM_LAYERS + 1
    elif layer_id >= NUM_LAYERS:
        depth = 0
    else:
        depth = (NUM_LAYERS - 1 - layer_id)
    return BACKBONE_LR * (LR_LAYER_DECAY ** depth)


layer_groups = {}
for n, p in model.backbone.named_parameters():
    if not p.requires_grad:
        continue
    lid = _layer_id_from_name(n)
    key = "nodecay" if _is_nodecay(n, p) else "decay"
    layer_groups.setdefault(lid, {"decay": [], "nodecay": []})
    layer_groups[lid][key].append(p)

bb_param_groups = []
for lid in sorted(layer_groups.keys()):
    lr = _lr_for_layer(lid)
    if layer_groups[lid]["decay"]:
        bb_param_groups.append({"params": layer_groups[lid]["decay"],   "lr": lr, "weight_decay": WEIGHT_DECAY})
    if layer_groups[lid]["nodecay"]:
        bb_param_groups.append({"params": layer_groups[lid]["nodecay"], "lr": lr, "weight_decay": 0.0})


hd_decay, hd_nodecay = [], []
for n, p in model.head.named_parameters():
    if not p.requires_grad: 
        continue
    (hd_nodecay if _is_nodecay(n, p) else hd_decay).append(p)

param_groups = bb_param_groups + [
    {"params": hd_decay,   "lr": HEAD_LR, "weight_decay": WEIGHT_DECAY},
    {"params": hd_nodecay, "lr": HEAD_LR, "weight_decay": 0.0},
]

optimizer = AdamW(param_groups)


def _layer_lr_table():
    order = [-1] + list(range(NUM_LAYERS)) + [NUM_LAYERS]
    return [(lid, _lr_for_layer(lid)) for lid in order]

print(f"LLRD: BASE_LR={BACKBONE_LR}, DECAY={LR_LAYER_DECAY}, layers={NUM_LAYERS}")
print("Backbone per-layer LRs:", [(lid, f"{lr:.8f}") for lid, lr in _layer_lr_table()])
print("Total param groups:", len(optimizer.param_groups))


scaler = GradScaler(enabled=(MIXED_PREC and DEVICE.type == 'cuda'))


counts = np.bincount(np.array(ytr, dtype=np.int64), minlength=NUM_CLASSES).astype(float)
w = 1.0 / np.maximum(counts, 1.0) ** CLASS_WEIGHT_EXP
w = w / w.mean()
class_weights = torch.tensor(w, dtype=torch.float32, device=DEVICE)
criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=LABEL_SMOOTH)

steps_per_epoch = len(train_loader)
total_steps = max(1, steps_per_epoch * EPOCHS)
warmup_steps = max(1, int(WARMUP_RATIO * total_steps))
scheduler = get_cosine_schedule_with_warmup(
    optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps
)

def init_ema(m: nn.Module):
    ema = {}
    for n, p in m.named_parameters():
        if p.requires_grad:
            ema[n] = p.detach().clone()
    return ema

@torch.no_grad()
def update_ema(m: nn.Module, ema: dict, decay: float = EMA_DECAY):
    for n, p in m.named_parameters():
        if p.requires_grad:
            
            if n not in ema:
                ema[n] = p.detach().clone()
            ema[n].mul_(decay).add_(p.detach(), alpha=1.0 - decay)

@torch.no_grad()
def swap_in_ema(m: nn.Module, ema: dict):
    backup = {}
    for n, p in m.named_parameters():
        if p.requires_grad and n in ema:
            backup[n] = p.detach().clone()
            p.copy_(ema[n])
    return backup

@torch.no_grad()
def restore_from_backup(m: nn.Module, backup: dict):
    for n, p in m.named_parameters():
        if p.requires_grad and n in backup:
            p.copy_(backup[n])

def build_ema_state_dict(m: nn.Module, ema: dict):
    sd = m.state_dict()
    for n, p in m.named_parameters():
        if p.requires_grad and n in ema:
            sd[n] = ema[n].detach().clone().to(sd[n].device)
    return sd

ema_state = init_ema(model)

print("Trainable params:", sum(p.numel() for p in model.parameters() if p.requires_grad))
print("Backbone trainable?:", any(p.requires_grad for p in model.backbone.parameters()))
print(f"Class counts: {counts.tolist()} | weights(exp={CLASS_WEIGHT_EXP}): {np.round(w, 4).tolist()} | label_smooth={LABEL_SMOOTH}")

best_acc = -1.0
best_f1  = -1.0
best_state = None
epochs_no_improve = 0
TARGET_NAMES = {0: "negative", 1: "neutral", 2: "positive"}


for ep in range(1, EPOCHS + 1):
    print(f"\nEpoch {ep}/{EPOCHS}")

    
    frozen = (ep <= FREEZE_BACKBONE_EPOCHS)
    model.freeze_backbone(frozen)
    _set_gradient_checkpointing(model, enabled=not frozen)
    print("  Backbone frozen?" , frozen)

    
    model.train()
    total_loss = 0.0
    y_true, y_pred = [], []

    for batch in tqdm(train_loader, desc="Training"):
        ids = batch['input_ids'].to(DEVICE, non_blocking=True)
        msk = batch['attention_mask'].to(DEVICE, non_blocking=True)
        lab = batch['labels'].to(DEVICE, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast("cuda", enabled=(MIXED_PREC and DEVICE.type == 'cuda')):
            if R_DROP_ALPHA > 0.0:
                
                logits1 = model(ids, msk)
                logits2 = model(ids, msk)
                ce = 0.5 * (criterion(logits1, lab) + criterion(logits2, lab))

                p = F.log_softmax(logits1, dim=1); q = F.log_softmax(logits2, dim=1)
                kl = 0.5 * (F.kl_div(p, q.exp(), reduction="batchmean") +
                            F.kl_div(q, p.exp(), reduction="batchmean"))
                loss = ce + R_DROP_ALPHA * kl
                logits_for_metrics = (logits1 + logits2) / 2.0
            else:
                logits_for_metrics = model(ids, msk)
                loss = criterion(logits_for_metrics, lab)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        update_ema(model, ema_state, decay=EMA_DECAY)

        bs = lab.size(0)
        total_loss += loss.item() * bs
        y_true.extend(lab.detach().cpu().numpy().tolist())
        y_pred.extend(logits_for_metrics.argmax(dim=1).detach().cpu().numpy().tolist())

    train_acc = accuracy_score(y_true, y_pred)
    train_loss = total_loss / max(1, len(train_loader.dataset))
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")

    
    model.eval()
    backup_params = swap_in_ema(model, ema_state)

    vy_true, vy_pred = [], []
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            ids = batch["input_ids"].to(DEVICE, non_blocking=True)
            msk = batch["attention_mask"].to(DEVICE, non_blocking=True)
            lab = batch["labels"].to(DEVICE, non_blocking=True)
            with torch.amp.autocast("cuda", enabled=(MIXED_PREC and DEVICE.type == 'cuda')):
                logits = model(ids, msk)
            vy_true.extend(lab.detach().cpu().tolist())
            vy_pred.extend(logits.argmax(dim=1).detach().cpu().tolist())

    val_acc = accuracy_score(vy_true, vy_pred)
    val_f1  = f1_score(vy_true, vy_pred, average="macro")
    print(f"Validation Accuracy: {val_acc:.4f} | Macro-F1: {val_f1:.4f}")

    present = sorted(np.unique(vy_true).tolist())
    print(classification_report(
        vy_true, vy_pred,
        labels=present,
        target_names=[{0:"negative",1:"neutral",2:"positive"}.get(i, f"class_{i}") for i in present],
        digits=4, zero_division=0
    ))

    
    improved = (val_acc > best_acc) or (abs(val_acc - best_acc) < 1e-6 and val_f1 > best_f1)
    if improved:
        best_acc, best_f1 = val_acc, val_f1
        epochs_no_improve = 0
        best_state = {k: v.cpu() for k, v in build_ema_state_dict(model, ema_state).items()}
        torch.save({
            "arch": getattr(model.head, "cfg", None),
            "state_dict": best_state,
            "backbone": _safe_name_or_path(getattr(model, "backbone", None).model) if hasattr(getattr(model, "backbone", None), "model") else None,
            "num_classes": NUM_CLASSES,
            "val_acc": float(val_acc),
            "val_f1": float(val_f1),
            "ema_decay": EMA_DECAY,
            "label_smoothing": LABEL_SMOOTH,
            "class_weights": w.tolist(),
            "class_weight_exp": CLASS_WEIGHT_EXP,
            "adapt_meta": _status,
            "r_drop_alpha": R_DROP_ALPHA,
            "llrd": {"base_lr": BACKBONE_LR, "layer_decay": LR_LAYER_DECAY},
        }, CKPT_OUT)
        print(f" New best (EMA) saved (acc={val_acc:.4f}, f1={val_f1:.4f}) → {CKPT_OUT}")
    else:
        epochs_no_improve += 1
        print(f"No improvement for {epochs_no_improve} epoch(s).")
        if epochs_no_improve >= PATIENCE:
            print("\n Early stopping triggered.")

    restore_from_backup(model, backup_params)
    if epochs_no_improve >= PATIENCE:
        break


if best_state is not None:
    model.load_state_dict(best_state, strict=True)
model.eval()
print("\n Best (EMA) model reloaded. Best Validation Accuracy:", f"{best_acc:.4f}")
