In [None]:
# TSV FILES
SAMPLE_SUBMISSION_TSV = "/kaggle/input/cafa-6-protein-function-prediction/sample_submission.tsv"
IA_TSV = "/kaggle/input/cafa-6-protein-function-prediction/IA.tsv"
TESTSUPERSET_TAXON_LIST_TSV = "/kaggle/input/cafa-6-protein-function-prediction/Test/testsuperset-taxon-list.tsv"
TRAIN_TERMS_TSV = "/kaggle/input/cafa-6-protein-function-prediction/Train/train_terms.tsv"
TRAIN_TAXONOMY_TSV = "/kaggle/input/cafa-6-protein-function-prediction/Train/train_taxonomy.tsv"

# FASTA FILES
TESTSUPERSET_FASTA = "/kaggle/input/cafa-6-protein-function-prediction/Test/testsuperset.fasta"
TRAIN_SEQUENCES_FASTA = "/kaggle/input/cafa-6-protein-function-prediction/Train/train_sequences.fasta"

# OBO FILE
GO_BASIC_OBO = "/kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo"

# OUTPUT FILE
OUTPUT_TSV = "/kaggle/working/submission.tsv"

print("Files are listed!!!")

In [None]:
# ------------------------------------------------------------
# CONFIG
# ------------------------------------------------------------
CONFIG = {
    "TRAIN_FASTA": TRAIN_SEQUENCES_FASTA,
    "TRAIN_TERMS": TRAIN_TERMS_TSV,
    "TRAIN_TAXONOMY": TRAIN_TAXONOMY_TSV,
    "GO_OBO": GO_BASIC_OBO,
    "IA_FILE": IA_TSV,
    "TEST_FASTA": TESTSUPERSET_FASTA,
    "TEST_TAXON_LIST": TESTSUPERSET_TAXON_LIST_TSV,
    "SAMPLE_SUBMISSION": SAMPLE_SUBMISSION_TSV,
    "OUTPUT_SUBMISSION": OUTPUT_TSV,
    # Embedding settings:
    "USE_PLM_MODEL": False,  # set False to force TF-IDF fallback
    # If using TF/HF transformer model, either place checkpoint in dataset and point here,
    # or use model name if internet enabled. On Kaggle usually you must provide local model.
    "PLM_MODEL_NAME_OR_PATH": "/kaggle/input/esm-2/keras/esm2_t6_8m/1",  
    "PLM_BATCH_SIZE": 8,
    # Memory & batch sizes for streaming
    "EMBED_BATCH_SIZE": 8,          # batch size used when embedding (train & test)
    "PREDICT_BATCH_SIZE": 64,       # how many examples to predict at once (fit to memory)
    # Label limitation to keep model small
    "TOP_K_LABELS": 3000,
    # Model training hyperparams
    "RANDOM_SEED": 42,
    "BATCH_SIZE": 32,
    "EPOCHS": 10,
    "LEARNING_RATE": 1e-3,
    "HIDDEN_UNITS": 512,
    "DROPOUT": 0.2,
    # Submission postprocessing
    "TOP_K_PER_PROTEIN": 200,
    "GLOBAL_THRESHOLD_SEARCH": True,
    "THRESHOLD_GRID": [i/100 for i in range(1, 51)],
    # Propagation
    "PROPAGATE_TRAIN_LABELS": True,
    "PROPAGATE_PREDICTIONS": True,

    # On-disk paths for memmaps/embeddings
    "TRAIN_EMB_MEMMAP": "/kaggle/working/train_embs.memmap",
    "TRAIN_EMB_SHAPE_FILE": "/kaggle/working/train_embs_shape.npy",
}


print("config...")

In [None]:
# ------------------------------------------------------------
# imports
# ------------------------------------------------------------
import os, gc, math, random, sys, time
from collections import defaultdict, Counter
from typing import List, Dict, Tuple, Set
import numpy as np
import pandas as pd
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
import tensorflow as tf
from tensorflow.keras import layers, models, callbacks, optimizers, losses, metrics, regularizers
import esm   # after installing from GitHub
import torch
import numpy as np
import gc
from pathlib import Path

# optional PLM imports
try:
    import torch
    from transformers import AutoTokenizer, AutoModel
    TORCH_AVAILABLE = True
except Exception as e:
    TORCH_AVAILABLE = False

# deterministic seeds
random.seed(CONFIG["RANDOM_SEED"])
np.random.seed(CONFIG["RANDOM_SEED"])
tf.random.set_seed(CONFIG["RANDOM_SEED"])

from tensorflow.keras import mixed_precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)


print("import done!!!")

In [None]:
# ------------------------------------------------------------
# tiny utils for FASTA/TSV/OBO parsing 
# ------------------------------------------------------------
def read_fasta(path: str) -> Dict[str, str]:
    seqs = {}
    with open(path, "r") as f:
        pid = None; seq_parts = []
        for line in f:
            line=line.strip()
            if line.startswith(">"):
                if pid: seqs[pid] = "".join(seq_parts)
                header=line[1:].split()[0]
                if "|" in header:
                    parts=header.split("|"); pid = parts[1] if len(parts)>=2 else header
                else:
                    pid = header
                seq_parts=[]
            else:
                seq_parts.append(line.strip())
        if pid: seqs[pid] = "".join(seq_parts)
    print(f"[io] Read {len(seqs)} sequences from {path}")
    return seqs

def read_train_terms(path: str) -> Dict[str, List[str]]:
    mapping = defaultdict(list)
    df = pd.read_csv(path, sep="\t", header=None, names=["protein","go","ont"], dtype=str)
    for _, r in df.iterrows(): mapping[r.protein].append(r.go)
    print(f"[io] Read training annotations for {len(mapping)} proteins from {path}")
    return mapping

def parse_obo(go_obo_path: str) -> Tuple[Dict[str, Set[str]], Dict[str, Set[str]]]:
    parents = defaultdict(set); children = defaultdict(set)
    if not os.path.exists(go_obo_path): return parents, children
    with open(go_obo_path,"r") as f:
        cur_id=None
        for line in f:
            line=line.strip()
            if line=="[Term]": cur_id=None
            elif line.startswith("id: "): cur_id=line.split("id: ")[1].strip()
            elif line.startswith("is_a: "):
                pid=line.split()[1].strip()
                if cur_id: parents[cur_id].add(pid); children[pid].add(cur_id)
            elif line.startswith("relationship: part_of "):
                parts=line.split(); 
                if len(parts)>=3:
                    pid=parts[2].strip()
                    if cur_id: parents[cur_id].add(pid); children[pid].add(cur_id)
    print(f"[io] Parsed OBO: {len(parents)} nodes with parents")
    return parents, children

def get_ancestors(go_id: str, parents: Dict[str, Set[str]]) -> Set[str]:
    ans=set(); stack=[go_id]
    while stack:
        cur=stack.pop()
        for p in parents.get(cur,[]): 
            if p not in ans:
                ans.add(p); stack.append(p)
    return ans

In [None]:
# ------------------------------------------------------------
# Load the data
# ------------------------------------------------------------
train_seqs = read_fasta(CONFIG["TRAIN_FASTA"])
train_terms = read_train_terms(CONFIG["TRAIN_TERMS"])
parents_map, children_map = parse_obo(CONFIG["GO_OBO"])
test_seqs = read_fasta(CONFIG["TEST_FASTA"])

# Keep proteins present in both seq & terms
train_proteins = [p for p in train_terms.keys() if p in train_seqs]
print(f"[io] {len(train_proteins)} train proteins with sequences available")

# propagate train labels (optional)
if CONFIG["PROPAGATE_TRAIN_LABELS"] and parents_map:
    print("[prep] Propagating train labels up GO graph")
    propagated={}
    for p in train_proteins:
        terms=set(train_terms[p])
        extra=set()
        for t in list(terms): extra |= get_ancestors(t, parents_map)
        propagated[p]=sorted(terms|extra)
    train_terms = propagated

# choose top-k labels (to control model size)
all_term_counts = Counter()
for p in train_proteins: all_term_counts.update(train_terms[p])
all_terms_sorted = [t for t,_ in all_term_counts.most_common()]
if CONFIG["TOP_K_LABELS"] is not None:
    chosen_terms = set(all_terms_sorted[:CONFIG["TOP_K_LABELS"]])
    print(f"[prep] Restricting to top-{CONFIG['TOP_K_LABELS']} GO terms")
else:
    chosen_terms = set(all_terms_sorted)
print(f"[prep] Using {len(chosen_terms)} target GO terms")

for p in train_proteins:
    train_terms[p] = [t for t in train_terms[p] if t in chosen_terms]

X_proteins = train_proteins
y_labels = [train_terms[p] for p in X_proteins]

mlb = MultiLabelBinarizer(classes=sorted(chosen_terms))
Y = mlb.fit_transform(y_labels).astype(np.float32)
print("[prep] Label matrix shape:", Y.shape)

In [None]:
# ------------------------------------------------------------
# PLM embedding helpers (ESM2); memory-conscious: produce numpy arrays per-batch
# ------------------------------------------------------------
def seqs_for_plm_input_esm(seqs: List[str]) -> List[str]:
    # ESM expects raw sequences (no spaces); we simply uppercase and replace unknowns
    out=[]
    for s in seqs:
        s2 = s.upper().replace("U","X").replace("O","X").replace("B","X").replace("Z","X")
        out.append(s2)
    return out

def embed_with_plm_to_memmap(all_seq_ids: List[str],
                             seqs_dict: Dict[str,str],
                             memmap_path: str,
                             batch_size:int=8,
                             model_name_or_path: str = CONFIG["PLM_MODEL_NAME_OR_PATH"]):
    """
    Compute embeddings using the ESM loader and write to disk-backed memmap.
    Returns memmap object and embedding dimension.
    This function expects the local directory model_name_or_path to contain the ESM checkpoint
    produced by the ESM tooling (e.g., esm2_t33_650M_UR50D.pt or a model dir).
    """
    if not TORCH_AVAILABLE:
        raise RuntimeError("Torch not available; cannot load ESM model.")

    model_dir = str(model_name_or_path)
    if not Path(model_dir).exists():
        raise FileNotFoundError(f"ESM model path not found: {model_dir}")

    # Try to load via esm loader that understands local formats
    try:
        # If model_name_or_path is a directory that contains a model checkpoint,
        # this loader will try to read it. If it points directly to a .pt file it also works.
        print(f"[esm] Loading local ESM model from: {model_dir}")
        model, alphabet = esm.pretrained.load_model_and_alphabet_local(model_dir)
    except Exception as e:
        # some ESM checkpoints use slightly different utilities - try the convenience function names:
        try:
            # If the user packaged a directory like "esm2_t33_650M_UR50D" that contains model.pt
            print(f"[esm] load_model_and_alphabet_local failed, attempting esm2 convenience loader...")
            model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()  # fallback: only works if model files are accessible
        except Exception as e2:
            raise RuntimeError("Failed to load ESM model via esm.pretrained. Ensure the directory contains a valid ESM checkpoint.") from e

    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    batch_converter = alphabet.get_batch_converter()
    N = len(all_seq_ids)

    # Determine embedding dimension using a tiny sample (1 sequence)
    sample_seq = seqs_dict[all_seq_ids[0]]
    _, _, sample_tokens = batch_converter([(all_seq_ids[0], sample_seq)])
    sample_tokens = sample_tokens.to(device)
    with torch.no_grad():
        results = model(sample_tokens, repr_layers=[model.num_layers], return_contacts=False)
        # pick the highest repr layer in results
        repr_keys = sorted(results["representations"].keys())
        last_layer_key = repr_keys[-1]
        emb_dim = results["representations"][last_layer_key].shape[-1]
    # create memmap file on disk (float32)
    mem = np.memmap(memmap_path, dtype=np.float32, mode="w+", shape=(N, int(emb_dim)))

    idx = 0
    for i in range(0, N, batch_size):
        batch_ids = all_seq_ids[i:i+batch_size]
        # Prepare list of (id, seq) tuples
        batch_pairs = [(pid, seqs_dict[pid]) for pid in batch_ids]
        labels, sequences, tokens = batch_converter(batch_pairs)  # tokens: (B, L)
        tokens = tokens.to(device)
        with torch.no_grad():
            results = model(tokens, repr_layers=[model.num_layers], return_contacts=False)
            repr_keys = sorted(results["representations"].keys())
            last_layer_key = repr_keys[-1]
            repr_tensor = results["representations"][last_layer_key].cpu()   # (B, L, C)
        # For each sequence, slice out residues and mean-pool (drop BOS/EOS token at positions 0 and -1)
        for j, seq in enumerate(sequences):
            seq_len = len(seq)
            # ESM token layout: tokens include BOS at pos 0 and EOS at pos seq_len+1 so we slice 1:seq_len+1
            seq_repr = repr_tensor[j, 1:seq_len+1, :]   # (seq_len, C)
            seq_embed = seq_repr.mean(axis=0).numpy().astype(np.float32)
            mem[idx + j, :] = seq_embed
        idx += len(batch_ids)

        # free intermediate tensors and empty CUDA cache
        del tokens, results, repr_tensor, seq_repr, seq_embed
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        print(f"[esm] Wrote embeddings {i}..{i+len(batch_ids)} / {N}")

    mem.flush()
    print(f"[esm] Finished writing memmap to {memmap_path} with dim {emb_dim}")
    return mem, int(emb_dim)

In [None]:
# ------------------------------------------------------------
# Embedding training set (memmap) OR fallback TF-IDF (also memory-friendly)
# ------------------------------------------------------------
USE_PLM = CONFIG["USE_PLM_MODEL"] and TORCH_AVAILABLE
if USE_PLM:
    # compute training embeddings to disk memmap to avoid storing huge array in RAM
    train_ids = X_proteins
    train_memmap_path = CONFIG["TRAIN_EMB_MEMMAP"]
    if not os.path.exists(train_memmap_path):
        print("[plm] Computing train embeddings to memmap. This may take time but keeps RAM low.")
        train_mem, D = embed_with_plm_to_memmap(train_ids, train_seqs, train_memmap_path,
                                                batch_size=CONFIG["EMBED_BATCH_SIZE"],
                                                model_name_or_path=CONFIG["PLM_MODEL_NAME_OR_PATH"])
        # Save shape info for later reopening
        np.save(CONFIG["TRAIN_EMB_SHAPE_FILE"], np.array([len(train_ids), D], dtype=np.int64))
        # Keep mem as memmap object reference
        emb_train = np.array(train_mem)  # small temporary conversion for training. If too large, we will reopen memmap later.
        # To be safe, copy to np.float32 if necessary
        if emb_train.dtype != np.float32: emb_train = emb_train.astype(np.float32)
        del train_mem; gc.collect()
        if torch.cuda.is_available(): torch.cuda.empty_cache()
    else:
        # memmap already exists: load shape and open
        shape = np.load(CONFIG["TRAIN_EMB_SHAPE_FILE"])
        Nshape, D = int(shape[0]), int(shape[1])
        emb_train = np.memmap(train_memmap_path, dtype=np.float32, mode="r", shape=(Nshape, D))
    embedding_method = "plm"
else:
    # TF-IDF fallback (fits in memory usually; if not, you can also memmap)
    print("[fallback] Using TF-IDF k-mer embeddings for train (memory-friendly for moderate sizes)")
    def kmers(seq, k=3):
        return " ".join([seq[i:i+k] for i in range(len(seq)-k+1)])
    train_texts = [kmers(train_seqs[p], k=3) for p in X_proteins]
    tfidf = TfidfVectorizer(analyzer="word", token_pattern=r"(?u)\b\w+\b", max_features=20000)
    emb_train = tfidf.fit_transform(train_texts).astype(np.float32).toarray()
    embedding_method = "tfidf"

print(f"[embed] Train embeddings: method={embedding_method}, shape={emb_train.shape}, dtype={emb_train.dtype}")


In [None]:
# ------------------------------------------------------------
# Train / validation split (we load train embeddings into memory here - if it's too big, we can train with memmap directly)
# ------------------------------------------------------------
# If embeddings are memmap and too big, we can load a subset or use a generator; here we assume emb_train fits for training.
X_emb = emb_train
y = Y
X_tr, X_val, y_tr, y_val, prot_tr, prot_val = train_test_split(
    X_emb, y, X_proteins, test_size=0.15, random_state=CONFIG["RANDOM_SEED"]
)
print("[train] shapes:", X_tr.shape, X_val.shape, y_tr.shape, y_val.shape)


In [None]:
def focal_loss(gamma=2., alpha=0.25):
    def loss_fn(y_true, y_pred):
        eps=1e-8
        p = tf.clip_by_value(y_pred, eps, 1-eps)
        pt = tf.where(tf.equal(y_true, 1.0), p, 1-p)
        w = tf.where(tf.equal(y_true, 1.0), alpha, 1-alpha)
        loss = - w * ((1-pt)**gamma) * tf.math.log(pt)
        return tf.reduce_mean(tf.reduce_sum(loss, axis=-1))
    return loss_fn


In [None]:
def se_block_keras(x, ratio=16):
    """Squeeze-and-Excite using Keras layers (works on vector x of shape (batch, dim))."""
    dim = int(x.shape[-1])
    # Keras GlobalAveragePooling1D expects shape (batch, steps, channels). We reshape.
    reshaped = layers.Reshape((dim, 1))(x)                  # (batch, dim, 1)
    se = layers.GlobalAveragePooling1D()(reshaped)          # (batch, 1)
    se = layers.Dense(max(dim // ratio, 8), activation='relu')(se)
    se = layers.Dense(dim, activation='sigmoid')(se)       # (batch, dim)
    se = layers.Reshape((dim,))(se)                        # ensure shape
    return layers.Multiply()([x, se])                      # elementwise scale

def attention_pooling_keras(x, hidden=128):
    """
    Learned attention pooling across feature dimensions.
    Input x: (batch, dim)
    We create a pseudo-sequence of length=dim, channel=1 and apply Conv1D attention.
    """
    dim = int(x.shape[-1])
    seq = layers.Reshape((dim, 1))(x)                      # (batch, dim, 1)
    att = layers.Conv1D(hidden, kernel_size=1, activation='tanh')(seq)  # (batch, dim, hidden)
    att = layers.Conv1D(1, kernel_size=1)(att)             # (batch, dim, 1)
    att = layers.Reshape((dim,))(att)                      # (batch, dim)
    att = layers.Activation('softmax')(att)                # (batch, dim)
    att = layers.Reshape((dim, 1))(att)                    # (batch, dim, 1)
    pooled = layers.Multiply()([seq, att])                 # (batch, dim, 1)
    pooled = layers.Lambda(lambda z: tf.reduce_sum(z, axis=1))(pooled)  # (batch, 1)
    pooled = layers.Reshape((1,))(pooled)                  # (batch,)
    return pooled

def build_model_A(input_dim, output_dim, hidden_units=512, dropout=0.4, lr=1e-3, l2=1e-6):
    inp = layers.Input(shape=(input_dim,), dtype=tf.float32)
    x = layers.Dense(hidden_units, activation='relu', kernel_regularizer=regularizers.l2(l2))(inp)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(dropout)(x)

    # Residual block 1
    r = layers.Dense(hidden_units, activation='relu', kernel_regularizer=regularizers.l2(l2))(x)
    r = layers.BatchNormalization()(r)
    r = layers.Dropout(dropout)(r)
    x = layers.Add()([x, r])

    # Squeeze-and-Excite
    x = se_block_keras(x, ratio=16)

    # Residual block 2 (project + add)
    r = layers.Dense(hidden_units//2, activation='relu', kernel_regularizer=regularizers.l2(l2))(x)
    r = layers.BatchNormalization()(r)
    r = layers.Dropout(dropout)(r)
    x_proj = layers.Dense(hidden_units//2, activation=None)(x)
    x = layers.Add()([x_proj, r])
    x = layers.Activation('relu')(x)

    # attention pooling across features
    pooled = attention_pooling_keras(x, hidden=128)   # (batch,)

    # combine pooled with global average
    gap = layers.GlobalAveragePooling1D()(layers.Reshape((int(x.shape[-1]), 1))(x))  # (batch,)
    combined = layers.Concatenate()([gap, pooled])
    combined = layers.Flatten()(layers.Reshape((2,1))(combined))  # flatten to shape (batch, 2)

    # Simple classifier head
    x = layers.Dense(512, activation='relu')(combined)
    x = layers.Dropout(dropout)(x)
    out = layers.Dense(output_dim, activation='sigmoid', dtype='float32')(x)  # keep outputs float32

    model = models.Model(inputs=inp, outputs=out)
    model.compile(optimizer=optimizers.Adam(learning_rate=lr),
                  loss=losses.BinaryCrossentropy(),
                  metrics=[metrics.Precision(), metrics.Recall()])
    return model


In [None]:
class LabelEmbeddingLayer( layers.Layer ):
    def __init__(self, num_labels, k, **kwargs):
        super().__init__(**kwargs)
        self.num_labels = num_labels
        self.k = k

    def build(self, input_shape):
        # label embeddings matrix (M, k) and bias (M,)
        self.L = self.add_weight(name='label_emb', shape=(self.num_labels, self.k),
                                 initializer='glorot_uniform', trainable=True)
        self.b = self.add_weight(name='label_bias', shape=(self.num_labels,),
                                 initializer='zeros', trainable=True)
        super().build(input_shape)

    def call(self, inputs):
        # inputs: (batch, k)
        # logits = inputs @ L^T + b
        logits = tf.matmul(inputs, self.L, transpose_b=True)  # (batch, M)
        logits = logits + self.b
        return logits  # Keras will handle activation outside if needed

def build_model_B(input_dim, output_dim, proj_dim=256, hidden_units=512, dropout=0.4, lr=1e-3, l2=1e-6):
    inp = layers.Input(shape=(input_dim,), dtype=tf.float32)
    x = layers.Dense(hidden_units, activation='gelu', kernel_regularizer=regularizers.l2(l2))(inp)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(dropout)(x)
    x = layers.Dense(hidden_units//2, activation='gelu', kernel_regularizer=regularizers.l2(l2))(x)
    x = layers.BatchNormalization()(x)

    proj = layers.Dense(proj_dim, activation=None, name="proj")(x)  # (batch, k)
    logits = LabelEmbeddingLayer(output_dim, proj_dim)(proj)       # (batch, M)
    out = layers.Activation('sigmoid', dtype='float32')(logits)    # (batch, M)

    model = models.Model(inputs=inp, outputs=out)
    model.compile(optimizer=optimizers.Adam(learning_rate=lr),
                  loss=losses.BinaryCrossentropy(),
                  metrics=[metrics.Precision(), metrics.Recall()])
    return model


In [None]:
# choose and build model

D = X_tr.shape[1]
M = y_tr.shape[1]


# model = build_model_A(D, M, hidden_units=1024, dropout=0.4)
model = build_model_B(D, M, proj_dim=256, hidden_units=CONFIG["HIDDEN_UNITS"], dropout=CONFIG["DROPOUT"])

# optional: use focal loss instead of BCE (uncomment to use)
model.compile(optimizer=optimizers.Adam(learning_rate=CONFIG["LEARNING_RATE"]),
              loss=focal_loss(gamma=2.0, alpha=0.25),
              metrics=[metrics.Precision(), metrics.Recall()])

es = callbacks.EarlyStopping(monitor="val_loss", patience=3, restore_best_weights=True, verbose=1)
mc = callbacks.ModelCheckpoint("/kaggle/working/best_model.h5", monitor="val_loss", save_best_only=True, verbose=1)

history = model.fit(X_tr, y_tr, validation_data=(X_val, y_val),
                    epochs=CONFIG["EPOCHS"], batch_size=CONFIG["BATCH_SIZE"],
                    callbacks=[es, mc], verbose=2)


In [None]:
# ------------------------------------------------------------
# Evaluate & select global threshold
# ------------------------------------------------------------
def weighted_precision_recall_f1(y_true, y_pred_bin, ia_map, mlb_obj):
    tp = ((y_true==1)&(y_pred_bin==1)).sum(axis=0).astype(float)
    fp = ((y_true==0)&(y_pred_bin==1)).sum(axis=0).astype(float)
    fn = ((y_true==1)&(y_pred_bin==0)).sum(axis=0).astype(float)
    eps=1e-12
    prec = tp/(tp+fp+eps); rec = tp/(tp+fn+eps)
    f1 = 2*prec*rec/(prec+rec+eps)
    cls = mlb_obj.classes_
    weights = np.array([ia_weights.get(c,1.0) for c in cls], dtype=float) if 'ia_weights' in globals() else np.ones(len(cls))
    weighted_f1 = (f1*weights).sum()/(weights.sum()+eps)
    weighted_prec = (prec*weights).sum()/(weights.sum()+eps)
    weighted_rec = (rec*weights).sum()/(weights.sum()+eps)
    return weighted_prec, weighted_rec, weighted_f1

# load IA weights if available (safe)
def read_IA_safe(path):
    if not os.path.exists(path): return {}
    df=pd.read_csv(path, sep="\t", header=None, names=["go","ia"], dtype=str)
    d={}
    for _,r in df.iterrows():
        try: d[r.go]=float(r.ia)
        except: 
            try: d[r.go]=float(r.ia.replace(",",".")) 
            except: d[r.go]=0.0
    return d

ia_weights = read_IA_safe(CONFIG["IA_FILE"])

y_val_prob = model.predict(X_val, batch_size=CONFIG["BATCH_SIZE"], verbose=0)
best_thresh = 0.5; best_score = -1.0
if CONFIG["GLOBAL_THRESHOLD_SEARCH"]:
    for t in CONFIG["THRESHOLD_GRID"]:
        y_pred_bin = (y_val_prob >= t).astype(int)
        _,_,f1 = weighted_precision_recall_f1(y_val, y_pred_bin, ia_weights, mlb)
        if f1 > best_score: best_score=f1; best_thresh=t
print(f"[eval] best_thresh {best_thresh} best IA-weighted F1 {best_score}")


In [None]:
# ------------------------------------------------------------
# Streaming test-time: embed test sequences batchwise, predict, propagate, write submission lines immediately
# ------------------------------------------------------------
# Precompute helpful mappings for propagation
term_to_idx = {t:i for i,t in enumerate(mlb.classes_)}
idx_to_term = {i:t for t,i in term_to_idx.items()}

# Build parents_map restricted to chosen_terms to speed propagation
restricted_parents = {}
for t in mlb.classes_:
    restricted_parents[t] = set([p for p in parents_map.get(t, set()) if p in term_to_idx])

# small propagation routine operating on a batch_of_probs (N_batch, M)
def propagate_batch(pred_batch: np.ndarray, parents_map_local: Dict[str, Set[str]], classes_list: List[str], iterations=3):
    # pred_batch: float32 shape (B, M)
    B, Mloc = pred_batch.shape
    idx_map = {i:classes_list[i] for i in range(Mloc)}
    term_to_idx_local = {classes_list[i]: i for i in range(Mloc)}
    for _ in range(iterations):
        changed = False
        # vectorized-ish: for each child index, update parent index with max
        # loop over terms (M might be a few thousands => ok per small batch)
        for child_idx in range(Mloc):
            child_term = idx_map[child_idx]
            child_scores = pred_batch[:, child_idx]
            for pterm in parents_map_local.get(child_term, []):
                pidx = term_to_idx_local[pterm]
                # update parent where child's score exceeds parent
                mask = child_scores > pred_batch[:, pidx]
                if mask.any():
                    pred_batch[mask, pidx] = child_scores[mask]
                    changed = True
        if not changed: break
    return pred_batch

# Open submission file for streaming write
out_fpath = CONFIG["OUTPUT_SUBMISSION"]
open(out_fpath, "w").close()  # truncate
out_f = open(out_fpath, "a")

# Create chunked iterator of test sequence IDs
test_ids = list(test_seqs.keys())
N_test = len(test_ids)
print(f"[test] Streaming {N_test} test sequences in batches of {CONFIG['EMBED_BATCH_SIZE']} (embed) / predict {CONFIG['PREDICT_BATCH_SIZE']}")

# If using PLM, prepare tokenizer & model once (on CPU/GPU)
if USE_PLM:
    tokenizer = AutoTokenizer.from_pretrained(CONFIG["PLM_MODEL_NAME_OR_PATH"], do_lower_case=False)
    plm_model = AutoModel.from_pretrained(CONFIG["PLM_MODEL_NAME_OR_PATH"])
    plm_model.eval()
    if torch.cuda.is_available(): plm_model.to(torch.device("cuda"))

# Helper to embed a list of sequences and return numpy array float32 of shape (len(seq_list), D)
def embed_batch_return_np(seq_list: List[str]):
    if USE_PLM:
        proc = seqs_for_plm_input_esm(seq_list)
        with torch.no_grad():
            inputs = tokenizer(proc, return_tensors="pt", padding=True, truncation=True)
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            inputs = {k:v.to(device) for k,v in inputs.items()}
            out = plm_model(**inputs)
            last_hidden = out.last_hidden_state  # (B, L, dim)
            mask = inputs.get("attention_mask", None)
            if mask is not None:
                mask = mask.unsqueeze(-1)
                summed = (last_hidden * mask).sum(1)
                counts = mask.sum(1).clamp(min=1)
                mean_pooled = (summed / counts).cpu().numpy().astype(np.float32)
            else:
                mean_pooled = last_hidden.mean(dim=1).cpu().numpy().astype(np.float32)
        # free GPU memory
        del inputs, out, last_hidden
        if torch.cuda.is_available(): torch.cuda.empty_cache()
        gc.collect()
        return mean_pooled
    else:
        # TF-IDF fallback: transform using vectorizer in small chunk
        texts = [" ".join([seq[i:i+3] for i in range(len(seq)-3+1)]) for seq in seq_list]
        arr = tfidf.transform(texts).astype(np.float32).toarray()
        return arr

# We'll process test samples in small "embed" batches and then accumulate a list of embedded rows
# until we have PREDICT_BATCH_SIZE embeddings ready for model.predict(...) then predict and write submission lines.
embed_batch = []
embed_ids = []

for i in range(0, N_test, CONFIG["EMBED_BATCH_SIZE"]):
    batch_ids = test_ids[i:i+CONFIG["EMBED_BATCH_SIZE"]]
    seqs_batch = [test_seqs[pid] for pid in batch_ids]
    # compute embeddings for this mini-batch
    emb_mini = embed_batch_return_np(seqs_batch)  # shape (Bmini, D)
    # append to buffer
    embed_batch.append(emb_mini)
    embed_ids.extend(batch_ids)
    # if enough buffered to predict, or we're at the end, flush to prediction
    buffered_examples = sum(arr.shape[0] for arr in embed_batch)
    if buffered_examples >= CONFIG["PREDICT_BATCH_SIZE"] or (i+CONFIG["EMBED_BATCH_SIZE"] >= N_test):
        # stack buffered embeddings (should be moderate size)
        X_buffer = np.vstack(embed_batch).astype(np.float32)  # shape (Bbuf, D)
        # predict in one shot for this buffer
        y_buffer_prob = model.predict(X_buffer, batch_size=min(128, X_buffer.shape[0]), verbose=0)
        # propagate per-batch (if desired)
        if CONFIG["PROPAGATE_PREDICTIONS"] and parents_map:
            y_buffer_prob = propagate_batch(y_buffer_prob, restricted_parents, list(mlb.classes_), iterations=3)
        # for each row, write top-K lines
        for ridx, pid in enumerate(embed_ids):
            probs = y_buffer_prob[ridx]
            # pick top-K_PER_PROTEIN indices (and filter near-zero)
            top_k = CONFIG["TOP_K_PER_PROTEIN"]
            if top_k is None:
                idxs = np.where(probs >= best_thresh)[0]
            else:
                idxs = np.argsort(probs)[-top_k:]
                idxs = [int(x) for x in idxs if probs[x] > 1e-6]
            idxs = sorted(idxs, key=lambda x: probs[x], reverse=True)
            for idx in idxs:
                score = float(probs[idx])
                if score <= 0.0: continue
                go_id = mlb.classes_[idx]
                out_f.write(f"{pid}\t{go_id}\t{score:.3f}\n")
        out_f.flush()
        # free buffer
        del X_buffer, y_buffer_prob, embed_batch
        embed_batch = []
        embed_ids = []
        gc.collect()
        if TORCH_AVAILABLE and torch.cuda.is_available(): torch.cuda.empty_cache()
    # small progress print
    if (i // CONFIG["EMBED_BATCH_SIZE"]) % 50 == 0:
        print(f"[stream] processed {i} / {N_test}")

out_f.close()
print(f"[done] Submission written to {CONFIG['OUTPUT_SUBMISSION']}")

In [None]:
# # ------------------------------------------------------------
# # Save model / artifacts if desired (lightweight)
# # ------------------------------------------------------------
# model.save("/kaggle/working/cafa6_baseline_model")
# np.save("/kaggle/working/mlb_classes.npy", np.array(mlb.classes_, dtype=object))
# print("[done] saved model and classes; notebook finished.")