In [None]:
# Environment configuration
ENVIRONMENT = 'local'  # Change to 'kaggle' when running on Kaggle

In [None]:
%pip install torch transformers pandas numpy scikit-learn tqdm biopython -q

In [None]:
import numpy as np
import pandas as pd
import torch
from pathlib import Path
from tqdm.auto import tqdm
from Bio import SeqIO
from transformers import AutoTokenizer, AutoModel
from sklearn.neighbors import NearestNeighbors
from sklearn.model_selection import train_test_split
from collections import Counter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"‚úÖ Imports successful | Device: {device}")

In [None]:
# Set base directory
if ENVIRONMENT == 'kaggle':
    base_dir = Path("/kaggle/input/cafa-6-dataset")
else:
    base_dir = Path.cwd().parent

print(f"üìÅ Base directory: {base_dir}")

## 1. Load Data

In [None]:
# Load sequences
print("Loading sequences...")
sequences = {}
for record in SeqIO.parse(base_dir / "Train" / "train_sequences.fasta", "fasta"):
    sequences[record.id] = str(record.seq)

print(f"Loaded {len(sequences)} sequences")

# Load annotations
print("\nLoading annotations...")
train_terms = pd.read_csv(base_dir / "Train" / "train_terms.tsv", sep='\t')
print(f"Total annotations: {len(train_terms)}")

# Load IA weights
print("\nLoading IA weights...")
ia_df = pd.read_csv(base_dir / "IA.tsv", sep='\t')
ia_weights = dict(zip(ia_df['term'], ia_df['IA']))
print(f"IA weights: {len(ia_weights)}")

## 2. Train/Val Split

In [None]:
# Split proteins
all_proteins = [p for p in train_terms['EntryID'].unique() if p in sequences]
train_proteins, val_proteins = train_test_split(
    all_proteins, test_size=0.2, random_state=42
)

print(f"Train proteins: {len(train_proteins)}")
print(f"Val proteins: {len(val_proteins)}")

# Get validation ground truth
val_data = train_terms[train_terms['EntryID'].isin(val_proteins)]
print(f"Val annotations: {len(val_data)}")

## 3. Generate ESM-2 Embeddings

In [None]:
# Load ESM-2 model
print("Loading ESM-2 model...")
model_name = "facebook/esm2_t6_8M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)
model.eval()

print("‚úÖ Model loaded")

In [None]:
def embed_sequences(protein_ids, sequences_dict, batch_size=8):
    """Generate embeddings for a list of proteins."""
    embeddings = []
    
    with torch.no_grad():
        for i in tqdm(range(0, len(protein_ids), batch_size), desc="Embedding"):
            batch_ids = protein_ids[i:i+batch_size]
            batch_seqs = [sequences_dict[pid] for pid in batch_ids]
            
            # Tokenize
            inputs = tokenizer(
                batch_seqs,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512
            ).to(device)
            
            # Get embeddings
            outputs = model(**inputs)
            
            # Mean pooling (excluding padding)
            attention_mask = inputs['attention_mask']
            token_embeddings = outputs.last_hidden_state
            
            # Expand mask and apply
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
            sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
            pooled = (sum_embeddings / sum_mask).cpu().numpy()
            
            embeddings.append(pooled)
    
    return np.vstack(embeddings)

print("‚úÖ Embedding function defined")

In [None]:
# Generate train embeddings
print("Generating train embeddings...")
train_embeddings = embed_sequences(train_proteins, sequences)
print(f"Train embeddings shape: {train_embeddings.shape}")

# Generate val embeddings
print("\nGenerating val embeddings...")
val_embeddings = embed_sequences(val_proteins, sequences)
print(f"Val embeddings shape: {val_embeddings.shape}")

## 4. Build K-NN Index

In [None]:
# Build KNN index
print("Building K-NN index...")
k = 10  # number of neighbours
knn = NearestNeighbors(n_neighbors=k, metric='cosine', n_jobs=-1)
knn.fit(train_embeddings)

print(f"‚úÖ K-NN index built with k={k}")

## 5. Generate Predictions

In [None]:
# Find nearest neighbours for validation set
print("Finding nearest neighbours...")
distances, indices = knn.kneighbors(val_embeddings)

print(f"Distances shape: {distances.shape}")
print(f"Indices shape: {indices.shape}")

In [None]:
# Build protein -> terms mapping for train set
train_annotations = train_terms[train_terms['EntryID'].isin(train_proteins)]
protein_to_terms = train_annotations.groupby('EntryID')['term'].apply(list).to_dict()

print(f"Train proteins with annotations: {len(protein_to_terms)}")

In [None]:
# Generate predictions by aggregating neighbour annotations
print("Aggregating neighbour annotations...")

predictions = []

for val_idx, val_protein in enumerate(tqdm(val_proteins, desc="Predicting")):
    # Get neighbours
    neighbour_indices = indices[val_idx]
    neighbour_distances = distances[val_idx]
    
    # Collect all terms from neighbours with weights
    term_scores = Counter()
    
    for nei_idx, distance in zip(neighbour_indices, neighbour_distances):
        nei_protein = train_proteins[nei_idx]
        nei_terms = protein_to_terms.get(nei_protein, [])
        
        # Weight by similarity (1 - distance for cosine)
        similarity = 1 - distance
        
        for term in nei_terms:
            term_scores[term] += similarity
    
    # Normalize scores to probabilities
    if term_scores:
        max_score = max(term_scores.values())
        for term, score in term_scores.items():
            predictions.append({
                'EntryID': val_protein,
                'term': term,
                'probability': score / max_score
            })

predictions_df = pd.DataFrame(predictions)
print(f"\nTotal predictions: {len(predictions_df)}")
print(f"Avg predictions per protein: {len(predictions_df) / len(val_proteins):.1f}")

## 6. Evaluation

In [None]:
def evaluate_predictions(predictions_df, ground_truth_df, ia_weights_dict, threshold=0.01):
    """Evaluate predictions against ground truth."""
    # Filter predictions by threshold
    pred_filtered = predictions_df[predictions_df['probability'] >= threshold]
    
    # Group by protein
    pred_grouped = pred_filtered.groupby('EntryID')['term'].apply(set).to_dict()
    true_grouped = ground_truth_df.groupby('EntryID')['term'].apply(set).to_dict()
    
    # Compute per-protein metrics
    f1_scores = []
    precisions = []
    recalls = []
    
    for protein in true_grouped.keys():
        true_terms = true_grouped[protein]
        pred_terms = pred_grouped.get(protein, set())
        
        if len(pred_terms) == 0:
            f1_scores.append(0.0)
            precisions.append(0.0)
            recalls.append(0.0)
            continue
        
        # Compute weighted metrics
        tp_weight = sum(ia_weights_dict.get(t, 1.0) for t in true_terms & pred_terms)
        fp_weight = sum(ia_weights_dict.get(t, 1.0) for t in pred_terms - true_terms)
        fn_weight = sum(ia_weights_dict.get(t, 1.0) for t in true_terms - pred_terms)
        
        precision = tp_weight / (tp_weight + fp_weight) if (tp_weight + fp_weight) > 0 else 0
        recall = tp_weight / (tp_weight + fn_weight) if (tp_weight + fn_weight) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        
        f1_scores.append(f1)
        precisions.append(precision)
        recalls.append(recall)
    
    return {
        'f1': np.mean(f1_scores),
        'precision': np.mean(precisions),
        'recall': np.mean(recalls),
        'coverage': len([p for p in pred_grouped if len(pred_grouped[p]) > 0]) / len(true_grouped)
    }

print("‚úÖ Evaluation function defined")

In [None]:
print("Testing different thresholds...\n")

thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]
results = []

for thr in thresholds:
    metrics = evaluate_predictions(predictions_df, val_data, ia_weights, threshold=thr)
    results.append({
        'threshold': thr,
        **metrics
    })
    print(f"Threshold {thr:.2f}: F1={metrics['f1']:.4f}, P={metrics['precision']:.4f}, "
          f"R={metrics['recall']:.4f}, Coverage={metrics['coverage']:.2%}")

# Find best threshold
best_result = max(results, key=lambda x: x['f1'])
print(f"\nüèÜ Best F1: {best_result['f1']:.4f} at threshold {best_result['threshold']:.2f}")

## 7. Save Results

In [None]:
# Save results
results_df = pd.DataFrame(results)
output_path = Path("02_knn_baseline_results.csv")
results_df.to_csv(output_path, index=False)

print(f"‚úÖ Results saved to {output_path}")
print("\nüìä Results:")
print(results_df.to_string(index=False))

## Summary

**K-NN Baseline Performance:**
- Uses ESM-2 embeddings for sequence similarity
- Transfers annotations from k=10 nearest neighbours
- Weights by cosine similarity
- Expected F1: ~0.18 (from previous experiments)

**Next:** 03_model_esm_finetuned.ipynb - Train classifier on embeddings