# KNN Standalone Notebook

This notebook runs **only** the KNN model training from the e2e pipeline.

**Requirements:**
- Pre-computed embeddings (ESM2-3B)
- Parsed training data
- Top terms list (13,500 terms)
- IA weights

**What this does NOT include:**
- HuggingFace file downloading
- HuggingFace file uploading
- Unnecessary setup cells


In [2]:
# CELL 01 - Setup (NO REPO)
import os
import sys
import ctypes
from pathlib import Path

# CUDA loader fix (PyTorch/RAPIDS coexistence): preload venv nvjitlink so we don't pick /usr/local/cuda/lib64
try:
    _venv_root = Path(sys.executable).resolve().parent.parent
    _nvjit_dir = (
        _venv_root
        / "lib"
        / f"python{sys.version_info.major}.{sys.version_info.minor}"
        / "site-packages"
        / "nvidia"
        / "nvjitlink"
        / "lib"
    )
    _nvjit_so = _nvjit_dir / "libnvJitLink.so.12"
    if _nvjit_so.exists():
        ctypes.CDLL(str(_nvjit_so), mode=ctypes.RTLD_GLOBAL)
        os.environ["LD_LIBRARY_PATH"] = f"{_nvjit_dir}:{os.environ.get('LD_LIBRARY_PATH','')}"
        print(f"[ENV] Preloaded nvjitlink: {_nvjit_so}")
except Exception as _e:
    print(f"[ENV] nvjitlink preload skipped: {_e}")

# Always run from a simple writable location; never cd into a repo.
if os.path.exists('/content'):
    os.chdir('/content')
RUNTIME_ROOT = Path.cwd()
DATA_ROOT = (RUNTIME_ROOT / 'cafa6_data')
DATA_ROOT.mkdir(parents=True, exist_ok=True)
TRAIN_LEVEL1 = True
print(f'CWD: {Path.cwd()}')
print(f'DATA_ROOT: {DATA_ROOT.resolve()}')

CWD: c:\Users\Olale\Documents\Codebase\Science\cafa-6-protein-function-prediction
DATA_ROOT: C:\Users\Olale\Documents\Codebase\Science\cafa-6-protein-function-prediction\cafa6_data


In [None]:
# CELL 03 - Simplified Setup & Config (KNN Standalone)
# This is a stripped-down version without HuggingFace download/upload

import json
import os
from pathlib import Path
import numpy as np


# Environment Detection
def _detect_kaggle() -> bool:
    return bool(
        os.environ.get('KAGGLE_KERNEL_RUN_TYPE')
        or os.environ.get('KAGGLE_URL_BASE')
        or os.environ.get('KAGGLE_DATA_PROXY_URL')
    )


def _detect_colab() -> bool:
    return bool(
        os.environ.get('COLAB_RELEASE_TAG')
        or os.environ.get('COLAB_GPU')
        or os.environ.get('COLAB_TPU_ADDR')
    )


IS_KAGGLE = _detect_kaggle()
IS_COLAB = (not IS_KAGGLE) and _detect_colab()

if IS_KAGGLE:
    print('Environment: Kaggle')
    WORKING_ROOT = Path('/kaggle/working')
elif IS_COLAB:
    print('Environment: Colab')
    WORKING_ROOT = Path('/content')
else:
    print('Environment: Local')
    WORKING_ROOT = Path.cwd()

# Setup WORK_ROOT
if 'DATA_ROOT' in globals():
    WORK_ROOT = Path(DATA_ROOT)
    WORKING_ROOT = WORK_ROOT.parent
else:
    WORK_ROOT = WORKING_ROOT / 'cafa6_data'

WORK_ROOT.mkdir(parents=True, exist_ok=True)
for _d in ['parsed', 'features', 'external', 'Train', 'Test']:
    (WORK_ROOT / _d).mkdir(parents=True, exist_ok=True)

print(f'WORK_ROOT: {WORK_ROOT}')

# Training flag
TRAIN_LEVEL1 = bool(int(os.getenv('CAFA_TRAIN_LEVEL1', '1')))
print(f'TRAIN_LEVEL1: {TRAIN_LEVEL1}')

# Stub CheckpointStore (no HuggingFace operations)
class CheckpointStore:
    """Simplified checkpoint store without HuggingFace integration."""
    
    def __init__(self, work_root: Path):
        self.work_root = work_root
    
    def maybe_pull(self, stage: str, required_files: list[str] = None, note: str = '') -> bool:
        """Stub: does nothing in standalone mode."""
        print(f'[CHECKPOINT] {stage}: pull skipped (standalone mode)')
        return False
    
    def maybe_push(self, stage: str, required_paths: list[str] = None, note: str = '') -> bool:
        """Stub: does nothing in standalone mode."""
        print(f'[CHECKPOINT] {stage}: push skipped (standalone mode)')
        return False


STORE = CheckpointStore(work_root=WORK_ROOT)
print('CheckpointStore initialized (stub mode - no HF operations)')


In [None]:
# CELL 13a - Setup & Data Loading (Phase 2 canonical)
# =============================================
# 4. PHASE 2: LEVEL-1 MODELS (DIVERSE ENSEMBLE)
# =============================================
# Target selection source-of-truth: Colab_04b_first_submission_no_ankh.ipynb (aspect-split Top-K)


if TRAIN_LEVEL1:
    import gc
    import json
    import os
    from pathlib import Path

    import numpy as np
    import pandas as pd
    import psutil

    # AUDITOR: Hardware Check
    try:
        import torch

        if torch.cuda.is_available():
            print(f"[AUDITOR] GPU Detected: {torch.cuda.get_device_name(0)}")
            print(
                f"[AUDITOR] VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB"
            )
        else:
            print("[AUDITOR] WARNING: No GPU detected.")
    except Exception:
        pass

    def log_mem(tag: str = "") -> None:
        try:
            mem = psutil.virtual_memory()
            print(
                f"[MEM] {tag:<30} | Used: {mem.used/1e9:.2f}GB / {mem.total/1e9:.2f}GB ({mem.percent}%)"
            )
        except Exception:
            pass

    # WORK_ROOT recovery (safety)
    # Prefer canonical dataset root (cafa6_data/) and validate by presence of parsed artefacts.
    if "WORK_ROOT" not in locals() and "WORK_ROOT" not in globals():
        candidates = [
            Path("/content/cafa6_data"),
            Path("/content/work"),
            Path("/kaggle/working/work"),
            Path.cwd() / "cafa6_data",
            Path.cwd() / "artefacts_local" / "work",
        ]

        WORK_ROOT = None
        for c in candidates:
            if (c / "parsed" / "train_terms.parquet").exists():
                WORK_ROOT = c
                break

        if WORK_ROOT is None:
            for c in candidates:
                if c.exists():
                    WORK_ROOT = c
                    break

        if WORK_ROOT is None:
            WORK_ROOT = Path.cwd() / "cafa6_data"

        print(f"WORK_ROOT recovered: {WORK_ROOT}")

    # -----------------------------
    # Load targets + ids
    # -----------------------------
    print("Loading targets...")
    train_terms = pd.read_parquet(WORK_ROOT / "parsed" / "train_terms.parquet")
    train_ids = pd.read_feather(WORK_ROOT / "parsed" / "train_seq.feather")["id"].astype(str)
    test_ids = pd.read_feather(WORK_ROOT / "parsed" / "test_seq.feather")["id"].astype(str)

    # FIX: Clean IDs in train_ids to match EntryID format
    print("Applying ID cleaning fix...")
    train_ids_clean = train_ids.str.extract(r"\|(.*?)\|")[0]
    train_ids_clean = train_ids_clean.fillna(train_ids)

    # -----------------------------
    # Target Matrix Construction (Champion Strategy: 13,500 Terms)
    # 10,000 BP + 2,000 MF + 1,500 CC
    # -----------------------------
    print("Selecting Top-K terms per aspect (Champion Strategy)...")

    try:
        import obonet

        # Robust OBO Path Search
        possible_paths = [
            WORK_ROOT / "go-basic.obo",
            WORK_ROOT / "Train" / "go-basic.obo",
            WORK_ROOT.parent / "go-basic.obo",
            Path("go-basic.obo"),
            Path("Train/go-basic.obo"),
            Path("../Train/go-basic.obo"),
            Path("/content/cafa6_data/Train/go-basic.obo"),
        ]

        obo_path = None
        for p in possible_paths:
            if p.exists():
                obo_path = p
                break

        if obo_path is None:
            raise FileNotFoundError(
                f"CRITICAL: go-basic.obo not found. Searched: {[str(p) for p in possible_paths]}"
            )

        global PATH_GO_OBO
        PATH_GO_OBO = obo_path
        print(f"Global PATH_GO_OBO set to: {PATH_GO_OBO}")

        print(f"Loading OBO from {obo_path}...")
        graph = obonet.read_obo(obo_path)
        term_to_ns = {
            node: data.get("namespace", "unknown") for node, data in graph.nodes(data=True)
        }

        # Keep compatibility with downstream code that expects go_namespaces
        go_namespaces = term_to_ns

        ns_map = {
            "biological_process": "BP",
            "molecular_function": "MF",
            "cellular_component": "CC",
        }

        # Normalise any existing aspect column (some artefacts store full namespace strings)
        aspect_aliases = {
            "biological_process": "BP",
            "molecular_function": "MF",
            "cellular_component": "CC",
            "BP": "BP",
            "MF": "MF",
            "CC": "CC",
        }
        if "aspect" in train_terms.columns:
            train_terms["aspect"] = train_terms["aspect"].map(
                lambda a: aspect_aliases.get(str(a), "UNK")
            )
        else:
            train_terms["aspect"] = train_terms["term"].map(
                lambda t: ns_map.get(term_to_ns.get(t), "UNK")
            )

    except ImportError as e:
        raise RuntimeError("obonet not installed. Please install it.") from e

    # Canonical aspect split (04b)
    term_counts = train_terms.groupby(["aspect", "term"]).size().reset_index(name="count")
    targets_bp = (
        term_counts[term_counts["aspect"] == "BP"].nlargest(10000, "count")["term"].tolist()
    )
    targets_mf = (
        term_counts[term_counts["aspect"] == "MF"].nlargest(2000, "count")["term"].tolist()
    )
    targets_cc = (
        term_counts[term_counts["aspect"] == "CC"].nlargest(1500, "count")["term"].tolist()
    )

    # Guardrail: avoid silently switching target strategy due to aspect encoding mismatch
    ALLOW_GLOBAL_FALLBACK = False
    if len(targets_bp) == 0 and len(targets_mf) == 0 and len(targets_cc) == 0:
        aspect_vc = train_terms["aspect"].value_counts().to_dict() if "aspect" in train_terms.columns else {}
        msg = (
            "No BP/MF/CC aspect split found after normalisation. "
            f"aspect_vc={aspect_vc}. This would fall back to global Top-13,500; "
            "set ALLOW_GLOBAL_FALLBACK=True to override."
        )
        if ALLOW_GLOBAL_FALLBACK:
            print("  [WARNING] " + msg)
            top_terms = train_terms["term"].value_counts().head(13500).index.tolist()
        else:
            raise RuntimeError(msg)
    else:
        # Stable, deterministic ordering: BP then MF then CC with de-dup preserving order
        top_terms = []
        seen = set()
        for t in (targets_bp + targets_mf + targets_cc):
            if t not in seen:
                top_terms.append(t)
                seen.add(t)
        print(f"  Selected: {len(targets_bp)} BP + {len(targets_mf)} MF + {len(targets_cc)} CC")

    # Persist label contract for downstream stages
    top_terms_path = WORK_ROOT / "features" / "top_terms_13500.json"
    top_terms_path.parent.mkdir(parents=True, exist_ok=True)
    if top_terms_path.exists():
        try:
            with open(top_terms_path, "r", encoding="utf-8") as f:
                top_terms_disk = json.load(f)
            if isinstance(top_terms_disk, list) and len(top_terms_disk) > 0:
                top_terms = [str(x) for x in top_terms_disk]
                print(f"Loaded existing top_terms_13500.json (n={len(top_terms)})")
        except Exception as e:
            print(f"[WARNING] Failed to load existing top_terms_13500.json: {e}")
    else:
        with open(top_terms_path, "w", encoding="utf-8") as f:
            json.dump(list(top_terms), f)
        print("Saved: top_terms_13500.json")

    # -----------------------------
    # Stable target contract (audited: 1,585 terms)
    # Definition: GO terms with >= 50 positives AND valid namespace (BP/MF/CC)
    # Stored separately from top_terms_13500.json (do not mix contracts).
    # -----------------------------
    stable_terms_path = WORK_ROOT / "features" / "stable_terms_1585.json"
    stable_meta_path = WORK_ROOT / "features" / "stable_terms_1585_meta.json"
    noise_floor = 50

    if stable_terms_path.exists():
        try:
            stable_terms = json.loads(stable_terms_path.read_text(encoding="utf-8"))
            stable_terms = [str(t) for t in stable_terms]
            print(f"Loaded existing stable_terms_1585.json (n={len(stable_terms)})")
        except Exception as e:
            raise RuntimeError(f"Failed to load {stable_terms_path}: {e}")
    else:
        # Compute from Phase-1 truth (train_terms.parquet) and OBO namespace mapping already loaded above.
        stable_bp = (
            term_counts[(term_counts["aspect"] == "BP") & (term_counts["count"] >= noise_floor)]
            .sort_values(["count", "term"], ascending=[False, True])["term"]
            .astype(str)
            .tolist()
        )
        stable_mf = (
            term_counts[(term_counts["aspect"] == "MF") & (term_counts["count"] >= noise_floor)]
            .sort_values(["count", "term"], ascending=[False, True])["term"]
            .astype(str)
            .tolist()
        )
        stable_cc = (
            term_counts[(term_counts["aspect"] == "CC") & (term_counts["count"] >= noise_floor)]
            .sort_values(["count", "term"], ascending=[False, True])["term"]
            .astype(str)
            .tolist()
        )
        stable_terms = stable_bp + stable_mf + stable_cc
        stable_terms_path.write_text(json.dumps(stable_terms), encoding="utf-8")
        stable_meta_path.write_text(
            json.dumps(
                {
                    "noise_floor": noise_floor,
                    "counts": {"BP": len(stable_bp), "MF": len(stable_mf), "CC": len(stable_cc)},
                    "total": len(stable_terms),
                },
                indent=2,
            ),
            encoding="utf-8",
        )
        print(f"Saved: stable_terms_1585.json (n={len(stable_terms)})")

    if len(stable_terms) != 1585:
        raise RuntimeError(f"Stable term contract mismatch: expected 1585, got {len(stable_terms)}")

    top_term_to_idx = {t: i for i, t in enumerate(top_terms)}
    missing_stable = [t for t in stable_terms if t not in top_term_to_idx]
    if missing_stable:
        raise RuntimeError(
            "Stable terms contain items not present in top_terms_13500.json. "
            f"Missing={len(missing_stable)} (example: {missing_stable[:10]})"
        )

    stable_idx = np.asarray([top_term_to_idx[t] for t in stable_terms], dtype=np.int64)
    print(f"Stable targets ready: n={int(stable_idx.shape[0])} (expected 1585)")

    train_terms_top = train_terms[train_terms["term"].isin(top_terms)]
    Y_df = train_terms_top.pivot_table(index="EntryID", columns="term", aggfunc="size", fill_value=0)
    Y_df = Y_df.reindex(train_ids_clean, fill_value=0)
    Y = Y_df.values.astype(np.float32)
    print(f"Targets: Y={Y.shape}")

    # -----------------------------
    # Feature loading helper (Memory Optimised)
    # -----------------------------
    FEAT_DIR = WORK_ROOT / "features"

    def load_features_dict(split: str = "both"):
        log_mem(f"Start load_features_dict({split})")
        print(f"Loading multimodal features (mode={split})...")

        ft_train = {}
        ft_test = {}

        def _load_pair(stem: str):
            tr = FEAT_DIR / f"train_embeds_{stem}.npy"
            te = FEAT_DIR / f"test_embeds_{stem}.npy"
            return tr, te

        # All modalities are mandatory.
        stems = [
            ("t5", "t5"),
            ("esm2", "esm2_650m"),
            ("esm2_3b", "esm2_3b"),
            ("ankh", "ankh"),
            ("text", "text"),
        ]

        for stem, key in stems:
            tr_path, te_path = _load_pair(stem)
            if not (tr_path.exists() and te_path.exists()):
                raise FileNotFoundError(f"Missing mandatory embeddings for {stem}: {tr_path} or {te_path}")

            if split in ["both", "train"]:
                ft_train[key] = np.load(tr_path, mmap_mode="r")
            if split in ["both", "test"]:
                ft_test[key] = np.load(te_path, mmap_mode="r")

        taxa_train_path = WORK_ROOT / "parsed" / "train_taxa.feather"
        taxa_test_path = WORK_ROOT / "parsed" / "test_taxa.feather"

        if not (taxa_train_path.exists() and taxa_test_path.exists()):
            raise FileNotFoundError(f"Missing mandatory taxa features: {taxa_train_path} or {taxa_test_path}")

        from sklearn.preprocessing import OneHotEncoder

        tax_tr = pd.read_feather(taxa_train_path).astype({"id": str})
        tax_te = pd.read_feather(taxa_test_path).astype({"id": str})
        enc = OneHotEncoder(handle_unknown="ignore", sparse_output=False, dtype=np.float32)
        enc.fit(pd.concat([tax_tr[["taxon_id"]], tax_te[["taxon_id"]]], axis=0))

        if split in ["both", "train"]:
            tax_tr = tax_tr.set_index("id").reindex(train_ids, fill_value=0).reset_index()
            ft_train["taxa"] = enc.transform(tax_tr[["taxon_id"]]).astype(np.float32)
        if split in ["both", "test"]:
            tax_te = tax_te.set_index("id").reindex(test_ids, fill_value=0).reset_index()
            ft_test["taxa"] = enc.transform(tax_te[["taxon_id"]]).astype(np.float32)

        log_mem(f"End load_features_dict({split})")
        if split == "train":
            return ft_train
        if split == "test":
            return ft_test
        return ft_train, ft_test

    # Materialise feature dicts (mmap arrays where possible)
    features_train, features_test = load_features_dict(split="both")

    # Flat concatenation order for classical models (LR/GBDT)
    FLAT_KEYS = [k for k in ["t5", "esm2_650m", "esm2_3b", "ankh", "text", "taxa"] if k in features_train]
    if "ankh" not in FLAT_KEYS:
        raise RuntimeError("Ankh is mandatory but was not loaded into features_train.")
    print(f"Flat X keys={FLAT_KEYS}")

    # -----------------------------
    # Disk-backed X / X_test (for RAM-safe downstream cells)
    # -----------------------------
    X_train_path = FEAT_DIR / "X_train_mmap.npy"
    X_test_path = FEAT_DIR / "X_test_mmap.npy"

    def _build_X_memmaps(chunk_size: int = 10000) -> None:
        dims = {k: int(features_train[k].shape[1]) for k in FLAT_KEYS}
        total_dim = int(sum(dims.values()))
        n_tr = int(len(train_ids))
        n_te = int(len(test_ids))

        print(f"Building X memmaps: train=({n_tr}, {total_dim}) test=({n_te}, {total_dim})")
        X_mm = np.lib.format.open_memmap(
            str(X_train_path), mode="w+", dtype=np.float32, shape=(n_tr, total_dim)
        )
        Xte_mm = np.lib.format.open_memmap(
            str(X_test_path), mode="w+", dtype=np.float32, shape=(n_te, total_dim)
        )

        col = 0
        for k in FLAT_KEYS:
            d = dims[k]
            print(f"  Streaming {k} into cols {col}:{col + d}")
            for i in range(0, n_tr, chunk_size):
                j = min(i + chunk_size, n_tr)
                X_mm[i:j, col : col + d] = np.asarray(features_train[k][i:j], dtype=np.float32)
            for i in range(0, n_te, chunk_size):
                j = min(i + chunk_size, n_te)
                Xte_mm[i:j, col : col + d] = np.asarray(features_test[k][i:j], dtype=np.float32)
            col += d

        X_mm.flush()
        Xte_mm.flush()

    if X_train_path.exists() and X_test_path.exists():
        print("X memmaps already exist; skipping build.")
    else:
        _build_X_memmaps(chunk_size=5000)

    X = np.load(X_train_path, mmap_mode="r")
    X_test = np.load(X_test_path, mmap_mode="r")

    log_mem("Phase 2 setup done")


[AUDITOR] GPU Detected: NVIDIA GeForce RTX 2070
[AUDITOR] VRAM: 8.59 GB
Loading targets...
Applying ID cleaning fix...
Selecting Top-K terms per aspect (Champion Strategy)...
Global PATH_GO_OBO set to: c:\Users\Olale\Documents\Codebase\Science\cafa-6-protein-function-prediction\cafa6_data\Train\go-basic.obo
Loading OBO from c:\Users\Olale\Documents\Codebase\Science\cafa-6-protein-function-prediction\cafa6_data\Train\go-basic.obo...
  Selected: 10000 BP + 2000 MF + 1500 CC
Loaded existing top_terms_13500.json (n=13500)
Loaded existing stable_terms_1585.json (n=1585)
Stable targets ready: n=1585 (expected 1585)


MemoryError: Unable to allocate 8.23 GiB for an array with shape (13500, 81865) and data type int64

In [None]:
# CELL 3b - KNN Helper Functions & Variable Setup
# This cell defines the helper functions and variables needed by the KNN cell

import numpy as np
import pandas as pd
import json
from pathlib import Path


# Helper function: L2 normalization
def _l2_norm(X):
    """L2-normalize rows of X (convert to unit vectors).
    This transforms cosine similarity to dot product for faster GPU computation.
    """
    norms = np.linalg.norm(X, axis=1, keepdims=True)
    norms = np.maximum(norms, 1e-12)  # Avoid division by zero
    return X / norms


# Load IA weights for weighted neighbor voting
print('[KNN Setup] Loading IA weights...')
FEAT_DIR = WORK_ROOT / 'features'
top_terms_path = FEAT_DIR / 'top_terms_13500.json'

if not top_terms_path.exists():
    raise FileNotFoundError(f'Missing {top_terms_path}. Ensure Phase 2 data loading completed.')

top_terms = json.loads(top_terms_path.read_text(encoding='utf-8'))

# Load IA weights
weights_full = None
try:
    ia_path = next(
        (p for p in [WORK_ROOT / 'IA.tsv', FEAT_DIR / 'IA.tsv'] if p.exists()),
        None
    )
    if ia_path:
        ia_df = pd.read_csv(ia_path, sep='\t', header=None, names=['term', 'ia'])
        ia_map = dict(zip(ia_df['term'].astype(str), ia_df['ia'].astype(np.float32)))
        # Map weights to our master top_terms list
        weights_full = np.asarray([ia_map.get(t, 1.0) for t in top_terms], dtype=np.float32)
        print(f'  Loaded IA weights from {ia_path}')
    else:
        print('  [WARN] IA.tsv not found; using uniform weights')
        weights_full = np.ones(len(top_terms), dtype=np.float32)
except Exception as e:
    print(f'  [WARN] Failed to load IA weights: {e}; using uniform weights')
    weights_full = np.ones(len(top_terms), dtype=np.float32)

print(f'  IA weights shape: {weights_full.shape}')

# Prepare Y_knn (target labels matrix)
print('[KNN Setup] Preparing target labels (Y_knn)...')
if 'Y' not in globals():
    raise RuntimeError('Y not defined. Ensure the Data Loading cell (Cell 3) ran successfully.')

Y_knn = Y  # Y is already defined in Cell 3
print(f'  Y_knn shape: {Y_knn.shape}')

# Prepare X_knn_test (test features)
print('[KNN Setup] Preparing test features (X_knn_test)...')
if 'features_test' not in globals():
    raise RuntimeError('features_test not defined. Ensure the Data Loading cell (Cell 3) ran successfully.')

if 'esm2_3b' not in features_test:
    raise FileNotFoundError("Missing required modality 'esm2_3b' in features_test.")

X_knn_test = features_test['esm2_3b'].astype(np.float32)
print(f'  X_knn_test shape: {X_knn_test.shape}')

print('[KNN Setup] All helper functions and variables ready! ✓')


In [None]:
# CELL 13E - KNN (cosine; ESM2-3B) + checkpoint push
if not TRAIN_LEVEL1:
    print('Skipping KNN (TRAIN_LEVEL1=False).')
else:
    # RANK-1: cuML migration with runtime detection (Auditor-approved v2)
    # Transforms cosine→dot-product (Manual GEMM Fast Path); 10-20× speedup on A100
    try:
        from cuml.neighbors import NearestNeighbors as cuNearestNeighbors
        USE_CUML = True
        print('[KNN] Using cuML NearestNeighbors (GPU-accelerated)')
    except ImportError:
        from sklearn.neighbors import NearestNeighbors
        USE_CUML = False
        print('[WARN] cuML not available; using sklearn NearestNeighbors (CPU, slower). Install cuML for Rank-1 performance.')
    
    from sklearn.model_selection import KFold
    import json
    PRED_DIR = WORK_ROOT / 'features' / 'level1_preds'
    PRED_DIR.mkdir(parents=True, exist_ok=True)

    # RANK-1: Load aspect-specific thresholds (same as DNN)
    FEAT_DIR = WORK_ROOT / 'features'
    thr_path = FEAT_DIR / 'aspect_thresholds.json'
    if not thr_path.exists():
        print(f'[WARN] aspect_thresholds.json not found at {thr_path}. Run Cell 13F first for per-aspect thresholds (proven +3.3% F1 boost).')
        ASPECT_THRESHOLDS = {'ALL': 0.3, 'BP': 0.25, 'MF': 0.35, 'CC': 0.35, 'UNK': 0.3}  # fallback defaults
    else:
        ASPECT_THRESHOLDS = json.loads(thr_path.read_text(encoding='utf-8'))
        print(f'[KNN] Loaded aspect thresholds: {ASPECT_THRESHOLDS}')
    knn_oof_path = PRED_DIR / 'oof_pred_knn.npy'
    knn_test_path = PRED_DIR / 'test_pred_knn.npy'
    # Backwards-compatible copies (some downstream code loads from WORK_ROOT/features)
    knn_oof_compat = WORK_ROOT / 'features' / 'oof_pred_knn.npy'
    knn_test_compat = WORK_ROOT / 'features' / 'test_pred_knn.npy'
    # Option to force retraining (set FORCE_RETRAIN=True to retrain)
    FORCE_RETRAIN = bool(int(os.getenv("FORCE_RETRAIN", "0")))
    
    if not FORCE_RETRAIN and knn_oof_path.exists() and knn_test_path.exists():
        print("[KNN] Predictions already exist. Loading from disk...")
        print("  Set FORCE_RETRAIN=1 environment variable to retrain.")
        oof_pred_knn = np.load(knn_oof_path)
        test_pred_knn = np.load(knn_test_path)
        oof_max_sim = None
    else:
        if 'features_train' not in globals() or 'features_test' not in globals():
            raise RuntimeError('Missing `features_train`/`features_test`. Run the Phase 2 feature load cell first.')
        if 'esm2_3b' not in features_train:
            raise FileNotFoundError("Missing required modality 'esm2_3b' in features_train. Ensure features/train_embeds_esm2_3b.npy exists.")
        X_knn = features_train['esm2_3b'].astype(np.float32)
        
        # RANK-1: L2 pre-normalization (Auditor requirement)
        # Transforms cosine→dot-product for Manual GEMM Fast Path on A100
        print('[KNN] Applying L2 normalization...')
        X_knn = _l2_norm(X_knn)
        X_knn_test_norm = _l2_norm(X_knn_test)
        
        KNN_K = int(globals().get('KNN_K', 10))  # FIXED: reduced from 50 to 10 (matches baseline)
        KNN_BATCH = int(globals().get('KNN_BATCH', 256))
        n_splits = 5
        kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
        oof_pred_knn = np.zeros((X_knn.shape[0], Y_knn.shape[1]), dtype=np.float32)
        test_pred_knn = np.zeros((X_knn_test_norm.shape[0], Y_knn.shape[1]), dtype=np.float32)
        oof_max_sim = np.zeros((X_knn.shape[0],), dtype=np.float32)
        
        # IA-weighted broadcast for neighbor voting
        w_ia_broadcast = weights_full[np.newaxis, np.newaxis, :]  # (1, 1, L)
        
        for fold, (tr_idx, va_idx) in enumerate(kf.split(X_knn), start=1):
            print(f'[KNN] Fold {fold}/{n_splits}')
            X_tr = X_knn[tr_idx]
            X_va = X_knn[va_idx]
            Y_tr = Y_knn[tr_idx]
            
            # Fit KNN on training fold
            if USE_CUML:
                # cuML: euclidean distance on L2-normalized vectors = cosine similarity (Fast Path)
                knn = cuNearestNeighbors(n_neighbors=KNN_K, metric='euclidean')
            else:
                # sklearn fallback: cosine metric
                knn = NearestNeighbors(n_neighbors=KNN_K, metric='cosine', n_jobs=-1)
            
            knn.fit(X_tr)
            
            # Predict on validation fold
            dists_va, neigh_va = knn.kneighbors(X_va, return_distance=True)
            
            # Convert distances to similarities
            if USE_CUML:
                # For euclidean on L2-normalized vectors: sim = 1 - dist^2/2
                sims_va = np.clip(1.0 - (dists_va ** 2) / 2.0, 0.0, 1.0).astype(np.float32)
            else:
                # For cosine metric: sim = 1 - dist
                sims_va = np.clip(1.0 - dists_va, 0.0, 1.0).astype(np.float32)
            
            # IA-weighted neighbor voting
            for i in range(0, len(va_idx), KNN_BATCH):
                j = min(i + KNN_BATCH, len(va_idx))
                neigh_b = neigh_va[i:j]
                sims_b = sims_va[i:j]
                denom = np.maximum(sims_b.sum(axis=1, keepdims=True), 1e-8)
                Y_nei = Y_tr[neigh_b]  # (B, K, L)
                # IA-weighted aggregation: sims @ (Y_nei * w_ia)
                scores = ((sims_b[:, :, np.newaxis] * Y_nei * w_ia_broadcast).sum(axis=1) / denom).astype(np.float32)
                oof_pred_knn[va_idx[i:j]] = scores
                # Track max similarity
                oof_max_sim[va_idx[i:j]] = sims_b.max(axis=1)
        
        # Train final model on all training data for test predictions
        print('[KNN] Training final model on all training data...')
        if USE_CUML:
            knn_final = cuNearestNeighbors(n_neighbors=KNN_K, metric='euclidean')
        else:
            knn_final = NearestNeighbors(n_neighbors=KNN_K, metric='cosine', n_jobs=-1)
        knn_final.fit(X_knn)
        
        # Predict on test set
        dists_te, neigh_te = knn_final.kneighbors(X_knn_test_norm, return_distance=True)
        
        # Convert distances to similarities
        if USE_CUML:
            sims_te = np.clip(1.0 - (dists_te ** 2) / 2.0, 0.0, 1.0).astype(np.float32)
        else:
            sims_te = np.clip(1.0 - dists_te, 0.0, 1.0).astype(np.float32)
        
        denom_te = np.maximum(sims_te.sum(axis=1, keepdims=True), 1e-8)
        
        # RANK-1: IA-weighted voting for test predictions
        for i in range(0, X_knn_test_norm.shape[0], KNN_BATCH):
            j = min(i + KNN_BATCH, X_knn_test_norm.shape[0])
            neigh_b = neigh_te[i:j]
            sims_b = sims_te[i:j]
            Y_nei = Y_knn[neigh_b]
            scores = ((sims_b[:, :, np.newaxis] * Y_nei * w_ia_broadcast).sum(axis=1) / denom_te[i:j]).astype(np.float32)
            test_pred_knn[i:j] = scores
        
        # NOTE: Scores are already normalized by sum of similarities
        # DO NOT apply per-protein max normalization - it destroys calibration
        # Scores are naturally in [0, 1] range from IA-weighted aggregation
        
        # This ensures each protein's scores are calibrated to [0, 1] range
        # Normalize OOF predictions
        # Normalize test predictions
        # ===== F1 EVALUATION AFTER TRAINING =====
        print('\n[KNN] Evaluating F1 performance on OOF predictions...')
        
        # Simple IA-weighted F1 calculation
        def calc_ia_f1(y_true, y_pred, weights, threshold=0.3):
            y_true_bin = (y_true > 0).astype(np.int8)
            y_pred_bin = (y_pred >= threshold).astype(np.int8)
            
            tp = (y_pred_bin & y_true_bin).sum(axis=0).astype(np.float64)
            pred_pos = y_pred_bin.sum(axis=0).astype(np.float64)
            true_pos = y_true_bin.sum(axis=0).astype(np.float64)
            
            w = weights.astype(np.float64)
            w_tp = float((w * tp).sum())
            w_pred = float((w * pred_pos).sum())
            w_true = float((w * true_pos).sum())
            
            precision = (w_tp / w_pred) if w_pred > 0 else 0.0
            recall = (w_tp / w_true) if w_true > 0 else 0.0
            f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
            
            return f1, precision, recall
        
        # Calculate F1 at multiple thresholds
        thresholds = [0.01, 0.02, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30]
        print('\n  Threshold   F1      Precision  Recall')
        print('  ' + '-'*42)
        
        best_f1 = 0.0
        best_thr = 0.3
        
        for thr in thresholds:
            f1, prec, rec = calc_ia_f1(Y_knn, oof_pred_knn, weights_full, threshold=thr)
            print(f'  {thr:5.2f}      {f1:.4f}  {prec:.4f}     {rec:.4f}')
            if f1 > best_f1:
                best_f1 = f1
                best_thr = thr
        
        print('  ' + '-'*42)
        print(f'  Best: F1={best_f1:.4f} @ threshold={best_thr:.2f}')
        print()
        
        # RANK-1: Finite-value quality gates (Auditor requirement: prevents NaN/Inf propagation to GCN stacker)
        if not np.isfinite(oof_pred_knn).all():
            print('[WARN] KNN OOF contains NaN/Inf; clipping to valid range [0, 1]')
            oof_pred_knn = np.clip(oof_pred_knn, 0.0, 1.0)
        if not np.isfinite(test_pred_knn).all():
            print('[WARN] KNN test contains NaN/Inf; clipping to valid range [0, 1]')
            test_pred_knn = np.clip(test_pred_knn, 0.0, 1.0)
        
        np.save(knn_oof_path, oof_pred_knn)
        np.save(knn_test_path, test_pred_knn)
        np.save(knn_oof_compat, oof_pred_knn)
        np.save(knn_test_compat, test_pred_knn)
        print('Saved:', knn_oof_path)
        print('Saved:', knn_test_path)
        print('Saved:', knn_oof_compat)
        print('Saved:', knn_test_compat)
    
    # Checkpoint push (stub in standalone mode)
    STORE.maybe_push(
        stage='stage_07d_level1_knn',
        required_paths=[
            str((WORK_ROOT / 'features' / 'top_terms_13500.json').as_posix()),
            str(knn_oof_path.as_posix()),
            str(knn_test_path.as_posix()),
        ],
        note='Level-1 KNN (cosine) predictions using ESM2-3B embeddings (OOF + test).',
    )
    
    # Diagnostics: similarity distribution
    try:
        import matplotlib.pyplot as plt
        if oof_max_sim is not None:
            plt.figure(figsize=(10, 4))
            plt.hist(oof_max_sim, bins=50)
            plt.title('KNN OOF diagnostic: max cosine similarity to neighbours (per protein)')
            plt.xlabel('max similarity')
            plt.ylabel('count')
            plt.grid(True, alpha=0.3)
            plt.show()
    except Exception as e:
        print('KNN diagnostics skipped:', repr(e))


In [None]:
# CELL 06 - Generate Submission File from KNN Predictions
# This cell creates a submission.tsv file from the KNN test predictions

import json
from pathlib import Path
import numpy as np
import pandas as pd

print('[SUBMISSION] Generating submission file from KNN predictions...')

# Check if submission already exists
submission_path = WORK_ROOT / 'submission.tsv'
if submission_path.exists():
    print(f'  submission.tsv already exists at {submission_path}')
    print('  To regenerate, delete it first.')
else:
    # Load KNN test predictions
    knn_test_path = WORK_ROOT / 'features' / 'level1_preds' / 'test_pred_knn.npy'
    if not knn_test_path.exists():
        knn_test_path = WORK_ROOT / 'features' / 'test_pred_knn.npy'
    
    if not knn_test_path.exists():
        raise FileNotFoundError(f'Missing KNN test predictions. Expected at {knn_test_path}. Run the KNN training cell first.')
    
    preds = np.load(knn_test_path).astype(np.float32)
    print(f'  Loaded KNN predictions: {preds.shape}')
    
    # Load test IDs
    test_seq_path = WORK_ROOT / 'parsed' / 'test_seq.feather'
    if not test_seq_path.exists():
        raise FileNotFoundError(f'Missing {test_seq_path}. Run the data loading cell first.')
    
    test_ids = pd.read_feather(test_seq_path)['id'].astype(str)
    # Extract UniProt IDs from FASTA format (e.g., >sp|P12345|NAME)
    test_ids = test_ids.str.extract(r'\|(.+?)\|', expand=False).fillna(test_ids)
    print(f'  Loaded {len(test_ids)} test IDs')
    
    # Load term list
    terms_path = WORK_ROOT / 'features' / 'top_terms_13500.json'
    if not terms_path.exists():
        raise FileNotFoundError(f'Missing {terms_path}. Run the data loading cell first.')
    
    top_terms = json.loads(terms_path.read_text(encoding='utf-8'))
    print(f'  Loaded {len(top_terms)} terms')
    
    if preds.shape[1] != len(top_terms):
        raise ValueError(f'Shape mismatch: preds has {preds.shape[1]} terms, top_terms has {len(top_terms)}')
    
    # Load GO ontology for hierarchy propagation (optional but recommended)
    go_obo_path = WORK_ROOT / 'Train' / 'go-basic.obo'
    if go_obo_path.exists():
        print('  Loading GO ontology for hierarchy propagation...')
        
        def parse_obo(path: Path):
            parents = {}
            namespaces = {}
            cur_id, cur_ns = None, None
            with path.open('r', encoding='utf-8') as f:
                for line in f:
                    line = line.strip()
                    if line == '[Term]':
                        if cur_id and cur_ns:
                            namespaces[cur_id] = cur_ns
                        cur_id, cur_ns = None, None
                    elif line.startswith('id: GO:'):
                        cur_id = line.split('id: ', 1)[1]
                    elif line.startswith('namespace:'):
                        cur_ns = line.split('namespace: ', 1)[1]
                    elif line.startswith('is_a:') and cur_id:
                        parent = line.split('is_a: ', 1)[1].split(' ! ')[0]
                        parents.setdefault(cur_id, set()).add(parent)
                if cur_id and cur_ns:
                    namespaces[cur_id] = cur_ns
            return parents, namespaces
        
        go_parents, go_namespaces = parse_obo(go_obo_path)
        
        # Apply hierarchy propagation (Max/Min)
        print('  Applying hierarchy propagation...')
        df_pred = pd.DataFrame(preds, columns=top_terms)
        term_set = set(top_terms)
        term_to_parents = {}
        term_to_children = {}
        
        for term in top_terms:
            parents = go_parents.get(term, set())
            if parents:
                parents = parents.intersection(term_set)
                if parents:
                    term_to_parents[term] = list(parents)
                    for p in parents:
                        term_to_children.setdefault(p, []).append(term)
        
        # Max Propagation (Child -> Parent)
        N_PROP_ITERS = 12
        for _ in range(N_PROP_ITERS):
            for child, parents in term_to_parents.items():
                child_scores = df_pred[child].values
                for parent in parents:
                    df_pred[parent] = np.maximum(df_pred[parent].values, child_scores)
        
        # Min Propagation (Parent -> Child)
        for _ in range(N_PROP_ITERS):
            for parent, children in term_to_children.items():
                parent_scores = df_pred[parent].values
                for child in children:
                    df_pred[child] = np.minimum(df_pred[child].values, parent_scores)
        
        preds = df_pred.values.astype(np.float32)
        print('  Hierarchy propagation complete')
    else:
        print('  [WARN] GO ontology not found; skipping hierarchy propagation')
        df_pred = pd.DataFrame(preds, columns=top_terms)
    
    # Apply per-aspect thresholds (if available)
    thr_path = WORK_ROOT / 'features' / 'aspect_thresholds.json'
    if thr_path.exists():
        print('  Applying per-aspect thresholds...')
        ASPECT_THRESHOLDS = json.loads(thr_path.read_text(encoding='utf-8'))
        print(f'    Thresholds: {ASPECT_THRESHOLDS}')
        
        ns_to_aspect = {
            'molecular_function': 'MF',
            'biological_process': 'BP',
            'cellular_component': 'CC',
        }
        
        if go_obo_path.exists():
            aspects = np.array([
                ns_to_aspect.get(go_namespaces.get(t, 'unknown'), 'UNK')
                for t in top_terms
            ], dtype='<U3')
        else:
            # Fallback: use ALL threshold
            aspects = np.array(['ALL'] * len(top_terms), dtype='<U3')
        
        thr_vec = np.array([
            float(ASPECT_THRESHOLDS.get(a, ASPECT_THRESHOLDS.get('ALL', 0.3)))
            for a in aspects
        ], dtype=np.float32)
        
        pred_np = preds
        pred_np = np.where(pred_np >= thr_vec[None, :], pred_np, 0.0).astype(np.float32)
        df_pred = pd.DataFrame(pred_np, columns=top_terms)
    else:
        print('  [WARN] Aspect thresholds not found; using default threshold 0.3')
        df_pred = pd.DataFrame(
            np.where(preds >= 0.05, preds, 0.0).astype(np.float32),
            columns=top_terms
        )
    
    # Format submission (CAFA rules)
    print('  Formatting submission...')
    df_pred['EntryID'] = test_ids.values
    submission = df_pred.melt(id_vars='EntryID', var_name='term', value_name='score')
    
    # Enforce score range + remove zeros
    submission['score'] = submission['score'].clip(lower=0.0, upper=1.0)
    submission = submission[submission['score'] > 0.0]
    
    # Keep top 1500 per protein (CAFA rule)
    submission = submission.sort_values(['EntryID', 'score'], ascending=[True, False])
    submission = submission.groupby('EntryID', sort=False).head(1500)
    
    # Write with <= 3 significant figures
    submission.to_csv(
        submission_path,
        sep='\t',
        index=False,
        header=False,
        float_format='%.3g',
    )
    
    print(f'\n✅ Submission saved to {submission_path}')
    print(f'   Total predictions: {len(submission):,}')
    print(f'   Proteins: {submission["EntryID"].nunique():,}')
    print(f'   Avg predictions/protein: {len(submission) / submission["EntryID"].nunique():.1f}')


In [None]:
# CELL 07 - Submit to Kaggle Competition
# This cell submits the generated submission.tsv to the CAFA-6 competition

import os
import subprocess
from pathlib import Path

# Load Kaggle credentials from environment (including .env if present)
def _load_kaggle_credentials():
    """Load Kaggle credentials from .env file or environment variables."""
    # Try to load from .env file
    env_path = Path.cwd() / '.env'
    if env_path.exists():
        print('[KAGGLE] Loading credentials from .env file...')
        try:
            with open(env_path, 'r') as f:
                for line in f:
                    line = line.strip()
                    if line and not line.startswith('#') and '=' in line:
                        key, value = line.split('=', 1)
                        key = key.strip()
                        value = value.strip().strip('"').strip("'")
                        if key in ['KAGGLE_USERNAME', 'KAGGLE_KEY'] and value:
                            os.environ[key] = value
        except Exception as e:
            print(f'  [WARN] Failed to load .env: {e}')
    
    # Check if credentials are set
    username = os.environ.get('KAGGLE_USERNAME', '').strip()
    key = os.environ.get('KAGGLE_KEY', '').strip()
    
    if username and key:
        print(f'  ✓ Kaggle credentials loaded: {username}')
        return True
    else:
        print('  ✗ Kaggle credentials not found in environment or .env')
        return False

# Configuration
COMPETITION_NAME = 'cafa-6-protein-function-prediction'
SUBMISSION_MESSAGE = 'KNN standalone submission (ESM2-3B embeddings)'

# Check if submission file exists
submission_path = WORK_ROOT / 'submission.tsv'
if not submission_path.exists():
    raise FileNotFoundError(
        f'Submission file not found at {submission_path}. '
        'Run the submission generation cell first.'
    )

print(f'[KAGGLE] Preparing to submit to {COMPETITION_NAME}...')

# Load credentials
if not _load_kaggle_credentials():
    print('\n' + '='*60)
    print('ERROR: Kaggle credentials not configured')
    print('='*60)
    print('\nPlease set KAGGLE_USERNAME and KAGGLE_KEY:')
    print('\n1. Create a .env file in the project root with:')
    print('   KAGGLE_USERNAME=your_username')
    print('   KAGGLE_KEY=your_api_key')
    print('\n2. Or set environment variables:')
    print('   export KAGGLE_USERNAME=your_username')
    print('   export KAGGLE_KEY=your_api_key')
    print('\nGet your API key from: https://www.kaggle.com/settings')
    print('='*60)
    raise RuntimeError('Kaggle credentials not configured')

# Check if Kaggle CLI is installed
try:
    result = subprocess.run(
        ['kaggle', '--version'],
        capture_output=True,
        text=True
    )
    print(f'  ✓ Kaggle CLI version: {result.stdout.strip()}')
except FileNotFoundError:
    print('\n' + '='*60)
    print('ERROR: Kaggle CLI not installed')
    print('='*60)
    print('\nPlease install the Kaggle CLI:')
    print('  pip install kaggle')
    print('\nOr if using conda:')
    print('  conda install -c conda-forge kaggle')
    print('='*60)
    raise RuntimeError('Kaggle CLI not installed')

# Build the kaggle command
print(f'\n[KAGGLE] Submitting {submission_path}...')
print(f'  Message: "{SUBMISSION_MESSAGE}"')

cmd = [
    'kaggle',
    'competitions',
    'submit',
    '-c', COMPETITION_NAME,
    '-f', str(submission_path),
    '-m', SUBMISSION_MESSAGE
]

# Execute submission
try:
    result = subprocess.run(
        cmd,
        check=True,
        capture_output=True,
        text=True,
        encoding='utf-8',
        errors='replace'
    )
    print('\n' + '='*60)
    print('✅ SUBMISSION SUCCESSFUL!')
    print('='*60)
    if result.stdout:
        print('\nKaggle output:')
        print(result.stdout)
    if result.stderr:
        print('\nAdditional info:')
        print(result.stderr)
    print('\nCheck your submission at:')
    print(f'https://www.kaggle.com/competitions/{COMPETITION_NAME}/submissions')
    print('='*60)
except subprocess.CalledProcessError as e:
    print('\n' + '='*60)
    print('❌ SUBMISSION FAILED')
    print('='*60)
    if e.stderr:
        print('\nError output:')
        print(e.stderr)
    if e.stdout:
        print('\nStdout:')
        print(e.stdout)
    print('\nCommon issues:')
    print('  - Competition rules not accepted')
    print('  - Invalid submission format')
    print('  - Daily submission limit reached')
    print('  - Incorrect credentials')
    print('='*60)
    raise
