# Install Dependencies & External Tools

In [None]:
# 1. Install Python libraries
!pip install -q biopython obonet networkx transformers torch tqdm

# 2. Install MMseqs2 (Static binary for Linux)
# We need this for the clustering logic in data_splits.py
!mkdir -p /content/mmseqs
!wget -q https://mmseqs.com/latest/mmseqs-linux-avx2.tar.gz -O /content/mmseqs/mmseqs.tar.gz
!tar xvfz /content/mmseqs/mmseqs.tar.gz -C /content/mmseqs
!ln -sf /content/mmseqs/mmseqs/bin/mmseqs /usr/local/bin/mmseqs

print("Dependencies and MMseqs2 installed.")

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/3.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/3.2 MB[0m [31m32.7 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m48.3 MB/s[0m eta [36m0:00:00[0m
[?25hmmseqs/
mmseqs/userguide.pdf
mmseqs/examples/
mmseqs/examples/DB.fasta
mmseqs/examples/QUERY.fasta
mmseqs/LICENSE.md
mmseqs/README.md
mmseqs/matrices/
mmseqs/matrices/PAM120.out
mmseqs/matrices/PAM190.out
mmseqs/matrices/PAM130.out
mmseqs/matrices/blosum62.out
mmseqs/matrices/blosum80.out
mmseqs/matrices/blosum95.out
mmseqs/matrices/PAM20.out
mmseqs/matrices/blosum75.out
mmseqs/matrices/blosum100.out
mmseqs/matrices/PAM40.out
mmseqs/matrices/PAM100.out
mmseqs/matrices/VTML200.out
mmseqs/matrices/blosum35.out
mmseqs/matrices/PAM70.out
mmseqs/matrices/PAM170.out
mmseqs/matrices/blosum30.out
mmseqs/matrices/VTML40.out
mmseqs/m

# Kaggle Data Setup (Upload kaggle.json)

In [None]:
import os
import shutil
from google.colab import files

# 1. Upload kaggle.json
if not os.path.exists('/root/.kaggle/kaggle.json'):
    print("Please upload your kaggle.json file:")
    uploaded = files.upload()

    !mkdir -p ~/.kaggle
    !cp kaggle.json ~/.kaggle/
    !chmod 600 ~/.kaggle/kaggle.json

# 2. Download Competition Data
if not os.path.exists('cafa-6-protein-function-prediction.zip'):
    print("Downloading dataset...")
    !kaggle competitions download -c cafa-6-protein-function-prediction

# 3. Organize Directory Structure
# Your code expects: data/raw/Train, data/embeddings, data/splits
print("Organizing directory structure...")

!mkdir -p data/raw
!mkdir -p data/embeddings/train
!mkdir -p data/splits
!mkdir -p results

# Unzip and move
!unzip -q cafa-6-protein-function-prediction.zip -d temp_data

# Move files to match your project tree
# Note: Adjusting based on standard Kaggle unzip structure
if os.path.exists('temp_data/Train'):
    !mv temp_data/Train data/raw/
    !mv temp_data/Test data/raw/
    !mv temp_data/IA.tsv data/raw/
    !mv temp_data/sample_submission.tsv data/raw/
else:
    # Fallback if zip structure is flat
    !mkdir -p data/raw/Train data/raw/Test
    !mv temp_data/train_* data/raw/Train/ 2>/dev/null || true
    !mv temp_data/go-basic.obo data/raw/Train/
    !mv temp_data/testsuperset* data/raw/Test/ 2>/dev/null || true

!rm -rf temp_data
print("Data setup complete. Structure:")

Please upload your kaggle.json file:


Saving kaggle.json to kaggle.json
Downloading dataset...
Downloading cafa-6-protein-function-prediction.zip to /content
  0% 0.00/91.3M [00:00<?, ?B/s]
100% 91.3M/91.3M [00:00<00:00, 1.28GB/s]
Organizing directory structure...
Data setup complete. Structure:
/bin/bash: line 1: tree: command not found


# Write Source Code modules

In [None]:
import os
os.makedirs('src', exist_ok=True)

# Add src to python path for immediate imports if needed
import sys
sys.path.append('/content/src')

In [None]:
%%writefile src/go_labeler.py
"""
GO Labeler: Handles Gene Ontology logic for protein function prediction.
"""
import numpy as np
import obonet
import networkx as nx
from collections import defaultdict
from typing import List, Tuple, Dict, Set, Optional

class GOLabeler:
    def __init__(self, obo_path: str, annotations: List[Tuple[str, str]]):
        self.obo_path = obo_path
        self.annotations = annotations
        print(f"Loading GO graph from {obo_path}...")
        self.go_graph = obonet.read_obo(obo_path)
        print(f"Loaded GO graph with {self.go_graph.number_of_nodes()} terms.")

        self.term_to_index: Dict[str, int] = {}
        self.index_to_term: Dict[int, str] = {}
        self.valid_terms: List[str] = []
        self._term_frequencies: Dict[str, int] = {}

    def _get_ancestors(self, term_id: str) -> Set[str]:
        ancestors = set()
        if term_id not in self.go_graph:
            return ancestors
        ancestors.add(term_id)
        try:
            ancestors.update(nx.descendants(self.go_graph, term_id))
        except nx.NetworkXError:
            pass
        return ancestors

    def _propagate_terms(self, terms: List[str]) -> Set[str]:
        all_terms = set()
        for term in terms:
            all_terms.update(self._get_ancestors(term))
        return all_terms

    def build_label_vocabulary(self, min_frequency: int = 50) -> None:
        print(f"Building label vocabulary with min_frequency={min_frequency}...")
        protein_to_terms: Dict[str, List[str]] = defaultdict(list)
        for protein_id, term_id in self.annotations:
            protein_to_terms[protein_id].append(term_id)

        term_counts: Dict[str, int] = defaultdict(int)
        for protein_id, terms in protein_to_terms.items():
            propagated_terms = self._propagate_terms(terms)
            for term in propagated_terms:
                term_counts[term] += 1

        self._term_frequencies = dict(term_counts)
        self.valid_terms = sorted([
            term for term, count in term_counts.items()
            if count >= min_frequency
        ])

        print(f"Terms meeting min_frequency threshold: {len(self.valid_terms)}")
        self.term_to_index = {term: idx for idx, term in enumerate(self.valid_terms)}
        self.index_to_term = {idx: term for idx, term in enumerate(self.valid_terms)}

    def get_vector(self, protein_terms: List[str]) -> np.ndarray:
        if not self.valid_terms:
            raise ValueError("Vocabulary not built.")
        vector = np.zeros(len(self.valid_terms), dtype=np.float32)
        propagated_terms = self._propagate_terms(protein_terms)
        for term in propagated_terms:
            if term in self.term_to_index:
                idx = self.term_to_index[term]
                vector[idx] = 1.0
        return vector

    def vocabulary_size(self) -> int:
        return len(self.valid_terms)

Writing src/go_labeler.py


In [None]:
%%writefile src/data_splits.py
"""
Data Splitting Utilities for Protein Function Prediction.
"""
import random
from collections import defaultdict
from typing import Dict, List, Set, Tuple, Optional

def create_splits(
    cluster_file: str,
    val_ratio: float = 0.2,
    random_seed: Optional[int] = 42
) -> Tuple[Set[str], Set[str]]:
    if random_seed is not None:
        random.seed(random_seed)

    clusters: Dict[str, List[str]] = defaultdict(list)
    print(f"Reading cluster file: {cluster_file}")
    with open(cluster_file, 'r') as f:
        for line in f:
            line = line.strip()
            if not line: continue
            parts = line.split('\t')
            if len(parts) < 2: continue
            cluster_rep, sequence_member = parts[0], parts[1]
            clusters[cluster_rep].append(sequence_member)

    cluster_reps = list(clusters.keys())
    n_clusters = len(cluster_reps)
    total_sequences = sum(len(members) for members in clusters.values())
    print(f"Found {n_clusters} clusters containing {total_sequences} sequences")

    random.shuffle(cluster_reps)
    n_val_clusters = int(n_clusters * val_ratio)

    val_cluster_reps = set(cluster_reps[:n_val_clusters])

    train_protein_ids: Set[str] = set()
    val_protein_ids: Set[str] = set()

    for cluster_rep, members in clusters.items():
        if cluster_rep in val_cluster_reps:
            val_protein_ids.update(members)
        else:
            train_protein_ids.update(members)

    print(f"Split complete: Train {len(train_protein_ids)}, Val {len(val_protein_ids)}")
    return train_protein_ids, val_protein_ids

Writing src/data_splits.py


In [None]:
%%writefile src/dataset.py
"""
PyTorch Dataset for Protein Function Prediction.
"""
import os
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import torch
from torch.utils.data import Dataset
import numpy as np

class ProteinGODataset(Dataset):
    def __init__(
        self,
        protein_ids: List[str],
        labeler,
        embedding_dir: str,
        annotations: Dict[str, List[str]],
        check_exists: bool = True
    ):
        self.labeler = labeler
        self.embedding_dir = Path(embedding_dir)
        self.annotations = annotations

        if check_exists:
            self.protein_ids = []
            for pid in protein_ids:
                embedding_path = self.embedding_dir / f"{pid}.pt"
                if embedding_path.exists():
                    self.protein_ids.append(pid)
        else:
            self.protein_ids = list(protein_ids)

        print(f"Dataset initialized with {len(self.protein_ids)} proteins")

    def __len__(self) -> int:
        return len(self.protein_ids)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        protein_id = self.protein_ids[idx]
        embedding_path = self.embedding_dir / f"{protein_id}.pt"
        embedding = torch.load(embedding_path, weights_only=True).to(torch.float32)

        explicit_terms = self.annotations.get(protein_id, [])
        label_numpy = self.labeler.get_vector(explicit_terms)
        label = torch.from_numpy(label_numpy).to(torch.float32)

        return embedding, label

Writing src/dataset.py


In [None]:
%%writefile src/model.py
"""
Neural Network Model and Loss Components.
"""
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from typing import Optional

class ProteinMLP(nn.Module):
    def __init__(self, num_classes: int, input_dim: int = 2560, hidden_dim: int = 1024, dropout: float = 0.3):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_classes = num_classes

        self.hidden = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        self.output = nn.Linear(hidden_dim, num_classes)
        self._init_weights()

    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.hidden(x)
        return self.output(x)

def calculate_pos_weights(train_dataset, num_workers: int = 0, batch_size: int = 256) -> torch.Tensor:
    loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    _, first_label = train_dataset[0]
    num_classes = first_label.shape[0]

    positive_counts = torch.zeros(num_classes, dtype=torch.float64)
    total_samples = 0

    print("Calculating positive class weights...")
    for _, labels in tqdm(loader, desc="Computing weights"):
        positive_counts += labels.sum(dim=0).to(torch.float64)
        total_samples += labels.shape[0]

    epsilon = 1e-7
    negative_counts = total_samples - positive_counts
    pos_weights = negative_counts / (positive_counts + epsilon)
    pos_weights = torch.clamp(pos_weights, min=1.0, max=100.0)

    return pos_weights.to(torch.float32)

def create_loss_function(pos_weights: Optional[torch.Tensor] = None, device: str = "cuda") -> nn.BCEWithLogitsLoss:
    if pos_weights is not None:
        pos_weights = pos_weights.to(device)
        return nn.BCEWithLogitsLoss(pos_weight=pos_weights)
    return nn.BCEWithLogitsLoss()

Writing src/model.py


In [None]:
%%writefile src/threshold_optimizer.py
"""
Threshold Optimization for Multi-Label Classification.
"""
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.amp import autocast
from sklearn.metrics import precision_recall_curve, f1_score
from tqdm import tqdm
from typing import Tuple, Optional

def collect_predictions(model, data_loader, device, use_amp=True):
    model.eval()
    use_amp = use_amp and device == "cuda"
    all_probs = []
    all_labels = []

    with torch.no_grad():
        for embeddings, labels in tqdm(data_loader, desc="Collecting predictions"):
            embeddings = embeddings.to(device)
            with autocast(device_type=device, enabled=use_amp):
                logits = model(embeddings)
            probs = torch.sigmoid(logits).cpu().numpy()
            all_probs.append(probs)
            all_labels.append(labels.numpy())

    return np.concatenate(all_probs, axis=0), np.concatenate(all_labels, axis=0)

def optimize_thresholds(model, val_loader, device, use_amp=True) -> np.ndarray:
    print("Optimizing Per-Class Thresholds")
    y_probs, y_true = collect_predictions(model, val_loader, device, use_amp)
    n_classes = y_probs.shape[1]
    optimal_thresholds = np.full(n_classes, 0.5)

    for class_idx in tqdm(range(n_classes), desc="Optimizing"):
        y_true_class = y_true[:, class_idx]
        if y_true_class.sum() == 0 or y_true_class.sum() == len(y_true_class):
            continue

        precision, recall, thresholds = precision_recall_curve(y_true_class, y_probs[:, class_idx])
        with np.errstate(divide='ignore', invalid='ignore'):
            f1_scores = 2 * (precision[:-1] * recall[:-1]) / (precision[:-1] + recall[:-1])
            f1_scores = np.nan_to_num(f1_scores, nan=0.0)

        if len(f1_scores) > 0 and f1_scores.max() > 0:
            optimal_thresholds[class_idx] = thresholds[np.argmax(f1_scores)]

    return optimal_thresholds

def evaluate_with_thresholds(model, data_loader, device, thresholds, use_amp=True):
    y_probs, y_true = collect_predictions(model, data_loader, device, use_amp)
    y_pred = (y_probs >= thresholds).astype(np.float32)
    micro_f1 = f1_score(y_true, y_pred, average='micro', zero_division=0)
    macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)

    print(f"Evaluation - Micro F1: {micro_f1:.4f}, Macro F1: {macro_f1:.4f}")
    return {'micro_f1': micro_f1, 'macro_f1': macro_f1}

Writing src/threshold_optimizer.py


In [None]:
%%writefile src/trainer.py
"""
Training Utilities.
"""
import time
from pathlib import Path
from typing import Optional, Dict, Any
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.amp import GradScaler, autocast
from tqdm import tqdm

def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    loss_fn: nn.Module,
    optimizer: torch.optim.Optimizer,
    device: str,
    epochs: int,
    patience: int = 5,
    checkpoint_dir: str = "models",
    checkpoint_name: str = "best_model.pt",
    scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
    use_amp: bool = True
) -> Dict[str, Any]:

    Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
    best_model_path = Path(checkpoint_dir) / checkpoint_name

    use_amp = use_amp and device == "cuda"
    scaler = GradScaler(enabled=use_amp)

    train_losses, val_losses = [], []
    best_val_loss = float('inf')
    best_epoch = 0
    patience_counter = 0
    stopped_early = False

    model = model.to(device)

    for epoch in range(epochs):
        model.train()
        train_loss_sum = 0.0
        batches = 0

        for embeddings, labels in tqdm(train_loader, desc=f"Epoch {epoch+1} Train", leave=False):
            embeddings, labels = embeddings.to(device), labels.to(device)
            optimizer.zero_grad()

            with autocast(device_type=device, enabled=use_amp):
                logits = model(embeddings)
                loss = loss_fn(logits, labels)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            train_loss_sum += loss.item()
            batches += 1

        avg_train_loss = train_loss_sum / batches
        train_losses.append(avg_train_loss)

        # Validation
        model.eval()
        val_loss_sum = 0.0
        val_batches = 0
        with torch.no_grad():
            for embeddings, labels in tqdm(val_loader, desc=f"Epoch {epoch+1} Val", leave=False):
                embeddings, labels = embeddings.to(device), labels.to(device)
                with autocast(device_type=device, enabled=use_amp):
                    logits = model(embeddings)
                    loss = loss_fn(logits, labels)
                val_loss_sum += loss.item()
                val_batches += 1

        avg_val_loss = val_loss_sum / val_batches
        val_losses.append(avg_val_loss)

        if scheduler:
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(avg_val_loss)
            else:
                scheduler.step()

        print(f"Epoch {epoch+1} - Train: {avg_train_loss:.4f}, Val: {avg_val_loss:.4f}")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_epoch = epoch + 1
            patience_counter = 0
            torch.save({
                'model_state_dict': model.state_dict(),
                'val_loss': best_val_loss
            }, best_model_path)
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping")
                stopped_early = True
                break

    checkpoint = torch.load(best_model_path)
    model.load_state_dict(checkpoint['model_state_dict'])

    return {'train_losses': train_losses, 'val_losses': val_losses, 'best_epoch': best_epoch, 'best_val_loss': best_val_loss}

def load_checkpoint(model, path, device="cpu"):
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    return model, {}

Writing src/trainer.py


In [None]:
%%writefile src/extract_embeddings.py
import os
import argparse
from pathlib import Path
import torch
from transformers import AutoTokenizer, AutoModel
from Bio import SeqIO
from tqdm import tqdm

MAX_SEQUENCE_LENGTH = 1022

def process_sequence(sequence, model, tokenizer, device):
    if len(sequence) > MAX_SEQUENCE_LENGTH:
        sequence = sequence[:MAX_SEQUENCE_LENGTH]

    inputs = tokenizer(sequence, return_tensors="pt", padding=False, truncation=True, max_length=MAX_SEQUENCE_LENGTH + 2)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs)

    # Mean pooling (excluding CLS/EOS)
    return outputs.last_hidden_state[:, 1:-1, :].mean(dim=1).squeeze(0).cpu()

def process_fasta(fasta_path, output_dir, model_name, device):
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name).to(device).eval()

    print(f"Extracting embeddings using {model_name}")

    total = sum(1 for _ in SeqIO.parse(fasta_path, "fasta"))

    with tqdm(total=total) as pbar:
        for record in SeqIO.parse(fasta_path, "fasta"):
            out_file = output_path / f"{record.id}.pt"
            if not out_file.exists():
                try:
                    emb = process_sequence(str(record.seq), model, tokenizer, device)
                    torch.save(emb, out_file)
                except Exception as e:
                    print(f"Error {record.id}: {e}")
            pbar.update(1)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--fasta", required=True)
    parser.add_argument("--output", required=True)
    parser.add_argument("--model", default="facebook/esm2_t36_3B_UR50D")
    parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
    args = parser.parse_args()

    process_fasta(args.fasta, args.output, args.model, args.device)

Writing src/extract_embeddings.py


In [None]:
%%writefile src/main.py
import os
import sys
import argparse
from pathlib import Path
from collections import defaultdict
import numpy as np
import torch
from torch.utils.data import DataLoader

from go_labeler import GOLabeler
from data_splits import create_splits
from dataset import ProteinGODataset
from model import ProteinMLP, calculate_pos_weights, create_loss_function
from trainer import train_model, load_checkpoint
from threshold_optimizer import optimize_thresholds, evaluate_with_thresholds

def load_annotations(path):
    ann_list, ann_dict = [], defaultdict(list)
    with open(path, 'r') as f:
        for line in f:
            line = line.strip()
            if not line: continue
            parts = line.split('\t')
            if len(parts) >= 2:
                ann_list.append((parts[0], parts[1]))
                ann_dict[parts[0]].append(parts[1])
    return ann_list, dict(ann_dict)

def main(args):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Running on {device}")

    # 1. Load Data
    annotations_list, annotations_dict = load_annotations(args.annotations_path)
    labeler = GOLabeler(args.obo_path, annotations_list)
    labeler.build_label_vocabulary(min_frequency=args.min_frequency)

    # 2. Splits
    train_ids, val_ids = create_splits(args.cluster_path, args.val_ratio, args.seed)
    train_ids = train_ids & set(annotations_dict.keys())
    val_ids = val_ids & set(annotations_dict.keys())

    # 3. Datasets
    train_dataset = ProteinGODataset(list(train_ids), labeler, args.embedding_dir, annotations_dict)
    val_dataset = ProteinGODataset(list(val_ids), labeler, args.embedding_dir, annotations_dict)

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    # 4. Model Setup
    pos_weights = calculate_pos_weights(train_dataset) if args.use_pos_weights else None
    model = ProteinMLP(num_classes=labeler.vocabulary_size(), input_dim=args.input_dim, hidden_dim=args.hidden_dim, dropout=args.dropout)
    loss_fn = create_loss_function(pos_weights, device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

    # 5. Train
    train_model(model, train_loader, val_loader, loss_fn, optimizer, device, args.epochs, args.patience, args.output_dir, "best_model.pt", scheduler)

    # 6. Optimization
    print("Loading best model for optimization...")
    model, _ = load_checkpoint(model, str(Path(args.output_dir)/"best_model.pt"), device)
    thresholds = optimize_thresholds(model, val_loader, device)
    evaluate_with_thresholds(model, val_loader, device, thresholds)

    # Save artifacts
    np.save(Path(args.output_dir)/"thresholds.npy", thresholds)
    np.savez(Path(args.output_dir)/"vocabulary.npz", valid_terms=np.array(labeler.valid_terms))

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--obo_path", default="data/raw/Train/go-basic.obo")
    parser.add_argument("--annotations_path", default="data/raw/Train/train_terms.tsv")
    parser.add_argument("--cluster_path", default="data/splits/train_cluster.tsv")
    parser.add_argument("--embedding_dir", default="data/embeddings/train")
    parser.add_argument("--output_dir", default="results/experiment_001")
    parser.add_argument("--input_dim", type=int, default=1280) # Changed default to match esm2_t33_650M
    parser.add_argument("--hidden_dim", type=int, default=1024)
    parser.add_argument("--dropout", type=float, default=0.3)
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--epochs", type=int, default=50)
    parser.add_argument("--patience", type=int, default=5)
    parser.add_argument("--learning_rate", type=float, default=1e-4)
    parser.add_argument("--weight_decay", type=float, default=1e-5)
    parser.add_argument("--min_frequency", type=int, default=50)
    parser.add_argument("--val_ratio", type=float, default=0.2)
    parser.add_argument("--use_pos_weights", action="store_true")
    parser.add_argument("--num_workers", type=int, default=2)
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()
    main(args)

Writing src/main.py


# Data Processing - Clustering

In [None]:
# Generate clusters (30% identity)
# This creates data/splits/train_cluster.tsv
!mmseqs easy-linclust \
    data/raw/Train/train_sequences.fasta \
    data/splits/train \
    tmp \
    --min-seq-id 0.3 \
    --cov-mode 1 \
    -c 0.8

# Rename output to match what src/data_splits.py expects
!mv data/splits/train_cluster.tsv data/splits/train_cluster.tsv.bak
# Filter out headers if any and ensure tab separated
!awk '{print $1"\t"$2}' data/splits/train_cluster.tsv.bak > data/splits/train_cluster.tsv
!head -n 5 data/splits/train_cluster.tsv

Create directory tmp
easy-linclust data/raw/Train/train_sequences.fasta data/splits/train tmp --min-seq-id 0.3 --cov-mode 1 -c 0.8 

MMseqs Version:                     	bd01c2229f027d8d8e61947f44d11ef1a7669212
Cluster mode                        	0
Max connected component depth       	1000
Similarity type                     	2
Threads                             	2
Compressed                          	0
Verbosity                           	3
Weight file name                    	
Cluster Weight threshold            	0.9
Set mode                            	false
Substitution matrix                 	aa:blosum62.out,nucl:nucleotide.out
Add backtrace                       	false
Alignment mode                      	0
Alignment mode                      	0
Allow wrapped scoring               	false
E-value threshold                   	0.001
Seq. id. threshold                  	0.3
Min alignment length                	0
Seq. id. mode                       	0
Alternative alignments         

# Data Processing - Embedding Extraction

In [None]:
# Run embedding extraction
# This takes time! (Approx 3-5 hours for full CAFA train set on T4)
!python src/extract_embeddings.py \
    --fasta data/raw/Train/train_sequences.fasta \
    --output data/embeddings/train \
    --model facebook/esm2_t36_3B_UR50D

tokenizer_config.json: 100% 95.0/95.0 [00:00<00:00, 664kB/s]
vocab.txt: 100% 93.0/93.0 [00:00<00:00, 881kB/s]
special_tokens_map.json: 100% 125/125 [00:00<00:00, 1.44MB/s]
config.json: 100% 779/779 [00:00<00:00, 6.61MB/s]
2026-01-21 05:37:24.293847: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1768973844.313478    1289 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1768973844.319338    1289 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1768973844.334255    1289 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1768973844.334279    1289 co

# Train the model

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!python src/main.py \
    --obo_path data/raw/Train/go-basic.obo \
    --annotations_path data/raw/Train/train_terms.tsv \
    --cluster_path data/splits/train_cluster.tsv \
    --embedding_dir data/embeddings/train \
    --output_dir results/final_run \
    --input_dim 2560 \
    --batch_size 128 \
    --epochs 20 \
    --use_pos_weights