# 05 CAFA E2E — Run LogReg for MF (wrapper)

Self-contained runner: executes `notebooks/05_cafa_e2e.ipynb` with `TARGET_ASPECT=MF`, stopping before the DNN cell.

Key behaviour:
- No subprocess calls
- Logging-first diagnostics (search paths, stop-marker, failures)
- Produces per-aspect artefacts: `oof_pred_logreg_MF.npy`, `test_pred_logreg_MF.npy`

In [None]:
# CELL 01 - Setup (NO REPO)
import os
import sys
import ctypes
from pathlib import Path

# CUDA loader fix (PyTorch/RAPIDS coexistence): preload venv nvjitlink so we don't pick /usr/local/cuda/lib64
try:
    _venv_root = Path(sys.executable).resolve().parent.parent
    _nvjit_dir = (
        _venv_root
        / "lib"
        / f"python{sys.version_info.major}.{sys.version_info.minor}"
        / "site-packages"
        / "nvidia"
        / "nvjitlink"
        / "lib"
    )
    _nvjit_so = _nvjit_dir / "libnvJitLink.so.12"
    if _nvjit_so.exists():
        ctypes.CDLL(str(_nvjit_so), mode=ctypes.RTLD_GLOBAL)
        os.environ["LD_LIBRARY_PATH"] = f"{_nvjit_dir}:{os.environ.get('LD_LIBRARY_PATH','')}"
        print(f"[ENV] Preloaded nvjitlink: {_nvjit_so}")
except Exception as _e:
    print(f"[ENV] nvjitlink preload skipped: {_e}")

# Always run from a simple writable location; never cd into a repo.
if os.path.exists('/content'):
    os.chdir('/content')
RUNTIME_ROOT = Path.cwd()
DATA_ROOT = (RUNTIME_ROOT / 'cafa6_data')
DATA_ROOT.mkdir(parents=True, exist_ok=True)
TRAIN_LEVEL1 = True
print(f'CWD: {Path.cwd()}')
print(f'DATA_ROOT: {DATA_ROOT.resolve()}')

In [None]:
# CELL 13c - Level 1: Logistic Regression — Long tail 13,500 (Aspect Split BP/MF/CC)
# ==============================================================================
# Track B (auditor): regularised linear model over the full 13,500-term space.
# Goal: Train LR per GO aspect (BP/MF/CC) with RAM-safe target handling.
#
# Critical fix: Avoid NumPy fancy-indexing on memmaps.
# - DO NOT do: Y_aspect = Y_full[:, aspect_indices]  (this materialises a huge dense copy in host RAM)
# - Instead: slice Y_full per fold+target-chunk via np.ix_(idx_tr, cols)
#
# Protocol (recommended for stability): run ONE aspect per fresh runtime.
# If `TARGET_ASPECT` is not set, we default to BP.
# ==============================================================================

if not TRAIN_LEVEL1:
    print('Skipping LogReg (TRAIN_LEVEL1=False).')
else:
    import os
    import sys
    import time
    import threading
    import gc
    import warnings
    import psutil
    import json
    import numpy as np
    import pandas as pd
    import joblib
    import torch
    from pathlib import Path
    from tqdm.auto import tqdm
    from sklearn.model_selection import KFold
    from sklearn.preprocessing import StandardScaler
    from sklearn.linear_model import SGDClassifier
    from sklearn.multiclass import OneVsRestClassifier
    from sklearn.metrics import f1_score
    from sklearn.exceptions import ConvergenceWarning

    warnings.filterwarnings('ignore', category=ConvergenceWarning)

    # FAIL FAST: RAPIDS Requirement
    try:
        import cuml
        import cupy
        print("RAPIDS (cuml, cupy): present")
    except ImportError as e:
        raise RuntimeError("RAPIDS (cuml/cupy) is REQUIRED for LogReg but missing. "
                           "Aborting to avoid slow CPU fallback.") from e

    def _stage(msg: str) -> None:
        print(msg)
        try:
            sys.stdout.flush()
        except Exception:
            pass

    # -----------------------------
    # WORK_ROOT recovery (safety)
    # -----------------------------
    if 'WORK_ROOT' not in locals() and 'WORK_ROOT' not in globals():
        candidates = [
            Path('/content/cafa6_data'),
            Path('/content/work'),
            Path('/kaggle/working/work'),
            Path.cwd() / 'cafa6_data',
            Path.cwd() / 'artefacts_local' / 'work',
        ]
        WORK_ROOT = None
        for c in candidates:
            if (c / 'parsed' / 'train_terms.parquet').exists():
                WORK_ROOT = c
                break
        if WORK_ROOT is None:
            for c in candidates:
                if c.exists():
                    WORK_ROOT = c
                    break
        if WORK_ROOT is None:
            WORK_ROOT = Path.cwd() / 'cafa6_data'
        _stage(f"[AUDITOR] Recovered WORK_ROOT: {WORK_ROOT}")

    FEAT_DIR = Path(WORK_ROOT) / 'features'
    PARSED_DIR = Path(WORK_ROOT) / 'parsed'

    PRED_DIR = FEAT_DIR / 'level1_preds'
    PRED_DIR.mkdir(parents=True, exist_ok=True)

    # -----------------------------
    # Load top_terms + term_to_ns
    # -----------------------------
    top_terms_path = FEAT_DIR / 'top_terms_13500.json'
    if not top_terms_path.exists():
        raise FileNotFoundError(f"Missing {top_terms_path}. Run the Phase 2 setup cell first.")
    top_terms = json.loads(top_terms_path.read_text(encoding='utf-8'))
    top_terms = [str(t) for t in top_terms]

    # Build / reuse term_to_ns mapping
    if 'term_to_ns' in globals():
        term_to_ns = globals()['term_to_ns']
    else:
        try:
            import obonet
        except Exception as e:
            raise RuntimeError('obonet is required for aspect split (term namespaces).') from e

        possible_paths = []
        if 'PATH_GO_OBO' in globals():
            try:
                possible_paths.append(Path(globals()['PATH_GO_OBO']))
            except Exception:
                pass
        possible_paths += [
            Path(WORK_ROOT) / 'Train' / 'go-basic.obo',
            Path(WORK_ROOT) / 'go-basic.obo',
            Path('/content/cafa6_data/Train/go-basic.obo'),
            Path('Train/go-basic.obo'),
            Path('go-basic.obo'),
        ]
        obo_path = None
        for p in possible_paths:
            if p is not None and Path(p).exists():
                obo_path = Path(p)
                break
        if obo_path is None:
            raise FileNotFoundError(f"go-basic.obo not found. Candidates: {[str(p) for p in possible_paths]}")

        _stage(f"[AUDITOR] Loading GO OBO for namespaces: {obo_path}")
        graph = obonet.read_obo(obo_path)
        term_to_ns = {node: data.get('namespace', 'unknown') for node, data in graph.nodes(data=True)}

    aspect_map = {
        'biological_process': 'BP',
        'molecular_function': 'MF',
        'cellular_component': 'CC',
    }

    def _aspect_of_term(term: str) -> str:
        return aspect_map.get(term_to_ns.get(term), 'UNK')

    # -----------------------------
    # Load X / X_test (memmap)
    # -----------------------------
    x_path = FEAT_DIR / 'X_train_mmap.npy'
    xt_path = FEAT_DIR / 'X_test_mmap.npy'
    if not x_path.exists() or not xt_path.exists():
        raise FileNotFoundError(f"Missing X memmaps ({x_path} / {xt_path}). Run the Phase 2 setup cell first.")

    X = np.load(x_path, mmap_mode='r')
    X_test = np.load(xt_path, mmap_mode='r')

    # -----------------------------
    # Load / build Y (full 13,500) as memmap
    # -----------------------------
    y_full_path = FEAT_DIR / 'Y_target_13500.npy'
    if y_full_path.exists():
        Y_full = np.load(y_full_path, mmap_mode='r')
    else:
        _stage('[AUDITOR] Building Y_target_13500.npy (disk-backed) ...')
        train_terms = pd.read_parquet(PARSED_DIR / 'train_terms.parquet')
        train_ids_raw = pd.read_feather(PARSED_DIR / 'train_seq.feather')['id'].astype(str)
        train_ids = train_ids_raw.str.extract(r"\|(.*?)\|")[0].fillna(train_ids_raw)

        train_terms_top = train_terms[train_terms['term'].isin(top_terms)]
        Y_df = train_terms_top.pivot_table(index='EntryID', columns='term', aggfunc='size', fill_value=0)
        Y_df = Y_df.reindex(train_ids, fill_value=0)
        Y_df = Y_df.reindex(columns=top_terms, fill_value=0)
        np.save(y_full_path, Y_df.values.astype(np.float32))
        del train_terms, train_ids_raw, train_ids, train_terms_top, Y_df
        gc.collect()
        Y_full = np.load(y_full_path, mmap_mode='r')

    # -----------------------------
    # IA weights (optional)
    # -----------------------------
    weights_full = None
    try:
        ia_candidates = [Path(WORK_ROOT) / 'IA.tsv', FEAT_DIR / 'IA.tsv', Path('IA.tsv')]
        ia_path = next((p for p in ia_candidates if p.exists()), None)
        if ia_path is not None:
            ia = pd.read_csv(ia_path, sep='\t')
            cols = [c.lower() for c in ia.columns]
            term_col = ia.columns[cols.index('term')] if 'term' in cols else ia.columns[0]
            if 'ia' in cols:
                ia_col = ia.columns[cols.index('ia')]
            elif 'information_accretion' in cols:
                ia_col = ia.columns[cols.index('information_accretion')]
            else:
                ia_col = ia.columns[1] if len(ia.columns) > 1 else ia.columns[0]
            ia_map = dict(zip(ia[term_col].astype(str), ia[ia_col].astype(np.float32)))
            weights_full = np.asarray([ia_map.get(t, np.float32(1.0)) for t in top_terms], dtype=np.float32)
            _stage(f"[AUDITOR] Loaded IA weights from {ia_path}")
    except Exception as e:
        _stage(f"[AUDITOR] IA weights unavailable (continuing): {e}")
        weights_full = None

    def _log_mem(msg: str = ''):
        try:
            proc = psutil.Process(os.getpid())
            ram = proc.memory_info().rss / (1024**3)
            avail = psutil.virtual_memory().available / (1024**3)
            if torch.cuda.is_available():
                alloc = torch.cuda.memory_allocated() / (1024**3)
                res = torch.cuda.memory_reserved() / (1024**3)
            else:
                alloc = res = 0.0
            print(f"[MEM] {msg:22s} | RSS: {ram:6.2f}GB | RAM_avail: {avail:6.2f}GB | torch_alloc: {alloc:5.2f}GB torch_res: {res:5.2f}GB")
            sys.stdout.flush()
        except Exception:
            pass

    # Placeholder; replaced after RAPIDS discovery.
    def _log_gpu_mem(msg: str = ''):
        return

    # -----------------------------
    # RAPIDS discovery
    # -----------------------------
    try:
        import cuml  # noqa: F401
        from cuml.linear_model import LogisticRegression as cuLogReg
        from cuml.multiclass import OneVsRestClassifier as cuOVR
        import cupy as cp
        import rmm

        try:
            rmm.reinitialize(managed_memory=True)
            _stage('[AUDITOR] RAPIDS (cuML) detected. RMM Managed Memory ENABLED.')
        except Exception as e:
            _stage(f'[AUDITOR] RAPIDS detected but RMM init failed ({e}); proceeding with default memory.')

        HAS_RAPIDS = True
        _stage(
            f"[AUDITOR] versions: cupy={getattr(cp, '__version__', '?')} cuml={getattr(cuml, '__version__', '?')} rmm={getattr(rmm, '__version__', '?')}"
        )

        def _maybe_bytes(obj, method_name: str):
            fn = getattr(obj, method_name, None)
            if callable(fn):
                try:
                    return int(fn())
                except Exception:
                    return None
            return None

        def _pinned_used_bytes(pinned_pool):
            used = _maybe_bytes(pinned_pool, 'used_bytes')
            if used is not None:
                return used
            # Older CuPy versions sometimes expose total/free but not used.
            total = _maybe_bytes(pinned_pool, 'total_bytes')
            free = _maybe_bytes(pinned_pool, 'free_bytes')
            if total is not None and free is not None:
                return int(total - free)
            return None

        _GPU_LOG_STATE = {'rmm_printed': False}

        def _log_gpu_mem(msg: str = ''):
            try:
                free_b, total_b = cp.cuda.runtime.memGetInfo()
                pool = cp.get_default_memory_pool()
                pinned = cp.get_default_pinned_memory_pool()

                pool_b = _maybe_bytes(pool, 'used_bytes')
                pinned_b = _pinned_used_bytes(pinned)

                pool_txt = f"{pool_b/1e9:6.2f}GB" if pool_b is not None else "   n/a"
                pinned_txt = f"{pinned_b/1e9:6.2f}GB" if pinned_b is not None else "   n/a"

                print(
                    f"[GPU] {msg:22s} | free {free_b/1e9:6.2f}GB / {total_b/1e9:6.2f}GB | "
                    f"cupy_pool {pool_txt} | pinned {pinned_txt}"
                )
                if not _GPU_LOG_STATE.get('rmm_printed', False):
                    try:
                        res = rmm.mr.get_current_device_resource()
                        res_name = type(res).__name__
                        # Keep this single-line and stable (no memory addresses).
                        print(f"[GPU] rmm: enabled=True resource={res_name}")
                    except Exception:
                        print('[GPU] rmm: enabled=True (resource unavailable)')
                    _GPU_LOG_STATE['rmm_printed'] = True
                sys.stdout.flush()
            except Exception as e:
                print(f"[GPU] {msg:22s} | unavailable ({e})")
                sys.stdout.flush()

    except Exception:
        HAS_RAPIDS = False
        cp = None  # type: ignore
        _stage('[AUDITOR] RAPIDS NOT detected. Falling back to CPU (slow).')

    def _gpu_mem_okay(n_rows: int, n_cols: int, safety: float = 1.35) -> bool:
        if not (HAS_RAPIDS and torch.cuda.is_available()):
            return False
        try:
            free_b, _total_b = cp.cuda.runtime.memGetInfo()
            need_b = int(n_rows) * int(n_cols) * 4
            ok = free_b > int(safety * need_b)
            if not ok:
                _stage(f"[AUDITOR] VRAM insufficient: free={free_b/1e9:.1f}GB need~{need_b/1e9:.1f}GB")
            return bool(ok)
        except Exception:
            return True

    # -----------------------------
    # Aspect selection (default MF)
    # -----------------------------
    target_aspect = (os.environ.get('TARGET_ASPECT') or '').strip().upper()
    if not target_aspect:
        target_aspect = (globals().get('TARGET_ASPECT') or '').strip().upper()
    if not target_aspect:
        target_aspect = 'MF'
        _stage('[AUDITOR] TARGET_ASPECT not set -> defaulting to MF (recommended: run one aspect per fresh runtime).')

    if target_aspect not in {'BP', 'MF', 'CC'}:
        raise RuntimeError(
            f"Invalid TARGET_ASPECT={target_aspect!r}. Must be one of: BP, MF, CC."
        )

    aspects = [target_aspect]
    _stage(f"[AUDITOR] Training only aspect: {target_aspect}")

    # Runtime knobs (keep quality the same; only batching/chunking changes)
    # - TARGET_CHUNK: larger chunk reduces per-chunk overhead and improves GPU utilisation.
    # - VAL_BS/TEST_BS: larger batches better saturate A100; reduce if you hit OOM.
    TARGET_CHUNK = 500
    VAL_BS = 1024
    TEST_BS = 8192

    n_splits = 5
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)

    tmp_dir = FEAT_DIR / 'tmp_folds_logreg'
    tmp_dir.mkdir(parents=True, exist_ok=True)

    oof_pred_logreg_by_aspect = {}
    test_pred_logreg_by_aspect = {}
    aspect_indices_map = {}

    t0_all = time.time()
    _log_mem('start')
    _log_gpu_mem('start')

    # -----------------------------
    # Performance: materialise X_test into RAM (best-effort)
    # -----------------------------
    # X_test is used repeatedly across folds and target-chunks; keeping it disk-backed can become the bottleneck.
    try:
        _stage('[AUDITOR] LogReg: materialising X_test into RAM for faster inference...')
        X_test_ram = np.ascontiguousarray(np.asarray(X_test, dtype=np.float32))
        _stage(f'[AUDITOR] LogReg: X_test_ram shape={X_test_ram.shape} dtype={X_test_ram.dtype}')
    except MemoryError:
        X_test_ram = X_test
        _stage('[AUDITOR] WARNING: could not materialise X_test (MemoryError); using memmap (slower).')

    # -----------------------------
    # Fold streaming helpers
    # -----------------------------
    def _fmt_seconds(sec: float) -> str:
        try:
            sec = float(sec)
        except Exception:
            return '?'
        if not np.isfinite(sec) or sec < 0:
            return '?'
        total = int(sec + 0.5)
        m, s = divmod(total, 60)
        h, m = divmod(m, 60)
        if h > 0:
            return f'{h}h{m:02d}m'
        if m > 0:
            return f'{m}m{s:02d}s'
        return f'{s}s'

    def _start_heartbeat(tag: str, every_s: float = 60.0):
        # Prints a periodic line while a long operation runs (e.g., cuML fit can be silent).
        stop = threading.Event()
        t0 = time.time()

        def _loop():
            while not stop.wait(float(every_s)):
                _stage(f"[HEARTBEAT] {tag} (elapsed={_fmt_seconds(time.time() - t0)})")

        th = threading.Thread(target=_loop, daemon=True)
        th.start()
        return stop, th

    def _iter_indexed_rows(src_mm, idx: np.ndarray, bs: int, desc: str):
        for i in tqdm(range(0, int(len(idx)), int(bs)), desc=desc, unit='batch', leave=False):
            j = min(i + int(bs), int(len(idx)))
            rows = idx[i:j]
            xb = np.asarray(src_mm[rows], dtype=np.float32)
            yield i, j, xb

    def _fit_scaler_from_indexed_rows(src_mm, idx: np.ndarray, bs: int, desc: str):
        scaler = StandardScaler(with_mean=True, with_std=True)
        for _i, _j, xb in _iter_indexed_rows(src_mm, idx, bs=bs, desc=desc):
            scaler.partial_fit(xb)
        return {'mean': scaler.mean_.astype(np.float32), 'scale': scaler.scale_.astype(np.float32)}

    def _alloc_and_fill_gpu_from_indexed_rows(src_mm, idx: np.ndarray, mean: np.ndarray, scale: np.ndarray, bs: int, desc: str):
        n_rows = int(len(idx))
        n_cols = int(src_mm.shape[1])
        out = cp.empty((n_rows, n_cols), dtype=cp.float32)
        mean_gpu = cp.asarray(mean)
        scale_gpu = cp.asarray(scale)
        for i, j, xb in _iter_indexed_rows(src_mm, idx, bs=bs, desc=desc):
            xg = cp.asarray(xb)
            xg = (xg - mean_gpu) / (scale_gpu + 1e-12)
            out[i:j, :] = xg
            del xb, xg
            try:
                cp.get_default_memory_pool().free_all_blocks()
                cp.get_default_pinned_memory_pool().free_all_blocks()
            except Exception:
                pass
        return out, mean_gpu, scale_gpu

    def _to_numpy(a):
        if hasattr(a, 'get'):
            return a.get()
        if hasattr(a, 'to_numpy'):
            return a.to_numpy()
        return np.asarray(a)

    def _sigmoid_np(z):
        z = np.asarray(z, dtype=np.float32)
        z = np.clip(z, -50.0, 50.0)
        return 1.0 / (1.0 + np.exp(-z))

    def _ensure_2d(p):
        p = np.asarray(p)
        if p.ndim == 1:
            p = p.reshape(-1, 1)
        return p

    def _predict_proba_like(clf, x, expect_cols: int | None = None):
        # RAPIDS/cuML OVR may not implement predict_proba; fall back to decision_function->sigmoid.
        if hasattr(clf, 'predict_proba'):
            p = _ensure_2d(_to_numpy(clf.predict_proba(x)))
            # Binary single-label can come back as (n,2); we need P(class=1)
            if expect_cols == 1 and p.shape[1] == 2:
                p = p[:, 1:2]
            return np.asarray(p, dtype=np.float32)
        if hasattr(clf, 'decision_function'):
            s = _ensure_2d(_to_numpy(clf.decision_function(x)))
            return np.asarray(_sigmoid_np(s), dtype=np.float32)
        # Last resort: hard predictions (0/1)
        p = _ensure_2d(_to_numpy(clf.predict(x)))
        return np.asarray(p, dtype=np.float32)

    def _predict_proba_gpu_batched(clf, src_mm, idx: np.ndarray, mean_gpu, scale_gpu, bs: int, out_np: np.ndarray, out_rows: np.ndarray, col_slice: slice, desc: str):
        expect_cols = None
        try:
            if (col_slice is not None) and (col_slice.start is not None) and (col_slice.stop is not None):
                expect_cols = int(col_slice.stop - col_slice.start)
        except Exception:
            expect_cols = None
        for i, j, xb in _iter_indexed_rows(src_mm, idx, bs=bs, desc=desc):
            xg = cp.asarray(xb)
            xg = (xg - mean_gpu) / (scale_gpu + 1e-12)
            p = _predict_proba_like(clf, xg, expect_cols=expect_cols)
            out_np[out_rows[i:j], col_slice] = p
            del xb, xg, p
            try:
                cp.get_default_memory_pool().free_all_blocks()
                cp.get_default_pinned_memory_pool().free_all_blocks()
            except Exception:
                pass

    for aspect in aspects:
        _stage(f"\n=== LogReg Aspect: {aspect} ===")

        aspect_indices = [i for i, t in enumerate(top_terms) if _aspect_of_term(t) == aspect]
        if not aspect_indices:
            raise RuntimeError(f"No terms found for aspect {aspect} within top_terms_13500.json")

        aspect_indices = np.asarray(aspect_indices, dtype=np.int64)
        aspect_indices_map[aspect] = aspect_indices
        n_targets = int(aspect_indices.shape[0])
        aspect_terms = [top_terms[i] for i in aspect_indices.tolist()]

        # Slice IA weights for diagnostics only
        weights_aspect = weights_full[aspect_indices] if weights_full is not None else None

        # Persist per-aspect term contract
        (PRED_DIR / f'top_terms_{aspect}.json').write_text(json.dumps(aspect_terms), encoding='utf-8')

        # Per-aspect artefacts
        lr_oof_path = PRED_DIR / f'oof_pred_logreg_{aspect}.npy'
        lr_test_path = PRED_DIR / f'test_pred_logreg_{aspect}.npy'
        marker_dir = PRED_DIR / f'fold_markers_logreg_{aspect}'
        marker_dir.mkdir(parents=True, exist_ok=True)

        # Done check
        if lr_oof_path.exists() and lr_test_path.exists() and (len(list(marker_dir.glob('fold_*_done.flag'))) == n_splits):
            _stage(f"[AUDITOR] {aspect}: all folds completed; loading from disk")
            oof_pred_logreg_by_aspect[aspect] = np.load(lr_oof_path, mmap_mode='r')
            test_pred_logreg_by_aspect[aspect] = np.load(lr_test_path, mmap_mode='r')
            continue

        # Prepare memmaps for this aspect
        mode_oof = 'r+' if lr_oof_path.exists() else 'w+'
        mode_test = 'r+' if lr_test_path.exists() else 'w+'

        oof_pred = np.lib.format.open_memmap(str(lr_oof_path), mode=mode_oof, dtype=np.float32, shape=(X.shape[0], n_targets))
        test_pred = np.lib.format.open_memmap(str(lr_test_path), mode=mode_test, dtype=np.float32, shape=(X_test.shape[0], n_targets))
        if mode_oof == 'w+':
            oof_pred[:] = 0.0
            oof_pred.flush()
        if mode_test == 'w+':
            test_pred[:] = 0.0
            test_pred.flush()

        fold_iter = tqdm(kf.split(np.arange(X.shape[0])), total=kf.get_n_splits(), desc=f'LogReg {aspect} folds', unit='fold')
        fold_wall_s = []
        already_done = int(len(list(marker_dir.glob('fold_*_done.flag'))))

        for fold, (idx_tr, idx_val) in enumerate(fold_iter):
            marker_path = marker_dir / f'fold_{fold}_done.flag'
            if marker_path.exists():
                _stage(f"{aspect} Fold {fold+1}/{n_splits} already completed (marker found). Skipping.")
                continue

            t0_fold = time.time()
            _stage(f"{aspect} Fold {fold+1}/{n_splits}")
            _log_mem(f"{aspect} fold {fold+1} start")
            _log_gpu_mem(f"{aspect} fold {fold+1} start")

            idx_tr = np.asarray(idx_tr, dtype=np.int64)
            idx_val = np.asarray(idx_val, dtype=np.int64)

            # Performance: materialise fold validation rows into RAM once.
            # This avoids re-reading X (disk-backed) for every target-chunk during predict_proba.
            try:
                X_val_ram = np.ascontiguousarray(np.asarray(X[idx_val], dtype=np.float32))
                val_src = X_val_ram
                val_rows = np.arange(int(len(idx_val)), dtype=np.int64)
                _stage(f'[AUDITOR] {aspect} Fold {fold+1}: X_val_ram in RAM: {X_val_ram.shape}')
            except MemoryError:
                val_src = X
                val_rows = idx_val
                _stage(f'[AUDITOR] {aspect} Fold {fold+1}: WARNING: could not materialise X_val (MemoryError); using memmap (slower).')

            # Fit scaler by streaming fold rows from the global memmap (no fold copies).
            scaler_path = PRED_DIR / f'logreg_scaler_{aspect}_fold{fold}.pkl'
            if scaler_path.exists():
                _stage(f"[AUDIT] {aspect} Fold {fold+1}: Found existing scaler. Loading.")
                scaler_state = joblib.load(scaler_path)
            else:
                t0 = time.time()
                scaler_state = _fit_scaler_from_indexed_rows(X, idx_tr, bs=2048, desc=f'{aspect} Fold {fold+1} Fit Scaler')
                joblib.dump(scaler_state, scaler_path)
                _stage(f"[AUDIT] {aspect} Fold {fold+1}: Scaler fit in {time.time() - t0:.1f}s")

            mean = scaler_state['mean']
            scale = scaler_state['scale']

            # Performance: pre-scale X_test ONCE per fold (math-equivalent; avoids repeating same scaling per target-chunk).
            test_src = X_test_ram
            test_prescaled = False
            try:
                _stage(f'[AUDITOR] {aspect} Fold {fold+1}: pre-scaling X_test for this fold...')
                X_test_scaled = np.empty((int(X_test_ram.shape[0]), int(X_test_ram.shape[1])), dtype=np.float32)
                for b0 in tqdm(
                    range(0, int(X_test_ram.shape[0]), int(TEST_BS)),
                    desc=f'{aspect} Fold {fold+1} pre-scale X_test',
                    unit='batch',
                    leave=False,
                ):
                    b1 = min(b0 + int(TEST_BS), int(X_test_ram.shape[0]))
                    xb = np.asarray(X_test_ram[b0:b1], dtype=np.float32)
                    xb = (xb - mean) / (scale + 1e-12)
                    X_test_scaled[b0:b1, :] = xb
                    del xb
                test_src = X_test_scaled
                test_prescaled = True
                _stage(f'[AUDITOR] {aspect} Fold {fold+1}: X_test pre-scaled: shape={X_test_scaled.shape} dtype={X_test_scaled.dtype}')
            except MemoryError:
                test_src = X_test_ram
                test_prescaled = False
                _stage(f'[AUDITOR] {aspect} Fold {fold+1}: WARNING: X_test pre-scale skipped (MemoryError); scaling per batch.')
            except Exception as e:
                test_src = X_test_ram
                test_prescaled = False
                _stage(f'[AUDITOR] {aspect} Fold {fold+1}: WARNING: X_test pre-scale failed; scaling per batch: {e!r}')
            _log_mem(f"{aspect} fold {fold+1} post scaler")
            _log_gpu_mem(f"{aspect} fold {fold+1} post scaler")

            # GPU selection: only if we can afford the full training fold on GPU.
            use_gpu = bool(HAS_RAPIDS and _gpu_mem_okay(int(len(idx_tr)), int(X.shape[1])))
            gpu_success = False

            if use_gpu:
                try:
                    _stage(f"[AUDITOR] {aspect} Fold {fold+1}: Using RAPIDS/cuML")

                    _stage(f"[AUDIT] {aspect} Fold {fold+1}: PRE X_tr_gpu build")
                    _log_mem(f"{aspect} pre X_tr_gpu")
                    _log_gpu_mem(f"{aspect} pre X_tr_gpu")

                    # Build X_tr on GPU by streaming scaled batches (avoid full host materialisation).
                    X_tr_gpu, mean_gpu, scale_gpu = _alloc_and_fill_gpu_from_indexed_rows(
                        X,
                        idx_tr,
                        mean=mean,
                        scale=scale,
                        bs=1024,
                        desc=f'{aspect} Fold {fold+1} ->GPU X_tr',
                    )

                    _stage(f"[AUDIT] {aspect} Fold {fold+1}: POST X_tr_gpu build")
                    _log_mem(f"{aspect} post X_tr_gpu")
                    _log_gpu_mem(f"{aspect} post X_tr_gpu")

                    # Performance: keep full X_test on GPU for the entire fold when possible.
                    # This removes repeated host->device transfers for every target-chunk.
                    X_test_gpu = None
                    try:
                        _stage(f"[AUDITOR] {aspect} Fold {fold+1}: attempting to keep X_test on GPU (fold-scoped)")
                        _log_gpu_mem(f"{aspect} pre X_test_gpu")
                        n_te = int(test_src.shape[0])
                        n_cols = int(test_src.shape[1])
                        X_test_gpu = cp.empty((n_te, n_cols), dtype=cp.float32)
                        for b0 in tqdm(
                            range(0, n_te, int(TEST_BS)),
                            desc=f'{aspect} Fold {fold+1} ->GPU X_test',
                            unit='batch',
                            leave=False,
                        ):
                            b1 = min(b0 + int(TEST_BS), n_te)
                            xb = np.asarray(test_src[b0:b1], dtype=np.float32)
                            xg = cp.asarray(xb)
                            if not test_prescaled:
                                xg = (xg - mean_gpu) / (scale_gpu + 1e-12)
                            X_test_gpu[b0:b1, :] = xg
                            del xb, xg
                        _stage(f"[AUDITOR] {aspect} Fold {fold+1}: X_test_gpu ready: shape={tuple(X_test_gpu.shape)} dtype={X_test_gpu.dtype}")
                        _log_gpu_mem(f"{aspect} post X_test_gpu")
                    except Exception as e:
                        X_test_gpu = None
                        _stage(f"[AUDITOR] {aspect} Fold {fold+1}: WARNING: X_test not kept on GPU; streaming per chunk. Reason: {e!r}")
                        _log_gpu_mem(f"{aspect} after X_test_gpu fail")

                    n_chunks = int((n_targets + TARGET_CHUNK - 1) // TARGET_CHUNK)
                    chunk_total_s = []
                    chunk_fit_s = []
                    chunk_val_s = []
                    chunk_test_s = []

                    pbar = tqdm(
                        range(0, n_targets, TARGET_CHUNK),
                        total=n_chunks,
                        desc=f'{aspect} Fold {fold+1} target chunks',
                        unit='chunk',
                        leave=False,
                    )

                    for start in pbar:
                        t0_chunk = time.time()
                        end = min(start + TARGET_CHUNK, n_targets)

                        if start == 0:
                            _stage(f"[AUDIT] {aspect} Fold {fold+1} chunk0 PRE Y_tr_chunk")
                            _log_mem(f"{aspect} chunk0 pre Y")
                            _log_gpu_mem(f"{aspect} chunk0 pre Y")

                        cols = aspect_indices[start:end]
                        y_host = np.asarray(Y_full[np.ix_(idx_tr, cols)], dtype=np.float32)
                        Y_tr_chunk = cp.asarray(y_host)

                        if start == 0:
                            _stage(f"[AUDIT] {aspect} Fold {fold+1} chunk0 PRE fit")
                            _log_mem(f"{aspect} chunk0 pre fit")
                            _log_gpu_mem(f"{aspect} chunk0 pre fit")

                        clf_chunk = cuOVR(cuLogReg(solver='qn', penalty='l2', C=1.0, max_iter=2000, tol=1e-3))
                        chunk_i = int(start // TARGET_CHUNK) + 1
                        hb_stop, hb_thr = _start_heartbeat(f'{aspect} Fold {fold+1} chunk {chunk_i}/{n_chunks} cuML fit', every_s=60.0)
                        t0 = time.time()
                        try:
                            clf_chunk.fit(X_tr_gpu, Y_tr_chunk)
                        finally:
                            hb_stop.set()
                            try:
                                hb_thr.join(timeout=2.0)
                            except Exception:
                                pass
                        fit_s = time.time() - t0

                        if start == 0:
                            _stage(f"[AUDIT] {aspect} Fold {fold+1} chunk0 POST fit")
                            _log_mem(f"{aspect} chunk0 post fit")
                            _log_gpu_mem(f"{aspect} chunk0 post fit")

                        # Validation probs (batched) -> write directly into oof memmap
                        t0 = time.time()
                        _predict_proba_gpu_batched(
                            clf_chunk,
                            val_src,
                            val_rows,
                            mean_gpu,
                            scale_gpu,
                            bs=VAL_BS,
                            out_np=oof_pred,
                            out_rows=idx_val,
                            col_slice=slice(start, end),
                            desc=f'{aspect} Fold {fold+1} val proba',
                        )
                        val_s = time.time() - t0

                        # Test probs (batched) -> accumulate
                        t0 = time.time()
                        if X_test_gpu is not None:
                            # Fast path: test features already on GPU (no repeated host->device copies).
                            for b0 in range(0, int(test_src.shape[0]), TEST_BS):
                                b1 = min(b0 + TEST_BS, int(test_src.shape[0]))
                                xb_gpu = X_test_gpu[b0:b1, :]
                                p_te = _predict_proba_like(clf_chunk, xb_gpu, expect_cols=int(end - start))
                                test_pred[b0:b1, start:end] += (np.asarray(p_te, dtype=np.float32) / float(n_splits))
                                del xb_gpu, p_te
                        else:
                            # Fallback: stream from host per batch (slower; keeps VRAM lower).
                            for b0 in range(0, int(test_src.shape[0]), TEST_BS):
                                b1 = min(b0 + TEST_BS, int(test_src.shape[0]))
                                xb = np.asarray(test_src[b0:b1], dtype=np.float32)
                                xb_gpu = cp.asarray(xb)
                                if not test_prescaled:
                                    xb_gpu = (xb_gpu - mean_gpu) / (scale_gpu + 1e-12)

                                p_te = _predict_proba_like(clf_chunk, xb_gpu, expect_cols=int(end - start))
                                test_pred[b0:b1, start:end] += (np.asarray(p_te, dtype=np.float32) / float(n_splits))
                                del xb, xb_gpu, p_te
                        test_s = time.time() - t0

                        del y_host, Y_tr_chunk, clf_chunk
                        try:
                            cp.get_default_memory_pool().free_all_blocks()
                            cp.get_default_pinned_memory_pool().free_all_blocks()
                        except Exception:
                            pass
                        gc.collect()
                        total_s = time.time() - t0_chunk
                        chunk_fit_s.append(float(fit_s))
                        chunk_val_s.append(float(val_s))
                        chunk_test_s.append(float(test_s))
                        chunk_total_s.append(float(total_s))

                        # ETA: update tqdm postfix every chunk (smoothed after >=2 chunks).
                        chunk_i = int(start // TARGET_CHUNK) + 1
                        if len(chunk_total_s) >= 2:
                            recent = chunk_total_s[-min(5, len(chunk_total_s)):]
                            # Drop the very first chunk from ETA if we have enough signal.
                            if len(chunk_total_s) >= 3:
                                recent = chunk_total_s[-min(5, len(chunk_total_s) - 1):]
                            avg_s = float(np.mean(recent)) if recent else float(chunk_total_s[-1])
                            eta_fold_s = avg_s * float(n_chunks - chunk_i)
                            eta_txt = _fmt_seconds(eta_fold_s)
                            rate_txt = _fmt_seconds(avg_s) + '/chunk'
                        else:
                            eta_txt = '?'
                            rate_txt = '?'
                        try:
                            pbar.set_postfix_str(f'avg~{rate_txt} ETA~{eta_txt}')
                        except Exception:
                            pass

                        # Progress: first few chunks are often slow (warmup/JIT).
                        if chunk_i <= 3 or (chunk_i % 10 == 0) or (chunk_i == n_chunks):
                            _stage(f"[PROGRESS] {aspect} Fold {fold+1}: chunk {chunk_i}/{n_chunks} fit={_fmt_seconds(fit_s)} val={_fmt_seconds(val_s)} test={_fmt_seconds(test_s)} total={_fmt_seconds(total_s)} avg~{rate_txt} ETA~{eta_txt}")

                    try:
                        del X_test_gpu
                    except Exception:
                        pass
                    del X_tr_gpu, mean_gpu, scale_gpu
                    try:
                        cp.get_default_memory_pool().free_all_blocks()
                        cp.get_default_pinned_memory_pool().free_all_blocks()
                    except Exception:
                        pass
                    gc.collect()
                    gpu_success = True

                except Exception as e:
                    _stage(f"[CRITICAL] {aspect} Fold {fold+1}: GPU Training Failed (likely OOM): {e}")
                    _stage('[AUDITOR] Cleaning up GPU memory and falling back to CPU...')
                    gpu_success = False
                    if HAS_RAPIDS and cp is not None:
                        try:
                            cp.get_default_memory_pool().free_all_blocks()
                            cp.get_default_pinned_memory_pool().free_all_blocks()
                        except Exception:
                            pass
                    gc.collect()

            if (not use_gpu) or (not gpu_success):
                _stage(f"[AUDITOR] {aspect} Fold {fold+1}: Using CPU SGD (fallback)")

                # CPU fallback: build one scaled X_tr memmap, then fit/predict target chunks.
                X_trs_path = tmp_dir / f'X_tr_scaled_{aspect}_fold{fold}.npy'
                if not X_trs_path.exists():
                    mm = np.lib.format.open_memmap(str(X_trs_path), mode='w+', dtype=np.float32, shape=(len(idx_tr), X.shape[1]))
                    for i, j, xb in _iter_indexed_rows(X, idx_tr, bs=2048, desc=f'{aspect} Fold {fold+1} scale X_tr (CPU)'):
                        xb = (xb - mean) / (scale + 1e-12)
                        mm[i:j, :] = xb
                    mm.flush()
                    del mm
                    gc.collect()

                X_trs = np.load(X_trs_path, mmap_mode='r')

                for start in tqdm(
                    range(0, n_targets, TARGET_CHUNK),
                    total=(n_targets + TARGET_CHUNK - 1) // TARGET_CHUNK,
                    desc=f'{aspect} Fold {fold+1} CPU target chunks',
                    unit='chunk',
                    leave=False,
                ):
                    end = min(start + TARGET_CHUNK, n_targets)
                    cols = aspect_indices[start:end]
                    Y_tr_chunk = np.asarray(Y_full[np.ix_(idx_tr, cols)], dtype=np.float32)

                    # NOTE: keep n_jobs=1 to avoid loky subprocesses importing torch (can crash on CUDA mismatches).
                    clf_logreg = OneVsRestClassifier(
                        SGDClassifier(loss='log_loss', penalty='l2', alpha=0.0001, max_iter=1, tol=None, n_jobs=1),
                        n_jobs=1,
                    )
                    clf_logreg.fit(X_trs, Y_tr_chunk)

                    # Validation preds (batched, scale on the fly)
                    for i, j, xb in _iter_indexed_rows(val_src, val_rows, bs=VAL_BS, desc=f'{aspect} Fold {fold+1} val predict (CPU)'):
                        xb = (xb - mean) / (scale + 1e-12)
                        pb = _predict_proba_like(clf_logreg, xb, expect_cols=int(end - start)).astype(np.float32)
                        oof_pred[idx_val[i:j], start:end] = pb

                    # Test preds (batched)
                    for b0 in tqdm(range(0, int(test_src.shape[0]), TEST_BS), desc=f'{aspect} Fold {fold+1} test predict (CPU)', unit='batch', leave=False):
                        b1 = min(b0 + TEST_BS, int(test_src.shape[0]))
                        xb = np.asarray(test_src[b0:b1], dtype=np.float32)
                        if not test_prescaled:
                            xb = (xb - mean) / (scale + 1e-12)
                        pb = _predict_proba_like(clf_logreg, xb, expect_cols=int(end - start)).astype(np.float32)
                        test_pred[b0:b1, start:end] += pb / float(n_splits)

                    del Y_tr_chunk, clf_logreg
                    gc.collect()

                del X_trs
                gc.collect()

                # Cleanup temp scaled fold
                try:
                    os.remove(X_trs_path)
                except Exception:
                    pass

            # Fold diagnostics (sampled, capped columns to avoid huge allocations)
            try:
                sample_n = int(min(20000, len(idx_val)))
                sample_k = int(min(2000, n_targets))
                if sample_n > 0 and sample_k > 0:
                    sample_probs = np.asarray(oof_pred[idx_val[:sample_n], :sample_k], dtype=np.float32)
                    cols = aspect_indices[:sample_k]
                    sample_true = np.asarray(Y_full[np.ix_(idx_val[:sample_n], cols)], dtype=np.float32)

                    best_f1 = 0.0
                    best_thr = 0.0
                    for thr in np.linspace(0.01, 0.20, 20):
                        vp = (sample_probs > thr).astype(np.int8)
                        score = f1_score(sample_true, vp, average='micro')
                        if score > best_f1:
                            best_f1, best_thr = score, float(thr)

                    _stage(f"  >> {aspect} Fold {fold+1} (sample) micro-F1={best_f1:.4f} best_thr={best_thr:.2f} (k={sample_k})")
            except Exception as e:
                _stage('  [WARNING] Diagnostics skipped: ' + repr(e))

            oof_pred.flush()
            test_pred.flush()
            marker_path.touch()
            _stage(f"{aspect} Fold {fold+1} completed and flushed.")
            _stage(f"[TIMER] {aspect} Fold {fold+1} wall: {time.time() - t0_fold:.1f}s")
            fold_wall_s.append(float(time.time() - t0_fold))
            done_now = int(already_done + len(fold_wall_s))
            left = max(0, int(n_splits - done_now))
            recent = fold_wall_s[-min(3, len(fold_wall_s)):]
            avg_fold = float(np.mean(recent)) if recent else float(fold_wall_s[-1])
            try:
                fold_iter.set_postfix_str(f'avg~{_fmt_seconds(avg_fold)}/fold ETA~{_fmt_seconds(avg_fold * left)}')
            except Exception:
                pass
            _log_mem(f"{aspect} fold {fold+1} end")
            _log_gpu_mem(f"{aspect} fold {fold+1} end")

        oof_pred.flush()
        test_pred.flush()
        del oof_pred, test_pred
        gc.collect()

        oof_pred_logreg_by_aspect[aspect] = np.load(lr_oof_path, mmap_mode='r')
        test_pred_logreg_by_aspect[aspect] = np.load(lr_test_path, mmap_mode='r')
        _stage(f"[AUDITOR] {aspect}: saved {lr_oof_path.name}, {lr_test_path.name}")

    # -----------------------------
    # Backwards compatibility: combined outputs are only assembled once BP+MF+CC exist.
    # -----------------------------
    if set(['BP', 'MF', 'CC']).issubset(set(oof_pred_logreg_by_aspect.keys())):
        lr_oof_full_path = PRED_DIR / 'oof_pred_logreg.npy'
        lr_test_full_path = PRED_DIR / 'test_pred_logreg.npy'

        mode_oof = 'r+' if lr_oof_full_path.exists() else 'w+'
        mode_test = 'r+' if lr_test_full_path.exists() else 'w+'
        oof_full = np.lib.format.open_memmap(str(lr_oof_full_path), mode=mode_oof, dtype=np.float32, shape=(X.shape[0], len(top_terms)))
        te_full = np.lib.format.open_memmap(str(lr_test_full_path), mode=mode_test, dtype=np.float32, shape=(X_test.shape[0], len(top_terms)))
        if mode_oof == 'w+':
            oof_full[:] = 0.0
        if mode_test == 'w+':
            te_full[:] = 0.0

        for asp in ['BP', 'MF', 'CC']:
            idx = aspect_indices_map[asp]
            oof_full[:, idx] = np.asarray(oof_pred_logreg_by_aspect[asp], dtype=np.float32)
            te_full[:, idx] = np.asarray(test_pred_logreg_by_aspect[asp], dtype=np.float32)

        oof_full.flush()
        te_full.flush()
        del oof_full, te_full
        gc.collect()

        oof_pred_logreg = np.load(lr_oof_full_path, mmap_mode='r')
        test_pred_logreg = np.load(lr_test_full_path, mmap_mode='r')
        _stage(f"[AUDITOR] Combined preds saved: {lr_oof_full_path.name}, {lr_test_full_path.name}")
    else:
        oof_pred_logreg = None
        test_pred_logreg = None

    # Final checkpoint push (per-aspect; combined only if present)
    if 'STORE' in globals() and STORE is not None:
        required = [str(top_terms_path.as_posix())]
        for asp in oof_pred_logreg_by_aspect.keys():
            required += [
                str((PRED_DIR / f'oof_pred_logreg_{asp}.npy').as_posix()),
                str((PRED_DIR / f'test_pred_logreg_{asp}.npy').as_posix()),
                str((PRED_DIR / f'top_terms_{asp}.json').as_posix()),
            ]
        if oof_pred_logreg is not None and test_pred_logreg is not None:
            required += [
                str((PRED_DIR / 'oof_pred_logreg.npy').as_posix()),
                str((PRED_DIR / 'test_pred_logreg.npy').as_posix()),
            ]

        try:
            STORE.maybe_push(
                stage='stage_07a_level1_logreg_aspect_split',
                required_paths=required,
                note='Level-1 Logistic Regression predictions (OOF + test), split by GO aspect (BP/MF/CC).',
            )
        except Exception as e:
            _stage(f"[WARN] STORE push failed: {e}")

    _stage(f"[TIMER] LogReg total wall: {time.time() - t0_all:.1f}s")
