# CAFA-6 Protein Function Prediction: The Ensemble Era 

Welcome to this comprehensive pipeline for the **CAFA-6 Competition**!

In this notebook, we will build a powerful, **multi-model ensemble** system to predict protein functions (GO terms) from amino acid sequences. We will combine the strengths of three state-of-the-art Protein Language Models (pLMs):

1.  **ProtBERT**  (Rostlab)
2.  **ESM-2**  (Meta AI)

###  Key Features
-   **Visual EDA**: Beautiful plots to understand our data.
-   **GPU Optimization**: Efficient embedding generation.
-   **Ensemble Learning**: Averaging predictions for robustness.
-   **Clean Code**: Structured, commented, and easy to follow.

Let's dive in! 

In [None]:
!pip install biopython --quiet # I was getting an error so that's why doing this

In [None]:
# Imports & Setup
import os
import sys
import gc
import json
import re
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, T5EncoderModel
from Bio import SeqIO
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import sparse
from torch.utils.data import TensorDataset, DataLoader
from multiprocessing import Pool, cpu_count
import glob
# TPU SUPPORT (ADDED
USE_TPU = False
try:
    import torch_xla
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    import torch_xla.debug.metrics as met
    import torch_xla.utils.utils as xu
    USE_TPU = True
except ImportError:
    USE_TPU = False

#  Plotting Style
sns.set_theme(style="whitegrid", palette="viridis")
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 12

#  Paths
INPUT_DIR = "/kaggle/input/cafa-6-protein-function-prediction/"
WORKING_DIR = "/kaggle/working/"
OUTPUT_LABELS_DIR = os.path.join(WORKING_DIR, "outputs_labels")
EMBEDDINGS_DIR = os.path.join(WORKING_DIR, "embeddings")
MODELS_DIR = os.path.join(WORKING_DIR, "models")
os.makedirs(OUTPUT_LABELS_DIR, exist_ok=True)
os.makedirs(EMBEDDINGS_DIR, exist_ok=True)
os.makedirs(MODELS_DIR, exist_ok=True)

if USE_TPU:
    device = torch_xla.device()
    print(" TPU detected. Using device:", device)
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(" Using device:", device)

## Part 1: Exploratory Data Analysis (EDA)
Before we model, we must understand. Let's load the data and visualize the sequence lengths and GO term distributions.

In [None]:
#  Load Data
print(" Loading sequences...")
train_seqs = {rec.id: str(rec.seq) for rec in SeqIO.parse(os.path.join(INPUT_DIR, "Train/train_sequences.fasta"), "fasta")}
test_seqs = {rec.id: str(rec.seq) for rec in SeqIO.parse(os.path.join(INPUT_DIR, "Test/testsuperset.fasta"), "fasta")}
train_terms_df = pd.read_csv(os.path.join(INPUT_DIR, "Train/train_terms.tsv"), sep="\t")

print(f" Train Sequences: {len(train_seqs):,}")
print(f" Test Sequences: {len(test_seqs):,}")
print(f" Train Annotations: {len(train_terms_df):,}")

# Plot Sequence Length Distribution
train_lens = [len(s) for s in train_seqs.values()]
test_lens = [len(s) for s in test_seqs.values()]

plt.figure(figsize=(14, 6))
sns.histplot(train_lens, color="blue", label="Train", kde=True, alpha=0.5, log_scale=True)
sns.histplot(test_lens, color="orange", label="Test", kde=True, alpha=0.5, log_scale=True)
plt.title(" Sequence Length Distribution (Log Scale)")
plt.xlabel("Sequence Length")
plt.legend()
plt.show()

#  Plot Top GO Terms
top_terms = train_terms_df['term'].value_counts().head(20)
plt.figure(figsize=(14, 8))
sns.barplot(x=top_terms.values, y=top_terms.index, palette="viridis")
plt.title(" Top 20 Most Frequent GO Terms")
plt.xlabel("Count")
plt.show()

##  Part 2: The Knowledge Graph (Gene Ontology)
Proteins functions are related! If a protein does X, and X is a subclass of Y, it also does Y. We need to parse the **Gene Ontology (GO)** graph to enforce these rules.

In [None]:
#  Parse OBO & Build Ancestors
def parse_obo(obo_file):
    terms = {}
    current_term = None
    with open(obo_file, 'r') as f:
        for line in f:
            line = line.strip()
            if line == "[Term]":
                if current_term:
                    terms[current_term['id']] = current_term
                current_term = {'is_a': [], 'namespace': ''}
            elif line.startswith("id: "):
                current_term['id'] = line[4:]
            elif line.startswith("namespace: "):
                current_term['namespace'] = line[11:]
            elif line.startswith("is_a: "):
                current_term['is_a'].append(line[6:].split(' ! ')[0])
    if current_term:
        terms[current_term['id']] = current_term
    return terms

go_terms = parse_obo(os.path.join(INPUT_DIR, "Train/go-basic.obo"))
print(f" Parsed {len(go_terms):,} GO terms.")

# Build Ancestors (Transitive Closure)
ancestors = {}
for term_id in tqdm(go_terms, desc=" Building Graph"):
    queue = [term_id]
    visited = set()
    while queue:
        curr = queue.pop(0)
        if curr in visited: continue
        visited.add(curr)
        if curr in go_terms:
            queue.extend(go_terms[curr]['is_a'])
    ancestors[term_id] = list(visited)
with open(os.path.join(OUTPUT_LABELS_DIR, "ancestors.json"), "w") as f:
    json.dump(ancestors, f)
print(" Ancestor graph saved.")

In [None]:
def extract_uniprot_id(header):
    # Extract between first pair of pipes ‚Üí sp|XXXX| or tr|XXXX|
    m = re.match(r"^[a-z]{2}\|([^|]+)\|", header)
    if m:
        return m.group(1)
    return header  # fallback

# Build mapping: extract UniProt ID for every FASTA protein
train_proteins_raw = list(train_seqs.keys())
train_proteins = [extract_uniprot_id(h) for h in train_proteins_raw]

prot_map = {pid: i for i, pid in enumerate(train_proteins)}

# Filter annotation proteins to the ones that appear in FASTA
valid_annots = train_terms_df[train_terms_df["EntryID"].isin(prot_map)]

print("Proteins in FASTA:", len(train_seqs))
print("Proteins in annotations:", len(train_terms_df))
print("Overlap:", len(valid_annots))

used_terms = set(train_terms_df["term"].unique())

terms_MF = [t for t in used_terms if t in go_terms and go_terms[t]['namespace'] == 'molecular_function']
terms_BP = [t for t in used_terms if t in go_terms and go_terms[t]['namespace'] == 'biological_process']
terms_CC = [t for t in used_terms if t in go_terms and go_terms[t]['namespace'] == 'cellular_component']


term_maps = {
    'MF': {str(i): t for i, t in enumerate(terms_MF)},
    'BP': {str(i): t for i, t in enumerate(terms_BP)},
    'CC': {str(i): t for i, t in enumerate(terms_CC)}
}

with open(os.path.join(OUTPUT_LABELS_DIR, "maps.json"), "w") as f:
    json.dump(term_maps, f)

for ns, terms_list in [('MF', terms_MF), ('BP', terms_BP), ('CC', terms_CC)]:

    go_to_idx = {t: i for i, t in enumerate(terms_list)}

    df_ns = valid_annots[valid_annots["term"].isin(terms_list)]
    
    p_idx = df_ns["EntryID"].map(prot_map).values
    t_idx = df_ns["term"].map(go_to_idx).values

    mat = sparse.coo_matrix(
        (np.ones(len(p_idx)), (p_idx, t_idx)),
        shape=(len(train_proteins), len(terms_list)),
        dtype=np.int8
    ).tocsr()

    sparse.save_npz(os.path.join(OUTPUT_LABELS_DIR, f"labels_{ns}.npz"), mat)

    print(f"{ns}: proteins={mat.shape[0]}, terms={mat.shape[1]}, nonzero={mat.nnz}")


In [None]:
#cheking the .npz files
for ns in ["MF", "BP", "CC"]:
    mat = sparse.load_npz(os.path.join(OUTPUT_LABELS_DIR, f"labels_{ns}.npz"))
    print(ns, "nonzero:", mat.count_nonzero(), "shape:", mat.shape)


##  Part 3: Feature Engineering (Embeddings)
Here I will now extract rich features from protein sequences using **three** different models. This is the secret sauce! 
We define a generic embedding function to reuse for all models.

In [None]:
def embed_sequences(model_name, seq_dict, batch_size=16, max_len=1024):

    # FORCE GPU ONLY
    if not torch.cuda.is_available():
        raise SystemExit(" GPU NOT FOUND ‚Äî Cannot embed on CPU. Aborting as requested.")

    device = torch.device("cuda")
    print(" FORCING GPU:", device)

    name = model_name.lower()
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(
            model_name,
            torch_dtype=torch.float16
        )

    model = model.to(device)
    model.eval()

    seqs = list(seq_dict.values())
    ids = list(seq_dict.keys())
    embeddings = []

    for i in tqdm(range(0, len(seqs), batch_size)):
        batch = seqs[i:i + batch_size]

        # TOKENIZATION RULES
        if "t5" in name:
            batch_tok = [" ".join(list(s)) for s in batch]
        elif "bert" in name:
            batch_tok = [" ".join(list(s)) for s in batch]
        else:
            batch_tok = batch   

        inputs = tokenizer(
            batch_tok,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_len
        ).to(device)

        with torch.no_grad():
            with torch.cuda.amp.autocast(dtype=torch.float16):

                outputs = model(**inputs)
                hidden = outputs.last_hidden_state   
                mask = inputs["attention_mask"].unsqueeze(-1).float()

                # mean pooling
                emb = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)

        # move to CPU only at the last moment
        embeddings.append(emb.float().cpu().numpy())

        torch.cuda.empty_cache()
        gc.collect()

    return np.concatenate(embeddings), ids


**Note: I have done the embedding locally so you can check the dataset**

In [None]:
# emb_train_bert, _ = embed_sequences("Rostlab/prot_bert", train_seqs, batch_size=128)
# np.save("train_protbert.npy", emb_train_bert)
# del emb_train_bert

# emb_test_bert, test_ids = embed_sequences("Rostlab/prot_bert", test_seqs, batch_size=128)
# np.save("test_protbert.npy", emb_test_bert)
# np.save("test_ids.npy", test_ids)
# del emb_test_bert
print("Protbert embeddings saved") 

In [None]:
# emb_train_esm, _ = embed_sequences("facebook/esm2_t30_150M_UR50D", train_seqs, batch_size=8)
# np.save("train_esm2.npy", emb_train_esm)
# del emb_train_esm

# emb_test_esm, _ = embed_sequences("facebook/esm2_t30_150M_UR50D", test_seqs, batch_size=8)
# np.save("test_esm2.npy", emb_test_esm)
# del emb_test_esm
print("ESM-2(150M) embeddings saved")

**I am using a different embedding function for Prot5 cause I was finding difficulty to create a single unifies one for Prott5**

**We freeze ProtT5 and generate contextual protein embeddings, then train lightweight downstream heads.**

In [None]:
# model_name = "Rostlab/prot_t5_xl_uniref50"
# tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, do_lower_case=False)

# def preprocess_for_t5(seq):
#     return " ".join(list(seq))

# def tokenize_all(seqs, max_len=1024):
#     tok_inputs = []
#     for s in tqdm(seqs, desc="Pre-tokenizing"):
#         t = tokenizer(
#             preprocess_for_t5(s),
#             truncation=True,
#             max_length=max_len,
#             padding="max_length",   # <- CRITICAL for speed
#             return_tensors="np"
#         )
#         tok_inputs.append((t["input_ids"][0], t["attention_mask"][0]))
#     return np.array([x[0] for x in tok_inputs]), np.array([x[1] for x in tok_inputs])

# train_ids_np, train_mask_np = tokenize_all(list(train_seqs.values()))
# np.save("train_ids.npy", train_ids_np)
# np.save("train_mask.npy", train_mask_np)
print("Saved: train_ids.npy, train_mask.npy")


In [None]:
# def tokenize_all(seqs, max_len=1024):
#     tok_inputs = []
#     for s in tqdm(seqs, desc="Pre-tokenizing"):
#         t = tokenizer(
#             preprocess_for_t5(s),
#             truncation=True,
#             max_length=max_len,
#             padding="max_length",
#             return_tensors="np"
#         )
#         tok_inputs.append((t["input_ids"][0], t["attention_mask"][0]))
#     return np.array([x[0] for x in tok_inputs]), np.array([x[1] for x in tok_inputs])

# test_ids_np, test_mask_np = tokenize_all(list(test_seqs.values()))

# np.save("test_ids.npy", test_ids_np)
# np.save("test_mask.npy", test_mask_np)

print("Saved: test_ids.npy, test_mask.npy")


In [None]:
train_ids = np.load("/kaggle/input/embeddings-for-cafa/train_ids_T5.npy")
train_mask = np.load("/kaggle/input/embeddings-for-cafa/train_mask.npy")
test_ids = np.load("/kaggle/input/embeddings-for-cafa/test_ids_T5.npy")
test_mask = np.load("/kaggle/input/embeddings-for-cafa/test_mask.npy")
print('loaded the npy files necessary to embed T5')

In [None]:
def prot_t5_embed(
    model_name: str,
    ids_np: np.ndarray,
    mask_np: np.ndarray,
    batch_size: int = 32,
    layer_idx: int = 12,
):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load model
    model = T5EncoderModel.from_pretrained(
        model_name,
        torch_dtype=torch.float16 if device.type == "cuda" else torch.float32
    ).to(device)

    model.eval()

    N = ids_np.shape[0]
    hidden_dim = model.config.d_model

    # Preallocate output
    embeddings = np.zeros((N, hidden_dim), dtype=np.float32)

    for start in tqdm(range(0, N, batch_size), desc="ProtT5 Embedding"):
        end = min(start + batch_size, N)

        ids = torch.tensor(ids_np[start:end], device=device)
        mask = torch.tensor(mask_np[start:end], device=device)

        with torch.no_grad():
            with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
                outputs = model(
                    input_ids=ids,
                    attention_mask=mask,
                    output_hidden_states=True,
                    return_dict=True
                )

                # Extract chosen layer
                h = outputs.hidden_states[layer_idx]   # (B, L, 1024)

                # Masked mean pooling
                mask_f = mask.unsqueeze(-1)
                pooled = (h * mask_f).sum(1) / mask_f.sum(1).clamp(min=1e-9)

        embeddings[start:end] = pooled.float().cpu().numpy()

        del ids, mask, h, pooled

    return embeddings


In [None]:
# emb_t5_train = prot_t5_embed(
#     "Rostlab/prot_t5_xl_uniref50",
#     train_ids, train_mask,
#     batch_size=4
# )

# np.save("/kaggle/working/train_prott5.npy", emb_t5_train)
# del emb_t5_train

print("Training prott5 completed")

**test chunks were large so that's why I am chunking and combining them**

In [None]:
# MODEL_NAME = "Rostlab/prot_t5_xl_uniref50"
# MAX_LEN = 512
# LAYER_IDX = 12
# EMB_DIM = 1024

# MICRO_BATCH = 4                 # per TPU core (SAFE)
# NUM_CORES = len(xm.get_xla_supported_devices())  # ‚úÖ ALWAYS WORKS
# GLOBAL_BATCH = MICRO_BATCH * NUM_CORES            # 32

# CHUNK_SIZE = 4096               # proteins per chunk (~2‚Äì3 min)
# OUT_DIR = "/kaggle/working/prot_t5_test_chunks"
# os.makedirs(OUT_DIR, exist_ok=True)

# print(" TPU devices:", xm.get_xla_supported_devices())
# print(" TPU cores:", NUM_CORES)
# print(" Global batch:", GLOBAL_BATCH)

# # ======================================================
# # DATA
# # ======================================================
# # test_ids, test_mask must already exist
# N = test_ids.shape[0]
# print("Total test proteins:", N)

# # ======================================================
# # MODEL (TPU)
# # ======================================================
# device = xm.xla_device()

# model = T5EncoderModel.from_pretrained(
#     MODEL_NAME,
#     torch_dtype=torch.bfloat16
# ).to(device)

# model.eval()

# # ======================================================
# # RESUME LOGIC
# # ======================================================
# def get_next_chunk_id(out_dir):
#     files = [f for f in os.listdir(out_dir) if f.startswith("chunk_")]
#     if not files:
#         return 0
#     ids = sorted(int(f.split("_")[1].split(".")[0]) for f in files)
#     return ids[-1] + 1

# start_chunk = get_next_chunk_id(OUT_DIR)
# print(" Resuming from chunk:", start_chunk)

# # ======================================================
# # MAIN LOOP
# # ======================================================
# start_time = time.time()
# num_chunks = math.ceil(N / CHUNK_SIZE)

# for chunk_id in range(start_chunk, num_chunks):
#     c_start = chunk_id * CHUNK_SIZE
#     c_end = min(N, c_start + CHUNK_SIZE)
#     cur_N = c_end - c_start

#     print(f"\n Chunk {chunk_id} | proteins {c_start}:{c_end}")

#     out_chunk = np.zeros((cur_N, EMB_DIM), dtype=np.float32)
#     write_ptr = 0

#     for i in tqdm(
#         range(c_start, c_end, GLOBAL_BATCH),
#         desc=f"TPU Chunk {chunk_id}",
#         leave=False
#     ):
#         j = min(i + GLOBAL_BATCH, c_end)
#         bs = j - i

#         ids = test_ids[i:j, :MAX_LEN]
#         mask = test_mask[i:j, :MAX_LEN]

#         # Pad to GLOBAL_BATCH
#         if bs < GLOBAL_BATCH:
#             pad = GLOBAL_BATCH - bs
#             ids = np.pad(ids, ((0, pad), (0, 0)))
#             mask = np.pad(mask, ((0, pad), (0, 0)))

#         ids_t = torch.tensor(ids, device=device)
#         mask_t = torch.tensor(mask, device=device)

#         with torch.no_grad():
#             out = model(
#                 input_ids=ids_t,
#                 attention_mask=mask_t,
#                 output_hidden_states=True,
#                 return_dict=True
#             )

#             h = out.hidden_states[LAYER_IDX]     # (B, L, 1024)
#             mask_f = mask_t.unsqueeze(-1).to(h.dtype)
#             pooled = (h * mask_f).sum(1) / mask_f.sum(1).clamp(min=1e-9)

#         pooled = pooled[:bs].float().cpu().numpy()
#         out_chunk[write_ptr:write_ptr + bs] = pooled
#         write_ptr += bs

#         xm.mark_step()   # TPU sync
#         del ids_t, mask_t, h, pooled
#         gc.collect()

#     # SAVE CHUNK
#     chunk_path = os.path.join(OUT_DIR, f"chunk_{chunk_id:03d}.npy")
#     np.save(chunk_path, out_chunk)
#     print(f" Saved {chunk_path}")

# elapsed = (time.time() - start_time) / 60
print(f"\n DONE with ProtT5 Embeddings")
# print(f" Total time: {elapsed:.1f} minutes")


In [None]:
# files = sorted(glob.glob("/kaggle/working/prot_t5_test_chunks/chunk_*.npy"))
# X = np.concatenate([np.load(f) for f in files], axis=0)

# np.save("/kaggle/working/prot_t5_test.npy", X)
print(" Final test embeddings saved")


##  Part 4: Training the Single-Head Classifier

We will train a separate **Single-Layer Perceptron (MLP)** for each embedding type.  SLP would have three heads:  Molecular Function (MF), Biological Process (BP), Cellular Component (CC).

I tried using Multi-Head classifier but the problem is I was constantly getting Memoryerror that's why swithced to single head

In [None]:
BASE_EMB = "/kaggle/input/embeddings-for-cafa/"         # ProtBERT & ESM2 embedding files
BASE_LABELS = "/kaggle/working/outputs_labels/"  # labels_MF.npz, labels_BP.npz, labels_CC.npz

labels_MF = sparse.load_npz(BASE_LABELS + "labels_MF.npz")
labels_BP = sparse.load_npz(BASE_LABELS + "labels_BP.npz")x
labels_CC = sparse.load_npz(BASE_LABELS + "labels_CC.npz")

labels_by_ns = {
    "MF": labels_MF,
    "BP": labels_BP,
    "CC": labels_CC
}

out_dims = {
    ns: mat.shape[1] for ns, mat in labels_by_ns.items()
}

print(" Loaded labels. Output dimensions:", out_dims)


In [None]:
from sklearn.model_selection import train_test_split

# All protein indices
all_ids = np.arange(labels_MF.shape[0])

# 10% validation split
train_idx, val_idx = train_test_split(all_ids, test_size=0.1, random_state=42)

def split_labels(mat):
    return mat[train_idx], mat[val_idx]

labels_train = {}
labels_val = {}

# Split MF/BP/CC label matrices
for ns, mat in labels_by_ns.items():
    labels_train[ns], labels_val[ns] = split_labels(mat)

# Load embeddings
train_emb_prot = np.load(BASE_EMB + "train_protbert.npy", mmap_mode="r")
train_emb_esm  = np.load(BASE_EMB + "train_esm2.npy", mmap_mode="r")

# Split embeddings
prot_train = train_emb_prot[train_idx]
prot_val   = train_emb_prot[val_idx]

esm_train  = train_emb_esm[train_idx]
esm_val    = train_emb_esm[val_idx]

print("Train/Val Split Complete")
print("ProtBERT train:", prot_train.shape, "val:", prot_val.shape)
print("ESM2 train:", esm_train.shape, "val:", esm_val.shape)


In [None]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# confirming device detection
USING_TPU = False
try:
    import torch_xla.core.xla_model as xm
    device_tpu = xm.xla_device()
    USING_TPU = True
except Exception:
    device_tpu = None

if torch.cuda.is_available():
    device = torch.device("cuda")
    print(" GPU detected. Executing on CUDA.")
elif USING_TPU:
    device = device_tpu
    print(" TPU detected. Executing on TPU.")
else:
    device = torch.device("cpu")
    print(" CPU fallback:", device)


In [None]:
class SingleHeadMLP(nn.Module):
    def __init__(self, input_dim, proj_dim, out_dim):
        super().__init__()

        hidden_dim = proj_dim

        self.ln_in = nn.LayerNorm(input_dim)

        self.block1 = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.3),
        )

        self.block2 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.3),
        )

        self.out_linear = nn.Linear(hidden_dim, out_dim)

    def forward(self, x):
        x = self.ln_in(x)

        h1 = self.block1(x)
        h2 = self.block2(h1)

        return self.out_linear(h2 + h1)   
        
class ProteinDatasetNS(Dataset):
    def __init__(self, emb_memmap, label_mat):
        self.emb = emb_memmap
        self.labels_by_ns = label_mat

    def __len__(self):
        return self.emb.shape[0]

    def __getitem__(self, idx):
        emb = torch.tensor(self.emb[idx], dtype=torch.float32)
        lab = torch.tensor(self.labels_by_ns[idx].toarray().ravel(), dtype=torch.float32)
        return emb, lab

In [None]:
def train_ns_only(model_name, ns, train_emb_path, label_mat,
                  proj_dim=768, batch_size=16, epochs=12,
                  train_emb_tensor=None, val_emb=None, val_labels=None):

    print(f"\n[{model_name} | {ns}] TRAINING...")

    if train_emb_tensor is not None:
        train_emb = train_emb_tensor  
    else:
        train_emb = np.load(train_emb_path, mmap_mode="r")

    input_dim = train_emb.shape[1]
    out_dim   = label_mat.shape[1]

    dataset = ProteinDatasetNS(train_emb, label_mat)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

    model = SingleHeadMLP(input_dim, proj_dim, out_dim).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
    
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=epochs,
        eta_min=3e-5  
    )

    label_counts = np.array(label_mat.sum(axis=0)).ravel()
    neg_counts   = (label_mat.shape[0] - label_counts)

    pos_weight_np = neg_counts / (label_counts + 1e-6)

    pos_weight_np = np.clip(pos_weight_np, 1.0, 50.0)  # you can try 20.0‚Äì50.0 range
    
    pos_weight = torch.tensor(pos_weight_np, dtype=torch.float32).to(device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    # Mixed precision
    use_amp = (torch.cuda.is_available() and not USING_TPU)
    scaler = torch.amp.GradScaler(enabled=use_amp)

    for epoch in range(epochs):
        model.train()
        total_loss = 0.0

        for emb, lab in tqdm(dataloader, desc=f"{model_name} {ns} Epoch {epoch+1}", leave=False):
            emb, lab = emb.to(device), lab.to(device)
            optimizer.zero_grad(set_to_none=True)

            if USE_TPU:
                out = model(emb)
                loss = criterion(out, lab)
                loss.backward()
                xm.optimizer_step(optimizer)
                xm.mark_step()

            else:
                if use_amp:
                    with torch.amp.autocast("cuda", dtype=torch.float16):
                        out = model(emb)
                        loss = criterion(out, lab)
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    out = model(emb)
                    loss = criterion(out, lab)
                    loss.backward()
                    optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / max(1, len(dataloader))
        print(f"[{model_name} | {ns}] Epoch {epoch+1}/{epochs} - Loss = {avg_loss:.4f}")
    
        scheduler.step()

    # Cleanup
    del dataset, dataloader, train_emb
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    model.eval()
    print(f"[{model_name} | {ns}] TRAINING COMPLETE.")

    DIAGNOSE = True
    if DIAGNOSE:
        print(f"\n[{model_name} | {ns}] Running inference diagnostics...")

        if input_dim == 1024:
            test_path = BASE_EMB + "test_protbert.npy"
        elif input_dim == 640:
            test_path = BASE_EMB + "test_esm2.npy"
        else:
            raise ValueError(f"Unknown embedding dim: {input_dim}")

        test_emb_small = np.load(test_path, mmap_mode="r")[:4]
        test_emb_small = torch.tensor(test_emb_small, dtype=torch.float32).to(device)

        with torch.no_grad():
            raw_logits = model(test_emb_small)
            probs = torch.sigmoid(raw_logits)

        print("\nSample probabilities (first protein, first 20 terms):")
        print(probs[0][:20].cpu().numpy())

        print("\nMax prob:", probs.max().item())
        print("Min prob:", probs.min().item())
        print("Mean prob:", probs.mean().item())

        if probs.max().item() < 0.001:
            print("\nWARNING: Model predictions are EXTREMELY SMALL ‚Äî empty submission risk.")
        else:
            print("\nDiagnostics OK ‚Äî model produces real signals.")

        del test_emb_small, raw_logits, probs
        gc.collect()

    return model


In [None]:
labels_MF = labels_by_ns["MF"]
labels_BP = labels_by_ns["BP"]
labels_CC = labels_by_ns["CC"]

namespaces = ["MF", "BP", "CC"]

protbert_models = {}
esm2_models = {}

for ns in namespaces:
    print(f"\n========== NAMESPACE: {ns} ==========")

    # Train ProtBERT 
    protbert_models[ns] = train_ns_only(
        model_name="ProtBERT",
        ns=ns,
        train_emb_path=os.path.join(BASE_EMB, "train_protbert.npy"),
        label_mat=labels_train[ns],
        train_emb_tensor=prot_train,     
        val_emb=prot_val,
        val_labels=labels_val[ns]
    )



    # Train ESM2 
    esm2_models[ns] = train_ns_only(
        model_name="ESM2-150M",
        ns=ns,
        train_emb_path=os.path.join(BASE_EMB, "train_esm2.npy"),
        label_mat=labels_train[ns],
        train_emb_tensor=esm_train,       
        val_emb=esm_val,
        val_labels=labels_val[ns]
    )



print("\nAll ProtBERT and ESM2 models trained")


##  Part 5: Submission

After completing training for all namespaces (MF, BP, CC) with both the ProtBERT and ESM2-150M models, we generate predictions for the full CAFA-6 test set.
Instead of saving massive .npy files (which corrupt easily and exceed Kaggle‚Äôs storage limits), we execute a streamed ensemble pipeline

In [None]:
# Load mappings and ancestors (unchanged)
with open(os.path.join(OUTPUT_LABELS_DIR, "maps.json")) as f:
    maps = json.load(f)

with open(os.path.join(OUTPUT_LABELS_DIR, "ancestors.json")) as f:
    ancestors = json.load(f)

IMPORTED_EMBEDDINGS_DIR = BASE_EMB

test_ids = np.load(os.path.join(IMPORTED_EMBEDDINGS_DIR, "test_ids.npy"))
test_prot = np.load(os.path.join(IMPORTED_EMBEDDINGS_DIR, "test_protbert.npy"), mmap_mode="r")
test_esm  = np.load(os.path.join(IMPORTED_EMBEDDINGS_DIR, "test_esm2.npy"), mmap_mode="r")

N_test = test_prot.shape[0]
assert len(test_ids) == N_test, "Mismatch between test_ids and test_prot size."

print(f"Test proteins: {N_test}")
print("Starting streamed ensemble + submission generation...")

ensemble_weights = {
    "MF": (0.4, 0.6),  
    "BP": (0.3, 0.7),
    "CC": (0.5, 0.5)
}

temperatures = {
    "MF": 0.9,
    "BP": 0.8,
    "CC": 1.0
}

thresholds = {
    "MF": 1e-3,
    "BP": 2e-3,
    "CC": 5e-4
}

maps_idx_to_go = {}
for ns in namespaces:
    D = len(maps[ns])
    arr = np.empty(D, dtype=object)
    for i in range(D):
        arr[i] = maps[ns][str(i)]
    maps_idx_to_go[ns] = arr

ancestors_fast = ancestors

TOPK_RAW = 3000

CHUNK_SIZE = 512
submission_path = os.path.join(WORKING_DIR, "submission.tsv")

with open(submission_path, "w") as fout:
    for start in tqdm(range(0, N_test, CHUNK_SIZE), desc="Batches"):
        end = min(N_test, start + CHUNK_SIZE)

        # Load embeddings for this batch
        prot_batch_np = test_prot[start:end]
        esm_batch_np  = test_esm[start:end]

        prot_batch = torch.tensor(prot_batch_np, dtype=torch.float32).to(device)
        esm_batch  = torch.tensor(esm_batch_np, dtype=torch.float32).to(device)

        ns_scores = {}

        for ns in namespaces:
            w_pb, w_es = ensemble_weights[ns]
            T = temperatures[ns]

            with torch.no_grad():
                pb_logits = protbert_models[ns](prot_batch)
                es_logits = esm2_models[ns](esm_batch)

                logits = (w_pb * pb_logits) + (w_es * es_logits)
                logits = logits / T

                scores = torch.sigmoid(logits).cpu().numpy()

            ns_scores[ns] = scores

        batch_ids = test_ids[start:end]
        chunk_lines = [] 

        for local_idx, pid in enumerate(batch_ids):
            protein_lines = []

            for ns in namespaces:
                scores = ns_scores[ns][local_idx]
                th = thresholds[ns]

                idxs = np.where(scores > th)[0]
                if idxs.size == 0:
                    continue

                if idxs.size > TOPK_RAW:
                    topk_idx = np.argpartition(scores[idxs], -TOPK_RAW)[-TOPK_RAW:]
                    idxs = idxs[topk_idx]

                go_terms = maps_idx_to_go[ns][idxs]
                vals = scores[idxs]

                term_scores = {go_id: float(v) for go_id, v in zip(go_terms, vals)}

                prop = dict(term_scores)
                for go_id, score in term_scores.items():
                    if go_id in ancestors_fast:
                        for anc in ancestors_fast[go_id]:
                            if anc not in prop or score > prop[anc]:
                                prop[anc] = score

                protein_lines.extend(prop.items())

            if protein_lines:
                protein_lines.sort(key=lambda x: x[1], reverse=True)
                for go_id, score in protein_lines[:1500]:
                    chunk_lines.append(f"{pid}\t{go_id}\t{score:.3f}\n")


        if chunk_lines:
            fout.writelines(chunk_lines)

        del prot_batch, esm_batch, prot_batch_np, esm_batch_np, ns_scores, chunk_lines
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

print(f"Submission written to: {submission_path}")


In [None]:
# Paths
WORK_DIR = "/kaggle/working/"
sub_path = os.path.join(WORK_DIR, "submission.tsv")
ids_path = "/kaggle/input/embeddings-for-cafa/test_ids.npy"
maps_path = os.path.join(WORK_DIR, "outputs_labels/maps.json")

print("üîç Loading data...")

# Load submissi
try:
    df = pd.read_csv(sub_path, sep="\t", header=None, names=["protein", "goterm", "score"])
except:
    raise ValueError("‚ùå FAILED TO READ submission.tsv. Check file format.")

# Load test IDs
test_ids = np.load(ids_path)

# Load GO maps
with open(maps_path) as f:
    maps = json.load(f)

# Build GO term set
all_go_terms = set()
for ns in ["MF", "BP", "CC"]:
    for k, v in maps[ns].items():
        all_go_terms.add(v)

print("‚úÖ Loaded submission, test IDs, and GO term maps.")


# ------------------------
# 1. EMPTY FILE CHECK
# ------------------------
print("\n=== EMPTY FILE CHECK ===")
if df.shape[0] == 0:
    print("‚ùå ERROR: submission.tsv is EMPTY.")
    print("Possible causes:")
    print("- You used an incorrect final_preds path.")
    print("- Your submission loop never appended predictions.")
    print("- Maps.json or ancestors.json mismatched.")
    raise SystemExit()

print(f"‚úÖ submission.tsv contains {df.shape[0]} rows.")


# ------------------------
# 2. BASIC STRUCTURE CHECK
# ------------------------
print("\n=== BASIC STRUCTURE CHECK ===")

if df.shape[1] != 3:
    print("‚ùå ERROR: submission must have 3 columns: protein, goterm, score")
    print("You have:", df.shape[1])
    raise SystemExit()

print("‚úÖ Correct number of columns.")


# ------------------------
# 3. CHECK IF ALL PROTEINS APPEAR
# ------------------------
print("\n=== PROTEIN COUNT CHECK ===")

unique_proteins = df["protein"].nunique()
expected = len(test_ids)

print("Proteins in submission:", unique_proteins)
print("Proteins expected:", expected)

if unique_proteins == 0:
    print("‚ùå ERROR: No proteins found in submission.")
    raise SystemExit()

if unique_proteins < expected:
    print("‚ö†Ô∏è WARNING: Missing proteins from submission.")
else:
    print("‚úÖ All proteins present (or extra rows exist).")


# ------------------------
# 4. SCORE VALIDITY CHECK
# ------------------------
print("\n=== SCORE CHECK ===")

bad_scores = df[~df["score"].between(0, 1)]

if len(bad_scores) > 0:
    print(f"‚ùå ERROR: {len(bad_scores)} invalid scores found.")
else:
    print("‚úÖ All scores are between 0 and 1.")


# ------------------------
# 5. GO TERM VALIDITY CHECK
# ------------------------
print("\n=== GO TERM VALIDATION ===")

invalid_go = df[~df["goterm"].isin(all_go_terms)]

if len(invalid_go) > 0:
    print(f"‚ö†Ô∏è WARNING: {len(invalid_go)} GO terms not found in maps.json.")
else:
    print("‚úÖ All GO terms are known and valid.")


# ------------------------
# 6. DUPLICATE CHECK
# ------------------------
print("\n=== DUPLICATE PER-PROTEIN CHECK ===")

dupes = df[df.duplicated(subset=["protein", "goterm"], keep=False)]
if len(dupes) > 0:
    print(f"‚ö†Ô∏è WARNING: Found {len(dupes)} duplicate GO annotations.")
else:
    print("‚úÖ No duplicates detected.")


# ------------------------
# 7. TOP 1500 LIMIT CHECK
# ------------------------
print("\n=== PREDICTION LIMIT CHECK (1500 per protein) ===")

counts = df.groupby("protein").size()
over = counts[counts > 1500]

if len(over) > 0:
    print(f"‚ùå ERROR: {len(over)} proteins exceed 1500 GO terms.")
else:
    print("‚úÖ All proteins obey the 1500-term rule.")


# ------------------------
# 8. SAFE SAMPLE PREVIEW
# ------------------------
print("\n=== SAMPLE PREVIEW ===")

print(df.head(10))

first_protein = df["protein"].iloc[0]   # SAFE now because df is non-empty
print(f"\nSample GO terms for protein {first_protein}:")
print(df[df["protein"] == first_protein].head(20))


print("\n VALIDATION COMPLETE ‚Äî FILE IS STRUCTURALLY SOUND.")
