# Install Dependencies & External Tools

In [1]:
# 1. Install Python libraries
!pip install -q biopython obonet networkx transformers torch tqdm

# 2. Install MMseqs2 (Static binary for Linux)
# We need this for the clustering logic in data_splits.py
!mkdir -p /content/mmseqs
!wget -q https://mmseqs.com/latest/mmseqs-linux-avx2.tar.gz -O /content/mmseqs/mmseqs.tar.gz
!tar xvfz /content/mmseqs/mmseqs.tar.gz -C /content/mmseqs
!ln -sf /content/mmseqs/mmseqs/bin/mmseqs /usr/local/bin/mmseqs

print("Dependencies and MMseqs2 installed.")

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m49.6 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hmmseqs/
mmseqs/userguide.pdf
mmseqs/examples/
mmseqs/examples/DB.fasta
mmseqs/examples/QUERY.fasta
mmseqs/LICENSE.md
mmseqs/README.md
mmseqs/matrices/
mmseqs/matrices/PAM120.out
mmseqs/matrices/PAM190.out
mmseqs/matrices/PAM130.out
mmseqs/matrices/blosum62.out
mmseqs/matrices/blosum80.out
mmseqs/matrices/blosum95.out
mmseqs/matrices/PAM20.out
mmseqs/matrices/blosum75.out
mmseqs/matrices/blosum100.out
mmseqs/matrices/PAM40.out
mmseqs/matrices/PAM100.out
mmseqs/matrices/VTML200.out
mmseqs/matrices/blosum35.out
mmseqs/matrices/PAM70.out
mmseqs/matrices/PAM170.out
mmseqs/matrices/blosum30.out
mmseqs/matrices/VTML40.out
mmseqs/matrices/blosum55.out
mmseqs/matrices/PAM30.out
mmseqs/matrices/PAM10.out
mmseqs/matrices/blosum65.out
mmseqs/matrices/blosum90.out
mmseqs/matrices/PAM160.out
mmseqs/matrices/PAM60.out
mmseqs/matrices/PAM50.out
mmseqs/matrices/PA

# Write Source Code modules

In [2]:
import os
os.makedirs('src', exist_ok=True)

# Add src to python path for immediate imports if needed
import sys
sys.path.append('/content/src')

In [3]:
%%writefile src/go_labeler.py
"""
GO Labeler: Handles Gene Ontology logic for protein function prediction.
"""
import numpy as np
import obonet
import networkx as nx
from collections import defaultdict
from typing import List, Tuple, Dict, Set, Optional

class GOLabeler:
    def __init__(self, obo_path: str, annotations: List[Tuple[str, str]]):
        self.obo_path = obo_path
        self.annotations = annotations
        print(f"Loading GO graph from {obo_path}...")
        self.go_graph = obonet.read_obo(obo_path)
        print(f"Loaded GO graph with {self.go_graph.number_of_nodes()} terms.")

        self.term_to_index: Dict[str, int] = {}
        self.index_to_term: Dict[int, str] = {}
        self.valid_terms: List[str] = []
        self._term_frequencies: Dict[str, int] = {}

    def _get_ancestors(self, term_id: str) -> Set[str]:
        ancestors = set()
        if term_id not in self.go_graph:
            return ancestors
        ancestors.add(term_id)
        try:
            ancestors.update(nx.descendants(self.go_graph, term_id))
        except nx.NetworkXError:
            pass
        return ancestors

    def _propagate_terms(self, terms: List[str]) -> Set[str]:
        all_terms = set()
        for term in terms:
            all_terms.update(self._get_ancestors(term))
        return all_terms

    def build_label_vocabulary(self, min_frequency: int = 50) -> None:
        print(f"Building label vocabulary with min_frequency={min_frequency}...")
        protein_to_terms: Dict[str, List[str]] = defaultdict(list)
        for protein_id, term_id in self.annotations:
            protein_to_terms[protein_id].append(term_id)

        term_counts: Dict[str, int] = defaultdict(int)
        for protein_id, terms in protein_to_terms.items():
            propagated_terms = self._propagate_terms(terms)
            for term in propagated_terms:
                term_counts[term] += 1

        self._term_frequencies = dict(term_counts)
        self.valid_terms = sorted([
            term for term, count in term_counts.items()
            if count >= min_frequency
        ])

        print(f"Terms meeting min_frequency threshold: {len(self.valid_terms)}")
        self.term_to_index = {term: idx for idx, term in enumerate(self.valid_terms)}
        self.index_to_term = {idx: term for idx, term in enumerate(self.valid_terms)}

    def get_vector(self, protein_terms: List[str]) -> np.ndarray:
        if not self.valid_terms:
            raise ValueError("Vocabulary not built.")
        vector = np.zeros(len(self.valid_terms), dtype=np.float32)
        propagated_terms = self._propagate_terms(protein_terms)
        for term in propagated_terms:
            if term in self.term_to_index:
                idx = self.term_to_index[term]
                vector[idx] = 1.0
        return vector

    def vocabulary_size(self) -> int:
        return len(self.valid_terms)

Overwriting src/go_labeler.py


In [4]:
%%writefile src/data_splits.py
"""
Data Splitting Utilities for Protein Function Prediction.
"""
import random
from collections import defaultdict
from typing import Dict, List, Set, Tuple, Optional

def create_splits(
    cluster_file: str,
    val_ratio: float = 0.2,
    random_seed: Optional[int] = 42
) -> Tuple[Set[str], Set[str]]:
    if random_seed is not None:
        random.seed(random_seed)

    clusters: Dict[str, List[str]] = defaultdict(list)
    print(f"Reading cluster file: {cluster_file}")
    with open(cluster_file, 'r') as f:
        for line in f:
            line = line.strip()
            if not line: continue
            parts = line.split('\t')
            if len(parts) < 2: continue
            cluster_rep, sequence_member = parts[0], parts[1]
            clusters[cluster_rep].append(sequence_member)

    cluster_reps = list(clusters.keys())
    n_clusters = len(cluster_reps)
    total_sequences = sum(len(members) for members in clusters.values())
    print(f"Found {n_clusters} clusters containing {total_sequences} sequences")

    random.shuffle(cluster_reps)
    n_val_clusters = int(n_clusters * val_ratio)

    val_cluster_reps = set(cluster_reps[:n_val_clusters])

    train_protein_ids: Set[str] = set()
    val_protein_ids: Set[str] = set()

    for cluster_rep, members in clusters.items():
        if cluster_rep in val_cluster_reps:
            val_protein_ids.update(members)
        else:
            train_protein_ids.update(members)

    print(f"Split complete: Train {len(train_protein_ids)}, Val {len(val_protein_ids)}")
    return train_protein_ids, val_protein_ids

Overwriting src/data_splits.py


In [5]:
%%writefile src/dataset.py
"""
PyTorch Dataset for Protein Function Prediction.
"""
import os
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import torch
from torch.utils.data import Dataset
import numpy as np

class ProteinGODataset(Dataset):
    def __init__(
        self,
        protein_ids: List[str],
        labeler,
        embedding_dir: str,
        annotations: Dict[str, List[str]],
        check_exists: bool = True
    ):
        self.labeler = labeler
        self.embedding_dir = Path(embedding_dir)
        self.annotations = annotations

        if check_exists:
            self.protein_ids = []
            for pid in protein_ids:
                embedding_path = self.embedding_dir / f"{pid}.pt"
                if embedding_path.exists():
                    self.protein_ids.append(pid)
        else:
            self.protein_ids = list(protein_ids)

        print(f"Dataset initialized with {len(self.protein_ids)} proteins")

    def __len__(self) -> int:
        return len(self.protein_ids)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        protein_id = self.protein_ids[idx]
        embedding_path = self.embedding_dir / f"{protein_id}.pt"
        embedding = torch.load(embedding_path, weights_only=True).to(torch.float32)

        explicit_terms = self.annotations.get(protein_id, [])
        label_numpy = self.labeler.get_vector(explicit_terms)
        label = torch.from_numpy(label_numpy).to(torch.float32)

        return embedding, label

Overwriting src/dataset.py


In [6]:
%%writefile src/model.py
"""
Neural Network Model and Loss Components.
"""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from typing import Optional

class ProteinMLP(nn.Module):
    def __init__(self, num_classes: int, input_dim: int = 1280, hidden_dim: int = 512, dropout: float = 0.3):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes

        self.hidden = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        self.output = nn.Linear(hidden_dim, num_classes)
        self._init_weights()

    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.hidden(x)
        return self.output(x)

def calculate_pos_weights(train_dataset, num_workers: int = 0, batch_size: int = 256) -> torch.Tensor:
    loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    _, first_label = train_dataset[0]
    num_classes = first_label.shape[0]

    positive_counts = torch.zeros(num_classes, dtype=torch.float64)
    total_samples = 0

    print("Calculating positive class weights...")
    for _, labels in tqdm(loader, desc="Computing weights"):
        positive_counts += labels.sum(dim=0).to(torch.float64)
        total_samples += labels.shape[0]

    epsilon = 1e-7
    negative_counts = total_samples - positive_counts
    pos_weights = negative_counts / (positive_counts + epsilon)
    pos_weights = torch.clamp(pos_weights, min=1.0, max=100.0)

    return pos_weights.to(torch.float32)

def create_loss_function(pos_weights: Optional[torch.Tensor] = None, device: str = "cuda") -> nn.BCEWithLogitsLoss:
    if pos_weights is not None:
        pos_weights = pos_weights.to(device)
        return nn.BCEWithLogitsLoss(pos_weight=pos_weights)
    return nn.BCEWithLogitsLoss()

Overwriting src/model.py


In [7]:
%%writefile src/threshold_optimizer.py
"""
Threshold Optimization for Multi-Label Classification.
"""
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.amp import autocast
from sklearn.metrics import precision_recall_curve, f1_score
from tqdm import tqdm
from typing import Tuple, Optional

def collect_predictions(model, data_loader, device, use_amp=True):
    model.eval()
    use_amp = use_amp and device == "cuda"
    all_probs = []
    all_labels = []

    with torch.no_grad():
        for embeddings, labels in tqdm(data_loader, desc="Collecting predictions"):
            embeddings = embeddings.to(device)
            with autocast(device_type=device, enabled=use_amp):
                logits = model(embeddings)
            probs = torch.sigmoid(logits).cpu().numpy()
            all_probs.append(probs)
            all_labels.append(labels.numpy())

    return np.concatenate(all_probs, axis=0), np.concatenate(all_labels, axis=0)

def optimize_thresholds(model, val_loader, device, use_amp=True) -> np.ndarray:
    print("Optimizing Per-Class Thresholds")
    y_probs, y_true = collect_predictions(model, val_loader, device, use_amp)
    n_classes = y_probs.shape[1]
    optimal_thresholds = np.full(n_classes, 0.5)

    for class_idx in tqdm(range(n_classes), desc="Optimizing"):
        y_true_class = y_true[:, class_idx]
        if y_true_class.sum() == 0 or y_true_class.sum() == len(y_true_class):
            continue

        precision, recall, thresholds = precision_recall_curve(y_true_class, y_probs[:, class_idx])
        with np.errstate(divide='ignore', invalid='ignore'):
            f1_scores = 2 * (precision[:-1] * recall[:-1]) / (precision[:-1] + recall[:-1])
            f1_scores = np.nan_to_num(f1_scores, nan=0.0)

        if len(f1_scores) > 0 and f1_scores.max() > 0:
            optimal_thresholds[class_idx] = thresholds[np.argmax(f1_scores)]

    return optimal_thresholds

def evaluate_with_thresholds(model, data_loader, device, thresholds, use_amp=True):
    y_probs, y_true = collect_predictions(model, data_loader, device, use_amp)
    y_pred = (y_probs >= thresholds).astype(np.float32)
    micro_f1 = f1_score(y_true, y_pred, average='micro', zero_division=0)
    macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)

    print(f"Evaluation - Micro F1: {micro_f1:.4f}, Macro F1: {macro_f1:.4f}")
    return {'micro_f1': micro_f1, 'macro_f1': macro_f1}

Overwriting src/threshold_optimizer.py


In [8]:
%%writefile src/trainer.py
import time
from pathlib import Path
from typing import Optional, Dict, Any
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.amp import GradScaler, autocast
from tqdm import tqdm

def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    loss_fn: nn.Module,
    optimizer: torch.optim.Optimizer,
    device: str,
    epochs: int,
    patience: int = 5,
    checkpoint_dir: str = "models",
    checkpoint_name: str = "best_model.pt",
    scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
    use_amp: bool = True
) -> Dict[str, Any]:
    
    Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
    best_model_path = Path(checkpoint_dir) / checkpoint_name
    
    # Only use AMP if requested AND device is cuda
    use_amp = use_amp and (device == "cuda")
    scaler = GradScaler(enabled=use_amp)
    
    train_losses, val_losses = [], []
    best_val_loss = float('inf')
    best_epoch = 0
    patience_counter = 0
    stopped_early = False
    
    # Note: Model is assumed to be moved to device (or wrapped in DataParallel) by main.py
    
    for epoch in range(epochs):
        model.train()
        train_loss_sum = 0.0
        batches = 0
        
        # Training Loop
        for embeddings, labels in tqdm(train_loader, desc=f"Epoch {epoch+1} Train", leave=False):
            # If DataParallel, input is split automatically
            embeddings, labels = embeddings.to(device), labels.to(device)
            optimizer.zero_grad()
            
            with autocast(device_type=device, enabled=use_amp):
                logits = model(embeddings)
                loss = loss_fn(logits, labels)
            
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            
            train_loss_sum += loss.item()
            batches += 1
            
        avg_train_loss = train_loss_sum / batches
        train_losses.append(avg_train_loss)
        
        # Validation Loop
        model.eval()
        val_loss_sum = 0.0
        val_batches = 0
        with torch.no_grad():
            for embeddings, labels in tqdm(val_loader, desc=f"Epoch {epoch+1} Val", leave=False):
                embeddings, labels = embeddings.to(device), labels.to(device)
                with autocast(device_type=device, enabled=use_amp):
                    logits = model(embeddings)
                    loss = loss_fn(logits, labels)
                val_loss_sum += loss.item()
                val_batches += 1
        
        avg_val_loss = val_loss_sum / val_batches
        val_losses.append(avg_val_loss)
        
        # Scheduler Step
        if scheduler:
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(avg_val_loss)
            else:
                scheduler.step()
                
        print(f"Epoch {epoch+1} - Train: {avg_train_loss:.4f}, Val: {avg_val_loss:.4f}")
        
        # Early Stopping & Checkpointing
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_epoch = epoch + 1
            patience_counter = 0
            
            # Handle DataParallel: save the underlying module, not the wrapper
            if isinstance(model, nn.DataParallel):
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()
                
            torch.save({
                'model_state_dict': state_dict,
                'val_loss': best_val_loss,
                'config': {
                    'input_dim': model.module.input_dim if isinstance(model, nn.DataParallel) else model.input_dim,
                    'num_classes': model.module.num_classes if isinstance(model, nn.DataParallel) else model.num_classes
                }
            }, best_model_path)
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                stopped_early = True
                break
                
    # Load best model for return
    # We load into the original model architecture
    print("Loading best model weights...")
    checkpoint = torch.load(best_model_path)
    
    if isinstance(model, nn.DataParallel):
        model.module.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint['model_state_dict'])
    
    return {'train_losses': train_losses, 'val_losses': val_losses, 'best_epoch': best_epoch, 'best_val_loss': best_val_loss}

def load_checkpoint(model, path, device="cpu"):
    """
    Load a checkpoint. Handles mismatch between DataParallel and single model.
    """
    checkpoint = torch.load(path, map_location=device)
    state_dict = checkpoint['model_state_dict']
    
    # If loading into a DataParallel model but checkpoint is standard
    if isinstance(model, nn.DataParallel):
        model.module.load_state_dict(state_dict)
    else:
        # If checkpoint has 'module.' prefix (from raw DP save) but model is standard
        # (This handles cases where simple torch.save(model.state_dict()) was used on DP)
        new_state_dict = {}
        for k, v in state_dict.items():
            if k.startswith('module.'):
                new_state_dict[k[7:]] = v
            else:
                new_state_dict[k] = v
        model.load_state_dict(new_state_dict)
        
    return model, {}

Overwriting src/trainer.py


In [9]:
%%writefile src/extract_embeddings.py
import os
import argparse
import math
from pathlib import Path
import torch
import torch.multiprocessing as mp
from transformers import AutoTokenizer, AutoModel
from Bio import SeqIO
from tqdm import tqdm

# Configuration
MAX_SEQUENCE_LENGTH = 1022

def get_device(gpu_id):
    return f"cuda:{gpu_id}"

def worker_process(gpu_id, all_records, output_dir, model_name):
    """
    Worker function to run on a specific GPU.
    """
    device = get_device(gpu_id)
    print(f"[GPU {gpu_id}] Initializing model {model_name}...")
    
    # 1. Calculate the slice of data for this GPU
    total_records = len(all_records)
    n_gpus = torch.cuda.device_count()
    chunk_size = math.ceil(total_records / n_gpus)
    
    start_idx = gpu_id * chunk_size
    end_idx = min(start_idx + chunk_size, total_records)
    
    my_records = all_records[start_idx:end_idx]
    
    print(f"[GPU {gpu_id}] Processing indices {start_idx} to {end_idx} ({len(my_records)} sequences)")
    
    # 2. Load model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name).to(device)
    model.eval()
    
    output_path = Path(output_dir)
    processed_count = 0
    
    # 3. Process assigned records
    for record in tqdm(my_records, desc=f"GPU {gpu_id}", position=gpu_id):
        out_file = output_path / f"{record.id}.pt"
        
        # Skip if already exists
        if out_file.exists():
            continue
            
        sequence = str(record.seq)
        
        # Truncate
        if len(sequence) > MAX_SEQUENCE_LENGTH:
            sequence = sequence[:MAX_SEQUENCE_LENGTH]
            
        try:
            # Tokenize
            inputs = tokenizer(
                sequence, 
                return_tensors="pt", 
                padding=False, 
                truncation=True, 
                max_length=MAX_SEQUENCE_LENGTH + 2
            )
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            # Inference
            with torch.no_grad():
                outputs = model(**inputs)
            
            # Mean pooling (exclude CLS and EOS)
            embedding = outputs.last_hidden_state[:, 1:-1, :].mean(dim=1).squeeze(0)
            
            # Save to CPU
            torch.save(embedding.cpu(), out_file)
            processed_count += 1
            
        except Exception as e:
            print(f"[GPU {gpu_id}] Error processing {record.id}: {e}")

    print(f"[GPU {gpu_id}] Finished. Processed {processed_count} sequences.")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--fasta", required=True)
    parser.add_argument("--output", required=True)
    parser.add_argument("--model", default="facebook/esm2_t36_3B_UR50D")
    args = parser.parse_args()
    
    Path(args.output).mkdir(parents=True, exist_ok=True)
    
    print(f"Reading FASTA file: {args.fasta}")
    all_records = list(SeqIO.parse(args.fasta, "fasta"))
    print(f"Total sequences: {len(all_records)}")
    
    n_gpus = torch.cuda.device_count()
    if n_gpus < 1:
        print("No GPUs found! Using CPU (single process).")
        worker_process(0, all_records, args.output, args.model)
        return

    print(f"Found {n_gpus} GPUs. Spawning workers...")
    
    # Pass the FULL list to all workers; they will slice it themselves based on their gpu_id
    mp.spawn(
        worker_process,
        args=(all_records, args.output, args.model),
        nprocs=n_gpus,
        join=True
    )
    print("All extraction processes completed.")

if __name__ == "__main__":
    main()

Overwriting src/extract_embeddings.py


In [10]:
%%writefile src/main.py
import os
import sys
import argparse
from pathlib import Path
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from go_labeler import GOLabeler
from data_splits import create_splits
from dataset import ProteinGODataset
from model import ProteinMLP, calculate_pos_weights, create_loss_function
from trainer import train_model, load_checkpoint
from threshold_optimizer import optimize_thresholds, evaluate_with_thresholds

def load_annotations(path):
    ann_list, ann_dict = [], defaultdict(list)
    with open(path, 'r') as f:
        for line in f:
            line = line.strip()
            if not line: continue
            parts = line.split('\t')
            if len(parts) >= 2:
                ann_list.append((parts[0], parts[1]))
                ann_dict[parts[0]].append(parts[1])
    return ann_list, dict(ann_dict)

def main(args):
    # Device Setup
    if torch.cuda.is_available():
        device = "cuda"
        n_gpus = torch.cuda.device_count()
        print(f"Using {n_gpus} GPUs!")
    else:
        device = "cpu"
        print("Using CPU")
    
    # 1. Load Data
    annotations_list, annotations_dict = load_annotations(args.annotations_path)
    labeler = GOLabeler(args.obo_path, annotations_list)
    labeler.build_label_vocabulary(min_frequency=args.min_frequency)
    
    # 2. Splits
    train_ids, val_ids = create_splits(args.cluster_path, args.val_ratio, args.seed)
    train_ids = train_ids & set(annotations_dict.keys())
    val_ids = val_ids & set(annotations_dict.keys())
    
    # 3. Datasets
    train_dataset = ProteinGODataset(list(train_ids), labeler, args.embedding_dir, annotations_dict)
    val_dataset = ProteinGODataset(list(val_ids), labeler, args.embedding_dir, annotations_dict)
    
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
    
    # 4. Model Setup
    pos_weights = calculate_pos_weights(train_dataset) if args.use_pos_weights else None
    
    model = ProteinMLP(
        num_classes=labeler.vocabulary_size(), 
        input_dim=args.input_dim, 
        hidden_dim=args.hidden_dim, 
        dropout=args.dropout
    )
    
    # Wrap model for Multi-GPU
    if torch.cuda.device_count() > 1:
        print(f"Wrapping model in DataParallel (Batch size {args.batch_size} will be split across GPUs)")
        model = nn.DataParallel(model)
    
    model = model.to(device)
    
    loss_fn = create_loss_function(pos_weights, device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)
    
    # 5. Train
    train_model(
        model, train_loader, val_loader, loss_fn, optimizer, device, 
        args.epochs, args.patience, args.output_dir, "best_model.pt", scheduler,
        use_amp=args.use_amp # Passed from args
    )
    
    # 6. Optimization
    print("Loading best model for optimization...")
    # Load best model (Trainer handles un-wrapping, so we just reload into the object)
    # But we need to handle if the current 'model' object is wrapped or not
    checkpoint_path = str(Path(args.output_dir)/"best_model.pt")
    model, _ = load_checkpoint(model, checkpoint_path, device)
    
    thresholds = optimize_thresholds(model, val_loader, device)
    evaluate_with_thresholds(model, val_loader, device, thresholds)
    
    # Save artifacts
    np.save(Path(args.output_dir)/"thresholds.npy", thresholds)
    np.savez(Path(args.output_dir)/"vocabulary.npz", valid_terms=np.array(labeler.valid_terms))

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--obo_path", default="/kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo")
    parser.add_argument("--annotations_path", default="/kaggle/input/cafa-6-protein-function-prediction/Train/train_terms.tsv")
    parser.add_argument("--cluster_path", default="data/splits/train_cluster.tsv")
    parser.add_argument("--embedding_dir", default="data/embeddings/train")
    parser.add_argument("--output_dir", default="results/experiment_001")
    
    # Updated defaults for 3B model
    parser.add_argument("--input_dim", type=int, default=1280)
    
    parser.add_argument("--hidden_dim", type=int, default=512)
    parser.add_argument("--dropout", type=float, default=0.3)
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--epochs", type=int, default=50)
    parser.add_argument("--patience", type=int, default=5)
    parser.add_argument("--learning_rate", type=float, default=1e-4)
    parser.add_argument("--weight_decay", type=float, default=1e-5)
    parser.add_argument("--min_frequency", type=int, default=50)
    parser.add_argument("--val_ratio", type=float, default=0.2)
    parser.add_argument("--use_pos_weights", action="store_true")
    
    # Default to FALSE for AMP if user wants strict f32, or TRUE if they just want speed
    # Since prompt said "no need to cast into f16", we can default this to False to be safe
    # or keep it True for training speed (Trainer handles it safely). 
    # I'll default to False to strictly follow the user's "f32 works" sentiment.
    parser.add_argument("--use_amp", action="store_true", help="Enable Mixed Precision Training")
    
    parser.add_argument("--num_workers", type=int, default=4)
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()
    main(args)

Overwriting src/main.py


# Data Processing - Clustering

In [11]:
# Generate clusters (30% identity)
# This creates data/splits/train_cluster.tsv
!mmseqs easy-linclust \
    /kaggle/input/cafa-6-protein-function-prediction/Train/train_sequences.fasta \
    ./data/splits/train \
    tmp \
    --min-seq-id 0.3 \
    --cov-mode 1 \
    -c 0.8

# Rename output to match what src/data_splits.py expects
!mv data/splits/train_cluster.tsv data/splits/train_cluster.tsv.bak
# Filter out headers if any and ensure tab separated
!awk '{print $1"\t"$2}' data/splits/train_cluster.tsv.bak > data/splits/train_cluster.tsv
!head -n 5 data/splits/train_cluster.tsv

easy-linclust /kaggle/input/cafa-6-protein-function-prediction/Train/train_sequences.fasta ./data/splits/train tmp --min-seq-id 0.3 --cov-mode 1 -c 0.8 

MMseqs Version:                     	bd01c2229f027d8d8e61947f44d11ef1a7669212
Cluster mode                        	0
Max connected component depth       	1000
Similarity type                     	2
Threads                             	4
Compressed                          	0
Verbosity                           	3
Weight file name                    	
Cluster Weight threshold            	0.9
Set mode                            	false
Substitution matrix                 	aa:blosum62.out,nucl:nucleotide.out
Add backtrace                       	false
Alignment mode                      	0
Alignment mode                      	0
Allow wrapped scoring               	false
E-value threshold                   	0.001
Seq. id. threshold                  	0.3
Min alignment length                	0
Seq. id. mode                       	0
Alternativ

# Data Processing - Embedding Extraction

In [12]:
# Run embedding extraction
# This takes time! (Approx 3-5 hours for full CAFA train set on T4)
!python src/extract_embeddings.py \
    --fasta /kaggle/input/cafa-6-protein-function-prediction/Train/train_sequences.fasta \
    --output ./data/embeddings/train \
    --model facebook/esm2_t33_650M_UR50D

Reading FASTA file: /kaggle/input/cafa-6-protein-function-prediction/Train/train_sequences.fasta
Total sequences: 82404
Found 2 GPUs. Spawning workers...
[GPU 0] Initializing model facebook/esm2_t33_650M_UR50D...
[GPU 0] Processing indices 0 to 41202 (41202 sequences)
tokenizer_config.json: 100%|██████████████████| 95.0/95.0 [00:00<00:00, 597kB/s]
vocab.txt: 100%|██████████████████████████████| 93.0/93.0 [00:00<00:00, 508kB/s]
special_tokens_map.json: 100%|██████████████████| 125/125 [00:00<00:00, 697kB/s]
config.json: 100%|█████████████████████████████| 724/724 [00:00<00:00, 3.56MB/s]
2026-01-25 08:03:09.024728: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1769328189.206618     219 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1769328189

# Zip the Results

In [14]:
# --- Configuration ---
# Ensure these match where you downloaded/unzipped the data
OUTPUT_DIR = "data/embeddings/"
ZIP_NAME = "esm2_650m_train_embeddings.zip"

if os.path.exists(OUTPUT_DIR) and len(os.listdir(OUTPUT_DIR)) > 0:
    print(f"Zipping files to {ZIP_NAME}...")
    # -q for quiet, -r for recursive, -j to junk paths (optional, but -r preserves structure)
    !zip -q -r {ZIP_NAME} {OUTPUT_DIR}
    
    print(f"Success! '{ZIP_NAME}' is ready in /kaggle/working/")
    print("To keep this file permanently: Click 'Save Version' -> 'Save & Run All' or 'Quick Save'.")
    
    # Optional: Clean up raw .pt files to save disk space if you are running low
    # !rm -rf {OUTPUT_DIR}
else:
    print("Error: Output directory is empty or does not exist.")

Zipping files to esm2_650m_train_embeddings.zip...
Success! 'esm2_650m_train_embeddings.zip' is ready in /kaggle/working/
To keep this file permanently: Click 'Save Version' -> 'Save & Run All' or 'Quick Save'.


# Renaming the embeddings

In [8]:
import os
from pathlib import Path
from tqdm import tqdm

# Paths
EMBED_DIR = Path("data/embeddings/train")
CLUSTER_FILE = Path("data/splits/train_cluster.tsv")
CLUSTER_FILE_NEW = Path("data/splits/train_cluster_fixed.tsv")

def clean_id(header_str):
    """
    Parses 'sp|Q497K5|ARRD5_MOUSE' -> 'Q497K5'
    Parses 'tr|A0A024RBG1|...' -> 'A0A024RBG1'
    Returns original if no pipes found.
    """
    header_str = str(header_str).strip()
    if header_str.count('|') >= 2:
        return header_str.split('|')[1]
    return header_str

print("1. Normalizing Embedding Filenames...")
files = list(EMBED_DIR.glob("*.pt"))
renamed_count = 0

for file_path in tqdm(files):
    old_name = file_path.stem # e.g., sp|Q497K5|ARRD5_MOUSE
    new_id = clean_id(old_name)
    
    if new_id != old_name:
        new_path = file_path.with_name(f"{new_id}.pt")
        # Rename (overwrite if exists, though unlikely in this batch)
        file_path.rename(new_path)
        renamed_count += 1

print(f"Renamed {renamed_count} embedding files.")

print("\n2. Normalizing Cluster File...")
if CLUSTER_FILE.exists():
    with open(CLUSTER_FILE, 'r') as f_in, open(CLUSTER_FILE_NEW, 'w') as f_out:
        for line in f_in:
            parts = line.strip().split('\t')
            if len(parts) >= 2:
                # Clean both representative and member
                rep = clean_id(parts[0])
                member = clean_id(parts[1])
                f_out.write(f"{rep}\t{member}\n")
    
    # Replace old cluster file
    os.replace(CLUSTER_FILE_NEW, CLUSTER_FILE)
    print(f"Fixed IDs in {CLUSTER_FILE}")
else:
    print(f"Warning: {CLUSTER_FILE} not found. Run the mmseqs step first?")

print("\nID Normalization Complete. You can now run main.py.")

1. Normalizing Embedding Filenames...


100%|██████████| 82404/82404 [00:01<00:00, 49981.22it/s]


Renamed 82404 embedding files.

2. Normalizing Cluster File...
Fixed IDs in data/splits/train_cluster.tsv

ID Normalization Complete. You can now run main.py.


# Train the model

In [10]:
!python src/main.py \
    --obo_path /kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo \
    --annotations_path /kaggle/input/cafa-6-protein-function-prediction/Train/train_terms.tsv \
    --cluster_path /kaggle/working/data/splits/train_cluster.tsv \
    --embedding_dir /kaggle/working/data/embeddings/train \
    --output_dir ./results/final_run \
    --input_dim 1280 \
    --batch_size 512 \
    --epochs 75 \
    --use_pos_weights \
    --use_amp 

Using 2 GPUs!
Loading GO graph from /kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo...
Loaded GO graph with 40122 terms.
Building label vocabulary with min_frequency=50...
Terms meeting min_frequency threshold: 5604
Reading cluster file: /kaggle/working/data/splits/train_cluster.tsv
Found 45294 clusters containing 82404 sequences
Split complete: Train 66033, Val 16371
Dataset initialized with 66033 proteins
Dataset initialized with 16371 proteins
Calculating positive class weights...
Computing weights: 100%|██████████████████████| 258/258 [00:41<00:00,  6.21it/s]
Wrapping model in DataParallel (Batch size 512 will be split across GPUs)
Epoch 1 - Train: 0.6989, Val: 0.5726                                            
Epoch 2 - Train: 0.5442, Val: 0.5068                                            
Epoch 3 - Train: 0.4945, Val: 0.4730                                            
Epoch 4 - Train: 0.4649, Val: 0.4517                                            
Epoch 5 - Tr

### 

# Inference

In [4]:
%%writefile src/inference.py
import os
import argparse
import math
import torch
import torch.nn as nn
import torch.multiprocessing as mp
import numpy as np
import gc
from pathlib import Path
from transformers import AutoTokenizer, AutoModel
from Bio import SeqIO
from tqdm import tqdm

# Import your model definition
from model import ProteinMLP

# Configuration
MAX_SEQUENCE_LENGTH = 1022
THRESHOLD = 0.01
BATCH_SIZE = 100  # Number of sequences to process before saving to disk

def get_device(gpu_id):
    return f"cuda:{gpu_id}"

def clean_id(header_str):
    header_str = str(header_str).strip()
    if header_str.count('|') >= 2:
        return header_str.split('|')[1]
    return header_str.split()[0]

def load_artifacts(model_dir):
    """Load vocabulary and model config"""
    model_dir = Path(model_dir)
    
    # Load Vocabulary
    vocab_path = model_dir / "vocabulary.npz"
    data = np.load(vocab_path, allow_pickle=True)
    valid_terms = data['valid_terms']
    index_to_term = {i: term for i, term in enumerate(valid_terms)}
    
    # Load Checkpoint to get config/weights
    checkpoint_path = model_dir / "best_model.pt"
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    config = checkpoint.get('config', {})
    
    return index_to_term, config, checkpoint['model_state_dict']

def worker_process(gpu_id, all_records, model_dir, esm_model_name, hidden_dim, output_dir):
    device = get_device(gpu_id)
    print(f"[GPU {gpu_id}] Worker started.")
    
    # 1. Calculate Slice for this GPU
    total_records = len(all_records)
    n_gpus = torch.cuda.device_count()
    if n_gpus == 0: n_gpus = 1
    
    chunk_size = math.ceil(total_records / n_gpus)
    start_idx = gpu_id * chunk_size
    end_idx = min(start_idx + chunk_size, total_records)
    
    my_records = all_records[start_idx:end_idx]
    
    # 2. Setup Batches
    num_batches = math.ceil(len(my_records) / BATCH_SIZE)
    print(f"[GPU {gpu_id}] Processing {len(my_records)} seqs in {num_batches} batches.")
    
    # 3. Load Artifacts & Model
    # We load strictly inside the worker to avoid pickling issues
    index_to_term, config, state_dict = load_artifacts(model_dir)
    input_dim = config.get('input_dim', 1280) 
    num_classes = config.get('num_classes', len(index_to_term))
    
    tokenizer = AutoTokenizer.from_pretrained(esm_model_name)
    esm_model = AutoModel.from_pretrained(esm_model_name).to(device)
    esm_model.eval()
    
    mlp = ProteinMLP(num_classes=num_classes, input_dim=input_dim, hidden_dim=hidden_dim)
    
    # Clean state dict keys
    new_state_dict = {}
    for k, v in state_dict.items():
        name = k[7:] if k.startswith('module.') else k
        new_state_dict[name] = v
    mlp.load_state_dict(new_state_dict)
    mlp = mlp.to(device)
    mlp.eval()
    
    # 4. Batch Processing Loop
    for i in range(num_batches):
        batch_filename = os.path.join(output_dir, f"batch_gpu{gpu_id}_{i}.tsv")
        
        # RESUME CAPABILITY: Skip if file exists and is not empty
        if os.path.exists(batch_filename) and os.path.getsize(batch_filename) > 0:
            # Only print every 10th skipped batch to reduce clutter
            if i % 10 == 0:
                print(f"[GPU {gpu_id}] Skipping batch {i}/{num_batches} (already exists)")
            continue
            
        # Get batch sequences
        b_start = i * BATCH_SIZE
        b_end = min(b_start + BATCH_SIZE, len(my_records))
        batch_records = my_records[b_start:b_end]
        
        batch_lines = []
        
        # Process individual sequences in this batch
        # (We could batch-tokenize here for more speed, but let's keep it safe for VRAM)
        for record in batch_records:
            protein_id = clean_id(record.id)
            sequence = str(record.seq)
            if len(sequence) > MAX_SEQUENCE_LENGTH:
                sequence = sequence[:MAX_SEQUENCE_LENGTH]
                
            try:
                # Embed
                inputs = tokenizer(sequence, return_tensors="pt", padding=False, truncation=True, max_length=MAX_SEQUENCE_LENGTH+2)
                inputs = {k: v.to(device) for k, v in inputs.items()}
                
                with torch.no_grad():
                    esm_out = esm_model(**inputs)
                    embedding = esm_out.last_hidden_state[:, 1:-1, :].mean(dim=1)
                    
                    # Predict
                    logits = mlp(embedding)
                    probs = torch.sigmoid(logits).cpu().numpy()[0]
                
                # Filter
                mask = probs >= THRESHOLD
                indices = np.where(mask)[0]
                
                for idx in indices:
                    batch_lines.append(f"{protein_id}\t{index_to_term[idx]}\t{probs[idx]:.3f}\n")
                
                # Explicit cleanup to prevent memory leaks
                del inputs, esm_out, embedding, logits, probs
                
            except Exception as e:
                print(f"[GPU {gpu_id}] Error on {protein_id}: {e}")
        
        # Write Batch to Disk
        with open(batch_filename, 'w') as f_out:
            if batch_lines:
                f_out.writelines(batch_lines)
            else:
                # Create empty file to mark progress even if no preds found
                f_out.write("") 
        
        # Aggressive Garbage Collection
        if i % 5 == 0:
            gc.collect()
            torch.cuda.empty_cache()
            
        print(f"[GPU {gpu_id}] Completed batch {i}/{num_batches}")

    print(f"[GPU {gpu_id}] Finished.")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--fasta", required=True)
    parser.add_argument("--model_dir", required=True)
    parser.add_argument("--output", default="submission.tsv")
    parser.add_argument("--esm_model", default="facebook/esm2_t33_650M_UR50D")
    parser.add_argument("--hidden_dim", type=int, default=1024)
    args = parser.parse_args()
    
    # Setup Temporary Directory for Batches
    temp_dir = "temp_inference_batches"
    os.makedirs(temp_dir, exist_ok=True)
    
    print(f"Reading {args.fasta}...")
    all_records = list(SeqIO.parse(args.fasta, "fasta"))
    n_gpus = torch.cuda.device_count()
    
    if n_gpus == 0:
        print("No GPUs found. Running on CPU.")
        worker_process(0, all_records, args.model_dir, args.esm_model, args.hidden_dim, temp_dir)
    else:
        mp.spawn(
            worker_process,
            args=(all_records, args.model_dir, args.esm_model, args.hidden_dim, temp_dir),
            nprocs=n_gpus,
            join=True
        )
    
    # Optional cleanup (commented out to be safe)
    # import shutil
    # shutil.rmtree(temp_dir)
                
    print(f"Submission saved to {args.output}")

if __name__ == "__main__":
    main()

Overwriting src/inference.py


In [2]:
!python src/inference.py \
    --fasta /kaggle/input/cafa-6-protein-function-prediction/Test/testsuperset.fasta \
    --model_dir results/final_run \
    --output submission.tsv \
    --esm_model facebook/esm2_t33_650M_UR50D \
    --hidden_dim 512

Reading /kaggle/input/cafa-6-protein-function-prediction/Test/testsuperset.fasta...
[GPU 0] Worker started.
[GPU 0] Processing 112155 seqs in 1122 batches.
tokenizer_config.json: 100%|██████████████████| 95.0/95.0 [00:00<00:00, 473kB/s]
vocab.txt: 100%|██████████████████████████████| 93.0/93.0 [00:00<00:00, 654kB/s]
special_tokens_map.json: 100%|██████████████████| 125/125 [00:00<00:00, 798kB/s]
config.json: 100%|█████████████████████████████| 724/724 [00:00<00:00, 4.72MB/s]
2026-01-28 07:58:50.135012: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1769587130.392058     128 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1769587130.463670     128 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin c

# Merging

In [3]:
import os
import zipfile
from tqdm import tqdm

# --- Configuration ---
TEMP_DIR = "temp_inference_batches"
OUTPUT_ZIP = "submission.zip"
INTERNAL_FILENAME = "submission.tsv"

def merge_compress_cleanup():
    # 1. Check if temp dir exists
    if not os.path.exists(TEMP_DIR):
        print(f"Error: {TEMP_DIR} not found. Did the inference run?")
        return

    # 2. Get list of files
    batch_files = sorted([f for f in os.listdir(TEMP_DIR) if f.startswith("batch_")])
    print(f"Found {len(batch_files)} batch files to merge.")
    
    if len(batch_files) == 0:
        print("❌ No batch files found! They were likely deleted during the failed merge.")
        print("   Please RE-RUN the inference step (Step 2).")
        print("   It will regenerate the missing batches.")
        return

    # 3. Clean up previous corrupted zip if exists
    if os.path.exists(OUTPUT_ZIP):
        print(f"Removing existing {OUTPUT_ZIP}...")
        os.remove(OUTPUT_ZIP)

    print(f"Streaming data directly to {OUTPUT_ZIP} and deleting sources...")
    
    # 4. Stream-Zip-Delete Loop
    with zipfile.ZipFile(OUTPUT_ZIP, 'w', compression=zipfile.ZIP_DEFLATED, allowZip64=True) as zf:
        
        # FIX: force_zip64=True is required for files >4GB written via stream
        with zf.open(INTERNAL_FILENAME, 'w', force_zip64=True) as dest:
            
            for fname in tqdm(batch_files):
                fpath = os.path.join(TEMP_DIR, fname)
                
                try:
                    with open(fpath, 'rb') as src:
                        while True:
                            chunk = src.read(1024 * 1024) # 1MB chunks
                            if not chunk:
                                break
                            dest.write(chunk)
                    
                    # Delete file immediately to free disk space
                    os.remove(fpath)
                    
                except Exception as e:
                    print(f"Error processing {fname}: {e}")
                    raise e

    print("\n----------------------------------------------")
    print("Merge Complete!")
    print(f"1. Raw batch files deleted.")
    print(f"2. Submission saved to: {OUTPUT_ZIP}")
    print(f"3. File size: {os.path.getsize(OUTPUT_ZIP) / (1024*1024):.2f} MB")
    print("----------------------------------------------")

# Run it
merge_compress_cleanup()

Found 2244 batch files to merge.
Streaming data directly to submission.zip and deleting sources...


100%|██████████| 2244/2244 [11:49<00:00,  3.16it/s]


----------------------------------------------
Merge Complete!
1. Raw batch files deleted.
2. Submission saved to: submission.zip
3. File size: 3275.72 MB
----------------------------------------------





In [5]:
from IPython.display import FileLink
FileLink('submission.zip')