#1) Imports, config, and seeding

In [6]:
# --- UltraHybrid XL++ (Notebook Edition): Cell 1 / Config & Helpers ---

from __future__ import annotations
import os, math, json, random, copy, warnings
from pathlib import Path
from typing import Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    roc_auc_score, f1_score, precision_score, recall_score,
    confusion_matrix, balanced_accuracy_score, average_precision_score
)

warnings.filterwarnings("ignore")
np.seterr(all="ignore")

# ----------------- user-editable paths -----------------
data_root_path = Path(r"D:\ML_CHALLANGE\data")  # <- edit if needed
human_train_npy_path = data_root_path / "train" / "train_human.npy"
ai_train_npy_path    = data_root_path / "train" / "train_ai.npy"
validation_jsonl_path = data_root_path / "val" / "validation.jsonl"  # optional; not required here

# Where to save artifacts (kept dims, whitening, weights, report)
output_dir_path = Path(".")
output_dir_path.mkdir(parents=True, exist_ok=True)

# ----------------- reproducibility -----------------
def set_global_seed(seed: int = 42, deterministic: bool = False):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    if deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
    else:
        torch.backends.cudnn.benchmark = True

# ----------------- configuration -----------------
training_config_dict = dict(
    random_seed_integer = 42,
    deterministic_cudnn_boolean = False,

    # data
    test_size_fraction = 0.15,

    # model
    input_embedding_dimensionality = 768,   # will be overwritten AFTER masking
    sequence_length_tokens = 100,
    model_hidden_dimensionality = 320,
    transformer_num_layers = 2,
    transformer_num_heads = 8,
    transformer_ffn_multiplier = 4,
    attention_dropout_rate = 0.10,
    ffn_dropout_rate = 0.10,
    use_rnn_boolean = True,
    rnn_num_layers = 1,
    use_cnn_boolean = True,
    cnn_kernel_sizes_list = [3, 5, 7],
    cnn_out_channels = 96,
    use_multiscale_pooling_boolean = True,
    use_checkpointing_boolean = True,

    # head
    head_type_string = "cosine",    # "cosine" (ArcFace-style) or "linear"
    arcface_scale_float = 16.0,
    arcface_margin_float = 0.20,
    arcface_easy_margin_boolean = False,

    # training
    total_epochs_integer = 12,
    batch_size_train = 64,
    batch_size_eval = 128,
    grad_accumulation_steps = 2,
    base_learning_rate = 2e-4,
    weight_decay_rate = 1e-2,
    lr_warmup_fraction = 0.10,
    max_grad_norm = 2.0,
    early_stop_patience_epochs = 6,

    # regularization & tricks
    use_amp_boolean = True,
    use_class_weights_boolean = True,
    use_rdrop_alpha = 2e-2,      # 0 to disable
    token_mix_probability = 0.15,

    # evaluation
    optimize_for_string = "acc",  # "acc" or "f1"
    mc_dropout_samples = 1,       # >1 enables MC-dropout

    # numerics
    mask_topk_high_variance_dims = 64,  # 0 to disable
    out_dir = str(output_dir_path)
)

device_torch = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_global_seed(training_config_dict["random_seed_integer"], training_config_dict["deterministic_cudnn_boolean"])

# ----------------- numerical guards -----------------
def safe_softmax_from_logits(logits_tensor: torch.Tensor) -> torch.Tensor:
    logits_tensor = logits_tensor.float()
    logits_tensor = torch.nan_to_num(logits_tensor, nan=0.0, neginf=-50.0, posinf=50.0).clamp(-50.0, 50.0)
    logits_tensor = logits_tensor - logits_tensor.max(dim=1, keepdim=True).values
    prob_tensor = torch.softmax(logits_tensor, dim=1)
    prob_tensor = torch.nan_to_num(prob_tensor, nan=0.5, neginf=0.0, posinf=1.0)
    return prob_tensor

def np_clean_probabilities(array_1d: np.ndarray) -> np.ndarray:
    return np.nan_to_num(array_1d, nan=0.5, neginf=0.0, posinf=1.0)

def cosine_warmup_lr(optimizer_like, current_step: int, total_steps: int, base_lr: float, warmup_fraction: float):
    warmup_steps = max(1, int(total_steps * warmup_fraction))
    if current_step < warmup_steps:
        lr_value = base_lr * (current_step + 1) / warmup_steps
    else:
        progress = (current_step - warmup_steps) / max(1, total_steps - warmup_steps)
        lr_value = 0.5 * base_lr * (1.0 + math.cos(math.pi * progress))
    for pg in optimizer_like.param_groups:
        pg["lr"] = lr_value

def mixup_batch(inputs_tensor: torch.Tensor, targets_tensor: torch.Tensor, alpha: float = 0.3):
    if alpha <= 0:
        return inputs_tensor, targets_tensor, targets_tensor, 1.0
    lam = np.random.beta(alpha, alpha)
    idx = torch.randperm(inputs_tensor.size(0), device=inputs_tensor.device)
    mixed_inputs = lam * inputs_tensor + (1 - lam) * inputs_tensor[idx]
    targets_a, targets_b = targets_tensor, targets_tensor[idx]
    return mixed_inputs, targets_a, targets_b, lam


#Cell 2 — Load data, preprocessing, split, and loaders

In [7]:
# --- UltraHybrid XL++: Cell 2 / Data loading & preprocessing ---

# 1) load arrays
assert human_train_npy_path.exists() and ai_train_npy_path.exists(), "NPY files not found."
human_embeddings_np = np.load(human_train_npy_path)   # (Nh, 100, 768)
ai_embeddings_np    = np.load(ai_train_npy_path)      # (Na, 100, 768)
assert human_embeddings_np.ndim == 3 and ai_embeddings_np.ndim == 3, "Bad shapes for input arrays."

# 2) labels and split (stratified)
labels_human = np.zeros(len(human_embeddings_np), dtype=np.int64)
labels_ai    = np.ones(len(ai_embeddings_np), dtype=np.int64)
all_embeddings_np = np.concatenate([human_embeddings_np, ai_embeddings_np], axis=0)
all_labels_np     = np.concatenate([labels_human, labels_ai], axis=0)

X_train_np, X_hold_np, y_train_np, y_hold_np = train_test_split(
    all_embeddings_np, all_labels_np,
    test_size=training_config_dict["test_size_fraction"],
    random_state=training_config_dict["random_seed_integer"],
    stratify=all_labels_np
)

# 3) per-sample layer-norm (stable)
def layernorm_2d(sample_np: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    # sample_np: (T=100, D)
    mean = sample_np.mean(axis=1, keepdims=True)
    var  = sample_np.var(axis=1, keepdims=True)
    return ((sample_np - mean) / np.sqrt(var + eps)).astype(np.float32)

def preprocess_batch_seqnorm(X3d: np.ndarray) -> np.ndarray:
    out = np.empty_like(X3d, dtype=np.float32)
    for i in range(len(X3d)):
        out[i] = layernorm_2d(X3d[i])
    return out

X_train_seq = preprocess_batch_seqnorm(X_train_np)
X_hold_seq  = preprocess_batch_seqnorm(X_hold_np)

# 4) mask top-K high-variance dimensions (computed on TRAIN only)
def mask_noisiest_dimensions(X_train: np.ndarray, X_hold: np.ndarray, topk: int):
    if topk <= 0:
        keep_idx = np.arange(X_train.shape[-1])
        return X_train, X_hold, keep_idx
    flat = X_train.reshape(-1, X_train.shape[-1])  # (N*T, D)
    var_per_dim = flat.var(axis=0)
    keep_idx = np.argsort(var_per_dim)[:-topk]     # drop highest-variance dims
    return X_train[..., keep_idx], X_hold[..., keep_idx], keep_idx

X_train_seq_mask, X_hold_seq_mask, kept_dim_indices = mask_noisiest_dimensions(
    X_train_seq, X_hold_seq, training_config_dict["mask_topk_high_variance_dims"]
)

# 5) set the true input dim AFTER masking
effective_input_dim = int(X_train_seq_mask.shape[-1])
training_config_dict["input_embedding_dimensionality"] = effective_input_dim

# 6) compute & save whitening on the *masked* TRAIN set
flat_train = X_train_seq_mask.reshape(-1, effective_input_dim)
whiten_mu = flat_train.mean(axis=0).astype(np.float32)
whiten_sd = np.clip(flat_train.std(axis=0), 1e-6, None).astype(np.float32)

np.save(output_dir_path / "kept_dim_indices.npy", kept_dim_indices.astype(np.int32))
np.save(output_dir_path / "whiten_mu.npy", whiten_mu)
np.save(output_dir_path / "whiten_sd.npy", whiten_sd)

# 7) Torch datasets/loaders
train_dataset = TensorDataset(torch.from_numpy(X_train_seq_mask), torch.from_numpy(y_train_np))
hold_dataset  = TensorDataset(torch.from_numpy(X_hold_seq_mask),  torch.from_numpy(y_hold_np))

train_loader = DataLoader(train_dataset, batch_size=training_config_dict["batch_size_train"],
                          shuffle=True, num_workers=0, pin_memory=True)
hold_loader  = DataLoader(hold_dataset,  batch_size=training_config_dict["batch_size_eval"],
                          shuffle=False, num_workers=0, pin_memory=True)

print("Data ready.")
print(f"Train: {X_train_seq_mask.shape} | Holdout: {X_hold_seq_mask.shape} | Kept dims: {kept_dim_indices.shape}")
print(f"Effective input dim = {effective_input_dim} | Whitening saved to: {output_dir_path.resolve()}")


Data ready.
Train: (13873, 100, 704) | Holdout: (2449, 100, 704) | Kept dims: (704,)
Effective input dim = 704 | Whitening saved to: D:\ML_CHALLANGE\models\models


#Cell 3 — Model blocks (Transformer, ArcFace head, unified model forward)

In [8]:
# --- UltraHybrid XL++: Cell 3 / Model definition ---

from torch.utils.checkpoint import checkpoint as ckpt

class TransformerBlock(nn.Module):
    def __init__(self, dim: int, num_heads: int, ffn_dim: int,
                 gated_attention: bool = True, attn_dropout: float = 0.1, ffn_dropout: float = 0.1):
        super().__init__()
        assert dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.gated_attention = gated_attention

        self.Wq = nn.Linear(dim, dim)
        self.Wk = nn.Linear(dim, dim)
        self.Wv = nn.Linear(dim, dim)
        self.Wo = nn.Linear(dim, dim)
        self.gate_proj = nn.Linear(dim, 1) if gated_attention else None

        self.ffn = nn.Sequential(
            nn.Linear(dim, ffn_dim), nn.GELU(), nn.Dropout(ffn_dropout),
            nn.Linear(ffn_dim, dim)
        )
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.drop_attn = nn.Dropout(attn_dropout)
        self.drop_ffn  = nn.Dropout(ffn_dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, L, D = x.shape
        q = self.Wq(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.Wk(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.Wv(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)

        if self.gated_attention:
            gate = torch.sigmoid(self.gate_proj(x)).unsqueeze(1)  # (B,1,L,1)
            k = k * gate
            v = v * gate

        scores = (q @ k.transpose(-1, -2)) / math.sqrt(self.head_dim)
        weights = self.drop_attn(torch.softmax(scores, dim=-1))
        attn_out = (weights @ v).transpose(1, 2).reshape(B, L, D)

        x = self.norm1(x + self.Wo(attn_out))
        z = self.drop_ffn(self.ffn(x))
        return self.norm2(x + z)

class CosineMarginHead(nn.Module):
    def __init__(self, in_dim: int, scale: float = 16.0, margin: float = 0.20, easy_margin: bool = False):
        super().__init__()
        self.scale = scale
        self.margin = margin
        self.easy = easy_margin
        self.weight_matrix = nn.Parameter(nn.init.orthogonal_(torch.empty(2, in_dim)))

    def forward(self, feature_vectors, labels: Optional[torch.Tensor] = None):
        features_norm = F.normalize(feature_vectors, dim=1)
        weights_norm  = F.normalize(self.weight_matrix, dim=1)
        cosine_vals = features_norm @ weights_norm.t()              # (B,2)

        if labels is None:
            return self.scale * cosine_vals

        # apply additive margin only on the target class
        clamped = cosine_vals.clamp(-1 + 1e-7, 1 - 1e-7)
        target = clamped[torch.arange(len(feature_vectors)), labels]
        target_m = torch.where(target > 0, target - self.margin, target) if self.easy else (target - self.margin)
        cosine_with_margin = cosine_vals.clone()
        cosine_with_margin[torch.arange(len(feature_vectors)), labels] = target_m
        return self.scale * cosine_with_margin

class UltraHybridModel(nn.Module):
    def __init__(self, cfg: dict):
        super().__init__()
        self.cfg = dict(cfg)
        d_in = int(self.cfg["input_embedding_dimensionality"])
        d    = int(self.cfg["model_hidden_dimensionality"])
        ffn_dim = int(self.cfg["transformer_ffn_multiplier"] * d)

        # input projection (if masked dim != model dim)
        self.input_projection = nn.Linear(d_in, d) if d_in != d else None

        # transformer stack
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(d, self.cfg["transformer_num_heads"], ffn_dim,
                             gated_attention=True,
                             attn_dropout=self.cfg["attention_dropout_rate"],
                             ffn_dropout=self.cfg["ffn_dropout_rate"])
            for _ in range(self.cfg["transformer_num_layers"])
        ])

        # optional heads
        self.use_rnn = bool(self.cfg["use_rnn_boolean"])
        if self.use_rnn:
            self.rnn = nn.GRU(input_size=d, hidden_size=d, num_layers=self.cfg["rnn_num_layers"],
                              batch_first=True, bidirectional=True, dropout=0.0)
            self.rnn_out_dim = 2 * d
        else:
            self.rnn_out_dim = 0

        self.use_cnn = bool(self.cfg["use_cnn_boolean"])
        if self.use_cnn:
            ch = self.cfg["cnn_out_channels"]; ks = self.cfg["cnn_kernel_sizes_list"]
            self.conv_layers = nn.ModuleList([nn.Conv1d(d, ch, k, padding=(k-1)//2) for k in ks])
            self.cnn_out_dim = len(ks) * (2 * ch if self.cfg["use_multiscale_pooling_boolean"] else ch)
        else:
            self.cnn_out_dim = 0

        self.use_multiscale_pool = bool(self.cfg["use_multiscale_pooling_boolean"])
        transformer_out_dim = 2 * d if self.use_multiscale_pool else d

        # minimal engineered extras
        self.extra_dim = 3  # mean L2, std L2, mean cos to seq centroid

        total_feature_dim = transformer_out_dim + self.rnn_out_dim + self.cnn_out_dim + self.extra_dim

        head_type = self.cfg.get("head_type_string", "linear")
        if head_type == "cosine":
            self.classifier_head = CosineMarginHead(total_feature_dim,
                                                    self.cfg["arcface_scale_float"],
                                                    self.cfg["arcface_margin_float"],
                                                    self.cfg["arcface_easy_margin_boolean"])
        else:
            self.classifier_head = nn.Linear(total_feature_dim, 2)

        # whitening buffers (load if present, else neutral) — dimension-safe
        mu_path = Path(self.cfg["out_dir"]) / "whiten_mu.npy"
        sd_path = Path(self.cfg["out_dir"]) / "whiten_sd.npy"
        if mu_path.exists() and sd_path.exists():
            mu = np.load(mu_path).astype(np.float32)
            sd = np.load(sd_path).astype(np.float32)
            if mu.shape[0] == d_in and sd.shape[0] == d_in:
                self.register_buffer("whiten_mu", torch.from_numpy(mu))
                self.register_buffer("whiten_sd", torch.from_numpy(sd))
            else:
                # fallback if stale files exist
                self.register_buffer("whiten_mu", torch.zeros(d_in))
                self.register_buffer("whiten_sd", torch.ones(d_in))
        else:
            self.register_buffer("whiten_mu", torch.zeros(d_in))
            self.register_buffer("whiten_sd", torch.ones(d_in))

        self._warned_whiten_shape = False  # one-time warning

    # ----- engineered features -----
    @staticmethod
    def token_l2_mean_std(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        norms = torch.linalg.vector_norm(x, dim=2)  # (B,T)
        return norms.mean(1, keepdim=True), norms.std(1, keepdim=True)

    @staticmethod
    def mean_cosine_to_seq_centroid(x: torch.Tensor) -> torch.Tensor:
        mean_vec = x.mean(1)
        mean_vec = mean_vec / (mean_vec.norm(dim=1, keepdim=True) + 1e-9)
        x_normed = x / (x.norm(dim=2, keepdim=True) + 1e-9)
        return ((x_normed * mean_vec.unsqueeze(1)).sum(2)).mean(1, keepdim=True)

    def time_weighted_mean(self, z: torch.Tensor) -> torch.Tensor:
        T = z.size(1)
        w = torch.linspace(1.3, 0.7, T, device=z.device, dtype=z.dtype).view(1, T, 1)
        return (z * w).sum(1) / (w.sum().clamp_min(1e-8))

    # ----- robust whitening -----
    def _apply_whitening(self, x: torch.Tensor) -> torch.Tensor:
        B, T, D = x.shape
        if self.whiten_mu.numel() != D or self.whiten_sd.numel() != D:
            if not self._warned_whiten_shape:
                print(f"[warn] Whitening dim mismatch: x:{D} vs buffers:{self.whiten_mu.numel()}. Using batch stats.")
                self._warned_whiten_shape = True
            flat = x.reshape(-1, D)
            mu = flat.mean(dim=0)
            sd = flat.std(dim=0).clamp_min(1e-6)
            return (x - mu.view(1,1,-1)) / sd.view(1,1,-1)
        return (x - self.whiten_mu.view(1,1,-1)) / (self.whiten_sd.view(1,1,-1) + 1e-6)

    # ----- unified forward -----
    def forward(self, token_embeddings: torch.Tensor,
                return_features: bool = False,
                labels: Optional[torch.Tensor] = None) -> torch.Tensor:
        assert token_embeddings.dim() == 3, "Expect [batch, seq, dim]"
        x = self._apply_whitening(token_embeddings)

        xp = self.input_projection(x) if self.input_projection is not None else x

        z = xp
        for blk in self.transformer_blocks:
            z = ckpt(blk, z) if (self.cfg["use_checkpointing_boolean"] and self.training) else blk(z)

        transformer_feat = (torch.cat([self.time_weighted_mean(z), z.max(1).values], dim=1)
                            if self.use_multiscale_pool else self.time_weighted_mean(z))

        rnn_feat = None
        if self.use_rnn:
            _, h = self.rnn(xp)
            h_fwd = h[-2] if self.cfg["rnn_num_layers"] >= 1 else h[0]
            h_bwd = h[-1] if self.cfg["rnn_num_layers"] >= 1 else h[0]
            rnn_feat = torch.cat([h_fwd, h_bwd], dim=1)

        cnn_feat = None
        if self.use_cnn:
            xc = xp.transpose(1, 2)
            feats = []
            for conv in self.conv_layers:
                y = F.relu(conv(xc))
                if self.use_multiscale_pool:
                    feats.extend([y.max(2).values, y.mean(2)])
                else:
                    feats.append(y.max(2).values)
            cnn_feat = torch.cat(feats, dim=1)

        mean_l2, std_l2 = self.token_l2_mean_std(x)
        mean_cos = self.mean_cosine_to_seq_centroid(x)
        extra = torch.cat([mean_l2, std_l2, mean_cos], dim=1)

        parts = [transformer_feat, extra]
        if rnn_feat is not None: parts.append(rnn_feat)
        if cnn_feat is not None: parts.append(cnn_feat)
        features = torch.cat(parts, dim=1) if len(parts) > 1 else parts[0]

        if return_features:
            return features

        if self.cfg["head_type_string"] == "cosine":
            return self.classifier_head(features, labels=labels)
        else:
            return self.classifier_head(features)


#Cell 4 — Training/eval utilities (metrics, evaluation loop with safe softmax)

In [9]:
# --- UltraHybrid XL++: Cell 4 / Metrics & Eval ---

def tune_threshold(y_true: np.ndarray, p_pos: np.ndarray, mode: str = "acc") -> float:
    grid = np.linspace(0.0, 1.0, 501)
    if mode == "acc":
        best_t, best_s = 0.5, -1.0
        for t in grid:
            s = (y_true == (p_pos >= t).astype(int)).mean()
            if s > best_s: best_t, best_s = t, s
        return float(best_t)
    # mode == "f1"
    best_t, best_f1, best_j = 0.5, -1.0, -1.0
    for t in grid:
        pred = (p_pos >= t).astype(int)
        p = precision_score(y_true, pred, zero_division=0)
        r = recall_score(y_true, pred, zero_division=0)
        f1 = 2*p*r/(p+r) if (p+r) > 0 else 0.0
        tn, fp, fn, tp = confusion_matrix(y_true, pred, labels=[0,1]).ravel()
        j = (tp/(tp+fn+1e-9)) - (fp/(fp+tn+1e-9))
        if (f1 > best_f1) or (abs(f1-best_f1) < 1e-9 and j > best_j):
            best_t, best_f1, best_j = t, f1, j
    return float(best_t)

@torch.no_grad()
def evaluate_model_probabilities(model: nn.Module, data_loader: DataLoader) -> np.ndarray:
    model.eval()
    probs_list = []
    # force float32 eval for stability, even on CUDA
    ctx = torch.autocast(device_type="cuda", dtype=torch.float16, enabled=False) if torch.cuda.is_available() else torch.cpu.amp.autocast(enabled=False)
    with ctx:
        for xb, _ in data_loader:
            xb = xb.to(device_torch, non_blocking=True)
            logits = model(xb)  # labels=None → eval path
            probs = safe_softmax_from_logits(logits)[:, 1]
            probs_list.append(probs.cpu().numpy())
    probs = np.concatenate(probs_list)
    return np_clean_probabilities(probs)

def print_classification_report(y_true: np.ndarray, p_pos: np.ndarray, tag: str):
    thr = tune_threshold(y_true, p_pos, training_config_dict["optimize_for_string"])
    y_pred = (p_pos >= thr).astype(int)
    auc = roc_auc_score(y_true, p_pos)
    pr  = average_precision_score(y_true, p_pos)
    f1  = f1_score(y_true, y_pred)
    acc = (y_true == y_pred).mean()
    bal = balanced_accuracy_score(y_true, y_pred)
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0,1]).ravel()
    print(f"\n[{tag}] AUC={auc:.4f}  PR-AUC={pr:.4f}  F1@thr={f1:.4f}  ACC@thr={acc:.4f}  Thr={thr:.3f}")
    print(f"        CM [[{tn:4d} {fp:4d}]\n            [{fn:4d} {tp:4d}]]")


#Cell 5 — Train loop (unified forward, stable AMP, clear prints, saving best)

In [10]:
# --- UltraHybrid XL++: Cell 5 / Training loop ---

# Build model AFTER Cell 2 set input dim & saved whitening
training_config_dict["out_dir"] = str(output_dir_path)
model_instance = UltraHybridModel(training_config_dict).to(device_torch)

# class weights (optional)
if training_config_dict["use_class_weights_boolean"]:
    n0 = int((y_train_np == 0).sum()); n1 = int((y_train_np == 1).sum())
    inv = np.array([1.0/max(1,n0), 1.0/max(1,n1)], dtype=np.float32)
    class_weights_tensor = torch.tensor(inv / inv.mean(), device=device_torch)
else:
    class_weights_tensor = None

criterion_fn = nn.CrossEntropyLoss(weight=class_weights_tensor)

optimizer_torch = torch.optim.AdamW(
    model_instance.parameters(),
    lr=training_config_dict["base_learning_rate"],
    weight_decay=training_config_dict["weight_decay_rate"]
)
scaler_amp = torch.cuda.amp.GradScaler(enabled=(training_config_dict["use_amp_boolean"] and device_torch.type=="cuda"))

total_update_steps = training_config_dict["total_epochs_integer"] * max(1, math.ceil(len(train_loader)/max(1, training_config_dict["grad_accumulation_steps"])))
global_step_counter = 0

best_holdout_auc = -1.0
best_state_dict = None
best_epoch_index = -1
epochs_without_improvement = 0

for epoch_index in range(1, training_config_dict["total_epochs_integer"] + 1):
    model_instance.train()
    running_loss_sum = 0.0
    step_in_epoch = 0

    for xb, yb in train_loader:
        xb = xb.to(device_torch, non_blocking=True)
        yb = yb.to(device_torch, non_blocking=True)

        # optional token-wise mixup (disabled by default for sequences; left here for completeness)
        if training_config_dict["token_mix_probability"] > 0.0 and np.random.rand() < training_config_dict["token_mix_probability"]:
            xb, ya, yb2, lam = mixup_batch(xb, yb, alpha=0.3)
            use_mixup = True
        else:
            use_mixup = False

        cosine_warmup_lr(optimizer_torch, global_step_counter, total_update_steps,
                         training_config_dict["base_learning_rate"], training_config_dict["lr_warmup_fraction"])

        with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=(training_config_dict["use_amp_boolean"] and device_torch.type=="cuda")):
            logits = model_instance(xb, labels=yb)  # unified: labels→ArcFace margin used internally

            if training_config_dict["use_rdrop_alpha"] > 0.0:
                logits2 = model_instance(xb, labels=yb)
                ce1 = criterion_fn(logits,  yb)
                ce2 = criterion_fn(logits2, yb)
                p1 = F.log_softmax(logits,  dim=1); p2 = F.log_softmax(logits2, dim=1)
                kl = 0.5 * (F.kl_div(p1, p2.exp(), reduction="batchmean") + F.kl_div(p2, p1.exp(), reduction="batchmean"))
                loss = 0.5*(ce1+ce2) + training_config_dict["use_rdrop_alpha"]*kl
            else:
                if use_mixup:
                    loss = lam * criterion_fn(logits, ya) + (1 - lam) * criterion_fn(logits, yb2)
                else:
                    loss = criterion_fn(logits, yb)

            loss = loss / training_config_dict["grad_accumulation_steps"]

        scaler_amp.scale(loss).backward()
        step_in_epoch += 1

        if (step_in_epoch % training_config_dict["grad_accumulation_steps"]) == 0:
            torch.nn.utils.clip_grad_norm_(model_instance.parameters(), training_config_dict["max_grad_norm"])
            scaler_amp.step(optimizer_torch)
            scaler_amp.update()
            optimizer_torch.zero_grad(set_to_none=True)
            global_step_counter += 1

        running_loss_sum += loss.item() * training_config_dict["grad_accumulation_steps"]

    # ----- epoch end: evaluate -----
    hold_probs = evaluate_model_probabilities(model_instance, hold_loader)
    hold_auc = roc_auc_score(y_hold_np, hold_probs)

    print(f"[Epoch {epoch_index:02d}] avg_loss={running_loss_sum/max(1, step_in_epoch):.4f} | holdout AUC={hold_auc:.5f}")

    if hold_auc > best_holdout_auc + 1e-4:
        best_holdout_auc = hold_auc
        best_state_dict = copy.deepcopy(model_instance.state_dict())
        best_epoch_index = epoch_index
        epochs_without_improvement = 0
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= training_config_dict["early_stop_patience_epochs"]:
            print("Early stopping triggered.")
            break

# restore best
if best_state_dict is not None:
    model_instance.load_state_dict(best_state_dict)

print(f"Best epoch: {best_epoch_index} | Best holdout AUC: {best_holdout_auc:.5f}")

# final report on holdout (thresholded metrics)
final_hold_probs = evaluate_model_probabilities(model_instance, hold_loader)
print_classification_report(y_hold_np, final_hold_probs, tag="HOLDOUT-FINAL")

# save artifacts
torch.save(model_instance.state_dict(), output_dir_path / "ultrahybrid_xlpp_best.pt")
np.save(output_dir_path / "holdout_probabilities.npy", final_hold_probs)
with open(output_dir_path / "training_report.json", "w") as f:
    json.dump(dict(
        best_epoch=int(best_epoch_index),
        best_holdout_auc=float(best_holdout_auc),
        kept_dimensions=kept_dim_indices.tolist(),
        effective_input_dim=effective_input_dim,
        config=training_config_dict
    ), f, indent=2)
print("Saved: ultrahybrid_xlpp_best.pt, holdout_probabilities.npy, training_report.json, kept_dim_indices.npy, whiten_mu.npy, whiten_sd.npy")


[Epoch 01] avg_loss=2.1514 | holdout AUC=0.92340
[Epoch 02] avg_loss=1.2664 | holdout AUC=0.95212
[Epoch 03] avg_loss=1.2074 | holdout AUC=0.95699
[Epoch 04] avg_loss=1.1040 | holdout AUC=0.95530
[Epoch 05] avg_loss=1.1105 | holdout AUC=0.95826
[Epoch 06] avg_loss=1.0424 | holdout AUC=0.96061
[Epoch 07] avg_loss=1.1109 | holdout AUC=0.96003
[Epoch 08] avg_loss=1.0761 | holdout AUC=0.95744
[Epoch 09] avg_loss=0.9249 | holdout AUC=0.96130
[Epoch 10] avg_loss=1.0502 | holdout AUC=0.96205
[Epoch 11] avg_loss=0.9938 | holdout AUC=0.96249
[Epoch 12] avg_loss=0.8428 | holdout AUC=0.96244
Best epoch: 11 | Best holdout AUC: 0.96249

[HOLDOUT-FINAL] AUC=0.9625  PR-AUC=0.9610  F1@thr=0.9039  ACC@thr=0.9008  Thr=0.108
        CM [[1063  162]
            [  81 1143]]
Saved: ultrahybrid_xlpp_best.pt, holdout_probabilities.npy, training_report.json, kept_dim_indices.npy, whiten_mu.npy, whiten_sd.npy
