# Multi-Head ProtBERT

This notebook implements a multi-task learning model with separate prediction heads for BP, MF, and CC ontologies.

In [None]:
!pip install protobuf==4.25.3

### Import Libraries

In [None]:
import os
import gc
import random
import numpy as np
import pandas as pd
from pathlib import Path
from collections import defaultdict, Counter
from typing import Optional
from typing import Dict, List, Set, Tuple
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split

## Configuration

In [None]:
# Paths
DATA_DIR = Path("/kaggle/input/cafa-6-protein-function-prediction")
TRAIN_DIR = DATA_DIR / "Train"
TEST_DIR = DATA_DIR / "Test"
WORK_DIR = Path("/kaggle/working")

TRAIN_FASTA = TRAIN_DIR / "train_sequences.fasta"
TRAIN_TERMS = TRAIN_DIR / "train_terms.tsv"
GO_OBO = TRAIN_DIR / "go-basic.obo"
IA_FILE = DATA_DIR / "IA.tsv"
TEST_FASTA = TEST_DIR / "testsuperset.fasta"
OUTPUT_FILE = WORK_DIR / "submission.tsv"

PROTBERT_TRAIN_EMB = Path("/kaggle/input/nnn-cafa6-protbert-embedding/train_embeddings.npy")
PROTBERT_TEST_EMB = Path("/kaggle/input/nnn-cafa6-protbert-embedding/test_embeddings.npy")

# Test split
TEST_PROPORTION = 0.15

# Model parameters
RANDOM_SEED = 42
TOP_K_LABELS = 3000
HIDDEN_DIMS = [1024, 512]
DROPOUT = 0.4
LEARNING_RATE = 1e-3
BATCH_SIZE = 32
EPOCHS = 25
PATIENCE = 5

# Prediction parameters
TOP_K_PER_PROTEIN = 200
THRESHOLD_SEARCH = True
THRESHOLD_GRID = np.arange(0.01, 0.51, 0.01)

# GO propagation
PROPAGATE_TRAIN = True
PROPAGATE_PRED = True
PROPAGATE_ITERATIONS = 3

# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Set seeds
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

## Data Helper Functions

In [None]:
def read_fasta(path: Path) -> Dict[str, str]:
    """Read FASTA file and return dict of protein_id: sequence"""
    sequences = {}
    with open(path) as f:
        protein_id = None
        seq_parts = []
        
        for line in f:
            line = line.strip()
            if line.startswith(">"):
                if protein_id:
                    sequences[protein_id] = "".join(seq_parts)
                
                header = line[1:].split()[0]
                protein_id = header.split("|")[1] if "|" in header else header
                seq_parts = []
            else:
                seq_parts.append(line)
        
        if protein_id:
            sequences[protein_id] = "".join(seq_parts)
    
    print(f"Loaded {len(sequences):,} sequences from {path.name}")
    return sequences


def read_annotations(path: Path) -> Dict[str, List[str]]:
    """Read protein-GO term annotations"""
    df = pd.read_csv(path, sep="\t", header=None, 
                    names=["protein", "go_term", "ontology"])
    
    annotations = defaultdict(list)
    for _, row in df.iterrows():
        annotations[row.protein].append(row.go_term)
    
    print(f"Loaded annotations for {len(annotations):,} proteins")
    return dict(annotations)


def read_ia_weights(path: Path) -> Dict[str, float]:
    """Read Information Accretion weights"""
    if not path.exists():
        print("Warning: IA weights file not found")
        return {}
    
    df = pd.read_csv(path, sep="\t", header=None, names=["go_term", "ia"])
    weights = {}
    
    for _, row in df.iterrows():
        try:
            weights[row.go_term] = float(str(row.ia).replace(",", "."))
        except:
            weights[row.go_term] = 0.0
    
    print(f"Loaded IA weights for {len(weights):,} GO terms")
    return weights

## GO Ontology Parser

In [None]:
def parse_obo(path: Path):
    """Parse OBO file to extract parent-child relationships and namespace (BP/MF/CC)"""
    parents = defaultdict(set)
    children = defaultdict(set)
    go_namespace = {}
    
    if not path.exists():
        print("Warning: OBO file not found")
        return parents, children, go_namespace
    
    with open(path) as f:
        current_id = None
        
        for line in f:
            line = line.strip()
            
            if line == "[Term]":
                current_id = None
            elif line.startswith("id: "):
                current_id = line.split("id: ")[1]
            elif line.startswith("namespace:") and current_id:
                ns = line.split("namespace:")[1].strip()
                if ns == "biological_process":
                    go_namespace[current_id] = "BP"
                elif ns == "molecular_function":
                    go_namespace[current_id] = "MF"
                elif ns == "cellular_component":
                    go_namespace[current_id] = "CC"
            elif line.startswith("is_a: ") and current_id:
                parent_id = line.split()[1]
                parents[current_id].add(parent_id)
                children[parent_id].add(current_id)
            elif line.startswith("relationship: part_of ") and current_id:
                parts = line.split()
                if len(parts) >= 3:
                    parent_id = parts[2]
                    parents[current_id].add(parent_id)
                    children[parent_id].add(current_id)
    
    print(f"Parsed GO graph: {len(parents):,} terms with parents")
    print(f"Parsed namespaces for {len(go_namespace):,} GO terms")
    return dict(parents), dict(children), go_namespace


def get_ancestors(go_term: str, parents: Dict) -> Set[str]:
    """Get all ancestor terms"""
    ancestors = set()
    stack = [go_term]
    
    while stack:
        current = stack.pop()
        for parent in parents.get(current, []):
            if parent not in ancestors:
                ancestors.add(parent)
                stack.append(parent)
    
    return ancestors


def propagate_labels(annotations: Dict[str, List[str]], 
                    parents: Dict) -> Dict[str, List[str]]:
    """Propagate labels up the GO graph"""
    print("Propagating labels up GO hierarchy...")
    propagated = {}
    
    for protein, terms in tqdm(annotations.items(), desc="Propagating"):
        expanded = set(terms)
        for term in terms:
            expanded.update(get_ancestors(term, parents))
        propagated[protein] = sorted(expanded)
    
    original_count = sum(len(v) for v in annotations.values())
    new_count = sum(len(v) for v in propagated.values())
    print(f"  {original_count:,} -> {new_count:,} annotations")
    
    return propagated

## PyTorch Dataset

In [None]:
class ProteinDataset(Dataset):
    """Dataset for multi-task MF/BP/CC"""
    def __init__(self, embeddings, y_bp, y_mf, y_cc):
        self.embeddings = torch.FloatTensor(embeddings)
        self.y_bp = torch.FloatTensor(y_bp)
        self.y_mf = torch.FloatTensor(y_mf)
        self.y_cc = torch.FloatTensor(y_cc)

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        return (
            self.embeddings[idx],
            self.y_bp[idx],
            self.y_mf[idx],
            self.y_cc[idx],
        )

## Prediction Propagation

In [None]:
class MultiTaskProteinFunctionPredictor(nn.Module):
    """Multi-task MLP for BP / MF / CC prediction"""

    def __init__(
        self,
        input_dim: int,
        bp_dim: int,
        mf_dim: int,
        cc_dim: int,
        hidden_dims: List[int] = [512, 256],
        dropout: float = 0.3,
    ):
        super().__init__()

        layers = []
        prev_dim = input_dim

        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
            prev_dim = hidden_dim

        self.trunk = nn.Sequential(*layers)

        self.bp_head = nn.Linear(prev_dim, bp_dim)
        self.mf_head = nn.Linear(prev_dim, mf_dim)
        self.cc_head = nn.Linear(prev_dim, cc_dim)

    def forward(self, x):
        h = self.trunk(x)
        out_bp = self.bp_head(h)
        out_mf = self.mf_head(h)
        out_cc = self.cc_head(h)
        return out_bp, out_mf, out_cc

## Model Architecture

In [None]:
def train_epoch(
    model,
    dataloader,
    criterion_bp,
    criterion_mf,
    criterion_cc,
    optimizer,
    device,
    w_bp=1.0,
    w_mf=1.0,
    w_cc=1.0,
):
    model.train()
    total_loss = 0.0

    for embeddings, y_bp, y_mf, y_cc in dataloader:
        embeddings = embeddings.to(device)
        y_bp = y_bp.to(device)
        y_mf = y_mf.to(device)
        y_cc = y_cc.to(device)

        optimizer.zero_grad()
        logits_bp, logits_mf, logits_cc = model(embeddings)

        loss_bp = criterion_bp(logits_bp, y_bp)
        loss_mf = criterion_mf(logits_mf, y_mf)
        loss_cc = criterion_cc(logits_cc, y_cc)

        loss = w_bp * loss_bp + w_mf * loss_mf + w_cc * loss_cc
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)


def evaluate(
    model,
    dataloader,
    criterion_bp,
    criterion_mf,
    criterion_cc,
    device,
    w_bp=1.0,
    w_mf=1.0,
    w_cc=1.0,
):
    model.eval()
    total_loss = 0.0

    all_bp_preds, all_mf_preds, all_cc_preds = [], [], []
    all_bp_labels, all_mf_labels, all_cc_labels = [], [], []

    with torch.no_grad():
        for embeddings, y_bp, y_mf, y_cc in dataloader:
            embeddings = embeddings.to(device)
            y_bp = y_bp.to(device)
            y_mf = y_mf.to(device)
            y_cc = y_cc.to(device)

            logits_bp, logits_mf, logits_cc = model(embeddings)

            loss_bp = criterion_bp(logits_bp, y_bp)
            loss_mf = criterion_mf(logits_mf, y_mf)
            loss_cc = criterion_cc(logits_cc, y_cc)

            loss = w_bp * loss_bp + w_mf * loss_mf + w_cc * loss_cc
            total_loss += loss.item()

            prob_bp = torch.sigmoid(logits_bp)
            prob_mf = torch.sigmoid(logits_mf)
            prob_cc = torch.sigmoid(logits_cc)

            all_bp_preds.append(prob_bp.cpu().numpy())
            all_mf_preds.append(prob_mf.cpu().numpy())
            all_cc_preds.append(prob_cc.cpu().numpy())

            all_bp_labels.append(y_bp.cpu().numpy())
            all_mf_labels.append(y_mf.cpu().numpy())
            all_cc_labels.append(y_cc.cpu().numpy())

    avg_loss = total_loss / len(dataloader)

    preds_bp = np.vstack(all_bp_preds)
    preds_mf = np.vstack(all_mf_preds)
    preds_cc = np.vstack(all_cc_preds)

    labels_bp = np.vstack(all_bp_labels)
    labels_mf = np.vstack(all_mf_labels)
    labels_cc = np.vstack(all_cc_labels)

    return avg_loss, preds_bp, preds_mf, preds_cc, labels_bp, labels_mf, labels_cc


def compute_metrics(
    y_true: np.ndarray,
    y_pred: np.ndarray,
    weights: Optional[np.ndarray] = None,
) -> Dict:
    """Compute (có thể) IA-weighted precision, recall, F1.
       Nếu weights=None → tính trung bình không IA.
    """
    tp = ((y_true == 1) & (y_pred == 1)).sum(axis=0).astype(float)
    fp = ((y_true == 0) & (y_pred == 1)).sum(axis=0).astype(float)
    fn = ((y_true == 1) & (y_pred == 0)).sum(axis=0).astype(float)
    
    eps = 1e-12
    precision = tp / (tp + fp + eps)
    recall = tp / (tp + fn + eps)
    f1 = 2 * precision * recall / (precision + recall + eps)
    
    if weights is None:
        weights = np.ones_like(precision)
    
    weight_sum = weights.sum() + eps
    return {
        "precision": (precision * weights).sum() / weight_sum,
        "recall": (recall * weights).sum() / weight_sum,
        "f1": (f1 * weights).sum() / weight_sum,
    }


def find_best_threshold(y_true: np.ndarray, y_pred_prob: np.ndarray,
                       weights: np.ndarray, thresholds: np.ndarray) -> Tuple[float, float]:
    """Find optimal threshold by grid search"""
    print("Searching for optimal threshold...")
    
    best_threshold = 0.5
    best_f1 = -1.0
    
    for thresh in tqdm(thresholds, desc="Threshold search"):
        y_pred = (y_pred_prob >= thresh).astype(int)
        metrics = compute_metrics(y_true, y_pred, weights)
        
        if metrics["f1"] > best_f1:
            best_f1 = metrics["f1"]
            best_threshold = thresh

    print(f"  Best threshold: {best_threshold:.3f} (F1: {best_f1:.4f})")
    return best_threshold, best_f1

## Training and Evaluation Functions

In [None]:
def propagate_predictions(predictions: np.ndarray, 
                         parents: Dict,
                         go_terms: List[str],
                         iterations: int = 3) -> np.ndarray:
    """Propagate predictions up GO hierarchy"""
    print("Propagating predictions...")
    
    term_to_idx = {t: i for i, t in enumerate(go_terms)}
    pred_copy = predictions.copy()
    
    # Restrict parents to terms in vocabulary
    restricted_parents = {}
    for term in go_terms:
        restricted_parents[term] = {p for p in parents.get(term, []) 
                                   if p in term_to_idx}
    
    for iteration in range(iterations):
        changed = False
        
        for child_idx, child_term in enumerate(go_terms):
            child_scores = pred_copy[:, child_idx]
            
            for parent_term in restricted_parents.get(child_term, []):
                parent_idx = term_to_idx[parent_term]
                
                # Update parent where child score is higher
                mask = child_scores > pred_copy[:, parent_idx]
                if mask.any():
                    pred_copy[mask, parent_idx] = child_scores[mask]
                    changed = True
        
        if not changed:
            print(f"  Converged after {iteration + 1} iterations")
            break
    
    return pred_copy

## Main Pipeline

### Load Data

In [None]:
# 1. Load Data
print("Loading data...")
train_seqs = read_fasta(TRAIN_FASTA)
test_seqs = read_fasta(TEST_FASTA)
annotations = read_annotations(TRAIN_TERMS)
ia_weights_dict = read_ia_weights(IA_FILE)

print("\n Loading GO ontology...")
parents_map, children_map, go_namespace = parse_obo(GO_OBO)

In [None]:
# 2. Propagate Training Labels
if PROPAGATE_TRAIN:
    annotations = propagate_labels(annotations, parents_map)

### Prepare Labels and Split Data

In [None]:
print("\nPreparing labels...")

train_proteins = [p for p in annotations.keys() if p in train_seqs]
print(f"  {len(train_proteins):,} training proteins")

# Count term frequencies
term_counts = Counter()
for protein in train_proteins:
    term_counts.update(annotations[protein])

# Select top-K terms
top_terms = [t for t, _ in term_counts.most_common(TOP_K_LABELS)]
chosen_terms = set(top_terms)
print(f"  Using top {len(chosen_terms):,} GO terms (all ontologies combined)")

# Filter annotations to keep only top-K terms
for protein in train_proteins:
    annotations[protein] = [t for t in annotations[protein] if t in chosen_terms]

# Split terms by ontology using go_namespace from parse_obo
bp_terms = [t for t in top_terms if go_namespace.get(t) == "BP"]
mf_terms = [t for t in top_terms if go_namespace.get(t) == "MF"]
cc_terms = [t for t in top_terms if go_namespace.get(t) == "CC"]

bp_term_set = set(bp_terms)
mf_term_set = set(mf_terms)
cc_term_set = set(cc_terms)

print(f"  BP terms: {len(bp_terms)}")
print(f"  MF terms: {len(mf_terms)}")
print(f"  CC terms: {len(cc_terms)}")
print(f"  Total (BP+MF+CC): {len(bp_terms) + len(mf_terms) + len(cc_terms)}")

# Create label lists for each ontology
labels_bp = [[t for t in annotations[p] if t in bp_term_set] for p in train_proteins]
labels_mf = [[t for t in annotations[p] if t in mf_term_set] for p in train_proteins]
labels_cc = [[t for t in annotations[p] if t in cc_term_set] for p in train_proteins]

mlb_bp = MultiLabelBinarizer(classes=sorted(bp_term_set))
mlb_mf = MultiLabelBinarizer(classes=sorted(mf_term_set))
mlb_cc = MultiLabelBinarizer(classes=sorted(cc_term_set))

y_bp = mlb_bp.fit_transform(labels_bp).astype(np.float32)
y_mf = mlb_mf.fit_transform(labels_mf).astype(np.float32)
y_cc = mlb_cc.fit_transform(labels_cc).astype(np.float32)

print(f"  y_bp shape: {y_bp.shape}")
print(f"  y_mf shape: {y_mf.shape}")
print(f"  y_cc shape: {y_cc.shape}")

# IA weights for each ontology
ia_weights_bp = np.array([ia_weights_dict.get(t, 1.0) for t in mlb_bp.classes_])
ia_weights_mf = np.array([ia_weights_dict.get(t, 1.0) for t in mlb_mf.classes_])
ia_weights_cc = np.array([ia_weights_dict.get(t, 1.0) for t in mlb_cc.classes_])

In [None]:
# Balance loss weights for 3 ontologies
n_bp = y_bp.shape[1]
n_mf = y_mf.shape[1]
n_cc = y_cc.shape[1]

import math
inv_sqrt = np.array([
    1.0 / math.sqrt(n_bp),
    1.0 / math.sqrt(n_mf),
    1.0 / math.sqrt(n_cc),
])
inv_sqrt = inv_sqrt / inv_sqrt.sum()
W_BP, W_MF, W_CC = inv_sqrt.tolist()

print("Loss weights:",
      f"W_BP={W_BP:.3f}, W_MF={W_MF:.3f}, W_CC={W_CC:.3f}")

### Build and Train Model

In [None]:
print("\nLoading ProtBERT embeddings...")

train_emb = np.load(PROTBERT_TRAIN_EMB).astype(np.float32)
print(f"  Train embeddings shape: {train_emb.shape}")

X = train_emb[:len(train_proteins)]

In [None]:
print("\nSplitting data (multi-task)...")
X_train, X_val, y_bp_train, y_bp_val, y_mf_train, y_mf_val, y_cc_train, y_cc_val = train_test_split(
    X,
    y_bp,
    y_mf,
    y_cc,
    test_size=TEST_PROPORTION,
    random_state=RANDOM_SEED,
)

print(f"  Train: {X_train.shape}")
print(f"  Val:   {X_val.shape}")

train_dataset = ProteinDataset(X_train, y_bp_train, y_mf_train, y_cc_train)
val_dataset = ProteinDataset(X_val, y_bp_val, y_mf_val, y_cc_val)

train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True
)
val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE * 2, shuffle=False, num_workers=2, pin_memory=True
)

In [None]:
print("\nBuilding multi-task model...")

model = MultiTaskProteinFunctionPredictor(
    input_dim=X.shape[1],
    bp_dim=y_bp.shape[1],
    mf_dim=y_mf.shape[1],
    cc_dim=y_cc.shape[1],
    hidden_dims=HIDDEN_DIMS,
    dropout=DROPOUT,
).to(DEVICE)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")

criterion_bp = nn.BCEWithLogitsLoss()
criterion_mf = nn.BCEWithLogitsLoss()
criterion_cc = nn.BCEWithLogitsLoss()

optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=2, verbose=True)

### Validation Evaluation

In [None]:
print("\nTraining...")
best_val_loss = float('inf')
patience_counter = 0
best_model_path = WORK_DIR / "best_model.pt"

history = {
    "train_loss": [],
    "val_loss": [],
    "val_f1_bp": [],
    "val_f1_mf": [],
    "val_f1_cc": [],
}

for epoch in range(EPOCHS):
    train_loss = train_epoch(
        model,
        train_loader,
        criterion_bp,
        criterion_mf,
        criterion_cc,
        optimizer,
        DEVICE,
        w_bp=W_BP,
        w_mf=W_MF,
        w_cc=W_CC,
    )

    (
        val_loss,
        val_bp_preds,
        val_mf_preds,
        val_cc_preds,
        val_bp_labels,
        val_mf_labels,
        val_cc_labels,
    ) = evaluate(
        model,
        val_loader,
        criterion_bp,
        criterion_mf,
        criterion_cc,
        DEVICE,
        w_bp=W_BP,
        w_mf=W_MF,
        w_cc=W_CC,
    )

    val_bp_bin = (val_bp_preds >= 0.5).astype(int)
    val_mf_bin = (val_mf_preds >= 0.5).astype(int)
    val_cc_bin = (val_cc_preds >= 0.5).astype(int)

    bp_unw = compute_metrics(val_bp_labels, val_bp_bin, weights=None)
    mf_unw = compute_metrics(val_mf_labels, val_mf_bin, weights=None)
    cc_unw = compute_metrics(val_cc_labels, val_cc_bin, weights=None)

    bp_ia = compute_metrics(val_bp_labels, val_bp_bin, ia_weights_bp)
    mf_ia = compute_metrics(val_mf_labels, val_mf_bin, ia_weights_mf)
    cc_ia = compute_metrics(val_cc_labels, val_cc_bin, ia_weights_cc)

    scheduler.step(val_loss)

    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["val_f1_bp"].append(bp_unw["f1"])
    history["val_f1_mf"].append(mf_unw["f1"])
    history["val_f1_cc"].append(cc_unw["f1"])

    print(f"Epoch {epoch+1}/{EPOCHS}")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss:   {val_loss:.4f}")

    print(
        "  [Unweighted] "
        f"BP P/R/F1: {bp_unw['precision']:.4f}/{bp_unw['recall']:.4f}/{bp_unw['f1']:.4f} | "
        f"MF P/R/F1: {mf_unw['precision']:.4f}/{mf_unw['recall']:.4f}/{mf_unw['f1']:.4f} | "
        f"CC P/R/F1: {cc_unw['precision']:.4f}/{cc_unw['recall']:.4f}/{cc_unw['f1']:.4f}"
    )

    print(
        "  [IA-weighted F1] "
        f"BP: {bp_ia['f1']:.6f}  MF: {mf_ia['f1']:.6f}  CC: {cc_ia['f1']:.6f}"
    )

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(model.state_dict(), best_model_path)
        print("  Saved best model")
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print(f"\n  Early stopping triggered after {epoch+1} epochs")
            break

    print()

print("Loading best model...")
model.load_state_dict(torch.load(best_model_path))

In [None]:
print("\nEvaluating on validation set (multi-task)...")
(
    _,
    val_bp_preds,
    val_mf_preds,
    val_cc_preds,
    val_bp_labels,
    val_mf_labels,
    val_cc_labels,
) = evaluate(
    model,
    val_loader,
    criterion_bp,
    criterion_mf,
    criterion_cc,
    DEVICE,
    w_bp=W_BP,
    w_mf=W_MF,
    w_cc=W_CC,
)

if THRESHOLD_SEARCH:
    best_thr_bp_ia, best_f1_bp_ia = find_best_threshold(
        val_bp_labels, val_bp_preds, ia_weights_bp, THRESHOLD_GRID
    )
    best_thr_bp_unw, best_f1_bp_unw = find_best_threshold(
        val_bp_labels, val_bp_preds, None, THRESHOLD_GRID
    )
    print(
        f"BP threshold IA: {best_thr_bp_ia:.2f}, F1_IA: {best_f1_bp_ia:.6f} | "
        f"BP threshold unweighted: {best_thr_bp_unw:.2f}, F1_unw: {best_f1_bp_unw:.4f}"
    )

### Test Predictions

In [None]:
print("\nGenerating test predictions...")

test_emb = np.load(PROTBERT_TEST_EMB).astype(np.float32)
print(f"  Test embeddings shape: {test_emb.shape}")

test_dataset = ProteinDataset(
    test_emb,
    np.zeros((len(test_emb), y_bp.shape[1])),
    np.zeros((len(test_emb), y_mf.shape[1])),
    np.zeros((len(test_emb), y_cc.shape[1])),
)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE * 2, 
                        shuffle=False, num_workers=2, pin_memory=True)

model.eval()
all_bp_preds, all_mf_preds, all_cc_preds = [], [], []

with torch.no_grad():
    for embeddings, _, _, _ in tqdm(test_loader, desc="Predicting"):
        embeddings = embeddings.to(DEVICE)
        logits_bp, logits_mf, logits_cc = model(embeddings)

        prob_bp = torch.sigmoid(logits_bp)
        prob_mf = torch.sigmoid(logits_mf)
        prob_cc = torch.sigmoid(logits_cc)

        all_bp_preds.append(prob_bp.cpu().numpy())
        all_mf_preds.append(prob_mf.cpu().numpy())
        all_cc_preds.append(prob_cc.cpu().numpy())

test_bp_preds = np.vstack(all_bp_preds)
test_mf_preds = np.vstack(all_mf_preds)
test_cc_preds = np.vstack(all_cc_preds)

print(f"  BP preds shape: {test_bp_preds.shape}")
print(f"  MF preds shape: {test_mf_preds.shape}")
print(f"  CC preds shape: {test_cc_preds.shape}")

In [None]:
if PROPAGATE_PRED:
    test_bp_preds = propagate_predictions(
        test_bp_preds, parents_map, list(mlb_bp.classes_), PROPAGATE_ITERATIONS
    )
    test_mf_preds = propagate_predictions(
        test_mf_preds, parents_map, list(mlb_mf.classes_), PROPAGATE_ITERATIONS
    )
    test_cc_preds = propagate_predictions(
        test_cc_preds, parents_map, list(mlb_cc.classes_), PROPAGATE_ITERATIONS
    )

# Concatenate predictions in order: BP -> MF -> CC
all_test_preds = np.concatenate([test_bp_preds, test_mf_preds, test_cc_preds], axis=1)
all_terms = np.concatenate([mlb_bp.classes_, mlb_mf.classes_, mlb_cc.classes_])

### Write Submission File

In [None]:
print(f"\nWriting submission to {OUTPUT_FILE}...")

test_ids = list(test_seqs.keys())[:all_test_preds.shape[0]]

with open(OUTPUT_FILE, "w") as f:
    for i, protein_id in enumerate(tqdm(test_ids, desc="Writing")):
        probs = all_test_preds[i]
        top_indices = np.argsort(probs)[-TOP_K_PER_PROTEIN:][::-1]
    
        for idx in top_indices:
            score = float(probs[idx])
            if score > 1e-6:
                go_term = all_terms[idx]
                f.write(f"{protein_id}\t{go_term}\t{score:.3f}\n")

print("\n" + "="*70)
print("Pipeline completed successfully")
print("="*70)

del model, train_loader, val_loader, test_loader
torch.cuda.empty_cache()
gc.collect()