# ProtBERT and K-mer TF-IDF Fusion

This notebook implements a 2-tower fusion model combining ProtBERT embeddings and k-mer TF-IDF features for protein function prediction.

In [None]:
!pip -q install "protobuf==5.29.4"

### Import Libraries

In [None]:
import os, gc, random
from collections import defaultdict, Counter
from typing import Dict, List, Set, Tuple

import numpy as np
import pandas as pd
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from scipy import sparse

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

from transformers import Trainer, TrainingArguments, EarlyStoppingCallback

### File Paths

In [None]:
SAMPLE_SUBMISSION_TSV = "/kaggle/input/cafa-6-protein-function-prediction/sample_submission.tsv"
IA_TSV = "/kaggle/input/cafa-6-protein-function-prediction/IA.tsv"
TESTSUPERSET_TAXON_LIST_TSV = "/kaggle/input/cafa-6-protein-function-prediction/Test/testsuperset-taxon-list.tsv"
TRAIN_TERMS_TSV = "/kaggle/input/cafa-6-protein-function-prediction/Train/train_terms.tsv"
TRAIN_TAXONOMY_TSV = "/kaggle/input/cafa-6-protein-function-prediction/Train/train_taxonomy.tsv"

TESTSUPERSET_FASTA = "/kaggle/input/cafa-6-protein-function-prediction/Test/testsuperset.fasta"
TRAIN_SEQUENCES_FASTA = "/kaggle/input/cafa-6-protein-function-prediction/Train/train_sequences.fasta"

GO_BASIC_OBO = "/kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo"
OUTPUT_TSV = "/kaggle/working/submission.tsv"

### Configuration

In [None]:
CONFIG = {
    "TRAIN_FASTA": TRAIN_SEQUENCES_FASTA,
    "TRAIN_TERMS": TRAIN_TERMS_TSV,
    "GO_OBO": GO_BASIC_OBO,
    "IA_FILE": IA_TSV,
    "TEST_FASTA": TESTSUPERSET_FASTA,
    "OUTPUT_SUBMISSION": OUTPUT_TSV,

    "PROTBERT_TRAIN_EMB": "/kaggle/input/nnn-cafa6-protbert-embedding/train_embeddings.npy",
    "PROTBERT_TEST_EMB":  "/kaggle/input/nnn-cafa6-protbert-embedding/test_embeddings.npy",

    "TOP_K_LABELS": 5000, 

    "TFIDF_K": 3, 
    "TFIDF_NGRAM_RANGE": (3, 4),
    "TFIDF_MAX_FEATURES": 20000, 
    "TFIDF_MIN_DF": 2, 
    "TFIDF_SUBLINEAR_TF": True,

    "RANDOM_SEED": 42,
    "TRAIN_BS": 64,
    "EVAL_BS": 128,
    "EPOCHS_TORCH": 20,
    "LR_TORCH": 5e-4,
    "PATIENCE": 3,
    "WEIGHT_DECAY": 1e-2,

    "PROJ_DIM": 512,      
    "FUSION_HID": 512,
    "DROPOUT_FUSION": 0.3,
    "FP16": True,

    "THRESH_COARSE": [i/100 for i in range(1, 31)],
    "THRESH_REFINE_STEP": 0.002,
    "THRESH_REFINE_SPAN": 0.03,

    "TOP_K_PER_PROTEIN": 200,

    "PROPAGATE_TRAIN_LABELS": True,
    "PROPAGATE_PREDICTIONS": True,
    "PROPAGATE_ITERATIONS": 3,

    "EMBED_BATCH_SIZE": 1024,
    "PREDICT_BATCH_SIZE": 2048,  
}

In [None]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(CONFIG["RANDOM_SEED"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("[env] device:", device)

## Utility Functions

In [None]:
def read_fasta(path: str) -> Dict[str, str]:
    seqs = {}
    with open(path, "r") as f:
        pid = None; seq_parts = []
        for line in f:
            line=line.strip()
            if line.startswith(">"):
                if pid: seqs[pid] = "".join(seq_parts)
                header=line[1:].split()[0]
                if "|" in header:
                    parts=header.split("|"); pid = parts[1] if len(parts)>=2 else header
                else:
                    pid = header
                seq_parts=[]
            else:
                seq_parts.append(line.strip())
        if pid: seqs[pid] = "".join(seq_parts)
    print(f"[io] Read {len(seqs)} sequences from {path}")
    return seqs


## GO Ontology Parser

In [None]:
def read_train_terms(path: str) -> Dict[str, List[str]]:
    mapping = defaultdict(list)
    df = pd.read_csv(path, sep="\t", header=None, names=["protein","go","ont"], dtype=str)
    for _, r in df.iterrows():
        mapping[r.protein].append(r.go)
    print(f"[io] Read training annotations for {len(mapping)} proteins from {path}")
    return mapping

## K-mer Feature Extraction

In [None]:
def parse_obo(go_obo_path: str) -> Tuple[Dict[str, Set[str]], Dict[str, Set[str]]]:
    parents = defaultdict(set); children = defaultdict(set)
    if not os.path.exists(go_obo_path): 
        return parents, children
    with open(go_obo_path,"r") as f:
        cur_id=None
        for line in f:
            line=line.strip()
            if line=="[Term]": 
                cur_id=None
            elif line.startswith("id: "): 
                cur_id=line.split("id: ")[1].strip()
            elif line.startswith("is_a: "):
                pid=line.split()[1].strip()
                if cur_id: 
                    parents[cur_id].add(pid); children[pid].add(cur_id)
            elif line.startswith("relationship: part_of "):
                parts=line.split()
                if len(parts)>=3:
                    pid=parts[2].strip()
                    if cur_id: 
                        parents[cur_id].add(pid); children[pid].add(cur_id)
    print(f"[io] Parsed OBO: {len(parents)} nodes with parents")
    return parents, children

## PyTorch Dataset

In [None]:
def get_ancestors(go_id: str, parents: Dict[str, Set[str]]) -> Set[str]:
    ans=set(); stack=[go_id]
    while stack:
        cur=stack.pop()
        for p in parents.get(cur,[]):
            if p not in ans:
                ans.add(p); stack.append(p)
    return ans

## Model Architecture

In [None]:
from scipy import sparse as sp

class Evaluator:
    def __init__(self, ia_weights: Dict[str, float], go_terms: List[str]):
        self.go_terms = go_terms
        self.weights = np.array([ia_weights.get(t, 1.0) for t in go_terms], dtype=np.float32)

    def _dense_bool(self, y_true):
        if sp.issparse(y_true):
            return y_true.astype(np.bool_).toarray()
        return (y_true > 0.5)

    def _f_from_prec_rec(self, p: float, r: float, eps: float = 1e-12) -> float:
        return (2.0 * p * r) / (p + r + eps)

    def fmax(self, y_true, y_prob: np.ndarray, thresholds: np.ndarray) -> Tuple[float, float]:
        y_true_b = self._dense_bool(y_true)
        best_t, best_f = 0.5, -1.0

        for t in thresholds:
            y_pred_b = (y_prob >= t)
            tp = np.logical_and(y_true_b, y_pred_b).sum(axis=1).astype(np.float32)
            fp = np.logical_and(~y_true_b, y_pred_b).sum(axis=1).astype(np.float32)
            fn = np.logical_and(y_true_b, ~y_pred_b).sum(axis=1).astype(np.float32)

            has_pred = (tp + fp) > 0
            has_true = (tp + fn) > 0

            p = (tp[has_pred] / (tp[has_pred] + fp[has_pred] + 1e-12)).mean() if has_pred.any() else 0.0
            r = (tp[has_true] / (tp[has_true] + fn[has_true] + 1e-12)).mean() if has_true.any() else 0.0

            f = self._f_from_prec_rec(float(p), float(r))
            if f > best_f:
                best_f, best_t = f, float(t)

        return best_t, best_f

    def ia_fmax(self, y_true, y_prob: np.ndarray, thresholds: np.ndarray) -> Tuple[float, float]:
        y_true_b = self._dense_bool(y_true)
        w = self.weights[None, :]

        best_t, best_f = 0.5, -1.0
        for t in thresholds:
            y_pred_b = (y_prob >= t)

            tp_w = (np.logical_and(y_true_b, y_pred_b) * w).sum(axis=1).astype(np.float32)
            fp_w = (np.logical_and(~y_true_b, y_pred_b) * w).sum(axis=1).astype(np.float32)
            fn_w = (np.logical_and(y_true_b, ~y_pred_b) * w).sum(axis=1).astype(np.float32)

            has_pred = (tp_w + fp_w) > 0
            has_true = (tp_w + fn_w) > 0

            p = (tp_w[has_pred] / (tp_w[has_pred] + fp_w[has_pred] + 1e-12)).mean() if has_pred.any() else 0.0
            r = (tp_w[has_true] / (tp_w[has_true] + fn_w[has_true] + 1e-12)).mean() if has_true.any() else 0.0

            f = self._f_from_prec_rec(float(p), float(r))
            if f > best_f:
                best_f, best_t = f, float(t)

        return best_t, best_f

In [None]:
def read_IA_safe(path):
    if not os.path.exists(path): return {}
    df=pd.read_csv(path, sep="\t", header=None, names=["go","ia"], dtype=str)
    d={}
    for _,r in df.iterrows():
        try: d[r.go]=float(r.ia)
        except:
            try: d[r.go]=float(r.ia.replace(",",".")) 
            except: d[r.go]=0.0
    return d

ia_weights = read_IA_safe(CONFIG["IA_FILE"])

## Training Functions

In [None]:
train_seqs = read_fasta(CONFIG["TRAIN_FASTA"])
train_terms = read_train_terms(CONFIG["TRAIN_TERMS"])
parents_map, children_map = parse_obo(CONFIG["GO_OBO"])
test_seqs  = read_fasta(CONFIG["TEST_FASTA"])

train_proteins = [p for p in train_terms.keys() if p in train_seqs]
print(f"[io] {len(train_proteins)} train proteins with sequences available")

if CONFIG["PROPAGATE_TRAIN_LABELS"] and parents_map:
    print("[prep] Propagating train labels up GO graph")
    propagated={}
    for p in train_proteins:
        terms=set(train_terms[p])
        extra=set()
        for t in list(terms):
            extra |= get_ancestors(t, parents_map)
        propagated[p]=sorted(terms|extra)
    train_terms = propagated

all_term_counts = Counter()
for p in train_proteins:
    all_term_counts.update(train_terms[p])

all_terms_sorted = [t for t,_ in all_term_counts.most_common()]
chosen_terms = set(all_terms_sorted[:CONFIG["TOP_K_LABELS"]])
print(f"[prep] Restricting to top-{CONFIG['TOP_K_LABELS']} GO terms")
print(f"[prep] Using {len(chosen_terms)} target GO terms")

for p in train_proteins:
    train_terms[p] = [t for t in train_terms[p] if t in chosen_terms]

X_proteins = train_proteins
y_labels = [train_terms[p] for p in X_proteins]

mlb = MultiLabelBinarizer(classes=sorted(chosen_terms), sparse_output=True)
Y_csr = mlb.fit_transform(y_labels).astype(np.float32)  # sparse to save RAM
print("[prep] Label matrix shape (sparse):", Y_csr.shape)

## Evaluation Functions

In [None]:
print("[emb] Loading ProtBERT embeddings...")
prot_train = np.load(CONFIG["PROTBERT_TRAIN_EMB"]).astype(np.float32)
prot_test  = np.load(CONFIG["PROTBERT_TEST_EMB"]).astype(np.float32)

# Assume aligned; slice to match train_proteins length
prot_train = prot_train[:len(X_proteins)]
print("[emb] ProtBERT train:", prot_train.shape, prot_train.dtype)
print("[emb] ProtBERT test :", prot_test.shape, prot_test.dtype)

print("[emb] Building TF-IDF ...")
tfidf = TfidfVectorizer(
    analyzer="char",
    ngram_range=CONFIG.get("TFIDF_NGRAM_RANGE", (CONFIG["TFIDF_K"], CONFIG["TFIDF_K"])),
    lowercase=False,
    max_features=CONFIG["TFIDF_MAX_FEATURES"],
    min_df=CONFIG["TFIDF_MIN_DF"],
    sublinear_tf=CONFIG.get("TFIDF_SUBLINEAR_TF", True),
    dtype=np.float32,
    norm="l2",
)
train_seqs_list = [train_seqs[pid] for pid in X_proteins]
X_tfidf_train = tfidf.fit_transform(train_seqs_list).astype(np.float32)
print("[emb] TF-IDF train:", X_tfidf_train.shape, X_tfidf_train.dtype)

## Prediction Propagation

In [None]:
idx_all = np.arange(len(X_proteins))
idx_tr, idx_val = train_test_split(idx_all, test_size=0.15, random_state=CONFIG["RANDOM_SEED"])

# Dense y_val for metric
y_val = Y_csr[idx_val].toarray().astype(np.float32)

## Submission Writing

In [None]:
class IndexDataset(Dataset):
    def __init__(self, indices: np.ndarray):
        self.indices = np.asarray(indices, dtype=np.int64)
    def __len__(self): return len(self.indices)
    def __getitem__(self, i):
        return {"idx": int(self.indices[i])}

def fusion_collator(features):
    idx = np.array([f["idx"] for f in features], dtype=np.int64)

    prot = torch.from_numpy(prot_train[idx]).to(torch.float32)  # (B, d_prot)

    tfv = X_tfidf_train[idx].toarray().astype(np.float32)       
    tfv = torch.from_numpy(tfv)

    labels = Y_csr[idx].toarray().astype(np.float32)
    labels = torch.from_numpy(labels)

    return {"prot": prot, "tfidf": tfv, "labels": labels}

train_ds = IndexDataset(idx_tr)
val_ds   = IndexDataset(idx_val)

## Main Pipeline

### Load Data

In [None]:
class FusionMLP(nn.Module):
    def __init__(self, d_prot, d_tfidf, n_labels, proj_dim=512, hid=512, dropout=0.25):
        super().__init__()
        self.prot_proj = nn.Sequential(
            nn.Linear(d_prot, proj_dim),
            nn.BatchNorm1d(proj_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
        )
        self.tfidf_proj = nn.Sequential(
            nn.Linear(d_tfidf, proj_dim),
            nn.BatchNorm1d(proj_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
        )
        self.fuse = nn.Sequential(
            nn.Linear(2 * proj_dim, hid),
            nn.BatchNorm1d(hid),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hid, n_labels),  # logits
        )

    def forward(self, prot=None, tfidf=None, labels=None):
        z1 = self.prot_proj(prot)
        z2 = self.tfidf_proj(tfidf)
        logits = self.fuse(torch.cat([z1, z2], dim=1))
        return {"logits": logits}

model = FusionMLP(
    d_prot=prot_train.shape[1],
    d_tfidf=X_tfidf_train.shape[1],
    n_labels=Y_csr.shape[1],
    proj_dim=CONFIG["PROJ_DIM"],
    hid=CONFIG["FUSION_HID"],
    dropout=CONFIG["DROPOUT_FUSION"],
).to(device)

### Build and Train Model

In [None]:
class MultilabelTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model_accepts_loss_kwargs = False

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs["logits"] if isinstance(outputs, dict) else outputs.logits
        loss = F.binary_cross_entropy_with_logits(logits, labels)
        return (loss, outputs) if return_outputs else loss

args = TrainingArguments(
    output_dir="/kaggle/working/fusion_ckpt",
    do_train=True,
    do_eval=True,

    per_device_train_batch_size=CONFIG["TRAIN_BS"],
    per_device_eval_batch_size=CONFIG["EVAL_BS"],
    num_train_epochs=CONFIG["EPOCHS_TORCH"],
    learning_rate=CONFIG["LR_TORCH"],
    weight_decay=CONFIG["WEIGHT_DECAY"],

    fp16=CONFIG["FP16"],
    logging_steps=50,

    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,

    remove_unused_columns=False,
    report_to="none",
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    max_grad_norm=1.0,
)

trainer = MultilabelTrainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=fusion_collator,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=CONFIG["PATIENCE"])],
)

trainer.train()

In [None]:
pred = trainer.predict(val_ds)
val_logits = pred.predictions  # numpy
y_val_prob = 1.0 / (1.0 + np.exp(-val_logits))  # sigmoid

In [None]:
term_to_idx = {t:i for i,t in enumerate(mlb.classes_)}
classes_list = list(mlb.classes_)

restricted_parents = {}
for t in classes_list:
    restricted_parents[t] = set([p for p in parents_map.get(t, set()) if p in term_to_idx])

parent_idx_list = []
for child_term in classes_list:
    parent_idx_list.append([term_to_idx[p] for p in restricted_parents.get(child_term, [])])

def propagate_batch(pred_batch: np.ndarray, parent_idx_list, iterations=3):
    for _ in range(iterations):
        changed = False
        for child_idx in range(pred_batch.shape[1]):
            child_scores = pred_batch[:, child_idx]
            for pidx in parent_idx_list[child_idx]:
                mask = child_scores > pred_batch[:, pidx]
                if mask.any():
                    pred_batch[mask, pidx] = child_scores[mask]
                    changed = True
        if not changed:
            break
    return pred_batch

### Test Predictions

In [None]:
y_true_val = Y_csr[idx_val]  # (N_val, L) sparse

y_eval_prob = y_val_prob
if CONFIG["PROPAGATE_PREDICTIONS"] and parents_map:
    y_eval_prob = propagate_batch(
        y_eval_prob.copy(),
        parent_idx_list,
        iterations=CONFIG["PROPAGATE_ITERATIONS"]
    )

evaluator = Evaluator(ia_weights, list(mlb.classes_))

thr_coarse = np.array(CONFIG["THRESH_COARSE"], dtype=np.float32)
best_thresh, best_score = evaluator.ia_fmax(y_true_val, y_eval_prob, thr_coarse)

lo = max(0.0, best_thresh - CONFIG["THRESH_REFINE_SPAN"])
hi = min(0.999, best_thresh + CONFIG["THRESH_REFINE_SPAN"])
thr_refine = np.arange(lo, hi + 1e-12, CONFIG["THRESH_REFINE_STEP"], dtype=np.float32)
best_thresh, best_score = evaluator.ia_fmax(y_true_val, y_eval_prob, thr_refine)

print(f"[eval] best_thresh {best_thresh:.3f} best IA-weighted Fmax {best_score:.6f}")

t_f, f_f = evaluator.fmax(y_true_val, y_eval_prob, thr_refine)
print(f"[eval] best_thresh(Fmax) {t_f:.3f} best Fmax {f_f:.6f}")

### Write Submission File

In [None]:
out_fpath = CONFIG["OUTPUT_SUBMISSION"]
open(out_fpath, "w").close()
out_f = open(out_fpath, "a")

test_ids = list(test_seqs.keys())
N_test = len(test_ids)
print(f"[test] Streaming {N_test} test sequences in batches of {CONFIG['EMBED_BATCH_SIZE']} / predict {CONFIG['PREDICT_BATCH_SIZE']}")

def tfidf_transform_batch(seq_list: List[str]) -> np.ndarray:
    return tfidf.transform(seq_list).toarray().astype(np.float32)

embed_plm_buf = []
embed_tfidf_buf = []
embed_ids_buf = []

trainer.model.to(device)
trainer.model.eval()

for i in range(0, N_test, CONFIG["EMBED_BATCH_SIZE"]):
    batch_ids = test_ids[i:i+CONFIG["EMBED_BATCH_SIZE"]]
    seqs_batch = [test_seqs[pid] for pid in batch_ids]

    plm_mini = prot_test[i:i+len(batch_ids)]
    tfidf_mini = tfidf_transform_batch(seqs_batch)

    embed_plm_buf.append(plm_mini)
    embed_tfidf_buf.append(tfidf_mini)
    embed_ids_buf.extend(batch_ids)

    buffered = sum(a.shape[0] for a in embed_plm_buf)
    if buffered >= CONFIG["PREDICT_BATCH_SIZE"] or (i+CONFIG["EMBED_BATCH_SIZE"] >= N_test):
        Xp = np.vstack(embed_plm_buf).astype(np.float32)
        Xt = np.vstack(embed_tfidf_buf).astype(np.float32)

        with torch.no_grad(), torch.cuda.amp.autocast(enabled=(device.type=="cuda")):
            xb1 = torch.from_numpy(Xp).to(device)
            xb2 = torch.from_numpy(Xt).to(device)
            logits = trainer.model(prot=xb1, tfidf=xb2)["logits"]
            probs = torch.sigmoid(logits).float().cpu().numpy()

        if CONFIG["PROPAGATE_PREDICTIONS"] and parents_map:
            probs = propagate_batch(probs, parent_idx_list, iterations=CONFIG["PROPAGATE_ITERATIONS"])

        top_k = CONFIG["TOP_K_PER_PROTEIN"]
        for ridx, pid in enumerate(embed_ids_buf):
            row = probs[ridx]
            idxs = np.argsort(row)[-top_k:]
            idxs = [int(x) for x in idxs if row[x] > 1e-6]
            idxs = sorted(idxs, key=lambda x: row[x], reverse=True)
            for idx in idxs:
                score = float(row[idx])
                if score <= 0.0: 
                    continue
                go_id = mlb.classes_[idx]
                out_f.write(f"{pid}\t{go_id}\t{score:.6f}\n")

        out_f.flush()
        embed_plm_buf, embed_tfidf_buf, embed_ids_buf = [], [], []
        del Xp, Xt, probs, logits, xb1, xb2
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    if (i // CONFIG["EMBED_BATCH_SIZE"]) % 50 == 0:
        print(f"[stream] processed {i} / {N_test}")

out_f.close()
print(f"[done] Submission written to {CONFIG['OUTPUT_SUBMISSION']}")