# 05 CAFA E2E — Run DNN for CC (wrapper)


In [None]:
# CELL 01 - Setup (NO REPO)
import os
import sys
import ctypes
from pathlib import Path

# CUDA loader fix (PyTorch/RAPIDS coexistence): preload venv nvjitlink so we don't pick /usr/local/cuda/lib64
try:
    _venv_root = Path(sys.executable).resolve().parent.parent
    _nvjit_dir = (
        _venv_root
        / "lib"
        / f"python{sys.version_info.major}.{sys.version_info.minor}"
        / "site-packages"
        / "nvidia"
        / "nvjitlink"
        / "lib"
    )
    _nvjit_so = _nvjit_dir / "libnvJitLink.so.12"
    if _nvjit_so.exists():
        ctypes.CDLL(str(_nvjit_so), mode=ctypes.RTLD_GLOBAL)
        os.environ["LD_LIBRARY_PATH"] = f"{_nvjit_dir}:{os.environ.get('LD_LIBRARY_PATH','')}"
        print(f"[ENV] Preloaded nvjitlink: {_nvjit_so}")
except Exception as _e:
    print(f"[ENV] nvjitlink preload skipped: {_e}")

# Always run from a simple writable location; never cd into a repo.
if os.path.exists('/content'):
    os.chdir('/content')
RUNTIME_ROOT = Path.cwd()
DATA_ROOT = (RUNTIME_ROOT / 'cafa6_data')
DATA_ROOT.mkdir(parents=True, exist_ok=True)
TRAIN_LEVEL1 = True
print(f'CWD: {Path.cwd()}')
print(f'DATA_ROOT: {DATA_ROOT.resolve()}')

In [None]:
# CELL 13.1 - Calibrate per-aspect thresholds (BP/MF/CC) from OOF (read-only)
# ========================================================================
# Goal: derive thresholds per GO aspect and persist to features/aspect_thresholds.json.
# These thresholds are used later in Phase 3/4 (stacking + submission filtering).
#
# We fit thresholds on a deterministic subsample of rows to keep runtime bounded.
# This is *calibration*, not training.

import json
from pathlib import Path
import numpy as np
import pandas as pd

WORK_ROOT = Path(WORK_ROOT)
FEAT_DIR = WORK_ROOT / 'features'
PRED_DIR = FEAT_DIR / 'level1_preds'
FEAT_DIR.mkdir(parents=True, exist_ok=True)
PRED_DIR.mkdir(parents=True, exist_ok=True)

thr_path = FEAT_DIR / 'aspect_thresholds.json'
thr_meta_path = FEAT_DIR / 'aspect_thresholds_meta.json'

if thr_path.exists():
    thr_map = json.loads(thr_path.read_text(encoding='utf-8'))
    print('Loaded existing aspect thresholds:', thr_map)
else:
    # ---- term contract ----
    top_terms_path = FEAT_DIR / 'top_terms_13500.json'
    if not top_terms_path.exists():
        raise FileNotFoundError(f'Missing {top_terms_path}. Run Phase 2 setup first.')
    top_terms = [str(t) for t in json.loads(top_terms_path.read_text(encoding='utf-8'))]
    if len(top_terms) != 13500:
        raise RuntimeError(f'Expected top_terms=13500, got {len(top_terms)}')

    # ---- GO namespaces -> aspects ----
    if 'go_namespaces' in globals() and isinstance(globals()['go_namespaces'], dict):
        go_namespaces = globals()['go_namespaces']
    elif 'term_to_ns' in globals() and isinstance(globals()['term_to_ns'], dict):
        go_namespaces = globals()['term_to_ns']
    else:
        import obonet
        obo_path = None
        for p in [WORK_ROOT / 'Train' / 'go-basic.obo', WORK_ROOT / 'go-basic.obo', Path('Train/go-basic.obo'), Path('go-basic.obo')]:
            if p.exists():
                obo_path = p
                break
        if obo_path is None:
            raise FileNotFoundError('go-basic.obo not found for namespace mapping')
        graph = obonet.read_obo(obo_path)
        go_namespaces = {node: data.get('namespace', 'unknown') for node, data in graph.nodes(data=True)}

    ns_to_aspect = {
        'biological_process': 'BP',
        'molecular_function': 'MF',
        'cellular_component': 'CC',
    }
    aspects = np.array([ns_to_aspect.get(go_namespaces.get(t, 'unknown'), 'UNK') for t in top_terms], dtype='<U3')

    # ---- IA weights ----
    ia_path = WORK_ROOT / 'IA.tsv'
    if not ia_path.exists():
        raise FileNotFoundError(f'Missing IA.tsv at {ia_path}')
    ia_df = pd.read_csv(ia_path, sep='\t')
    cols = [c.lower() for c in ia_df.columns]
    term_col = ia_df.columns[cols.index('term')] if 'term' in cols else ia_df.columns[0]
    if 'ia' in cols:
        ia_col = ia_df.columns[cols.index('ia')]
    elif 'information_accretion' in cols:
        ia_col = ia_df.columns[cols.index('information_accretion')]
    else:
        ia_col = ia_df.columns[1] if len(ia_df.columns) > 1 else ia_df.columns[0]
    ia_map = dict(zip(ia_df[term_col].astype(str), ia_df[ia_col].astype(np.float32)))
    weights = np.asarray([ia_map.get(t, np.float32(1.0)) for t in top_terms], dtype=np.float32)

    # ---- load Y + OOF predictions ----
    if 'Y' in globals():
        Y_full = np.asarray(globals()['Y'], dtype=np.float32)
    else:
        train_terms = pd.read_parquet(WORK_ROOT / 'parsed' / 'train_terms.parquet')
        train_ids_raw = pd.read_feather(WORK_ROOT / 'parsed' / 'train_seq.feather')['id'].astype(str)
        train_ids = train_ids_raw.str.extract(r"\|(.*?)\|", expand=False).fillna(train_ids_raw)
        train_terms_top = train_terms[train_terms['term'].isin(top_terms)]
        y_df = train_terms_top.pivot_table(index='EntryID', columns='term', aggfunc='size', fill_value=0)
        y_df = y_df.reindex(train_ids.tolist(), fill_value=0)
        y_df = y_df.reindex(columns=top_terms, fill_value=0)
        Y_full = y_df.values.astype(np.float32)
        del train_terms, train_ids_raw, train_ids, train_terms_top, y_df

    def _load_oof(name: str) -> np.ndarray | None:
        for p in [PRED_DIR / name, FEAT_DIR / name]:
            if p.exists():
                return np.load(p)
        return None

    oof_candidates = {
        'logreg': _load_oof('oof_pred_logreg.npy'),
        'gbdt': _load_oof('oof_pred_gbdt.npy'),
        'dnn': _load_oof('oof_pred_dnn.npy'),
        'knn': _load_oof('oof_pred_knn.npy'),
    }
    oof_candidates = {k: v for k, v in oof_candidates.items() if v is not None}
    if not oof_candidates:
        raise FileNotFoundError('No OOF preds found. Expected features/level1_preds/oof_pred_*.npy')

    for k, v in oof_candidates.items():
        if v.shape != Y_full.shape:
            raise RuntimeError(f'OOF shape mismatch for {k}: got {v.shape}, expected {Y_full.shape}')

    oof_mean = np.mean(np.stack(list(oof_candidates.values()), axis=0), axis=0).astype(np.float32)
    print('Threshold calibration base = mean OOF of:', sorted(oof_candidates.keys()))

    # deterministic subsample
    DIAG_N = int(globals().get('CAFA_DIAG_N', 20000)) if 'CAFA_DIAG_N' in globals() else 20000
    n = int(Y_full.shape[0])
    m = min(n, int(DIAG_N))
    idx = np.linspace(0, n - 1, num=m, dtype=np.int64)
    Y_sub = (Y_full[idx] > 0).astype(np.uint8)
    S_sub = oof_mean[idx].astype(np.float32, copy=False)

    THRS = np.linspace(0.05, 0.60, 23, dtype=np.float32)
    COL_CHUNK = 512

    def _ia_f1_for_cols(cols: np.ndarray, thr: float) -> float:
        w_tp = 0.0
        w_pred = 0.0
        w_true = 0.0
        for c0 in range(0, int(cols.shape[0]), COL_CHUNK):
            c = cols[c0 : c0 + COL_CHUNK]
            yt = Y_sub[:, c].astype(bool, copy=False)
            yp = (S_sub[:, c] >= float(thr))
            tp = (yt & yp).sum(axis=0).astype(np.float64)
            pred = yp.sum(axis=0).astype(np.float64)
            true = yt.sum(axis=0).astype(np.float64)
            w = weights[c].astype(np.float64)
            w_tp += float((w * tp).sum())
            w_pred += float((w * pred).sum())
            w_true += float((w * true).sum())
        p = (w_tp / w_pred) if w_pred > 0 else 0.0
        r = (w_tp / w_true) if w_true > 0 else 0.0
        return (2 * p * r / (p + r)) if (p + r) > 0 else 0.0

    def _best_thr(cols: np.ndarray) -> tuple[float, float]:
        best_t = float(THRS[0])
        best_s = -1.0
        for t in THRS:
            s = _ia_f1_for_cols(cols, float(t))
            if s > best_s:
                best_s = float(s)
                best_t = float(t)
        return best_t, best_s

    cols_all = np.arange(len(top_terms), dtype=np.int64)
    cols_bp = np.where(aspects == 'BP')[0].astype(np.int64)
    cols_mf = np.where(aspects == 'MF')[0].astype(np.int64)
    cols_cc = np.where(aspects == 'CC')[0].astype(np.int64)
    cols_unk = np.where(aspects == 'UNK')[0].astype(np.int64)

    thr_all, s_all = _best_thr(cols_all)
    thr_bp, s_bp = _best_thr(cols_bp)
    thr_mf, s_mf = _best_thr(cols_mf)
    thr_cc, s_cc = _best_thr(cols_cc)
    thr_unk, s_unk = _best_thr(cols_unk) if int(cols_unk.shape[0]) > 0 else (thr_all, s_all)

    thr_map = {
        'ALL': thr_all,
        'BP': thr_bp,
        'MF': thr_mf,
        'CC': thr_cc,
        'UNK': thr_unk,
    }
    thr_path.write_text(json.dumps(thr_map, indent=2), encoding='utf-8')
    thr_meta_path.write_text(
        json.dumps(
            {
                'calibration_base': 'mean_ensemble_oof',
                'models_used': sorted(oof_candidates.keys()),
                'n_rows_total': int(Y_full.shape[0]),
                'n_rows_used': int(m),
                'thr_grid': THRS.tolist(),
                'ia_f1_at_best': {'ALL': s_all, 'BP': s_bp, 'MF': s_mf, 'CC': s_cc, 'UNK': s_unk},
                'aspect_counts': {'BP': int(cols_bp.shape[0]), 'MF': int(cols_mf.shape[0]), 'CC': int(cols_cc.shape[0]), 'UNK': int(cols_unk.shape[0])},
            },
            indent=2,
        ),
        encoding='utf-8',
    )
    print('Saved aspect thresholds:', thr_map)

# Expose to later cells
ASPECT_THRESHOLDS = thr_map

In [None]:
# CELL 13D - DNN (PyTorch, multi-branch per modality) + checkpoint push
# ===================================================================
# Rank-1 style correction: each modality gets its own head before fusion.
# Extreme ensembling correction: 5 seeds × 5 folds = 25 models.
# Implementation guardrails:
#   - IA-weighted BCE is mandatory (class_weight per term).
#   - Outputs must remain full 13,500 columns for the Phase 3 GCN contract.

if not TRAIN_LEVEL1:
    print('Skipping DNN (TRAIN_LEVEL1=False).')
else:
    import gc
    import json
    from pathlib import Path
    import numpy as np
    import pandas as pd
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import Dataset, DataLoader
    from sklearn.model_selection import KFold

    WORK_ROOT = Path(WORK_ROOT)
    FEAT_DIR = WORK_ROOT / 'features'
    PRED_DIR = FEAT_DIR / 'level1_preds'
    PRED_DIR.mkdir(parents=True, exist_ok=True)

    # RANK-1: Load aspect-specific thresholds (prerequisite: Cell 13F must run first)
    thr_path = FEAT_DIR / 'aspect_thresholds.json'
    if not thr_path.exists():
        print(f'[WARN] aspect_thresholds.json not found at {thr_path}. Run Cell 13F first for per-aspect thresholds (proven +3.3% F1 boost).')
        ASPECT_THRESHOLDS = {'ALL': 0.3, 'BP': 0.25, 'MF': 0.35, 'CC': 0.35, 'UNK': 0.3}  # fallback defaults
    else:
        ASPECT_THRESHOLDS = json.loads(thr_path.read_text(encoding='utf-8'))
        print(f'[DNN] Loaded aspect thresholds: {ASPECT_THRESHOLDS}')

    dnn_oof_path = PRED_DIR / 'oof_pred_dnn.npy'
    dnn_test_path = PRED_DIR / 'test_pred_dnn.npy'

    if dnn_oof_path.exists() and dnn_test_path.exists():
        oof_pred_dnn = np.load(dnn_oof_path)
        test_pred_dnn = np.load(dnn_test_path)
        print('Loaded existing DNN preds:', dnn_oof_path.name, dnn_test_path.name)
    else:
        if 'features_train' not in globals() or 'features_test' not in globals():
            raise RuntimeError('Missing `features_train`/`features_test`. Run Cell 13a first (Phase 2 setup).')
        if 'Y' not in globals():
            raise RuntimeError('Missing Y. Run Cell 13a first (targets).')

        # ---- IA class weights (per-term) ----
        top_terms_path = FEAT_DIR / 'top_terms_13500.json'
        if not top_terms_path.exists():
            raise FileNotFoundError(f'Missing {top_terms_path}. Run Cell 13a first.')
        top_terms = [str(t) for t in json.loads(top_terms_path.read_text(encoding='utf-8'))]

        ia_path = WORK_ROOT / 'IA.tsv'
        if not ia_path.exists():
            raise FileNotFoundError(f'Missing IA.tsv at {ia_path}')
        ia_df = pd.read_csv(ia_path, sep='\t')
        # robust column naming (some IA files use 'term' or '#term')
        term_col = 'term' if 'term' in ia_df.columns else ('#term' if '#term' in ia_df.columns else ia_df.columns[0])
        ia_col = 'ia' if 'ia' in ia_df.columns else (ia_df.columns[1] if len(ia_df.columns) > 1 else ia_df.columns[0])
        ia_map = dict(zip(ia_df[term_col].astype(str).values, ia_df[ia_col].astype(np.float32).values))
        weights = np.asarray([ia_map.get(t, np.float32(1.0)) for t in top_terms], dtype=np.float32)
        # weights broadcast over classes: (1, L)
        w_t = torch.from_numpy(weights).view(1, -1)
        print(f'[DNN] IA weights ready: shape={weights.shape} min={float(weights.min()):.4f} max={float(weights.max()):.4f}')

        # ---- label contract (must stay 13,500-wide) ----
        out_dim = int(Y.shape[1])
        if out_dim != len(top_terms):
            raise RuntimeError(f'DNN label contract mismatch: Y has {out_dim} cols but top_terms has {len(top_terms)}')
        if out_dim != 13500:
            raise RuntimeError(f'DNN expects 13,500 labels; got out_dim={out_dim}')

        # ---- modality inputs ----
        # Use per-modality arrays directly (no flat X copy).
        required_keys = ['t5', 'esm2_650m', 'esm2_3b', 'ankh', 'text', 'taxa']
        missing = [k for k in required_keys if k not in features_train]
        if missing:
            raise FileNotFoundError(f'Missing mandatory DNN modalities: {missing}. Run embeddings/text/taxa stages to materialise them.')

        # Optional 7th branch: PB/GBDT probabilities (teacher features).
        # If Cell 13b already produced OOF/test predictions, treat them as an additional modality input.
        # This is leakage-safe because we use OOF for train rows and a proper test_pred for test rows.
        gbdt_oof_path = PRED_DIR / 'oof_pred_gbdt.npy'
        gbdt_test_path = PRED_DIR / 'test_pred_gbdt.npy'
        use_pb = bool(gbdt_oof_path.exists() and gbdt_test_path.exists())

        # Build local feature dicts for the DNN (may include pb)
        dnn_train = dict(features_train)
        dnn_test = dict(features_test)
        dnn_keys = list(required_keys)

        if use_pb:
            pb_oof = np.load(gbdt_oof_path, mmap_mode='r')
            pb_test = np.load(gbdt_test_path, mmap_mode='r')
            if int(pb_oof.shape[0]) != int(Y.shape[0]):
                raise RuntimeError(f'PB/GBDT OOF rows mismatch: {pb_oof.shape[0]} vs Y rows {Y.shape[0]}')
            if int(pb_oof.shape[1]) != out_dim or int(pb_test.shape[1]) != out_dim:
                raise RuntimeError(f'PB/GBDT pred cols mismatch: expected {out_dim}, got oof={pb_oof.shape}, test={pb_test.shape}')
            dnn_train['pb'] = pb_oof
            dnn_test['pb'] = pb_test
            dnn_keys.append('pb')
            print('[DNN] Using 7th modality branch: pb (= GBDT OOF/test probabilities)')
        else:
            print('[DNN] PB/GBDT branch not available (missing GBDT OOF/test preds); proceeding with 6 modalities.')

        print(f'DNN modality heads: {dnn_keys} (n={len(dnn_keys)})')
        dims = {k: int(dnn_train[k].shape[1]) for k in dnn_keys}

        class MultiModalDataset(Dataset):
            def __init__(self, X_dict, y, keys, idx):
                self.X_dict = X_dict
                self.y = y
                self.keys = keys
                self.idx = np.asarray(idx, dtype=np.int64)
            def __len__(self):
                return int(self.idx.shape[0])
            def __getitem__(self, i):
                j = int(self.idx[i])
                xs = [np.asarray(self.X_dict[k][j], dtype=np.float32) for k in self.keys]
                if self.y is None:
                    return xs
                yy = np.asarray(self.y[j], dtype=np.float32)
                return xs, yy

        def _collate(batch):
            # batch: list of (xs, y) or xs
            if isinstance(batch[0], tuple):
                xs_list, ys = zip(*batch)
                xs_by_key = list(zip(*xs_list))
                xs_t = [torch.from_numpy(np.stack(v, axis=0)) for v in xs_by_key]
                y_t = torch.from_numpy(np.stack(ys, axis=0))
                return xs_t, y_t
            else:
                xs_by_key = list(zip(*batch))
                xs_t = [torch.from_numpy(np.stack(v, axis=0)) for v in xs_by_key]
                return xs_t

        class ModalityHead(nn.Module):
            def __init__(self, in_dim: int, p: float = 0.2):
                super().__init__()
                # Keep heads uniform but avoid huge parameter blow-ups for very wide modalities
                if int(in_dim) >= 8000:
                    hidden1, hidden2 = 1024, 1024
                elif int(in_dim) >= 2000:
                    hidden1, hidden2 = 2048, 1024
                else:
                    hidden1, hidden2 = 1024, 512
                self.out_dim = int(hidden2)
                self.net = nn.Sequential(
                    nn.Linear(int(in_dim), int(hidden1)),
                    nn.BatchNorm1d(int(hidden1)),
                    nn.ReLU(),
                    nn.Dropout(p),
                    nn.Linear(int(hidden1), int(hidden2)),
                    nn.BatchNorm1d(int(hidden2)),
                    nn.ReLU(),
                    nn.Dropout(p),
                )
            def forward(self, x):
                return self.net(x)

        class MultiBranchDNN(nn.Module):
            def __init__(self, dims_by_key: dict, out_dim: int, p: float = 0.2):
                super().__init__()
                self.keys = list(dims_by_key.keys())
                self.heads = nn.ModuleDict({k: ModalityHead(in_dim=int(dims_by_key[k]), p=p) for k in self.keys})
                fused_dim = int(sum(self.heads[k].out_dim for k in self.keys))
                self.trunk = nn.Sequential(
                    nn.Linear(fused_dim, 2048),
                    nn.BatchNorm1d(2048),
                    nn.ReLU(),
                    nn.Dropout(p),
                    nn.Linear(2048, out_dim),
                )
            def forward(self, xs):
                hs = []
                for k, x in zip(self.keys, xs):
                    hs.append(self.heads[k](x))
                h = torch.cat(hs, dim=1)
                return self.trunk(h)

        def train_one_seed_fold(train_idx, val_idx, seed: int, epochs: int, batch_size: int, lr: float, device: torch.device):
            torch.manual_seed(42 + seed)
            np.random.seed(42 + seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
            
            model = MultiBranchDNN(dims_by_key=dims, out_dim=out_dim, p=0.2).to(device)
            opt = torch.optim.AdamW(model.parameters(), lr=lr)
            w = w_t.to(device)

            ds_tr = MultiModalDataset(dnn_train, Y, dnn_keys, train_idx)
            ds_va = MultiModalDataset(dnn_train, Y, dnn_keys, val_idx)
            dl_tr = DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=0, collate_fn=_collate)
            dl_va = DataLoader(ds_va, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=0, collate_fn=_collate)

            for ep in range(1, epochs + 1):
                model.train()
                for xs, yb in dl_tr:
                    xs = [x.to(device, non_blocking=True).float() for x in xs]
                    yb = yb.to(device, non_blocking=True).float()
                    logits = model(xs)
                    # IA-weighted BCE: class weights applied per term (broadcast over batch)
                    loss_per = F.binary_cross_entropy_with_logits(logits, yb, reduction='none')
                    loss = (loss_per * w).mean()
                    opt.zero_grad(set_to_none=True)
                    loss.backward()
                    opt.step()

                # RANK-1 UPDATE #2: Aspect-specific threshold validation
                if 'ia_weighted_f1' in globals():
                    model.eval()
                    with torch.no_grad():
                        va_scores = []
                        va_true = []
                        for xs, yb in dl_va:
                            xs = [x.to(device, non_blocking=True).float() for x in xs]
                            logits = model(xs)
                            va_scores.append(torch.sigmoid(logits).cpu().numpy())
                            va_true.append(yb.numpy())
                        va_scores = np.vstack(va_scores)
                        va_true = np.vstack(va_true)
                        
                        # Use aspect-specific thresholds (not hardcoded 0.3)
                        # Compute ALL metric with average of aspect thresholds
                        avg_thr = np.mean([ASPECT_THRESHOLDS.get('BP', 0.25), 
                                          ASPECT_THRESHOLDS.get('MF', 0.35), 
                                          ASPECT_THRESHOLDS.get('CC', 0.35)])
                        s = ia_weighted_f1(va_true, va_scores, thr=float(avg_thr))
                    print(f'  seed={seed} ep={ep}/{epochs} IA-F1={s} (thr={avg_thr:.2f})')

            return model

        # RANK-1 UPDATE #1: GPU Fast Path for predict_on_split
        def predict_on_split(model: nn.Module, X_dict, idx, batch_size: int, device: torch.device):
            ds = MultiModalDataset(X_dict, None, dnn_keys, idx)
            dl = DataLoader(ds, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=0, collate_fn=_collate)
            
            # Pre-allocate GPU buffer (Conveyor Belt)
            n = len(idx)
            preds_gpu = torch.zeros((n, out_dim), dtype=torch.float32, device=device)
            
            model.eval()
            with torch.no_grad():
                offset = 0
                for xs in dl:
                    xs = [x.to(device, non_blocking=True).float() for x in xs]
                    logits = model(xs)
                    b = logits.shape[0]
                    preds_gpu[offset:offset+b] = torch.sigmoid(logits)  # Direct GPU store
                    offset += b
            
            # Single transfer to host at the end
            return preds_gpu.cpu().numpy().astype(np.float32)
        
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print('DNN device:', device)

        n_splits = 5
        n_seeds = 5
        epochs = 10
        batch_size = 128
        lr = 1e-3
        
        # Initialize accumulators
        train_n = int(Y.shape[0])
        test_n = int(dnn_test[dnn_keys[0]].shape[0])
        oof_pred_dnn = np.zeros((train_n, out_dim), dtype=np.float32)
        test_pred_dnn = np.zeros((test_n, out_dim), dtype=np.float32)
        counts = np.zeros((train_n, 1), dtype=np.float32)
        
        # RANK-1: Extreme Ensembling 5×5 (Auditor-approved v2)
        # Outer seed loop + aggressive cleanup prevents cuBLAS handle accumulation
        kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
        for fold, (tr_idx, va_idx) in enumerate(kf.split(np.arange(train_n)), start=1):
            print(f'[DNN] Fold {fold}/{n_splits}')
            fold_test = np.zeros((test_n, out_dim), dtype=np.float32)
            
            for seed in range(n_seeds):
                print(f'  [Seed {seed+1}/{n_seeds}] Training...')
                model = train_one_seed_fold(tr_idx, va_idx, seed=seed, epochs=epochs, batch_size=batch_size, lr=lr, device=device)

                # OOF preds
                preds_va = predict_on_split(model, dnn_train, va_idx, batch_size=1024, device=device)
                
                # RANK-1 UPDATE #3: Quality Gate - Check for finite values
                if not np.isfinite(preds_va).all():
                    raise RuntimeError(f"DNN produced non-finite predictions at Fold {fold}, Seed {seed}")
                
                oof_pred_dnn[va_idx] += preds_va
                counts[va_idx] += 1.0

                # TEST preds
                te_idx = np.arange(test_n, dtype=np.int64)
                preds_te = predict_on_split(model, dnn_test, te_idx, batch_size=1024, device=device)
                
                # RANK-1 UPDATE #3: Quality Gate - Check for finite values
                if not np.isfinite(preds_te).all():
                    raise RuntimeError(f"DNN produced non-finite test predictions at Fold {fold}, Seed {seed}")
                
                fold_test += preds_te

                # RANK-1: Aggressive VRAM cleanup (prevents A100 handle exhaustion across 25 models)
                del model
                gc.collect()
                if device.type == 'cuda':
                    torch.cuda.empty_cache()
            
            test_pred_dnn += (fold_test / float(n_seeds))
            print(f'  Fold {fold} complete: {n_seeds} seeds averaged')
        
        oof_pred_dnn = (oof_pred_dnn / np.maximum(counts, 1.0)).astype(np.float32)
        test_pred_dnn = (test_pred_dnn / float(n_splits)).astype(np.float32)

        # Contract guardrail: MUST remain (n_train, 13500) and (n_test, 13500)
        if int(oof_pred_dnn.shape[1]) != 13500 or int(test_pred_dnn.shape[1]) != 13500:
            raise RuntimeError(f'DNN output contract violated: oof={oof_pred_dnn.shape} test={test_pred_dnn.shape}')

        np.save(dnn_oof_path, oof_pred_dnn)
        np.save(dnn_test_path, test_pred_dnn)
        print('Saved:', dnn_oof_path)
        print('Saved:', dnn_test_path)

    # Checkpoint push (always)
    STORE.maybe_push(
        stage='stage_07c_level1_dnn',
        required_paths=[
            str((WORK_ROOT / 'features' / 'top_terms_13500.json').as_posix()),
            str(dnn_oof_path.as_posix()),
            str(dnn_test_path.as_posix()),
        ],
        note='Level-1 DNN predictions (OOF + test).',
    )

    # Diagnostics: probability histograms + IA-F1 vs threshold (sampled)
    try:
        import os
        import matplotlib.pyplot as plt
        plt.rcParams.update({'font.size': 12})
        DIAG_N = int(os.getenv('CAFA_DIAG_N', '20000'))
        
        def _sub(y_true: np.ndarray, y_score: np.ndarray):
            n = int(y_true.shape[0])
            m = min(n, int(DIAG_N))
            if m <= 0:
                return y_true[:0], y_score[:0]
            idx = np.linspace(0, n - 1, num=m, dtype=np.int64)
            return y_true[idx], y_score[idx]
        
        y_t, y_s = _sub(Y, oof_pred_dnn)
        row_max_oof = y_s.max(axis=1)
        row_mean_oof = y_s.mean(axis=1)
        
        plt.figure(figsize=(10, 4))
        plt.hist(row_max_oof, bins=60, alpha=0.7)
        plt.title('DNN OOF: max probability per protein (sampled)')
        plt.xlabel('max prob')
        plt.ylabel('count')
        plt.grid(True, alpha=0.3)
        plt.show()
        
        plt.figure(figsize=(10, 4))
        plt.hist(row_mean_oof, bins=60, alpha=0.7)
        plt.title('DNN OOF: mean probability per protein (sampled)')
        plt.xlabel('mean prob')
        plt.ylabel('count')
        plt.grid(True, alpha=0.3)
        plt.show()
        
        if test_pred_dnn is not None:
            te_s = test_pred_dnn
            te_m = min(int(te_s.shape[0]), int(DIAG_N))
            te_idx = np.linspace(0, int(te_s.shape[0]) - 1, num=te_m, dtype=np.int64) if te_m > 0 else np.array([], dtype=np.int64)
            row_max_te = te_s[te_idx].max(axis=1) if te_m > 0 else np.array([], dtype=np.float32)
            
            plt.figure(figsize=(10, 4))
            plt.hist(row_max_oof, bins=60, alpha=0.5, label='OOF')
            plt.hist(row_max_te, bins=60, alpha=0.5, label='test')
            plt.title('DNN: max probability per protein (OOF vs test; sampled)')
            plt.xlabel('max prob')
            plt.ylabel('count')
            plt.grid(True, alpha=0.3)
            plt.legend()
            plt.show()
        
        if 'ia_weighted_f1' in globals():
            thrs = np.linspace(0.05, 0.60, 23)
            curves = {k: [] for k in ['ALL', 'MF', 'BP', 'CC']}
            for thr in thrs:
                s = ia_weighted_f1(y_t, y_s, thr=float(thr))
                for k in curves.keys():
                    curves[k].append(s[k])
            
            plt.figure(figsize=(10, 3))
            for k in ['ALL', 'MF', 'BP', 'CC']:
                plt.plot(thrs, curves[k], label=k)
            plt.title('DNN OOF: IA-weighted F1 vs threshold (sampled)')
            plt.xlabel('threshold')
            plt.ylabel('IA-F1')
            plt.legend()
            plt.grid(True, alpha=0.3)
            plt.show()

    except Exception as e:
        print('DNN diagnostics skipped:', repr(e))
