# Hierarchy-Calibrated ProtBERT

This notebook implements a ProtBERT-based model with GO hierarchy calibration to ensure predictions are consistent with the Gene Ontology graph structure.

## Import Libraries

In [None]:
import os
import gc
import random
import numpy as np
import pandas as pd
from pathlib import Path
from collections import defaultdict, Counter
from typing import Dict, List, Set, Tuple
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau

from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split

## Configuration

In [None]:
# Paths
DATA_DIR = Path("/kaggle/input/cafa-6-protein-function-prediction")
TRAIN_DIR = DATA_DIR / "Train"
TEST_DIR = DATA_DIR / "Test"
WORK_DIR = Path("/kaggle/working")

TRAIN_FASTA = TRAIN_DIR / "train_sequences.fasta"
TRAIN_TERMS = TRAIN_DIR / "train_terms.tsv"
GO_OBO = TRAIN_DIR / "go-basic.obo"
IA_FILE = DATA_DIR / "IA.tsv"
TEST_FASTA = TEST_DIR / "testsuperset.fasta"
OUTPUT_FILE = WORK_DIR / "submission.tsv"

# ProtBERT embeddings
PROTBERT_TRAIN_EMB = Path("/kaggle/input/nnn-cafa6-protbert-embedding/train_embeddings.npy")
PROTBERT_TEST_EMB = Path("/kaggle/input/nnn-cafa6-protbert-embedding/test_embeddings.npy")

# Test split
TEST_PROPORTION = 0.2

# Model parameters
RANDOM_SEED = 42
TOP_K_LABELS = 3000
HIDDEN_DIMS = [768, 512, 256]
DROPOUT = 0.2
LEARNING_RATE = 1e-5
BATCH_SIZE = 32
EPOCHS = 20
PATIENCE = 5
WEIGHT_DECAY = 1e-5

# Prediction parameters
TOP_K_PER_PROTEIN = 200
THRESHOLD_SEARCH = True
THRESHOLD_GRID = np.arange(0.01, 0.51, 0.01)

# GO propagation
PROPAGATE_TRAIN = True
PROPAGATE_PRED = True
PROPAGATE_ITERATIONS = 3

# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Set seeds
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

## Data Helper Functions

In [None]:
def read_fasta(path: Path) -> Dict[str, str]:
    """Read FASTA file and return dict of protein_id: sequence"""
    sequences = {}
    with open(path) as f:
        protein_id = None
        seq_parts = []
        
        for line in f:
            line = line.strip()
            if line.startswith(">"):
                if protein_id:
                    sequences[protein_id] = "".join(seq_parts)
                
                header = line[1:].split()[0]
                protein_id = header.split("|")[1] if "|" in header else header
                seq_parts = []
            else:
                seq_parts.append(line)
        
        if protein_id:
            sequences[protein_id] = "".join(seq_parts)
    
    print(f"Loaded {len(sequences):,} sequences from {path.name}")
    return sequences


def read_annotations(path: Path) -> Dict[str, List[str]]:
    """Read protein-GO term annotations"""
    df = pd.read_csv(path, sep="\t", header=None, 
                    names=["protein", "go_term", "ontology"])
    
    annotations = defaultdict(list)
    for _, row in df.iterrows():
        annotations[row.protein].append(row.go_term)
    
    print(f"Loaded annotations for {len(annotations):,} proteins")
    return dict(annotations)


def read_ia_weights(path: Path) -> Dict[str, float]:
    """Read Information Accretion weights"""
    if not path.exists():
        print("Warning: IA weights file not found")
        return {}
    
    df = pd.read_csv(path, sep="\t", header=None, names=["go_term", "ia"])
    weights = {}
    
    for _, row in df.iterrows():
        try:
            weights[row.go_term] = float(str(row.ia).replace(",", "."))
        except:
            weights[row.go_term] = 0.0
    
    print(f"Loaded IA weights for {len(weights):,} GO terms")
    return weights

## GO Ontology Parser

In [None]:
def parse_obo(path: Path) -> Tuple[Dict, Dict]:
    """Parse OBO file to extract parent-child relationships"""
    parents = defaultdict(set)
    children = defaultdict(set)
    
    if not path.exists():
        print("Warning: OBO file not found")
        return parents, children
    
    with open(path) as f:
        current_id = None
        
        for line in f:
            line = line.strip()
            
            if line == "[Term]":
                current_id = None
            elif line.startswith("id: "):
                current_id = line.split("id: ")[1]
            elif line.startswith("is_a: ") and current_id:
                parent_id = line.split()[1]
                parents[current_id].add(parent_id)
                children[parent_id].add(current_id)
            elif line.startswith("relationship: part_of ") and current_id:
                parts = line.split()
                if len(parts) >= 3:
                    parent_id = parts[2]
                    parents[current_id].add(parent_id)
                    children[parent_id].add(current_id)
    
    print(f"Parsed GO graph: {len(parents):,} terms with parents")
    return dict(parents), dict(children)


def get_ancestors(go_term: str, parents: Dict) -> Set[str]:
    """Get all ancestor terms"""
    ancestors = set()
    stack = [go_term]
    
    while stack:
        current = stack.pop()
        for parent in parents.get(current, []):
            if parent not in ancestors:
                ancestors.add(parent)
                stack.append(parent)
    
    return ancestors


def propagate_labels(annotations: Dict[str, List[str]], 
                    parents: Dict) -> Dict[str, List[str]]:
    """Propagate labels up the GO graph"""
    print("Propagating labels up GO hierarchy...")
    propagated = {}
    
    for protein, terms in tqdm(annotations.items(), desc="Propagating"):
        expanded = set(terms)
        for term in terms:
            expanded.update(get_ancestors(term, parents))
        propagated[protein] = sorted(expanded)
    
    original_count = sum(len(v) for v in annotations.values())
    new_count = sum(len(v) for v in propagated.values())
    print(f"  {original_count:,} -> {new_count:,} annotations")
    
    return propagated

## PyTorch Dataset

In [None]:
class ProteinDataset(Dataset):
    """Dataset for protein embeddings and labels"""
    
    def __init__(self, embeddings: np.ndarray, labels: np.ndarray):
        self.embeddings = torch.FloatTensor(embeddings)
        self.labels = torch.FloatTensor(labels)
    
    def __len__(self):
        return len(self.embeddings)
    
    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]

## Hierarchy Calibration Functions

In [None]:
class ProteinFunctionPredictor(nn.Module):
    """Multi-layer perceptron for protein function prediction"""
    
    def __init__(self, input_dim: int, output_dim: int, 
                 hidden_dims: List[int] = [512, 256], 
                 dropout: float = 0.3):
        super().__init__()
        
        layers = []
        prev_dim = input_dim
        
        # Hidden layers
        for i, hidden_dim in enumerate(hidden_dims):
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
            prev_dim = hidden_dim
        
        # Output layer (no activation - will use BCEWithLogitsLoss)
        layers.append(nn.Linear(prev_dim, output_dim))
        
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.network(x)

## Model Architecture

In [None]:
def compute_simple_metrics(y_true, y_pred):
    """Compute simple precision/recall (like Keras, not IA-weighted)"""
    y_pred_binary = (y_pred >= 0.5).astype(np.float32)
    
    tp = (y_true * y_pred_binary).sum()
    fp = ((1 - y_true) * y_pred_binary).sum()
    fn = (y_true * (1 - y_pred_binary)).sum()
    
    precision = tp / (tp + fp + 1e-7)
    recall = tp / (tp + fn + 1e-7)
    
    return precision, recall


def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    for embeddings, labels in dataloader:
        embeddings = embeddings.to(device)
        labels = labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        logits = model(embeddings)
        loss = criterion(logits, labels)
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        total_loss += loss.item()
        
        # Collect predictions for metrics
        with torch.no_grad():
            probs = torch.sigmoid(logits)
            all_preds.append(probs.cpu().numpy())
            all_labels.append(labels.cpu().numpy())
    
    avg_loss = total_loss / len(dataloader)
    
    # Compute metrics
    preds = np.vstack(all_preds)
    labels = np.vstack(all_labels)
    precision, recall = compute_simple_metrics(labels, preds)
    
    return avg_loss, precision, recall


def evaluate(model, dataloader, criterion, device):
    """Evaluate model"""
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for embeddings, labels in dataloader:
            embeddings = embeddings.to(device)
            labels = labels.to(device)
            
            logits = model(embeddings)
            loss = criterion(logits, labels)
            
            # Apply sigmoid to get probabilities
            outputs = torch.sigmoid(logits)
            
            total_loss += loss.item()
            all_preds.append(outputs.cpu().numpy())
            all_labels.append(labels.cpu().numpy())
    
    avg_loss = total_loss / len(dataloader)
    predictions = np.vstack(all_preds)
    labels = np.vstack(all_labels)
    
    # Compute simple metrics
    precision, recall = compute_simple_metrics(labels, predictions)
    
    return avg_loss, predictions, labels, precision, recall


def compute_metrics(y_true: np.ndarray, y_pred: np.ndarray, 
                   weights: np.ndarray) -> Dict:
    """Compute IA-weighted precision, recall, F1"""
    tp = ((y_true == 1) & (y_pred == 1)).sum(axis=0).astype(float)
    fp = ((y_true == 0) & (y_pred == 1)).sum(axis=0).astype(float)
    fn = ((y_true == 1) & (y_pred == 0)).sum(axis=0).astype(float)
    
    eps = 1e-12
    precision = tp / (tp + fp + eps)
    recall = tp / (tp + fn + eps)
    f1 = 2 * precision * recall / (precision + recall + eps)
    
    # Weighted metrics
    weight_sum = weights.sum() + eps
    return {
        "precision": (precision * weights).sum() / weight_sum,
        "recall": (recall * weights).sum() / weight_sum,
        "f1": (f1 * weights).sum() / weight_sum
    }


def find_best_threshold(y_true: np.ndarray, y_pred_prob: np.ndarray,
                       weights: np.ndarray, thresholds: np.ndarray) -> Tuple[float, float]:
    """Find optimal threshold by grid search"""
    print("â†’ Searching for optimal threshold...")
    
    best_threshold = 0.5
    best_f1 = -1.0
    
    for thresh in tqdm(thresholds, desc="Threshold search"):
        y_pred = (y_pred_prob >= thresh).astype(int)
        metrics = compute_metrics(y_true, y_pred, weights)
        
        if metrics["f1"] > best_f1:
            best_f1 = metrics["f1"]
            best_threshold = thresh
    
    print(f"  Best threshold: {best_threshold:.3f} (F1: {best_f1:.4f})")
    return best_threshold, best_f1

## Training and Evaluation Functions

In [None]:
def propagate_predictions(predictions: np.ndarray, 
                         parents: Dict,
                         go_terms: List[str],
                         iterations: int = 3) -> np.ndarray:
    """Propagate predictions up GO hierarchy"""
    print("â†’ Propagating predictions...")
    
    term_to_idx = {t: i for i, t in enumerate(go_terms)}
    pred_copy = predictions.copy()
    
    # Restrict parents to terms in vocabulary
    restricted_parents = {}
    for term in go_terms:
        restricted_parents[term] = {p for p in parents.get(term, []) 
                                   if p in term_to_idx}
    
    for iteration in range(iterations):
        changed = False
        
        for child_idx, child_term in enumerate(go_terms):
            child_scores = pred_copy[:, child_idx]
            
            for parent_term in restricted_parents.get(child_term, []):
                parent_idx = term_to_idx[parent_term]
                
                # Update parent where child score is higher
                mask = child_scores > pred_copy[:, parent_idx]
                if mask.any():
                    pred_copy[mask, parent_idx] = child_scores[mask]
                    changed = True
        
        if not changed:
            print(f"  Converged after {iteration + 1} iterations")
            break
    
    return pred_copy

## Main Pipeline

### Load Data

In [None]:
print("Loading data...")
train_seqs = read_fasta(TRAIN_FASTA)
test_seqs = read_fasta(TEST_FASTA)
annotations = read_annotations(TRAIN_TERMS)
ia_weights_dict = read_ia_weights(IA_FILE)

print("\nLoading GO ontology...")
parents_map, children_map = parse_obo(GO_OBO)

In [None]:
# 2. Propagate Training Labels
if PROPAGATE_TRAIN:
    annotations = propagate_labels(annotations, parents_map)

### Prepare Labels and Embeddings

In [None]:
print("\nPreparing labels...")

train_proteins = [p for p in annotations.keys() if p in train_seqs]
print(f"  {len(train_proteins):,} training proteins")

# Count term frequencies
term_counts = Counter()
for protein in train_proteins:
    term_counts.update(annotations[protein])

# Select top-K terms
top_terms = [t for t, _ in term_counts.most_common(TOP_K_LABELS)]
chosen_terms = set(top_terms)
print(f"  Using top {len(chosen_terms):,} GO terms")

# Filter annotations
for protein in train_proteins:
    annotations[protein] = [t for t in annotations[protein] 
                           if t in chosen_terms]

# Create label matrix
labels_list = [annotations[p] for p in train_proteins]
mlb = MultiLabelBinarizer(classes=sorted(chosen_terms))
y = mlb.fit_transform(labels_list).astype(np.float32)

print(f"  Label matrix shape: {y.shape}")

# Prepare IA weights array
ia_weights = np.array([ia_weights_dict.get(t, 1.0) for t in mlb.classes_])

In [None]:
print("\nLoading ProtBERT embeddings...")

train_emb = np.load(PROTBERT_TRAIN_EMB).astype(np.float32)
print(f"  Train embeddings shape: {train_emb.shape}")

X = train_emb[:len(train_proteins)]

### Split Data and Create DataLoaders

In [None]:
# 5. Train/Val Split
print("\nðŸ“Š Splitting data...")
X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=TEST_PROPORTION, random_state=RANDOM_SEED
)

print(f"  Train: {X_train.shape}")
print(f"  Val:   {X_val.shape}")

# Create datasets and dataloaders
train_dataset = ProteinDataset(X_train, y_train)
val_dataset = ProteinDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, 
                         shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE * 2, 
                       shuffle=False, num_workers=0, pin_memory=True)

In [None]:
print("\nBuilding model...")
model = ProteinFunctionPredictor(
    input_dim=X_train.shape[1],
    output_dim=y_train.shape[1],
    hidden_dims=HIDDEN_DIMS,
    dropout=DROPOUT
).to(DEVICE)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")

criterion = nn.BCEWithLogitsLoss()
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, 
                             patience=5, verbose=True)

### Build Model and Train

In [None]:
print("\nTraining...")
best_val_loss = float('inf')
patience_counter = 0
best_model_path = WORK_DIR / "best_model.pt"

history = {
    'train_loss': [],
    'val_loss': [],
    'val_f1': []
}

for epoch in range(EPOCHS):
    train_loss, train_prec, train_rec = train_epoch(model, train_loader, criterion, optimizer, DEVICE)
    val_loss, val_preds, val_labels, val_prec, val_rec = evaluate(model, val_loader, criterion, DEVICE)
    
    scheduler.step(val_loss)
    
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['val_f1'].append(2 * val_prec * val_rec / (val_prec + val_rec + 1e-7))
    
    current_lr = optimizer.param_groups[0]['lr']
    
    print(f"Epoch {epoch+1}/{EPOCHS}")
    print(f"  loss: {train_loss:.4f} - precision: {train_prec:.4f} - recall: {train_rec:.4f}")
    print(f"  val_loss: {val_loss:.4f} - val_precision: {val_prec:.4f} - val_recall: {val_rec:.4f}")
    print(f"  learning_rate: {current_lr:.4e}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(model.state_dict(), best_model_path)
        print(f"  Model improved, saved")
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print(f"\n  Early stopping triggered after {epoch+1} epochs")
            break
    
    print()

print("Loading best model...")
model.load_state_dict(torch.load(best_model_path))

In [None]:
print("\nEvaluating on validation set...")
_, val_preds, val_labels, _, _ = evaluate(model, val_loader, criterion, DEVICE)

if THRESHOLD_SEARCH:
    best_threshold, best_f1 = find_best_threshold(
        val_labels, val_preds, ia_weights, THRESHOLD_GRID
    )
else:
    best_threshold = 0.5
    val_pred_binary = (val_preds >= best_threshold).astype(int)
    val_metrics = compute_metrics(val_labels, val_pred_binary, ia_weights)
    best_f1 = val_metrics["f1"]
    print(f"  F1 at threshold 0.5: {best_f1:.4f}")

## Test Predictions

### Generate Predictions

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

epochs_range = range(1, len(history['train_loss']) + 1)
axes[0].plot(epochs_range, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
axes[0].plot(epochs_range, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

axes[1].plot(epochs_range, history['val_f1'], 'g-', label='Val F1 Score', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('F1 Score', fontsize=12)
axes[1].set_title('Validation F1 Score', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nTraining Summary:")
print(f"  Best Val Loss: {min(history['val_loss']):.4f} at epoch {history['val_loss'].index(min(history['val_loss'])) + 1}")
print(f"  Best Val F1: {max(history['val_f1']):.4f} at epoch {history['val_f1'].index(max(history['val_f1'])) + 1}")
print(f"  Final Train Loss: {history['train_loss'][-1]:.4f}")
print(f"  Final Val Loss: {history['val_loss'][-1]:.4f}")

In [None]:
print("\nGenerating test predictions...")

test_emb = np.load(PROTBERT_TEST_EMB).astype(np.float32)
print(f"  Test embeddings shape: {test_emb.shape}")

test_dataset = ProteinDataset(test_emb, np.zeros((len(test_emb), y.shape[1])))
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE * 2, 
                        shuffle=False, num_workers=0, pin_memory=True)

model.eval()
all_test_preds = []

with torch.no_grad():
    for embeddings, _ in tqdm(test_loader, desc="Predicting"):
        embeddings = embeddings.to(DEVICE)
        logits = model(embeddings)
        outputs = torch.sigmoid(logits)
        all_test_preds.append(outputs.cpu().numpy())

test_preds = np.vstack(all_test_preds)
print(f"  Predictions shape: {test_preds.shape}")

In [None]:
# 10. Propagate Predictions
if PROPAGATE_PRED:
    test_preds = propagate_predictions(
        test_preds, parents_map, list(mlb.classes_), PROPAGATE_ITERATIONS
    )

### Write Submission File

In [None]:
print(f"\nWriting submission to {OUTPUT_FILE}...")

test_ids = list(test_seqs.keys())[:len(test_preds)]

with open(OUTPUT_FILE, "w") as f:
    for i, protein_id in enumerate(tqdm(test_ids, desc="Writing")):
        probs = test_preds[i]
        top_indices = np.argsort(probs)[-TOP_K_PER_PROTEIN:][::-1]
        
        for idx in top_indices:
            score = float(probs[idx])
            if score > 1e-6:
                go_term = mlb.classes_[idx]
                f.write(f"{protein_id}\t{go_term}\t{score:.3f}\n")

print("\n" + "="*70)
print("Pipeline completed successfully")
print("="*70)

del model, train_loader, val_loader, test_loader
torch.cuda.empty_cache()
gc.collect()