# Multi-Modal Taxonomy and Text-Enhanced ProtBERT

This notebook implements a knowledge-enhanced protein function prediction model that combines:
- ProtBERT embeddings for protein sequence representation
- Taxonomy embeddings for species-specific information
- GO term text embeddings (using BiomedBERT) for semantic matching
- GO hierarchy propagation for consistent predictions

In [None]:
!pip install protobuf==4.25.3

## Environment Setup

In [None]:
import os
import gc
import random
import numpy as np
import pandas as pd
from pathlib import Path
from collections import defaultdict, Counter
from typing import Optional
from typing import Dict, List, Set, Tuple
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split

### Import Libraries

## Configuration

In [None]:
# Paths
DATA_DIR = Path("/kaggle/input/cafa-6-protein-function-prediction")
TRAIN_DIR = DATA_DIR / "Train"
TEST_DIR = DATA_DIR / "Test"
WORK_DIR = Path("/kaggle/working")

TRAIN_FASTA = TRAIN_DIR / "train_sequences.fasta"
TRAIN_TERMS = TRAIN_DIR / "train_terms.tsv"
GO_OBO = TRAIN_DIR / "go-basic.obo"
IA_FILE = DATA_DIR / "IA.tsv"
TEST_FASTA = TEST_DIR / "testsuperset.fasta"
OUTPUT_FILE = WORK_DIR / "submission.tsv"

# ProtBERT embeddings
PROTBERT_TRAIN_EMB = Path("/kaggle/input/nnn-cafa6-protbert-embedding/train_embeddings.npy")
PROTBERT_TEST_EMB = Path("/kaggle/input/nnn-cafa6-protbert-embedding/test_embeddings.npy")

# Taxonomy data
TRAIN_TAXONOMY = Path("/kaggle/input/cafa-6-protein-function-prediction/Train/train_taxonomy.tsv")
USE_TAXONOMY = True
TAXONOMY_EMBEDDING_DIM = 32

# GO text embeddings (BiomedBERT)
USE_GO_TEXT = True
GO_TEXT_MODEL = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract"
GO_TEXT_BATCH_SIZE = 32
GO_TEXT_MAX_LENGTH = 128

# Test split
TEST_PROPORTION = 0.15

# Model parameters
RANDOM_SEED = 42
TOP_K_LABELS = 3000
HIDDEN_DIMS = [1024, 512]
DROPOUT = 0.3
LEARNING_RATE = 1e-3
BATCH_SIZE = 32
EPOCHS = 25
PATIENCE = 5

# Prediction parameters
TOP_K_PER_PROTEIN = 200
THRESHOLD_SEARCH = True
THRESHOLD_GRID = np.arange(0.01, 0.51, 0.01)

# GO propagation
PROPAGATE_TRAIN = True
PROPAGATE_PRED = True
PROPAGATE_ITERATIONS = 3

# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Set seeds
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

## Data Helper Functions

In [None]:
def read_fasta(path: Path) -> Dict[str, str]:
    """Read FASTA file and return dict of protein_id: sequence"""
    sequences = {}
    with open(path) as f:
        protein_id = None
        seq_parts = []
        
        for line in f:
            line = line.strip()
            if line.startswith(">"):
                if protein_id:
                    sequences[protein_id] = "".join(seq_parts)
                
                header = line[1:].split()[0]
                protein_id = header.split("|")[1] if "|" in header else header
                seq_parts = []
            else:
                seq_parts.append(line)
        
        if protein_id:
            sequences[protein_id] = "".join(seq_parts)
    
    print(f"Loaded {len(sequences):,} sequences from {path.name}")
    return sequences


def read_annotations(path: Path) -> Dict[str, List[str]]:
    """Read protein-GO term annotations"""
    df = pd.read_csv(path, sep="\t", header=None, 
                    names=["protein", "go_term", "ontology"])
    
    annotations = defaultdict(list)
    for _, row in df.iterrows():
        annotations[row.protein].append(row.go_term)
    
    print(f"Loaded annotations for {len(annotations):,} proteins")
    return dict(annotations)


def read_ia_weights(path: Path) -> Dict[str, float]:
    """Read Information Accretion weights"""
    if not path.exists():
        print("Warning: IA weights file not found")
        return {}
    
    df = pd.read_csv(path, sep="\t", header=None, names=["go_term", "ia"])
    weights = {}
    
    for _, row in df.iterrows():
        try:
            weights[row.go_term] = float(str(row.ia).replace(",", "."))
        except:
            weights[row.go_term] = 0.0
    
    print(f"Loaded IA weights for {len(weights):,} GO terms")
    return weights


def read_taxonomy(path: Path) -> Dict[str, int]:
    """Read taxonomy mapping protein_id -> taxonomy_id"""
    if not path.exists():
        print("Warning: Taxonomy file not found")
        return {}
    
    df = pd.read_csv(path, sep="\t", header=None, names=["protein", "taxonomy_id"])
    taxonomy_map = dict(zip(df.protein, df.taxonomy_id))
    
    unique_taxa = len(set(taxonomy_map.values()))
    print(f"Loaded taxonomy for {len(taxonomy_map):,} proteins ({unique_taxa:,} unique taxa)")
    return taxonomy_map

## GO Ontology Helper Functions

In [None]:
def parse_obo(path: Path):
    """Parse OBO file to extract parent-child relationships, namespace (BP/MF/CC), and text descriptions"""
    parents = defaultdict(set)
    children = defaultdict(set)
    go_namespace = {}
    go_names = {}
    go_defs = {}
    
    if not path.exists():
        print("Warning: OBO file not found")
        return parents, children, go_namespace, go_names, go_defs
    
    with open(path) as f:
        current_id = None
        
        for line in f:
            line = line.strip()
            
            if line == "[Term]":
                current_id = None
            elif line.startswith("id: "):
                current_id = line.split("id: ")[1]
            elif line.startswith("name: ") and current_id:
                go_names[current_id] = line.split("name: ")[1]
            elif line.startswith("def: ") and current_id:
                def_text = line.split("def: ")[1]
                if '"' in def_text:
                    go_defs[current_id] = def_text.split('"')[1]
            elif line.startswith("namespace:") and current_id:
                ns = line.split("namespace:")[1].strip()
                if ns == "biological_process":
                    go_namespace[current_id] = "BP"
                elif ns == "molecular_function":
                    go_namespace[current_id] = "MF"
                elif ns == "cellular_component":
                    go_namespace[current_id] = "CC"
            elif line.startswith("is_a: ") and current_id:
                parent_id = line.split()[1]
                parents[current_id].add(parent_id)
                children[parent_id].add(current_id)
            elif line.startswith("relationship: part_of ") and current_id:
                parts = line.split()
                if len(parts) >= 3:
                    parent_id = parts[2]
                    parents[current_id].add(parent_id)
                    children[parent_id].add(current_id)
    
    print(f"Parsed {len(go_names):,} GO terms")
    print(f"  - {len([t for t, ns in go_namespace.items() if ns == 'BP'])} BP terms")
    print(f"  - {len([t for t, ns in go_namespace.items() if ns == 'MF'])} MF terms")
    print(f"  - {len([t for t, ns in go_namespace.items() if ns == 'CC'])} CC terms")
    print(f"  - {len(go_defs):,} definitions")
    
    return parents, children, go_namespace, go_names, go_defs


def get_ancestors(go_term: str, parents: Dict) -> Set[str]:
    """Get all ancestor terms"""
    ancestors = set()
    stack = [go_term]
    
    while stack:
        current = stack.pop()
        for parent in parents.get(current, []):
            if parent not in ancestors:
                ancestors.add(parent)
                stack.append(parent)
    
    return ancestors


def propagate_labels(annotations: Dict[str, List[str]], 
                    parents: Dict) -> Dict[str, List[str]]:
    """Propagate labels up the GO graph"""
    print("Propagating labels up GO hierarchy...")
    propagated = {}
    
    for protein, terms in tqdm(annotations.items(), desc="Propagating"):
        expanded = set(terms)
        for term in terms:
            expanded.update(get_ancestors(term, parents))
        propagated[protein] = sorted(expanded)
    
    original_count = sum(len(v) for v in annotations.values())
    new_count = sum(len(v) for v in propagated.values())
    print(f"  {original_count:,} -> {new_count:,} annotations")
    
    return propagated

## GO Text Embedding Helpers

In [None]:
def print_go_term_info(go_id: str, go_names: Dict, go_defs: Dict, go_namespace: Dict):
    """Print GO term information for debugging"""
    name = go_names.get(go_id, "Unknown")
    definition = go_defs.get(go_id, "No definition")
    namespace = go_namespace.get(go_id, "Unknown")
    
    print(f"GO Term: {go_id}")
    print(f"  Name: {name}")
    print(f"  Namespace: {namespace}")
    print(f"  Definition: {definition[:200]}..." if len(definition) > 200 else f"  Definition: {definition}")
    print()


def create_go_term_embeddings(
    go_terms: List[str],
    go_names: Dict,
    go_defs: Dict,
    model_name: str = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract",
    batch_size: int = 32
) -> np.ndarray:
    """
    Create embeddings for all GO terms from OBO.
    
    Uses GO term names and definitions to create rich semantic embeddings.
    These embeddings are used for both training and testing (no domain shift).
    
    Args:
        go_terms: List of GO IDs to embed
        go_names: Dict mapping GO ID -> name
        go_defs: Dict mapping GO ID -> definition
    
    Returns:
        np.ndarray: (n_go_terms, 768) - [CLS] embeddings for each GO term
    """
    from transformers import AutoTokenizer, AutoModel
    import torch
    
    print(f"Creating GO term embeddings using {model_name}...")
    print(f"  Embedding {len(go_terms):,} GO terms from ontology...")
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()
    
    all_embeddings = []
    
    with torch.no_grad():
        for i in tqdm(range(0, len(go_terms), batch_size), desc="Encoding GO terms"):
            batch_go_terms = go_terms[i:i + batch_size]
            batch_texts = []
            
            for go_id in batch_go_terms:
                name = go_names.get(go_id, "")
                definition = go_defs.get(go_id, "")
                text = f"{name}. {definition}" if definition else name
                if not text:
                    text = f"GO term {go_id}"
                batch_texts.append(text)
            
            inputs = tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors="pt"
            )
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            outputs = model(**inputs)
            cls_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
            all_embeddings.append(cls_embeddings)
    
    go_embeddings = np.vstack(all_embeddings)
    
    print(f"  Created GO term embeddings: {go_embeddings.shape}")
    
    del model, tokenizer
    torch.cuda.empty_cache()
    
    return go_embeddings

## PyTorch Dataset

In [None]:
class ProteinDataset(Dataset):
    """Dataset for single-head prediction"""
    def __init__(self, embeddings, y_labels, taxonomy_ids=None, go_prototypes=None):
        self.embeddings = torch.FloatTensor(embeddings)
        self.y_labels = torch.FloatTensor(y_labels)  # Single label matrix
        
        if taxonomy_ids is not None:
            self.taxonomy_ids = torch.LongTensor(taxonomy_ids)
        else:
            self.taxonomy_ids = None
        
        self.go_prototypes = go_prototypes  # Shared prototypes

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        tax_id = self.taxonomy_ids[idx] if self.taxonomy_ids is not None else torch.tensor(0, dtype=torch.long)
        
        return (
            self.embeddings[idx],
            tax_id,
            self.y_labels[idx],
        )

## Model Architecture

In [None]:
class SimpleProteinFunctionPredictor(nn.Module):
    """
    Single output head with GO term prototypes.
    
    Uses taxonomy embedding and GO text embeddings for enhanced prediction.
    """
    def __init__(
        self,
        protbert_dim: int,
        num_go_terms: int,
        hidden_dims: List[int] = [512, 256],
        dropout: float = 0.3,
        use_taxonomy: bool = False,
        num_taxonomy: int = 0,
        taxonomy_emb_dim: int = 32,
        go_prototypes: Optional[np.ndarray] = None,
    ):
        super().__init__()
        
        self.use_taxonomy = use_taxonomy
        
        # Taxonomy embedding with weight=0.1
        if use_taxonomy and num_taxonomy > 0:
            self.taxonomy_embedding = nn.Embedding(num_taxonomy + 1, protbert_dim, padding_idx=0)
            self.taxonomy_weight = 0.1
            input_dim = protbert_dim
        else:
            self.taxonomy_embedding = None
            input_dim = protbert_dim

        # Shared trunk
        layers = []
        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
            prev_dim = hidden_dim

        self.trunk = nn.Sequential(*layers)
        
        # Single projection head to GO embedding space
        self.projection = nn.Linear(prev_dim, 768)
        
        # Single scale and bias
        self.scale = nn.Parameter(torch.tensor(1.0))
        self.bias = nn.Parameter(torch.zeros(num_go_terms))
        
        # Register GO prototypes as buffer
        if go_prototypes is not None:
            self.register_buffer('go_prototypes', torch.FloatTensor(go_prototypes))
        else:
            self.register_buffer('go_prototypes', None)

    def forward(self, protbert_emb, taxonomy_ids=None):
        """
        Args:
            protbert_emb: (batch, protbert_dim)
            taxonomy_ids: (batch,)
        
        Returns:
            logits: (batch, num_go_terms) - ALL GO terms in one output
        """
        x_pro = protbert_emb
        
        # Taxonomy encoding
        if self.use_taxonomy and self.taxonomy_embedding is not None and taxonomy_ids is not None:
            tax_emb = self.taxonomy_embedding(taxonomy_ids)
            x_mix = self.taxonomy_weight * tax_emb + x_pro
        else:
            x_mix = x_pro
        
        # Shared trunk
        h = self.trunk(x_mix)
        
        # Project to GO embedding space
        repr = self.projection(h)
        
        # Normalize
        repr = F.normalize(repr, p=2, dim=1)
        
        # Compute cosine similarity with GO prototypes
        if self.go_prototypes is not None:
            proto_norm = F.normalize(self.go_prototypes, p=2, dim=1)
            logits = torch.matmul(repr, proto_norm.t()) * self.scale + self.bias
        else:
            logits = torch.zeros(repr.size(0), 1, device=repr.device)
        
        return logits

## Training and Evaluation Functions

In [None]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0

    for batch in dataloader:
        embeddings, taxonomy_ids, y_labels = batch
        embeddings = embeddings.to(device)
        taxonomy_ids = taxonomy_ids.to(device)
        y_labels = y_labels.to(device)

        optimizer.zero_grad()
        logits = model(embeddings, taxonomy_ids)

        loss = criterion(logits, y_labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)


def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds, all_labels = [], []

    with torch.no_grad():
        for batch in dataloader:
            embeddings, taxonomy_ids, y_labels = batch
            embeddings = embeddings.to(device)
            taxonomy_ids = taxonomy_ids.to(device)
            y_labels = y_labels.to(device)

            logits = model(embeddings, taxonomy_ids)
            loss = criterion(logits, y_labels)
            total_loss += loss.item()

            probs = torch.sigmoid(logits)
            all_preds.append(probs.cpu().numpy())
            all_labels.append(y_labels.cpu().numpy())

    avg_loss = total_loss / len(dataloader)
    preds = np.vstack(all_preds)
    labels = np.vstack(all_labels)

    return avg_loss, preds, labels

def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray, 
                   weights: Optional[np.ndarray] = None) -> Dict:
    """Compute weighted precision, recall, F1"""
    tp = ((y_true == 1) & (y_pred == 1)).sum(axis=0).astype(float)
    fp = ((y_true == 0) & (y_pred == 1)).sum(axis=0).astype(float)
    fn = ((y_true == 1) & (y_pred == 0)).sum(axis=0).astype(float)
    
    eps = 1e-12
    precision = tp / (tp + fp + eps)
    recall = tp / (tp + fn + eps)
    f1 = 2 * precision * recall / (precision + recall + eps)
    
    if weights is None:
        weights = np.ones_like(precision)
    
    weight_sum = weights.sum() + eps
    return {
        "precision": (precision * weights).sum() / weight_sum,
        "recall": (recall * weights).sum() / weight_sum,
        "f1": (f1 * weights).sum() / weight_sum,
    }

def find_best_threshold(y_true: np.ndarray, y_pred_prob: np.ndarray,
                       weights: np.ndarray, thresholds: np.ndarray) -> Tuple[float, float]:
    """Find optimal threshold by grid search"""
    print("→ Searching for optimal threshold...")
    
    best_threshold = 0.5
    best_f1 = -1.0
    
    for thresh in tqdm(thresholds, desc="Threshold search"):
        y_pred = (y_pred_prob >= thresh).astype(int)
        metrics = compute_metrics(y_true, y_pred, weights)
        
        if metrics["f1"] > best_f1:
            best_f1 = metrics["f1"]
            best_threshold = thresh

    print(f"  Best threshold: {best_threshold:.3f} (F1: {best_f1:.4f})")
    return best_threshold, best_f1

## Prediction Propagation

In [None]:
def propagate_predictions(predictions: np.ndarray, 
                         parents: Dict,
                         go_terms: List[str],
                         iterations: int = 3) -> np.ndarray:
    """Propagate predictions up GO hierarchy"""
    print("→ Propagating predictions...")
    
    term_to_idx = {t: i for i, t in enumerate(go_terms)}
    pred_copy = predictions.copy()
    
    # Restrict parents to terms in vocabulary
    restricted_parents = {}
    for term in go_terms:
        restricted_parents[term] = {p for p in parents.get(term, []) 
                                   if p in term_to_idx}
    
    for iteration in range(iterations):
        changed = False
        
        for child_idx, child_term in enumerate(go_terms):
            child_scores = pred_copy[:, child_idx]
            
            for parent_term in restricted_parents.get(child_term, []):
                parent_idx = term_to_idx[parent_term]
                
                # Update parent where child score is higher
                mask = child_scores > pred_copy[:, parent_idx]
                if mask.any():
                    pred_copy[mask, parent_idx] = child_scores[mask]
                    changed = True
        
        if not changed:
            print(f"  Converged after {iteration + 1} iterations")
            break
    
    return pred_copy

## Main Pipeline

### Load Data

In [None]:
print("Loading data...")
train_seqs = read_fasta(TRAIN_FASTA)
test_seqs = read_fasta(TEST_FASTA)
annotations = read_annotations(TRAIN_TERMS)
ia_weights_dict = read_ia_weights(IA_FILE)

if USE_TAXONOMY:
    print("\nLoading taxonomy...")
    taxonomy_map = read_taxonomy(TRAIN_TAXONOMY)
    unique_taxa = sorted(set(taxonomy_map.values()))
    taxa_to_idx = {taxa: idx + 1 for idx, taxa in enumerate(unique_taxa)}
    num_taxonomy = len(unique_taxa)
    print(f"  Mapped {num_taxonomy:,} unique taxonomy IDs")
else:
    taxonomy_map = {}
    taxa_to_idx = {}
    num_taxonomy = 0

print("\nLoading GO ontology...")
parents_map, children_map, go_namespace, go_names, go_defs = parse_obo(GO_OBO)

In [None]:
if PROPAGATE_TRAIN:
    annotations = propagate_labels(annotations, parents_map)

### Prepare Labels

### Propagate Training Labels

In [None]:
print("\nPreparing labels (single output)...")

train_proteins = [p for p in annotations.keys() if p in train_seqs]
print(f"  {len(train_proteins):,} training proteins")

# Count term frequencies
term_counts = Counter()
for protein in train_proteins:
    term_counts.update(annotations[protein])

# Select top-K terms
top_terms = [t for t, _ in term_counts.most_common(TOP_K_LABELS)]
chosen_terms = set(top_terms)
print(f"  Using top {len(chosen_terms):,} GO terms (all ontologies)")

# Filter annotations
for protein in train_proteins:
    annotations[protein] = [t for t in annotations[protein] if t in chosen_terms]

# Create labels
labels = [[t for t in annotations[p] if t in chosen_terms] for p in train_proteins]

mlb = MultiLabelBinarizer(classes=sorted(chosen_terms))
y_labels = mlb.fit_transform(labels).astype(np.float32)

print(f"  y_labels shape: {y_labels.shape}")

# IA weights
ia_weights = np.array([ia_weights_dict.get(t, 1.0) for t in mlb.classes_])

In [None]:
print("\nLoading ProtBERT embeddings...")

train_emb = np.load(PROTBERT_TRAIN_EMB).astype(np.float32)
print(f"  Train embeddings shape: {train_emb.shape}")

X = train_emb[:len(train_proteins)]

# Prepare taxonomy features
if USE_TAXONOMY:
    taxonomy_ids = []
    for protein in train_proteins:
        taxa = taxonomy_map.get(protein, 0)
        taxa_idx = taxa_to_idx.get(taxa, 0)
        taxonomy_ids.append(taxa_idx)
    taxonomy_ids = np.array(taxonomy_ids)
    print(f"  Taxonomy IDs shape: {taxonomy_ids.shape}")
    print(f"  Unknown taxonomy: {(taxonomy_ids == 0).sum()} / {len(taxonomy_ids)}")
else:
    taxonomy_ids = None

### Load ProtBERT Embeddings

In [None]:
if USE_GO_TEXT:
    print("\nCreating GO term embeddings (single prototype matrix)...")
    
    go_embeddings = create_go_term_embeddings(
        go_terms=list(mlb.classes_),
        go_names=go_names,
        go_defs=go_defs,
        model_name=GO_TEXT_MODEL,
        batch_size=GO_TEXT_BATCH_SIZE
    )
    
    print(f"  GO prototypes: {go_embeddings.shape}")
    print(f"  These prototypes will be used for BOTH train AND test (no domain shift!)")
else:
    go_embeddings = None

### Create GO Term Embeddings

In [None]:
print("\nSplitting data...")

if USE_TAXONOMY and taxonomy_ids is not None:
    X_train, X_val, y_train, y_val, tax_train, tax_val = train_test_split(
        X, y_labels, taxonomy_ids,
        test_size=TEST_PROPORTION,
        random_state=RANDOM_SEED,
    )
else:
    X_train, X_val, y_train, y_val = train_test_split(
        X, y_labels,
        test_size=TEST_PROPORTION,
        random_state=RANDOM_SEED,
    )
    tax_train = None
    tax_val = None

print(f"  Train: {X_train.shape}")
print(f"  Val:   {X_val.shape}")

# Create datasets
train_dataset = ProteinDataset(X_train, y_train, tax_train, go_embeddings)
val_dataset = ProteinDataset(X_val, y_val, tax_val, go_embeddings)

train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True
)
val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE * 2, shuffle=False, num_workers=2, pin_memory=True
)

### Split Data and Create DataLoaders

In [None]:
print("\nBuilding simple MLP model...")

model = SimpleProteinFunctionPredictor(
    protbert_dim=X.shape[1],
    num_go_terms=y_labels.shape[1],
    hidden_dims=HIDDEN_DIMS,
    dropout=DROPOUT,
    use_taxonomy=USE_TAXONOMY,
    num_taxonomy=num_taxonomy,
    taxonomy_emb_dim=X.shape[1],
    go_prototypes=go_embeddings if USE_GO_TEXT else None,
).to(DEVICE)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
if USE_TAXONOMY:
    print(f"  Taxonomy embedding: {num_taxonomy} taxa -> {X.shape[1]} dims")
if USE_GO_TEXT:
    print(f"  GO prototypes: {go_embeddings.shape}")
    print(f"  Using knowledge-enhanced prediction (prototype matching)")

criterion = nn.BCEWithLogitsLoss()
optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=2, verbose=True)

### Build Model

In [None]:
print("\nTesting forward pass...")
with torch.no_grad():
    dummy_batch = next(iter(train_loader))
    dummy_emb, dummy_tax, dummy_labels = dummy_batch
    dummy_emb = dummy_emb.to(DEVICE)
    dummy_tax = dummy_tax.to(DEVICE)
    
    logits = model(dummy_emb, dummy_tax)
    
    print(f"  Input shape: {dummy_emb.shape}")
    print(f"  Output shape: {logits.shape} (should be {dummy_emb.shape[0]} x {y_labels.shape[1]})")
    print(f"\n  Logit statistics:")
    print(f"    Logits: min={logits.min().item():.2f}, max={logits.max().item():.2f}, mean={logits.mean().item():.2f}")
    
    probs = torch.sigmoid(logits)
    print(f"\n  Probability statistics (after sigmoid):")
    print(f"    Probs: min={probs.min().item():.4f}, max={probs.max().item():.4f}, mean={probs.mean().item():.4f}")
    
    print(f"\n  Forward pass successful!")

### Test Forward Pass

In [None]:
print("\nTraining...")
best_val_loss = float('inf')
patience_counter = 0
best_model_path = WORK_DIR / "best_model.pt"

history = {
    "train_loss": [],
    "val_loss": [],
    "val_f1": [],
}

for epoch in range(EPOCHS):
    train_loss = train_epoch(model, train_loader, criterion, optimizer, DEVICE)
    val_loss, val_preds, val_labels = evaluate(model, val_loader, criterion, DEVICE)
    
    # Compute metrics
    val_bin = (val_preds >= 0.5).astype(int)
    metrics_unw = compute_metrics(val_labels, val_bin, weights=None)
    metrics_ia = compute_metrics(val_labels, val_bin, ia_weights)
    
    scheduler.step(val_loss)
    
    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["val_f1"].append(metrics_unw["f1"])
    
    print(f"Epoch {epoch+1}/{EPOCHS}")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss:   {val_loss:.4f}")
    print(f"  [Unweighted] P/R/F1: {metrics_unw['precision']:.4f}/{metrics_unw['recall']:.4f}/{metrics_unw['f1']:.4f}")
    print(f"  [IA-weighted F1] {metrics_ia['f1']:.6f}")
    
    # Debug info
    pos_preds = (val_preds >= 0.5).sum()
    print(f"  [Debug] Positive preds: {pos_preds} | Scale: {model.scale.item():.2f}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(model.state_dict(), best_model_path)
        print("  Saved best model")
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print(f"\n  Early stopping triggered after {epoch+1} epochs")
            break
    
    print()

print("Loading best model...")
model.load_state_dict(torch.load(best_model_path))

### Train Model

In [None]:
print("\nEvaluating on validation set...")
_, val_preds, val_labels = evaluate(model, val_loader, criterion, DEVICE)

if THRESHOLD_SEARCH:
    best_thr_ia, best_f1_ia = find_best_threshold(
        val_labels, val_preds, ia_weights, THRESHOLD_GRID
    )
    best_thr_unw, best_f1_unw = find_best_threshold(
        val_labels, val_preds, None, THRESHOLD_GRID
    )
    print(
        f"Best threshold IA: {best_thr_ia:.2f}, F1_IA: {best_f1_ia:.6f} | "
        f"Unweighted: {best_thr_unw:.2f}, F1_unw: {best_f1_unw:.4f}"
    )

### Evaluate on Validation Set

In [None]:
print("\nGenerating test predictions...")

test_emb = np.load(PROTBERT_TEST_EMB).astype(np.float32)
print(f"  Test embeddings shape: {test_emb.shape}")

# Prepare test taxonomy
if USE_TAXONOMY:
    test_taxonomy_ids = []
    test_protein_list = list(test_seqs.keys())[:len(test_emb)]
    for protein in test_protein_list:
        taxa = taxonomy_map.get(protein, 0)
        taxa_idx = taxa_to_idx.get(taxa, 0)
        test_taxonomy_ids.append(taxa_idx)
    test_taxonomy_ids = np.array(test_taxonomy_ids)
    print(f"  Test taxonomy IDs shape: {test_taxonomy_ids.shape}")
else:
    test_taxonomy_ids = None

if USE_GO_TEXT:
    print(f"  Using same GO prototypes as training (no domain shift!)")
    print(f"     GO prototypes: {go_embeddings.shape}")

# Create test dataset
test_dataset = ProteinDataset(
    test_emb,
    np.zeros((len(test_emb), y_labels.shape[1])),
    test_taxonomy_ids,
    go_embeddings if USE_GO_TEXT else None,
)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE * 2, 
                        shuffle=False, num_workers=2, pin_memory=True)

# Predict
model.eval()
all_preds = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Predicting"):
        embeddings, taxonomy_ids, _ = batch
        embeddings = embeddings.to(DEVICE)
        taxonomy_ids = taxonomy_ids.to(DEVICE)
        
        logits = model(embeddings, taxonomy_ids)
        probs = torch.sigmoid(logits)
        
        all_preds.append(probs.cpu().numpy())

test_preds = np.vstack(all_preds)
print(f"  Predictions shape: {test_preds.shape}")

### Generate Test Predictions

In [None]:
if PROPAGATE_PRED:
    test_preds = propagate_predictions(
        test_preds, parents_map, list(mlb.classes_), PROPAGATE_ITERATIONS
    )

all_terms = mlb.classes_

### Apply Hierarchy Propagation

In [None]:
print(f"\nWriting submission to {OUTPUT_FILE}...")

test_ids = list(test_seqs.keys())[:test_preds.shape[0]]

with open(OUTPUT_FILE, "w") as f:
    for i, protein_id in enumerate(tqdm(test_ids, desc="Writing")):
        probs = test_preds[i]
        top_indices = np.argsort(probs)[-TOP_K_PER_PROTEIN:][::-1]
    
        for idx in top_indices:
            score = float(probs[idx])
            if score > 1e-6:
                go_term = all_terms[idx]
                f.write(f"{protein_id}\t{go_term}\t{score:.3f}\n")


print("\n" + "="*70)
print("Pipeline completed successfully")
print("="*70)

del model, train_loader, val_loader, test_loader
torch.cuda.empty_cache()
gc.collect()