# 04: Label Propagation with ESM-2

**Graph-based label propagation**: Use GO ontology structure to propagate predictions to ancestor terms, improving consistency.

Expected improvement: +0.02-0.04 F1 over base model

**Prerequisites:** Trained model from `03_model_esm_finetuned.ipynb`

In [None]:
# Install dependencies
%pip install torch obonet biopython transformers scikit-learn -q

In [1]:
# Standard imports
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from pathlib import Path
from typing import Dict, List, Set, Optional, Union

# Bio imports
import networkx as nx
import obonet
from Bio import SeqIO

print("‚úÖ Imports successful")

‚úÖ Imports successful


In [5]:
# ========================================
# ENVIRONMENT CONFIGURATION
# ========================================
# Set your environment: 'local' or 'kaggle'
ENVIRONMENT = 'local'  # Change to 'kaggle' when running on Kaggle

print(f"üîß Environment: {ENVIRONMENT.upper()}")

üîß Environment: LOCAL


## 1. Helper Classes (Data Loaders)

Embedded versions of OntologyLoader, SequenceLoader, and LabelLoader.

In [6]:
class OntologyLoader:
    """Handles loading and traversing the Gene Ontology (GO) graph."""
    
    def __init__(self, obo_path: Union[str, Path]):
        self.obo_path = Path(obo_path)
        self.graph = None
        self._load()

    def _load(self):
        if not self.obo_path.exists():
            raise FileNotFoundError(f"OBO file not found: {self.obo_path}")
        print(f"Loading ontology from {self.obo_path}...")
        self.graph = obonet.read_obo(self.obo_path)
        print(f"Loaded {len(self.graph)} terms.")

    def get_ancestors(self, term: str) -> Set[str]:
        """Get all ancestor terms (parents, grandparents, etc.)"""
        if term not in self.graph:
            return set()
        return nx.ancestors(self.graph, term)

    def get_parents(self, term: str) -> Set[str]:
        """Get immediate parent terms only"""
        if term not in self.graph:
            return set()
        return set(self.graph.successors(term))

    def get_namespace(self, term: str) -> Optional[str]:
        """Get ontology aspect (MF, BP, or CC)"""
        if term in self.graph:
            return self.graph.nodes[term].get('namespace')
        return None


class SequenceLoader:
    """Handles loading protein sequences from FASTA files."""
    
    def __init__(self, fasta_path: Union[str, Path]):
        self.fasta_path = Path(fasta_path)
        self.sequences = {}
        self._load()

    def _load(self):
        if not self.fasta_path.exists():
            raise FileNotFoundError(f"FASTA file not found: {self.fasta_path}")
        
        print(f"Loading sequences from {self.fasta_path}...")
        for record in SeqIO.parse(self.fasta_path, "fasta"):
            header = record.id
            # Handle different FASTA formats
            if "|" in header:
                parts = header.split("|")
                clean_id = parts[1] if len(parts) >= 2 else header
            else:
                clean_id = header.split()[0]
            self.sequences[clean_id] = str(record.seq)
        print(f"Loaded {len(self.sequences)} sequences.")

    def get_sequence(self, protein_id: str) -> Optional[str]:
        return self.sequences.get(protein_id)

    def get_all_ids(self) -> List[str]:
        return list(self.sequences.keys())


class LabelLoader:
    """Handles loading ground truth GO annotations."""
    
    def __init__(self, tsv_path: Union[str, Path], ontology_loader: Optional[OntologyLoader] = None):
        self.tsv_path = Path(tsv_path)
        self.ontology = ontology_loader
        self.df = None
        self.protein_to_terms = {}
        self._load()

    def _load(self):
        if not self.tsv_path.exists():
            raise FileNotFoundError(f"Label file not found: {self.tsv_path}")
        
        print(f"Loading labels from {self.tsv_path}...")
        self.df = pd.read_csv(self.tsv_path, sep='\t')
        self.protein_to_terms = self.df.groupby('EntryID')['term'].apply(set).to_dict()
        print(f"Loaded annotations for {len(self.protein_to_terms)} proteins.")

    def get_terms(self, protein_id: str, propagate: bool = False) -> Set[str]:
        """Get GO terms for a protein, optionally with ancestors"""
        terms = self.protein_to_terms.get(protein_id, set())
        
        if propagate and self.ontology:
            propagated_terms = set(terms)
            for term in terms:
                propagated_terms.update(self.ontology.get_ancestors(term))
            return propagated_terms
        return terms


print("‚úÖ Data loader classes defined")

‚úÖ Data loader classes defined


## 2. Label Propagation Functions

Core propagation logic using GO hierarchy.

In [7]:
def propagate_predictions(
    predictions: Union[np.ndarray, torch.Tensor],
    term_list: List[str],
    ontology_loader: OntologyLoader,
    strategy: str = 'max'
) -> Union[np.ndarray, torch.Tensor]:
    """
    Propagate predictions to ancestor terms using GO hierarchy.
    
    Args:
        predictions: (N, K) array of probabilities for K GO terms
        term_list: List of K GO term IDs corresponding to columns
        ontology_loader: Loaded GO ontology graph
        strategy: 'max' (take max of current and children), 'copy', or 'threshold'
    
    Returns:
        Propagated predictions with same shape as input
    """
    # Convert to numpy for processing
    is_torch = isinstance(predictions, torch.Tensor)
    if is_torch:
        device = predictions.device
        predictions = predictions.cpu().numpy()
    
    propagated = predictions.copy()
    term_to_idx = {term: idx for idx, term in enumerate(term_list)}
    
    # For each term, propagate to its ancestors
    for child_idx, child_term in enumerate(term_list):
        ancestors = ontology_loader.get_ancestors(child_term)
        ancestor_indices = [term_to_idx[anc] for anc in ancestors if anc in term_to_idx]
        
        if not ancestor_indices:
            continue
        
        child_probs = propagated[:, child_idx]
        
        for anc_idx in ancestor_indices:
            if strategy == 'max':
                # Ancestor = max(current, child)
                propagated[:, anc_idx] = np.maximum(
                    propagated[:, anc_idx],
                    child_probs
                )
            elif strategy == 'copy':
                # Copy if child higher
                mask = propagated[:, anc_idx] < child_probs
                propagated[mask, anc_idx] = child_probs[mask]
            elif strategy == 'threshold':
                # Binary: set to 1 if child predicted
                mask = child_probs > 0.5
                propagated[mask, anc_idx] = 1.0
    
    # Convert back to torch if needed
    if is_torch:
        propagated = torch.from_numpy(propagated).to(device)
    
    return propagated


def get_propagated_terms(
    predicted_terms: Set[str],
    ontology_loader: OntologyLoader
) -> Set[str]:
    """Expand predicted terms to include all ancestors."""
    propagated = set(predicted_terms)
    for term in predicted_terms:
        ancestors = ontology_loader.get_ancestors(term)
        propagated.update(ancestors)
    return propagated


print("‚úÖ Propagation functions defined")

‚úÖ Propagation functions defined


## 3. Load GO Ontology

**üìÅ Update path to your GO ontology file:**

In [None]:
# Set base directory and data paths
if ENVIRONMENT == 'kaggle':
    base_dir = Path("/kaggle/input/cafa-6-dataset")
else:  # local
    base_dir = Path.cwd().parent

# Define all data paths
TRAIN_SEQ = base_dir / 'Train' / 'train_sequences.fasta'
TRAIN_TERMS = base_dir / 'Train' / 'train_terms.tsv'
GO_OBO = base_dir / 'Train' / 'go-basic.obo'
IA_TSV = base_dir / 'IA.tsv'
TEST_FASTA = base_dir / 'Test' / 'testsuperset.fasta'

print(f"üìÅ Base directory: {base_dir}")
print(f"üìÑ Data files:")
print(f"  - GO ontology: {GO_OBO.name}")
print(f"  - Training sequences: {TRAIN_SEQ.name}")
print(f"  - Training terms: {TRAIN_TERMS.name}")
print(f"  - IA weights: {IA_TSV.name}")

print("\nLoading GO ontology...")
ontology = OntologyLoader(GO_OBO)

print(f"Total GO terms: {len(ontology.graph)}")
print(f"Nodes: {ontology.graph.number_of_nodes()}")
print(f"Edges: {ontology.graph.number_of_edges()}")

üìÅ Base directory: c:\Users\Olale\Documents\Codebase\Science

Loading GO ontology...


FileNotFoundError: OBO file not found: c:\Users\Olale\Documents\Codebase\Science\Train\go-basic.obo

## 4. Test Propagation with Example

Verify propagation works: if we predict "nuclease activity", ancestors get boosted.

In [None]:
# Example: nuclease activity
nuclease = 'GO:0004518'
ancestors = ontology.get_ancestors(nuclease)

print(f"Term: {nuclease}")
if nuclease in ontology.graph:
    print(f"Name: {ontology.graph.nodes[nuclease].get('name', 'N/A')}")
    print(f"Namespace: {ontology.get_namespace(nuclease)}")

print(f"\nNumber of ancestors: {len(ancestors)}")
print(f"\nFirst 10 ancestors:")
for anc in list(ancestors)[:10]:
    name = ontology.graph.nodes[anc].get('name', 'N/A')
    print(f"  {anc}: {name}")

In [None]:
# Test array propagation
test_terms = [
    'GO:0003674',  # molecular_function (root)
    'GO:0016787',  # hydrolase activity
    'GO:0004518',  # nuclease activity (child of hydrolase)
    'GO:0008150',  # biological_process (root)
    'GO:0006281',  # DNA repair
]

# High confidence only for nuclease (0.9)
preds_before = np.array([[0.05, 0.10, 0.90, 0.03, 0.15]])

print("Before propagation:")
for term, prob in zip(test_terms, preds_before[0]):
    name = ontology.graph.nodes[term].get('name', 'N/A')
    print(f"  {term} ({name}): {prob:.3f}")

# Apply propagation
preds_after = propagate_predictions(preds_before, test_terms, ontology, strategy='max')

print("\nAfter propagation:")
for term, prob in zip(test_terms, preds_after[0]):
    name = ontology.graph.nodes[term].get('name', 'N/A')
    change = '‚úÖ' if prob > preds_before[0][test_terms.index(term)] else ''
    print(f"  {term} ({name}): {prob:.3f} {change}")

print(f"\nüéØ Hydrolase boosted: {preds_before[0][1]:.3f} ‚Üí {preds_after[0][1]:.3f}")

## 5. Load Training Data

**üìÅ Update paths to your training data:**

In [None]:
print("Loading sequences...")
seq_loader = SequenceLoader(TRAIN_SEQ)

print("Loading labels...")
label_loader = LabelLoader(TRAIN_TERMS)

print("Loading IA weights...")
ia_df = pd.read_csv(IA_TSV, sep='\t')
ia_weights = dict(zip(ia_df['term'], ia_df['IA']))

print(f"\nTotal proteins: {len(seq_loader.sequences)}")
print(f"Total annotations: {len(label_loader.df)}")
print(f"IA weights available: {len(ia_weights)}")

## 6. Evaluation Helper Function

In [None]:
def compute_f1_with_threshold(y_true, y_pred, threshold, ia_weights_dict, term_list):
    """
    Compute weighted F1 score using IA weights.
    
    Args:
        y_true: (N, K) binary labels
        y_pred: (N, K) probabilities
        threshold: float, prediction threshold
        ia_weights_dict: dict mapping GO terms to IA weights
        term_list: list of GO term IDs
    """
    # Threshold predictions
    y_pred_binary = (y_pred >= threshold).astype(int)
    
    # Get IA weights for vocabulary
    weights = np.array([ia_weights_dict.get(term, 1.0) for term in term_list])
    
    # Compute per-sample F1
    f1_scores = []
    precisions = []
    recalls = []
    
    for i in range(len(y_true)):
        true_pos = (y_true[i] == 1) & (y_pred_binary[i] == 1)
        pred_pos = (y_pred_binary[i] == 1)
        actual_pos = (y_true[i] == 1)
        
        # Weighted counts
        tp = (true_pos * weights).sum()
        fp = ((pred_pos & ~true_pos) * weights).sum()
        fn = ((actual_pos & ~true_pos) * weights).sum()
        
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        
        f1_scores.append(f1)
        precisions.append(precision)
        recalls.append(recall)
    
    return {
        'f1': np.mean(f1_scores),
        'precision': np.mean(precisions),
        'recall': np.mean(recalls)
    }

print("‚úÖ Evaluation function defined")

## 7. Load Your Model and Generate Predictions

**‚ö†Ô∏è CRITICAL: Replace this section with your actual model loading and prediction code.**

This is a placeholder. You need to:
1. Load your trained model (ESM-2, CNN, etc.)
2. Create your validation dataset
3. Generate predictions in shape (N_samples, N_terms)
4. Generate true labels in same shape

In [None]:
from transformers import EsmModel, AutoConfig, AutoTokenizer
from torch.utils.data import Dataset, DataLoader, Subset
from collections import Counter

class ESMForGOPrediction(torch.nn.Module):
    """ESM-2 model with classification head for GO term prediction."""
    
    def __init__(self, model_name: str = "facebook/esm2_t6_8M_UR50D", 
                 num_labels: int = 5000, dropout: float = 0.3):
        super().__init__()
        self.model_name = model_name
        self.num_labels = num_labels
        
        # Load pre-trained ESM-2
        self.esm = EsmModel.from_pretrained(model_name)
        config = AutoConfig.from_pretrained(model_name)
        self.hidden_dim = config.hidden_size
        
        # Classification head
        self.classifier = torch.nn.Sequential(
            torch.nn.Dropout(dropout),
            torch.nn.Linear(self.hidden_dim, num_labels)
        )
    
    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        """Forward pass with mean pooling."""
        outputs = self.esm(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state
        
        # Mean pooling (ignore padding)
        mask_expanded = attention_mask.unsqueeze(-1).expand(sequence_output.size()).float()
        sum_embeddings = torch.sum(sequence_output * mask_expanded, dim=1)
        sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
        pooled_output = sum_embeddings / sum_mask
        
        logits = self.classifier(pooled_output)
        return logits
    
    @classmethod
    def from_pretrained(cls, load_directory: str):
        """Load model from directory."""
        import json
        load_path = Path(load_directory)
        
        with open(load_path / "config.json", "r") as f:
            config = json.load(f)
        
        model = cls(model_name=config["model_name"], num_labels=config["num_labels"])
        state_dict = torch.load(load_path / "pytorch_model.bin", map_location="cpu")
        model.load_state_dict(state_dict)
        
        return model


class FineTuneDataset(Dataset):
    """Dataset for ESM-2 fine-tuning."""
    
    def __init__(self, sequences: dict, labels: dict, vocab: list, 
                 model_name: str = "facebook/esm2_t6_8M_UR50D", max_length: int = 512):
        self.sequences = sequences
        self.labels = labels
        self.vocab = vocab
        self.max_length = max_length
        
        self.term_to_idx = {term: idx for idx, term in enumerate(vocab)}
        self.num_classes = len(vocab)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        self.protein_ids = [pid for pid in sequences.keys() 
                           if pid in labels and len(labels[pid]) > 0]
    
    def __len__(self):
        return len(self.protein_ids)
    
    def __getitem__(self, idx):
        protein_id = self.protein_ids[idx]
        sequence = self.sequences[protein_id]
        go_terms = self.labels[protein_id]
        
        # Tokenize
        inputs = self.tokenizer(sequence, max_length=self.max_length, 
                               padding='max_length', truncation=True, 
                               return_tensors='pt')
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        
        # Create label vector
        label_vector = torch.zeros(self.num_classes, dtype=torch.float32)
        for term in go_terms:
            if term in self.term_to_idx:
                label_vector[self.term_to_idx[term]] = 1.0
        
        return {'input_ids': inputs['input_ids'], 
                'attention_mask': inputs['attention_mask'],
                'labels': label_vector}


print("‚úÖ Model and dataset classes defined")

## 8. Build Vocabulary and Create Validation Dataset

In [None]:
# ============================================
# PLACEHOLDER - REPLACE WITH YOUR MODEL CODE
# ============================================

# Example structure (uncomment and adapt to your model):

# from transformers import AutoModel, AutoTokenizer
# import torch.nn as nn

# # 1. Load your model
# model_path = base_dir / "models" / "best_model"
# model = YourModelClass.from_pretrained(model_path)
# model.eval()

# # 2. Create validation dataset
# # ... your dataset code ...

# # 3. Generate predictions
# all_preds = []  # Shape: (N_samples, N_terms)
# all_labels = []  # Shape: (N_samples, N_terms)
# term_list = [...]  # List of GO term IDs corresponding to columns

# with torch.no_grad():
#     for batch in dataloader:
#         outputs = model(batch)
#         probs = torch.sigmoid(outputs.logits)
#         all_preds.append(probs.cpu().numpy())
#         all_labels.append(batch['labels'].cpu().numpy())

# all_preds = np.vstack(all_preds)
# all_labels = np.vstack(all_labels)

# ============================================
# FOR DEMO: Create dummy predictions
# ============================================
print("‚ö†Ô∏è Using DUMMY predictions for demonstration")
print("Replace this with your actual model predictions!\n")

# Dummy data (500 samples, 100 terms)
n_samples = 500
n_terms = 100
term_list = [f"GO:{str(i).zfill(7)}" for i in range(3674, 3674 + n_terms)]

# Random predictions and labels
np.random.seed(42)
all_preds = np.random.rand(n_samples, n_terms) * 0.3  # Low confidence predictions
all_labels = (np.random.rand(n_samples, n_terms) > 0.95).astype(int)  # Sparse labels

print(f"Predictions shape: {all_preds.shape}")
print(f"Labels shape: {all_labels.shape}")
print(f"Vocabulary size: {len(term_list)}")

## 9. Load Fine-Tuned Model and Generate Predictions

In [None]:
# Load model
model_path = base_dir / "models" / "esm_finetuned" / "best_model"

print(f"Loading model from {model_path}...")

if not model_path.exists():
    print("‚ùå Model not found!")
    print("‚ö†Ô∏è Using DUMMY predictions for demonstration\n")
    
    # Dummy predictions
    np.random.seed(42)
    n_samples = len(val_dataset)
    n_terms = len(term_list)
    all_preds = np.random.rand(n_samples, n_terms) * 0.3
    all_labels = np.zeros((n_samples, n_terms))
    for i in range(n_samples):
        sample = val_dataset[i]
        all_labels[i] = sample['labels'].numpy()
    
    print(f"Predictions shape: {all_preds.shape}")
    print(f"Labels shape: {all_labels.shape}")
else:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")
    
    model = ESMForGOPrediction.from_pretrained(str(model_path))
    model.to(device)
    model.eval()
    
    print(f"‚úÖ Model loaded")
    print(f"Output dimension: {model.num_labels}")
    
    # Generate predictions
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
    
    print("\nGenerating predictions...")
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Inference"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].cpu().numpy()
            
            # Forward pass
            logits = model(input_ids=input_ids, attention_mask=attention_mask)
            probs = torch.sigmoid(logits).cpu().numpy()
            
            all_preds.append(probs)
            all_labels.append(labels)
    
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)
    
    print(f"\nPredictions shape: {all_preds.shape}")
    print(f"Labels shape: {all_labels.shape}")

## 8. Evaluate WITHOUT Propagation (Baseline)

In [None]:
print("Evaluating WITHOUT propagation...")
print("Testing thresholds...\n")

thresholds = [0.01, 0.05, 0.10, 0.15, 0.20, 0.30, 0.40, 0.50]
results_baseline = []

for thr in thresholds:
    metrics = compute_f1_with_threshold(
        all_labels, 
        all_preds, 
        thr, 
        ia_weights, 
        term_list
    )
    results_baseline.append({
        'threshold': thr,
        'f1': metrics['f1'],
        'precision': metrics['precision'],
        'recall': metrics['recall']
    })
    print(f"Thr={thr:.2f}: F1={metrics['f1']:.4f}, P={metrics['precision']:.4f}, R={metrics['recall']:.4f}")

best_baseline = max(results_baseline, key=lambda x: x['f1'])
print(f"\nüèÜ Best WITHOUT propagation:")
print(f"   F1 = {best_baseline['f1']:.4f} at threshold {best_baseline['threshold']}")

## 9. Apply Propagation

In [None]:
print("Applying label propagation...")

# Convert to torch and propagate
all_preds_torch = torch.from_numpy(all_preds)
all_preds_propagated = propagate_predictions(
    all_preds_torch,
    term_list,
    ontology,
    strategy='max'
)
all_preds_propagated = all_preds_propagated.numpy()

print(f"‚úÖ Propagation complete")
print(f"Shape unchanged: {all_preds_propagated.shape}")

# Check impact
increased = (all_preds_propagated > all_preds).sum()
total = all_preds.size
print(f"Probabilities boosted: {increased:,} / {total:,} ({100*increased/total:.2f}%)")

## 10. Evaluate WITH Propagation

In [None]:
print("Evaluating WITH propagation...")
print("Testing thresholds...\n")

results_propagated = []

for thr in thresholds:
    metrics = compute_f1_with_threshold(
        all_labels, 
        all_preds_propagated, 
        thr, 
        ia_weights, 
        term_list
    )
    results_propagated.append({
        'threshold': thr,
        'f1': metrics['f1'],
        'precision': metrics['precision'],
        'recall': metrics['recall']
    })
    print(f"Thr={thr:.2f}: F1={metrics['f1']:.4f}, P={metrics['precision']:.4f}, R={metrics['recall']:.4f}")

best_propagated = max(results_propagated, key=lambda x: x['f1'])
print(f"\nüèÜ Best WITH propagation:")
print(f"   F1 = {best_propagated['f1']:.4f} at threshold {best_propagated['threshold']}")

## 11. Compare Results

In [None]:
print("=" * 60)
print("üìä COMPARISON")
print("=" * 60)

print(f"\nBaseline (no propagation):")
print(f"  F1:        {best_baseline['f1']:.4f}")
print(f"  Precision: {best_baseline['precision']:.4f}")
print(f"  Recall:    {best_baseline['recall']:.4f}")
print(f"  Threshold: {best_baseline['threshold']}")

print(f"\nWith Propagation:")
print(f"  F1:        {best_propagated['f1']:.4f}")
print(f"  Precision: {best_propagated['precision']:.4f}")
print(f"  Recall:    {best_propagated['recall']:.4f}")
print(f"  Threshold: {best_propagated['threshold']}")

improvement = best_propagated['f1'] - best_baseline['f1']
pct_improvement = 100 * improvement / best_baseline['f1'] if best_baseline['f1'] > 0 else 0

print(f"\n{'üéâ' if improvement > 0 else '‚ö†Ô∏è'} Improvement:")
print(f"  Œî F1:      {improvement:+.4f} ({pct_improvement:+.2f}%)")

if best_propagated['f1'] >= 0.25:
    print(f"\n‚úÖ TARGET REACHED! F1 ‚â• 0.25")
else:
    print(f"\n‚ö†Ô∏è Target not reached (goal: 0.25, got: {best_propagated['f1']:.4f})")

print("=" * 60)

## 12. Visualize Results

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: F1 curves
ax1 = axes[0]
baseline_f1s = [r['f1'] for r in results_baseline]
propagated_f1s = [r['f1'] for r in results_propagated]

ax1.plot(thresholds, baseline_f1s, 'o-', label='Without Propagation', linewidth=2)
ax1.plot(thresholds, propagated_f1s, 's-', label='With Propagation', linewidth=2)
ax1.axhline(y=0.25, color='red', linestyle='--', alpha=0.5, label='Target (0.25)')
ax1.set_xlabel('Threshold', fontsize=12)
ax1.set_ylabel('F1 Score', fontsize=12)
ax1.set_title('F1 Score vs Threshold', fontsize=14, fontweight='bold')
ax1.legend()
ax1.grid(alpha=0.3)

# Right: Precision-Recall
ax2 = axes[1]
baseline_prec = [r['precision'] for r in results_baseline]
baseline_rec = [r['recall'] for r in results_baseline]
propagated_prec = [r['precision'] for r in results_propagated]
propagated_rec = [r['recall'] for r in results_propagated]

ax2.plot(baseline_rec, baseline_prec, 'o-', label='Without Propagation', linewidth=2)
ax2.plot(propagated_rec, propagated_prec, 's-', label='With Propagation', linewidth=2)
ax2.set_xlabel('Recall', fontsize=12)
ax2.set_ylabel('Precision', fontsize=12)
ax2.set_title('Precision-Recall Curve', fontsize=14, fontweight='bold')
ax2.legend()
ax2.grid(alpha=0.3)

plt.tight_layout()
plt.show()

print("üìà Plots generated")

## 13. Save Results

In [None]:
# Save comparison to CSV
results_df = pd.DataFrame({
    'threshold': thresholds,
    'f1_baseline': [r['f1'] for r in results_baseline],
    'f1_propagated': [r['f1'] for r in results_propagated],
    'precision_baseline': [r['precision'] for r in results_baseline],
    'precision_propagated': [r['precision'] for r in results_propagated],
    'recall_baseline': [r['recall'] for r in results_baseline],
    'recall_propagated': [r['recall'] for r in results_propagated],
})

output_path = Path("propagation_comparison.csv")
results_df.to_csv(output_path, index=False)

print(f"‚úÖ Results saved to {output_path}")
print("\nüìä Results preview:")
print(results_df)

## 14. Summary

**What we did:**
1. ‚úÖ Implemented label propagation using GO hierarchy
2. ‚úÖ Evaluated with/without propagation
3. ‚úÖ Measured F1 improvement

**Key takeaway:** Propagation ensures ontological consistency ‚Äî if you predict a specific term, all its parents should also be predicted. This typically boosts F1 by 2-4%.

**Next steps:**
- Per-aspect thresholds (MF/BP/CC separate optimization)
- Simple ensemble (KNN + ESM)
- Larger backbone (ESM-2 35M or 150M)