In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split, Dataset
from sklearn.preprocessing import MultiLabelBinarizer
from tqdm import tqdm
from torchmetrics.classification import MultilabelF1Score
from collections import defaultdict

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# KAGGLE PATH CONFIG
BASE_INPUT = "/kaggle/input"
CHECKPOINT_PATH = f"{BASE_INPUT}/checkpoint-310-cafa6/pytorch/default/1"
TRAIN_TERMS_PATH = (
    f"{BASE_INPUT}/cafa-6-protein-function-prediction/Train/train_terms.tsv"
)
TRAIN_SEQUENCES_PATH = (
    f"{BASE_INPUT}/cafa-6-protein-function-prediction/Train/train_sequences.fasta"
)
TEST_SEQUENCES_PATH = (
    f"{BASE_INPUT}/cafa-6-protein-function-prediction/Test/testsuperset.fasta"
)

PROT_EMBEDS = f"{BASE_INPUT}/protein-embeddings/protein_embeddings.npy"
PIDS = f"{BASE_INPUT}/protein-embeddings/protein_id.csv"

OBO_PATH = (
    f"{BASE_INPUT}/cafa-6-protein-function-prediction/Train/go-basic.obo"
)
GOA_PATH = f"{BASE_INPUT}/protein-go-annotations/goa_uniprot_all.csv"

In [None]:
protein_ids = pd.read_csv(PIDS)["protein_id"].tolist()
embeddings = np.load(PROT_EMBEDS, mmap_mode="r")

assert len(protein_ids) == embeddings.shape[0], \
    "protein_id count does not match embedding rows"

print(f"Loaded {embeddings.shape[0]} embeddings of dimension {embeddings.shape[1]}")

In [None]:
def parse_fasta(fasta_file):
    sequences = {}
    current_id, current_seq = None, []

    with open(fasta_file) as f:
        for line in f:
            line = line.strip()
            if not line:
                continue

            if line.startswith(">"):
                if current_id is not None:
                    sequences[current_id] = "".join(current_seq)

                header = line[1:]
                current_id = header.split("|")[1] if "|" in header else header.split()[0]
                current_seq = []
            else:
                current_seq.append(line)

        if current_id is not None:
            sequences[current_id] = "".join(current_seq)

    return sequences


In [None]:
def parse_obo(go_obo_path):
    parents = defaultdict(set)
    children = defaultdict(set)
    obsolete = set()

    with open(go_obo_path) as f:
        cur_id = None
        for line in f:
            line = line.strip()

            if line == "[Term]":
                cur_id = None
            elif line.startswith("id: "):
                cur_id = line.split("id: ")[1]
            elif line.startswith("is_obsolete: true"):
                obsolete.add(cur_id)
            elif line.startswith("is_a: ") and cur_id:
                pid = line.split()[1]
                parents[cur_id].add(pid)
                children[pid].add(cur_id)
            elif line.startswith("relationship: part_of ") and cur_id:
                pid = line.split()[2]
                parents[cur_id].add(pid)
                children[pid].add(cur_id)

    for o in obsolete:
        parents.pop(o, None)
        children.pop(o, None)

    return parents, children
def get_all_ancestors(term, go_parents, cache=None):
    if cache is None:
        cache = {}
    if term in cache:
        return cache[term]
    
    ancestors = set()
    stack = [term]
    while stack:
        cur = stack.pop()
        for parent in go_parents.get(cur, []):
            if parent not in ancestors:
                ancestors.add(parent)
                stack.append(parent)
    
    cache[term] = ancestors
    return ancestors

def get_all_descendants(term, go_children, cache=None):
    if cache is None:
        cache = {}
    if term in cache:
        return cache[term]
    
    descendants = set()
    stack = [term]
    while stack:
        cur = stack.pop()
        for child in go_children.get(cur, []):
            if child not in descendants:
                descendants.add(child)
                stack.append(child)
    
    cache[term] = descendants
    return descendants

go_parents, go_children = parse_obo(OBO_PATH)


In [None]:
# GOA NEGATIVE PROPAGATION
def load_goa_and_build_negative_keys(goa_path, go_children):
    """
    Load GOA annotations and build:
    1. negative_keys: protein-GO pairs to REMOVE (NOT + descendants)
    2. goa_ground_truth: positive annotations to ADD with score 1.0
    """
    if not os.path.exists(goa_path):
        print(f"[WARNING] GOA file not found at {goa_path}, skipping negative propagation")
        return set(), None

    print(f"Loading GOA annotations from {goa_path}...")
    goa_df = (
        pd.read_csv(goa_path)
        .drop_duplicates(subset=['protein_id', 'go_term', 'qualifier'])
    )
    print(f"Loaded {len(goa_df)} GOA annotations")

    # NEGATIVE ANNOTATIONS 
    print("Extracting negative annotations (NOT)...")
    neg_df = goa_df[goa_df['qualifier'].str.contains('NOT', na=False)]
    negative_by_protein = (
        neg_df.groupby('protein_id')['go_term']
        .apply(set)
        .to_dict()
    )

    print("Propagating negative annotations to descendants...")
    desc_cache = {}
    negative_keys = set()

    for protein, terms in tqdm(negative_by_protein.items(), desc="Propagating negatives"):
        all_terms = set(terms)
        for term in terms:
            all_terms.update(get_all_descendants(term, go_children, desc_cache))

        for term in all_terms:
            negative_keys.add(f"{protein}_{term}")

    print(f"Total unique negative protein-GO pairs: {len(negative_keys)}")

    # POSITIVE ANNOTATIONS
    print("Extracting positive GOA ground truth...")
    pos_df = goa_df[~goa_df['qualifier'].str.contains('NOT', na=False)]
    pos_df = pos_df[['protein_id', 'go_term']].drop_duplicates()

    pos_df['pred_key'] = pos_df['protein_id'].astype(str) + '_' + pos_df['go_term'].astype(str)
    pos_df = pos_df[~pos_df['pred_key'].isin(negative_keys)]

    pos_df = pos_df.drop(columns='pred_key')
    pos_df['score'] = 1.0

    print(f"Total positive GOA ground truth pairs: {len(pos_df)}")

    return negative_keys, pos_df


def propagate_predictions(predictions_df, go_parents):
    print("Propagating predictions to ancestor GO terms...")

    ancestor_cache = {}
    propagated_rows = []

    for pid, group in tqdm(predictions_df.groupby('pid'), desc="Propagating"):
        term_scores = {}

        for row in group.itertuples(index=False):
            term, score = row.term, row.p

            # Original term
            term_scores[term] = max(term_scores.get(term, 0.0), score)

            # Ancestors
            for ancestor in get_all_ancestors(term, go_parents, ancestor_cache):
                term_scores[ancestor] = max(term_scores.get(ancestor, 0.0), score)

        propagated_rows.extend(
            {'pid': pid, 'term': t, 'p': s}
            for t, s in term_scores.items()
        )

    propagated_df = pd.DataFrame(propagated_rows)

    print(f"Before propagation: {len(predictions_df)} predictions")
    print(f"After propagation: {len(propagated_df)} predictions")

    return propagated_df


def apply_negative_propagation(predictions_df, negative_keys):
    if not negative_keys:
        return predictions_df

    before = len(predictions_df)

    mask = (
        predictions_df['pid'].astype(str)
        + '_'
        + predictions_df['term'].astype(str)
    ).isin(negative_keys)

    predictions_df = predictions_df.loc[~mask].copy()

    after = len(predictions_df)
    print(f"Removed {before - after} negative predictions ({before} -> {after})")

    return predictions_df

def add_goa_ground_truth(predictions_df, goa_positive_df):
    """Add GOA ground truth annotations with score 1.0"""
    if goa_positive_df is None or goa_positive_df.empty:
        return predictions_df

    test_proteins = set(predictions_df['pid'].unique())

    goa_for_test = (
        goa_positive_df[goa_positive_df['protein_id'].isin(test_proteins)]
        .rename(columns={
            'protein_id': 'pid',
            'go_term': 'term',
            'score': 'p'
        })[['pid', 'term', 'p']]
    )

    print(f"Found {len(goa_for_test)} GOA annotations for test proteins")

    combined = pd.concat([predictions_df, goa_for_test], ignore_index=True)
    combined = combined.groupby(['pid', 'term'], as_index=False)['p'].max()

    print(f"After adding GOA: {len(predictions_df)} -> {len(combined)} predictions")

    return combined


In [None]:
negative_keys, goa_positive_df = load_goa_and_build_negative_keys(GOA_PATH, go_children)

In [None]:
  class ProtDataset(Dataset):
    def __init__(self, pids, labels, embeddings, pid_to_index):
        """
        pids: list[str]               - protein ids
        labels: np.ndarray or sparse  - multi-hot labels
        embeddings: np.ndarray/mmap   - (N, D)
        pid_to_index: dict[str, int]
        """
        self.pids = pids
        self.labels = labels
        self.embeddings = embeddings
        self.pid_to_index = pid_to_index

    def __len__(self):
        return len(self.pids)

    def __getitem__(self, idx):
        pid = self.pids[idx]

        # embedding
        emb_idx = self.pid_to_index.get(pid)
        if emb_idx is None:
            raise KeyError(f"Embedding not found for protein {pid}")

        embed = self.embeddings[emb_idx]

        # label
        if hasattr(self.labels, "toarray"):  # sparse matrix
            label = self.labels[idx].toarray().ravel()
        else:
            label = self.labels[idx]

        return (
            torch.from_numpy(embed).float(),
            torch.from_numpy(label).float()
        )

class TestDataset(Dataset):
    def __init__(self, pids, embeddings, pid_to_index):
        self.pids = pids
        self.embeddings = embeddings
        self.pid_to_index = pid_to_index

    def __len__(self):
        return len(self.pids)

    def __getitem__(self, idx):
        pid = self.pids[idx]
        emb_idx = self.pid_to_index.get(pid)

        if emb_idx is None:
            raise KeyError(f"Embedding not found for protein {pid}")

        embed = self.embeddings[emb_idx]
        return torch.from_numpy(embed).float()

In [None]:
class SimpleMLP(nn.Module):
    def __init__(self, input_dim=1280, num_classes=3000):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
        
    def forward(self, x):
        return self.network(x)

def get_improved_model(num_classes):
    return SimpleMLP(input_dim=1280, num_classes=num_classes)


In [None]:
#VISUALIZATION
def plot_losses_and_scores(train_losses, val_losses, train_scores, val_scores, aspect_name):
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))

    # Loss
    axes[0].plot(train_losses, label='Train Loss')
    axes[0].plot(val_losses, label='Val Loss')
    axes[0].set_title(f'{aspect_name} - Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].legend()
    axes[0].grid(True)

    # F1
    axes[1].plot(train_scores, label='Train F1')
    axes[1].plot(val_scores, label='Val F1')
    axes[1].set_title(f'{aspect_name} - Micro F1')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('F1')
    axes[1].legend()
    axes[1].grid(True)

    plt.tight_layout()
    plt.savefig(f'{aspect_name}_training_curves.png', dpi=150)
    plt.show()

def train(
    aspect_name,
    model,
    train_loader,
    valid_loader,
    num_epochs,
    num_classes,
    lr,
    grad_clip=1.0
):
    model = model.to(device)

    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=2
    )

    f1_metric = MultilabelF1Score(
        num_labels=num_classes,
        threshold=0.05,
        average='micro'
    ).to(device)

    train_losses, val_losses = [], []
    train_scores, val_scores = [], []

    for epoch in range(1, num_epochs + 1):
        # ================= TRAIN =================
        model.train()
        total_loss = 0.0
        all_preds, all_labels = [], []

        for X, y in train_loader:
            X, y = X.to(device), y.to(device)

            optimizer.zero_grad()
            logits = model(X)
            loss = loss_fn(logits, y)
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()

            total_loss += loss.item() * X.size(0)
            all_preds.append(torch.sigmoid(logits).detach())
            all_labels.append(y)

        train_loss = total_loss / len(train_loader.dataset)
        train_losses.append(train_loss)

        all_preds = torch.vstack(all_preds)
        all_labels = torch.vstack(all_labels)
        train_f1 = f1_metric(all_preds, all_labels).item()
        train_scores.append(train_f1)

        # ================= VALID =================
        model.eval()
        val_loss = 0.0
        all_preds, all_labels = [], []

        with torch.no_grad():
            for X, y in valid_loader:
                X, y = X.to(device), y.to(device)
                logits = model(X)
                loss = loss_fn(logits, y)

                val_loss += loss.item() * X.size(0)
                all_preds.append(torch.sigmoid(logits))
                all_labels.append(y)

        val_loss /= len(valid_loader.dataset)
        val_losses.append(val_loss)

        all_preds = torch.vstack(all_preds)
        all_labels = torch.vstack(all_labels)
        val_f1 = f1_metric(all_preds, all_labels).item()
        val_scores.append(val_f1)

        scheduler.step(val_loss)

        print(
            f"[{aspect_name}] Epoch {epoch:02d}/{num_epochs} | "
            f"Train Loss: {train_loss:.4f} | Train F1: {train_f1:.4f} | "
            f"Val Loss: {val_loss:.4f} | Val F1: {val_f1:.4f}"
        )

    plot_losses_and_scores(
        train_losses,
        val_losses,
        train_scores,
        val_scores,
        aspect_name
    )

    torch.cuda.empty_cache()
    return model

In [None]:
ASPECTS = ['C', 'F', 'P']
ASPECT_NAMES = {'C': 'Cellular Component (CCO)', 'F': 'Molecular Function (MFO)', 'P': 'Biological Process (BPO)'}
aspect_models = {}
aspect_mlbs = {}
def train_aspect_model(aspect_name, seed=42):
    global aspect_models, aspect_mlbs

    print(f"\n===== Training model for aspect {aspect_name} =====")

    # Reproducibility
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    aspect_df = train_terms_df[train_terms_df['aspect'] == aspect_name]

    # Group GO terms by protein
    protein_2_terms = (
        aspect_df.groupby('EntryID')['term']
        .apply(list)
        .to_dict()
    )

    # Filter valid proteins
    pid_train = [
        pid for pid in train_sequences
        if pid in embeddings_dict and pid in protein_2_terms
    ]

    labels_list = [protein_2_terms[pid] for pid in pid_train]

    # Multi-label binarizer
    mlb = MultiLabelBinarizer(sparse_output=True)
    y_train_labels = mlb.fit_transform(labels_list)

    aspect_mlbs[aspect_name] = mlb
    num_classes = len(mlb.classes_)

    print(
        f"{aspect_name} | Proteins: {len(pid_train)} | "
        f"GO terms: {num_classes}"
    )

    # Dataset
    dataset = ProtDataset(pid_train, y_train_labels, embeddings_dict)

    train_size = int(0.9 * len(dataset))
    valid_size = len(dataset) - train_size

    train_part, valid_part = random_split(
        dataset,
        [train_size, valid_size],
        generator=torch.Generator().manual_seed(seed)
    )

    train_loader = DataLoader(
        train_part,
        batch_size=128,
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )

    valid_loader = DataLoader(
        valid_part,
        batch_size=128,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    model = train(
        aspect_name=aspect_name,
        model=get_improved_model(num_classes),
        train_loader=train_loader,
        valid_loader=valid_loader,
        num_epochs=30,
        num_classes=num_classes,
        lr=1e-3
    )

    aspect_models[aspect_name] = model

    # Cleanup
    del dataset, train_part, valid_part, y_train_labels
    torch.cuda.empty_cache()
    if torch.cuda.is_available():
        torch.cuda.synchronize()

    return model, mlb

In [None]:
for aspect in ASPECTS:
    train_aspect_model(aspect)
    torch.save({
        "model": aspect_models[aspect].state_dict(),
    }, f"checkpoint-{aspect}.pth")

In [None]:
def predict(threshold=0.02, min_predictions=25):
    """
    Generate predictions for test proteins using trained aspect-specific models.
    """

    pid_test = [pid for pid in test_sequences if pid in embeddings_dict]

    test_dataset = TestDataSet(pid_test, embeddings_dict)
    test_loader = DataLoader(
        test_dataset,
        batch_size=128,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    submission_list = []

    for aspect in ASPECTS:
        print(f"\n=== Predicting aspect {aspect} ===")

        model = aspect_models[aspect].to(device)
        mlb = aspect_mlbs[aspect]
        model.eval()

        pid_offset = 0

        for batch_embeds in tqdm(test_loader, desc=f"Predicting ({aspect})"):
            batch_embeds = batch_embeds.to(device, non_blocking=True)
            batch_size = batch_embeds.size(0)

            with torch.no_grad():
                logits = model(batch_embeds)
                probs = torch.sigmoid(logits).cpu().numpy()

            for i in range(batch_size):
                protein_id = pid_test[pid_offset + i]
                scores = probs[i]

                # Threshold filtering
                term_indices = np.where(scores >= threshold)[0]

                # Ensure minimum predictions
                if len(term_indices) < min_predictions:
                    term_indices = np.argsort(scores)[-min_predictions:][::-1]

                for idx in term_indices:
                    submission_list.append({
                        'pid': protein_id,
                        'term': mlb.classes_[idx],
                        'p': float(scores[idx])
                    })

            pid_offset += batch_size

            del batch_embeds, logits, probs
            if pid_offset % 5000 == 0:
                torch.cuda.empty_cache()

    submission_df = pd.DataFrame(submission_list)
    print(f"\nTotal predictions: {len(submission_df)}")

    return submission_df

In [None]:
def get_aspect_num_classes(aspect_name):
    aspect_df = train_terms_df[train_terms_df['aspect'] == aspect_name]

    protein_2_terms = (
        aspect_df.groupby('EntryID')['term']
        .apply(list)
        .to_dict()
    )

    pid_train = [
        pid for pid in train_sequences
        if pid in embeddings_dict and pid in protein_2_terms
    ]

    labels_list = [protein_2_terms[pid] for pid in pid_train]

    mlb = MultiLabelBinarizer(sparse_output=True)
    mlb.fit(labels_list)

    return len(mlb.classes_)

In [None]:
for aspect in ASPECTS:
    print(f"Loading model for aspect {aspect}")

    checkpoint_path = f"/kaggle/working/checkpoint-{aspect}.pth"
    checkpoint = torch.load(checkpoint_path, map_location="cpu")

    num_classes = get_aspect_num_classes(aspect)

    model = get_improved_model(num_classes)
    model.load_state_dict(checkpoint["model"])
    model.to(device)
    model.eval()

    aspect_models[aspect] = model

combined_submission_df = predict(threshold=0.02)

propagated_df = propagate_predictions(
    combined_submission_df,
    go_parents
)

cleaned_df = apply_negative_propagation(
    propagated_df,
    negative_keys
)

final_df = add_goa_ground_truth(
    cleaned_df,
    goa_positive_df
)

final_df.to_csv(
    "submission.tsv",
    sep="\t",
    index=False,
    header=False
)