# KNN Standalone Notebook (Baseline-Aligned)

This notebook implements KNN using **baseline methodology** that achieved F1 ~0.216.

## Key Alignment with Baseline

Based on auditor feedback, this notebook fixes the performance regression in the e2e pipeline:

1. **Full Term Vocabulary** ✓ - Uses ALL terms from training data (~25,000+ terms), not filtered to 13,500
2. **Pure Similarity Aggregation** ✓ - No IA weighting during prediction (IA only used in CAFA evaluation)
3. **Per-Protein Max Normalization** ✓ - Normalizes each protein's scores to [0, 1] range (matches baseline)
4. **No GO Hierarchy Propagation** ✓ - CAFA labels are already pre-propagated
5. **Appropriate Thresholds** ✓ - Uses thresholds (0.1-0.8) optimized for max-normalized scores

## What This Fixes

The e2e pipeline had multiple issues causing F1 ~0.072 vs baseline F1 ~0.216:
- **Term filtering** → Severe recall loss (~12,000+ missing terms)
- **IA-weighted aggregation** → Double-weighting (IA in both prediction & evaluation)
- **Wrong thresholds** → Too low for normalized scores
- **Redundant hierarchy propagation** → Distorts probability distributions

##Requirements:
- Pre-computed embeddings (ESM2-3B)
- Parsed training data
- IA weights (for evaluation only)

## Expected Performance

With these fixes, F1 should improve from ~0.072 to ~0.20-0.25 (matching baseline methodology with better 3B embeddings).


In [None]:
# CELL 1 - Setup (NO REPO)
import os
import sys
import ctypes
from pathlib import Path

# CUDA loader fix (PyTorch/RAPIDS coexistence): preload venv nvjitlink so we don't pick /usr/local/cuda/lib64
try:
    _venv_root = Path(sys.executable).resolve().parent.parent
    _nvjit_dir = (
        _venv_root
        / "lib"
        / f"python{sys.version_info.major}.{sys.version_info.minor}"
        / "site-packages"
        / "nvidia"
        / "nvjitlink"
        / "lib"
    )
    _nvjit_so = _nvjit_dir / "libnvJitLink.so.12"
    if _nvjit_so.exists():
        ctypes.CDLL(str(_nvjit_so), mode=ctypes.RTLD_GLOBAL)
        os.environ["LD_LIBRARY_PATH"] = f"{_nvjit_dir}:{os.environ.get('LD_LIBRARY_PATH','')}"
        print(f"[ENV] Preloaded nvjitlink: {_nvjit_so}")
except Exception as _e:
    print(f"[ENV] nvjitlink preload skipped: {_e}")

# Always run from a simple writable location; never cd into a repo.
if os.path.exists('/content'):
    os.chdir('/content')
RUNTIME_ROOT = Path.cwd()
DATA_ROOT = (RUNTIME_ROOT / 'cafa6_data')
DATA_ROOT.mkdir(parents=True, exist_ok=True)
TRAIN_LEVEL1 = True
print(f'CWD: {Path.cwd()}')
print(f'DATA_ROOT: {DATA_ROOT.resolve()}')

In [None]:
# CELL 2 - Simplified Setup & Config (KNN Standalone)
# This is a stripped-down version without HuggingFace download/upload

import json
import os
from pathlib import Path
import numpy as np


# Environment Detection
def _detect_kaggle() -> bool:
    return bool(
        os.environ.get('KAGGLE_KERNEL_RUN_TYPE')
        or os.environ.get('KAGGLE_URL_BASE')
        or os.environ.get('KAGGLE_DATA_PROXY_URL')
    )


def _detect_colab() -> bool:
    return bool(
        os.environ.get('COLAB_RELEASE_TAG')
        or os.environ.get('COLAB_GPU')
        or os.environ.get('COLAB_TPU_ADDR')
    )


IS_KAGGLE = _detect_kaggle()
IS_COLAB = (not IS_KAGGLE) and _detect_colab()

if IS_KAGGLE:
    print('Environment: Kaggle')
    WORKING_ROOT = Path('/kaggle/working')
elif IS_COLAB:
    print('Environment: Colab')
    WORKING_ROOT = Path('/content')
else:
    print('Environment: Local')
    WORKING_ROOT = Path.cwd()

# Setup WORK_ROOT
if 'DATA_ROOT' in globals():
    WORK_ROOT = Path(DATA_ROOT)
    WORKING_ROOT = WORK_ROOT.parent
else:
    WORK_ROOT = WORKING_ROOT / 'cafa6_data'

WORK_ROOT.mkdir(parents=True, exist_ok=True)
for _d in ['parsed', 'features', 'external', 'Train', 'Test']:
    (WORK_ROOT / _d).mkdir(parents=True, exist_ok=True)

print(f'WORK_ROOT: {WORK_ROOT}')

# Training flag
TRAIN_LEVEL1 = bool(int(os.getenv('CAFA_TRAIN_LEVEL1', '1')))
print(f'TRAIN_LEVEL1: {TRAIN_LEVEL1}')

# Stub CheckpointStore (no HuggingFace operations)
class CheckpointStore:
    """Simplified checkpoint store without HuggingFace integration."""
    
    def __init__(self, work_root: Path):
        self.work_root = work_root
    
    def maybe_pull(self, stage: str, required_files: list[str] = None, note: str = '') -> bool:
        """Stub: does nothing in standalone mode."""
        print(f'[CHECKPOINT] {stage}: pull skipped (standalone mode)')
        return False
    
    def maybe_push(self, stage: str, required_paths: list[str] = None, note: str = '') -> bool:
        """Stub: does nothing in standalone mode."""
        print(f'[CHECKPOINT] {stage}: push skipped (standalone mode)')
        return False


STORE = CheckpointStore(work_root=WORK_ROOT)
print('CheckpointStore initialized (stub mode - no HF operations)')


In [None]:
# CELL 3 - Data Loading (Phase 2 canonical)
# =====================================================
# This cell loads the training data and embeddings
# for the KNN model using FULL term vocabulary
# =====================================================

import json
import pandas as pd
import numpy as np
from pathlib import Path

# Define data directories
DATA_ROOT = WORK_ROOT  # WORK_ROOT is defined in Cell 2
FEAT_DIR = WORK_ROOT / 'features'

# Load training data
print("[DATA] Loading training data...")
train_terms = pd.read_csv(DATA_ROOT / 'Train' / 'train_terms.tsv', sep='\t')
print(f"  Loaded {len(train_terms)} training annotations")

# Load training embeddings (ESM2-3B)
print("[DATA] Loading ESM2-3B embeddings...")
train_emb = np.load(FEAT_DIR / 'train_embeds_esm2_3b.npy')
print(f"  Train embeddings shape: {train_emb.shape}")

# Load test embeddings
test_emb = np.load(FEAT_DIR / 'test_embeds_esm2_3b.npy')
print(f"  Test embeddings shape: {test_emb.shape}")

# BASELINE ALIGNMENT: Use FULL term vocabulary (no 13,500 filtering)
# This was one of the 5 critical bugs - term filtering caused severe recall loss
print("[BASELINE MODE] Using FULL term vocabulary (no 13,500 filtering)")
top_terms = train_terms["term"].value_counts().index.tolist()
print(f"  Total unique terms in training data: {len(top_terms)}")

# Create term-to-index mapping
term_to_idx = {term: i for i, term in enumerate(top_terms)}
print(f"  Term vocabulary size: {len(top_terms)} (full vocabulary)")

# Build label matrix Y
print("[DATA] Building label matrix...")
n_proteins = train_terms['EntryID'].nunique()
n_terms = len(top_terms)
protein_ids = train_terms['EntryID'].unique()
protein_to_idx = {pid: i for i, pid in enumerate(protein_ids)}

# Create sparse label matrix
from scipy import sparse
rows = []
cols = []
for _, row in train_terms.iterrows():
    protein_idx = protein_to_idx[row['EntryID']]
    term_idx = term_to_idx.get(row['term'])
    if term_idx is not None:
        rows.append(protein_idx)
        cols.append(term_idx)

Y_train = sparse.csr_matrix(
    (np.ones(len(rows), dtype=np.float32), (rows, cols)),
    shape=(n_proteins, n_terms)
)
print(f"  Label matrix shape: {Y_train.shape}")
print(f"  Total annotations: {Y_train.nnz}")

# Load GO ontology for aspect mapping
print("[DATA] Loading GO ontology...")
go_path = DATA_ROOT / 'Train' / 'go-basic.obo'
# Simple parser for GO aspects
go_aspects = {}
current_term = None
with open(go_path, 'r') as f:
    for line in f:
        line = line.strip()
        if line.startswith('id: GO:'):
            current_term = line[4:]
        elif line.startswith('namespace:') and current_term:
            namespace = line.split(': ')[1]
            aspect_map = {
                'biological_process': 'BP',
                'molecular_function': 'MF',
                'cellular_component': 'CC'
            }
            go_aspects[current_term] = aspect_map.get(namespace, 'UNK')

print(f"  Loaded aspects for {len(go_aspects)} GO terms")

# Map top_terms to aspects
term_aspects = {}
for term in top_terms:
    term_aspects[term] = go_aspects.get(term, 'UNK')

print("[DATA] Data loading complete")
print(f"  Training proteins: {n_proteins}")
print(f"  Full term vocabulary: {len(top_terms)}")
print(f"  Embeddings dimension: {train_emb.shape[1]}")

# Create variables expected by Cell 4 (KNN Helper Functions)
Y = Y_train.toarray()  # Convert sparse to dense for KNN
features_test = {'esm2_3b': test_emb}  # Dictionary format expected by Cell 4

print(f"[DATA] Variables prepared for KNN:")
print(f"  Y shape: {Y.shape}")
print(f"  features_test keys: {list(features_test.keys())}")


In [None]:
# CELL 4 - KNN Helper Functions & Variable Setup
# This cell defines the helper functions and variables needed by the KNN cell

import numpy as np
import pandas as pd
import json
from pathlib import Path


# Helper function: L2 normalization
def _l2_norm(X):
    """L2-normalize rows of X (convert to unit vectors).
    This transforms cosine similarity to dot product for faster GPU computation.
    """
    norms = np.linalg.norm(X, axis=1, keepdims=True)
    norms = np.maximum(norms, 1e-12)  # Avoid division by zero
    return X / norms


# NOTE: IA weights are NOT used in aggregation
# CAFA evaluation already weights by IA; using IA during prediction causes double-weighting
# We use pure similarity aggregation instead

# BASELINE ALIGNMENT: Use FULL term vocabulary (no 13,500 filtering)
# The top_terms list is already defined in Cell 3 (Data Loading) from training data
# We use uniform weights (no IA weighting during prediction)

if 'top_terms' not in globals():
    raise RuntimeError('top_terms not defined. Ensure the Data Loading cell (Cell 3) ran successfully.')

weights_full = np.ones(len(top_terms), dtype=np.float32)  # Uniform weights, no IA
print(f'[KNN Setup] Using {len(top_terms)} terms with uniform weights (no IA, baseline mode)')

# Prepare Y_knn (target labels matrix)
print('[KNN Setup] Preparing target labels (Y_knn)...')
if 'Y' not in globals():
    raise RuntimeError('Y not defined. Ensure the Data Loading cell (Cell 3) ran successfully.')

Y_knn = Y  # Y is already defined in Cell 3
print(f'  Y_knn shape: {Y_knn.shape}')

# Prepare X_knn_test (test features)
print('[KNN Setup] Preparing test features (X_knn_test)...')
if 'features_test' not in globals():
    raise RuntimeError('features_test not defined. Ensure the Data Loading cell (Cell 3) ran successfully.')

if 'esm2_3b' not in features_test:
    raise FileNotFoundError("Missing required modality 'esm2_3b' in features_test.")

X_knn_test = features_test['esm2_3b'].astype(np.float32)
print(f'  X_knn_test shape: {X_knn_test.shape}')

print('[KNN Setup] All helper functions and variables ready! ✓')
print('[BASELINE MODE] No term filtering - using FULL vocabulary from training data')


In [None]:
# CELL 5 - KNN Training (Baseline-Aligned + Memory Optimized)
# =====================================================
# Pure similarity aggregation (NO IA weighting)
# Vectorized per-protein max normalization
# Batch processing to avoid OOM
# k hyperparameter sweep: [10, 20, 30]
# Full vocabulary (~25,000+ terms)

from sklearn.model_selection import StratifiedKFold
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import f1_score, precision_score, recall_score
import numpy as np

# OPTIMIZATION 3: k Hyperparameter sweep
K_VALUES = [10, 20, 30]  # Test different k values
BATCH_SIZE = 500  # Batch size for memory-safe processing

k_results = []
best_oof_f1 = 0
best_k = 10
oof_pred_knn = None
test_pred_knn = None

print(f"\n[KNN] Starting hyperparameter sweep over k={K_VALUES}...")

for KNN_K in K_VALUES:
    print(f"\n{'='*60}")
    print(f"[KNN] Testing k={KNN_K}")
    print(f"{'='*60}")
    
    # Initialize OOF predictions for this k
    oof_pred_k = np.zeros((len(train_emb), len(top_terms)), dtype=np.float32)
    test_pred_k = np.zeros((len(test_emb), len(top_terms)), dtype=np.float32)
    
    # 5-fold cross-validation
    N_FOLDS = 5
    skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=42)
    
    # Create a simple stratification target (number of labels per protein)
    stratify_target = Y_train.sum(axis=1).A1  # Convert to 1D array
    stratify_target = np.minimum(stratify_target, 10)  # Cap at 10 for stratification
    
    for fold_num, (train_idx, val_idx) in enumerate(skf.split(train_emb, stratify_target), 1):
        print(f"  Fold {fold_num}/{N_FOLDS}...")
        
        # Split data
        X_train_fold = train_emb[train_idx]
        X_val_fold = train_emb[val_idx]
        Y_train_fold = Y_train[train_idx]
        
        # Train KNN
        print(f"    Training KNN with k={KNN_K}...")
        knn = NearestNeighbors(n_neighbors=KNN_K, metric='cosine', n_jobs=-1)
        knn.fit(X_train_fold)
        
        # Get neighbors for validation set
        print("    Computing validation predictions (batched)...")
        distances, indices = knn.kneighbors(X_val_fold)
        similarities = 1 - distances
        
        # OPTIMIZATION 1: Batch processing to avoid OOM
        for batch_start in range(0, len(val_idx), BATCH_SIZE):
            batch_end = min(batch_start + BATCH_SIZE, len(val_idx))
            batch_indices = indices[batch_start:batch_end]
            batch_sims = similarities[batch_start:batch_end]
            
            # FIX IndexError: Flatten 2D indices, fetch from sparse matrix, then reshape
            flat_indices = batch_indices.flatten()
            batch_labels = Y_train_fold[flat_indices].toarray()
            Y_neighbors_batch = batch_labels.reshape(batch_end - batch_start, KNN_K, -1)
            
            # BASELINE ALIGNMENT: Pure similarity aggregation (NO IA weighting)
            sims_expanded = batch_sims[:, :, np.newaxis]
            raw_scores_batch = (sims_expanded * Y_neighbors_batch).sum(axis=1)
            raw_scores_batch /= batch_sims.sum(axis=1, keepdims=True)
            
            oof_pred_k[val_idx[batch_start:batch_end]] = raw_scores_batch.astype(np.float32)
        
        # Test predictions with BATCH PROCESSING
        print("    Computing test predictions (batched)...")
        test_distances, test_indices = knn.kneighbors(test_emb)
        test_similarities = 1 - test_distances
        
        test_raw = np.zeros((len(test_emb), len(top_terms)), dtype=np.float32)
        for batch_start in range(0, len(test_emb), BATCH_SIZE):
            batch_end = min(batch_start + BATCH_SIZE, len(test_emb))
            batch_indices = test_indices[batch_start:batch_end]
            batch_sims = test_similarities[batch_start:batch_end]
            
            # FIX IndexError: Flatten 2D indices, fetch from sparse matrix, then reshape
            flat_indices = batch_indices.flatten()
            batch_labels = Y_train_fold[flat_indices].toarray()
            Y_neighbors_batch = batch_labels.reshape(batch_end - batch_start, KNN_K, -1)
            
            test_sims_expanded = batch_sims[:, :, np.newaxis]
            test_raw_batch = (test_sims_expanded * Y_neighbors_batch).sum(axis=1)
            test_raw_batch /= batch_sims.sum(axis=1, keepdims=True)
            
            test_raw[batch_start:batch_end] = test_raw_batch.astype(np.float32)
        
        test_pred_k += test_raw / N_FOLDS
    
    # OPTIMIZATION 2: Vectorized per-protein max normalization
    print(f"  Applying vectorized per-protein max normalization...")
    row_max = oof_pred_k.max(axis=1, keepdims=True)
    row_max[row_max == 0] = 1.0
    oof_pred_k /= row_max
    
    test_row_max = test_pred_k.max(axis=1, keepdims=True)
    test_row_max[test_row_max == 0] = 1.0
    test_pred_k /= test_row_max
    
    # Evaluate this k value
    print(f"  Evaluating k={KNN_K} on OOF predictions...")
    Y_true = Y_train.toarray()
    
    # Find best threshold for this k
    thresholds = [0.1, 0.2, 0.3, 0.4, 0.5]
    best_f1_k = 0
    best_threshold_k = 0.3
    
    for threshold in thresholds:
        Y_pred = (oof_pred_k >= threshold).astype(int)
        f1 = f1_score(Y_true, Y_pred, average='samples', zero_division=0)
        if f1 > best_f1_k:
            best_f1_k = f1
            best_threshold_k = threshold
    
    print(f"  k={KNN_K}: Best F1={best_f1_k:.4f} @ threshold={best_threshold_k:.2f}")
    
    # Store results
    k_results.append({
        'k': KNN_K,
        'oof_f1': best_f1_k,
        'best_threshold': best_threshold_k
    })
    
    # Track best k
    if best_f1_k > best_oof_f1:
        best_oof_f1 = best_f1_k
        best_k = KNN_K
        oof_pred_knn = oof_pred_k.copy()
        test_pred_knn = test_pred_k.copy()

# Display hyperparameter sweep results
print(f"\n{'='*60}")
print("[KNN] Hyperparameter Sweep Results")
print(f"{'='*60}")
print(f"{'k':<6} {'OOF F1':<10} {'Best Threshold':<15}")
print("-" * 40)
for result in k_results:
    marker = " *** BEST" if result['k'] == best_k else ""
    print(f"{result['k']:<6} {result['oof_f1']:<10.4f} {result['best_threshold']:<15.2f}{marker}")

print(f"\n[KNN] Selected k={best_k} with OOF F1={best_oof_f1:.4f}")

# Save predictions (using best k)
print(f"\n[KNN] Saving predictions from best k={best_k}...")
knn_pred_path = WORK_ROOT / 'oof_knn.npy'
test_pred_path = WORK_ROOT / 'test_knn.npy'
np.save(knn_pred_path, oof_pred_knn)
np.save(test_pred_path, test_pred_knn)
print(f"  Saved to: {knn_pred_path}")

# Final evaluation on validation set
print("\n[KNN] Final evaluation on validation set...")
Y_true = Y_train.toarray()

# Test multiple thresholds
thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
best_f1 = 0
best_threshold = 0.3
print("\nThreshold   F1      Precision  Recall")
print("-" * 50)
for threshold in thresholds:
    Y_pred = (oof_pred_knn >= threshold).astype(int)
    f1 = f1_score(Y_true, Y_pred, average='samples', zero_division=0)
    precision = precision_score(Y_true, Y_pred, average='samples', zero_division=0)
    recall = recall_score(Y_true, Y_pred, average='samples', zero_division=0)
    
    marker = "  <- BEST" if f1 > best_f1 else ""
    if f1 > best_f1:
        best_f1 = f1
        best_threshold = threshold
    
    print(f"{threshold:.2f}       {f1:.4f}  {precision:.4f}     {recall:.4f}{marker}")

print(f"\n[KNN] Best threshold: {best_threshold:.2f}, F1: {best_f1:.4f}")
print("[KNN] Training complete")

In [None]:
# CELL 5b - KNN Hyperparameter Sweep (k + Thresholds)
# =====================================================
# Optimize k (number of neighbors) and aspect-specific thresholds
# Tests k in [5, 10, 15, 20, 25, 50] with threshold search for each

RUN_HYPERPARAM_SWEEP = False  # Set to True to run sweep (time-consuming)

if RUN_HYPERPARAM_SWEEP:
    print("\n" + "="*80)
    print("[HYPERPARAMETER SWEEP] Testing k and aspect-specific thresholds")
    print("="*80)
    
    # Test range for k (number of neighbors)
    K_VALUES = [5, 10, 15, 20, 25, 50]
    
    # Threshold search grid per aspect
    BP_THRESHOLDS = [0.15, 0.20, 0.25, 0.30, 0.35]
    MF_THRESHOLDS = [0.30, 0.40, 0.50, 0.60, 0.70]
    CC_THRESHOLDS = [0.20, 0.25, 0.30, 0.35, 0.40]
    
    # Store results
    sweep_results = []
    best_f1 = 0.0
    best_config = {}
    
    import time
    sweep_start = time.time()
    
    for k_val in K_VALUES:
        print(f"\n{'='*60}")
        print(f"Testing k = {k_val}")
        print(f"{'='*60}")
        
        # Re-run KNN with this k value
        # Use cuML if available, else sklearn
        try:
            from cuml.neighbors import NearestNeighbors
            using_cuml = True
        except ImportError:
            from sklearn.neighbors import NearestNeighbors
            using_cuml = False
        
        # Cross-validation predictions with this k
        oof_pred_k = np.zeros((len(train_df), len(top_terms)), dtype=np.float32)
        
        for fold in range(N_FOLDS):
            tr_idx = train_df.index[train_df['fold'] != fold].tolist()
            va_idx = train_df.index[train_df['fold'] == fold].tolist()
            
            X_tr = X_knn[tr_idx]
            X_va = X_knn[va_idx]
            Y_tr = Y_knn[tr_idx]
            
            # Fit KNN with k_val neighbors
            knn = NearestNeighbors(n_neighbors=k_val, metric='cosine', n_jobs=-1)
            knn.fit(X_tr)
            
            # Predict on validation fold
            dists_va, neigh_va = knn.kneighbors(X_va, return_distance=True)
            sims_va = np.clip(1.0 - dists_va, 0.0, 1.0).astype(np.float32)
            
            # Pure similarity aggregation (no IA weighting)
            for i in range(0, len(va_idx), KNN_BATCH):
                end_i = min(i + KNN_BATCH, len(va_idx))
                batch_neigh = neigh_va[i:end_i]
                batch_sims = sims_va[i:end_i]
                
                Y_nei = Y_tr[batch_neigh].toarray() if hasattr(Y_tr, 'toarray') else Y_tr[batch_neigh]
                sims_b = batch_sims
                denom = sims_b.sum(axis=1, keepdims=True)
                denom[denom == 0] = 1.0
                
                scores = ((sims_b[:, :, np.newaxis] * Y_nei).sum(axis=1) / denom).astype(np.float32)
                oof_pred_k[va_idx[i:end_i]] = scores
        
        # Per-protein max normalization
        for i in range(oof_pred_k.shape[0]):
            max_score = oof_pred_k[i].max()
            if max_score > 0:
                oof_pred_k[i] /= max_score
        
        print(f"  Completed KNN training with k={k_val}")
        
        # Threshold search for this k
        print(f"  Searching optimal thresholds...")
        
        best_f1_k = 0.0
        best_thr_k = {}
        
        for bp_thr in BP_THRESHOLDS:
            for mf_thr in MF_THRESHOLDS:
                for cc_thr in CC_THRESHOLDS:
                    # Apply aspect-specific thresholds
                    preds_binary = np.zeros_like(oof_pred_k, dtype=np.int8)
                    
                    for i, term in enumerate(top_terms):
                        aspect = term_to_aspect.get(term, 'UNK')
                        if aspect == 'BP':
                            thr = bp_thr
                        elif aspect == 'MF':
                            thr = mf_thr
                        elif aspect == 'CC':
                            thr = cc_thr
                        else:
                            thr = 0.3  # fallback
                        
                        preds_binary[:, i] = (oof_pred_k[:, i] >= thr).astype(np.int8)
                    
                    # Calculate F1
                    tp = (preds_binary * Y_knn.toarray()).sum()
                    fp = (preds_binary * (1 - Y_knn.toarray())).sum()
                    fn = ((1 - preds_binary) * Y_knn.toarray()).sum()
                    
                    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
                    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
                    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
                    
                    if f1 > best_f1_k:
                        best_f1_k = f1
                        best_thr_k = {'BP': bp_thr, 'MF': mf_thr, 'CC': cc_thr}
        
        print(f"  Best F1 for k={k_val}: {best_f1_k:.4f} | Thresholds: BP={best_thr_k['BP']:.2f}, MF={best_thr_k['MF']:.2f}, CC={best_thr_k['CC']:.2f}")
        
        # Store result
        sweep_results.append({
            'k': k_val,
            'f1': best_f1_k,
            'bp_threshold': best_thr_k['BP'],
            'mf_threshold': best_thr_k['MF'],
            'cc_threshold': best_thr_k['CC']
        })
        
        if best_f1_k > best_f1:
            best_f1 = best_f1_k
            best_config = {'k': k_val, **best_thr_k}
    
    sweep_elapsed = time.time() - sweep_start
    
    # Display results
    print("\n" + "="*80)
    print("[HYPERPARAMETER SWEEP RESULTS]")
    print("="*80)
    print("\nPerformance Heatmap:")
    print(f"{'k':<6} {'F1':<8} {'BP Thr':<8} {'MF Thr':<8} {'CC Thr':<8}")
    print("-" * 50)
    for result in sweep_results:
        marker = " *** BEST" if result['k'] == best_config['k'] else ""
        print(f"{result['k']:<6} {result['f1']:<8.4f} {result['bp_threshold']:<8.2f} {result['mf_threshold']:<8.2f} {result['cc_threshold']:<8.2f}{marker}")
    
    print("\n" + "="*80)
    print("[OPTIMAL CONFIGURATION]")
    print("="*80)
    print(f"  Best k: {best_config['k']}")
    print(f"  Best F1: {best_f1:.4f}")
    print(f"  Optimal Thresholds:")
    print(f"    BP (Biological Process): {best_config['BP']:.2f}")
    print(f"    MF (Molecular Function): {best_config['MF']:.2f}")
    print(f"    CC (Cellular Component): {best_config['CC']:.2f}")
    print(f"\n  Total sweep time: {sweep_elapsed/60:.1f} minutes")
    print("="*80)
    
    # Save results
    import pandas as pd
    results_df = pd.DataFrame(sweep_results)
    results_path = WORK_ROOT / 'knn_hyperparam_sweep_results.csv'
    results_df.to_csv(results_path, index=False)
    print(f"\n  Results saved to: {results_path}")
    
    # Visualize if matplotlib available
    try:
        import matplotlib.pyplot as plt
        
        fig, ax = plt.subplots(1, 1, figsize=(10, 6))
        ax.plot([r['k'] for r in sweep_results], [r['f1'] for r in sweep_results], 'o-', linewidth=2, markersize=8)
        ax.set_xlabel('k (Number of Neighbors)', fontsize=12)
        ax.set_ylabel('F1 Score', fontsize=12)
        ax.set_title('KNN Performance vs k (with optimized thresholds per k)', fontsize=14, fontweight='bold')
        ax.grid(True, alpha=0.3)
        ax.axhline(y=best_f1, color='r', linestyle='--', alpha=0.5, label=f'Best F1: {best_f1:.4f} @ k={best_config["k"]}')
        ax.legend()
        plt.tight_layout()
        
        plot_path = WORK_ROOT / 'knn_hyperparam_sweep.png'
        plt.savefig(plot_path, dpi=150, bbox_inches='tight')
        print(f"  Plot saved to: {plot_path}")
        plt.show()
    except ImportError:
        print("  (matplotlib not available for visualization)")

else:
    print("\n[HYPERPARAMETER SWEEP] Skipped (RUN_HYPERPARAM_SWEEP=False)")
    print("  Set RUN_HYPERPARAM_SWEEP=True to run k optimization and threshold search")
    print("  This will test k in [5, 10, 15, 20, 25, 50] with aspect-specific threshold grid search")
    print("  Expected time: 30-60 minutes depending on hardware")


In [None]:
# CELL 6 - Generate Submission File from KNN Predictions
# This cell creates a submission.tsv file from the KNN test predictions

import json
from pathlib import Path
import numpy as np
import pandas as pd

print('[SUBMISSION] Generating submission file from KNN predictions...')

# Check if submission already exists
submission_path = WORK_ROOT / 'submission.tsv'
if submission_path.exists():
    print(f'  submission.tsv already exists at {submission_path}')
    print('  To regenerate, delete it first.')
else:
    # Load KNN test predictions
    knn_test_path = WORK_ROOT / 'features' / 'level1_preds' / 'test_pred_knn.npy'
    if not knn_test_path.exists():
        knn_test_path = WORK_ROOT / 'features' / 'test_pred_knn.npy'
    
    if not knn_test_path.exists():
        raise FileNotFoundError(f'Missing KNN test predictions. Expected at {knn_test_path}. Run the KNN training cell first.')
    
    preds = np.load(knn_test_path).astype(np.float32)
    print(f'  Loaded KNN predictions: {preds.shape}')
    
    # Load test IDs
    test_seq_path = WORK_ROOT / 'parsed' / 'test_seq.feather'
    if not test_seq_path.exists():
        raise FileNotFoundError(f'Missing {test_seq_path}. Run the data loading cell first.')
    
    test_ids = pd.read_feather(test_seq_path)['id'].astype(str)
    # Extract UniProt IDs from FASTA format (e.g., >sp|P12345|NAME)
    test_ids = test_ids.str.extract(r'\|(.+?)\|', expand=False).fillna(test_ids)
    print(f'  Loaded {len(test_ids)} test IDs')
    
    # Load term list
    terms_path = WORK_ROOT / 'features' / 'top_terms_full.json'
    if not terms_path.exists():
        # Fallback to old name (should not happen after Cell 3 runs)
        terms_path = WORK_ROOT / 'features' / 'top_terms_13500.json'
    
    if not terms_path.exists():
        raise FileNotFoundError(f'Missing {terms_path}. Run the data loading cell first.')
    
    top_terms = json.loads(terms_path.read_text(encoding='utf-8'))
    print(f'  Loaded {len(top_terms)} terms')
    
    if preds.shape[1] != len(top_terms):
        raise ValueError(f'Shape mismatch: preds has {preds.shape[1]} terms, top_terms has {len(top_terms)}')
    
    # BASELINE ALIGNMENT: Skip GO hierarchy propagation
    # The auditor identified that CAFA labels are already pre-propagated
    # Explicit propagation is "redundant/distorting" and hurts performance
    print('  [BASELINE MODE] Skipping GO hierarchy propagation (labels already propagated)')
    
    df_pred = pd.DataFrame(preds, columns=top_terms)
    
    # ============================================================================
    # BASELINE THRESHOLD OVERRIDE
    # ============================================================================
    # Set this to True to use baseline-proven thresholds for the first GPU run
    # Set to False to use learned thresholds from aspect_thresholds.json
    USE_BASELINE_THRESHOLDS = True
    
    if USE_BASELINE_THRESHOLDS:
        print('  [BASELINE MODE] Using baseline-proven thresholds')
        ASPECT_THRESHOLDS = {
            'BP': 0.25,   # Biological Process
            'MF': 0.50,   # Molecular Function
            'CC': 0.35,   # Cellular Component
            'ALL': 0.30   # Fallback for unknown aspects
        }
        print(f'    MF=0.50, BP=0.25, CC=0.35, ALL=0.30')
    else:
        # Fallback to learned thresholds
        thr_path = WORK_ROOT / 'features' / 'aspect_thresholds.json'
        if thr_path.exists():
            ASPECT_THRESHOLDS = json.loads(thr_path.read_text(encoding='utf-8'))
            print(f'  Using learned thresholds: {ASPECT_THRESHOLDS}')
        else:
            print('  [WARN] aspect_thresholds.json not found; using default 0.3')
            ASPECT_THRESHOLDS = {'ALL': 0.3}
    
    # Map GO terms to aspects
    go_obo_path = WORK_ROOT / 'parsed' / 'go-basic.obo'
    ns_to_aspect = {
        'molecular_function': 'MF',
        'biological_process': 'BP',
        'cellular_component': 'CC',
    }
    
    if go_obo_path.exists():
        # Load GO namespaces from ontology
        from goatools.obo_parser import GODag
        godag = GODag(str(go_obo_path))
        go_namespaces = {rec.id: rec.namespace for rec.id in godag}
        
        aspects = np.array([
            ns_to_aspect.get(go_namespaces.get(t, 'unknown'), 'ALL')
            for t in top_terms
        ], dtype='<U3')
    else:
        # Fallback: use ALL threshold for everything
        print('  [WARN] go-basic.obo not found; using ALL threshold for all terms')
        aspects = np.array(['ALL'] * len(top_terms), dtype='<U3')
    
    # Build threshold vector
    thr_vec = np.array([
        float(ASPECT_THRESHOLDS.get(a, ASPECT_THRESHOLDS.get('ALL', 0.3)))
        for a in aspects
    ], dtype=np.float32)
    
    # Apply thresholds
    pred_np = preds
    pred_np = np.where(pred_np >= thr_vec[None, :], pred_np, 0.0).astype(np.float32)
    df_pred = pd.DataFrame(pred_np, columns=top_terms)
    
    # Format submission (CAFA rules)
    print('  Formatting submission...')
    df_pred['EntryID'] = test_ids.values
    submission = df_pred.melt(id_vars='EntryID', var_name='term', value_name='score')
    
    # Enforce score range + remove zeros
    submission['score'] = submission['score'].clip(lower=0.0, upper=1.0)
    submission = submission[submission['score'] > 0.0]
    
    # Keep top 1500 per protein (CAFA rule)
    submission = submission.sort_values(['EntryID', 'score'], ascending=[True, False])
    submission = submission.groupby('EntryID', sort=False).head(1500)
    
    # Write with <= 3 significant figures
    submission.to_csv(
        submission_path,
        sep='\t',
        index=False,
        header=False,
        float_format='%.3g',
    )
    
    print(f'\n✅ Submission saved to {submission_path}')
    print(f'   Total predictions: {len(submission):,}')
    print(f'   Proteins: {submission["EntryID"].nunique():,}')
    print(f'   Avg predictions/protein: {len(submission) / submission["EntryID"].nunique():.1f}')


In [None]:
# CELL 7 - Install Kaggle CLI
# This cell installs the Kaggle CLI if not already installed

!pip install kaggle -q


In [None]:
# CELL 8 - Submit to Kaggle Competition
# This cell submits the generated submission.tsv to the CAFA-6 competition

import os
import subprocess
from pathlib import Path

# Load Kaggle credentials from environment (including .env if present)
def _load_kaggle_credentials():
    """Load Kaggle credentials from .env file or environment variables."""
    # Try to load from .env file
    env_path = Path.cwd().parent / '.env'
    if env_path.exists():
        print('[KAGGLE] Loading credentials from .env file...')
        try:
            with open(env_path, 'r') as f:
                for line in f:
                    line = line.strip()
                    if line and not line.startswith('#') and '=' in line:
                        key, value = line.split('=', 1)
                        key = key.strip()
                        value = value.strip().strip('"').strip("'")
                        if key in ['KAGGLE_USERNAME', 'KAGGLE_KEY'] and value:
                            os.environ[key] = value
        except Exception as e:
            print(f'  [WARN] Failed to load .env: {e}')
    
    # Check if credentials are set
    username = os.environ.get('KAGGLE_USERNAME', '').strip()
    key = os.environ.get('KAGGLE_KEY', '').strip()
    
    if username and key:
        print(f'  ✓ Kaggle credentials loaded: {username}')
        return True
    else:
        print('  ✗ Kaggle credentials not found in environment or .env')
        return False

# Configuration
COMPETITION_NAME = 'cafa-6-protein-function-prediction'
SUBMISSION_MESSAGE = 'KNN standalone submission (ESM2-3B embeddings)'

# Check if submission file exists
submission_path = WORK_ROOT / 'submission.tsv'
if not submission_path.exists():
    raise FileNotFoundError(
        f'Submission file not found at {submission_path}. '
        'Run the submission generation cell first.'
    )

print(f'[KAGGLE] Preparing to submit to {COMPETITION_NAME}...')

# Load credentials
if not _load_kaggle_credentials():
    print('\n' + '='*60)
    print('ERROR: Kaggle credentials not configured')
    print('='*60)
    print('\nPlease set KAGGLE_USERNAME and KAGGLE_KEY:')
    print('\n1. Create a .env file in the project root with:')
    print('   KAGGLE_USERNAME=your_username')
    print('   KAGGLE_KEY=your_api_key')
    print('\n2. Or set environment variables:')
    print('   export KAGGLE_USERNAME=your_username')
    print('   export KAGGLE_KEY=your_api_key')
    print('\nGet your API key from: https://www.kaggle.com/settings')
    print('='*60)
    raise RuntimeError('Kaggle credentials not configured')

# Check if Kaggle CLI is installed
try:
    result = subprocess.run(
        ['kaggle', '--version'],
        capture_output=True,
        text=True
    )
    print(f'  ✓ Kaggle CLI version: {result.stdout.strip()}')
except FileNotFoundError:
    print('\n' + '='*60)
    print('ERROR: Kaggle CLI not installed')
    print('='*60)
    print('\nPlease install the Kaggle CLI:')
    print('  pip install kaggle')
    print('\nOr if using conda:')
    print('  conda install -c conda-forge kaggle')
    print('='*60)
    raise RuntimeError('Kaggle CLI not installed')

# Build the kaggle command
print(f'\n[KAGGLE] Submitting {submission_path}...')
print(f'  Message: "{SUBMISSION_MESSAGE}"')

cmd = [
    'kaggle',
    'competitions',
    'submit',
    '-c', COMPETITION_NAME,
    '-f', str(submission_path),
    '-m', SUBMISSION_MESSAGE
]

# Execute submission
try:
    result = subprocess.run(
        cmd,
        check=True,
        capture_output=True,
        text=True,
        encoding='utf-8',
        errors='replace'
    )
    print('\n' + '='*60)
    print('✅ SUBMISSION SUCCESSFUL!')
    print('='*60)
    if result.stdout:
        print('\nKaggle output:')
        print(result.stdout)
    if result.stderr:
        print('\nAdditional info:')
        print(result.stderr)
    print('\nCheck your submission at:')
    print(f'https://www.kaggle.com/competitions/{COMPETITION_NAME}/submissions')
    print('='*60)
except subprocess.CalledProcessError as e:
    print('\n' + '='*60)
    print('❌ SUBMISSION FAILED')
    print('='*60)
    if e.stderr:
        print('\nError output:')
        print(e.stderr)
    if e.stdout:
        print('\nStdout:')
        print(e.stdout)
    print('\nCommon issues:')
    print('  - Competition rules not accepted')
    print('  - Invalid submission format')
    print('  - Daily submission limit reached')
    print('  - Incorrect credentials')
    print('='*60)
    raise
