In [1]:
import warnings
warnings.filterwarnings('ignore')


from pathlib import Path
import os, re, shutil, sys
from collections import Counter, defaultdict

import numpy as np
import pandas as pd
from scipy import sparse

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.feature_extraction import DictVectorizer
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier

In [None]:
TRAIN_FASTA = "train_sequences.fasta"
TRAIN_TERMS = "train_terms.tsv"
TRAIN_TAXONOMY = "train_taxonomy.tsv"
TEST_FASTA  = "testsuperset.fasta"
TEST_TAXON  = "testsuperset-taxon-list.tsv"
IA_TSV      = "IA.tsv"
TRAIN_EMB= "train_embeddings.npy"
TEST_EMB = "test_embeddings.npy"
OUT_SUBMISSION = "submission.tsv"

<p align="center">
Data Loading
</p>


In [3]:
# Label space control (important for runtime)
MIN_COUNT     = 50       # keep GO terms appearing >= 2 times
MAX_TERMS     = 1200    # cap number of GO classes (raise if your session allows)

# Emission control (be generous; evaluator picks the threshold)
MIN_PROB      = 1e-6
TOP_K         = 200
CAP_PER_PROT  = 1500



# Training loop controls
MIN_POSITIVES_PER_CLASS   = 5            # skip classes with too few positives
VAL_SIZE                  = 0.1          # class-wise validation split for early stopping
RANDOM_STATE              = 42


# Read protein IDs and amino acid sequences from a FASTA file
def read_fasta_ids_and_seqs(path):
    
    
    ids, seqs, cur_id, cur_seq = [], [], None, []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if line.startswith(">"):
                if cur_id is not None:
                    ids.append(cur_id); seqs.append("".join(cur_seq))
                header = line[1:].strip()
                cur_id = header.split("|")[1].split()[0] if "|" in header else header.split()[0]
                cur_seq = []
            else:
                cur_seq.append(line.strip())
        if cur_id is not None:
            ids.append(cur_id); seqs.append("".join(cur_seq))
    return ids, seqs

# Load training GO annotations and standardize column names
def load_train_terms(path):
    df0 = pd.read_csv(path, sep="\t", dtype=str)
    cols = [c.lower() for c in df0.columns]; df0.columns = cols
    if "entryid" in cols and "term" in cols:
        df = df0.rename(columns={"entryid":"protein_id","term":"go_id"})[["protein_id","go_id"]]
    else:
        df = df0.iloc[:, :2].copy()
        df.columns = ["protein_id","go_id"]
    return df.dropna().drop_duplicates()

# Load protein taxonomy information as a dictionary (protein_id -> taxon)
def load_taxonomy(path):
    df = pd.read_csv(path, sep="\t", header=None, dtype=str)
    if df.shape[1] < 2:
        raise ValueError("taxonomy TSV must have at least two columns.")
    df = df.iloc[:, :2].copy()
    df.columns = ["protein_id","taxon"]
    return dict(zip(df["protein_id"], df["taxon"]))

# Build a mapping from protein IDs to their associated GO term sets
def build_label_sets(df_terms):
    mp = defaultdict(set)
    for p,g in df_terms[["protein_id","go_id"]].itertuples(index=False):
        mp[p].add(g)
    return mp

# Filter GO terms by frequency and optionally limit the total number of terms
def filter_terms_by_freq(p2t, min_count=2, max_terms=None):
    c = Counter()
    for ts in p2t.values(): c.update(ts)
    items = [(g,n) for g,n in c.items() if n >= min_count]
    items.sort(key=lambda x: x[1], reverse=True)
    if max_terms: items = items[:max_terms]
    return set(g for g,_ in items)

# Load the whitelist of allowed GO terms from an IA TSV file
def load_allowed_go_terms(ia_path):
    df = pd.read_csv(ia_path, sep="\t", dtype=str)
    go_cols = [c for c in df.columns if "go" in c.lower() or "term" in c.lower()]
    if go_cols:
        vals = df[go_cols[0]].dropna().astype(str).tolist()
        return {v for v in vals if v.startswith("GO:")}
    s = set()
    for col in df.columns:
        s |= set(df[col].dropna().astype(str))
    return {x for x in s if x.startswith("GO:")}

# Format prediction scores to a fixed number of significant digits
def round_sig(x: float, sig=3) -> str:
    if x <= 0: return None
    s = f"{x:.{sig}g}"
    return s if float(s) > 0 else "0.001"



def get_xgb_device_params():
    # Use GPU if available; else CPU
    try:
        import subprocess
        has_gpu = False
        try:
            out = subprocess.run(["nvidia-smi"], capture_output=True, text=True)
            has_gpu = (out.returncode == 0)
        except Exception:
            has_gpu = False
        if has_gpu:
            return dict(tree_method="gpu_hist", predictor="gpu_predictor")
    except Exception:
        pass
    return dict(tree_method="hist", predictor="auto")





print("Loading training FASTA...")
train_ids, train_seqs = read_fasta_ids_and_seqs(TRAIN_FASTA)
print("Loading training terms...")
terms_df = load_train_terms(TRAIN_TERMS)
prot2terms_all = build_label_sets(terms_df)

# Keep only proteins with labels and present in FASTA
X_text, y_terms, used_ids = [], [], []
for pid, seq in zip(train_ids, train_seqs):
    t = prot2terms_all.get(pid)
    if t:
        used_ids.append(pid); X_text.append(seq); y_terms.append(sorted(t))
print("Training instances used:", len(used_ids))



# Term filtering (frequency) and whitelist (IA.tsv)
kept_freq = filter_terms_by_freq({pid:set(ts) for pid,ts in zip(used_ids, y_terms)},
                                 min_count=MIN_COUNT, max_terms=MAX_TERMS)
allowed_terms = load_allowed_go_terms(IA_TSV)
print("Kept by freq:", len(kept_freq), "| Allowed (IA):", len(allowed_terms))
kept_terms = kept_freq & allowed_terms
print("Kept after whitelist:", len(kept_terms))
y_filtered = [[t for t in ts if t in kept_terms] for ts in y_terms]




Loading training FASTA...
Loading training terms...
Training instances used: 82404
Kept by freq: 1200 | Allowed (IA): 40121
Kept after whitelist: 1200


In [None]:
#Loading Embedding Features
X_train = np.load(TRAIN_EMB)  # shape (N_train, D)
X_test  = np.load(TEST_EMB)   # shape (N_test, D)
print("X_test shape:", X_test.shape)

Loading precomputed embeddings...
X_test shape: (224309, 1280)


In [5]:
from sklearn.preprocessing import MultiLabelBinarizer

all_terms_sorted = sorted(list(kept_terms))
mlb = MultiLabelBinarizer(classes=all_terms_sorted)
Y = mlb.fit_transform(y_filtered)
num_classes = Y.shape[1]


print("Embedding shape:", X_train.shape)
print("Label matrix shape:", Y.shape)
print("Num GO classes:", len(all_terms_sorted))

Embedding shape: (82404, 1280)
Label matrix shape: (82404, 1200)
Num GO classes: 1200


<p align="center">
Training
</p>

In [6]:
from sklearn.model_selection import train_test_split

X_tr, X_val, y_tr, y_val = train_test_split(
    X_train, Y, test_size=VAL_SIZE,
    random_state=RANDOM_STATE, stratify=(Y.sum(axis=1) > 0)
)

In [7]:
import xgboost as xgb
from tqdm import tqdm
device_params = {
    "tree_method": "gpu_hist",
    "predictor": "gpu_predictor"
}
models = []

print("Training XGBoost for each GO term...")

for i in tqdm(range(num_classes), desc="Training XGB"):
    y_i = y_tr[:, i]

    # Skip if class has no positives
    if y_i.sum() == 0:
        models.append(None)
        continue

    dtrain = xgb.DMatrix(X_tr, label=y_i)

    params = {
        "objective": "binary:logistic",
        "eval_metric": "logloss",
        "max_depth": 6,
        "eta": 0.1,
        "subsample": 0.8,
        "colsample_bytree": 0.8,
        **device_params
    }

    bst = xgb.train(
        params,
        dtrain,
        num_boost_round=200
    )

    models.append(bst)


Training XGBoost for each GO term...


Training XGB:   0%|          | 1/1200 [00:04<1:36:28,  4.83s/it]


KeyboardInterrupt: 

<p align="center">
Hierarchy Constraints</p>

In [None]:
# Parse GO hierarchy from an OBO file and build a mapping from each GO term to its parent terms
def load_go_parents(obo_file):
    parents_map = {}
    current_term = None

    with open(obo_file, "r") as f:
        for line in f:
            line = line.strip()

            if line == "[Term]":
                current_term = None

            elif line.startswith("id: GO:"):
                current_term = line.split("id: ")[1]
                parents_map[current_term] = set()

            elif line.startswith("is_a:") and current_term:
                parent = line.split("is_a: ")[1].split(" !")[0]
                parents_map[current_term].add(parent)

    return parents_map
parents_map = load_go_parents("go-basic.obo")

# Restrict GO parent relationships to the set of GO terms used by the model
def restrict_go_parents(parents_map, classes):
    term_to_idx = {t: i for i, t in enumerate(classes)}

    restricted_parents = {
        t: {p for p in parents_map.get(t, set()) if p in term_to_idx}
        for t in classes
    }
    return restricted_parents
restricted_parents = restrict_go_parents(
    parents_map,
    mlb.classes_
)

# Propagate GO prediction scores along the hierarchy to enforce the true-path rule
def propagate_batch(pred_batch, parents_map, classes, iterations=3):
    

    term_to_idx = {t: i for i, t in enumerate(classes)}

    for _ in range(iterations):
        changed = False

        for child, parents in parents_map.items():
            cidx = term_to_idx[child]
            child_scores = pred_batch[:, cidx]

            for p in parents:
                pidx = term_to_idx[p]
                mask = child_scores > pred_batch[:, pidx]
                if mask.any():
                    pred_batch[mask, pidx] = child_scores[mask]
                    changed = True

        if not changed:
            break

    return pred_batch


<p align="center">
Model Evaluation on Validation Set
</p>

In [None]:
from sklearn.metrics import f1_score, precision_score, recall_score, average_precision_score
THRESHOLD = 0.3
print("Predicting validation...")
val_proba = np.zeros((X_val.shape[0], num_classes))

for i in tqdm(range(num_classes), desc="XGB val predict"):
    model = models[i]
    if model is None:
        continue
    dval = xgb.DMatrix(X_val)
    val_proba[:, i] = model.predict(dval)
val_pred = (val_proba >= THRESHOLD).astype(int)
# Micro metrics 
micro_f1 = f1_score(y_val, val_pred, average="micro")
micro_precision = precision_score(y_val, val_pred, average="micro", zero_division=0)
micro_recall = recall_score(y_val, val_pred, average="micro", zero_division=0)

# Macro F1 
macro_f1 = f1_score(y_val, val_pred, average="macro", zero_division=0)

# PR-AUC (micro) 
pr_auc_micro = average_precision_score(y_val, val_proba, average="micro")
avg_labels_per_protein = val_pred.sum(axis=1).mean()
coverage = (val_pred.sum(axis=1) > 0).mean()
print("\n===== VALIDATION METRICS (Logistic Regression) =====")
print(f"Micro-F1        : {micro_f1:.4f}")
print(f"Macro-F1        : {macro_f1:.4f}")
print(f"Micro-Precision : {micro_precision:.4f}")
print(f"Micro-Recall    : {micro_recall:.4f}")



Predicting validation...


XGB val predict: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1200/1200 [04:26<00:00,  4.51it/s]



===== VALIDATION METRICS (Logistic Regression) =====
Micro-F1        : 0.1727
Macro-F1        : 0.0023
Micro-Precision : 0.4166
Micro-Recall    : 0.1089
Micro PR-AUC    : 0.1094
Avg labels/prot : 1.04


<p align="center">
Predict Test Set</p>


In [None]:
def xgb_predict_in_batches(models, X, batch_size=10000):
    n = X.shape[0]
    C = len(models)
    proba = np.zeros((n, C), dtype=np.float32)

    for start in range(0, n, batch_size):
        end = min(start + batch_size, n)
        print(f"Predicting batch {start}:{end}")

        dmat = xgb.DMatrix(X[start:end])

        for i, model in enumerate(models):
            if model is not None:
                proba[start:end, i] = model.predict(dmat)

        del dmat   # ðŸ”¥ giáº£i phÃ³ng RAM ngay

    return proba

print("Predicting test in batches...")
test_proba = xgb_predict_in_batches(models, X_test, batch_size=8000)
test_proba = propagate_batch(
    test_proba,
    restricted_parents,
    mlb.classes_,
    iterations=2
)




Predicting test in batches...
Predicting batch 0:8000
Predicting batch 8000:16000
Predicting batch 16000:24000
Predicting batch 24000:32000
Predicting batch 32000:40000
Predicting batch 40000:48000
Predicting batch 48000:56000
Predicting batch 56000:64000
Predicting batch 64000:72000
Predicting batch 72000:80000
Predicting batch 80000:88000
Predicting batch 88000:96000
Predicting batch 96000:104000
Predicting batch 104000:112000
Predicting batch 112000:120000
Predicting batch 120000:128000
Predicting batch 128000:136000
Predicting batch 136000:144000
Predicting batch 144000:152000
Predicting batch 152000:160000
Predicting batch 160000:168000
Predicting batch 168000:176000
Predicting batch 176000:184000
Predicting batch 184000:192000
Predicting batch 192000:200000
Predicting batch 200000:208000
Predicting batch 208000:216000
Predicting batch 216000:224000
Predicting batch 224000:224309


<p align="center">
Generate Submission File</p>


In [None]:
test_ids, test_seqs = read_fasta_ids_and_seqs(TEST_FASTA)

out_path = OUT_SUBMISSION
with open(out_path, "w", encoding="utf-8") as f:
    f.write("EntryID\tGO_ID\tConfidence\n")
    
    for i, pid in enumerate(test_ids):
        pr = test_proba[i]
        idx = np.argsort(-pr)[:TOP_K]  

        for j in idx:
            score = pr[j]
            if score < MIN_PROB:
                continue
            s = round_sig(score, 3)
            if s:
                f.write(f"{pid}\t{all_terms_sorted[j]}\t{s}\n")

print("Saved:", out_path)



Saved: submission2.tsv
