In [None]:
# CELL 13a - Setup & Data Loading (Phase 2 canonical)
# =============================================
# 4. PHASE 2: LEVEL-1 MODELS (DIVERSE ENSEMBLE)
# =============================================
# Target selection source-of-truth: Colab_04b_first_submission_no_ankh.ipynb (aspect-split Top-K)


if TRAIN_LEVEL1:
    import gc
    import json
    import os
    from pathlib import Path

    import numpy as np
    import pandas as pd
    import psutil

    # AUDITOR: Hardware Check
    try:
        import torch

        if torch.cuda.is_available():
            print(f"[AUDITOR] GPU Detected: {torch.cuda.get_device_name(0)}")
            print(
                f"[AUDITOR] VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB"
            )
        else:
            print("[AUDITOR] WARNING: No GPU detected.")
    except Exception:
        pass

    def log_mem(tag: str = "") -> None:
        try:
            mem = psutil.virtual_memory()
            print(
                f"[MEM] {tag:<30} | Used: {mem.used/1e9:.2f}GB / {mem.total/1e9:.2f}GB ({mem.percent}%)"
            )
        except Exception:
            pass

    # WORK_ROOT recovery (safety)
    # Prefer canonical dataset root (cafa6_data/) and validate by presence of parsed artefacts.
    if "WORK_ROOT" not in locals() and "WORK_ROOT" not in globals():
        candidates = [
            Path("/content/cafa6_data"),
            Path("/content/work"),
            Path("/kaggle/working/work"),
            Path.cwd() / "cafa6_data",
            Path.cwd() / "artefacts_local" / "work",
        ]

        WORK_ROOT = None
        for c in candidates:
            if (c / "parsed" / "train_terms.parquet").exists():
                WORK_ROOT = c
                break

        if WORK_ROOT is None:
            for c in candidates:
                if c.exists():
                    WORK_ROOT = c
                    break

        if WORK_ROOT is None:
            WORK_ROOT = Path.cwd() / "cafa6_data"

        print(f"WORK_ROOT recovered: {WORK_ROOT}")

    # -----------------------------
    # Load targets + ids
    # -----------------------------
    print("Loading targets...")
    train_terms = pd.read_parquet(WORK_ROOT / "parsed" / "train_terms.parquet")
    train_ids = pd.read_feather(WORK_ROOT / "parsed" / "train_seq.feather")["id"].astype(str)
    test_ids = pd.read_feather(WORK_ROOT / "parsed" / "test_seq.feather")["id"].astype(str)

    # FIX: Clean IDs in train_ids to match EntryID format
    print("Applying ID cleaning fix...")
    train_ids_clean = train_ids.str.extract(r"\|(.*?)\|")[0]
    train_ids_clean = train_ids_clean.fillna(train_ids)

    # -----------------------------
    # Target Matrix Construction (Champion Strategy: 13,500 Terms)
    # 10,000 BP + 2,000 MF + 1,500 CC
    # -----------------------------
    print("Selecting Top-K terms per aspect (Champion Strategy)...")

    try:
        import obonet

        # Robust OBO Path Search
        possible_paths = [
            WORK_ROOT / "go-basic.obo",
            WORK_ROOT / "Train" / "go-basic.obo",
            WORK_ROOT.parent / "go-basic.obo",
            Path("go-basic.obo"),
            Path("Train/go-basic.obo"),
            Path("../Train/go-basic.obo"),
            Path("/content/cafa6_data/Train/go-basic.obo"),
        ]

        obo_path = None
        for p in possible_paths:
            if p.exists():
                obo_path = p
                break

        if obo_path is None:
            raise FileNotFoundError(
                f"CRITICAL: go-basic.obo not found. Searched: {[str(p) for p in possible_paths]}"
            )

        global PATH_GO_OBO
        PATH_GO_OBO = obo_path
        print(f"Global PATH_GO_OBO set to: {PATH_GO_OBO}")

        print(f"Loading OBO from {obo_path}...")
        graph = obonet.read_obo(obo_path)
        term_to_ns = {
            node: data.get("namespace", "unknown") for node, data in graph.nodes(data=True)
        }

        # Keep compatibility with downstream code that expects go_namespaces
        go_namespaces = term_to_ns

        ns_map = {
            "biological_process": "BP",
            "molecular_function": "MF",
            "cellular_component": "CC",
        }

        # Normalise any existing aspect column (some artefacts store full namespace strings)
        aspect_aliases = {
            "biological_process": "BP",
            "molecular_function": "MF",
            "cellular_component": "CC",
            "BP": "BP",
            "MF": "MF",
            "CC": "CC",
        }
        if "aspect" in train_terms.columns:
            train_terms["aspect"] = train_terms["aspect"].map(
                lambda a: aspect_aliases.get(str(a), "UNK")
            )
        else:
            train_terms["aspect"] = train_terms["term"].map(
                lambda t: ns_map.get(term_to_ns.get(t), "UNK")
            )

    except ImportError as e:
        raise RuntimeError("obonet not installed. Please install it.") from e

    # Canonical aspect split (04b)
    term_counts = train_terms.groupby(["aspect", "term"]).size().reset_index(name="count")
    targets_bp = (
        term_counts[term_counts["aspect"] == "BP"].nlargest(10000, "count")["term"].tolist()
    )
    targets_mf = (
        term_counts[term_counts["aspect"] == "MF"].nlargest(2000, "count")["term"].tolist()
    )
    targets_cc = (
        term_counts[term_counts["aspect"] == "CC"].nlargest(1500, "count")["term"].tolist()
    )

    # Guardrail: avoid silently switching target strategy due to aspect encoding mismatch
    ALLOW_GLOBAL_FALLBACK = False
    if len(targets_bp) == 0 and len(targets_mf) == 0 and len(targets_cc) == 0:
        aspect_vc = train_terms["aspect"].value_counts().to_dict() if "aspect" in train_terms.columns else {}
        msg = (
            "No BP/MF/CC aspect split found after normalisation. "
            f"aspect_vc={aspect_vc}. This would fall back to global Top-13,500; "
            "set ALLOW_GLOBAL_FALLBACK=True to override."
        )
        if ALLOW_GLOBAL_FALLBACK:
            print("  [WARNING] " + msg)
            top_terms = train_terms["term"].value_counts().head(13500).index.tolist()
        else:
            raise RuntimeError(msg)
    else:
        # Stable, deterministic ordering: BP then MF then CC with de-dup preserving order
        top_terms = []
        seen = set()
        for t in (targets_bp + targets_mf + targets_cc):
            if t not in seen:
                top_terms.append(t)
                seen.add(t)
        print(f"  Selected: {len(targets_bp)} BP + {len(targets_mf)} MF + {len(targets_cc)} CC")

    # Persist label contract for downstream stages
    top_terms_path = WORK_ROOT / "features" / "top_terms_13500.json"
    top_terms_path.parent.mkdir(parents=True, exist_ok=True)
    if top_terms_path.exists():
        try:
            with open(top_terms_path, "r", encoding="utf-8") as f:
                top_terms_disk = json.load(f)
            if isinstance(top_terms_disk, list) and len(top_terms_disk) > 0:
                top_terms = [str(x) for x in top_terms_disk]
                print(f"Loaded existing top_terms_13500.json (n={len(top_terms)})")
        except Exception as e:
            print(f"[WARNING] Failed to load existing top_terms_13500.json: {e}")
    else:
        with open(top_terms_path, "w", encoding="utf-8") as f:
            json.dump(list(top_terms), f)
        print("Saved: top_terms_13500.json")

    # -----------------------------
    # Stable target contract (audited: 1,585 terms)
    # Definition: GO terms with >= 50 positives AND valid namespace (BP/MF/CC)
    # Stored separately from top_terms_13500.json (do not mix contracts).
    # -----------------------------
    stable_terms_path = WORK_ROOT / "features" / "stable_terms_1585.json"
    stable_meta_path = WORK_ROOT / "features" / "stable_terms_1585_meta.json"
    noise_floor = 50

    if stable_terms_path.exists():
        try:
            stable_terms = json.loads(stable_terms_path.read_text(encoding="utf-8"))
            stable_terms = [str(t) for t in stable_terms]
            print(f"Loaded existing stable_terms_1585.json (n={len(stable_terms)})")
        except Exception as e:
            raise RuntimeError(f"Failed to load {stable_terms_path}: {e}")
    else:
        # Compute from Phase-1 truth (train_terms.parquet) and OBO namespace mapping already loaded above.
        stable_bp = (
            term_counts[(term_counts["aspect"] == "BP") & (term_counts["count"] >= noise_floor)]
            .sort_values(["count", "term"], ascending=[False, True])["term"]
            .astype(str)
            .tolist()
        )
        stable_mf = (
            term_counts[(term_counts["aspect"] == "MF") & (term_counts["count"] >= noise_floor)]
            .sort_values(["count", "term"], ascending=[False, True])["term"]
            .astype(str)
            .tolist()
        )
        stable_cc = (
            term_counts[(term_counts["aspect"] == "CC") & (term_counts["count"] >= noise_floor)]
            .sort_values(["count", "term"], ascending=[False, True])["term"]
            .astype(str)
            .tolist()
        )
        stable_terms = stable_bp + stable_mf + stable_cc
        stable_terms_path.write_text(json.dumps(stable_terms), encoding="utf-8")
        stable_meta_path.write_text(
            json.dumps(
                {
                    "noise_floor": noise_floor,
                    "counts": {"BP": len(stable_bp), "MF": len(stable_mf), "CC": len(stable_cc)},
                    "total": len(stable_terms),
                },
                indent=2,
            ),
            encoding="utf-8",
        )
        print(f"Saved: stable_terms_1585.json (n={len(stable_terms)})")

    if len(stable_terms) != 1585:
        raise RuntimeError(f"Stable term contract mismatch: expected 1585, got {len(stable_terms)}")

    top_term_to_idx = {t: i for i, t in enumerate(top_terms)}
    missing_stable = [t for t in stable_terms if t not in top_term_to_idx]
    if missing_stable:
        raise RuntimeError(
            "Stable terms contain items not present in top_terms_13500.json. "
            f"Missing={len(missing_stable)} (example: {missing_stable[:10]})"
        )

    stable_idx = np.asarray([top_term_to_idx[t] for t in stable_terms], dtype=np.int64)
    print(f"Stable targets ready: n={int(stable_idx.shape[0])} (expected 1585)")

    train_terms_top = train_terms[train_terms["term"].isin(top_terms)]
    Y_df = train_terms_top.pivot_table(index="EntryID", columns="term", aggfunc="size", fill_value=0)
    Y_df = Y_df.reindex(train_ids_clean, fill_value=0)
    Y = Y_df.values.astype(np.float32)
    print(f"Targets: Y={Y.shape}")

    # -----------------------------
    # Feature loading helper (Memory Optimised)
    # -----------------------------
    FEAT_DIR = WORK_ROOT / "features"

    def load_features_dict(split: str = "both"):
        log_mem(f"Start load_features_dict({split})")
        print(f"Loading multimodal features (mode={split})...")

        ft_train = {}
        ft_test = {}

        def _load_pair(stem: str):
            tr = FEAT_DIR / f"train_embeds_{stem}.npy"
            te = FEAT_DIR / f"test_embeds_{stem}.npy"
            return tr, te

        # All modalities are mandatory.
        stems = [
            ("t5", "t5"),
            ("esm2", "esm2_650m"),
            ("esm2_3b", "esm2_3b"),
            ("ankh", "ankh"),
            ("text", "text"),
        ]

        for stem, key in stems:
            tr_path, te_path = _load_pair(stem)
            if not (tr_path.exists() and te_path.exists()):
                raise FileNotFoundError(f"Missing mandatory embeddings for {stem}: {tr_path} or {te_path}")

            if split in ["both", "train"]:
                ft_train[key] = np.load(tr_path, mmap_mode="r")
            if split in ["both", "test"]:
                ft_test[key] = np.load(te_path, mmap_mode="r")

        taxa_train_path = WORK_ROOT / "parsed" / "train_taxa.feather"
        taxa_test_path = WORK_ROOT / "parsed" / "test_taxa.feather"

        if not (taxa_train_path.exists() and taxa_test_path.exists()):
            raise FileNotFoundError(f"Missing mandatory taxa features: {taxa_train_path} or {taxa_test_path}")

        from sklearn.preprocessing import OneHotEncoder

        tax_tr = pd.read_feather(taxa_train_path).astype({"id": str})
        tax_te = pd.read_feather(taxa_test_path).astype({"id": str})
        enc = OneHotEncoder(handle_unknown="ignore", sparse_output=False, dtype=np.float32)
        enc.fit(pd.concat([tax_tr[["taxon_id"]], tax_te[["taxon_id"]]], axis=0))

        if split in ["both", "train"]:
            tax_tr = tax_tr.set_index("id").reindex(train_ids, fill_value=0).reset_index()
            ft_train["taxa"] = enc.transform(tax_tr[["taxon_id"]]).astype(np.float32)
        if split in ["both", "test"]:
            tax_te = tax_te.set_index("id").reindex(test_ids, fill_value=0).reset_index()
            ft_test["taxa"] = enc.transform(tax_te[["taxon_id"]]).astype(np.float32)

        log_mem(f"End load_features_dict({split})")
        if split == "train":
            return ft_train
        if split == "test":
            return ft_test
        return ft_train, ft_test

    # Materialise feature dicts (mmap arrays where possible)
    features_train, features_test = load_features_dict(split="both")

    # Flat concatenation order for classical models (LR/GBDT)
    FLAT_KEYS = [k for k in ["t5", "esm2_650m", "esm2_3b", "ankh", "text", "taxa"] if k in features_train]
    if "ankh" not in FLAT_KEYS:
        raise RuntimeError("Ankh is mandatory but was not loaded into features_train.")
    print(f"Flat X keys={FLAT_KEYS}")

    # -----------------------------
    # Disk-backed X / X_test (for RAM-safe downstream cells)
    # -----------------------------
    X_train_path = FEAT_DIR / "X_train_mmap.npy"
    X_test_path = FEAT_DIR / "X_test_mmap.npy"

    def _build_X_memmaps(chunk_size: int = 10000) -> None:
        dims = {k: int(features_train[k].shape[1]) for k in FLAT_KEYS}
        total_dim = int(sum(dims.values()))
        n_tr = int(len(train_ids))
        n_te = int(len(test_ids))

        print(f"Building X memmaps: train=({n_tr}, {total_dim}) test=({n_te}, {total_dim})")
        X_mm = np.lib.format.open_memmap(
            str(X_train_path), mode="w+", dtype=np.float32, shape=(n_tr, total_dim)
        )
        Xte_mm = np.lib.format.open_memmap(
            str(X_test_path), mode="w+", dtype=np.float32, shape=(n_te, total_dim)
        )

        col = 0
        for k in FLAT_KEYS:
            d = dims[k]
            print(f"  Streaming {k} into cols {col}:{col + d}")
            for i in range(0, n_tr, chunk_size):
                j = min(i + chunk_size, n_tr)
                X_mm[i:j, col : col + d] = np.asarray(features_train[k][i:j], dtype=np.float32)
            for i in range(0, n_te, chunk_size):
                j = min(i + chunk_size, n_te)
                Xte_mm[i:j, col : col + d] = np.asarray(features_test[k][i:j], dtype=np.float32)
            col += d

        X_mm.flush()
        Xte_mm.flush()

    if X_train_path.exists() and X_test_path.exists():
        print("X memmaps already exist; skipping build.")
    else:
        _build_X_memmaps(chunk_size=5000)

    X = np.load(X_train_path, mmap_mode="r")
    X_test = np.load(X_test_path, mmap_mode="r")

    log_mem("Phase 2 setup done")


In [None]:
# CELL 13E - KNN (GPU-accelerated, IA-weighted) + checkpoint push
# ===================================================================
# RANK-1 Winning Strategy Implementation:
#   1. cuML GPU acceleration (10-20× speedup on A100)
#   2. L2 normalization → Euclidean = Cosine (Manual GEMM Fast Path)
#   3. IA-weighted neighbor voting (prioritises rare/high-value terms)
#   4. Aspect-specific calibration (BP: 0.25, MF: 0.35, CC: 0.35)
#   5. Finite-value quality gates (prevents NaN corruption of GCN stacker)

if not TRAIN_LEVEL1:
    print('Skipping KNN (TRAIN_LEVEL1=False).')
else:
    import json
    import time
    import numpy as np
    from pathlib import Path
    from sklearn.model_selection import KFold

    # Telemetry (best-effort)
    try:
        import psutil
    except Exception:
        psutil = None

    try:
        import cupy as cp
    except Exception:
        cp = None

    def _fmt_s(seconds: float) -> str:
        seconds = float(max(0.0, seconds))
        if seconds < 60:
            return f'{seconds:.1f}s'
        minutes = seconds / 60.0
        if minutes < 60:
            return f'{minutes:.1f}min'
        hours = minutes / 60.0
        return f'{hours:.2f}h'

    def _sys_parts() -> list[str]:
        parts: list[str] = []
        if psutil is not None:
            try:
                p = psutil.Process()
                rss_gb = p.memory_info().rss / (1024**3)
                avail_gb = psutil.virtual_memory().available / (1024**3)
                parts.append(f'rss={rss_gb:.1f}GiB')
                parts.append(f'ram_avail={avail_gb:.1f}GiB')
            except Exception:
                pass
        if cp is not None:
            try:
                free_b, total_b = cp.cuda.runtime.memGetInfo()
                used_gb = (total_b - free_b) / (1024**3)
                parts.append(f'vram_used={used_gb:.1f}GiB')
                parts.append(f'vram_total={total_b/(1024**3):.1f}GiB')
                pool = cp.get_default_memory_pool()
                parts.append(f'cp_pool_total={pool.total_bytes()/(1024**3):.1f}GiB')
                parts.append(f'cp_pool_used={pool.used_bytes()/(1024**3):.1f}GiB')
            except Exception:
                pass
        return parts

    def _hb(prefix: str, **kv):
        parts = [prefix]
        for k, v in kv.items():
            parts.append(f'{k}={v}')
        sys_p = _sys_parts()
        if sys_p:
            parts.extend(sys_p)
        print(' '.join(parts))

    overall_t0 = time.time()

    # RANK-1: cuML migration with runtime detection
    try:
        from cuml.neighbors import NearestNeighbors as cuNearestNeighbors

        USE_CUML = True
        print('[KNN] Using cuML NearestNeighbors (GPU-accelerated)')
    except ImportError:
        from sklearn.neighbors import NearestNeighbors

        USE_CUML = False
        print('[WARN] cuML not available; using sklearn NearestNeighbors (CPU, slower). Install cuML for Rank-1 performance.')

    WORK_ROOT = Path(WORK_ROOT)
    FEAT_DIR = WORK_ROOT / 'features'
    PRED_DIR = FEAT_DIR / 'level1_preds'
    PRED_DIR.mkdir(parents=True, exist_ok=True)

    # RANK-1: Load aspect-specific thresholds
    thr_path = FEAT_DIR / 'aspect_thresholds.json'
    if not thr_path.exists():
        print(
            f'[WARN] aspect_thresholds.json not found at {thr_path}. Run Cell 13F first for per-aspect thresholds (proven +3.3% F1 boost).'
        )
        ASPECT_THRESHOLDS = {'ALL': 0.3, 'BP': 0.25, 'MF': 0.35, 'CC': 0.35, 'UNK': 0.3}
    else:
        ASPECT_THRESHOLDS = json.loads(thr_path.read_text(encoding='utf-8'))
        print(f'[KNN] Loaded aspect thresholds: {ASPECT_THRESHOLDS}')

    knn_oof_path = PRED_DIR / 'oof_pred_knn.npy'
    knn_test_path = PRED_DIR / 'test_pred_knn.npy'

    # Backwards-compatible copies (some downstream code loads from WORK_ROOT/features)
    knn_oof_compat = WORK_ROOT / 'features' / 'oof_pred_knn.npy'
    knn_test_compat = WORK_ROOT / 'features' / 'test_pred_knn.npy'

    if knn_oof_path.exists() and knn_test_path.exists():
        oof_pred_knn = np.load(knn_oof_path)
        test_pred_knn = np.load(knn_test_path)
        oof_max_sim = None
        print('Loaded existing KNN preds:', knn_oof_path.name, knn_test_path.name)
    else:
        if 'features_train' not in globals() or 'features_test' not in globals():
            raise RuntimeError('Missing `features_train`/`features_test`. Run the Phase 2 feature load cell first.')
        if 'Y' not in globals():
            raise RuntimeError('Missing Y. Run Cell 13a first (targets).')
        if 'esm2_3b' not in features_train:
            raise FileNotFoundError(
                "Missing required modality 'esm2_3b' in features_train. Ensure features/train_embeds_esm2_3b.npy exists."
            )

        # ---- Load IA weights for weighted neighbour voting ----
        top_terms_path = FEAT_DIR / 'top_terms_13500.json'
        if not top_terms_path.exists():
            raise FileNotFoundError(f'Missing {top_terms_path}. Run Cell 13a first.')
        top_terms = [str(t) for t in json.loads(top_terms_path.read_text(encoding='utf-8'))]

        ia_path = WORK_ROOT / 'IA.tsv'
        if not ia_path.exists():
            raise FileNotFoundError(f'Missing IA.tsv at {ia_path}')

        import pandas as pd

        ia_df = pd.read_csv(ia_path, sep='\t')
        term_col = 'term' if 'term' in ia_df.columns else ('#term' if '#term' in ia_df.columns else ia_df.columns[0])
        ia_col = 'ia' if 'ia' in ia_df.columns else (ia_df.columns[1] if len(ia_df.columns) > 1 else ia_df.columns[0])
        ia_map = dict(zip(ia_df[term_col].astype(str).values, ia_df[ia_col].astype(np.float32).values))
        weights_full = np.asarray([ia_map.get(t, np.float32(1.0)) for t in top_terms], dtype=np.float32)
        print(
            f'[KNN] IA weights ready: shape={weights_full.shape} min={float(weights_full.min()):.4f} max={float(weights_full.max()):.4f}'
        )

        # ---- Prepare embeddings ----
        X_knn = features_train['esm2_3b'].astype(np.float32)
        X_knn_test = features_test['esm2_3b'].astype(np.float32)
        Y_knn = Y.astype(np.float32)

        # RANK-1: L2 pre-normalisation (transforms cosine→dot-product for Manual GEMM Fast Path)
        def _l2_norm(X):
            norms = np.linalg.norm(X, axis=1, keepdims=True)
            norms = np.maximum(norms, 1e-12)
            return (X / norms).astype(np.float32)

        t_norm = time.time()
        X_knn = _l2_norm(X_knn)
        X_knn_test = _l2_norm(X_knn_test)
        print(f'[KNN] L2-normalised embeddings: train={X_knn.shape} test={X_knn_test.shape} (took {_fmt_s(time.time()-t_norm)})')

        # ---- KNN parameters ----
        KNN_K = int(globals().get('KNN_K', 50))
        KNN_BATCH = int(globals().get('KNN_BATCH', 4096))  # Memory pipelining for A100
        n_splits = 5

        print(f'[KNN] Config: k={KNN_K} batch={KNN_BATCH} folds={n_splits} backend={"cuML" if USE_CUML else "sklearn"}')

        # ---- Initialise accumulators ----
        train_n = int(X_knn.shape[0])
        test_n = int(X_knn_test.shape[0])
        out_dim = int(Y_knn.shape[1])

        oof_pred_knn = np.zeros((train_n, out_dim), dtype=np.float32)
        test_pred_knn = np.zeros((test_n, out_dim), dtype=np.float32)
        oof_max_sim = np.zeros((train_n,), dtype=np.float32)

        # Broadcast IA weights: (1, 1, L)
        w_ia_broadcast = weights_full[np.newaxis, np.newaxis, :]

        # ---- 5-Fold Cross-Validation ----
        kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
        fold_t0 = time.time()
        folds_done = 0

        for fold, (tr_idx, va_idx) in enumerate(kf.split(X_knn), start=1):
            fold_start = time.time()
            print(f'\n[KNN] ═══ Fold {fold}/{n_splits} ═══')
            print(f'  Train samples: {len(tr_idx)}, Val samples: {len(va_idx)}')
            _hb('  [KNN][HB]', phase='start_fold')

            # Fit KNN on training fold
            if USE_CUML:
                knn = cuNearestNeighbors(n_neighbors=KNN_K, metric='euclidean')
            else:
                knn = NearestNeighbors(n_neighbors=KNN_K, metric='cosine', n_jobs=-1)

            t_fit = time.time()
            knn.fit(X_knn[tr_idx])
            print(f'  [KNN] Fit complete in {_fmt_s(time.time()-t_fit)}')

            # Compute neighbours for validation fold
            t_kn = time.time()
            dists_va, neigh_va = knn.kneighbors(X_knn[va_idx], return_distance=True)
            print(f'  [KNN] kneighbors(val) complete in {_fmt_s(time.time()-t_kn)}')

            # Convert distances to similarities
            # sklearn cosine: distance = 1 - similarity
            # cuML euclidean on L2-normalised vectors: we still use (1 - dist) proxy as previously
            sims_va = np.clip((1.0 - dists_va).astype(np.float32), 0.0, 1.0)

            # Global neighbour indices (map fold-local to global)
            neigh_global = tr_idx[neigh_va]

            # Batched aggregation (memory pipelining)
            agg_t0 = time.time()
            n_val = int(len(va_idx))
            n_batches = int((n_val + KNN_BATCH - 1) // KNN_BATCH)
            for b, i in enumerate(range(0, n_val, KNN_BATCH), start=1):
                j = min(i + KNN_BATCH, n_val)

                neigh_b = neigh_global[i:j]  # (B, K)
                sims_b = sims_va[i:j]        # (B, K)

                # Fetch neighbour labels
                Y_nei = Y_knn[neigh_b]  # (B, K, L)

                weighted_votes = (sims_b[:, :, np.newaxis] * Y_nei * w_ia_broadcast).sum(axis=1)
                denom = np.maximum(sims_b.sum(axis=1, keepdims=True), 1e-8)
                scores = (weighted_votes / denom).astype(np.float32)

                oof_pred_knn[va_idx[i:j]] = scores

                # Heartbeat (OOF aggregation)
                if b == 1 or b == n_batches or (b % 10 == 0):
                    elapsed = time.time() - agg_t0
                    avg = elapsed / b
                    eta = avg * (n_batches - b)
                    _hb(
                        '  [KNN][HB]',
                        phase='oof_agg',
                        batch=f'{b}/{n_batches}',
                        pct=f'{(100.0*b/n_batches):.0f}%',
                        elapsed=_fmt_s(elapsed),
                        eta=_fmt_s(eta),
                    )

            # Track max similarity for diagnostics
            oof_max_sim[va_idx] = sims_va.max(axis=1)
            print(f'  [KNN] Fold {fold} OOF complete in {_fmt_s(time.time()-fold_start)}')

            folds_done += 1
            fold_elapsed = time.time() - fold_t0
            avg_fold = fold_elapsed / folds_done
            eta_folds = avg_fold * (n_splits - folds_done)
            print(f'  [KNN] Fold progress: {folds_done}/{n_splits} | ETA(folds)={_fmt_s(eta_folds)} (avg_fold={_fmt_s(avg_fold)})')

        # Quality gate - Check OOF predictions
        if not np.isfinite(oof_pred_knn).all():
            nan_count = int((~np.isfinite(oof_pred_knn)).sum())
            print(f'[WARN] KNN OOF contains {nan_count} NaN/Inf values; clipping to valid range [0, 1]')
            oof_pred_knn = np.clip(np.nan_to_num(oof_pred_knn, nan=0.0, posinf=1.0, neginf=0.0), 0.0, 1.0)

        # ---- Test Predictions ----
        print('\n[KNN] Computing test predictions...')
        _hb('[KNN][HB]', phase='test_start', test_n=test_n)

        if USE_CUML:
            knn_final = cuNearestNeighbors(n_neighbors=KNN_K, metric='euclidean')
        else:
            knn_final = NearestNeighbors(n_neighbors=KNN_K, metric='cosine', n_jobs=-1)

        t_fit_final = time.time()
        knn_final.fit(X_knn)
        print(f'[KNN] Fit (full train) complete in {_fmt_s(time.time()-t_fit_final)}')

        t_kn_te = time.time()
        dists_te, neigh_te = knn_final.kneighbors(X_knn_test, return_distance=True)
        print(f'[KNN] kneighbors(test) complete in {_fmt_s(time.time()-t_kn_te)}')

        sims_te = np.clip((1.0 - dists_te).astype(np.float32), 0.0, 1.0)

        te_agg_t0 = time.time()
        n_batches_te = int((test_n + KNN_BATCH - 1) // KNN_BATCH)
        for b, i in enumerate(range(0, test_n, KNN_BATCH), start=1):
            j = min(i + KNN_BATCH, test_n)

            neigh_b = neigh_te[i:j]
            sims_b = sims_te[i:j]

            Y_nei = Y_knn[neigh_b]
            weighted_votes = (sims_b[:, :, np.newaxis] * Y_nei * w_ia_broadcast).sum(axis=1)
            denom = np.maximum(sims_b.sum(axis=1, keepdims=True), 1e-8)
            scores = (weighted_votes / denom).astype(np.float32)

            test_pred_knn[i:j] = scores

            # Heartbeat (test aggregation)
            if b == 1 or b == n_batches_te or (b % 10 == 0):
                elapsed = time.time() - te_agg_t0
                avg = elapsed / b
                eta = avg * (n_batches_te - b)
                _hb(
                    '[KNN][HB]',
                    phase='test_agg',
                    batch=f'{b}/{n_batches_te}',
                    pct=f'{(100.0*b/n_batches_te):.0f}%',
                    elapsed=_fmt_s(elapsed),
                    eta=_fmt_s(eta),
                )

        # Quality gate - Check test predictions
        if not np.isfinite(test_pred_knn).all():
            nan_count = int((~np.isfinite(test_pred_knn)).sum())
            print(f'[WARN] KNN test contains {nan_count} NaN/Inf values; clipping to valid range [0, 1]')
            test_pred_knn = np.clip(np.nan_to_num(test_pred_knn, nan=0.0, posinf=1.0, neginf=0.0), 0.0, 1.0)

        # Save predictions
        t_save = time.time()
        np.save(knn_oof_path, oof_pred_knn)
        np.save(knn_test_path, test_pred_knn)
        np.save(knn_oof_compat, oof_pred_knn)
        np.save(knn_test_compat, test_pred_knn)
        print(f'[KNN] Saved predictions in {_fmt_s(time.time()-t_save)}')
        print('Saved:', knn_oof_path)
        print('Saved:', knn_test_path)
        print('Saved (compat):', knn_oof_compat)
        print('Saved (compat):', knn_test_compat)

        print(f'\n[KNN] ═══ COMPLETE ═══ total_time={_fmt_s(time.time()-overall_t0)}')

    # Checkpoint push (always)
    STORE.maybe_push(
        stage='stage_07d_level1_knn',
        required_paths=[
            str((WORK_ROOT / 'features' / 'top_terms_13500.json').as_posix()),
            str(knn_oof_path.as_posix()),
            str(knn_test_path.as_posix()),
        ],
        note='Level-1 KNN (GPU-accelerated, IA-weighted) predictions using ESM2-3B embeddings (OOF + test).',
    )

    # Diagnostics: similarity distribution + IA-F1 vs threshold
    try:
        import os
        import matplotlib.pyplot as plt

        plt.rcParams.update({'font.size': 12})

        if oof_max_sim is not None:
            plt.figure(figsize=(10, 4))
            plt.hist(oof_max_sim, bins=50, alpha=0.7)
            plt.title('KNN OOF diagnostic: max cosine similarity to neighbours (per protein)')
            plt.xlabel('max similarity')
            plt.ylabel('count')
            plt.grid(True, alpha=0.3)
            plt.show()

        if 'ia_weighted_f1' in globals():
            DIAG_N = int(os.getenv('CAFA_DIAG_N', '20000'))

            def _sub(y_true: np.ndarray, y_score: np.ndarray):
                n = int(y_true.shape[0])
                m = min(n, int(DIAG_N))
                if m <= 0:
                    return y_true[:0], y_score[:0]
                idx = np.linspace(0, n - 1, num=m, dtype=np.int64)
                return y_true[idx], y_score[idx]

            y_t, y_s = _sub(Y, oof_pred_knn)

            thrs = np.linspace(0.05, 0.60, 23)
            curves = {k: [] for k in ['ALL', 'MF', 'BP', 'CC']}

            for thr in thrs:
                s = ia_weighted_f1(y_t, y_s, thr=float(thr))
                for k in curves.keys():
                    curves[k].append(s[k])

            plt.figure(figsize=(10, 3))
            for k in ['ALL', 'MF', 'BP', 'CC']:
                plt.plot(thrs, curves[k], label=k)
            plt.title('KNN OOF: IA-weighted F1 vs threshold (sampled)')
            plt.xlabel('threshold')
            plt.ylabel('IA-F1')
            plt.legend()
            plt.grid(True, alpha=0.3)
            plt.show()

    except Exception as e:
        print('KNN diagnostics skipped:', repr(e))
