In [1]:
# CELL 01 - Bootstrap repo (clone/update)
# Colab bootstrap (no external notebook/script execution)
from pathlib import Path
import os
import shutil
import subprocess

REPO_URL = os.environ.get('CAFA_REPO_GIT_URL', 'https://github.com/PeterOla/cafa-6-protein-function-prediction.git')
REPO_DIR = Path(os.environ.get('CAFA_REPO_DIR', '/content/cafa-6-protein-function-prediction'))
SAFE_CWD = Path('/content') if Path('/content').exists() else Path('/')
def run(cmd: list[str]) -> None:
    cmd_str = ' '.join(cmd)
    print('+', cmd_str)
    p = subprocess.run(cmd, text=True, capture_output=True, cwd=str(SAFE_CWD))
    if p.stdout.strip():
        print(p.stdout)
    if p.stderr.strip():
        print(p.stderr)
    if p.returncode != 0:
        raise RuntimeError(f'Command failed (exit={p.returncode}): {cmd_str}')
os.chdir(SAFE_CWD)
if REPO_DIR.exists() and (REPO_DIR / '.git').is_dir():
    run(['git', '-C', str(REPO_DIR), 'fetch', '--depth', '1', 'origin'])
    run(['git', '-C', str(REPO_DIR), 'reset', '--hard', 'origin/HEAD'])
else:
    if REPO_DIR.exists():
        shutil.rmtree(REPO_DIR, ignore_errors=True)
    run(['git', 'clone', '--depth', '1', REPO_URL, str(REPO_DIR)])
os.chdir(REPO_DIR)
print('CWD:', Path.cwd())

+ git clone --depth 1 https://github.com/PeterOla/cafa-6-protein-function-prediction.git /content/cafa-6-protein-function-prediction
Cloning into '/content/cafa-6-protein-function-prediction'...

CWD: /content/cafa-6-protein-function-prediction


In [None]:
# Install missing dependencies
!pip install py-boost

In [None]:
# CELL 02 - Install dependencies
import importlib.util
import os
import subprocess
import sys

def _detect_kaggle() -> bool:
    return bool(os.environ.get('KAGGLE_KERNEL_RUN_TYPE') or os.environ.get('KAGGLE_URL_BASE') or os.environ.get('KAGGLE_DATA_PROXY_URL'))
def _detect_colab() -> bool:
    return bool(os.environ.get('COLAB_RELEASE_TAG') or os.environ.get('COLAB_GPU') or os.environ.get('COLAB_TPU_ADDR'))
IS_KAGGLE = _detect_kaggle()
IS_COLAB = (not IS_KAGGLE) and _detect_colab()
if IS_KAGGLE:
    print('Environment: Kaggle Detected')
elif IS_COLAB:
    print('Environment: Colab Detected')
else:
    print('Environment: Local Detected')
FORCE_PIP = os.environ.get('CAFA_FORCE_PIP', '0').strip() == '1'

def _pip_install(args: list[str]) -> None:
    print('+', sys.executable, '-m', 'pip', 'install', *args)
    subprocess.check_call([sys.executable, '-m', 'pip', '-q', 'install', *args])
# Kaggle has a heavily preinstalled environment; avoid upgrading core packages (pandas/notebook/requests/tornado/RAPIDS).
# Only install a small set if missing, unless CAFA_FORCE_PIP=1.
if IS_KAGGLE and not FORCE_PIP:
    minimal = ['obonet', 'biopython', 'pyyaml', 'py-boost']
    missing = [p for p in minimal if importlib.util.find_spec(p) is None]
    if missing:
        _pip_install(missing)
    else:
        print('Kaggle: skipping pip install (already satisfied). Set CAFA_FORCE_PIP=1 to force.')
else:
    _pip_install(['-r', 'requirements.txt'])


In [None]:
# CELL 02b - Safety switches (recommended for "reuse artefacts" runs)
# This notebook is designed to *reuse* an existing checkpoint dataset. To avoid accidental multi-GB uploads,
# we disable checkpoint pushes by default. Opt-in by setting CAFA_CHECKPOINT_PUSH=1.

import os

os.environ.setdefault('CAFA_CHECKPOINT_PULL', '1')
os.environ.setdefault('CAFA_CHECKPOINT_REQUIRED', '0')
# In Colab, pulling the entire checkpoint dataset in one go can get the kaggle process SIGKILL'd (return code -9).
# Default to lean pulls: grab parsed + external first; only pull features.zip when you actually need embeddings/preds.
os.environ.setdefault('CAFA_CHECKPOINT_PULL_FILES', 'parsed.zip,external.zip')
os.environ.setdefault('CAFA_CHECKPOINT_PUSH', '0')
os.environ.setdefault('CAFA_CHECKPOINT_PUSH_EXISTING', '0')

print('CAFA_CHECKPOINT_PULL:', os.environ.get('CAFA_CHECKPOINT_PULL'))
print('CAFA_CHECKPOINT_REQUIRED:', os.environ.get('CAFA_CHECKPOINT_REQUIRED'))
print('CAFA_CHECKPOINT_PULL_FILES:', os.environ.get('CAFA_CHECKPOINT_PULL_FILES'))
print('CAFA_CHECKPOINT_PUSH :', os.environ.get('CAFA_CHECKPOINT_PUSH'))
print('CAFA_CHECKPOINT_PUSH_EXISTING:', os.environ.get('CAFA_CHECKPOINT_PUSH_EXISTING'))


In [None]:
# CELL 03 - Solution: 1. SETUP, CONFIG & DIAGNOSTICS
# 1. SETUP, CONFIG & DIAGNOSTICS
# ==========================================
# HARDWARE: CPU (Standard)
# ==========================================
import json
import os
import shutil
import subprocess
import sys
import time
import zipfile
from dataclasses import dataclass
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

# ------------------------------------------
# Environment Detection & Paths
# ------------------------------------------
# Kaggle images can have `google-colab` installed; never use `import google.colab` as a signal.

def _detect_kaggle() -> bool:
    # Kaggle kernels reliably set at least one of these env vars.
    return bool(
        os.environ.get('KAGGLE_KERNEL_RUN_TYPE')
        or os.environ.get('KAGGLE_URL_BASE')
        or os.environ.get('KAGGLE_DATA_PROXY_URL')
    )


def _detect_colab() -> bool:
    # Colab sets these env vars; this avoids false positives on Kaggle.
    return bool(os.environ.get('COLAB_RELEASE_TAG') or os.environ.get('COLAB_GPU') or os.environ.get('COLAB_TPU_ADDR'))


IS_KAGGLE = _detect_kaggle()
IS_COLAB = (not IS_KAGGLE) and _detect_colab()

if IS_KAGGLE:
    print('Environment: Kaggle Detected')
    INPUT_ROOT = Path('/kaggle/input')
    WORKING_ROOT = Path('/kaggle/working')
    if INPUT_ROOT.exists():
        for dirname, _, filenames in os.walk(str(INPUT_ROOT)):
            for filename in filenames:
                print(os.path.join(dirname, filename))
elif IS_COLAB:
    print('Environment: Colab Detected')
    INPUT_ROOT = Path(os.environ.get('CAFA_INPUT_ROOT', str(Path('/content'))))
    WORKING_ROOT = Path(os.environ.get('CAFA_WORKING_ROOT', str(Path('/content'))))
else:
    print('Environment: Local Detected')
    CURRENT_DIR = Path.cwd()
    if CURRENT_DIR.name == 'notebooks':
        PROJECT_ROOT = CURRENT_DIR.parent
    else:
        PROJECT_ROOT = CURRENT_DIR
    INPUT_ROOT = Path(os.environ.get('CAFA_INPUT_ROOT', str(PROJECT_ROOT)))
    WORKING_ROOT = Path(os.environ.get('CAFA_WORKING_ROOT', str(PROJECT_ROOT / 'artefacts_local')))
    WORKING_ROOT.mkdir(exist_ok=True)

# ------------------------------------------
# Local cache roots (ephemeral) + published artefacts root
# ------------------------------------------
# IMPORTANT: we always write locally first (every runtime needs a write path),
# but the Kaggle Dataset is the *single source of truth* for resumability.

WORK_ROOT = Path(os.environ.get('CAFA_WORK_ROOT', str(WORKING_ROOT / 'work')))
WORK_ROOT.mkdir(parents=True, exist_ok=True)
(WORK_ROOT / 'parsed').mkdir(parents=True, exist_ok=True)
(WORK_ROOT / 'features').mkdir(parents=True, exist_ok=True)
(WORK_ROOT / 'external').mkdir(parents=True, exist_ok=True)

# Keep caches OUT of WORK_ROOT so we never accidentally publish them.
CACHE_ROOT = Path(os.environ.get('CAFA_CACHE_ROOT', str(WORKING_ROOT / 'cache')))
CACHE_ROOT.mkdir(parents=True, exist_ok=True)
os.environ.setdefault('HF_HOME', str(CACHE_ROOT / 'hf_home'))
os.environ.setdefault('TRANSFORMERS_CACHE', str(CACHE_ROOT / 'hf_home'))
os.environ.setdefault('HF_HUB_CACHE', str(CACHE_ROOT / 'hf_hub'))
os.environ.setdefault('TORCH_HOME', str(CACHE_ROOT / 'torch_home'))

# ------------------------------------------
# Dataset Discovery (competition data)
# ------------------------------------------

DATASET_SLUG = 'cafa-6-protein-function-prediction'


def _score_dataset_dir(p: Path) -> int:
    return (
        int((p / 'Train').exists())
        + int((p / 'Test').exists())
        + int((p / 'IA.tsv').exists())
        + int((p / 'sample_submission.tsv').exists())
    )


def find_dataset_root(input_root: Path, dataset_slug: str) -> Path:
    override = (os.environ.get('CAFA_DATASET_ROOT') or '').strip()
    if override:
        p = Path(override)
        if p.exists():
            return p
        raise FileNotFoundError(f'CAFA_DATASET_ROOT is set but does not exist: {p}')

    candidate = input_root / dataset_slug
    if candidate.exists() and _score_dataset_dir(candidate) >= 2:
        return candidate

    if _score_dataset_dir(input_root) >= 2:
        return input_root

    # Shallow scan: input_root/* and input_root/*/* (helps Colab + GitHub clones)
    candidates: list[Path] = []
    if input_root.exists():
        for p in input_root.iterdir():
            if not p.is_dir():
                continue
            candidates.append(p)
            try:
                for q in p.iterdir():
                    if q.is_dir():
                        candidates.append(q)
            except Exception:
                pass

    candidates = sorted(set(candidates), key=_score_dataset_dir, reverse=True)
    if candidates and _score_dataset_dir(candidates[0]) >= 2:
        return candidates[0]

    auto_flag = (os.environ.get('CAFA_COLAB_AUTO_DOWNLOAD') or '').strip() or '<unset>'

    # Colab rule: secrets must be fetched ONLY via google.colab.userdata.get(...)
    if IS_COLAB:
        try:
            from google.colab import userdata  # type: ignore

            has_user = bool((userdata.get('KAGGLE_USERNAME') or '').strip())
            has_key = bool((userdata.get('KAGGLE_KEY') or '').strip())
        except Exception:
            has_user = False
            has_key = False
    else:
        has_user = bool((os.environ.get('KAGGLE_USERNAME') or '').strip())
        has_key = bool((os.environ.get('KAGGLE_KEY') or '').strip())

    raise FileNotFoundError(
        f'Dataset not found under {input_root}. '
        'Note: a GitHub clone does not include the CAFA competition files.\n'
        f'CAFA_COLAB_AUTO_DOWNLOAD={auto_flag}; Kaggle creds present={has_user and has_key}.\n'
        'Fix options:\n'
        '  1) Set CAFA_DATASET_ROOT to a folder containing Train/ and Test/ (e.g. Drive).\n'
        '  2) (Colab) Set Colab secrets KAGGLE_USERNAME and KAGGLE_KEY, then set CAFA_COLAB_AUTO_DOWNLOAD=1.'
    )


def _maybe_colab_auto_download_competition() -> None:
    if not IS_COLAB:
        return

    # Default-on in Colab.
    # - Disable by explicitly setting CAFA_COLAB_AUTO_DOWNLOAD=0.
    os.environ.setdefault('CAFA_COLAB_AUTO_DOWNLOAD', '1')

    flag = (os.environ.get('CAFA_COLAB_AUTO_DOWNLOAD') or '').strip()
    if flag == '0':
        print('Colab auto-download disabled (CAFA_COLAB_AUTO_DOWNLOAD=0).')
        return

    target_dir = Path(os.environ.get('CAFA_COLAB_DATA_DIR', str(Path('/content/cafa6_data'))))
    target_dir.mkdir(parents=True, exist_ok=True)
    if _score_dataset_dir(target_dir) >= 2:
        os.environ['CAFA_DATASET_ROOT'] = str(target_dir)
        return

    # Install kaggle CLI if missing
    try:
        subprocess.run(['kaggle', '--version'], check=True, capture_output=True, text=True)
    except Exception:
        subprocess.check_call([sys.executable, '-m', 'pip', '-q', 'install', 'kaggle'])

    # Colab rule: fetch secrets ONLY via google.colab.userdata.get(...)
    try:
        from google.colab import userdata  # type: ignore
    except Exception as e:
        raise RuntimeError('Colab detected but google.colab.userdata is unavailable.') from e

    username = (userdata.get('KAGGLE_USERNAME') or '').strip()
    key = (userdata.get('KAGGLE_KEY') or '').strip()
    if (not username) or (not key):
        raise RuntimeError(
            'CAFA_COLAB_AUTO_DOWNLOAD=1 but Kaggle API auth is missing. '
            'Set Colab secrets KAGGLE_USERNAME and KAGGLE_KEY.'
        )

    # Export into env for downstream subprocesses (but do NOT *source* from env in Colab).
    os.environ['KAGGLE_USERNAME'] = username
    os.environ['KAGGLE_KEY'] = key
    env = os.environ.copy()

    print(f'Downloading competition data to {target_dir} via Kaggle API...')
    # Do NOT rely on kaggle --unzip; unzip ourselves.
    p = subprocess.run(
        ['kaggle', 'competitions', 'download', '-c', DATASET_SLUG, '-p', str(target_dir)],
        text=True,
        capture_output=True,
        env=env,
    )
    if p.returncode != 0:
        print(p.stdout)
        print(p.stderr)
        raise RuntimeError(
            'Failed to download competition data. See logs above (you may need to accept the competition rules).'
        )

    for z in target_dir.glob('*.zip'):
        try:
            print(f'Unzipping {z.name}...')
            with zipfile.ZipFile(z, 'r') as zf:
                zf.extractall(target_dir)
            z.unlink()
        except Exception as e:
            print(f'Warning: failed to unzip {z}: {e}')

    os.environ['CAFA_DATASET_ROOT'] = str(target_dir)


_maybe_colab_auto_download_competition()

DATASET_ROOT = find_dataset_root(INPUT_ROOT, DATASET_SLUG)
print(f'DATASET_ROOT: {DATASET_ROOT}')

# Define input paths
PATH_IA = DATASET_ROOT / 'IA.tsv'
PATH_SAMPLE_SUB = DATASET_ROOT / 'sample_submission.tsv'
PATH_TRAIN_FASTA = DATASET_ROOT / 'Train' / 'train_sequences.fasta'
PATH_TRAIN_TERMS = DATASET_ROOT / 'Train' / 'train_terms.tsv'
PATH_TRAIN_TAXON = DATASET_ROOT / 'Train' / 'train_taxonomy.tsv'
PATH_GO_OBO = DATASET_ROOT / 'Train' / 'go-basic.obo'
PATH_TEST_FASTA = DATASET_ROOT / 'Test' / 'testsuperset.fasta'
PATH_TEST_TAXON = DATASET_ROOT / 'Test' / 'testsuperset-taxon-list.tsv'

# ------------------------------------------
# Sanity Checks
# ------------------------------------------

required = {
    'IA.tsv': PATH_IA,
    'Train/train_sequences.fasta': PATH_TRAIN_FASTA,
    'Train/train_terms.tsv': PATH_TRAIN_TERMS,
    'Train/go-basic.obo': PATH_GO_OBO,
}
missing = {k: v for k, v in required.items() if not v.exists()}
if missing:
    raise FileNotFoundError(f'Missing files: {missing}')
print('All required inputs found.')

# ------------------------------------------
# Fail-fast: FASTA readability checks (path issues)
# ------------------------------------------

def _fasta_smoke_test(path: Path, label: str, max_lines: int = 20000) -> None:
    path = Path(path)

    if not path.exists():
        raise FileNotFoundError(
            f"{label} FASTA not found: {path}\n"
            f"DATASET_ROOT={DATASET_ROOT}\n"
            "If you're on Colab, ensure the competition data downloaded/unzipped correctly,\n"
            "or set CAFA_DATASET_ROOT to the folder containing Train/ and Test/."
        )

    try:
        size = path.stat().st_size
    except Exception:
        size = None

    if size is not None and size == 0:
        raise RuntimeError(f"{label} FASTA file is empty (0 bytes): {path}")

    headers = 0
    first_nonempty = None
    try:
        with path.open('r', encoding='utf-8', errors='ignore') as f:
            for i, line in enumerate(f):
                s = line.strip()
                if not s:
                    continue
                if first_nonempty is None:
                    first_nonempty = s
                if s.startswith('>'):
                    headers += 1
                if i >= max_lines:
                    break
    except Exception as e:
        raise RuntimeError(f"Failed reading {label} FASTA at {path}: {type(e).__name__}: {e}")

    if first_nonempty is None:
        raise RuntimeError(f"{label} FASTA appears empty/unreadable (no non-empty lines): {path}")

    if not first_nonempty.startswith('>'):
        raise RuntimeError(
            f"{label} FASTA does not look like FASTA (first content line doesn't start with '>'): {path}\n"
            f"First line was: {first_nonempty[:120]!r}"
        )

    if headers == 0:
        raise RuntimeError(f"{label} FASTA had zero headers in the first {max_lines} lines: {path}")


print('FASTA smoke tests:')
_fasta_smoke_test(PATH_TRAIN_FASTA, 'Train')
print('  Train FASTA: OK')

if PATH_TEST_FASTA.exists():
    _fasta_smoke_test(PATH_TEST_FASTA, 'Test')
    print('  Test FASTA: OK')
else:
    print(f'  Test FASTA: MISSING at {PATH_TEST_FASTA} (continuing; some steps may fail later)')

# ------------------------------------------
# Checkpoint store (Kaggle Dataset = single source of truth)
# ------------------------------------------

def _get_secret(name: str) -> str:
    # Colab rule: secrets must be fetched ONLY via google.colab.userdata.get(...).
    if IS_COLAB:
        try:
            from google.colab import userdata  # type: ignore

            return (userdata.get(name) or '').strip()
        except Exception:
            return ''

    # Non-Colab: allow env var -> Kaggle Secrets.
    v = (os.environ.get(name, '') or '').strip()
    if v:
        return v

    try:
        from kaggle_secrets import UserSecretsClient  # type: ignore

        v = (UserSecretsClient().get_secret(name) or '').strip()
        if v:
            return v
    except Exception:
        pass

    return ''


# Kaggle Secrets are NOT automatically environment variables.
# Resolve CAFA_FORCE_REBUILD via the same secret lookup and export it so later cells
# (which use os.getenv) behave consistently.
_raw_force_rebuild = (_get_secret('CAFA_FORCE_REBUILD') or '').strip()
if _raw_force_rebuild:
    os.environ.setdefault('CAFA_FORCE_REBUILD', _raw_force_rebuild)


def _truthy(v: str) -> bool:
    return str(v).strip().lower() in {'1', 'true', 'yes', 'y'}


print('CAFA_FORCE_REBUILD (env):', repr(os.environ.get('CAFA_FORCE_REBUILD', '')))
print('CAFA_FORCE_REBUILD (effective):', int(_truthy(os.environ.get('CAFA_FORCE_REBUILD', '0'))))


CHECKPOINT_DATASET_ID = (
    _get_secret('CAFA_CHECKPOINT_DATASET_ID')
    or _get_secret('CAFA_KAGGLE_DATASET_ID')
)
CHECKPOINT_DATASET_TITLE = os.environ.get('CAFA_CHECKPOINT_DATASET_TITLE', 'CAFA6 Checkpoints').strip()
CHECKPOINT_PULL = os.environ.get('CAFA_CHECKPOINT_PULL', '1').strip() == '1'
CHECKPOINT_PUSH = os.environ.get('CAFA_CHECKPOINT_PUSH', '1').strip() == '1'
MANIFEST_PATH = WORK_ROOT / 'manifest.json'


def _get_kaggle_token() -> str:
    return _get_secret('KAGGLE_API_TOKEN')


def _get_kaggle_user_key() -> tuple[str, str]:
    # Kaggle CLI expects Kaggle API credentials: username + key.
    username = _get_secret('KAGGLE_USERNAME')
    key = _get_secret('KAGGLE_KEY')
    if username and key:
        return username, key

    # Back-compat: allow KAGGLE_API_TOKEN to carry either JSON ({username,key}) or 'username:key'.
    tok = _get_kaggle_token()
    if tok:
        try:
            obj = json.loads(tok)
            username = (obj.get('username') or '').strip()
            key = (obj.get('key') or '').strip()
            if username and key:
                return username, key
        except Exception:
            pass

        if ':' in tok:
            u, k = tok.split(':', 1)
            username = u.strip()
            key = k.strip()
            if username and key:
                return username, key

    return '', ''


def _kaggle_env(require: bool = False) -> dict[str, str]:
    env = os.environ.copy()
    username, key = _get_kaggle_user_key()
    if username and key:
        # Export into both subprocess env (this call) and process env (subsequent cells).
        env['KAGGLE_USERNAME'] = username
        env['KAGGLE_KEY'] = key
        os.environ['KAGGLE_USERNAME'] = username
        os.environ['KAGGLE_KEY'] = key

    if require and (not env.get('KAGGLE_USERNAME') or not env.get('KAGGLE_KEY')):
        raise RuntimeError(
            'Kaggle API auth missing. The `kaggle` CLI requires `KAGGLE_USERNAME` + `KAGGLE_KEY` '
            '(set them as env vars or Kaggle/Colab secrets). Or attach the '
            'checkpoint dataset as an Input so `STORE.pull()` can use the mounted copy.'
        )

    return env


def _ensure_kaggle_cli() -> None:
    try:
        subprocess.run(['kaggle', '--version'], check=True, capture_output=True, text=True)
    except Exception:
        subprocess.check_call([sys.executable, '-m', 'pip', '-q', 'install', 'kaggle'])
        subprocess.run(['kaggle', '--version'], check=True)


def _copy_merge(src: Path, dst: Path) -> None:
    src = Path(src)
    dst = Path(dst)

    for p in src.rglob('*'):
        if p.is_dir():
            continue
        rel = p.relative_to(src)
        out = dst / rel
        out.parent.mkdir(parents=True, exist_ok=True)
        shutil.copy2(p, out)


def _maybe_unpack_dir_mode_zips(work_root: Path) -> None:
    # Kaggle CLI with `--dir-mode zip` uploads directories as `parsed.zip`, `external.zip`, etc.
    # We unpack them back into folders so the pipeline can resume from `WORK_ROOT/{parsed,external,features}/...`.
    work_root = Path(work_root)
    for name in ['parsed', 'external', 'features']:
        zpath = work_root / f'{name}.zip'
        if not zpath.exists():
            continue

        target_dir = work_root / name
        target_dir.mkdir(parents=True, exist_ok=True)
        needs_unpack = (not any(target_dir.rglob('*')))
        if not needs_unpack:
            # Folder already populated; keep the zip as-is (it might be a newer version).
            continue

        print(f'Unpacking checkpoint archive: {zpath} -> {target_dir}')
        with zipfile.ZipFile(zpath, 'r') as zf:
            zf.extractall(target_dir)
        try:
            zpath.unlink()
        except Exception:
            pass


def _load_manifest() -> dict:
    if MANIFEST_PATH.exists():
        try:
            return json.loads(MANIFEST_PATH.read_text(encoding='utf-8'))
        except Exception:
            return {}
    return {}


def _update_manifest(stage: str, required_paths: list[Path], note: str = '') -> None:
    m = _load_manifest()
    stages = m.get('stages', {})
    files = []

    for p in required_paths:
        p = Path(p)
        rel = str(p.relative_to(WORK_ROOT)) if str(p).startswith(str(WORK_ROOT)) else str(p)
        files.append({'path': rel, 'bytes': int(p.stat().st_size) if p.exists() else None})

    stages[stage] = {
        'ts_utc': time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime()),
        'note': note,
        'files': files,
    }

    m['stages'] = stages
    MANIFEST_PATH.write_text(json.dumps(m, indent=2), encoding='utf-8')


def _stage_files_signature(required_paths: list[Path]) -> list[dict]:
    sig = []
    for p in required_paths:
        p = Path(p)
        rel = str(p.relative_to(WORK_ROOT)) if str(p).startswith(str(WORK_ROOT)) else str(p)
        sig.append({'path': rel, 'bytes': int(p.stat().st_size) if p.exists() else None})
    return sorted(sig, key=lambda x: x['path'])


@dataclass
class KaggleCheckpointStore:
    work_root: Path
    dataset_id: str
    dataset_title: str
    pull_enabled: bool
    push_enabled: bool
    input_root: Path
    is_kaggle: bool

    @property
    def mount_dir(self) -> Path | None:
        if not self.is_kaggle or not self.dataset_id:
            return None
        slug = self.dataset_id.split('/')[-1]
        p = self.input_root / slug
        return p if p.exists() else None

    def pull(self) -> None:
        if not self.pull_enabled:
            print('Checkpoint pull disabled (CAFA_CHECKPOINT_PULL=0).')
            return

        checkpoint_required = str(os.environ.get('CAFA_CHECKPOINT_REQUIRED', '1')).strip().lower() in {'1', 'true', 'yes'}
        if not self.dataset_id:
            msg = 'Missing CAFA_CHECKPOINT_DATASET_ID=<user>/<slug>; cannot resume.'
            if checkpoint_required:
                raise ValueError(msg)
            print('WARNING: ' + msg)
            return

        if self.mount_dir is not None:
            print(f'Pulling checkpoints from Kaggle mounted dataset: {self.mount_dir}')
            _copy_merge(self.mount_dir, self.work_root)
            _maybe_unpack_dir_mode_zips(self.work_root)
            return

        print(f'Downloading checkpoints from Kaggle API: {self.dataset_id}')
        _ensure_kaggle_cli()
        env = _kaggle_env(require=checkpoint_required)

        if not env.get('KAGGLE_USERNAME') or not env.get('KAGGLE_KEY'):
            msg = (
                'Kaggle API auth missing. The `kaggle` CLI requires `KAGGLE_USERNAME` + `KAGGLE_KEY` '
                '(set them as env vars or Kaggle/Colab secrets). Either set them, or attach the '
                'checkpoint dataset as a Notebook Input so `STORE.pull()` can use the mounted copy.'
            )
            if checkpoint_required:
                raise RuntimeError(msg)
            print('WARNING: ' + msg)
            return

        tmp = self.work_root / '_tmp_kaggle_download'
        if tmp.exists():
            shutil.rmtree(tmp)
        tmp.mkdir(parents=True, exist_ok=True)

        # Debug (no secrets): confirm auth is present before calling kaggle.
        print('Kaggle auth present:', bool(env.get('KAGGLE_USERNAME')) and bool(env.get('KAGGLE_KEY')))
        print('Kaggle username length:', len(env.get('KAGGLE_USERNAME', '')))
        print('Kaggle key length:', len(env.get('KAGGLE_KEY', '')))

        # Colab can SIGKILL big unzip steps (returncode -9). Avoid `--unzip` and unzip ourselves.
        pull_files_raw = (os.environ.get('CAFA_CHECKPOINT_PULL_FILES') or '').strip()
        pull_files = [f.strip() for f in pull_files_raw.split(',') if f.strip()] if pull_files_raw else []

        def _download_via_kaggle_api(files: list[str] | None) -> tuple[int, str, str]:
            # Use the Kaggle Python API in Colab (more stable than spawning the CLI for large downloads).
            try:
                from kaggle.api.kaggle_api_extended import KaggleApi  # type: ignore
            except Exception as e:
                return 1, '', f'Failed to import KaggleApi: {type(e).__name__}: {e}'

            try:
                api = KaggleApi()
                api.authenticate()
            except Exception as e:
                return 1, '', f'KaggleApi.authenticate() failed: {type(e).__name__}: {e}'

            def _list_dataset_files() -> list[str]:
                try:
                    lf = api.dataset_list_files(self.dataset_id)
                    out = []
                    for f in getattr(lf, 'files', []) or []:
                        name = getattr(f, 'name', None)
                        if name:
                            out.append(str(name))
                    return sorted(out)
                except Exception as e:
                    print('WARNING: failed to list dataset files via KaggleApi:', type(e).__name__, str(e)[:200])
                    return []

            available = _list_dataset_files()
            if available:
                print('Checkpoint dataset files (first 60):')
                for n in available[:60]:
                    print(' -', n)
                if len(available) > 60:
                    print(f' - ... ({len(available)-60} more)')
            else:
                print('WARNING: could not list files for dataset via KaggleApi (might be private/not visible).')

            # Strategy: Force FULL DATASET DOWNLOAD (skipping individual files)
            print('Forcing FULL DATASET DOWNLOAD (skipping individual files)...')

            # Track what we need to extract
            needed_prefixes = set()
            if not files:
                needed_prefixes.add('') # Root
            else:
                for f in files:
                    if f.endswith('.zip'):
                        needed_prefixes.add(f[:-4] + '/') # parsed.zip -> parsed/
                    else:
                        needed_prefixes.add(f) # literal file

            try:
                # Download full zip
                api.dataset_download_files(self.dataset_id, path=str(tmp), force=True, quiet=False, unzip=False)
                
                # Find the zip
                zips = list(tmp.glob('*.zip'))
                if not zips:
                    return 1, '', 'Full download finished but no .zip file found.'
                
                main_zip = zips[0]
                print(f'Extracting relevant files from {main_zip.name}...')
                
                with zipfile.ZipFile(main_zip, 'r') as zf:
                    all_names = zf.namelist()
                    to_extract = []
                    for n in all_names:
                        if '' in needed_prefixes:
                            to_extract.append(n)
                            continue
                        
                        for p in needed_prefixes:
                            if n.startswith(p) or n in {'manifest.json', 'README.md'}:
                                to_extract.append(n)
                                break
                    
                    print(f'Extracting {len(to_extract)} files...')
                    zf.extractall(tmp, members=to_extract)
                
                main_zip.unlink()
                return 0, 'kaggle_api: ok (full download)', ''

            except Exception as e_full:
                return 1, '', f'Full download failed: {type(e_full).__name__}: {e_full}'

        def _extract_outer_zips(download_dir: Path) -> None:
            zips = sorted(download_dir.glob('*.zip'))
            for z in zips:
                try:
                    with zipfile.ZipFile(z, 'r') as zf:
                        zf.extractall(download_dir)
                    z.unlink()
                except Exception as e:
                    raise RuntimeError(f'Failed to unzip downloaded archive: {z}: {type(e).__name__}: {e}')

        if IS_COLAB:
            # In Colab, prefer KaggleApi (avoids subprocess SIGKILLs more often than the CLI).
            if pull_files:
                print('Checkpoint pull (file mode via KaggleApi):', ', '.join(pull_files))
            else:
                print('Checkpoint pull (full via KaggleApi).')
            rc, out_s, err_s = _download_via_kaggle_api(pull_files if pull_files else None)
            class _P:  # lightweight subprocess-like container
                def __init__(self, returncode, stdout, stderr):
                    self.returncode = returncode
                    self.stdout = stdout
                    self.stderr = stderr
            p = _P(rc, out_s, err_s)
            if p.returncode == 0:
                _extract_outer_zips(tmp)
        else:
            # Local/Kaggle: keep using the CLI.
            if pull_files:
                print('Checkpoint pull (file mode):', ', '.join(pull_files))
                for fname in pull_files:
                    p = subprocess.run(
                        ['kaggle', 'datasets', 'download', '-d', self.dataset_id, '-f', fname, '-p', str(tmp)],
                        text=True,
                        capture_output=True,
                        env=env,
                    )
                    if p.returncode != 0:
                        break
                    _extract_outer_zips(tmp)
            else:
                p = subprocess.run(
                    ['kaggle', 'datasets', 'download', '-d', self.dataset_id, '-p', str(tmp)],
                    text=True,
                    capture_output=True,
                    env=env,
                )
                if p.returncode == 0:
                    _extract_outer_zips(tmp)

        if p.returncode != 0:
            print('Kaggle return code:', p.returncode)
            print(p.stdout)
            print(p.stderr)

            stderr_raw = p.stderr or ''
            stdout_raw = p.stdout or ''
            err = (stderr_raw + '\n' + stdout_raw).strip()
            if (p.returncode in {-9, 137}) and (not err):
                err = (
                    f'kaggle process was killed (returncode={p.returncode}). This is usually OOM or a large unzip. '
                    'or set CAFA_CHECKPOINT_REQUIRED=0 to continue without checkpoints.'
                )

            # If kaggle returned non-zero but we captured nothing, rerun without capture_output so logs appear.
            if not err:
                print('Kaggle CLI returned non-zero but no output was captured; rerunning without capture_output for logs...')
                try:
                    subprocess.run(
                        ['kaggle', 'datasets', 'download', '-d', self.dataset_id, '-p', str(tmp)],
                        text=True,
                        env=env,
                        check=False,
                    )
                except Exception as e:
                    print('Rerun failed:', repr(e))
                err = '<no kaggle cli output captured; see the rerun logs above>'

            err_excerpt = err[:2000] + ('\n...<truncated>' if len(err) > 2000 else '')

            if '403' in err or 'Forbidden' in err:
                msg = (
                    'Checkpoint download was forbidden (HTTP 403). This almost always means the dataset is private '
                    'or not accessible from the current Kaggle account/runtime. Fix options:\n'
                    '  1) Attach the checkpoint dataset as a Kaggle notebook Input (fastest; no API call).\n'
                    '  2) Make the checkpoint dataset public (or share it with the account running the notebook).\n'
                    '  3) Ensure Secrets `KAGGLE_USERNAME`/`KAGGLE_KEY` belong to a user with access.'
                )
                if checkpoint_required:
                    raise RuntimeError(msg)
                print('WARNING: ' + msg)
                shutil.rmtree(tmp)
                return

            if '404' in err or 'Not Found' in err:
                msg = (
                    'Checkpoint dataset not found / not visible (HTTP 404). This can mean the dataset ID is wrong '
                    'or the dataset is private and not accessible to the current account. '
                    f'Dataset: {self.dataset_id}'
                )
                if checkpoint_required:
                    raise RuntimeError(msg + '\n\nKaggle CLI output:\n' + err_excerpt)
                print('WARNING: ' + msg)
                shutil.rmtree(tmp)
                return

            msg = 'Failed to download checkpoints from Kaggle. Check auth/network.'
            if checkpoint_required:
                raise RuntimeError(msg + '\n\nKaggle CLI output:\n' + err_excerpt)
            print('WARNING: ' + msg)
            shutil.rmtree(tmp)
            return

        _copy_merge(tmp, self.work_root)
        _maybe_unpack_dir_mode_zips(self.work_root)
        shutil.rmtree(tmp)

    def push(self, stage: str, required_paths: list[Path], note: str = '') -> None:
        if not self.push_enabled:
            print('Checkpoint push disabled (CAFA_CHECKPOINT_PUSH=0).')
            return

        if not self.dataset_id:
            raise ValueError('Missing CAFA_CHECKPOINT_DATASET_ID=<user>/<slug>; cannot checkpoint.')

        missing = [p for p in required_paths if not Path(p).exists()]
        if missing:
            raise FileNotFoundError(
                'Cannot checkpoint; missing required artefacts:\n' + '\n'.join([f' - {m}' for m in missing])
            )

        force_push = os.environ.get('CAFA_CHECKPOINT_FORCE_PUSH', '0').strip() == '1'
        if not force_push:
            m = _load_manifest()
            existing = (m.get('stages', {}) or {}).get(stage) if isinstance(m, dict) else None
            if isinstance(existing, dict):
                prev_files = existing.get('files', [])
                if isinstance(prev_files, list):
                    prev_sig = sorted(
                        [{'path': f.get('path'), 'bytes': f.get('bytes')} for f in prev_files if isinstance(f, dict)],
                        key=lambda x: str(x.get('path')),
                    )
                    cur_sig = _stage_files_signature(required_paths)
                    if prev_sig == cur_sig:
                        print(
                            f'Checkpoint stage {stage} unchanged; skipping publish '
                            '(set CAFA_CHECKPOINT_FORCE_PUSH=1 to force).'
                        )
                        return

        _update_manifest(stage, required_paths, note=note)

        # Publish WORK_ROOT directly (must not contain caches).
        (self.work_root / 'dataset-metadata.json').write_text(
            json.dumps({'title': self.dataset_title, 'id': self.dataset_id, 'licenses': [{'name': 'CC0-1.0'}]}, indent=2),
            encoding='utf-8',
        )
        (self.work_root / 'README.md').write_text(
            f'# {self.dataset_title}\n\nAuto-published checkpoint dataset for CAFA6.\n\nLatest stage: {stage}\n',
            encoding='utf-8',
        )

        _ensure_kaggle_cli()
        env = _kaggle_env(require=True)
        msg = f'{stage}: {note}'.strip() if note else stage

        # IMPORTANT: Kaggle CLI skips directories unless --dir-mode is set.
        p = subprocess.run(
            ['kaggle', 'datasets', 'version', '-p', str(self.work_root), '--dir-mode', 'zip', '-m', msg],
            text=True,
            capture_output=True,
            env=env,
        )

        if p.returncode != 0:
            # If dataset does not exist yet, create it.
            p2 = subprocess.run(
                ['kaggle', 'datasets', 'create', '-p', str(self.work_root), '--dir-mode', 'zip'],
                text=True,
                capture_output=True,
                env=env,
            )
            if p2.returncode != 0:
                print(p.stdout)
                print(p.stderr)
                print(p2.stdout)
                print(p2.stderr)
                raise RuntimeError('Kaggle dataset publish failed. See logs above.')
            print(p2.stdout)
            print(p2.stderr)
        else:
            print(p.stdout)
            print(p.stderr)
            print('Published new checkpoint dataset version:', self.dataset_id)


STORE = KaggleCheckpointStore(
    work_root=WORK_ROOT,
    dataset_id=CHECKPOINT_DATASET_ID,
    dataset_title=CHECKPOINT_DATASET_TITLE,
    pull_enabled=CHECKPOINT_PULL,
    push_enabled=CHECKPOINT_PUSH,
    input_root=INPUT_ROOT,
    is_kaggle=IS_KAGGLE,
)

# Pull once at startup (fresh runtimes resume here)
STORE.pull()


# Post-pull diagnostics: make it obvious whether artefacts are present (and whether dir-mode zips were unpacked).

def _p(path: Path) -> str:
    return str(path)


def _exists_bytes(path: Path) -> str:
    if not path.exists():
        return 'MISSING'
    try:
        return f'OK ({path.stat().st_size / (1024**2):.1f} MB)'
    except Exception:
        return 'OK'


print('Checkpoint status (after pull):')
print('  WORK_ROOT:', _p(WORK_ROOT))
print('  parsed/:', _exists_bytes(WORK_ROOT / 'parsed'))
print('  external/:', _exists_bytes(WORK_ROOT / 'external'))
print('  features/:', _exists_bytes(WORK_ROOT / 'features'))
print('  parsed.zip:', _exists_bytes(WORK_ROOT / 'parsed.zip'))
print('  external.zip:', _exists_bytes(WORK_ROOT / 'external.zip'))
print('  features.zip:', _exists_bytes(WORK_ROOT / 'features.zip'))
print('  external/entryid_text.tsv:', _exists_bytes(WORK_ROOT / 'external' / 'entryid_text.tsv'))
print('  parsed/train_seq.feather:', _exists_bytes(WORK_ROOT / 'parsed' / 'train_seq.feather'))


def stage_present(required_paths: list[Path]) -> bool:
    return all(Path(p).exists() for p in required_paths)


# ------------------------------------------
# Initial Diagnostics (Sequence Lengths)
# ------------------------------------------

%matplotlib inline
plt.rcParams.update({'font.size': 10})


def read_fasta_lengths(path: Path, max_records=20000):
    lengths = []
    current = 0
    n = 0
    with path.open('r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line.startswith('>'):
                if n > 0:
                    lengths.append(current)
                n += 1
                current = 0
                if max_records and n > max_records:
                    break
            else:
                current += len(line)
        if n > 0:
            lengths.append(current)
    return np.array(lengths)


train_lens = read_fasta_lengths(PATH_TRAIN_FASTA)
test_lens = read_fasta_lengths(PATH_TEST_FASTA) if PATH_TEST_FASTA.exists() else np.array([])

plt.figure(figsize=(10, 4))
# Use fixed bins (0 to 3000) to ignore outliers (like Titin) and ensure alignment
bins = np.linspace(0, 3000, 60)

plt.hist(train_lens, bins=bins, alpha=0.5, label='Train', density=True, color='tab:blue')
if len(test_lens) > 0:
    plt.hist(test_lens, bins=bins, alpha=0.5, label='Test', density=True, color='tab:orange')

plt.title('Sequence Length Distribution (0-3000aa, Normalized)')
plt.xlabel('Length (amino acids)')
plt.ylabel('Density')
plt.legend()
plt.grid(True, alpha=0.2)
plt.show()

print(f'Train sequences: {len(train_lens)}')
print(f'Test sequences : {len(test_lens)}')
if len(train_lens) > 0:
    print(f'Train max len  : {train_lens.max()}')
if len(test_lens) > 0:
    print(f'Test max len   : {test_lens.max()}')


In [None]:
# CELL 04 - Solution: 2. PHASE 1: DATA STRUCTURING & HIERARCHY
# 2. PHASE 1: DATA STRUCTURING & HIERARCHY
# ==========================================
# HARDWARE: CPU (Standard)
# ==========================================
# ------------------------------------------
# B. Parse OBO & Terms (needed in-memory downstream)
# ------------------------------------------

def parse_obo(path: Path):
    parents = {}
    namespaces = {}
    cur_id, cur_ns = None, None
    with path.open('r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line == '[Term]':
                if cur_id and cur_ns:
                    namespaces[cur_id] = cur_ns
                cur_id, cur_ns = None, None
            elif line.startswith('id: GO:'):
                cur_id = line.split('id: ', 1)[1]
            elif line.startswith('namespace:'):
                cur_ns = line.split('namespace: ', 1)[1]
            elif line.startswith('is_a:') and cur_id:
                parent = line.split('is_a: ', 1)[1].split(' ! ')[0]
                parents.setdefault(cur_id, set()).add(parent)
        if cur_id and cur_ns:
            namespaces[cur_id] = cur_ns
    return parents, namespaces
print("Parsing OBO...")
go_parents, go_namespaces = parse_obo(PATH_GO_OBO)
print(f"GO Graph: {len(go_parents)} nodes with parents, {len(go_namespaces)} terms with namespace.")
# ------------------------------------------
# Milestone checkpoint: stage_01_parsed
# ------------------------------------------
parsed_dir = WORK_ROOT / 'parsed'
parsed_dir.mkdir(parents=True, exist_ok=True)
out_train_seq = parsed_dir / 'train_seq.feather'
out_test_seq = parsed_dir / 'test_seq.feather'
out_train_terms = parsed_dir / 'train_terms.parquet'
out_term_counts = parsed_dir / 'term_counts.parquet'
out_term_priors = parsed_dir / 'term_priors.parquet'
out_train_taxa = parsed_dir / 'train_taxa.feather'
out_test_taxa = parsed_dir / 'test_taxa.feather'
expected = [out_train_seq, out_train_terms, out_term_counts, out_term_priors, out_train_taxa]
if PATH_TEST_FASTA.exists():
    expected += [out_test_seq, out_test_taxa]
missing = [p for p in expected if not p.exists()]
if not missing:
    print("Parsed artefacts already exist; skipping Phase 1 writes.")
else:
    # ------------------------------------------
    # A. Parse FASTA to Feather
    # ------------------------------------------
    def parse_fasta(path: Path) -> pd.DataFrame:
        ids, seqs = [], []
        cur_id, cur_seq = None, []
        with path.open('r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line.startswith('>'):
                    if cur_id:
                        ids.append(cur_id)
                        seqs.append(''.join(cur_seq))
                    cur_id = line[1:].split()[0]
                    cur_seq = []
                else:
                    cur_seq.append(line)
            if cur_id:
                ids.append(cur_id)
                seqs.append(''.join(cur_seq))
        return pd.DataFrame({'id': ids, 'sequence': seqs})
    print("Parsing FASTA...")
    parse_fasta(PATH_TRAIN_FASTA).to_feather(out_train_seq)
    if PATH_TEST_FASTA.exists():
        parse_fasta(PATH_TEST_FASTA).to_feather(out_test_seq)
    print("FASTA parsed and saved to artefacts.")
    # ------------------------------------------
    # C. Process Terms & Priors
    # ------------------------------------------
    terms = pd.read_csv(PATH_TRAIN_TERMS, sep='\t')
    col_term = terms.columns[1]
    terms['aspect'] = terms[col_term].map(lambda x: go_namespaces.get(x, 'UNK'))
    # Plot Aspects
    plt.figure(figsize=(6, 3))
    terms['aspect'].value_counts().plot(kind='bar', title='Annotations by Namespace')
    plt.show()
    # Save Priors
    priors = (terms[col_term].value_counts() / terms.iloc[:, 0].nunique()).reset_index()
    priors.columns = ['term', 'prior']
    if PATH_IA.exists():
        ia = pd.read_csv(PATH_IA, sep='\t', names=['term', 'ia'])
        priors = priors.merge(ia, on='term', how='left').fillna(0)
    priors.to_parquet(out_term_priors)
    print("Terms processed and priors saved.")
    # ------------------------------------------
    # D. Process Taxonomy
    # ------------------------------------------
    print("Processing Taxonomy...")
    # Train Taxonomy
    tax_train = pd.read_csv(PATH_TRAIN_TAXON, sep='\t', header=None, names=['id', 'taxon_id'])
    tax_train['taxon_id'] = tax_train['taxon_id'].astype(int)
    tax_train.to_feather(out_train_taxa)
    # Test Taxonomy (Extract from FASTA headers)
    if PATH_TEST_FASTA.exists():
        ids, taxons = [], []
        with PATH_TEST_FASTA.open('r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line.startswith('>'):
                    parts = line[1:].split()
                    ids.append(parts[0])
                    # Assume second part is taxon if present
                    if len(parts) > 1:
                        try:
                            taxons.append(int(parts[1]))
                        except ValueError:
                            taxons.append(0)
                    else:
                        taxons.append(0)
        tax_test = pd.DataFrame({'id': ids, 'taxon_id': taxons})
        tax_test.to_feather(out_test_taxa)
        print(f"Taxonomy processed. Train: {len(tax_train)}, Test: {len(tax_test)}")
    else:
        print(f"Taxonomy processed. Train: {len(tax_train)}")
    # ------------------------------------------
    # E. Save Targets & Term List
    # ------------------------------------------
    print("Saving Targets & Term List...")
    # Save full terms list (long format)
    terms.to_parquet(out_train_terms)
    # Save unique term list with counts
    term_counts = terms['term'].value_counts().reset_index()
    term_counts.columns = ['term', 'count']
    term_counts.to_parquet(out_term_counts)
    print("Targets saved.")
    if 'STORE' in globals() and STORE is not None:
        STORE.push('stage_01_parsed', [p for p in expected if p.exists()], note='parsed FASTA/taxa/terms/priors')

In [None]:
# CELL 10b - Diagnostics: artefact manifest (existence + sizes)
%matplotlib inline
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

plt.rcParams.update({'font.size': 12})

def _mb(p: Path) -> float:
    return p.stat().st_size / (1024**2)

WORK_ROOT = Path(WORK_ROOT)

# Minimal contract for *this* notebook (first submission, no Ankh):
# - parsed/*
# - core embeddings (t5, esm2, esm2_3b, text)
# - taxonomy
# - external priors (if you want them in the stacker)
paths = {
    # Phase 1 parsed
    'parsed/train_seq.feather': WORK_ROOT / 'parsed' / 'train_seq.feather',
    'parsed/test_seq.feather': WORK_ROOT / 'parsed' / 'test_seq.feather',
    'parsed/train_terms.parquet': WORK_ROOT / 'parsed' / 'train_terms.parquet',
    'parsed/term_priors.parquet': WORK_ROOT / 'parsed' / 'term_priors.parquet',
    'parsed/train_taxa.feather': WORK_ROOT / 'parsed' / 'train_taxa.feather',
    'parsed/test_taxa.feather': WORK_ROOT / 'parsed' / 'test_taxa.feather',
    # Text pipeline
    'external/entryid_text.tsv': WORK_ROOT / 'external' / 'entryid_text.tsv',
    'features/text_vectorizer.joblib': WORK_ROOT / 'features' / 'text_vectorizer.joblib',
    'features/train_embeds_text.npy': WORK_ROOT / 'features' / 'train_embeds_text.npy',
    'features/test_embeds_text.npy': WORK_ROOT / 'features' / 'test_embeds_text.npy',
    # Sequence embeddings (core)
    'features/train_embeds_t5.npy': WORK_ROOT / 'features' / 'train_embeds_t5.npy',
    'features/test_embeds_t5.npy': WORK_ROOT / 'features' / 'test_embeds_t5.npy',
    'features/train_embeds_esm2.npy': WORK_ROOT / 'features' / 'train_embeds_esm2.npy',
    'features/test_embeds_esm2.npy': WORK_ROOT / 'features' / 'test_embeds_esm2.npy',
    'features/train_embeds_esm2_3b.npy': WORK_ROOT / 'features' / 'train_embeds_esm2_3b.npy',
    'features/test_embeds_esm2_3b.npy': WORK_ROOT / 'features' / 'test_embeds_esm2_3b.npy',
    # External priors (optional but used if present)
    'external/prop_train_no_kaggle.tsv.gz': WORK_ROOT / 'external' / 'prop_train_no_kaggle.tsv.gz',
    'external/prop_test_no_kaggle.tsv.gz': WORK_ROOT / 'external' / 'prop_test_no_kaggle.tsv.gz',
    # Downstream expectations
    'features/top_terms_1500.json': WORK_ROOT / 'features' / 'top_terms_1500.json',
    'features/oof_pred_logreg.npy': WORK_ROOT / 'features' / 'oof_pred_logreg.npy',
    'features/oof_pred_gbdt.npy': WORK_ROOT / 'features' / 'oof_pred_gbdt.npy',
    'features/oof_pred_dnn.npy': WORK_ROOT / 'features' / 'oof_pred_dnn.npy',
    'features/oof_pred_knn.npy': WORK_ROOT / 'features' / 'oof_pred_knn.npy',
    'features/test_pred_logreg.npy': WORK_ROOT / 'features' / 'test_pred_logreg.npy',
    'features/test_pred_gbdt.npy': WORK_ROOT / 'features' / 'test_pred_gbdt.npy',
    'features/test_pred_dnn.npy': WORK_ROOT / 'features' / 'test_pred_dnn.npy',
    'features/test_pred_knn.npy': WORK_ROOT / 'features' / 'test_pred_knn.npy',
    'features/test_pred_gcn.npy': WORK_ROOT / 'features' / 'test_pred_gcn.npy',
}

rows = []
for name, p in paths.items():
    rows.append({'artefact': name, 'exists': p.exists(), 'mb': _mb(p) if p.exists() else 0.0, 'path': str(p)})
df = pd.DataFrame(rows).sort_values(['exists', 'mb'], ascending=[True, False])
print('WORK_ROOT:', WORK_ROOT)
display(df)

# Visual: top 25 largest artefacts
df2 = df[df['exists']].sort_values('mb', ascending=False).head(25)
if len(df2) > 0:
    plt.figure(figsize=(10, 6))
    sns.barplot(data=df2, y='artefact', x='mb')
    plt.title('Largest artefacts (MB)')
    plt.xlabel('MB')
    plt.ylabel('artefact')
    plt.tight_layout()
    plt.show()

# Strict check (what must exist before Phase 2)
required_for_phase2 = [
    'features/train_embeds_text.npy',
    'features/test_embeds_text.npy',
    'features/train_embeds_t5.npy',
    'features/test_embeds_t5.npy',
    'features/train_embeds_esm2.npy',
    'features/test_embeds_esm2.npy',
    'features/train_embeds_esm2_3b.npy',
    'features/test_embeds_esm2_3b.npy',
    'parsed/train_taxa.feather',
    'parsed/test_taxa.feather',
    'parsed/train_terms.parquet',
    'parsed/train_seq.feather',
    'parsed/test_seq.feather',
    # If you intend to use external priors in the GCN cell, these must exist too:
    # 'external/prop_train_no_kaggle.tsv.gz',
    # 'external/prop_test_no_kaggle.tsv.gz',
]
missing = [a for a in required_for_phase2 if not paths[a].exists()]
if missing:
    print('\nMissing artefacts for Phase 2 (first submission, no Ankh):')
    for m in missing:
        print(' -', m)
else:
    print('\nPhase 2 artefacts OK: embeddings + taxonomy + parsed targets present.')

In [None]:
# CELL 8a - Setup & Data Loading
# =============================================
# 4. PHASE 2: LEVEL-1 MODELS (DIVERSE ENSEMBLE)
# =============================================
TRAIN_LEVEL1 = True
if TRAIN_LEVEL1:
    import joblib
    import json
    import pandas as pd
    import numpy as np
    import os
    import gc
    from pathlib import Path
    from sklearn.model_selection import KFold
    from sklearn.metrics import f1_score
    import psutil

    # AUDITOR: Hardware Check
    try:
        import torch
        if torch.cuda.is_available():
            print(f"[AUDITOR] GPU Detected: {torch.cuda.get_device_name(0)}")
            print(f"[AUDITOR] VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
        else:
            print("[AUDITOR] WARNING: No GPU detected. RAPIDS will fail.")
    except:
        pass

    def log_mem(tag=""):
        try:
            mem = psutil.virtual_memory()
            print(f"[MEM] {tag:<30} | Used: {mem.used/1e9:.2f}GB / {mem.total/1e9:.2f}GB ({mem.percent}%)")
        except:
            pass

    if 'WORK_ROOT' not in locals() and 'WORK_ROOT' not in globals():
        if os.path.exists('/content/work'):
            WORK_ROOT = Path('/content/work')
        elif os.path.exists('/kaggle/working/work'):
            WORK_ROOT = Path('/kaggle/working/work')
        else:
            WORK_ROOT = Path.cwd() / 'artefacts_local' / 'work'
        print(f"WORK_ROOT recovered: {WORK_ROOT}")

    # -----------------------------
    # Load targets + ids
    # -----------------------------
    print("Loading targets...")
    train_terms = pd.read_parquet(WORK_ROOT / 'parsed' / 'train_terms.parquet')
    train_ids = pd.read_feather(WORK_ROOT / 'parsed' / 'train_seq.feather')['id'].astype(str)
    test_ids = pd.read_feather(WORK_ROOT / 'parsed' / 'test_seq.feather')['id'].astype(str)
    
    # FIX: Clean IDs in train_ids to match EntryID format
    print("Applying ID cleaning fix...")
    train_ids_clean = train_ids.str.extract(r'\|(.*?)\|')[0]
    train_ids_clean = train_ids_clean.fillna(train_ids)
    
    # Target Matrix Construction (Champion Strategy: 13,500 Terms)
    print("Selecting Top-K terms per aspect (Champion Strategy)...")
    
    try:
        import obonet
        obo_path = WORK_ROOT.parent / 'go-basic.obo'
        if not obo_path.exists(): obo_path = Path('go-basic.obo')
        if not obo_path.exists(): obo_path = Path('/content/cafa-6-protein-function-prediction/Train/go-basic.obo')
        
        print(f"Loading OBO from {obo_path}...")
        graph = obonet.read_obo(obo_path)
        term_to_ns = {node: data.get('namespace', 'unknown') for node, data in graph.nodes(data=True)}
        ns_map = {'biological_process': 'BP', 'molecular_function': 'MF', 'cellular_component': 'CC'}
        if 'aspect' not in train_terms.columns:
            train_terms['aspect'] = train_terms['term'].map(lambda t: ns_map.get(term_to_ns.get(t), 'UNK'))
    except ImportError:
        print("[WARNING] obonet not found. Falling back to global Top-K (13,500).")
        train_terms['aspect'] = 'UNK'

    term_counts = train_terms.groupby(['aspect', 'term']).size().reset_index(name='count')
    targets_bp = term_counts[term_counts['aspect'] == 'BP'].nlargest(10000, 'count')['term'].tolist()
    targets_mf = term_counts[term_counts['aspect'] == 'MF'].nlargest(2000, 'count')['term'].tolist()
    targets_cc = term_counts[term_counts['aspect'] == 'CC'].nlargest(1500, 'count')['term'].tolist()
    
    if len(targets_bp) == 0 and len(targets_mf) == 0:
        print("  Using global Top-13,500 (OBO fallback)")
        top_terms = train_terms['term'].value_counts().head(13500).index.tolist()
    else:
        top_terms = list(set(targets_bp + targets_mf + targets_cc))
        print(f"  Selected: {len(targets_bp)} BP + {len(targets_mf)} MF + {len(targets_cc)} CC")

    train_terms_top = train_terms[train_terms['term'].isin(top_terms)]
    Y_df = train_terms_top.pivot_table(index='EntryID', columns='term', aggfunc='size', fill_value=0)
    Y_df = Y_df.reindex(train_ids_clean, fill_value=0)
    Y = Y_df.values.astype(np.float32)
    print(f"Targets: Y={Y.shape}")
    
    # -----------------------------
    # Feature loading helper
    # -----------------------------
    FEAT_DIR = WORK_ROOT / 'features'
    
    def load_features_dict(split='both'):
        log_mem(f"Start load_features_dict({split})")
        print(f"Loading multimodal features (mode={split})...")
        def _load_pair(stem):
            tr = FEAT_DIR / f'train_embeds_{stem}.npy'
            te = FEAT_DIR / f'test_embeds_{stem}.npy'
            return tr, te
        
        ft_train = {}
        ft_test = {}
        for stem, key in [('t5', 't5'), ('esm2', 'esm2_650m'), ('esm2_3b', 'esm2_3b'), ('text', 'text')]:
            tr_path, te_path = _load_pair(stem)
            if split in ['both', 'train'] and tr_path.exists():
                arr = np.load(tr_path).astype(np.float32)
                ft_train[key] = arr
            if split in ['both', 'test'] and te_path.exists():
                arr = np.load(te_path).astype(np.float32)
                ft_test[key] = arr
        
        taxa_train_path = WORK_ROOT / 'parsed' / 'train_taxa.feather'
        taxa_test_path = WORK_ROOT / 'parsed' / 'test_taxa.feather'
        if taxa_train_path.exists() and taxa_test_path.exists():
            from sklearn.preprocessing import OneHotEncoder
            tax_tr = pd.read_feather(taxa_train_path).astype({'id': str})
            tax_te = pd.read_feather(taxa_test_path).astype({'id': str})
            enc = OneHotEncoder(handle_unknown='ignore', sparse_output=False, dtype=np.float32)
            enc.fit(pd.concat([tax_tr[['taxon_id']], tax_te[['taxon_id']]], axis=0))
            
            if split in ['both', 'train']:
                tax_tr = tax_tr.set_index('id').reindex(train_ids, fill_value=0).reset_index()
                ft_train['taxa'] = enc.transform(tax_tr[['taxon_id']]).astype(np.float32)
            if split in ['both', 'test']:
                tax_te = tax_te.set_index('id').reindex(test_ids, fill_value=0).reset_index()
                ft_test['taxa'] = enc.transform(tax_te[['taxon_id']]).astype(np.float32)
        
        log_mem(f"End load_features_dict({split})")
        if split == 'train': return ft_train
        if split == 'test': return ft_test
        return ft_train, ft_test

    # -----------------------------
    # IA-weighted F1 Helper
    # -----------------------------
    if 'ia' in locals(): ia_df = ia[['term', 'ia']].copy()
    elif (WORK_ROOT.parent / 'IA.tsv').exists(): ia_df = pd.read_csv(WORK_ROOT.parent / 'IA.tsv', sep='\t', names=['term', 'ia'])
    elif (WORK_ROOT / 'IA.tsv').exists(): ia_df = pd.read_csv(WORK_ROOT / 'IA.tsv', sep='\t', names=['term', 'ia'])
    else: ia_df = pd.DataFrame({'term': [], 'ia': []})

    ia_map = dict(zip(ia_df['term'], ia_df['ia']))
    weights = np.array([ia_map.get(t, 0.0) for t in top_terms], dtype=np.float32)
    
    if 'go_namespaces' not in locals() and 'go_namespaces' not in globals():
         term_aspects = np.array(['UNK'] * len(top_terms))
    else:
        ns_to_aspect = {'molecular_function': 'MF', 'biological_process': 'BP', 'cellular_component': 'CC'}
        term_aspects = np.array([ns_to_aspect.get(go_namespaces.get(t, ''), 'UNK') for t in top_terms])

    def ia_weighted_f1(y_true, y_score, thr=0.3):
        y_true = (y_true > 0).astype(np.int8)
        y_pred = (y_score >= thr).astype(np.int8)
        tp = (y_pred & y_true).sum(axis=0).astype(np.float64)
        pred = y_pred.sum(axis=0).astype(np.float64)
        true = y_true.sum(axis=0).astype(np.float64)
        def _score(mask=None):
            w = weights if mask is None else (weights * mask)
            w_tp = float((w * tp).sum())
            w_pred = float((w * pred).sum())
            w_true = float((w * true).sum())
            p = (w_tp / w_pred) if w_pred > 0 else 0.0
            r = (w_tp / w_true) if w_true > 0 else 0.0
            return (2 * p * r / (p + r)) if (p + r) > 0 else 0.0
        out = {'ALL': _score(None)}
        for asp in ['MF', 'BP', 'CC']:
            mask = (term_aspects == asp).astype(np.float32)
            out[asp] = _score(mask)
        return out


In [None]:
# CELL 8b - Phase 2a: Logistic Regression (Optimized)
if TRAIN_LEVEL1:
    print("\n=== Phase 2a: Classical Models (LR/GBDT) ===")
    
    # 1. Create Train X (Memory Optimized)
    log_mem("Before loading Train")
    features_train = load_features_dict(split='train')
    FLAT_KEYS = [k for k in ['t5', 'esm2_650m', 'esm2_3b', 'taxa', 'text'] if k in features_train]
    print(f"Creating Flat X from keys: {FLAT_KEYS}")
    
    # OPTIMIZATION: Use Memory-Mapped File for X to avoid 56GB RAM crash
    X_path = WORK_ROOT / 'features' / 'X_train_mmap.npy'
    
    # Calculate total shape
    n_samples = features_train[FLAT_KEYS[0]].shape[0]
    n_features = sum(features_train[k].shape[1] for k in FLAT_KEYS)
    print(f"Target X shape: ({n_samples}, {n_features})")
    
    # Create mmap file
    X_mmap = np.lib.format.open_memmap(X_path, mode='w+', dtype=np.float32, shape=(n_samples, n_features))
    
    # Fill mmap column by column (or block by block) to save RAM
    current_col = 0
    for k in FLAT_KEYS:
        data = features_train[k]
        dim = data.shape[1]
        print(f"  Writing {k} ({dim} cols) to mmap...")
        X_mmap[:, current_col:current_col+dim] = data
        current_col += dim
        # Free memory immediately
        del features_train[k]
        gc.collect()
        
    del features_train
    X_mmap.flush()
    del X_mmap
    gc.collect()
    log_mem("Created X mmap")
    
    # Load X in read-only mmap mode
    X = np.load(X_path, mmap_mode='r')
    print(f"Loaded X from mmap: {X.shape}")
    
    # ------------------------------------------
    # A. Logistic Regression (Baseline)
    # ------------------------------------------
    print("\n--- Training Logistic Regression ---")
    from sklearn.linear_model import LogisticRegression, SGDClassifier
    from sklearn.multiclass import OneVsRestClassifier
    from sklearn.preprocessing import StandardScaler
    
    try:
        import cuml
        from cuml.linear_model import LogisticRegression as cuLogReg
        from cuml.multiclass import OneVsRestClassifier as cuOVR
        import cupy as cp
        HAS_RAPIDS = True
        print("[AUDITOR] RAPIDS (cuml) detected.")
    except ImportError:
        HAS_RAPIDS = False
        print("[AUDITOR] RAPIDS NOT detected.")

    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    oof_preds_logreg = np.zeros(Y.shape, dtype=np.float32)
    
    for fold, (idx_tr, idx_val) in enumerate(kf.split(X)):
        print(f"LogReg Fold {fold+1}/5")
        
        # Load data slices (RAM usage increases here)
        # X is mmap, so X[idx_tr] reads from disk into RAM.
        # X_tr size ~ 25GB.
        log_mem("Before loading X_tr")
        X_tr = X[idx_tr] # Copy to RAM
        X_val = X[idx_val] # Copy to RAM
        log_mem("Loaded X_tr/X_val")
        
        Y_tr, Y_val = Y[idx_tr], Y[idx_val]
        
        # SCALING
        scaler = StandardScaler()
        X_tr = scaler.fit_transform(X_tr)
        X_val = scaler.transform(X_val)
        
        if HAS_RAPIDS:
            # OPTIMIZATION: Enable RMM
            try:
                import rmm
                rmm.reinitialize(managed_memory=True)
            except: pass

            # Move to GPU and delete CPU copy
            X_tr_gpu = cp.asarray(X_tr)
            X_val_gpu = cp.asarray(X_val)
            del X_tr, X_val
            gc.collect()
            
            # Batch targets
            n_targets = Y.shape[1]
            chunk_size = 2000
            val_probs = np.zeros((Y_val.shape[0], n_targets), dtype=np.float32)
            all_coefs = []
            all_intercepts = []
            
            for start in range(0, n_targets, chunk_size):
                end = min(start + chunk_size, n_targets)
                print(f"    Chunk {start}-{end}...")
                Y_tr_chunk = cp.asarray(Y_tr[:, start:end])
                clf_chunk = cuOVR(cuLogReg(solver='qn', penalty='l2', C=1.0, max_iter=1000, tol=1e-3))
                clf_chunk.fit(X_tr_gpu, Y_tr_chunk)
                
                p_chunk = clf_chunk.predict_proba(X_val_gpu)
                if hasattr(p_chunk, 'get'): p_chunk = p_chunk.get()
                elif hasattr(p_chunk, 'to_numpy'): p_chunk = p_chunk.to_numpy()
                val_probs[:, start:end] = p_chunk
                
                for est in clf_chunk.estimators_:
                    all_coefs.append(est.coef_.to_numpy() if hasattr(est.coef_, 'to_numpy') else est.coef_)
                    all_intercepts.append(est.intercept_.to_numpy() if hasattr(est.intercept_, 'to_numpy') else est.intercept_)
                
                del Y_tr_chunk, clf_chunk, p_chunk
                cp.get_default_memory_pool().free_all_blocks()
            
            del X_tr_gpu, X_val_gpu
            cp.get_default_memory_pool().free_all_blocks()
            
            # Save weights
            model_data = {'coef': np.vstack(all_coefs), 'intercept': np.hstack(all_intercepts)}
            joblib.dump(model_data, WORK_ROOT / 'features' / f'level1_logreg_weights_fold{fold}.pkl')
            
        else:
            # CPU Fallback
            clf_logreg = OneVsRestClassifier(SGDClassifier(loss='log_loss', penalty='l2', alpha=0.0001, max_iter=1000, tol=1e-3, n_jobs=4))
            clf_logreg.fit(X_tr, Y_tr)
            val_probs = clf_logreg.predict_proba(X_val)
            joblib.dump(clf_logreg, WORK_ROOT / 'features' / f'level1_logreg_fold{fold}.pkl')
            
        oof_preds_logreg[idx_val] = val_probs
        
        # Diagnostics
        best_f1 = 0.0
        best_thr = 0.0
        for thr in np.linspace(0.01, 0.20, 20):
            vp = (val_probs > thr).astype(int)
            score = f1_score(Y_val, vp, average='micro')
            if score > best_f1: best_f1, best_thr = score, thr
        
        ia_f1 = ia_weighted_f1(Y_val, val_probs, thr=best_thr)
        print(f"  >> Fold {fold+1} Best F1: {best_f1:.4f} at Thr {best_thr:.2f}")
        print(f"  >> Fold {fold+1} IA-F1: ALL={ia_f1['ALL']:.4f}")
        
        joblib.dump(scaler, WORK_ROOT / 'features' / f'level1_logreg_scaler_fold{fold}.pkl')
        del Y_tr, Y_val, scaler, val_probs
        if 'X_tr' in locals(): del X_tr
        if 'X_val' in locals(): del X_val
        gc.collect()
        
    np.save(WORK_ROOT / 'features' / 'oof_pred_logreg.npy', oof_preds_logreg)
    print("LogReg OOF saved.")

In [None]:
# CELL 8c - Phase 2a: Py-Boost (GBDT)
if TRAIN_LEVEL1:
    # Ensure X is available
    if 'X' not in locals():
        X_path = WORK_ROOT / 'features' / 'X_train_mmap.npy'
        if X_path.exists():
            X = np.load(X_path, mmap_mode='r')
            print(f"Reloaded X from mmap: {X.shape}")
        else:
            raise RuntimeError("X mmap not found. Run Cell 8b first.")

    # ------------------------------------------
    # B. Py-Boost (GBDT)
    # ------------------------------------------
    try:
        from py_boost import GradientBoosting
        HAS_PYBOOST = True
    except ImportError as e:
        raise RuntimeError("CRITICAL: py_boost is missing. GBDT is a mandatory component. Install with 'pip install py-boost'.") from e
        
    if HAS_PYBOOST:
        print("\n--- Training Py-Boost GBDT ---")
        oof_preds_gbdt = np.zeros(Y.shape, dtype=np.float32)
        
        for fold, (idx_tr, idx_val) in enumerate(kf.split(X)):
            print(f"GBDT Fold {fold+1}/5")
            
            # Load data slices
            X_tr = X[idx_tr]
            X_val = X[idx_val]
            Y_tr, Y_val = Y[idx_tr], Y[idx_val]
            
            model = GradientBoosting(
                loss='bce', ntrees=1000, lr=0.05, max_depth=6, 
                verbose=100, es=50, gpu_id=0
            )
            model.fit(X_tr, Y_tr, eval_sets=[{'X': X_val, 'y': Y_val}])
            val_probs = model.predict(X_val)
            oof_preds_gbdt[idx_val] = val_probs
            
            val_preds = (val_probs > 0.3).astype(int)
            f1 = f1_score(Y_val, val_preds, average='micro')
            ia_f1 = ia_weighted_f1(Y_val, val_probs, thr=0.3)
            print(f"  >> Fold {fold+1} micro-F1@0.30: {f1:.4f}")
            print(f"  >> Fold {fold+1} IA-F1: ALL={ia_f1['ALL']:.4f}")
            
            model.save(str(WORK_ROOT / 'features' / f'level1_gbdt_fold{fold}.json'))
            del model, X_tr, X_val, Y_tr, Y_val
            gc.collect()
            
        np.save(WORK_ROOT / 'features' / 'oof_pred_gbdt.npy', oof_preds_gbdt)
        print("GBDT OOF saved.")
        
    # Cleanup X mmap to free file handle
    del X
    gc.collect()

In [None]:
# CELL 8d - Phase 2a Post-Processing: Generate Test Predictions
if TRAIN_LEVEL1:
    print("\n=== Phase 2a Post-Processing: Generate Test Predictions ===")
    log_mem("Before loading Test Features")
    features_test = load_features_dict(split='test')
    log_mem("Loaded Test Features")
    
    # Helper for batched prediction with scaling
    def predict_proba_batched_scaled(model, feat_dict, keys, scaler, batch_size=5000):
        n_test = feat_dict[keys[0]].shape[0]
        preds = []
        is_cuml = False
        try:
            import cuml
            if 'cuml' in str(type(model)) or 'cuml' in str(type(getattr(model, 'estimator', None))):
                is_cuml = True
                import cupy as cp
        except: pass

        for i in range(0, n_test, batch_size):
            X_batch = np.hstack([feat_dict[k][i:i+batch_size] for k in keys]).astype(np.float32)
            X_batch = scaler.transform(X_batch)
            if is_cuml: X_batch = cp.asarray(X_batch)
            p = model.predict_proba(X_batch)
            if is_cuml:
                if hasattr(p, 'get'): p = p.get()
                elif hasattr(p, 'to_numpy'): p = p.to_numpy()
            preds.append(p)
            del X_batch
        return np.vstack(preds)

    # Helper for batched prediction from weights (LogReg)
    def predict_proba_from_weights(weights_path, feat_dict, keys, scaler, batch_size=5000):
        model_data = joblib.load(weights_path)
        coef = model_data['coef'] # (n_classes, n_features)
        intercept = model_data['intercept'] # (n_classes,)
        
        # Use GPU if available
        try:
            import cupy as cp
            coef = cp.asarray(coef)
            intercept = cp.asarray(intercept)
            use_gpu = True
        except:
            use_gpu = False
            
        n_test = feat_dict[keys[0]].shape[0]
        preds = []
        
        for i in range(0, n_test, batch_size):
            X_batch = np.hstack([feat_dict[k][i:i+batch_size] for k in keys]).astype(np.float32)
            X_batch = scaler.transform(X_batch)
            
            if use_gpu:
                X_batch = cp.asarray(X_batch)
                # LogReg: sigmoid(X @ W.T + b)
                logits = cp.dot(X_batch, coef.T) + intercept
                p = 1 / (1 + cp.exp(-logits))
                p = p.get()
            else:
                logits = np.dot(X_batch, coef.T) + intercept
                p = 1 / (1 + np.exp(-logits))
                
            preds.append(p)
            del X_batch
        
        if use_gpu:
            del coef, intercept
            cp.get_default_memory_pool().free_all_blocks()
            
        return np.vstack(preds)

    # A. LogReg Test Preds
    print("Generating LogReg Test Predictions...")
    test_preds_logreg = np.zeros((len(test_ids), Y.shape[1]), dtype=np.float32)
    for fold in range(5):
        print(f"  Loading LogReg Fold {fold+1}...")
        scaler = joblib.load(WORK_ROOT / 'features' / f'level1_logreg_scaler_fold{fold}.pkl')
        
        # Check if we have weights or model
        weights_path = WORK_ROOT / 'features' / f'level1_logreg_weights_fold{fold}.pkl'
        model_path = WORK_ROOT / 'features' / f'level1_logreg_fold{fold}.pkl'
        
        if weights_path.exists():
            probs = predict_proba_from_weights(weights_path, features_test, FLAT_KEYS, scaler)
        elif model_path.exists():
            clf = joblib.load(model_path)
            probs = predict_proba_batched_scaled(clf, features_test, FLAT_KEYS, scaler)
            del clf
        else:
            print(f"  [WARNING] No model found for fold {fold}")
            probs = np.zeros_like(test_preds_logreg) # Should not happen
            
        test_preds_logreg += probs / 5.0
        del scaler, probs
        gc.collect()
    np.save(WORK_ROOT / 'features' / 'test_pred_logreg.npy', test_preds_logreg)
    del test_preds_logreg
    gc.collect()
    print("LogReg Test Preds Saved.")

    # B. GBDT Test Preds
    if HAS_PYBOOST:
        print("Generating GBDT Test Predictions...")
        test_preds_gbdt = np.zeros((len(test_ids), Y.shape[1]), dtype=np.float32)
        
        # Helper for GBDT batched
        def predict_proba_batched_gbdt(model, feat_dict, keys, batch_size=5000):
            n_test = feat_dict[keys[0]].shape[0]
            preds = []
            for i in range(0, n_test, batch_size):
                X_batch = np.hstack([feat_dict[k][i:i+batch_size] for k in keys]).astype(np.float32)
                p = model.predict(X_batch)
                preds.append(p)
                del X_batch
            return np.vstack(preds)

        for fold in range(5):
            print(f"  Loading GBDT Fold {fold+1}...")
            model = GradientBoosting.load(str(WORK_ROOT / 'features' / f'level1_gbdt_fold{fold}.json'))
            probs = predict_proba_batched_gbdt(model, features_test, FLAT_KEYS)
            test_preds_gbdt += probs / 5.0
            del model, probs
            gc.collect()
        np.save(WORK_ROOT / 'features' / 'test_pred_gbdt.npy', test_preds_gbdt)
        del test_preds_gbdt
        gc.collect()
        print("GBDT Test Preds Saved.")
        
    # Cleanup Test Features
    del features_test
    gc.collect()


In [None]:
# CELL 8e - Phase 2b: Deep Learning (DNN)
if TRAIN_LEVEL1:
    print("\n=== Phase 2b: Deep Learning (DNN) ===")
    
    # 1. Reload dicts
    features_train, features_test = load_features_dict()
    
    # ------------------------------------------
    # C. DNN Ensemble (PyTorch, IA-weighted, multimodal, multi-state)
    # ------------------------------------------
    print("\n--- Training DNN Ensemble (IA-weighted, multimodal, multi-state) ---")
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Build a stable per-label IA weight vector for the current TOP_K targets
    ia_w = weights.copy()
    ia_w = np.where(np.isfinite(ia_w) & (ia_w > 0), ia_w, 1.0).astype(np.float32)
    ia_w = ia_w / float(np.mean(ia_w))
    ia_w = np.clip(ia_w, 0.5, 5.0)
    ia_w_t = torch.tensor(ia_w, dtype=torch.float32, device=device).view(1, -1)
    
    # Optional: include other model predictions as an input stream (PB OOFs analogue)
    USE_BASE_OOFS_IN_DNN = True
    if USE_BASE_OOFS_IN_DNN and (WORK_ROOT / 'features' / 'oof_pred_logreg.npy').exists():
        oof_stream = [np.load(WORK_ROOT / 'features' / 'oof_pred_logreg.npy').astype(np.float32)]
        test_stream = [np.load(WORK_ROOT / 'features' / 'test_pred_logreg.npy').astype(np.float32)]
        if (WORK_ROOT / 'features' / 'oof_pred_gbdt.npy').exists():
            oof_stream.append(np.load(WORK_ROOT / 'features' / 'oof_pred_gbdt.npy').astype(np.float32))
            test_stream.append(np.load(WORK_ROOT / 'features' / 'test_pred_gbdt.npy').astype(np.float32))
        base_oof = np.hstack(oof_stream)
        base_test = np.hstack(test_stream)
        features_train['base_oof'] = base_oof
        features_test['base_oof'] = base_test
        print(f"Base OOF stream: train={base_oof.shape} test={base_test.shape}")
        
    # Select modality keys for the DNN (towers)
    DNN_KEYS = [k for k in ['t5', 'esm2_650m', 'esm2_3b', 'taxa', 'text', 'base_oof'] if k in features_train]
    print(f"DNN modality keys={DNN_KEYS}")
    
    class Tower(nn.Module):
        def __init__(self, in_dim, out_dim=512, dropout=0.1):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(in_dim, 1024),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(1024, out_dim),
                nn.ReLU(),
            )
        def forward(self, x):
            return self.net(x)
            
    class ColossalMultiModalDNN(nn.Module):
        def __init__(self, dims: dict, output_dim: int):
            super().__init__()
            self.keys = list(dims.keys())
            self.towers = nn.ModuleDict({k: Tower(dims[k]) for k in self.keys})
            fused_dim = 512 * len(self.keys)
            self.head = nn.Sequential(
                nn.Linear(fused_dim, 2048),
                nn.BatchNorm1d(2048),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(2048, 1024),
                nn.BatchNorm1d(1024),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(1024, output_dim),
            )
        def forward(self, batch: dict):
            hs = [self.towers[k](batch[k]) for k in self.keys]
            h = torch.cat(hs, dim=1)
            return self.head(h)
            
    # Prepare torch tensors per modality
    train_t = {k: torch.tensor(features_train[k], dtype=torch.float32, device=device) for k in DNN_KEYS}
    test_t = {k: torch.tensor(features_test[k], dtype=torch.float32, device=device) for k in DNN_KEYS}
    
    def _batch_dict(tensors: dict, idx):
        return {k: v[idx] for k, v in tensors.items()}
        
    # Multi-state ensembling
    DNN_SEEDS = [42, 43, 44, 45, 46]
    DNN_EPOCHS = 10
    BATCH_SIZE = 256
    oof_sum = np.zeros(Y.shape, dtype=np.float32)
    test_sum = np.zeros((len(test_ids), Y.shape[1]), dtype=np.float32)
    n_states = len(DNN_SEEDS)
    
    for state_i, seed in enumerate(DNN_SEEDS, 1):
        print(f"\n[DNN] Random state {state_i}/{n_states}: seed={seed}")
        torch.manual_seed(seed)
        np.random.seed(seed)
        kf_state = KFold(n_splits=5, shuffle=True, random_state=seed)
        oof_state = np.zeros(Y.shape, dtype=np.float32)
        test_state = np.zeros((len(test_ids), Y.shape[1]), dtype=np.float32)
        dims = {k: int(features_train[k].shape[1]) for k in DNN_KEYS}
        
        for fold, (idx_tr, idx_val) in enumerate(kf_state.split(train_ids)):
            print(f"DNN Fold {fold+1}/5")
            model = ColossalMultiModalDNN(dims=dims, output_dim=Y.shape[1]).to(device)
            optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
            Y_full_t = torch.tensor(Y, dtype=torch.float32, device=device)
            n_samples = len(idx_tr)
            model.train()
            idx_tr_t = torch.tensor(idx_tr, dtype=torch.long, device=device)
            
            for _epoch in range(DNN_EPOCHS):
                perm = idx_tr_t[torch.randperm(n_samples, device=device)]
                for i in range(0, n_samples, BATCH_SIZE):
                    b = perm[i:i + BATCH_SIZE]
                    optimizer.zero_grad()
                    logits = model(_batch_dict(train_t, b))
                    yb = Y_full_t[b]
                    loss_el = F.binary_cross_entropy_with_logits(logits, yb, reduction='none')
                    loss = (loss_el * ia_w_t).mean()
                    loss.backward()
                    optimizer.step()
                    
            model.eval()
            with torch.no_grad():
                idx_val_t = torch.tensor(idx_val, dtype=torch.long, device=device)
                val_probs = torch.sigmoid(model(_batch_dict(train_t, idx_val_t))).cpu().numpy()
                oof_state[idx_val] = val_probs
                test_probs = torch.sigmoid(model(test_t)).cpu().numpy()
                test_state += test_probs / kf_state.get_n_splits()
                
            val_preds = (val_probs > 0.3).astype(int)
            f1 = f1_score(Y[idx_val], val_preds, average='micro')
            ia_f1 = ia_weighted_f1(Y[idx_val], val_probs, thr=0.3)
            print(f"  >> Fold {fold+1} micro-F1@0.30: {f1:.4f}")
            print(f"  >> Fold {fold+1} IA-F1: ALL={ia_f1['ALL']:.4f}")
            
            torch.save(model.state_dict(), WORK_ROOT / 'features' / f'level1_dnn_seed{seed}_fold{fold}.pth')
            
        oof_sum += oof_state
        test_sum += test_state
        
    oof_preds_dnn = (oof_sum / n_states).astype(np.float32)
    test_preds_dnn = (test_sum / n_states).astype(np.float32)
    np.save(WORK_ROOT / 'features' / 'oof_pred_dnn.npy', oof_preds_dnn)
    np.save(WORK_ROOT / 'features' / 'test_pred_dnn.npy', test_preds_dnn)
    print("DNN OOF + test preds saved (multi-state averaged).")
    
    # Persist term list
    with open(WORK_ROOT / 'features' / 'top_terms_1500.json', 'w') as f:
        json.dump(top_terms, f)
        
    if 'STORE' in globals() and STORE is not None:
        req = [WORK_ROOT / 'features' / 'top_terms_1500.json', WORK_ROOT / 'features' / 'oof_pred_logreg.npy', WORK_ROOT / 'features' / 'test_pred_logreg.npy', WORK_ROOT / 'features' / 'oof_pred_dnn.npy', WORK_ROOT / 'features' / 'test_pred_dnn.npy']
        req += [WORK_ROOT / 'features' / 'oof_pred_gbdt.npy', WORK_ROOT / 'features' / 'test_pred_gbdt.npy']
        STORE.push('stage_07_level1_preds', req, note='Level-1 OOF + test preds (LR/GBDT/DNN)')
        
    print("Phase 2 Complete.")
else:
    print("Skipping Phase 2.")

In [None]:
# CELL 9 - Phase 3: Hierarchy-Aware Stacking (GCN)
# =============================================================================
# PHASE 3: GCN STACKER (Hierarchy-Aware Refinement)
# =============================================================================
# Strategy:
# 1. Input: OOF predictions from Level 1 models (LogReg, GBDT, DNN) -> Shape (N, K, 3)
# 2. Graph: GO Ontology Adjacency Matrix (K, K)
# 3. Model: GCN that refines predictions based on graph structure
# 4. Output: Refined probabilities (N, K)

TRAIN_STACKER = True

if TRAIN_STACKER and TRAIN_LEVEL1:
    print("\n=== Phase 3: Hierarchy-Aware Stacking (GCN) ===")
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import networkx as nx
    
    # 1. Load OOF Predictions
    print("Loading Level 1 OOFs...")
    oof_files = ['oof_pred_logreg.npy', 'oof_pred_gbdt.npy', 'oof_pred_dnn.npy']
    oofs = []
    for f in oof_files:
        p = WORK_ROOT / 'features' / f
        if p.exists():
            oofs.append(np.load(p).astype(np.float32))
            print(f"  Loaded {f}")
    
    if not oofs:
        raise RuntimeError("CRITICAL: No Level 1 OOF predictions found (LogReg/GBDT/DNN). Cannot train Stacker. Check Phase 2 execution.")
    else:
        # Stack: (N, K, M) where M is number of models
        X_stack = np.stack(oofs, axis=2) 
        print(f"Stacker Input Shape: {X_stack.shape} (Samples, Terms, Models)")
        
        # 2. Build Adjacency Matrix
        print("Building GO Graph Adjacency Matrix...")
        try:
            import obonet
            obo_path = WORK_ROOT.parent / 'go-basic.obo'
            if not obo_path.exists(): obo_path = Path('go-basic.obo')
            graph = obonet.read_obo(obo_path)
            
            # Create subgraph for our K terms
            # Map term to index
            term_to_idx = {t: i for i, t in enumerate(top_terms)}
            
            # Build adjacency (A_ij = 1 if i is parent of j or j is parent of i? GCN usually undirected or directed?)
            # Standard GCN: Undirected + Self-loops.
            # But GO is directed (Child -> Parent).
            # We want information to flow both ways? Or mostly Child -> Parent (consistency)?
            # Let's use a symmetric normalized adjacency for standard GCN.
            
            adj = np.eye(len(top_terms), dtype=np.float32) # Self-loops
            
            # Fill edges
            edges_count = 0
            for child in top_terms:
                if child in graph:
                    for parent in graph.successors(child): # 'is_a' points to parent in obonet/networkx
                        if parent in term_to_idx:
                            i, j = term_to_idx[child], term_to_idx[parent]
                            adj[i, j] = 1.0
                            adj[j, i] = 1.0 # Symmetric
                            edges_count += 1
            
            print(f"Graph built. Nodes: {len(top_terms)}, Edges (in subset): {edges_count}")
            
            # Normalize Adjacency: D^{-1/2} A D^{-1/2}
            D = np.sum(adj, axis=1)
            D_inv_sqrt = np.power(D, -0.5)
            D_inv_sqrt[np.isinf(D_inv_sqrt)] = 0.
            D_mat = np.diag(D_inv_sqrt)
            A_norm = D_mat @ adj @ D_mat
            A_norm = torch.tensor(A_norm, dtype=torch.float32, device=device)
            
        except Exception as e:
            print(f"[WARNING] Failed to build graph: {e}. Using Identity (No-op GCN).")
            A_norm = torch.eye(len(top_terms), device=device)

        # 3. Define GCN Model
        class GCNStacker(nn.Module):
            def __init__(self, n_models, n_hidden=16):
                super().__init__()
                # We apply GCN per protein.
                # Input features per node: [p_logreg, p_gbdt, p_dnn] (dim=3)
                # We want to learn a mixing weight + graph smoothing.
                
                # Simple 2-layer GCN
                self.gc1 = nn.Linear(n_models, n_hidden)
                self.gc2 = nn.Linear(n_hidden, 1) # Output 1 score per node
                self.relu = nn.ReLU()
                
            def forward(self, x, adj):
                # x: (Batch, Nodes, Models)
                # adj: (Nodes, Nodes)
                
                # Layer 1: H1 = ReLU(A X W1)
                # Support = X W1
                # Output = A Support
                
                # Batch matmul is tricky with static adj.
                # x: (B, N, M)
                # W1: (M, H)
                # x @ W1 -> (B, N, H)
                support = self.gc1(x) 
                
                # A @ support
                # A: (N, N)
                # support: (B, N, H) -> permute to (B, H, N) for matmul?
                # Or just (A @ support[i]) for each i?
                # Einsum: n k, b k h -> b n h
                out1 = torch.einsum('nk,bkh->bnh', adj, support)
                out1 = self.relu(out1)
                
                # Layer 2
                support2 = self.gc2(out1) # (B, N, 1)
                out2 = torch.einsum('nk,bkh->bnh', adj, support2)
                
                return out2.squeeze(-1) # (B, N)

        # 4. Train Stacker
        print("Training GCN Stacker...")
        # We split the OOFs into train/val for the stacker?
        # Actually, OOFs are already "test-like" for the whole train set.
        # We can train on the whole train set (since OOFs were generated via CV).
        
        # Targets
        Y_tensor = torch.tensor(Y, dtype=torch.float32, device=device)
        X_stack_tensor = torch.tensor(X_stack, dtype=torch.float32, device=device)
        
        model = GCNStacker(n_models=X_stack.shape[2]).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
        loss_fn = nn.BCEWithLogitsLoss()
        
        BATCH_SIZE = 128
        N_EPOCHS = 20
        
        dataset = torch.utils.data.TensorDataset(X_stack_tensor, Y_tensor)
        loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
        
        model.train()
        for epoch in range(N_EPOCHS):
            total_loss = 0
            for xb, yb in loader:
                optimizer.zero_grad()
                logits = model(xb, A_norm)
                loss = loss_fn(logits, yb)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            if (epoch+1) % 5 == 0:
                print(f"  Epoch {epoch+1}/{N_EPOCHS} Loss: {total_loss/len(loader):.4f}")
                
        # 5. Generate Final Predictions (on Test)
        print("Generating Final Stacker Predictions on Test...")
        # Load Test Preds
        test_files = ['test_pred_logreg.npy', 'test_pred_gbdt.npy', 'test_pred_dnn.npy']
        tests = []
        for f in test_files:
            p = WORK_ROOT / 'features' / f
            if p.exists():
                tests.append(np.load(p).astype(np.float32))
        
        if tests:
            X_test_stack = np.stack(tests, axis=2)
            X_test_tensor = torch.tensor(X_test_stack, dtype=torch.float32, device=device)
            
            model.eval()
            final_preds = []
            with torch.no_grad():
                # Process in chunks to avoid OOM
                for i in range(0, len(X_test_tensor), BATCH_SIZE):
                    batch = X_test_tensor[i:i+BATCH_SIZE]
                    logits = model(batch, A_norm)
                    probs = torch.sigmoid(logits).cpu().numpy()
                    final_preds.append(probs)
            
            final_preds = np.vstack(final_preds)
            np.save(WORK_ROOT / 'features' / 'final_pred_gcn.npy', final_preds)
            print(f"Final GCN Predictions Saved: {final_preds.shape}")
            
            # Clean up
            del X_stack_tensor, Y_tensor, X_test_tensor, model, A_norm
            torch.cuda.empty_cache()
        else:
            raise RuntimeError("CRITICAL: No Level 1 Test predictions found. Cannot generate final submission.")

else:
    print("Skipping Phase 3 (Stacker).")

In [None]:
# CELL 10 - Phase 4: Strict Post-Processing & Submission
# =============================================================================
# PHASE 4: STRICT POST-PROCESSING (Hierarchy Enforcement)
# =============================================================================
# 1. Max Propagation (Parent Rule): P(parent) >= P(child)
# 2. Min Propagation (Child Rule): P(child) <= P(parent)
# 3. Top-1500 Selection
# 4. Submission Generation

APPLY_POSTPROC = True

if APPLY_POSTPROC:
    print("\n=== Phase 4: Strict Post-Processing ===")
    
    # Load predictions (prefer GCN, fallback to DNN/LogReg average)
    if (WORK_ROOT / 'features' / 'final_pred_gcn.npy').exists():
        print("Loading GCN predictions...")
        preds = np.load(WORK_ROOT / 'features' / 'final_pred_gcn.npy')
    elif (WORK_ROOT / 'features' / 'test_pred_dnn.npy').exists():
        print("Loading DNN predictions (GCN missing)...")
        preds = np.load(WORK_ROOT / 'features' / 'test_pred_dnn.npy')
    else:
        print("Loading LogReg predictions (Fallback)...")
        preds = np.load(WORK_ROOT / 'features' / 'test_pred_logreg.npy')
        
    print(f"Raw Predictions Shape: {preds.shape}")
    
    # Load Graph for Propagation
    try:
        import obonet
        import networkx as nx
        obo_path = WORK_ROOT.parent / 'go-basic.obo'
        if not obo_path.exists(): obo_path = Path('go-basic.obo')
        graph = obonet.read_obo(obo_path)
        
        # Map term to index
        term_to_idx = {t: i for i, t in enumerate(top_terms)}
        
        # Precompute parent/child indices for fast propagation
        # We need a topological sort or just iterate multiple times?
        # Since DAG depth is small, iterating 2-3 times is usually enough.
        # Or we can just use the edge list.
        
        # Build parent map: child_idx -> [parent_indices]
        child_to_parents = {}
        parent_to_children = {}
        
        for child in top_terms:
            if child in graph:
                c_idx = term_to_idx[child]
                # Parents
                parents = [p for p in graph.successors(child) if p in term_to_idx]
                if parents:
                    child_to_parents[c_idx] = [term_to_idx[p] for p in parents]
                
                # Children (predecessors in networkx/obonet)
                children = [c for c in graph.predecessors(child) if c in term_to_idx]
                if children:
                    parent_to_children[c_idx] = [term_to_idx[c] for c in children]
                    
        print(f"Hierarchy constraints: {len(child_to_parents)} terms have parents in subset.")
        
        # 1. Max Propagation (Child -> Parent)
        # Ensure Parent >= Child
        # Iterate a few times to propagate up
        print("Applying Max Propagation (Child -> Parent)...")
        for _ in range(3):
            for c_idx, p_indices in child_to_parents.items():
                # Vectorized update for all samples
                # P(parent) = max(P(parent), P(child))
                # We can do this efficiently?
                # preds[:, p_indices] = np.maximum(preds[:, p_indices], preds[:, c_idx:c_idx+1])
                # This is slow in python loop.
                pass
        
        # Optimized Propagation (Matrix based?)
        # Since we have 13.5k terms, a loop is okay if we batch operations?
        # Actually, iterating over 13k terms is fast in Python (milliseconds).
        # The operation inside is numpy array (20k samples).
        
        # Let's do a topological sort order to do it in 1 pass?
        # Subgraph of top_terms
        subgraph = graph.subgraph(top_terms)
        try:
            topo_order = list(nx.topological_sort(subgraph))
            # Reverse topo order for Child -> Parent (Leaves first)
            topo_order_rev = topo_order[::-1]
            
            print("  Optimized Max Prop (1 pass)...")
            for term in topo_order_rev:
                if term not in term_to_idx: continue
                c_idx = term_to_idx[term]
                # Propagate to parents
                if term in child_to_parents: # Use precomputed map
                    # parents = child_to_parents[term] # Wait, map uses indices
                    pass
                
                # Use graph directly
                parents = [p for p in graph.successors(term) if p in term_to_idx]
                if not parents: continue
                
                p_indices = [term_to_idx[p] for p in parents]
                # preds[:, p_indices] = np.maximum(preds[:, p_indices], preds[:, c_idx, None])
                # Use broadcasting
                child_val = preds[:, c_idx:c_idx+1]
                preds[:, p_indices] = np.maximum(preds[:, p_indices], child_val)
                
        except nx.NetworkXUnfeasible:
            print("  [WARNING] Cycle detected or graph issue. Skipping topological prop.")
            
        # 2. Min Propagation (Parent -> Child)
        # Ensure Child <= Parent
        # Iterate Root -> Leaves (Topo order)
        print("Applying Min Propagation (Parent -> Child)...")
        try:
            for term in topo_order:
                if term not in term_to_idx: continue
                p_idx = term_to_idx[term]
                
                children = [c for c in graph.predecessors(term) if c in term_to_idx]
                if not children: continue
                
                c_indices = [term_to_idx[c] for c in children]
                parent_val = preds[:, p_idx:p_idx+1]
                preds[:, c_indices] = np.minimum(preds[:, c_indices], parent_val)
                
        except: pass

    except Exception as e:
        print(f"[WARNING] Post-processing failed: {e}")
        
    # 3. Submission Formatting
    print("Formatting Submission...")
    # Clip to (0, 1]
    preds = np.clip(preds, 0.0, 1.0)
    
    # Create submission DataFrame
    # Format: ProteinID, TermID, Score
    # We need to melt the matrix. This is huge (20k * 13.5k = 270M rows).
    # We only keep top 1500 per protein?
    
    # Strategy:
    # 1. For each protein, find top 1500 indices.
    # 2. Create sparse rows.
    
    submission_rows = []
    test_ids_list = test_ids.tolist()
    
    print("Selecting Top-1500 per protein...")
    for i, pid in enumerate(test_ids_list):
        # Get scores for this protein
        scores = preds[i]
        # Find top 1500 indices
        # argpartition is faster than sort
        if len(scores) > 1500:
            top_indices = np.argpartition(scores, -1500)[-1500:]
        else:
            top_indices = np.arange(len(scores))
            
        # Filter out zero scores?
        # top_indices = top_indices[scores[top_indices] > 0.001]
        
        for idx in top_indices:
            score = scores[idx]
            if score > 0: # Only positive
                term = top_terms[idx]
                submission_rows.append((pid, term, f"{score:.3f}"))
                
        if (i+1) % 1000 == 0:
            print(f"  Processed {i+1}/{len(test_ids_list)} proteins...")
            
    # Write to file
    print("Writing submission.tsv...")
    with open('submission.tsv', 'w') as f:
        # No header usually? Or check sample_submission
        # CAFA format: no header? Or header?
        # Sample submission has header?
        # Let's check sample_submission.tsv
        pass
        
    # Just write standard TSV
    df_sub = pd.DataFrame(submission_rows, columns=['Protein Id', 'GO Term Id', 'Prediction'])
    df_sub.to_csv('submission.tsv', sep='\t', index=False, header=False)
    print(f"Submission saved: {len(df_sub)} rows.")
    
    # Zip it
    os.system('zip submission.zip submission.tsv')
    print("Zipped to submission.zip")

else:
    print("Skipping Post-Processing.")


In [None]:
# CELL 13E - KNN (cosine; ESM2-3B)
# Notes:
# - Uses in-memory features from CELL 13 (features_train/features_test, Y, top_terms).
# - Checkpoint pushing is controlled globally by CAFA_CHECKPOINT_PUSH (default 0 in this notebook).
if not TRAIN_LEVEL1:
    print('Skipping KNN (TRAIN_LEVEL1=False).')
else:
    import os
    import json
    from sklearn.neighbors import NearestNeighbors
    from sklearn.model_selection import KFold

    FORCE_REBUILD = (os.getenv('CAFA_FORCE_REBUILD', '0').strip() == '1')

    PRED_DIR = WORK_ROOT / 'features' / 'level1_preds'
    PRED_DIR.mkdir(parents=True, exist_ok=True)

    knn_oof_path = PRED_DIR / 'oof_pred_knn.npy'
    knn_test_path = PRED_DIR / 'test_pred_knn.npy'

    # Backwards-compatible copies (some downstream code loads from WORK_ROOT/features)
    knn_oof_compat = WORK_ROOT / 'features' / 'oof_pred_knn.npy'
    knn_test_compat = WORK_ROOT / 'features' / 'test_pred_knn.npy'

    if knn_oof_path.exists() and knn_test_path.exists() and (not FORCE_REBUILD):
        print('KNN preds exist; skipping training (set CAFA_FORCE_REBUILD=1 to force).')
        oof_pred_knn = np.load(knn_oof_path)
        test_pred_knn = np.load(knn_test_path)
        oof_max_sim = None
    else:
        if 'features_train' not in globals() or 'features_test' not in globals():
            raise RuntimeError('Missing `features_train`/`features_test`. Run CELL 13 first.')
        if 'esm2_3b' not in features_train:
            raise FileNotFoundError("Missing required modality 'esm2_3b' in features_train. Ensure features/train_embeds_esm2_3b.npy exists.")

        X_knn = features_train['esm2_3b'].astype(np.float32)
        X_knn_test = features_test['esm2_3b'].astype(np.float32)

        # Enforce TOP_K alignment using the persisted term list.
        top_terms_path = WORK_ROOT / 'features' / 'top_terms_1500.json'
        if top_terms_path.exists():
            top_terms_knn = json.loads(top_terms_path.read_text())
            if Y.shape[1] != len(top_terms_knn):
                raise ValueError(f'KNN shape mismatch: Y has {Y.shape[1]} cols but top_terms_1500.json has {len(top_terms_knn)} terms.')
        else:
            if 'top_terms' not in globals():
                raise RuntimeError('Missing top_terms_1500.json and in-memory top_terms. Run CELL 13 first.')
            top_terms_knn = list(top_terms)
            if Y.shape[1] != len(top_terms_knn):
                raise ValueError(f'KNN shape mismatch: Y has {Y.shape[1]} cols but top_terms has {len(top_terms_knn)} terms.')

        # KNN needs binary targets (presence/absence), not counts.
        Y_knn = (Y > 0).astype(np.float32)

        def _l2_norm(x: np.ndarray, eps: float = 1e-12) -> np.ndarray:
            n = np.linalg.norm(x, axis=1, keepdims=True)
            return x / np.maximum(n, eps)

        # Cosine distance is best-behaved on L2-normalised vectors
        X_knn = _l2_norm(X_knn)
        X_knn_test = _l2_norm(X_knn_test)

        KNN_K = int(os.getenv('CAFA_KNN_K', '50'))
        KNN_BATCH = int(os.getenv('CAFA_KNN_BATCH', '256'))

        n_splits = 5
        kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)

        oof_pred_knn = np.zeros((X_knn.shape[0], Y_knn.shape[1]), dtype=np.float32)
        test_pred_knn = np.zeros((X_knn_test.shape[0], Y_knn.shape[1]), dtype=np.float32)
        oof_max_sim = np.zeros((X_knn.shape[0],), dtype=np.float32)

        for fold, (tr_idx, va_idx) in enumerate(kf.split(X_knn), start=1):
            print(f'Fold {fold}/{n_splits} (KNN)')
            knn = NearestNeighbors(n_neighbors=KNN_K, metric='cosine', n_jobs=4)
            knn.fit(X_knn[tr_idx])

            dists, neigh = knn.kneighbors(X_knn[va_idx], return_distance=True)
            sims = np.clip((1.0 - dists).astype(np.float32), 0.0, 1.0)
            oof_max_sim[va_idx] = sims.max(axis=1)
            neigh_global = tr_idx[neigh]  # map to global row indices into Y_knn

            for i in range(0, len(va_idx), KNN_BATCH):
                j = min(i + KNN_BATCH, len(va_idx))
                neigh_b = neigh_global[i:j]
                sims_b = sims[i:j]
                denom = np.maximum(sims_b.sum(axis=1, keepdims=True), 1e-8)
                Y_nei = Y_knn[neigh_b]  # (B, K, L)
                scores = (np.einsum('bk,bkl->bl', sims_b, Y_nei) / denom).astype(np.float32)
                oof_pred_knn[va_idx[i:j]] = scores

            if 'ia_weighted_f1' in globals():
                print('  IA-F1:', ia_weighted_f1(Y_knn[va_idx], oof_pred_knn[va_idx], thr=0.3))

        # Final model on full train -> test
        knn_final = NearestNeighbors(n_neighbors=KNN_K, metric='cosine', n_jobs=4)
        knn_final.fit(X_knn)
        dists_te, neigh_te = knn_final.kneighbors(X_knn_test, return_distance=True)
        sims_te = np.clip((1.0 - dists_te).astype(np.float32), 0.0, 1.0)
        denom_te = np.maximum(sims_te.sum(axis=1, keepdims=True), 1e-8)

        for i in range(0, X_knn_test.shape[0], KNN_BATCH):
            j = min(i + KNN_BATCH, X_knn_test.shape[0])
            neigh_b = neigh_te[i:j]
            sims_b = sims_te[i:j]
            Y_nei = Y_knn[neigh_b]
            scores = (np.einsum('bk,bkl->bl', sims_b, Y_nei) / denom_te[i:j]).astype(np.float32)
            test_pred_knn[i:j] = scores

        np.save(knn_oof_path, oof_pred_knn)
        np.save(knn_test_path, test_pred_knn)
        np.save(knn_oof_compat, oof_pred_knn)
        np.save(knn_test_compat, test_pred_knn)
        print('Saved:', knn_oof_path)
        print('Saved:', knn_test_path)

    # Checkpoint push is controlled by CAFA_CHECKPOINT_PUSH inside STORE.push; no extra guard needed here.
    if 'STORE' in globals() and STORE is not None:
        STORE.push(
            stage='stage_07d_level1_knn',
            required_paths=[
                str((WORK_ROOT / 'features' / 'top_terms_1500.json').as_posix()),
                str(knn_oof_path.as_posix()),
                str(knn_test_path.as_posix()),
            ],
            note='Level-1 KNN (cosine) predictions using ESM2-3B embeddings (OOF + test).',
        )

    # Diagnostics: similarity distribution + IA-F1 vs threshold
    try:
        import matplotlib.pyplot as plt

        if oof_max_sim is not None:
            plt.figure(figsize=(10, 4))
            plt.hist(oof_max_sim, bins=50)
            plt.title('KNN OOF diagnostic: max cosine similarity to neighbours (per protein)')
            plt.xlabel('max similarity')
            plt.ylabel('count')
            plt.grid(True, alpha=0.3)
            plt.show()

        if 'ia_weighted_f1' in globals():
            thrs = np.linspace(0.05, 0.60, 23)
            curves = {k: [] for k in ['ALL', 'MF', 'BP', 'CC']}
            for thr in thrs:
                s = ia_weighted_f1(Y, oof_pred_knn, thr=float(thr))
                for k in curves.keys():
                    curves[k].append(s[k])
            plt.figure(figsize=(10, 3))
            for k in ['ALL', 'MF', 'BP', 'CC']:
                plt.plot(thrs, curves[k], label=k)
            plt.title('KNN OOF: IA-weighted F1 vs threshold')
            plt.xlabel('threshold')
            plt.ylabel('IA-F1')
            plt.legend()
            plt.grid(True, alpha=0.3)
            plt.show()
    except Exception as e:
        print('KNN diagnostics skipped:', repr(e))

In [None]:
# CELL 15 - Solution: 5b. PHASE 3: HIERARCHY-AWARE STACKING (GRAPH SMOOTHING GCN)
# 5b. PHASE 3: HIERARCHY-AWARE STACKING (GRAPH SMOOTHING GCN)
# =========================================================
# Option B strictness: if PROCESS_EXTERNAL=True, we REQUIRE the propagated prior files to exist.
PROCESS_EXTERNAL = globals().get('PROCESS_EXTERNAL', True)
import numpy as np
import pandas as pd
import json
from pathlib import Path
from sklearn.metrics import f1_score
import torch
import torch.nn as nn

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
TOP_K = 1500

def _load_level1_pred(fname: str):
    """Load Level-1 prediction arrays from either features/level1_preds or features/."""
    cand = [
        WORK_ROOT / 'features' / 'level1_preds' / fname,
        WORK_ROOT / 'features' / fname,
    ]
    for p in cand:
        if p.exists():
            return np.load(p).astype(np.float32)
    return None

# 1. Load Level-1 predictions
train_feats = []
for fname in ['oof_pred_logreg.npy', 'oof_pred_gbdt.npy', 'oof_pred_dnn.npy', 'oof_pred_knn.npy']:
    arr = _load_level1_pred(fname)
    if arr is not None:
        train_feats.append(arr)
if not train_feats:
    raise FileNotFoundError('No Level-1 OOF predictions found. Run Phase 2 first.')
X_stack = np.mean(train_feats, axis=0).astype(np.float32)

# 2. Build Y label matrix for top-K terms
train_terms = pd.read_parquet(WORK_ROOT / 'parsed' / 'train_terms.parquet')
train_ids = pd.read_feather(WORK_ROOT / 'parsed' / 'train_seq.feather')['id'].astype(str)
top_terms = train_terms['term'].value_counts().head(TOP_K).index.tolist()
train_terms_top = train_terms[train_terms['term'].isin(top_terms)]
Y_df = train_terms_top.pivot_table(index='EntryID', columns='term', aggfunc='size', fill_value=0)
Y_df = Y_df.reindex(train_ids, fill_value=0)
Y = Y_df.values.astype(np.float32)

# 2b. External priors (Phase 1 Step 4 outputs) -> inject as *conservative* extra signal
EXTERNAL_PRIOR_WEIGHT = 0.25
ext_dir = WORK_ROOT / 'external'
prior_train_path = ext_dir / 'prop_train_no_kaggle.tsv.gz'
if PROCESS_EXTERNAL:
    if not prior_train_path.exists():
        raise FileNotFoundError(
            f'Option B requires external priors, but missing: {prior_train_path}. '
            'Run Phase 1 Step 4 propagation or ensure your checkpoint dataset contains these files (run setup: STORE.pull()).',
        )
    prior_train = pd.read_csv(prior_train_path, sep='\t')
    prior_train = prior_train[prior_train['term'].isin(top_terms)]
    prior_mat = prior_train.pivot_table(index='EntryID', columns='term', values='score', aggfunc='max', fill_value=0.0)
    prior_mat = prior_mat.reindex(train_ids.tolist(), fill_value=0.0)
    prior_mat = prior_mat.reindex(columns=top_terms, fill_value=0.0)
    prior_np = prior_mat.values.astype(np.float32)
    X_stack = np.maximum(X_stack, EXTERNAL_PRIOR_WEIGHT * prior_np)
    print(f'Injected external IEA prior into train stack (weight={EXTERNAL_PRIOR_WEIGHT}).')
(WORK_ROOT / 'features').mkdir(parents=True, exist_ok=True)
with open(WORK_ROOT / 'features' / 'top_terms_1500.json', 'w') as f:
    json.dump(top_terms, f)
print('Saved: top_terms_1500.json')

# 3. Graph adjacency from go-basic.obo (reload if needed)
if 'go_parents' not in locals() or 'go_namespaces' not in locals():
    print('Reloading GO graph (parse_obo)...')
    def parse_obo(path: Path):
        parents = {}
        namespaces = {}
        cur_id, cur_ns = None, None
        with path.open('r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line == '[Term]':
                    if cur_id and cur_ns:
                        namespaces[cur_id] = cur_ns
                    cur_id, cur_ns = None, None
                elif line.startswith('id: GO:'):
                    cur_id = line.split('id: ', 1)[1]
                elif line.startswith('namespace:'):
                    cur_ns = line.split('namespace: ', 1)[1]
                elif line.startswith('is_a:') and cur_id:
                    parent = line.split('is_a: ', 1)[1].split(' ! ')[0]
                    parents.setdefault(cur_id, set()).add(parent)
            if cur_id and cur_ns:
                namespaces[cur_id] = cur_ns
        return parents, namespaces
    go_parents, go_namespaces = parse_obo(PATH_GO_OBO)
def build_adjacency(terms_list, parents_dict):
    term_to_idx = {t: i for i, t in enumerate(terms_list)}
    n_terms = len(terms_list)
    src, dst = [], []
    for child in terms_list:
        parents = parents_dict.get(child, set())
        if not parents:
            continue
        child_idx = term_to_idx[child]
        for parent in parents:
            if parent in term_to_idx:
                parent_idx = term_to_idx[parent]
                src.append(child_idx); dst.append(parent_idx)
                src.append(parent_idx); dst.append(child_idx)
    src.extend(range(n_terms))
    dst.extend(range(n_terms))
    indices = torch.tensor([src, dst], dtype=torch.long)
    values = torch.ones(len(src), dtype=torch.float32)
    return torch.sparse_coo_tensor(indices, values, (n_terms, n_terms)).coalesce().to(device)
class SimpleGCN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, adj_matrix):
        super().__init__()
        self.adj = adj_matrix
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        x = torch.sparse.mm(self.adj, x.t()).t()
        return torch.sigmoid(x)

# 4. Build test stack once (and inject external priors once), then split by ontology
print('\nPreparing test stack...')
test_feats = []
for fname in ['test_pred_logreg.npy', 'test_pred_gbdt.npy', 'test_pred_dnn.npy', 'test_pred_knn.npy']:
    arr = _load_level1_pred(fname)
    if arr is not None:
        test_feats.append(arr)
if not test_feats:
    raise FileNotFoundError('No Level-1 test predictions found. Run Phase 2 first.')
X_test_stack = np.mean(test_feats, axis=0).astype(np.float32)
prior_test_path = ext_dir / 'prop_test_no_kaggle.tsv.gz'
if PROCESS_EXTERNAL:
    if not prior_test_path.exists():
        raise FileNotFoundError(
            f'Option B requires external priors, but missing: {prior_test_path}. '
            'Run Phase 1 Step 4 propagation or ensure your checkpoint dataset contains these files (run setup: STORE.pull()).',
        )
    test_ids = pd.read_feather(WORK_ROOT / 'parsed' / 'test_seq.feather')['id'].astype(str)
    prior_test = pd.read_csv(prior_test_path, sep='\t')
    prior_test = prior_test[prior_test['term'].isin(top_terms)]
    prior_t = prior_test.pivot_table(index='EntryID', columns='term', values='score', aggfunc='max', fill_value=0.0)
    prior_t = prior_t.reindex(test_ids.tolist(), fill_value=0.0)
    prior_t = prior_t.reindex(columns=top_terms, fill_value=0.0)
    prior_test_np = prior_t.values.astype(np.float32)
    X_test_stack = np.maximum(X_test_stack, EXTERNAL_PRIOR_WEIGHT * prior_test_np)
    print(f'Injected external IEA prior into test stack (weight={EXTERNAL_PRIOR_WEIGHT}).')

# 5. Ontology split (BP/MF/CC)
ns_to_aspect = {
    'molecular_function': 'MF',
    'biological_process': 'BP',
    'cellular_component': 'CC',
}
aspects = []
for t in top_terms:
    asp = ns_to_aspect.get(go_namespaces.get(t, ''), 'BP')
    aspects.append(asp)
aspects = np.array(aspects)
aspect_to_idx = {
    'BP': np.where(aspects == 'BP')[0].tolist(),
    'MF': np.where(aspects == 'MF')[0].tolist(),
    'CC': np.where(aspects == 'CC')[0].tolist(),
}
for k in ['BP', 'MF', 'CC']:
    print(f'Terms[{k}]={len(aspect_to_idx[k])}')

# 6. Train 3 specialised GCNs and stitch outputs back
test_pred_gcn = np.zeros_like(X_test_stack, dtype=np.float32)
X_tensor_full = torch.tensor(X_stack, dtype=torch.float32, device=device)
Y_tensor_full = torch.tensor(Y, dtype=torch.float32, device=device)
X_test_full = torch.tensor(X_test_stack, dtype=torch.float32, device=device)
def train_one(aspect_name: str, idx_cols: list[int]):
    if not idx_cols:
        print(f'[{aspect_name}] No terms; skipping.')
        return None
    terms_sub = [top_terms[i] for i in idx_cols]
    adj = build_adjacency(terms_sub, go_parents)
    model = SimpleGCN(input_dim=len(idx_cols), hidden_dim=1024, output_dim=len(idx_cols), adj_matrix=adj).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.BCELoss()
    X_t = X_tensor_full[:, idx_cols]
    Y_t = Y_tensor_full[:, idx_cols]
    n_samples = X_t.shape[0]
    BS = 256
    EPOCHS = 5
    model.train()
    print(f'\n=== Training GCN[{aspect_name}] terms={len(idx_cols)} ===')
    for epoch in range(EPOCHS):
        total_loss = 0.0
        perm = torch.randperm(n_samples, device=device)
        for i in range(0, n_samples, BS):
            b = perm[i:i + BS]
            optimizer.zero_grad()
            out = model(X_t[b])
            loss = criterion(out, Y_t[b])
            loss.backward()
            optimizer.step()
            total_loss += float(loss.item())
        with torch.no_grad():
            pred = (model(X_t[:2000]) > 0.3).float().cpu().numpy()
            f1 = f1_score(Y_t[:2000].cpu().numpy(), pred, average='micro')
        print(f'Epoch {epoch+1}/{EPOCHS} Loss={total_loss:.4f} micro-F1@0.30={f1:.4f}')
    return model

In [None]:
# CELL 16 - Solution: 6. PHASE 4: POST-PROCESSING & SUBMISSION
# 6. PHASE 4: POST-PROCESSING & SUBMISSION
# ========================================
# HARDWARE: CPU / GPU
# ========================================
# This phase applies the "Strict Post-Processing" rules (Max/Min Propagation)
# and generates the final submission file.
import json
from pathlib import Path
import numpy as np
import pandas as pd

# Check if submission already exists
if (WORK_ROOT / 'submission.tsv').exists():
    print("submission.tsv already exists. Skipping Phase 4.")
else:
    print("Starting Phase 4: Post-processing & submission...")
    # Ensure go_parents is available (from Phase 1)
    if 'go_parents' not in locals() or 'go_namespaces' not in locals():
        print("Reloading GO graph (parse_obo)...")
        def parse_obo(path: Path):
            parents = {}
            namespaces = {}
            cur_id, cur_ns = None, None
            with path.open('r', encoding='utf-8') as f:
                for line in f:
                    line = line.strip()
                    if line == '[Term]':
                        if cur_id and cur_ns:
                            namespaces[cur_id] = cur_ns
                        cur_id, cur_ns = None, None
                    elif line.startswith('id: GO:'):
                        cur_id = line.split('id: ', 1)[1]
                    elif line.startswith('namespace:'):
                        cur_ns = line.split('namespace: ', 1)[1]
                    elif line.startswith('is_a:') and cur_id:
                        parent = line.split('is_a: ', 1)[1].split(' ! ')[0]
                        parents.setdefault(cur_id, set()).add(parent)
                if cur_id and cur_ns:
                    namespaces[cur_id] = cur_ns
            return parents, namespaces
        go_parents, go_namespaces = parse_obo(PATH_GO_OBO)
    # Load test IDs
    test_ids = pd.read_feather(WORK_ROOT / 'parsed' / 'test_seq.feather')['id']
    # Load stacker predictions
    pred_path = WORK_ROOT / 'features' / 'test_pred_gcn.npy'
    if not pred_path.exists():
        raise FileNotFoundError("Missing `test_pred_gcn.npy`. Run Phase 3 (GCN stacker) first.")
    preds = np.load(pred_path)
    # Load term list (must match Phase 3)
    terms_path = WORK_ROOT / 'features' / 'top_terms_1500.json'
    if terms_path.exists():
        with open(terms_path, 'r') as f:
            top_terms = json.load(f)
    else:
        print("Warning: top_terms_1500.json missing; rebuilding from train_terms counts (may mismatch Phase 3).")
        train_terms = pd.read_parquet(WORK_ROOT / 'parsed' / 'train_terms.parquet')
        top_terms = train_terms['term'].value_counts().head(preds.shape[1]).index.tolist()
    if preds.shape[1] != len(top_terms):
        raise ValueError(f"Shape mismatch: preds has {preds.shape[1]} terms, top_terms has {len(top_terms)}.")
    # ------------------------------------------
    # Strict post-processing (Max/Min Propagation)
    # ------------------------------------------
    print(f"Applying hierarchy rules on {len(top_terms)} terms...")
    df_pred = pd.DataFrame(preds, columns=top_terms)
    term_set = set(top_terms)
    term_to_parents = {}
    term_to_children = {}
    for term in top_terms:
        parents = go_parents.get(term, set())
        if not parents:
            continue
        parents = parents.intersection(term_set)
        if not parents:
            continue
        term_to_parents[term] = list(parents)
        for p in parents:
            term_to_children.setdefault(p, []).append(term)
    # Max Propagation (Child -> Parent)
    for _ in range(2):
        for child, parents in term_to_parents.items():
            child_scores = df_pred[child].values
            for parent in parents:
                df_pred[parent] = np.maximum(df_pred[parent].values, child_scores)
    # Min Propagation (Parent -> Child)
    for _ in range(2):
        for parent, children in term_to_children.items():
            parent_scores = df_pred[parent].values
            for child in children:
                df_pred[child] = np.minimum(df_pred[child].values, parent_scores)
    # ------------------------------------------
    # Submission formatting (CAFA rules)
    # - tab-separated, no header
    # - score in (0, 1.000]
    # - up to 3 significant figures
    # - <= 1500 terms per target (MF/BP/CC combined)
    # ------------------------------------------
    df_pred['EntryID'] = test_ids.values
    submission = df_pred.melt(id_vars='EntryID', var_name='term', value_name='score')
    # Enforce score range + remove zeros
    submission['score'] = submission['score'].clip(lower=0.0, upper=1.0)
    submission = submission[submission['score'] > 0.0]
    # Light pruning (keeps file size sane; still rule-compliant)
    submission = submission[submission['score'] >= 0.001]
    # Keep top 1500 per protein (rule)
    submission = submission.sort_values(['EntryID', 'score'], ascending=[True, False])
    submission = submission.groupby('EntryID', sort=False).head(1500)
    # Write with <= 3 significant figures
    submission.to_csv(
        WORK_ROOT / 'submission.tsv',
        sep='\t',
        index=False,
        header=False,
        float_format='%.3g',
    )
    print(f"Done! Submission saved to {WORK_ROOT / 'submission.tsv'}")
    if 'STORE' in globals() and STORE is not None:
        STORE.push('stage_09_submission', [WORK_ROOT / 'submission.tsv'], note='final submission')