## Import Libraries

# K-mer TF-IDF with MLP

This notebook implements a protein function prediction model using k-mer TF-IDF features with linear classifiers (SGD/Logistic Regression).

In [None]:
import os, gc
import numpy as np
import pandas as pd
from pathlib import Path
from collections import defaultdict, Counter
from typing import Dict, List, Set, Tuple, Optional

from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import SGDClassifier, LogisticRegression
from sklearn.multiclass import OneVsRestClassifier
import random
import joblib
import scipy.sparse as sp

## Configuration

In [None]:
class Config:
    """Centralized configuration"""
    # Paths
    DATA_DIR = Path("/kaggle/input/cafa-6-protein-function-prediction")
    TRAIN_DIR = DATA_DIR / "Train"
    TEST_DIR = DATA_DIR / "Test"
    WORK_DIR = Path("/kaggle/working")
    
    # Data files
    TRAIN_FASTA = TRAIN_DIR / "train_sequences.fasta"
    TRAIN_TERMS = TRAIN_DIR / "train_terms.tsv"
    TRAIN_TAXONOMY = TRAIN_DIR / "train_taxonomy.tsv"
    GO_OBO = TRAIN_DIR / "go-basic.obo"
    IA_FILE = DATA_DIR / "IA.tsv"
    TEST_FASTA = TEST_DIR / "testsuperset.fasta"
    SAMPLE_SUBMISSION = DATA_DIR / "sample_submission.tsv"
    OUTPUT_FILE = WORK_DIR / "submission.tsv"
    
    # TF-IDF k-mer settings
    KMER_NGRAM_RANGE = (3, 3)
    KMER_MAX_FEATURES = 50000
    KMER_MIN_DF = 2
    KMER_SUBLINEAR_TF = True

    MODEL_TYPE = "sgd"
    
    # SGD parameters
    SGD_ALPHA = 2e-6
    SGD_MAX_ITER = 30
    SGD_TOL = 1e-3
    SGD_N_JOBS = -1
    
    # Logistic Regression parameters
    LOGREG_C = 4.0
    LOGREG_MAX_ITER = 200
    LOGREG_N_JOBS = -1

    # Model parameters
    RANDOM_SEED = 42
    TOP_K_LABELS = 3000
    HIDDEN_UNITS = [1024, 512]
    DROPOUT = 0.5
    LEARNING_RATE = 3e-4
    BATCH_SIZE = 64
    EPOCHS = 50
    PATIENCE = 5
    
    # Prediction parameters
    TOP_K_PER_PROTEIN = 200
    
    THRESHOLD_SEARCH = True
    THRESHOLD_GRID = np.concatenate([
        np.arange(0.001, 0.051, 0.002),
        np.arange(0.05, 0.201, 0.01)
    ])
        
    # GO propagation
    PROPAGATE_TRAIN = True
    PROPAGATE_PRED = True
    PROPAGATE_ITERATIONS = 3
    SGD_N_JOBS = 4
    
    @classmethod
    def set_seed(cls):
        np.random.seed(cls.RANDOM_SEED)
        random.seed(cls.RANDOM_SEED)
        os.environ["PYTHONHASHSEED"] = str(cls.RANDOM_SEED)

## Data Loading Functions

In [None]:
import time, os

try:
    import psutil
    _PROC = psutil.Process(os.getpid())
except Exception:
    _PROC = None

class Timer:
    def __init__(self):
        self.t = time.perf_counter()

    def hit(self, msg: str):
        now = time.perf_counter()
        dt = now - self.t
        self.t = now
        if _PROC is not None:
            rss_gb = _PROC.memory_info().rss / (1024**3)
            print(f"[TIMER] {msg}: {dt:.2f}s | RSS={rss_gb:.2f} GB")
        else:
            print(f"[TIMER] {msg}: {dt:.2f}s")

## GO Ontology Parser

In [None]:
class DataLoader:
    """Handle all data loading operations"""
    
    @staticmethod
    def read_fasta(path: Path) -> Dict[str, str]:
        """Read FASTA file and return dict of protein_id: sequence"""
        sequences = {}
        with open(path) as f:
            protein_id = None
            seq_parts = []
            
            for line in f:
                line = line.strip()
                if line.startswith(">"):
                    if protein_id:
                        sequences[protein_id] = "".join(seq_parts)
                    
                    header = line[1:].split()[0]
                    protein_id = header.split("|")[1] if "|" in header else header
                    seq_parts = []
                else:
                    seq_parts.append(line)
            
            if protein_id:
                sequences[protein_id] = "".join(seq_parts)
        
        print(f"Loaded {len(sequences):,} sequences from {path.name}")
        return sequences
    
    @staticmethod
    def read_annotations(path: Path) -> Dict[str, List[str]]:
        """Read protein-GO term annotations"""
        df = pd.read_csv(path, sep="\t", header=None, 
                        names=["protein", "go_term", "ontology"])
        
        annotations = defaultdict(list)
        for _, row in df.iterrows():
            annotations[row.protein].append(row.go_term)
        
        print(f"Loaded annotations for {len(annotations):,} proteins")
        return dict(annotations)
    
    @staticmethod
    def read_ia_weights(path: Path) -> Dict[str, float]:
        """Read Information Accretion weights"""
        if not path.exists():
            print("Warning: IA weights file not found")
            return {}
        
        df = pd.read_csv(path, sep="\t", header=None, names=["go_term", "ia"])
        weights = {}
        
        for _, row in df.iterrows():
            try:
                weights[row.go_term] = float(str(row.ia).replace(",", "."))
            except:
                weights[row.go_term] = 0.0
        
        print(f"Loaded IA weights for {len(weights):,} GO terms")
        return weights

## K-mer Feature Extraction

In [None]:
class GOGraph:
    """Handle Gene Ontology graph operations"""
    
    def __init__(self, obo_path: Path):
        self.parents, self.children = self._parse_obo(obo_path)
    
    def _parse_obo(self, path: Path) -> Tuple[Dict, Dict]:
        """Parse OBO file to extract parent-child relationships"""
        parents = defaultdict(set)
        children = defaultdict(set)
        
        if not path.exists():
            print("Warning: OBO file not found")
            return parents, children
        
        with open(path) as f:
            current_id = None
            
            for line in f:
                line = line.strip()
                
                if line == "[Term]":
                    current_id = None
                elif line.startswith("id: "):
                    current_id = line.split("id: ")[1]
                elif line.startswith("is_a: ") and current_id:
                    parent_id = line.split()[1]
                    parents[current_id].add(parent_id)
                    children[parent_id].add(current_id)
                elif line.startswith("relationship: part_of ") and current_id:
                    parts = line.split()
                    if len(parts) >= 3:
                        parent_id = parts[2]
                        parents[current_id].add(parent_id)
                        children[parent_id].add(current_id)
        
        print(f"Parsed GO graph: {len(parents):,} terms with parents")
        return dict(parents), dict(children)
    
    def get_ancestors(self, go_term: str) -> Set[str]:
        """Get all ancestor terms"""
        ancestors = set()
        stack = [go_term]
        
        while stack:
            current = stack.pop()
            for parent in self.parents.get(current, []):
                if parent not in ancestors:
                    ancestors.add(parent)
                    stack.append(parent)
        
        return ancestors
    
    def propagate_labels(self, annotations: Dict[str, List[str]]) -> Dict[str, List[str]]:
        """Propagate labels up the GO graph"""
        print("Propagating labels up GO hierarchy...")
        propagated = {}
        
        for protein, terms in annotations.items():
            expanded = set(terms)
            for term in terms:
                expanded.update(self.get_ancestors(term))
            propagated[protein] = sorted(expanded)
        
        original_count = sum(len(v) for v in annotations.values())
        new_count = sum(len(v) for v in propagated.values())
        print(f"  {original_count:,} -> {new_count:,} annotations")
        
        return propagated

## Evaluation Functions

In [None]:
# === SPEED FIX: cache ancestors ===
from functools import lru_cache

def _attach_cached_ancestors(go_graph: GOGraph):
    @lru_cache(maxsize=None)
    def ancestors_cached(term: str):
        ancestors = set()
        stack = [term]
        while stack:
            cur = stack.pop()
            for p in go_graph.parents.get(cur, ()):
                if p not in ancestors:
                    ancestors.add(p)
                    stack.append(p)
        # trả về tuple để cache ổn định
        return tuple(ancestors)

    go_graph.get_ancestors = lambda t: set(ancestors_cached(t))  # override method

# dùng sau khi tạo go_graph
# pipeline.go_graph = GOGraph(...)
# _attach_cached_ancestors(pipeline.go_graph)

## Prediction Propagation

In [None]:
class EmbeddingHandler:
    """Handle protein embeddings: K-mer TF-IDF (SPARSE)"""

    @staticmethod
    def build_kmer_tfidf(
        sequences: List[str],
        ngram_range: Tuple[int, int],
        max_features: int,
        min_df: int = 1,
        sublinear_tf: bool = True,
    ):
        print(f"Building TF-IDF ngram_range={ngram_range} (max_features={max_features}, min_df={min_df})...")

        vectorizer = TfidfVectorizer(
            analyzer="char",
            ngram_range=ngram_range,
            lowercase=False,
            max_features=max_features,
            min_df=min_df,
            sublinear_tf=sublinear_tf,
            dtype=np.float32,
            norm="l2",
            use_idf=True,
        )
        X = vectorizer.fit_transform(sequences)
        print(f"Built TF-IDF: X={X.shape}, vocab_size={len(vectorizer.vocabulary_):,}")
        return X, vectorizer

    @staticmethod
    def transform_kmer_tfidf(
        sequences: List[str],
        vectorizer: TfidfVectorizer,
    ):
        X = vectorizer.transform(sequences)
        print(f"Transformed sequences with TF-IDF: {X.shape}")
        return X

## Submission Writing

In [None]:
class ModelBuilder:
    """Build sklearn linear multi-label models"""

    @staticmethod
    def build_linear_model(config: Config):
        if config.MODEL_TYPE == "logreg":
            base = LogisticRegression(
                C=config.LOGREG_C,
                solver="saga",
                max_iter=config.LOGREG_MAX_ITER,
                n_jobs=config.LOGREG_N_JOBS,
                verbose=0,
            )
            model = OneVsRestClassifier(base, n_jobs=config.LOGREG_N_JOBS)
            return model

        # default: SGD
        base = SGDClassifier(
            loss="log_loss",
            alpha=config.SGD_ALPHA,
            max_iter=config.SGD_MAX_ITER,
            tol=config.SGD_TOL,
            early_stopping=False,
            n_iter_no_change=3,
            validation_fraction=0.1,
            average=True,
        )
        model = OneVsRestClassifier(base, n_jobs=config.SGD_N_JOBS)
        return model


## Main Pipeline

### Load Data

In [None]:
class Evaluator:
    def __init__(self, ia_weights: Dict[str, float], go_terms: List[str]):
        self.go_terms = go_terms
        self.weights = np.array([ia_weights.get(t, 1.0) for t in go_terms], dtype=np.float32)

    def _dense_bool(self, y_true):
        if sp.issparse(y_true):
            return y_true.astype(np.bool_).toarray()
        return (y_true > 0.5)

    def _f_from_prec_rec(self, p: float, r: float, eps: float = 1e-12) -> float:
        return (2.0 * p * r) / (p + r + eps)

    def fmax(self, y_true, y_prob: np.ndarray, thresholds: np.ndarray) -> Tuple[float, float]:
        y_true_b = self._dense_bool(y_true)
        best_t, best_f = 0.5, -1.0

        for t in thresholds:
            y_pred_b = (y_prob >= t)

            tp = np.logical_and(y_true_b, y_pred_b).sum(axis=1).astype(np.float32)
            fp = np.logical_and(~y_true_b, y_pred_b).sum(axis=1).astype(np.float32)
            fn = np.logical_and(y_true_b, ~y_pred_b).sum(axis=1).astype(np.float32)

            has_pred = (tp + fp) > 0
            has_true = (tp + fn) > 0

            p = (tp[has_pred] / (tp[has_pred] + fp[has_pred] + 1e-12)).mean() if has_pred.any() else 0.0
            r = (tp[has_true] / (tp[has_true] + fn[has_true] + 1e-12)).mean() if has_true.any() else 0.0

            f = self._f_from_prec_rec(float(p), float(r))
            if f > best_f:
                best_f, best_t = f, float(t)

        return best_t, best_f

    def ia_fmax(self, y_true, y_prob: np.ndarray, thresholds: np.ndarray) -> Tuple[float, float]:
        y_true_b = self._dense_bool(y_true)
        w = self.weights[None, :]

        best_t, best_f = 0.5, -1.0
        for t in thresholds:
            y_pred_b = (y_prob >= t)

            tp_w = (np.logical_and(y_true_b, y_pred_b) * w).sum(axis=1).astype(np.float32)
            fp_w = (np.logical_and(~y_true_b, y_pred_b) * w).sum(axis=1).astype(np.float32)
            fn_w = (np.logical_and(y_true_b, ~y_pred_b) * w).sum(axis=1).astype(np.float32)

            has_pred = (tp_w + fp_w) > 0
            has_true = (tp_w + fn_w) > 0

            p = (tp_w[has_pred] / (tp_w[has_pred] + fp_w[has_pred] + 1e-12)).mean() if has_pred.any() else 0.0
            r = (tp_w[has_true] / (tp_w[has_true] + fn_w[has_true] + 1e-12)).mean() if has_true.any() else 0.0

            f = self._f_from_prec_rec(float(p), float(r))
            if f > best_f:
                best_f, best_t = f, float(t)

        return best_t, best_f

### Prepare Features and Labels

In [None]:
class PredictionPropagator:
    """Propagate predictions up GO hierarchy"""
    
    def __init__(self, go_graph: GOGraph, go_terms: List[str]):
        self.go_terms = go_terms
        self.term_to_idx = {t: i for i, t in enumerate(go_terms)}
        
        # Restrict parents to terms in our vocabulary
        self.parents = {}
        for term in go_terms:
            self.parents[term] = {p for p in go_graph.parents.get(term, []) 
                                 if p in self.term_to_idx}
    
    def propagate(self, predictions: np.ndarray, 
                  iterations: int = 3) -> np.ndarray:
        """Propagate predictions iteratively"""
        pred_copy = predictions.copy()
        
        for _ in range(iterations):
            changed = False
            
            for child_idx, child_term in enumerate(self.go_terms):
                child_scores = pred_copy[:, child_idx]
                
                for parent_term in self.parents.get(child_term, []):
                    parent_idx = self.term_to_idx[parent_term]
                    
                    # Update parent where child score is higher
                    mask = child_scores > pred_copy[:, parent_idx]
                    if mask.any():
                        pred_copy[mask, parent_idx] = child_scores[mask]
                        changed = True
            
            if not changed:
                break
        
        return pred_copy

### Train Model and Generate Predictions

In [None]:
class CAFA6Pipeline:
    """Main training and prediction pipeline"""
    
    def __init__(self, config: Config):
        self.config = config
        config.set_seed()
        
        print("\n" + "="*70)
        print("CAFA-6 PROTEIN FUNCTION PREDICTION PIPELINE")
        print("="*70 + "\n")
        
        print("Loading data...")
        self.loader = DataLoader()
        self.train_seqs = self.loader.read_fasta(config.TRAIN_FASTA)
        self.test_seqs = self.loader.read_fasta(config.TEST_FASTA)
        self.annotations = self.loader.read_annotations(config.TRAIN_TERMS)
        self.ia_weights = self.loader.read_ia_weights(config.IA_FILE)
        
        print("\nLoading GO ontology...")
        self.go_graph = GOGraph(config.GO_OBO)
        _attach_cached_ancestors(self.go_graph)
        
        if config.PROPAGATE_TRAIN:
            self.annotations = self.go_graph.propagate_labels(self.annotations)
        
        self._prepare_labels()
        
        print("\nPreparing K-mer TF-IDF embeddings...")
        self.embedding_handler = EmbeddingHandler()
        self.vectorizer = None  
        
    def _prepare_labels(self):
        """Prepare label matrix"""
        print("\nPreparing labels...")
        
        self.train_proteins = [p for p in self.annotations.keys() 
                              if p in self.train_seqs]
        print(f"  {len(self.train_proteins):,} training proteins")
        
        term_counts = Counter()
        for protein in self.train_proteins:
            term_counts.update(self.annotations[protein])
        
        top_terms = [t for t, _ in term_counts.most_common(self.config.TOP_K_LABELS)]
        self.chosen_terms = set(top_terms)
        print(f"  Using top {len(self.chosen_terms):,} GO terms")
        
        for protein in self.train_proteins:
            self.annotations[protein] = [t for t in self.annotations[protein] 
                                        if t in self.chosen_terms]
        
        labels_list = [self.annotations[p] for p in self.train_proteins]
        self.mlb = MultiLabelBinarizer(classes=sorted(self.chosen_terms), sparse_output=True)
        self.y = self.mlb.fit_transform(labels_list)
        
        print(f"  Label matrix shape: {self.y.shape}")
    
    def prepare_data(self):
        """Build K-mer TF-IDF embeddings for training"""
        print("\nPreparing training data (K-mer TF-IDF)...")
        t = Timer()
        
        train_texts = [self.train_seqs[p] for p in self.train_proteins]
        
        self.X, self.vectorizer = self.embedding_handler.build_kmer_tfidf(
            train_texts,
            ngram_range=self.config.KMER_NGRAM_RANGE,
            max_features=self.config.KMER_MAX_FEATURES,
            min_df=self.config.KMER_MIN_DF,
            sublinear_tf=self.config.KMER_SUBLINEAR_TF,
        )
        
        self.X_train, self.X_val, self.y_train, self.y_val = train_test_split(
            self.X, self.y, test_size=0.15, random_state=self.config.RANDOM_SEED
        )
        t.hit("TF-IDF fit_transform")
        
        print(f"  Train: {self.X_train.shape}")
        print(f"  Val:   {self.X_val.shape}")
        
    def build_and_train(self):
        print("\nBuilding linear model (OneVsRest)...")
        self.model = ModelBuilder.build_linear_model(self.config)
        
        t = Timer()
        print("Training (sklearn)...")
        with joblib.parallel_backend("threading"):
            self.model.fit(self.X_train, self.y_train)

        t.hit("model.fit")
        
    def evaluate(self):
        print("\nEvaluating (Fmax sweep)...")
    
        self.evaluator = Evaluator(self.ia_weights, list(self.mlb.classes_))
    
        y_val_prob = self.model.predict_proba(self.X_val).astype(np.float32)
    
        if self.config.PROPAGATE_PRED:
            print("  Propagating VAL predictions...")
            propagator = PredictionPropagator(self.go_graph, list(self.mlb.classes_))
            y_val_prob = propagator.propagate(y_val_prob, self.config.PROPAGATE_ITERATIONS)
    
        if self.config.THRESHOLD_SEARCH:
            t_best, f_best = self.evaluator.fmax(self.y_val, y_val_prob, self.config.THRESHOLD_GRID)
            t_best_ia, f_best_ia = self.evaluator.ia_fmax(self.y_val, y_val_prob, self.config.THRESHOLD_GRID)
    
            self.best_threshold = t_best
            self.best_f1 = f_best
            print(f"  Best threshold (Fmax): {t_best:.4f} | Fmax: {f_best:.4f}")
            print(f"  Best threshold (IA-Fmax): {t_best_ia:.4f} | IA-Fmax: {f_best_ia:.4f}")
        else:
            self.best_threshold = 0.05
            print(f"  Using fixed threshold: {self.best_threshold:.4f}")

    def predict_and_submit(self):
        print("\nGenerating predictions (batch + streaming write)...")
    
        if self.vectorizer is None:
            raise RuntimeError("TF-IDF vectorizer is None. Did you call prepare_data() before predict_and_submit()?")
    
        test_ids = list(self.test_seqs.keys())
        test_texts = [self.test_seqs[p] for p in test_ids]
    
        propagator = None
        if self.config.PROPAGATE_PRED:
            propagator = PredictionPropagator(self.go_graph, list(self.mlb.classes_))
    
        top_k = self.config.TOP_K_PER_PROTEIN
        batch_size = 4096
    
        print(f"\nWriting submission to {self.config.OUTPUT_FILE}...")
        with open(self.config.OUTPUT_FILE, "w") as f:
            for start in range(0, len(test_ids), batch_size):
                end = min(start + batch_size, len(test_ids))
                batch_ids = test_ids[start:end]
                batch_texts = test_texts[start:end]
    
                X_batch = self.embedding_handler.transform_kmer_tfidf(batch_texts, self.vectorizer)
                y_prob = self.model.predict_proba(X_batch).astype(np.float32)
    
                if propagator is not None:
                    y_prob = propagator.propagate(y_prob, self.config.PROPAGATE_ITERATIONS)
    
                for i, protein_id in enumerate(batch_ids):
                    probs = y_prob[i]
    
                    if top_k < probs.shape[0]:
                        idx = np.argpartition(probs, -top_k)[-top_k:]
                        idx = idx[np.argsort(probs[idx])[::-1]]
                    else:
                        idx = np.argsort(probs)[::-1]
    
                    for j in idx:
                        score = float(probs[j])
                        if score > 1e-8:
                            go_term = self.mlb.classes_[j]
                            f.write(f"{protein_id}\t{go_term}\t{score:.6f}\n")
    
        print("Done!")

        
    def run(self):
        t = Timer()
        try:
            t.hit("start run()")
    
            self.prepare_data()
            t.hit("prepare_data() done")
    
            self.build_and_train()
            t.hit("build_and_train() done")
    
            self.evaluate()
            t.hit("evaluate() done")
    
            self.predict_and_submit()
            t.hit("predict_and_submit() done")
    
            print("\n" + "="*70)
            print("PIPELINE COMPLETED SUCCESSFULLY")
            print("="*70)
    
        except Exception as e:
            print(f"\nError: {e}")
            raise

In [None]:
# Initialize and run pipeline
config = Config()
pipeline = CAFA6Pipeline(config)
pipeline.run()