In [None]:
import os
from torch.utils.data import DataLoader, random_split, Dataset
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
from sklearn.preprocessing import MultiLabelBinarizer
from tqdm import tqdm
from torchmetrics.classification import MultilabelF1Score
import matplotlib.pyplot as plt
from collections import defaultdict

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
IS_LOCAL = not os.path.isdir("/kaggle/input")
CHECKPOINT_PATH = "/kaggle/input/checkpoint-310-cafa6/pytorch/default/1"
TRAIN_TERMS_PATH = '/kaggle/input/cafa-6-protein-function-prediction/Train/train_terms.tsv'
TRAIN_SEQUENCES_PATH = '/kaggle/input/cafa-6-protein-function-prediction/Train/train_sequences.fasta'
TEST_SEQUENCES_PATH = '/kaggle/input/cafa-6-protein-function-prediction/Test/testsuperset.fasta'
PROT_EMBEDS = '/kaggle/input/protein-embeddings/protein_embeddings.npy'
PIDS = '/kaggle/input/protein-embeddings/protein_id.csv'
OBO_PATH = '/kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo'
GOA_PATH = '/kaggle/input/protein-go-annotations/goa_uniprot_all.csv'

if IS_LOCAL:
    CHECKPOINT_PATH = "./models"
    TRAIN_TERMS_PATH = './data/cafa-6-protein-function-prediction/Train/train_terms.tsv'
    TRAIN_SEQUENCES_PATH = './data/cafa-6-protein-function-prediction/Train/train_sequences.fasta'
    TEST_SEQUENCES_PATH = './data/cafa-6-protein-function-prediction/Test/testsuperset.fasta'
    PIDS = "./data/cafa-6-protein-function-prediction/protein-embeddings/protein_id.csv"
    PROT_EMBEDS = "./data/cafa-6-protein-function-prediction/protein-embeddings/protein_embeddings.npy"
    OBO_PATH = './data/cafa-6-protein-function-prediction/Train/go-basic.obo'
    GOA_PATH = './data/cafa-6-protein-function-prediction/goa_uniprot_all.csv' 

In [None]:
protein_ids = pd.read_csv(PIDS)["protein_id"].tolist()
embeddings = np.load(PROT_EMBEDS)
embeddings_dict = {pid: emb for pid, emb in zip(protein_ids, embeddings)}
print(f"Loaded {len(protein_ids)} embeddings of dimension {embeddings.shape[1]}")

In [None]:
def parse_fasta(fasta_file) -> dict[str, str]:
    sequences = {}
    current_id = None
    current_seq = []
    with open(fasta_file, 'r') as f:
        for line in f:
            line = line.strip()
            if line.startswith('>'):
                if current_id:
                    sequences[current_id] = ''.join(current_seq)
                parts = line[1:].split('|')
                if len(parts) >= 2:
                    current_id = parts[1]
                else:
                    current_id = line[1:].split()[0]
                current_seq = []
            else:
                current_seq.append(line)
        if current_id:
            sequences[current_id] = ''.join(current_seq)
    return sequences

train_terms_df = pd.read_csv(TRAIN_TERMS_PATH, sep='\t')
train_sequences = parse_fasta(TRAIN_SEQUENCES_PATH)
test_sequences = parse_fasta(TEST_SEQUENCES_PATH)
print(f"Train size: {len(train_sequences)}, Test size: {len(test_sequences)}")
print(f"Total annotations: {len(train_terms_df)}")

In [None]:
# Thanks to https://www.kaggle.com/code/seddiktrk/cafa-6-blend-goa-negative-propagation
def parse_obo(go_obo_path):
    parents = defaultdict(set)
    children = defaultdict(set)
    
    if not os.path.exists(go_obo_path): 
        return parents, children
        
    with open(go_obo_path,"r") as f:
        cur_id=None
        for line in f:
            line=line.strip()
            if line=="[Term]": 
                cur_id=None
            elif line.startswith("id: "): 
                cur_id=line.split("id: ")[1].strip()
            elif line.startswith("is_a: "):
                pid=line.split()[1].strip()
                if cur_id: 
                    parents[cur_id].add(pid)
                    children[pid].add(cur_id)
            elif line.startswith("relationship: part_of "):
                parts=line.split(); 
                if len(parts)>=3:
                    pid=parts[2].strip()
                    if cur_id: 
                        parents[cur_id].add(pid)
                        children[pid].add(cur_id)
    print(f"[io] Parsed OBO: {len(parents)} nodes with parents")
    return parents, children

def get_all_ancestors(term, go_parents, cache=None):
    if cache is None:
        cache = {}
    if term in cache:
        return cache[term]
    
    ancestors = set()
    stack = [term]
    while stack:
        cur = stack.pop()
        for parent in go_parents.get(cur, []):
            if parent not in ancestors:
                ancestors.add(parent)
                stack.append(parent)
    
    cache[term] = ancestors
    return ancestors

def get_all_descendants(term, go_children, cache=None):
    if cache is None:
        cache = {}
    if term in cache:
        return cache[term]
    
    descendants = set()
    stack = [term]
    while stack:
        cur = stack.pop()
        for child in go_children.get(cur, []):
            if child not in descendants:
                descendants.add(child)
                stack.append(child)
    
    cache[term] = descendants
    return descendants

# Load GO hierarchy
go_parents, go_children = parse_obo(OBO_PATH)

In [None]:
# ===== GOA NEGATIVE PROPAGATION =====
def load_goa_and_build_negative_keys(goa_path, go_children):
    """
    Load GOA annotations and build:
    1. negative_keys: protein-GO pairs that should be REMOVED (NOT annotations + their descendants)
    2. goa_ground_truth: known positive annotations to ADD with score 1.0
    """
    if not os.path.exists(goa_path):
        print(f"[WARNING] GOA file not found at {goa_path}, skipping negative propagation")
        return set(), None
    
    print(f"Loading GOA annotations from {goa_path}...")
    goa_df = pd.read_csv(goa_path)
    goa_df = goa_df.drop_duplicates()
    print(f"Loaded {len(goa_df)} GOA annotations")
    
    # 1. Extract NEGATIVE annotations (NOT qualifiers)
    print("Extracting negative annotations...")
    negative_annots = goa_df[goa_df['qualifier'].str.contains('NOT', na=False)]
    negative_annots = negative_annots[['protein_id', 'go_term']].drop_duplicates()
    negative_by_protein = negative_annots.groupby('protein_id')['go_term'].apply(list).to_dict()
    
    # Propagate negatives to descendants
    print("Propagating negative annotations to descendants...")
    desc_cache = {}
    negative_keys = set()
    for protein, terms in tqdm(negative_by_protein.items(), desc="Propagating negatives"):
        all_negative_terms = set(terms)
        for term in terms:
            all_negative_terms |= get_all_descendants(term, go_children, desc_cache)
        for term in all_negative_terms:
            negative_keys.add(f"{protein}_{term}")
    
    print(f"Total unique negative protein-GO pairs: {len(negative_keys)}")
    
    # 2. Extract POSITIVE annotations (ground truth)
    print("Extracting positive GOA ground truth...")
    positive_annots = goa_df[~goa_df['qualifier'].str.contains('NOT', na=False)]
    positive_annots = positive_annots[['protein_id', 'go_term']].drop_duplicates()
    positive_annots['score'] = 1.0
    positive_annots['pred_key'] = positive_annots['protein_id'].astype(str) + '_' + positive_annots['go_term'].astype(str)
    # Remove any that are also negative
    positive_annots = positive_annots[~positive_annots['pred_key'].isin(negative_keys)]
    
    print(f"Total positive GOA ground truth pairs: {len(positive_annots)}")
    
    return negative_keys, positive_annots


def propagate_predictions(predictions_df, go_parents):
    print("Propagating predictions to ancestor GO terms...")
    
    # Build ancestor cache for efficiency
    ancestor_cache = {}
    
    # Group by protein
    propagated_rows = []
    
    for pid, group in tqdm(predictions_df.groupby('pid'), desc="Propagating"):
        term_scores = {}
        
        for _, row in group.iterrows():
            term = row['term']
            score = row['p']
            
            # Add original prediction
            if term not in term_scores or score > term_scores[term]:
                term_scores[term] = score
            
            # Propagate to all ancestors
            ancestors = get_all_ancestors(term, go_parents, ancestor_cache)
            for ancestor in ancestors:
                if ancestor not in term_scores or score > term_scores[ancestor]:
                    term_scores[ancestor] = score
        
        # Convert back to rows
        for term, score in term_scores.items():
            propagated_rows.append({
                'pid': pid,
                'term': term,
                'p': score
            })
    
    propagated_df = pd.DataFrame(propagated_rows)
    print(f"Before propagation: {len(predictions_df)} predictions")
    print(f"After propagation: {len(propagated_df)} predictions")
    
    return propagated_df

def apply_negative_propagation(predictions_df, negative_keys):
    if not negative_keys:
        return predictions_df
    
    predictions_df['pred_key'] = predictions_df['pid'].astype(str) + '_' + predictions_df['term'].astype(str)
    before_count = len(predictions_df)
    predictions_df = predictions_df[~predictions_df['pred_key'].isin(negative_keys)]
    after_count = len(predictions_df)
    predictions_df = predictions_df.drop(columns=['pred_key'])
    
    print(f"Removed {before_count - after_count} negative predictions ({before_count} -> {after_count})")
    return predictions_df

def add_goa_ground_truth(predictions_df, goa_positive_df):
    """Add GOA ground truth annotations with score 1.0"""
    if goa_positive_df is None or len(goa_positive_df) == 0:
        return predictions_df
        
    # Get test proteins from predictions
    test_proteins = set(predictions_df['pid'].unique())
    
    # Filter GOA to only include test proteins
    goa_for_test = goa_positive_df[goa_positive_df['protein_id'].isin(test_proteins)].copy()
    goa_for_test = goa_for_test.rename(columns={'protein_id': 'pid', 'go_term': 'term', 'score': 'p'})
    goa_for_test = goa_for_test[['pid', 'term', 'p']]
    
    print(f"Found {len(goa_for_test)} GOA annotations for test proteins")
    
    # Combine, keeping max score for duplicates
    combined = pd.concat([predictions_df, goa_for_test], ignore_index=True)
    combined = combined.groupby(['pid', 'term'])['p'].max().reset_index()
    
    print(f"After adding GOA: {len(predictions_df)} -> {len(combined)} predictions")
    return combined

In [None]:
negative_keys, goa_positive_df = load_goa_and_build_negative_keys(GOA_PATH, go_children)

In [None]:
class ProtDataset(Dataset):
    def __init__(self, pids, labels, embeddings_dict):
        self.pids = pids
        self.labels = labels
        self.embeddings_dict = embeddings_dict
        
    def __len__(self):
        return len(self.pids)

    def __getitem__(self, idx):
        embed = self.embeddings_dict[self.pids[idx]]
        if hasattr(self.labels, 'toarray'):
            label_array = self.labels[idx].toarray().flatten()
        else:
            label_array = self.labels[idx]
        label_tensor = torch.tensor(label_array, dtype=torch.float32)
        return torch.from_numpy(embed).float(), label_tensor

class TestDataSet(Dataset):
    def __init__(self, pids, embeddings_dict):
        self.pids = pids
        self.embeddings_dict = embeddings_dict
        
    def __len__(self):
        return len(self.pids)

    def __getitem__(self, idx):
        return torch.from_numpy(self.embeddings_dict[self.pids[idx]]).float()

In [None]:
class SimpleMLP(nn.Module):
    def __init__(self, input_dim=1280, num_classes=3000):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
        
    def forward(self, x):
        return self.network(x)

def get_improved_model(num_classes):
    return SimpleMLP(input_dim=1280, num_classes=num_classes)

In [None]:
# ===== VISUALIZATION =====
def plot_losses_and_scores(train_losses, train_scores, val_scores, aspect_name):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    ax1.plot(train_losses, 'r-', marker='o', label='Train Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title(f'{aspect_name} - Training Loss')
    ax1.legend()
    ax1.grid(True)
    
    ax2.plot(train_scores, 'b-', marker='o', label='Train F1')
    ax2.plot(val_scores, 'g-', marker='s', label='Val F1')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('F1 Score')
    ax2.set_title(f'{aspect_name} - F1 Scores')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.savefig(f'{aspect_name}_training_curves.png', dpi=150)
    plt.show()

def train(aspect_name, model, train_loader, valid_loader, num_epochs, num_classes, lr):
    model = model.to(device)
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=2
    )
    
    torch_f1_score = MultilabelF1Score(num_labels=num_classes, threshold=0.05, average='micro').to(device=device)
    
    train_losses = []
    train_scores = []
    val_scores = []
    
    for epoch in range(1, num_epochs + 1):
        model.train()
        train_loss = 0.0
        
        for X, y in train_loader:
            X = X.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            
            outputs = model(X)
            loss = loss_fn(outputs, y)
        
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * X.size(0)
        
        avg_train_loss = train_loss / len(train_loader.dataset)
        train_losses.append(avg_train_loss)
        
        # Validation
        model.eval()
        all_val_preds = []
        all_val_labels = []
        val_loss = 0.0
        
        for valid_seqs, valid_labels in valid_loader:
            valid_seqs = valid_seqs.to(device)
            valid_labels = valid_labels.to(device)
            
            with torch.no_grad():
                outputs = model(valid_seqs)
                batch_loss = loss_fn(outputs, valid_labels)
                val_loss += batch_loss.item() * valid_seqs.size(0)
                preds = torch.sigmoid(outputs)
                all_val_preds.append(preds)
                all_val_labels.append(valid_labels)
        
        all_val_preds = torch.vstack(all_val_preds)
        all_val_labels = torch.vstack(all_val_labels)
        val_f1 = torch_f1_score(all_val_preds, all_val_labels).item()
        val_scores.append(val_f1)
        
        avg_val_loss = val_loss / len(valid_loader.dataset)
        scheduler.step(avg_val_loss)
        
        # print(f"Epoch {epoch}/{num_epochs} | Train Loss: {avg_train_loss:.6f} | Val Loss: {avg_val_loss:.6f} | Val F1: {val_f1:.4f}")
    
    plot_losses_and_scores(train_losses, train_scores, val_scores, aspect_name)
    
    optimizer = None
    scheduler = None
    torch.cuda.empty_cache()
    return model

In [None]:
# Aspect codes: C = Cellular Component (CCO), F = Molecular Function (MFO), P = Biological Process (BPO)
ASPECTS = ['C', 'F', 'P']
ASPECT_NAMES = {'C': 'Cellular Component (CCO)', 'F': 'Molecular Function (MFO)', 'P': 'Biological Process (BPO)'}
aspect_models = {}
aspect_mlbs = {}

def train_aspect_model(aspect_name):
    global aspect_models, aspect_mlbs
    
    aspect_df = train_terms_df[train_terms_df['aspect'] == aspect_name]
    print(f"Training model for {aspect_name}")
    
    # Group terms by protein for this aspect
    protein_2_terms = aspect_df.groupby('EntryID')['term'].apply(list).to_dict()
    
    # Only include proteins that have:
    # 1. Embeddings available
    # 2. At least one GO term for this aspect
    pid_train = [pid for pid in train_sequences.keys() 
                 if pid in embeddings_dict and pid in protein_2_terms]
    
    labels_list = [protein_2_terms[pid] for pid in pid_train]
    
    mlb = MultiLabelBinarizer(sparse_output=True)
    y_train_labels = mlb.fit_transform(labels_list)
    aspect_mlbs[aspect_name] = mlb
    
    num_classes = len(mlb.classes_)
    print(f"{aspect_name} - Number of proteins: {y_train_labels.shape[0]}, GO terms: {num_classes}")
    

    train_dataset = ProtDataset(pid_train, y_train_labels, embeddings_dict)
    train_size = int(0.9 * len(train_dataset))
    valid_size = len(train_dataset) - train_size
    
    train_part, valid_part = random_split(train_dataset, [train_size, valid_size])
    train_loader = DataLoader(train_part, batch_size=128, shuffle=True, num_workers=4)
    valid_loader = DataLoader(valid_part, batch_size=128, shuffle=False, num_workers=4)
    
    num_classes = len(mlb.classes_)
    epochs = 30
    lr = 1e-3
    
    model = train(
        aspect_name,
        get_improved_model(num_classes),
        train_loader, valid_loader, epochs,
        num_classes, lr
    )
    
    aspect_models[aspect_name] = model
    
    del train_loader, valid_loader, train_dataset, train_part, valid_part
    del y_train_labels
    torch.cuda.empty_cache()
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    return model, mlb

In [None]:
for aspect in ASPECTS:
    train_aspect_model(aspect)
    torch.save({
        "model": aspect_models[aspect].state_dict(),
    }, f"checkpoint-{aspect}.pth")

In [None]:
def predict(threshold=0.02, min_predictions=25):
    
    pid_test = [pid for pid in test_sequences.keys() if pid in embeddings_dict]
    
    test_dataset = TestDataSet(pid_test, embeddings_dict)
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=3)
    
    submission_list = []
    
    for aspect in ASPECTS:
        model = aspect_models[aspect]
        mlb = aspect_mlbs[aspect]
        model.to(device)
        model.eval()
        
        batch_start_idx = 0
        
        for test_embeds in tqdm(test_loader, desc=f"Predicting ({aspect})"):
            test_embeds = test_embeds.to(device)
            batch_size = test_embeds.shape[0]
            
            with torch.no_grad():
                outputs = model(test_embeds)
                preds = torch.sigmoid(outputs).cpu()
            
            for i in range(batch_size):
                protein_id = pid_test[batch_start_idx + i]
                pred_scores = preds[i].numpy()
                
                pred_term_ids = np.where(pred_scores >= threshold)[0]
                
                if len(pred_term_ids) < min_predictions:
                    pred_term_ids = np.argsort(pred_scores)[-min_predictions:][::-1]
                
                for term_idx in pred_term_ids:
                    submission_list.append({
                        'pid': protein_id,
                        'term': mlb.classes_[term_idx],
                        'p': pred_scores[term_idx].item()
                    })
            
            batch_start_idx += batch_size
            
            del test_embeds, outputs, preds
            if batch_start_idx % 5000 == 0:
                torch.cuda.empty_cache()
    
    submission_df = pd.DataFrame(submission_list)
    print(f"\nTotal predictions: {len(submission_df)}")
    
    return submission_df

In [None]:
def get_aspect_cls(aspect_name):
    aspect_df = train_terms_df[train_terms_df['aspect'] == aspect_name]
    # Group terms by protein for this aspect
    protein_2_terms = aspect_df.groupby('EntryID')['term'].apply(list).to_dict()
    
    # Only include proteins that have:
    # 1. Embeddings available
    # 2. At least one GO term for this aspect
    pid_train = [pid for pid in train_sequences.keys() 
                 if pid in embeddings_dict and pid in protein_2_terms]
    
    labels_list = [protein_2_terms[pid] for pid in pid_train]
    
    mlb = MultiLabelBinarizer(sparse_output=True)
    y_train_labels = mlb.fit_transform(labels_list)
    aspect_mlbs[aspect_name] = mlb
    
    return len(mlb.classes_)

In [None]:
for aspect in ASPECTS:
    checkpoint = torch.load(f"/kaggle/working/checkpoint-{aspect}.pth", map_location="cpu")
    num_classes = get_aspect_cls(aspect)
    aspect_models[aspect]= get_improved_model(num_classes)
    aspect_models[aspect].load_state_dict(checkpoint["model"])

combined_submission_df = predict(threshold=0.02)
propagated_df = propagate_predictions(combined_submission_df, go_parents)
cleaned_df = apply_negative_propagation(propagated_df, negative_keys)
final_df = add_goa_ground_truth(cleaned_df, goa_positive_df)
# final_df = final_df.sort_values(by='p', ascending=False)
final_df.to_csv('submission.tsv', sep='\t', index=False, header=False)