In [None]:
# Environment configuration
ENVIRONMENT = 'local'  # Change to 'kaggle' when running on Kaggle

In [None]:
%pip install torch transformers pandas numpy scikit-learn tqdm biopython -q

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from tqdm.auto import tqdm
from Bio import SeqIO
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from collections import Counter
import json

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"‚úÖ Imports successful | Device: {device}")

In [None]:
# Set base directory
if ENVIRONMENT == 'kaggle':
    base_dir = Path("/kaggle/input/cafa-6-dataset")
else:
    base_dir = Path.cwd().parent

print(f"üìÅ Base directory: {base_dir}")

## 1. Configuration

In [None]:
# Hyperparameters
BATCH_SIZE = 8
GRADIENT_ACCUMULATION = 4
LEARNING_RATE = 2e-5
NUM_EPOCHS = 10
VOCAB_SIZE = 5000
MIN_COUNT = 10
PATIENCE = 3
DROPOUT = 0.1

# Model
MODEL_NAME = "facebook/esm2_t6_8M_UR50D"

# Paths
SAVE_DIR = base_dir.parent / "models" / "esm_finetuned" if ENVIRONMENT == 'local' else Path("/kaggle/working/models/esm_finetuned")
SAVE_DIR.mkdir(parents=True, exist_ok=True)

print(f"Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION}")
print(f"Vocabulary: Top {VOCAB_SIZE} GO terms")
print(f"Save directory: {SAVE_DIR}")

## 2. Load Data and Build Vocabulary

In [None]:
# Load sequences
print("Loading sequences...")
sequences = {}
for record in SeqIO.parse(base_dir / "Train" / "train_sequences.fasta", "fasta"):
    sequences[record.id] = str(record.seq)

print(f"Loaded {len(sequences)} sequences")

# Load annotations
print("Loading annotations...")
train_terms = pd.read_csv(base_dir / "Train" / "train_terms.tsv", sep='\t')
print(f"Total annotations: {len(train_terms)}")

In [None]:
# Build vocabulary: top N most frequent GO terms
print("Building vocabulary...")
term_counts = Counter(train_terms['term'])
filtered_terms = {term: count for term, count in term_counts.items() if count >= MIN_COUNT}
vocab_terms = [term for term, _ in sorted(filtered_terms.items(), key=lambda x: x[1], reverse=True)[:VOCAB_SIZE]]
vocab = {term: idx for idx, term in enumerate(vocab_terms)}

print(f"Vocabulary size: {len(vocab)}")
print(f"Coverage: {sum(filtered_terms[t] for t in vocab_terms) / len(train_terms):.2%} of annotations")

# Save vocabulary
with open(SAVE_DIR / "vocab.json", "w") as f:
    json.dump(vocab, f, indent=2)
print(f"‚úÖ Vocabulary saved")

## 3. Create Dataset Class

In [None]:
class ProteinDataset(Dataset):
    def __init__(self, protein_ids, sequences_dict, annotations_df, vocab_dict, tokenizer, max_length=512):
        self.protein_ids = protein_ids
        self.sequences = sequences_dict
        self.vocab = vocab_dict
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # Build protein -> terms mapping
        self.protein_terms = {}
        for protein_id in protein_ids:
            terms = annotations_df[annotations_df['EntryID'] == protein_id]['term'].tolist()
            # Filter to vocabulary
            terms_in_vocab = [t for t in terms if t in vocab_dict]
            self.protein_terms[protein_id] = terms_in_vocab
    
    def __len__(self):
        return len(self.protein_ids)
    
    def __getitem__(self, idx):
        protein_id = self.protein_ids[idx]
        sequence = self.sequences[protein_id]
        
        # Tokenize sequence
        tokens = self.tokenizer(
            sequence,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Create multi-hot label vector
        labels = torch.zeros(len(self.vocab), dtype=torch.float32)
        for term in self.protein_terms[protein_id]:
            labels[self.vocab[term]] = 1.0
        
        return {
            'input_ids': tokens['input_ids'].squeeze(0),
            'attention_mask': tokens['attention_mask'].squeeze(0),
            'labels': labels
        }

print("‚úÖ Dataset class defined")

## 4. Create Model Class

In [None]:
class ESMForGOPrediction(nn.Module):
    def __init__(self, model_name, num_labels, dropout=0.1):
        super().__init__()
        self.esm = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(self.esm.config.hidden_size, num_labels)
    
    def forward(self, input_ids, attention_mask):
        # Get ESM-2 embeddings
        outputs = self.esm(input_ids=input_ids, attention_mask=attention_mask)
        
        # Mean pooling
        token_embeddings = outputs.last_hidden_state
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        pooled = sum_embeddings / sum_mask
        
        # Classification
        pooled = self.dropout(pooled)
        logits = self.classifier(pooled)
        return logits

print("‚úÖ Model class defined")

## 5. Create Train/Val Split

In [None]:
# Split proteins
all_proteins = [p for p in train_terms['EntryID'].unique() if p in sequences]
train_proteins, val_proteins = train_test_split(
    all_proteins, test_size=0.2, random_state=42
)

print(f"Train proteins: {len(train_proteins)}")
print(f"Val proteins: {len(val_proteins)}")

In [None]:
# Load tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Create datasets
print("Creating datasets...")
train_dataset = ProteinDataset(train_proteins, sequences, train_terms, vocab, tokenizer)
val_dataset = ProteinDataset(val_proteins, sequences, train_terms, vocab, tokenizer)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"‚úÖ Train batches: {len(train_loader)}")
print(f"‚úÖ Val batches: {len(val_loader)}")

## 6. Initialize Model and Optimizer

In [None]:
# Initialize model
print("Initializing model...")
model = ESMForGOPrediction(MODEL_NAME, len(vocab), dropout=DROPOUT).to(device)

# Loss function (BCEWithLogitsLoss for multi-label)
criterion = nn.BCEWithLogitsLoss()

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

# Learning rate scheduler
total_steps = len(train_loader) * NUM_EPOCHS // GRADIENT_ACCUMULATION
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)

print(f"‚úÖ Model initialized | Parameters: {sum(p.numel() for p in model.parameters()):,}")

## 7. Training Loop

In [None]:
def evaluate_model(model, dataloader, criterion, device, threshold=0.5):
    """Evaluate model on validation set."""
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating", leave=False):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            logits = model(input_ids, attention_mask)
            loss = criterion(logits, labels)
            
            total_loss += loss.item()
            
            preds = (torch.sigmoid(logits) > threshold).cpu().numpy()
            all_preds.append(preds)
            all_labels.append(labels.cpu().numpy())
    
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)
    
    # Compute F1 (sample-averaged)
    f1_samples = []
    for pred, label in zip(all_preds, all_labels):
        if pred.sum() == 0:
            f1_samples.append(0.0)
        else:
            f1_samples.append(f1_score(label, pred, average='binary', zero_division=0))
    
    return total_loss / len(dataloader), np.mean(f1_samples)

print("‚úÖ Evaluation function defined")

In [None]:
print("Starting training...\n")

best_f1 = 0
patience_counter = 0
history = []

for epoch in range(NUM_EPOCHS):
    print(f"\n{'='*60}")
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
    print(f"{'='*60}")
    
    # Training
    model.train()
    train_loss = 0
    optimizer.zero_grad()
    
    progress = tqdm(train_loader, desc="Training")
    for batch_idx, batch in enumerate(progress):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        # Forward pass
        logits = model(input_ids, attention_mask)
        loss = criterion(logits, labels)
        loss = loss / GRADIENT_ACCUMULATION
        
        # Backward pass
        loss.backward()
        
        # Update weights every GRADIENT_ACCUMULATION steps
        if (batch_idx + 1) % GRADIENT_ACCUMULATION == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        train_loss += loss.item() * GRADIENT_ACCUMULATION
        progress.set_postfix({'loss': train_loss / (batch_idx + 1)})
    
    avg_train_loss = train_loss / len(train_loader)
    
    # Validation
    val_loss, val_f1 = evaluate_model(model, val_loader, criterion, device)
    
    print(f"\nTrain Loss: {avg_train_loss:.4f}")
    print(f"Val Loss: {val_loss:.4f}")
    print(f"Val F1: {val_f1:.4f}")
    
    history.append({
        'epoch': epoch + 1,
        'train_loss': avg_train_loss,
        'val_loss': val_loss,
        'val_f1': val_f1
    })
    
    # Save best model
    if val_f1 > best_f1:
        best_f1 = val_f1
        patience_counter = 0
        
        best_model_dir = SAVE_DIR / "best_model"
        best_model_dir.mkdir(exist_ok=True)
        torch.save(model.state_dict(), best_model_dir / "pytorch_model.bin")
        
        # Save config
        config = {
            'model_name': MODEL_NAME,
            'num_labels': len(vocab),
            'dropout': DROPOUT,
            'best_f1': float(best_f1)
        }
        with open(best_model_dir / "config.json", "w") as f:
            json.dump(config, f, indent=2)
        
        print(f"üèÜ New best model saved! F1: {best_f1:.4f}")
    else:
        patience_counter += 1
        print(f"No improvement. Patience: {patience_counter}/{PATIENCE}")
    
    # Early stopping
    if patience_counter >= PATIENCE:
        print(f"\n‚èπÔ∏è Early stopping triggered after {epoch+1} epochs")
        break

print(f"\n‚úÖ Training complete! Best F1: {best_f1:.4f}")

## 8. Save Training History

In [None]:
# Save history
history_df = pd.DataFrame(history)
history_df.to_csv(SAVE_DIR / "training_history.csv", index=False)

print("\nüìä Training History:")
print(history_df.to_string(index=False))

print(f"\n‚úÖ Model and history saved to {SAVE_DIR}")

## Summary

**ESM-2 Fine-Tuned Classifier:**
- Fine-tuned ESM-2 8M model with classification head
- Trained on 5000 most common GO terms
- Multi-label classification with BCE loss
- Expected F1: ~0.23

**Next:** 04_label_propagation.ipynb - Apply graph-based propagation to improve predictions