In [None]:
# 0) Setup
import os
import math
import random
from dataclasses import dataclass
from typing import List, Dict, Any, Optional, Tuple

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, AutoModel, get_cosine_with_hard_restarts_schedule_with_warmup

In [None]:
def set_seed(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

print("torch:", torch.__version__)

In [None]:
# 1) Config
@dataclass
class Config:
    # Models
    student_name: str = "bert-base-uncased"
    teacher_name: str = "bert-base-uncased"

    # Data
    train_csv: str = "/kaggle/input/multitask-data/merged_9_data_3k_each_ver2.csv"
    text_col: str = "text"
    max_length: int = 256

    # Training
    batch_size: int = 8
    epochs: int = 10
    lr: float = 2e-5
    weight_decay: float = 0.01
    temperature: float = 0.07
    grad_clip: float = 1.0

    # Pooling
    student_pool: str = "cls"  # "cls" or "mean"
    teacher_embed_token: Optional[str] = None  # None works for BGE-M3

    # MRL baseline task (Matryoshka InfoNCE)
    nested_dims: Optional[List[int]] = None  # if None -> auto after student load

    # Optional SAMD components
    use_teacher_kd: bool = True  
    use_attn_cka: bool = True

    # Teacher KD (prefix cosine) dims (auto-filtered to <= d_s)
    kd_dims: Optional[List[int]] = None  # if None -> [128,256,384,512,d_s] filtered

    # Weights + ramps
    w_task: float = 1.0
    alpha_kd: float = 1.0

    beta_kd_max: float = 1.0
    beta_kd_start: int = 0
    beta_kd_ramp: int = 1000

    alpha_attn_max: float = 0.1
    alpha_attn_start: int = 200
    alpha_attn_ramp: int = 1000

    # Attention distill controls
    att_every: int = 10
    att_layer: str = "last"   # "last" or "mid"
    min_coverage: float = 0.30
    top_frac: float = 0.5
    min_tokens: int = 1

    # Scheduler
    warmup_ratio: float = 0.04
    num_restarts: int = 1

cfg = Config()

device_s = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device_t = torch.device("cuda:1" if (torch.cuda.device_count() > 1) else device_s)
print("device_s:", device_s, "| device_t:", device_t, "| n_gpu:", torch.cuda.device_count())

In [None]:
# 2) Pooling + sentence embedding helpers
def mean_pooling(last_hidden: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    # last_hidden: [B, L, D], attention_mask: [B, L]
    mask = attention_mask.unsqueeze(-1).type_as(last_hidden)
    summed = (last_hidden * mask).sum(dim=1)
    counts = mask.sum(dim=1).clamp(min=1e-9)
    return summed / counts

def get_sentence_emb(last_hidden: torch.Tensor, attention_mask: torch.Tensor, mode: str) -> torch.Tensor:
    if mode == "mean":
        return mean_pooling(last_hidden, attention_mask)
    return last_hidden[:, 0, :]  # CLS

@torch.inference_mode()
def extract_teacher_sentence_embedding(
    teacher_last_hidden: torch.Tensor,
    teacher_attention_mask: torch.Tensor,
    teacher_input_ids: torch.Tensor,
    teacher_tokenizer,
    embed_token: Optional[str] = None,
) -> torch.Tensor:
    # General rule:
    # - If embed_token exists in sequence -> use that token hidden state.
    # - Else -> mean pool (works for BGE-M3).
    if embed_token is not None:
        token_id = teacher_tokenizer.convert_tokens_to_ids(embed_token)
        if token_id is not None and token_id != teacher_tokenizer.unk_token_id:
            B, L, D = teacher_last_hidden.shape
            out = []
            for b in range(B):
                idx = (teacher_input_ids[b] == token_id).nonzero(as_tuple=False)
                if idx.numel() > 0:
                    out.append(teacher_last_hidden[b, idx[0,0], :])
                else:
                    out.append(mean_pooling(teacher_last_hidden[b:b+1], teacher_attention_mask[b:b+1])[0])
            return torch.stack(out, dim=0)
    return mean_pooling(teacher_last_hidden, teacher_attention_mask)

In [None]:
# 3) Losses: InfoNCE + Matryoshka InfoNCE (MRL task) + prefix KD + Linear CKA
def info_nce(q: torch.Tensor, k: torch.Tensor, temperature: float = 0.07) -> torch.Tensor:
    q = F.normalize(q, dim=-1)
    k = F.normalize(k, dim=-1)
    logits = (q @ k.T) / temperature
    labels = torch.arange(q.size(0), device=q.device)
    return F.cross_entropy(logits, labels)

def matryoshka_infonce(
    a: torch.Tensor,
    b: torch.Tensor,
    temperature: float,
    nested_dims: List[int],
    dim_weight: str = "uniform",  # "uniform" or "inverse_sqrt"
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
    # Matches your MRL baseline: sum of InfoNCE across nested prefix dims.
    assert a.dim() == 2 and b.dim() == 2
    assert a.shape == b.shape

    full_dim = a.size(1)
    total = a.new_tensor(0.0)
    all_logits: Dict[str, torch.Tensor] = {}

    dims = [d for d in nested_dims if d <= full_dim]
    if len(dims) == 0:
        raise ValueError(f"No nested_dims <= full_dim={full_dim}. nested_dims={nested_dims}")

    if dim_weight == "uniform":
        weights = [1.0 for _ in dims]
    elif dim_weight == "inverse_sqrt":
        weights = [1.0 / math.sqrt(d) for d in dims]
        s = sum(weights)
        weights = [w / s * len(weights) for w in weights]  # normalize mean weight ~ 1
    else:
        raise ValueError(f"Unknown dim_weight={dim_weight}")

    for d, w in zip(dims, weights):
        q = F.normalize(a[:, :d], dim=-1)
        k = F.normalize(b[:, :d], dim=-1)
        logits = (q @ k.T) / temperature
        labels = torch.arange(q.size(0), device=q.device)
        loss_d = F.cross_entropy(logits, labels)
        total = total + (w * loss_d)
        all_logits[f"dim_{d}"] = logits

    return total, all_logits

def matryoshka_prefix_cosine_kd(
    s: torch.Tensor,
    t: torch.Tensor,
    dims: List[int],
    dim_weight: str = "inverse_sqrt",
) -> torch.Tensor:
    # Prefix cosine distillation: sum_d (1 - cosine(s[:d], t[:d])).
    assert s.shape == t.shape and s.dim() == 2
    full_dim = s.size(1)
    dims = [d for d in dims if d <= full_dim]
    if len(dims) == 0:
        return s.new_tensor(0.0)

    if dim_weight == "uniform":
        weights = [1.0 for _ in dims]
    elif dim_weight == "inverse_sqrt":
        weights = [1.0 / math.sqrt(d) for d in dims]
        ssum = sum(weights)
        weights = [w / ssum * len(weights) for w in weights]
    else:
        raise ValueError(f"Unknown dim_weight={dim_weight}")

    total = s.new_tensor(0.0)
    for d, w in zip(dims, weights):
        cs = F.cosine_similarity(s[:, :d], t[:, :d], dim=-1)
        total = total + w * (1.0 - cs).mean()
    return total

def _center_gram(K: torch.Tensor) -> torch.Tensor:
    n = K.size(0)
    one = torch.ones((n, n), device=K.device, dtype=K.dtype) / n
    return K - one @ K - K @ one + one @ K @ one

def linear_cka_from_grams(K: torch.Tensor, L: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    # CKA for Gram matrices (n x n). Returns similarity in [0,1].
    Kc = _center_gram(K)
    Lc = _center_gram(L)
    hsic = (Kc * Lc).sum()
    normK = torch.sqrt((Kc * Kc).sum().clamp_min(eps))
    normL = torch.sqrt((Lc * Lc).sum().clamp_min(eps))
    return (hsic / (normK * normL)).clamp(0.0, 1.0)

def linear_cka_loss(K: torch.Tensor, L: torch.Tensor) -> torch.Tensor:
    return 1.0 - linear_cka_from_grams(K, L)

In [None]:
# 4) Span overlap alignment matrix A (student tokens x teacher tokens)
def build_span_overlap_matrix(
    offsets_s: torch.Tensor,  # [L_s, 2]
    offsets_t: torch.Tensor,  # [L_t, 2]
    eps: float = 1e-12,
) -> torch.Tensor:
    # Build soft alignment A based on char-span overlap (row-normalized).
    device = offsets_s.device
    offsets_t = offsets_t.to(device)

    s_start = offsets_s[:, 0].unsqueeze(1)
    s_end   = offsets_s[:, 1].unsqueeze(1)
    t_start = offsets_t[:, 0].unsqueeze(0)
    t_end   = offsets_t[:, 1].unsqueeze(0)

    inter_start = torch.maximum(s_start, t_start)
    inter_end   = torch.minimum(s_end, t_end)
    inter = (inter_end - inter_start).clamp(min=0)

    len_s = (s_end - s_start).clamp(min=0)
    len_t = (t_end - t_start).clamp(min=0)
    union = (len_s + len_t - inter).clamp(min=eps)

    A = inter / union  # [Ls, Lt]
    row_sum = A.sum(dim=1, keepdim=True).clamp(min=eps)
    A = A / row_sum
    return A

def coverage_from_A(A: torch.Tensor, eps: float = 1e-12) -> Tuple[float, float]:
    row_has = (A.sum(dim=1) > eps).float().mean().item()
    col_has = (A.sum(dim=0) > eps).float().mean().item()
    return row_has, col_has

In [None]:
# 5) Attention CKA loss (Span-aware)
def token_importance_from_attention(att: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
    # att: [L, L] (head-avg). importance: [L]
    imp = (att.abs().sum(dim=0) + att.abs().sum(dim=1))
    if mask is not None:
        imp = imp * mask.float()
    return imp

def select_top_tokens(importance: torch.Tensor, mask: torch.Tensor, top_frac: float, min_tokens: int) -> torch.Tensor:
    valid_idx = torch.where(mask > 0)[0]
    if valid_idx.numel() == 0:
        return valid_idx
    imp_valid = importance[valid_idx]
    k = max(min_tokens, int(math.ceil(top_frac * valid_idx.numel())))
    k = min(k, valid_idx.numel())
    topk = torch.topk(imp_valid, k=k, largest=True).indices
    return valid_idx[topk]

def compute_span_cka_att_loss(
    att_s: torch.Tensor,               # [B, H, Ls, Ls]
    att_t: torch.Tensor,               # [B, H, Lt, Lt]
    offsets_s: torch.Tensor,           # [B, Ls, 2]
    offsets_t: torch.Tensor,           # [B, Lt, 2]
    mask_s: torch.Tensor,              # [B, Ls]
    mask_t: torch.Tensor,              # [B, Lt]
    min_coverage: float,
    top_frac: float,
    min_tokens: int,
) -> torch.Tensor:
    # Span-aware attention distillation using A * Att_t * A^T then CKA on a selected student subset.
    B = att_s.size(0)

    att_s_mean = att_s.mean(dim=1)  # [B, Ls, Ls]
    att_t_mean = att_t.mean(dim=1)  # [B, Lt, Lt]

    losses = []
    for b in range(B):
        A = build_span_overlap_matrix(offsets_s[b].to(att_s.device), offsets_t[b].to(att_s.device))
        cov_s, cov_t = coverage_from_A(A)
        conf = min(cov_s, cov_t)
        if conf < min_coverage:
            continue

        t_mask = mask_t[b].to(att_s.device)
        t_imp = token_importance_from_attention(att_t_mean[b], t_mask)
        t_sel = select_top_tokens(t_imp, t_mask, top_frac=top_frac, min_tokens=min_tokens)
        if t_sel.numel() < min_tokens:
            continue

        A_sel = A[:, t_sel]  # [Ls, M]
        s_has = (A_sel.sum(dim=1) > 1e-12).float()
        s_mask = mask_s[b].to(att_s.device) * s_has
        if s_mask.sum() < min_tokens:
            continue

        att_t_sub = att_t_mean[b][t_sel][:, t_sel]             # [M, M]
        att_t_proj = A_sel @ att_t_sub @ A_sel.transpose(0, 1) # [Ls, Ls]

        s_imp = token_importance_from_attention(att_s_mean[b], s_mask)
        s_sel = select_top_tokens(s_imp, s_mask, top_frac=top_frac, min_tokens=min_tokens)
        if s_sel.numel() < min_tokens:
            continue

        K = att_s_mean[b][s_sel][:, s_sel]
        L = att_t_proj[s_sel][:, s_sel]
        losses.append(linear_cka_loss(K, L) * conf)

    if len(losses) == 0:
        return att_s.new_tensor(0.0)
    return torch.stack(losses).mean()

In [None]:
# 6) Dataset + Dual-tokenizer collator (student + teacher + offsets)
class TextOnlyDataset(Dataset):
    def __init__(self, df: pd.DataFrame, text_col: str):
        self.texts = df[text_col].astype(str).tolist()

    def __len__(self) -> int:
        return len(self.texts)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        return {"text": self.texts[idx]}

class DualTokenizerCollate:
    def __init__(self, tok_s, tok_t, max_length: int):
        self.tok_s = tok_s
        self.tok_t = tok_t
        self.max_length = max_length

    def _tokenize(self, tokenizer, texts: List[str]) -> Dict[str, torch.Tensor]:
        return tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
            return_offsets_mapping=True,  # requires fast tokenizer
        )

    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        texts = [x["text"] for x in batch]

        # Two "views" come from dropout. Tokens are identical.
        s1 = self._tokenize(self.tok_s, texts)
        s2 = self._tokenize(self.tok_s, texts)
        t1 = self._tokenize(self.tok_t, texts)
        t2 = self._tokenize(self.tok_t, texts)

        out: Dict[str, Any] = {"texts": texts}
        for k,v in s1.items(): out[f"{k}1_stu"] = v
        for k,v in s2.items(): out[f"{k}2_stu"] = v
        for k,v in t1.items(): out[f"{k}1_tea"] = v
        for k,v in t2.items(): out[f"{k}2_tea"] = v
        return out

def load_train_dataframe(path: str, text_col: str) -> pd.DataFrame:
    df = pd.read_csv(path)
    if text_col not in df.columns:
        raise ValueError(f"text_col='{text_col}' not in columns: {list(df.columns)[:30]}")
    df = df.dropna(subset=[text_col]).reset_index(drop=True)
    return df

In [None]:
# 7) Load tokenizers + models (teacher frozen) + projection
tok_student = AutoTokenizer.from_pretrained(cfg.student_name, use_fast=True)
tok_teacher = AutoTokenizer.from_pretrained(cfg.teacher_name, use_fast=True)

model_student = AutoModel.from_pretrained(cfg.student_name, output_hidden_states=True).to(device_s)
model_teacher = AutoModel.from_pretrained(cfg.teacher_name, output_hidden_states=True).to(device_t)

model_teacher.eval()
for p in model_teacher.parameters():
    p.requires_grad_(False)

d_s = model_student.config.hidden_size
d_t = model_teacher.config.hidden_size
print("d_s:", d_s, "| d_t:", d_t)

if cfg.nested_dims is None:
    base = [16, 32, 64, 128, 256, 512, 1024]
    cfg.nested_dims = [d for d in base if d <= d_s]
    if cfg.nested_dims[-1] != d_s:
        cfg.nested_dims.append(d_s)

if cfg.kd_dims is None:
    cfg.kd_dims = [16,32,64,128, 256, 384, 512, d_s]
    cfg.kd_dims = [d for d in cfg.kd_dims if d <= d_s] or [d_s]

proj_t2s = nn.Linear(d_t, d_s, bias=False).to(device_s)

print("nested_dims (task):", cfg.nested_dims)
print("kd_dims:", cfg.kd_dims)

In [None]:
# ============================================================
# 8) Optimizer + scheduler + loader
# ============================================================
params = list(model_student.parameters()) + list(proj_t2s.parameters())
optimizer = torch.optim.AdamW(params, lr=cfg.lr, weight_decay=cfg.weight_decay)

df_train = load_train_dataframe(cfg.train_csv, cfg.text_col)
train_ds = TextOnlyDataset(df_train, cfg.text_col)
collate = DualTokenizerCollate(tok_student, tok_teacher, max_length=cfg.max_length)
train_loader = DataLoader(
    train_ds,
    batch_size=cfg.batch_size,
    shuffle=True,
    num_workers=2,
    collate_fn=collate,
    drop_last=True,
)

total_steps = cfg.epochs * max(1, len(train_loader))
warmup_steps = int(cfg.warmup_ratio * total_steps)

scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps,
    num_cycles=cfg.num_restarts,
)

scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())
print("total_steps:", total_steps, "| warmup_steps:", warmup_steps)

In [None]:
# 9) Weight schedules (ramps)
def linear_ramp(step: int, start: int, ramp: int) -> float:
    if ramp <= 0:
        return 1.0 if step >= start else 0.0
    if step < start:
        return 0.0
    return min(1.0, (step - start) / float(ramp))

def get_beta_kd(step: int) -> float:
    return cfg.beta_kd_max * linear_ramp(step, cfg.beta_kd_start, cfg.beta_kd_ramp)

def get_alpha_attn(step: int) -> float:
    return cfg.alpha_attn_max * linear_ramp(step, cfg.alpha_attn_start, cfg.alpha_attn_ramp)

In [None]:
# 10) Training loop (MRL task + optional teacher KD + optional Span+CKA)
from tqdm.auto import tqdm

global_step = 0
model_student.train()
proj_t2s.train()

for epoch in range(cfg.epochs):
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg.epochs}")
    running = 0.0

    for batch in pbar:
        optimizer.zero_grad(set_to_none=True)

        batch_s = {k: v.to(device_s, non_blocking=True) for k,v in batch.items() if torch.is_tensor(v) and k.endswith("_stu")}
        batch_t = {k: v.to(device_t, non_blocking=True) for k,v in batch.items() if torch.is_tensor(v) and k.endswith("_tea")}

        need_att = cfg.use_attn_cka and cfg.use_teacher_kd and (global_step % cfg.att_every == 0)
        out_att = need_att

        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            # ---- student: two dropout views ----
            s_out1 = model_student(
                input_ids=batch_s["input_ids1_stu"],
                attention_mask=batch_s["attention_mask1_stu"],
                output_attentions=out_att,
                return_dict=True,
            )
            s_out2 = model_student(
                input_ids=batch_s["input_ids2_stu"],
                attention_mask=batch_s["attention_mask2_stu"],
                output_attentions=False,   # attentions only on view1 for efficiency
                return_dict=True,
            )

            S1 = get_sentence_emb(s_out1.last_hidden_state, batch_s["attention_mask1_stu"], cfg.student_pool)
            S2 = get_sentence_emb(s_out2.last_hidden_state, batch_s["attention_mask2_stu"], cfg.student_pool)

            # ===== MAIN TASK: MRL baseline (Matryoshka InfoNCE) =====
            task_loss, _ = matryoshka_infonce(
                S1, S2,
                temperature=cfg.temperature,
                nested_dims=cfg.nested_dims,
                dim_weight="uniform",
            )

            kd_loss = S1.new_tensor(0.0)
            att_loss = S1.new_tensor(0.0)

            if cfg.use_teacher_kd:
                with torch.inference_mode():
                    t_out = model_teacher(
                        input_ids=batch_t["input_ids1_tea"],
                        attention_mask=batch_t["attention_mask1_tea"],
                        output_attentions=out_att,
                        return_dict=True,
                    )
                    T = extract_teacher_sentence_embedding(
                        t_out.last_hidden_state,
                        batch_t["attention_mask1_tea"],
                        batch_t["input_ids1_tea"],
                        tok_teacher,
                        embed_token=cfg.teacher_embed_token,
                    ).to(device_s)

                T = proj_t2s(T)

                # ---- teacher prefix KD ----
                kd_loss = 0.5 * (
                    matryoshka_prefix_cosine_kd(S1, T, dims=cfg.kd_dims, dim_weight="inverse_sqrt") +
                    matryoshka_prefix_cosine_kd(S2, T, dims=cfg.kd_dims, dim_weight="inverse_sqrt")
                )

                # ---- span-aware attention CKA ----
                if need_att:
                    if cfg.att_layer == "mid":
                        idx_s = len(s_out1.attentions)//2
                        idx_t = len(t_out.attentions)//2
                    else:
                        idx_s = -1
                        idx_t = -1

                    att_s = s_out1.attentions[idx_s]             # [B,H,Ls,Ls]
                    att_t = t_out.attentions[idx_t].to(device_s) # [B,H,Lt,Lt]

                    offsets_s = batch_s["offset_mapping1_stu"]          # [B,Ls,2]
                    offsets_t = batch_t["offset_mapping1_tea"].to(device_s)  # [B,Lt,2]
                    mask_s = batch_s["attention_mask1_stu"]             # [B,Ls]
                    mask_t = batch_t["attention_mask1_tea"].to(device_s)     # [B,Lt]

                    att_loss = compute_span_cka_att_loss(
                        att_s=att_s,
                        att_t=att_t,
                        offsets_s=offsets_s,
                        offsets_t=offsets_t,
                        mask_s=mask_s,
                        mask_t=mask_t,
                        min_coverage=cfg.min_coverage,
                        top_frac=cfg.top_frac,
                        min_tokens=cfg.min_tokens,
                    )

            beta_kd = get_beta_kd(global_step)
            alpha_attn = get_alpha_attn(global_step)

            total_loss = cfg.w_task * task_loss
            if cfg.use_teacher_kd:
                total_loss = total_loss + cfg.alpha_kd * (beta_kd * kd_loss)
            if cfg.use_attn_cka:
                total_loss = total_loss + cfg.alpha_kd * (alpha_attn * att_loss)

        scaler.scale(total_loss).backward()

        if cfg.grad_clip is not None and cfg.grad_clip > 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(params, cfg.grad_clip)

        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        running = 0.95 * running + 0.05 * total_loss.item() if global_step > 0 else total_loss.item()
        pbar.set_postfix({
            "loss": f"{running:.4f}",
            "task": f"{task_loss.item():.4f}",
            "kd": f"{kd_loss.item():.4f}" if cfg.use_teacher_kd else "0.0000",
            "att": f"{att_loss.item():.4f}" if need_att else "skip",
            "β_kd": f"{beta_kd:.2f}",
            "α_att": f"{alpha_attn:.3f}",
            "att?": int(need_att),
        })

        global_step += 1

print("Done. global_step =", global_step)

In [None]:
# 11) Fair comparison cheat sheet
print("PURE MRL baseline (same task as mrl-baseline.ipynb):")
print("  cfg.use_teacher_kd = False")
print("  cfg.use_attn_cka   = False")
print("")
print("SAMD (MRL task + KD + Span+CKA):")
print("  cfg.use_teacher_kd = True")
print("  cfg.use_attn_cka   = True")

In [None]:
# ENCODE (for evaluation)
from tqdm.auto import tqdm

@torch.no_grad()
def encode_texts(
    texts,
    batch_size: int = 256,
    max_length: Optional[int] = None,
    pool: Optional[str] = None,
    normalize: bool = False,
) -> torch.Tensor:
    """Encode texts with the STUDENT model. Returns CPU tensor [N, d_s]."""
    model_student.eval()
    if max_length is None:
        max_length = cfg.max_length
    if pool is None:
        pool = cfg.student_pool

    texts = [str(x) for x in texts]
    all_emb = []

    for i in tqdm(range(0, len(texts), batch_size), desc="encode", leave=False):
        chunk = texts[i:i+batch_size]
        enc = tok_student(
            chunk,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt",
        )
        enc = {k: v.to(device_s, non_blocking=True) for k, v in enc.items()}

        out = model_student(**enc, return_dict=True)
        emb = get_sentence_emb(out.last_hidden_state, enc["attention_mask"], pool)

        if normalize:
            emb = F.normalize(emb, dim=-1)

        all_emb.append(emb.detach().cpu())

    return torch.cat(all_emb, dim=0)

In [None]:
# Evaluation helpers (per slice)
import numpy as np
import pandas as pd

# Optional deps. If missing, eval will be skipped gracefully.
try:
    from scipy.stats import spearmanr
except Exception as e:
    spearmanr = None
    print("[WARN] scipy not available -> STS Spearman will be skipped.", e)

try:
    from sklearn.linear_model import LogisticRegression
    from sklearn.metrics import accuracy_score, f1_score
except Exception as e:
    LogisticRegression = None
    accuracy_score = None
    f1_score = None
    print("[WARN] scikit-learn not available -> CLS eval will be skipped.", e)

try:
    from IPython.display import display
except Exception:
    display = print

def _safe_spearman(a, b) -> float:
    if spearmanr is None:
        return 0.0
    r = spearmanr(a, b).correlation
    return float(r) if r == r else 0.0

def eval_cls_task(train_csv, test_csv, text_col="text", label_col="label", dims=None, batch_size=256):
    if LogisticRegression is None:
        raise RuntimeError("scikit-learn not available. Install scikit-learn to run CLS eval.")

    tr = pd.read_csv(train_csv)
    te = pd.read_csv(test_csv)

    X_tr = encode_texts(tr[text_col].astype(str).tolist(), batch_size=batch_size)
    X_te = encode_texts(te[text_col].astype(str).tolist(), batch_size=batch_size)

    y_tr = tr[label_col].astype(int).values
    y_te = te[label_col].astype(int).values

    rows = []
    for d in dims:
        clf = LogisticRegression(max_iter=2000, n_jobs=1)
        clf.fit(X_tr[:, :d].numpy(), y_tr)
        pred = clf.predict(X_te[:, :d].numpy())
        rows.append({
            "dim": d,
            "acc": float(accuracy_score(y_te, pred)),
            "f1_macro": float(f1_score(y_te, pred, average="macro")),
        })
    return pd.DataFrame(rows).set_index("dim")

def eval_pair_task(train_csv, test_csv, s1="sentence1", s2="sentence2", label_col="label", dims=None, batch_size=256):
    tr = pd.read_csv(train_csv)
    te = pd.read_csv(test_csv)

    A_tr = encode_texts(tr[s1].astype(str).tolist(), batch_size=batch_size)
    B_tr = encode_texts(tr[s2].astype(str).tolist(), batch_size=batch_size)
    A_te = encode_texts(te[s1].astype(str).tolist(), batch_size=batch_size)
    B_te = encode_texts(te[s2].astype(str).tolist(), batch_size=batch_size)

    y_tr = tr[label_col].astype(int).values
    y_te = te[label_col].astype(int).values

    rows = []
    for d in dims:
        sim_tr = (F.normalize(A_tr[:, :d], dim=-1) * F.normalize(B_tr[:, :d], dim=-1)).sum(dim=-1).numpy()
        sim_te = (F.normalize(A_te[:, :d], dim=-1) * F.normalize(B_te[:, :d], dim=-1)).sum(dim=-1).numpy()

        best_thr, best_acc = 0.0, -1.0
        for thr in np.linspace(-1, 1, 401):
            acc = ((sim_tr >= thr).astype(int) == y_tr).mean()
            if acc > best_acc:
                best_acc, best_thr = acc, thr

        pred = (sim_te >= best_thr).astype(int)
        rows.append({"dim": d, "acc": float((pred == y_te).mean()), "thr": float(best_thr)})
    return pd.DataFrame(rows).set_index("dim")

def eval_sts_task(test_csv, s1="sentence1", s2="sentence2", score_col="score", dims=None, batch_size=256):
    te = pd.read_csv(test_csv)

    A = encode_texts(te[s1].astype(str).tolist(), batch_size=batch_size)
    B = encode_texts(te[s2].astype(str).tolist(), batch_size=batch_size)
    y = te[score_col].astype(float).values

    rows = []
    for d in dims:
        sim = (F.normalize(A[:, :d], dim=-1) * F.normalize(B[:, :d], dim=-1)).sum(dim=-1).numpy()
        rows.append({"dim": d, "spearman": _safe_spearman(sim, y)})
    return pd.DataFrame(rows).set_index("dim")

In [None]:
EVAL_ROOT = os.environ.get("/kaggle/input/multitask-data" , "")
if EVAL_ROOT and os.path.exists(EVAL_ROOT):
    print(1)
else:
    print(0)

In [None]:
# RUN EVAL (thorough, per slice) — OPTIONAL

EVAL_ROOT = "/kaggle/input/multitask-data/multi-data"
STS_EXTRA_ROOT = os.environ.get("STS_EXTRA_ROOT", "")

NESTED_DIMS = cfg.nested_dims
print("NESTED_DIMS:", NESTED_DIMS)

results = {"cls": {}, "pair": {}, "sts": {}, "sts_extra": {}}

if EVAL_ROOT and os.path.exists(EVAL_ROOT):
    cls_tasks = [
        ("Banking77", "banking_train.csv", "banking77_test.csv"),
        ("Emotion",   "emotion_train.csv", "emotion_test.csv"),
        ("TweetEval", "tweet_train.csv",   "tweet_test.csv"),
    ]
    for name, trf, tef in cls_tasks:
        tr = os.path.join(EVAL_ROOT, trf)
        te = os.path.join(EVAL_ROOT, tef)
        if os.path.exists(tr) and os.path.exists(te):
            results["cls"][name] = eval_cls_task(tr, te, text_col="text", label_col="label", dims=NESTED_DIMS)
            print(f"[CLS] {name}")
            display(results["cls"][name])

    pair_tasks = [
        ("MRPC",    "mrpc_validation.csv",    "mrpc_test.csv"),
        ("SciTail", "scitail_validation.csv", "scitail_test.csv"),
        ("WiC",     "wic_validation.csv",     "wic_test.csv"),
    ]
    for name, trf, tef in pair_tasks:
        tr = os.path.join(EVAL_ROOT, trf)
        te = os.path.join(EVAL_ROOT, tef)
        if os.path.exists(tr) and os.path.exists(te):
            results["pair"][name] = eval_pair_task(tr, te, s1="sentence1", s2="sentence2", label_col="label", dims=NESTED_DIMS)
            print(f"[PAIR] {name}")
            display(results["pair"][name])

    sts_tasks = [
        ("SICK",  "sick_test.csv"),
        ("STS12", "sts12_test.csv"),
        ("STS-B", "stsb_test.csv"),
    ]
    for name, tef in sts_tasks:
        te = os.path.join(EVAL_ROOT, tef)
        if os.path.exists(te):
            results["sts"][name] = eval_sts_task(te, s1="sentence1", s2="sentence2", score_col="score", dims=NESTED_DIMS)
            print(f"[STS] {name}")
            display(results["sts"][name])
else:
    print("EVAL_ROOT not found or not set; skipping core eval.")

if STS_EXTRA_ROOT and os.path.exists(STS_EXTRA_ROOT):
    candidates = [
        ("STS13", "sts13.csv"),
        ("STS14", "sts14.csv"),
        ("STS15", "sts15.csv"),
        ("STS16", "sts16.csv"),
        ("STS17", "sts17.csv"),
    ]
    for name, fn in candidates:
        p = os.path.join(STS_EXTRA_ROOT, fn)
        if os.path.exists(p):
            results["sts_extra"][name] = eval_sts_task(p, s1="sentence1", s2="sentence2", score_col="score", dims=NESTED_DIMS)
            print(f"[STS_EXTRA] {name}")
            display(results["sts_extra"][name])
else:
    print("STS_EXTRA_ROOT not found or not set; skipping extra STS eval.")

results