# 02 — Train Models (RF, IF, ET-SSL)

This notebook trains the three selected approaches under **two scenarios**:

- **Base** — use all classes normally.  
- **Zero-Day** — exclude `Bot`, `Web Attack - Brute Force`, and `Infiltration` from train/val; they only reappear in test.  

### Implemented approaches
- **Random Forest (RF, supervised)** — with randomized hyperparameter search, class-weight balancing, and threshold tuning.  
- **Isolation Forest (IF, unsupervised)** — trained on BENIGN only, small grid over core hyperparameters.  
- **ET-SSL (self-supervised)** — contrastive encoder with augmentations, batch norm, dropout, and anomaly-scoring setup.  

### Notes
- All runs are **seeded** for reproducibility.  
- RF/IF: implemented via scikit-learn pipelines (including PCA grouping transformer).  
- ET-SSL: trained in PyTorch, GPU-accelerated if available.  
- Model-specific metadata (e.g., thresholds, centroids, scaling stats) are saved for later evaluation.  

**Inputs**: `data/splits/*` from notebook 01  
**Outputs**:  
- RF/IF: `models/*.joblib` + `*_meta.json`  
- ET-SSL: `models/etssl_*_encoder.pt` + `etssl_*_meta.json`


In [None]:
# %% [markdown]
# ## Imports & configuration

import os, json, time, math, warnings
from pathlib import Path
from typing import Tuple, Dict, Any, List
import numpy as np
import pandas as pd

import joblib
from joblib import dump, load
from tqdm.auto import tqdm, trange

from sklearn.pipeline import Pipeline, make_pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import MinMaxScaler
from sklearn.feature_selection import VarianceThreshold
from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit, ParameterSampler
from sklearn.preprocessing import RobustScaler, QuantileTransformer
from sklearn.metrics import f1_score
from sklearn.ensemble import RandomForestClassifier, IsolationForest

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

ROOT = Path(".").resolve()
SPLITS_DIR = ROOT / "data" / "splits"
MODELS_DIR = ROOT / "models"
MODELS_DIR.mkdir(parents=True, exist_ok=True)

LABEL_COL = "label"
BENIGN_NAME = "BENIGN"
USE_CG_PCA = True  # toggle
CV_SUBSAMPLE_CAP = 150_000   # subsample for RF CV if train is huge
FINAL_RF_TREES = 800         # warm-start grows to this many trees

In [None]:
# %% [markdown]
# ## Utilities: split loader, threshold search, CorrelatedGroupsPCA

def load_split(approach: str, scenario: str, split: str) -> pd.DataFrame:
    p = SPLITS_DIR / approach / scenario / f"{split}.parquet"
    if not p.exists():
        raise FileNotFoundError(p)
    return pd.read_parquet(p)

def features_and_labels(df: pd.DataFrame) -> Tuple[pd.DataFrame, np.ndarray]:
    X = df.drop(columns=[LABEL_COL, "Label"], errors="ignore")
    y = df[LABEL_COL].values.astype(np.int64)
    return X, y

def best_threshold_from_scores(y_true: np.ndarray, scores: np.ndarray, maximize: str = "f1_macro") -> Tuple[float, float]:
    """Find threshold that maximizes macro-F1 by scanning unique score quantiles."""
    import numpy as np
    from sklearn.metrics import f1_score
    y_true = y_true.astype(int)
    # quantile grid (fast, robust)
    qs = np.linspace(0.01, 0.99, 99)
    thr_grid = np.quantile(scores, qs)
    best = (-1.0, 0.5)
    for thr in np.unique(thr_grid):
        yhat = (scores >= thr).astype(int)
        f1m = f1_score(y_true, yhat, average="macro", zero_division=0)
        if f1m > best[0]:
            best = (f1m, thr)
    return float(best[1]), float(best[0])

class CorrelatedGroupsPCA:
    """Lightweight transformer: groups |corr|>=rho and applies PCA per group to keep var_keep."""
    def __init__(self, rho: float = 0.95, var_keep: float = 0.99):
        self.rho = float(rho)
        self.var_keep = float(var_keep)
        self.groups_ = None
        self.columns_ = None
        self.pca_models_ = None

    def fit(self, X, y=None):
        import numpy as np
        import pandas as pd
        from sklearn.decomposition import PCA
        if isinstance(X, np.ndarray):
            X_df = pd.DataFrame(X, columns=[f"f{i}" for i in range(X.shape[1])])
        else:
            X_df = X.copy()
        self.columns_ = list(X_df.columns)
        corr = X_df.corr(numeric_only=True).fillna(0.0).values
        n = corr.shape[0]
        visited = np.zeros(n, dtype=bool)
        groups = []
        for i in range(n):
            if visited[i]: continue
            g = [i]
            visited[i] = True
            for j in range(i+1, n):
                if visited[j]: continue
                if abs(corr[i, j]) >= self.rho:
                    g.append(j); visited[j] = True
            groups.append(sorted(g))
        # Build models per group
        self.pca_models_ = []
        for g in groups:
            if len(g) == 1:
                self.pca_models_.append(("pass", g, None))
            else:
                pca = PCA(n_components=None, svd_solver="full", random_state=SEED)
                pca.fit(X_df.iloc[:, g])
                # select number of components to reach var_keep
                csum = np.cumsum(pca.explained_variance_ratio_)
                k = int(np.searchsorted(csum, self.var_keep) + 1)
                k = max(1, min(k, len(g)))
                pca_k = PCA(n_components=k, svd_solver="full", random_state=SEED).fit(X_df.iloc[:, g])
                self.pca_models_.append(("pca", g, pca_k))
        self.groups_ = groups
        return self

    def transform(self, X):
        import numpy as np
        import pandas as pd
        if isinstance(X, np.ndarray):
            X_df = pd.DataFrame(X, columns=self.columns_)
        else:
            X_df = X.copy()
        outs = []
        for kind, g, model in self.pca_models_:
            if kind == "pass":
                outs.append(X_df.iloc[:, g].values)
            else:
                outs.append(model.transform(X_df.iloc[:, g]))
        return np.concatenate(outs, axis=1)

    def fit_transform(self, X, y=None):
        return self.fit(X, y).transform(X)

In [None]:
# %% [markdown]
# ## Random Forest — with CV progress & warm-start final fit

from joblib import parallel_backend
from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit, ParameterSampler

def _rf_pipe_for_cv():
    return Pipeline([
        ("var0", VarianceThreshold(0.0)),
        ("impute", SimpleImputer(strategy="median")),
        ("scaler", MinMaxScaler()),
        ("cgpca", CorrelatedGroupsPCA(0.95, 0.99) if USE_CG_PCA else "passthrough"),
        ("rf", RandomForestClassifier(
            n_estimators=200,
            class_weight="balanced_subsample",
            max_depth=None, max_features="sqrt", min_samples_leaf=1,
            n_jobs=-1, random_state=SEED
        ))
    ])

def train_rf(scenario: str) -> Dict[str, Any]:
    # Load splits
    tr = load_split("rf", scenario, "train")
    va = load_split("rf", scenario, "val")
    X_tr, y_tr = features_and_labels(tr)
    X_va, y_va = features_and_labels(va)

    # Stratified subsample for CV
    if len(X_tr) > CV_SUBSAMPLE_CAP:
        sss = StratifiedShuffleSplit(n_splits=1, train_size=CV_SUBSAMPLE_CAP, random_state=SEED)
        idx, _ = next(sss.split(X_tr, y_tr))
        X_cv, y_cv = X_tr.iloc[idx].copy(), y_tr[idx].copy()
    else:
        X_cv, y_cv = X_tr, y_tr

    # Param space
    param_space = {
        "rf__max_depth": [10, 20, 30, None],
        "rf__max_features": ["sqrt", 0.6, 0.8],
        "rf__min_samples_leaf": [1, 2, 4],
    }
    n_iter = 12
    sampler = list(ParameterSampler(param_space, n_iter=n_iter, random_state=SEED))
    cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=SEED)
    total_fits = n_iter * cv.get_n_splits()

    # Manual CV (live progress)
    pipe_template = _rf_pipe_for_cv()
    fit_times: List[float] = []
    results: List[Tuple[dict, float]] = []

    print(f"[RF/{scenario}] CV: {n_iter} candidates × {cv.get_n_splits()} folds = {total_fits} fits on {len(X_cv)} rows")
    t_search0 = time.time()
    with parallel_backend("threading"):
        pbar = tqdm(total=total_fits, desc=f"RF {scenario} CV", unit="fit")
        for params in sampler:
            scores = []
            for tr_idx, va_idx in cv.split(X_cv, y_cv):
                Xtr, Xv = X_cv.iloc[tr_idx], X_cv.iloc[va_idx]
                ytr, yv = y_cv[tr_idx], y_cv[va_idx]
                model = Pipeline(pipe_template.steps)  # fresh clone
                model.set_params(**params)
                t0 = time.time()
                model.fit(Xtr, ytr)
                prob = model.predict_proba(Xv)[:, 1]
                yhat = (prob >= 0.5).astype(int)
                f1m = f1_score(yv, yhat, average="macro", zero_division=0)
                dt = time.time() - t0
                fit_times.append(dt); scores.append(f1m)

                done = len(fit_times)
                avg = np.mean(fit_times)
                eta = max(0, (total_fits - done) * avg)
                pbar.set_postfix({"avg_s": f"{avg:.1f}", "ETA": time.strftime("%H:%M:%S", time.gmtime(eta))})
                pbar.update(1)
            results.append((params, float(np.mean(scores))))
        pbar.close()
    search_time = time.time() - t_search0

    results.sort(key=lambda kv: kv[1], reverse=True)
    best_params, best_cv_f1 = results[0]
    print(f"[RF/{scenario}] best CV macro-F1={best_cv_f1:.4f} params={best_params}")

    # Final fit on full train with warm-start & ETA (no resampling)
    prep = Pipeline([
        ("var0", VarianceThreshold(0.0)),
        ("impute", SimpleImputer(strategy="median")),
        ("scaler", MinMaxScaler()),
        ("cgpca", CorrelatedGroupsPCA(0.95, 0.99) if USE_CG_PCA else "passthrough"),
    ])
    Xtr_tf = prep.fit_transform(X_tr, y_tr)
    Xva_tf = prep.transform(X_va)

    # If CV found max_depth=None, cap for runtime safety
    max_depth_final = best_params.get("rf__max_depth", None)
    if max_depth_final is None:
        max_depth_final = 30

    rf = RandomForestClassifier(
        n_estimators=0, warm_start=True,  # incremental
        max_depth=max_depth_final,
        max_features=best_params.get("rf__max_features", "sqrt"),
        min_samples_leaf=best_params.get("rf__min_samples_leaf", 1),
        class_weight="balanced_subsample",
        n_jobs=-1, random_state=SEED
    )

    n_total = FINAL_RF_TREES
    step = 100
    pbar = tqdm(total=n_total, desc=f"RF {scenario} final trees", unit="tree")
    times = []; built = 0
    t_fit0 = time.time()
    while built < n_total:
        target = min(n_total, built + step)
        rf.set_params(n_estimators=target)
        t0 = time.time()
        rf.fit(Xtr_tf, y_tr)
        dt = time.time() - t0
        times.append(dt)
        built = target
        avg_per_tree = (np.sum(times) / built) if built else 0.0
        eta = (n_total - built) * avg_per_tree
        pbar.set_postfix({"avg_s/tree": f"{avg_per_tree:.2f}", "ETA": time.strftime("%H:%M:%S", time.gmtime(eta))})
        pbar.update(step)
    pbar.close()
    final_time = time.time() - t_fit0

    # Threshold on validation
    val_proba = rf.predict_proba(Xva_tf)[:, 1]
    thr, f1m = best_threshold_from_scores(y_va, val_proba)

    # Wrap into a single pipeline for saving
    from sklearn.pipeline import make_pipeline
    final_pipe = make_pipeline(prep, rf)

    model_path = MODELS_DIR / f"rf_{scenario}.joblib"
    joblib.dump(final_pipe, model_path)
    meta = {
        "scenario": scenario,
        "threshold": float(thr),
        "val_macro_f1_at_threshold": float(f1m),
        "best_params": {**best_params,
                        "final_n_estimators": int(n_total),
                        "max_depth_final": int(max_depth_final) if max_depth_final is not None else None,
                        "class_weight": "balanced_subsample"},
        "cv_subsample_rows": int(len(X_cv)),
        "search_time_sec": float(search_time),
        "final_train_time_sec": float(final_time),
        "use_cg_pca": bool(USE_CG_PCA)
    }
    with open(MODELS_DIR / f"rf_{scenario}_meta.json", "w") as f:
        json.dump(meta, f, indent=2)

    return {"model_path": str(model_path), **meta}

In [None]:
# %% [markdown]
# ## Isolation Forest — stronger search (contamination, seeds, scalers) with ETA

def train_if(scenario: str, mode: str = "speed") -> Dict[str, Any]:
    """
    mode: "speed" (sehr schnell, guter Start) oder "balanced" (etwas breiter, dennoch flott)
    """
    assert mode in {"speed", "balanced"}
    tr = load_split("if", scenario, "train")   # benign-only
    va = load_split("if", scenario, "val")
    X_tr_df, _     = features_and_labels(tr)
    X_va_df, y_va  = features_and_labels(va)

    if mode == "speed":
        PREPROCS = [("minmax", MinMaxScaler())]
        grid_ne   = [200, 400]
        grid_ms   = ["auto", 256]        # "auto" ~ min(256, n_samples)
        grid_cont = [0.002, 0.005, 0.01] # CICIDS << 1%
        grid_seed = [SEED]
    else:  # balanced
        PREPROCS = [
            ("minmax", MinMaxScaler()),
            ("robust", RobustScaler()),
            ("quantile", QuantileTransformer(
                n_quantiles=min(1000, max(50, len(X_tr_df)//50)),
                output_distribution="normal", subsample=int(1e6), random_state=SEED
            )),
        ]
        grid_ne   = [200, 400, 800]
        grid_ms   = ["auto", 256, 512]
        grid_cont = [0.001, 0.002, 0.005, 0.01]
        grid_seed = [SEED, SEED+7]

    prep = Pipeline([
        ("var0", VarianceThreshold(0.0)),
        ("impute", SimpleImputer(strategy="median")),
        ("scaler", PREPROCS[0][1]),                    # placeholder; wird für caching überschrieben
        ("cgpca", CorrelatedGroupsPCA(0.95, 0.99) if USE_CG_PCA else "passthrough"),
    ])

    cached = {}  # name -> (Xtr_tf, Xva_tf, fitted_prep)
    fit_times = []
    best = (None, -1.0, None)  # (best_pipe, best_f1, params)

    total_fits = 0
    for _, __ in PREPROCS:
        total_fits += len(grid_ne) * len(grid_ms) * len(grid_cont) * len(grid_seed)
    pbar = tqdm(total=total_fits, desc=f"IF {scenario} ({mode})", unit="fit")

    for pre_name, pre in PREPROCS:
        if pre_name not in cached:
            t0 = time.time()
            prep_pre = Pipeline([
                ("var0", VarianceThreshold(0.0)),
                ("impute", SimpleImputer(strategy="median")),
                ("scaler", pre),
                ("cgpca", CorrelatedGroupsPCA(0.95, 0.99) if USE_CG_PCA else "passthrough"),
            ])
            X_tr_tf = prep_pre.fit_transform(X_tr_df)
            X_va_tf = prep_pre.transform(X_va_df)
            fit_times.append(time.time() - t0)
            cached[pre_name] = (X_tr_tf, X_va_tf, prep_pre)
        else:
            X_tr_tf, X_va_tf, prep_pre = cached[pre_name]

        for ne in grid_ne:
            for ms in grid_ms:
                for cont in grid_cont:
                    for s in grid_seed:
                        t0 = time.time()
                        iforest = IsolationForest(
                            n_estimators=ne, max_samples=ms, contamination=cont,
                            n_jobs=-1, random_state=int(s), bootstrap=False, warm_start=False
                        )
                        iforest.fit(X_tr_tf)  # unsupervised on benigns
                        dt = time.time() - t0
                        fit_times.append(dt)

                        scores = -iforest.score_samples(X_va_tf)
                        thr, f1m = best_threshold_from_scores(y_va, scores)

                        if f1m > best[1]:
                            best = ( (pre_name, ne, ms, cont, int(s), float(thr)),
                                     float(f1m),
                                     {"preprocess": pre_name, "n_estimators": ne,
                                      "max_samples": ms, "contamination": cont,
                                      "seed": int(s), "threshold": float(thr)} )

                        # ETA
                        avg = float(np.mean(fit_times))
                        left = max(0.0, (pbar.total - pbar.n - 1) * avg)
                        pbar.set_postfix({"avg_s": f"{avg:.1f}", "ETA": time.strftime("%H:%M:%S", time.gmtime(left))})
                        pbar.update(1)
    pbar.close()

    (pre_name, ne, ms, cont, s, thr) = best[0]
    if pre_name == "minmax":
        pre_obj = MinMaxScaler()
    elif pre_name == "robust":
        pre_obj = RobustScaler()
    else:
        pre_obj = QuantileTransformer(
            n_quantiles=min(1000, max(50, len(X_tr_df)//50)),
            output_distribution="normal", subsample=int(1e6), random_state=SEED
        )

    final_pipe = Pipeline([
        ("var0", VarianceThreshold(0.0)),
        ("impute", SimpleImputer(strategy="median")),
        ("scaler", pre_obj),
        ("cgpca", CorrelatedGroupsPCA(0.95, 0.99) if USE_CG_PCA else "passthrough"),
        ("iforest", IsolationForest(
            n_estimators=ne, max_samples=ms, contamination=cont,
            n_jobs=-1, random_state=int(s), bootstrap=False, warm_start=False
        ))
    ])
    t0 = time.time()
    final_pipe.fit(X_tr_df)
    total_time = float(np.sum(fit_times) + (time.time() - t0))

    model_path = MODELS_DIR / f"if_{scenario}.joblib"
    dump(final_pipe, model_path)

    meta = {
        "scenario": scenario,
        "threshold": float(thr),
        "val_macro_f1_at_threshold": best[1],
        "best_params": {
            "preprocess": pre_name,
            "n_estimators": int(ne),
            "max_samples": ms if isinstance(ms, str) else int(ms),
            "contamination": float(cont),
            "seed": int(s),
        },
        "train_time_sec": total_time,
        "use_cg_pca": bool(USE_CG_PCA),
        "mode": mode,
    }
    with open(MODELS_DIR / f"if_{scenario}_meta.json", "w") as f:
        json.dump(meta, f, indent=2)

    return {"model_path": str(model_path), **meta}

In [None]:
# %% [markdown]
# ## ET-SSL — stronger encoder, cosine LR, AMP, early stopping, better aug, live ETA

def train_etssl_for_scenario(
    scenario: str,
    epochs: int = 20,
    batch_size: int = 2048,
    emb_dim: int = 64,
    proj_dim: int = 128,
    tau: float = 0.5,
    alpha: float = 0.95,  # slower EMA for μ_norm
    gamma: float = 0.2,
    weight_decay: float = 1e-5,
    lr: float = 1e-3,
    patience: int = 4,
    grad_clip: float = 1.0,
) -> Dict[str, Any]:
    # Load splits
    tr = load_split("etssl", scenario, "train")
    va = load_split("etssl", scenario, "val")
    te = load_split("etssl", scenario, "test")
    X_tr_df, _    = features_and_labels(tr)
    X_va_df, y_va = features_and_labels(va)
    X_te_df, _    = features_and_labels(te)

    # Impute + scale on TRAIN only, then NumPy for stable sklearn interop
    imp = SimpleImputer(strategy="median").fit(X_tr_df)
    X_tr = imp.transform(X_tr_df).astype(np.float32, copy=False)
    X_va = imp.transform(X_va_df).astype(np.float32, copy=False)
    X_te = imp.transform(X_te_df).astype(np.float32, copy=False)

    sc  = MinMaxScaler().fit(X_tr)
    X_tr = sc.transform(X_tr).astype(np.float32, copy=False)
    X_va = sc.transform(X_va).astype(np.float32, copy=False)
    X_te = sc.transform(X_te).astype(np.float32, copy=False)

    d_in = X_tr.shape[1]

    # Encoder with BN+Dropout for better generalization
    class Encoder(nn.Module):
        def __init__(self, d, emb=64, proj=128, p_drop=0.1):
            super().__init__()
            self.back = nn.Sequential(
                nn.Linear(d, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(p_drop),
                nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(p_drop),
                nn.Linear(256, emb)
            )
            self.proj = nn.Sequential(
                nn.BatchNorm1d(emb), nn.ReLU(),
                nn.Linear(emb, proj)
            )
        def forward(self, x):
            h = self.back(x)
            z = self.proj(h)
            return z, h

    # Stronger augmentations: noise + scaling + feature dropout
    def augment(x, ns=0.03, drop_p=0.05):
        noise = torch.randn_like(x) * ns
        scale = torch.empty((x.size(0),1), device=x.device).uniform_(0.9, 1.1)
        x2 = (x + noise) * scale
        if drop_p > 0.0:
            mask = (torch.rand_like(x2) > drop_p).float()
            x2 = x2 * mask
        return x2

    # NT-Xent with temperature tau
    def nt_xent(z1, z2, tau=0.5):
        # L2-normalize
        z1 = nn.functional.normalize(z1, dim=1)
        z2 = nn.functional.normalize(z2, dim=1)
        z  = torch.cat([z1, z2], dim=0)  # (2B, d)
    
        # compute similarity in float32 to avoid fp16 over/underflow
        sim = (z.float() @ z.float().t())  # (2B, 2B), float32
    
        # mask self-similarity (diagonal) to -inf (safe with float32)
        sim.fill_diagonal_(float("-inf"))
    
        B = z1.size(0)
        targets = torch.cat([torch.arange(B, 2*B), torch.arange(0, B)], dim=0).to(sim.device)
    
        # keep logits in float32 for CE stability
        return nn.CrossEntropyLoss()(sim / tau, targets)


    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.backends.cudnn.benchmark = True
    torch.set_num_threads(os.cpu_count() or 8)

    enc = Encoder(d_in, emb=emb_dim, proj=proj_dim).to(device)
    opt = torch.optim.AdamW(enc.parameters(), lr=lr, weight_decay=weight_decay)
    # Cosine schedule to a small floor
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=max(1, epochs), eta_min=lr*0.05)

    dl  = DataLoader(
        TensorDataset(torch.from_numpy(X_tr), torch.zeros(len(X_tr))),
        batch_size=batch_size, shuffle=True, drop_last=True, num_workers=0, pin_memory=True
    )

    scaler_amp = torch.amp.GradScaler("cuda", enabled=(device=="cuda"))


    # Early stopping on validation macro-F1
    best_f1 = -1.0
    best_state = None
    best_meta  = None
    no_improve = 0
    t0 = time.time()

    enc.train(); mu = None
    for ep in trange(epochs, desc=f"ETSSL {scenario} epochs"):
        tb = time.time(); steps = 0
        pb = tqdm(total=len(dl), desc=f"epoch {ep+1}/{epochs}", leave=False, unit="step")
        for xb,_ in dl:
            xb = xb.to(device, non_blocking=True)
            with torch.amp.autocast("cuda", enabled=(device=="cuda")):
                z1,h1 = enc(augment(xb))
                z2,h2 = enc(augment(xb))
                loss_c = nt_xent(z1, z2, tau=tau)

                # anomaly regularization (push predicted anomalies away from μ_norm)
                h1n = nn.functional.normalize(h1, dim=1)
                if mu is None:
                    mu = h1n.mean(0).detach()
                dn = ((h1n - mu)**2).sum(1)
                theta = torch.quantile(dn.detach(), 0.95)
                lanom = (dn[dn > theta]).mean() if (dn > theta).any() else 0.0

                loss = loss_c + gamma * (lanom if isinstance(lanom, float) else lanom)

            opt.zero_grad(set_to_none=True)
            scaler_amp.scale(loss).backward()
            if grad_clip is not None and grad_clip > 0:
                scaler_amp.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(enc.parameters(), grad_clip)
            scaler_amp.step(opt)
            scaler_amp.update()

            # EMA update for μ_norm using predicted normals
            with torch.no_grad():
                normals = h1n[dn <= theta]
                if normals.size(0) > 0:
                    mu = alpha*mu + (1-alpha)*normals.mean(0)

            steps += 1
            avg_s = (time.time()-tb) / max(1, steps)
            left = (len(dl) - steps) * avg_s
            pb.set_postfix({"avg_s": f"{avg_s:.2f}", "ETA": time.strftime("%H:%M:%S", time.gmtime(left))})
            pb.update(1)
        pb.close()
        sched.step()

        # ---- validation at end of epoch ----
        enc.eval()
        with torch.no_grad():
            def enc_np(X):
                Z = []
                for i in range(0, len(X), 4096):
                    xb = torch.from_numpy(X[i:i+4096]).to(device)
                    _, h = enc(xb)
                    Z.append(nn.functional.normalize(h, dim=1).cpu().numpy())
                return np.vstack(Z)

            Zva = enc_np(X_va)

            # centroids from validation labels (fallback if missing class)
            mu_norm = Zva[y_va==0].mean(0) if (y_va==0).any() else Zva.mean(0)
            if (y_va==1).any():
                mu_anom = Zva[y_va==1].mean(0)
            else:
                dn_val = ((Zva - mu_norm)**2).sum(1)
                k = max(1, int(0.02 * len(dn_val)))
                mu_anom = Zva[np.argsort(dn_val)[-k:]].mean(0)

            def scores(Z, kappa):
                dn = ((Z - mu_norm)**2).sum(1)
                da = ((Z - mu_anom)**2).sum(1)
                return dn - kappa*da

            # Wider kappa sweep
            kappas = [0.0, 0.25, 0.5, 0.75, 1.0, 1.5, 2.0]
            best_local = {"kappa":0.0,"theta":0.0,"f1":-1.0}
            for k in kappas:
                s = scores(Zva, k)
                th, f1m = best_threshold_from_scores(y_va, s)
                if f1m > best_local["f1"]:
                    best_local = {"kappa": float(k), "theta": float(th), "f1": float(f1m)}

        # keep best epoch
        if best_local["f1"] > best_f1:
            best_f1 = best_local["f1"]
            no_improve = 0
            best_state = {k: v.detach().cpu() if isinstance(v, torch.Tensor) else v
                          for k,v in enc.state_dict().items()}
            best_meta = {
                "scenario": scenario,
                "train_time_sec": float(time.time() - t0),
                "emb_dim": int(emb_dim), "proj_dim": int(proj_dim),
                "kappa": best_local["kappa"], "theta": best_local["theta"],
                "mu_norm": mu_norm.tolist(), "mu_anom": mu_anom.tolist(),
                "imputer_statistics_": imp.statistics_.tolist(),
                "scaler_min_": sc.min_.tolist() if hasattr(sc, "min_") else None,
                "scaler_scale_": sc.scale_.tolist() if hasattr(sc, "scale_") else None,
                "best_val_macro_f1": float(best_f1),
                "epochs_trained": int(ep+1),
                "lr_final": float(sched.get_last_lr()[0])
            }
        else:
            no_improve += 1

        enc.train()
        if no_improve >= patience:
            print(f"[ETSSL/{scenario}] Early stopping at epoch {ep+1} (best F1={best_f1:.4f})")
            break

    # Save best state & meta (ensure we actually save the best)
    if best_state is None:
        best_state = enc.state_dict()
        # build minimal meta if necessary
        if best_meta is None:
            best_meta = {
                "scenario": scenario,
                "train_time_sec": float(time.time() - t0),
                "emb_dim": int(emb_dim), "proj_dim": int(proj_dim),
                "kappa": 0.0, "theta": 0.0,
                "mu_norm": np.zeros(emb_dim, dtype=float).tolist(),
                "mu_anom": np.zeros(emb_dim, dtype=float).tolist(),
                "imputer_statistics_": imp.statistics_.tolist(),
                "scaler_min_": sc.min_.tolist() if hasattr(sc, "min_") else None,
                "scaler_scale_": sc.scale_.tolist() if hasattr(sc, "scale_") else None,
                "best_val_macro_f1": float(-1.0),
                "epochs_trained": 0,
                "lr_final": float(sched.get_last_lr()[0])
            }

    enc_path = MODELS_DIR / f"etssl_{scenario}_encoder.pt"
    torch.save(best_state, enc_path)
    with open(MODELS_DIR / f"etssl_{scenario}_meta.json", "w") as f:
        json.dump(best_meta, f, indent=2)

    return best_meta


In [None]:
# %% [markdown]
# ## Train all models for both scenarios

all_meta = {}

for scenario in ["base", "zeroday"]:
    print("\n==============================")
    print(f"Scenario: {scenario.upper()}")
    print("==============================")

    # RF
    m_rf = train_rf(scenario)
    all_meta[f"rf/{scenario}"] = m_rf

    # IF
    m_if = train_if(scenario)
    all_meta[f"if/{scenario}"] = m_if

    # ETSSL
    m_et = train_etssl_for_scenario(scenario)
    all_meta[f"etssl/{scenario}"] = m_et

# Write a compact summary for quick inspection
with open(MODELS_DIR / "training_summary.json", "w") as f:
    json.dump(all_meta, f, indent=2)

all_meta