# train_ensemble_final.py

Robust ensemble trainer (Spatial + Temporal) — final, stable version.

Behavior:
- If predictions/ (spatial + temporal) exist, uses them (fast).
- Otherwise computes features from embeddings using spatial head + temporal model.
- Logs bad files to OUT_DIR/bad_embeddings.txt and OUT_DIR/bad_preds.txt
- Trains a calibrated logistic regression (if calibration possible), otherwise a plain LR.
- Saves trained ensemble and a small results text file.


In [None]:
from pathlib import Path
import json, time, warnings, sys
import numpy as np
import random
from collections import Counter

# sklearn / torch / joblib imports
from sklearn.linear_model import LogisticRegression
from sklearn.calibration import CalibratedClassifierCV
from sklearn.model_selection import StratifiedKFold
from sklearn.exceptions import NotFittedError
from sklearn.metrics import roc_auc_score
import joblib

import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils
import timm

In [None]:
# ---------------- CONFIG ----------------
ROOT = Path.cwd().parent
DATA_DIR = ROOT / "data"
EMB_ROOT = ROOT / "embeddings"                 # embeddings/<split>/<stem>.npy
PRED_SPATIAL_DIR = ROOT / "predictions" / "spatial"    # optional: per-video saved scalar preds
PRED_TEMPORAL_DIR = ROOT / "predictions" / "temporal"

CHECKPOINT_DIR = ROOT / "checkpoints" / "ensemble"
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

OUT_CACHE = ROOT / "ensemble_features"        # caches X/y as npz
OUT_CACHE.mkdir(parents=True, exist_ok=True)

SPATIAL_CKPT = ROOT / "checkpoints" / "spatial" / "spatial_best_valAUC.pth"
TEMPORAL_CKPT = ROOT / "checkpoints" / "temporal" / "temporal_best.pth"  # or temporal_best.pth

LABELS_JSON = DATA_DIR / "labels.json"

BAD_EMB_LOG = CHECKPOINT_DIR / "bad_embeddings.txt"
BAD_PRED_LOG = CHECKPOINT_DIR / "bad_preds.txt"
RESULTS_TXT = CHECKPOINT_DIR / "ensemble_results.txt"
ENSEMBLE_OUT = CHECKPOINT_DIR / "ensemble_best.pkl"

SPLITS = ["train", "val", "test"]

# Ensemble / training hyperparams
RNG_SEED = 42
CALIBRATION_CV = 5        # fallback to smaller value if dataset small
LR_MAX_ITER = 2000

In [None]:
# ------------ determinism ------------
def set_seed(seed=RNG_SEED):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    # deterministic cudnn (may slow)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(RNG_SEED)

In [None]:
# ------------ helpers ------------
def safe_load_checkpoint(path):
    if not Path(path).exists():
        return None
    ck = torch.load(path, map_location="cpu")
    return ck

def strip_module_prefix(state_dict):
    new = {}
    for k, v in state_dict.items():
        nk = k.replace("module.", "") if k.startswith("module.") else k
        new[nk] = v
    return new

def is_bad_array(arr):
    """Check numpy arr for problems; returns (bad_bool, reason_str)"""
    if not isinstance(arr, np.ndarray):
        return True, f"[NOT_NDARRAY] type={type(arr)}"
    if arr.ndim != 2:
        return True, f"[DIM] ndim={arr.ndim} shape={getattr(arr,'shape',None)}"
    if arr.shape[0] == 0:
        return True, f"[EMPTY] shape={arr.shape}"
    if np.isnan(arr).any():
        return True, f"[NaN] min={np.nanmin(arr)} max={np.nanmax(arr)} mean={np.nanmean(arr)}"
    if np.isinf(arr).any():
        return True, "[INF]"
    # extreme values guard
    mx = np.max(np.abs(arr))
    if np.isnan(mx) or mx > 1e6:
        return True, f"[EXTREME] abs_max={mx}"
    return False, ""

def safe_auc(y_true, y_pred):
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    if np.isnan(y_pred).any():
        warnings.warn("NaNs in predictions — skipping AUC")
        return float("nan")
    if len(np.unique(y_true)) < 2:
        warnings.warn(f"Single-class y_true={np.unique(y_true)} — AUC undefined")
        return float("nan")
    return roc_auc_score(y_true, y_pred)

In [None]:
# ------------ Model classes used if we compute features ------------
class SpatialHead(nn.Module):
    def __init__(self, backbone_name="efficientnet_b3"):
        super().__init__()
        backbone = timm.create_model(backbone_name, pretrained=False, num_classes=0)
        self.backbone = backbone
        self.feat_dim = backbone.num_features
        self.head = nn.Sequential(
            nn.Linear(self.feat_dim, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(512, 1)
        )
    def forward(self, x):
        feats = self.backbone(x)
        return self.head(feats).squeeze(1)

In [None]:
class TemporalModel(nn.Module):
    def __init__(self, feat_dim, hidden=512, layers=2, dropout=0.3, bidirectional=True):
        super().__init__()
        self.lstm = nn.LSTM(feat_dim, hidden, layers, batch_first=True,
                            bidirectional=bidirectional, dropout=dropout if layers>1 else 0)
        self.out_dim = hidden * (2 if bidirectional else 1)
        self.att = nn.Linear(self.out_dim, 1)
        self.head = nn.Sequential(
            nn.Linear(self.out_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 1)
        )
    def forward(self, x, lengths):
        # x: [B, T, feat], lengths: [B]
        # pack for safe RNN handling
        if lengths.numel() == 0:
            raise ValueError("Empty lengths in TemporalModel")
        lengths_sorted, perm = lengths.sort(descending=True)
        x_sorted = x[perm]
        packed = rnn_utils.pack_padded_sequence(x_sorted, lengths_sorted.cpu(), batch_first=True, enforce_sorted=True)
        packed_out, _ = self.lstm(packed)
        out_unpacked, _ = rnn_utils.pad_packed_sequence(packed_out, batch_first=True)
        # unsort
        _, unperm = perm.sort()
        out = out_unpacked[unperm]
        lengths = lengths[unperm]
        # attention pooling with mask
        B, T, H = out.shape
        scores = self.att(out).squeeze(-1)   # [B, T]
        mask = torch.arange(T, device=out.device).unsqueeze(0) >= lengths.unsqueeze(1)
        scores = scores.masked_fill(mask, -1e9)
        weights = torch.softmax(scores, dim=1)
        weights = torch.nan_to_num(weights, nan=0.0, posinf=0.0, neginf=0.0)
        pooled = (out * weights.unsqueeze(-1)).sum(dim=1)
        logits = self.head(pooled).squeeze(1)
        return logits

In [None]:
# ------------ Feature builders ------------
def build_features_from_embeddings(split, spatial_ckpt=SPATIAL_CKPT, temporal_ckpt=TEMPORAL_CKPT, overwrite=False):
    """
    Build features (mean, max, std, top3, temporal_score) from embeddings/<split>/*.npy
    Saves cache to OUT_CACHE/<split>.npz
    Returns X (N x 5), y (N,)
    """
    cache_path = OUT_CACHE / f"{split}.npz"
    if cache_path.exists() and not overwrite:
        print(f"Loading cached features for {split} from {cache_path}")
        d = np.load(cache_path, allow_pickle=True)
        return d["X"], d["y"]

    # load labels map
    with open(LABELS_JSON, "r") as f:
        labels_map = json.load(f)

    # instantiate spatial head and temporal model if available
    spatial = None
    temporal = None
    spatial_state = safe_load_checkpoint(spatial_ckpt)
    if spatial_state is not None:
        # load backbone features-only weights if possible
        state = spatial_state.get("model_state", spatial_state)
        state = strip_module_prefix(state)
        # create a SpatialHead and try to load any matching keys
        spatial = SpatialHead()
        try:
            # try load whole state (non-strict)
            spatial.load_state_dict(state, strict=False)
            print("Spatial checkpoint loaded into SpatialHead (non-strict).")
        except Exception:
            # ignore: we'll still use backbone features via timm
            print("Warning: couldn't fully load spatial checkpoint; proceeding with backbone defaults.")
        spatial.eval()
        spatial.to("cpu")  # we'll move tensors as needed later

    temporal_state = safe_load_checkpoint(temporal_ckpt)
    if temporal_state is not None:
        tstate = temporal_state.get("model_state", temporal_state)
        tstate = strip_module_prefix(tstate)
        # infer feat_dim from a sample embedding (need at least one file)
        # We'll load temporal after we infer FEAT_DIM below.
        print("Temporal checkpoint loaded (keys found).")
    else:
        tstate = None

    split_dir = EMB_ROOT / split
    if not split_dir.exists():
        raise RuntimeError(f"No embeddings directory for split: {split_dir}")

    files = sorted(split_dir.glob("*.npy"))
    X_list, y_list = [], []

    # clear bad log
    if BAD_EMB_LOG.exists():
        BAD_EMB_LOG.unlink()

    for p in files:
        try:
            arr = np.load(p)
        except Exception as e:
            with open(BAD_EMB_LOG, "a") as f:
                f.write(f"[LOAD_ERROR] {p} | {e}\n")
            continue

        bad, reason = is_bad_array(arr)
        if bad:
            with open(BAD_EMB_LOG, "a") as f:
                f.write(f"{p} | {reason}\n")
            continue

        # spatial stats: compute logits by treating each row as feat vector into spatial head.head
        try:
            if spatial is not None:
                # we assume spatial.backbone produces same feature dim as embeddings columns
                emb_t = torch.from_numpy(arr.astype(np.float32))
                with torch.no_grad():
                    logits = spatial.head(emb_t) if hasattr(spatial, "head") else torch.zeros(len(emb_t))
                    # defensive: ensure numeric
                    if torch.isnan(logits).any() or torch.isinf(logits).any():
                        logits = torch.nan_to_num(logits, nan=0.0, posinf=0.0, neginf=0.0)
                    probs = torch.sigmoid(logits).cpu().numpy()
            else:
                # If no spatial checkpoint, we can't run head; derive proxy by using row-norms (safe fallback)
                norms = np.linalg.norm(arr, axis=1)
                probs = (norms - norms.min()) / (norms.max() - norms.min() + 1e-12)
        except Exception as e:
            with open(BAD_EMB_LOG, "a") as f:
                f.write(f"[SPATIAL_EVAL_ERROR] {p} | {e}\n")
            continue

        s_mean = float(probs.mean())
        s_max  = float(probs.max())
        s_std  = float(probs.std())
        s_top3 = float(np.sort(probs)[-3:].mean()) if len(probs) >= 3 else s_mean

        # temporal score
        try:
            # simple single-sample inference: build a tiny temporal model if checkpoint exists,
            # otherwise approximate by e.g., mean of probs
            if tstate is not None:
                # lazy create temporal model on first call
                if 'temporal_model_cached' not in globals():
                    FEAT_DIM = arr.shape[1]
                    globals()['temporal_model_cached'] = TemporalModel(FEAT_DIM)
                    globals()['temporal_model_cached'].load_state_dict(tstate, strict=False)
                    globals()['temporal_model_cached'].eval()
                tm = globals()['temporal_model_cached']
                x = torch.from_numpy(arr.astype(np.float32)).unsqueeze(0)  # [1,T,feat]
                lengths = torch.tensor([arr.shape[0]], dtype=torch.long)
                with torch.no_grad():
                    logits = tm(x, lengths)
                    logits = torch.nan_to_num(logits, nan=0.0, posinf=0.0, neginf=0.0)
                    t_prob = float(torch.sigmoid(logits).item())
            else:
                # fallback: use mean of spatial probs as temporal proxy
                t_prob = float(s_mean)
        except Exception as e:
            with open(BAD_EMB_LOG, "a") as f:
                f.write(f"[TEMPORAL_EVAL_ERROR] {p} | {e}\n")
            continue

        feat = [s_mean, s_max, s_std, s_top3, t_prob]
        X_list.append(feat)
        # label lookup
        stem = p.stem
        try:
            with open(LABELS_JSON, "r") as f:
                labels_map = json.load(f)
            lab = labels_map.get(stem, None)
            if lab is None:
                # fallback to substring match
                for k,v in labels_map.items():
                    if stem in k:
                        lab = int(v); break
            if lab is None:
                raise KeyError("Label missing")
            y_list.append(int(lab))
        except Exception as e:
            with open(BAD_EMB_LOG, "a") as f:
                f.write(f"[LABEL_ERROR] {p} | {e}\n")
            # skip if label missing
            continue

    if len(X_list) == 0:
        raise RuntimeError(f"No valid features built for split {split}. Check {BAD_EMB_LOG}")

    X = np.array(X_list, dtype=np.float32)
    y = np.array(y_list, dtype=np.int64)
    np.savez(cache_path, X=X, y=y)
    print(f"Saved feature cache: {cache_path} (n={len(y)})")
    return X, y


In [None]:
def build_features_from_predictions(split, overwrite=False):
    """
    Load per-video scalar predictions from PRED_SPATIAL_DIR/<split>/<stem>.npy
    and PRED_TEMPORAL_DIR/<split>/<stem>.npy
    Save cache to OUT_CACHE/<split>.npz
    """
    cache_path = OUT_CACHE / f"{split}.npz"
    if cache_path.exists() and not overwrite:
        print(f"Loaded cached features for {split} from {cache_path}")
        d = np.load(cache_path, allow_pickle=True)
        return d["X"], d["y"]

    spatial_dir = PRED_SPATIAL_DIR / split
    temporal_dir = PRED_TEMPORAL_DIR / split
    if not spatial_dir.exists() or not temporal_dir.exists():
        raise RuntimeError("Prediction directories missing; cannot build from predictions")

    if BAD_PRED_LOG.exists():
        BAD_PRED_LOG.unlink()

    spatial_files = {p.stem:p for p in spatial_dir.glob("*.npy")}
    temporal_files = {p.stem:p for p in temporal_dir.glob("*.npy")}
    stems = sorted(set(spatial_files.keys()) & set(temporal_files.keys()))
    if not stems:
        raise RuntimeError("No common prediction stems found between spatial and temporal predictions")

    X_list, y_list = [], []
    with open(LABELS_JSON, "r") as f:
        labels_map = json.load(f)

    for stem in stems:
        try:
            s_v = np.load(spatial_files[stem], allow_pickle=True)
            # convert to scalar safely
            if np.isscalar(s_v):
                s = float(s_v)
            elif isinstance(s_v, np.ndarray) and s_v.size == 1:
                s = float(s_v.flatten()[0])
            else:
                raise ValueError(f"Bad spatial pred shape: {getattr(s_v,'shape',None)}")
            t_v = np.load(temporal_files[stem], allow_pickle=True)
            if np.isscalar(t_v):
                t = float(t_v)
            elif isinstance(t_v, np.ndarray) and t_v.size == 1:
                t = float(t_v.flatten()[0])
            else:
                raise ValueError(f"Bad temporal pred shape: {getattr(t_v,'shape',None)}")
            if not np.isfinite(s) or not np.isfinite(t):
                raise ValueError(f"Non-finite pred s={s}, t={t}")
        except Exception as e:
            with open(BAD_PRED_LOG, "a") as f:
                f.write(f"{stem} | {e}\n")
            continue

        feat = [s, t]          # only two features; we'll pad to 5 with zeros for consistency
        # map label
        lab = labels_map.get(stem, None)
        if lab is None:
            for k,v in labels_map.items():
                if stem in k:
                    lab = int(v); break
        if lab is None:
            with open(BAD_PRED_LOG, "a") as f:
                f.write(f"{stem} | LABEL_MISSING\n")
            continue

        # expand to 5 dims: [mean,max,std,top3,temporal] where spatial stats unknown => use s repeated
        X_list.append([s, s, 0.0, s, t])
        y_list.append(int(lab))

    if len(X_list) == 0:
        raise RuntimeError(f"No valid prediction-based features for split {split}. Check {BAD_PRED_LOG}")

    X = np.array(X_list, dtype=np.float32)
    y = np.array(y_list, dtype=np.int64)
    np.savez(cache_path, X=X, y=y)
    print(f"Saved prediction-based cache: {cache_path} (n={len(y)})")
    return X, y


In [None]:
# ------------ Determine mode (predictions exist?) ------------
def have_prediction_dirs():
    return (PRED_SPATIAL_DIR.exists() and PRED_TEMPORAL_DIR.exists())

mode_from_preds = have_prediction_dirs()
print("Mode:", "predictions" if mode_from_preds else "embeddings (will compute features)")


In [None]:
# ------------ Build / load features for all splits ------------
def build_all_features(overwrite=False):
    Xs, ys = {}, {}
    if mode_from_preds:
        builder = build_features_from_predictions
    else:
        builder = build_features_from_embeddings

    for split in SPLITS:
        try:
            Xs[split], ys[split] = builder(split, overwrite=overwrite)
        except Exception as e:
            print(f"Error building features for split={split}: {e}. Check logs {BAD_EMB_LOG if not mode_from_preds else BAD_PRED_LOG}")
            Xs[split], ys[split] = np.zeros((0,5), dtype=np.float32), np.zeros((0,), dtype=np.int64)
    return Xs, ys

X_dict, y_dict = build_all_features(overwrite=False)


In [None]:
# quick dataset stats
print("Train size:", X_dict["train"].shape[0], "Val size:", X_dict["val"].shape[0], "Test size:", X_dict["test"].shape[0])
print("Train label counts:", Counter(y_dict["train"].tolist()))


In [None]:
# ------------ Train calibrated logistic regression (with fallbacks) ------------
# Prepare training arrays
X_train = X_dict["train"]
y_train = y_dict["train"]
X_val = X_dict["val"]
y_val = y_dict["val"]
X_test = X_dict["test"]
y_test = y_dict["test"]

if len(y_train) == 0:
    raise RuntimeError("No training samples available. Cannot train ensemble.")

# choose cv safely
n_samples = len(y_train)
cv = CALIBRATION_CV
while cv > 1:
    # ensure each fold will have at least 1 sample per class roughly
    if n_samples >= cv * 2:  # heuristic: want at least 2 samples per fold
        break
    cv -= 1
if cv < 2:
    cv = None  # don't calibrate with cross-val if too small

base_clf = LogisticRegression(max_iter=LR_MAX_ITER, solver="lbfgs")
ensemble_clf = None

if cv is None:
    print("Calibration cv too small — training plain LogisticRegression (no calibration).")
    base_clf.fit(X_train, y_train)
    ensemble_clf = base_clf
else:
    try:
        skf = StratifiedKFold(n_splits=cv, shuffle=True, random_state=RNG_SEED)
        ensemble = CalibratedClassifierCV(base_clf, cv=skf, method="sigmoid")
        ensemble.fit(X_train, y_train)
        ensemble_clf = ensemble
    except Exception as e:
        warnings.warn(f"CalibratedClassifierCV failed: {e}. Falling back to plain LogisticRegression.")
        base_clf.fit(X_train, y_train)
        ensemble_clf = base_clf


In [None]:
# ------------ Evaluate and save ------------
train_p = ensemble_clf.predict_proba(X_train)[:,1] if hasattr(ensemble_clf, "predict_proba") else ensemble_clf.decision_function(X_train)
val_p   = ensemble_clf.predict_proba(X_val)[:,1] if X_val.shape[0]>0 and hasattr(ensemble_clf, "predict_proba") else (ensemble_clf.decision_function(X_val) if X_val.shape[0]>0 else np.array([]))
test_p  = ensemble_clf.predict_proba(X_test)[:,1] if X_test.shape[0]>0 and hasattr(ensemble_clf, "predict_proba") else (ensemble_clf.decision_function(X_test) if X_test.shape[0]>0 else np.array([]))

train_auc = safe_auc(y_train, train_p)
val_auc   = safe_auc(y_val, val_p) if X_val.shape[0]>0 else float("nan")
test_auc  = safe_auc(y_test, test_p) if X_test.shape[0]>0 else float("nan")

# Save model and results
joblib.dump(ensemble_clf, ENSEMBLE_OUT)
with open(RESULTS_TXT, "w") as f:
    f.write(f"Mode: {'predictions' if mode_from_preds else 'embeddings'}\n")
    f.write(f"Train samples: {len(y_train)}  Train AUC: {train_auc}\n")
    f.write(f"Val samples:   {len(y_val)}  Val AUC:   {val_auc}\n")
    f.write(f"Test samples:  {len(y_test)}  Test AUC:  {test_auc}\n")
    f.write("\nNotes:\n")
    if BAD_EMB_LOG.exists():
        f.write(f"Bad embeddings log: {BAD_EMB_LOG}\n")
    if BAD_PRED_LOG.exists():
        f.write(f"Bad preds log: {BAD_PRED_LOG}\n")
print("Saved ensemble to:", ENSEMBLE_OUT)
print("Results written to:", RESULTS_TXT)
if BAD_EMB_LOG.exists():
    print("Bad embeddings logged at:", BAD_EMB_LOG)
if BAD_PRED_LOG.exists():
    print("Bad predictions logged at:", BAD_PRED_LOG)

# Print summary to console
print("Train AUC:", train_auc)
print("Val AUC:", val_auc)
print("Test AUC:", test_auc)