# **STAGE 0: INSTALLS AND IMPORTS** `

In [None]:

!pip install -q transformers obonet biopython accelerate datasets
!pip install -q biopython obonet transformers
!pip install -q pyarrow==21.0.0 --quiet

import optuna

import logging
import warnings
warnings.filterwarnings("ignore")

import os, random, time
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy import sparse

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset, DataLoader, Subset

import obonet
from Bio import SeqIO
from torch.cuda.amp import GradScaler, autocast
from transformers import AutoTokenizer, AutoModel


from sklearn.preprocessing import MultiLabelBinarizer, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score


os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"  # suppress TensorFlow/XLA logs (0=all, 3=errors only)
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
os.environ["HF_HUB_REQUEST_TIMEOUT"] = "120" # Environment variables for Hugging Face
os.environ["HF_HUB_CONNECT_RETRIES"] = "10"
os.environ["PYTHONWARNINGS"] = "ignore"

logging.getLogger("pip").setLevel(logging.ERROR)
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("torch").setLevel(logging.ERROR)

print("Imports OK. Torch version:", torch.__version__)


# **STAGE 1:  PATHS AND SETTINGS**

In [None]:
BASE = "/kaggle/input/cafa-6-protein-function-prediction"
TRAIN_DIR = f"{BASE}/Train"
TEST_DIR = f"{BASE}/Test"

TRAIN_FASTA = f"{TRAIN_DIR}/train_sequences.fasta"
TRAIN_TERMS = f"{TRAIN_DIR}/train_terms.tsv"
TRAIN_TAX = f"{TRAIN_DIR}/train_taxonomy.tsv"
GO_OBO = f"{TRAIN_DIR}/go-basic.obo"

TEST_FASTA = f"{TEST_DIR}/testsuperset.fasta"
TEST_TAXON = f"{TEST_DIR}/testsuperset-taxon-list.tsv"

IA_PATH = f"{BASE}/IA.tsv"
SAMPLE_SUB = f"{BASE}/sample_submission.tsv"

EMBED_CACHE_TRAIN = "/kaggle/input/embeddings-cache/embeddings_cache/train_esm2_embeddings.parquet"
EMBED_CACHE_TEST = "/kaggle/input/embeddings-cache/embeddings_cache/test_esm2_embeddings.parquet"

OUT_DIR = "/kaggle/working/cafa_mlp_output"
os.makedirs(OUT_DIR, exist_ok=True)

# Experiment switches (tune for debugging / full run)
SEED = 42
EMBED_MODEL = "facebook/esm2_t6_8M_UR50D"   # balanced speed & quality
BATCH_SIZE_EMBED = 4    # embedding batch size (reduce if OOM)
BATCH_SIZE_TRAIN = 128   # training batch size (reduce if OOM)
MAX_SEQS_TO_EMBED = None # None => embed all train sequences. Set small for debugging.
MAX_LABELS = None        # None => use all GO terms. Use int (e.g., 5000) to limit labels while debugging.

HPO_TRAIN_FRAC = 0.15  # use 15% of training data for speed
HPO_VAL_FRAC = 0.25    # use 25% of validation data

EPOCHS = 20
LR = 1e-4
WEIGHT_DECAY = 1e-5
PATIENCE = 3

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)
print("Paths set. Output dir:", OUT_DIR)

print("ðŸ”§ Experiment Settings")
print("-"*40)
print(f"Random seed: {SEED}")
print(f"ESM2 model: {EMBED_MODEL}")
print(f"Embedding batch size: {BATCH_SIZE_EMBED}")
print(f"Training batch size: {BATCH_SIZE_TRAIN}")
print(f"Max train sequences to embed: {MAX_SEQS_TO_EMBED}")
print(f"Max GO terms (labels): {MAX_LABELS}")
print(f"Training epochs: {EPOCHS}")
print(f"Learning rate: {LR}")
print(f"Weight decay: {WEIGHT_DECAY}")
print(f"Early stopping patience: {PATIENCE}")
print(f"Training Fraction: {HPO_TRAIN_FRAC}")
print(f"Validation Fraction: {HPO_VAL_FRAC}")

# **STAGE 2: LOAD METADATA AND SEQUENCES**

In [None]:
# Load dataset files (terms, taxonomy, IA, sequences)
def seed_everything(seed=SEED):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

seed_everything()

# Load metadata 
train_terms = pd.read_csv(TRAIN_TERMS, sep="\t", header=None,
                          names=["protein_id","go_term","ontology"], dtype=str)
train_terms = train_terms[~train_terms['protein_id'].str.contains('Entry', na=False)].reset_index(drop=True)

train_tax = pd.read_csv(TRAIN_TAX, sep="\t", header=None,
                        names=["protein_id","taxon_id"], dtype=str)

ia = pd.read_csv(IA_PATH, sep="\t", header=None,
                 names=["go_term","ia_weight"], dtype={'go_term':str,'ia_weight':float})
 
go_graph = obonet.read_obo(GO_OBO)

# Load FASTA sequences 
def load_fasta_dict(path, max_records=None):
    out = {}
    for i, rec in enumerate(SeqIO.parse(path, "fasta")):
        prot_id = rec.id.split("|")[1] if "|" in rec.id else rec.id
        out[prot_id] = str(rec.seq)
        if max_records and (i+1) >= max_records:
            break
    return out

train_seqs = load_fasta_dict(TRAIN_FASTA, max_records=MAX_SEQS_TO_EMBED)
test_seqs = load_fasta_dict(TEST_FASTA, max_records=None)

# Merge annotations with taxonomy 
train_df = train_terms.merge(train_tax, on="protein_id", how="left")

# -------------------- Dataset summary --------------------
print("\nðŸ“Š Dataset Overview\n" + "-"*40)
print(f"Total GO ontology terms: {len(go_graph)}")
print(f"Number of training sequences: {len(train_seqs)}")
print(f"Number of test sequences: {len(test_seqs)}")
print(f"Training DataFrame shape: {train_df.shape}\n")

print("Summary of key datasets:")
print(f"â€¢ train_terms dataframe: {train_terms.shape}")
print(f"â€¢ train_taxonomy dataframe: {train_tax.shape}")
print(f"â€¢ Information content (IA) table: {ia.shape}")
print(f"â€¢ Number of training sequences: {len(train_seqs)}")
print(f"â€¢ Number of test sequences: {len(test_seqs)}")
print(f"â€¢ Total GO terms in ontology graph: {len(go_graph)}")

# **STAGE 3:  EMBEDDING EXTRACTION (ESM2) WITH CACHING**

In [None]:
def load_fasta_sequences(file_path):
    sequences = {}
    for record in SeqIO.parse(file_path, "fasta"):
        prot_id = record.id.split("|")[1] if "|" in record.id else record.id
        sequences[prot_id] = str(record.seq)
    return sequences

train_sequences = load_fasta_sequences(TRAIN_FASTA)
test_sequences = load_fasta_sequences(TEST_FASTA)

def extract_embeddings(seqs_dict, cache_path, model_name=EMBED_MODEL, batch_size=BATCH_SIZE_EMBED, device=DEVICE):
    cache = Path(cache_path)
    if cache.exists():
        print("Loading cached embeddings:", cache_path)
        return pd.read_parquet(cache_path)

    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
    model = AutoModel.from_pretrained(model_name).to(device)
    model.eval()

    prot_ids = list(seqs_dict.keys())
    embeddings = []
    ids = []

    for i in tqdm(range(0, len(prot_ids), batch_size), desc="Embedding batches"):
        batch_ids = prot_ids[i:i+batch_size]
        batch_seqs = [seqs_dict[x] for x in batch_ids]
        inputs = tokenizer(batch_seqs, return_tensors="pt", padding=True, truncation=True, max_length=1022)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            out = model(**inputs)
            emb = out.last_hidden_state.mean(dim=1).cpu().numpy()

        embeddings.append(emb)
        ids.extend(batch_ids)
        torch.cuda.empty_cache()

    emb_mat = np.vstack(embeddings)
    df = pd.DataFrame(emb_mat)
    df.insert(0, "protein_id", ids)
    df.to_parquet(cache_path, index=False)
    print("Saved embeddings to:", cache_path, "shape:", df.shape)
    return df

train_emb_df = extract_embeddings(train_sequences, EMBED_CACHE_TRAIN)
test_emb_df = extract_embeddings(test_sequences, EMBED_CACHE_TEST)

# **STAGE 4: PREPARE MULTI-LABEL TARGETS**

In [None]:
# Build label vocabulary and Y matrix
# Optionally limit labels with MAX_LABELS during debugging
all_go_terms = sorted(train_terms['go_term'].unique().tolist())
if MAX_LABELS:
    print("Limiting labels to:", MAX_LABELS)
    all_go_terms = all_go_terms[:MAX_LABELS]
print("Using #labels:", len(all_go_terms))

# IA weights array (for later)
ia_map = dict(zip(ia['go_term'], ia['ia_weight']))
ia_weights = np.array([ia_map.get(g, 0.0) for g in all_go_terms], dtype=float)

# Build grouped labels per protein
tt = train_terms[train_terms['go_term'].isin(all_go_terms)]
labels_grouped = tt.groupby('protein_id')['go_term'].apply(list).reset_index()

# Keep only proteins with embeddings
available_train_ids = set(train_emb_df['protein_id'].tolist())
labels_grouped = labels_grouped[labels_grouped['protein_id'].isin(available_train_ids)].reset_index(drop=True)
print("Proteins with labels & embeddings:", labels_grouped.shape[0])

# Make label matrix
mlb = MultiLabelBinarizer(classes=all_go_terms)
Y = mlb.fit_transform(labels_grouped['go_term'])
print("Label matrix shape:", Y.shape)

# Build X aligned to labels
emb_map = train_emb_df.set_index('protein_id')
X_rows = []
for pid in labels_grouped['protein_id']:
    row = emb_map.loc[pid].drop(labels=['protein_id'], errors='ignore') if 'protein_id' in emb_map.columns else emb_map.loc[pid]
    X_rows.append(row.values)
X = np.vstack(X_rows)
print("Final X shape:", X.shape)

# -----------------------------------------
# âœ… Normalize embeddings
# Step 1: L2 normalization per protein vector
X = X / np.linalg.norm(X, axis=1, keepdims=True)

# Step 2: StandardScaler normalization across features
scaler = StandardScaler()
X = scaler.fit_transform(X)

print("âœ… Embeddings normalized.")
print("Mean (overall):", np.mean(X))
print("Std (overall):", np.std(X))
print("Final X shape after normalization:", X.shape)

# **STAGE 5: DATASET AND DATALOADER**

In [None]:
class ProteinDataset(Dataset):
    def __init__(self, X_np, Y_np):
        self.X = torch.from_numpy(X_np).float()
        self.Y = torch.from_numpy(Y_np).float()
    def __len__(self): return len(self.X)
    def __getitem__(self, idx): return self.X[idx], self.Y[idx]

# Split train/val
train_idx, val_idx = train_test_split(np.arange(len(X)), test_size=0.15, random_state=SEED, shuffle=True)
train_ds = ProteinDataset(X[train_idx], Y[train_idx])
val_ds = ProteinDataset(X[val_idx], Y[val_idx])
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE_TRAIN, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE_TRAIN, shuffle=False, num_workers=2, pin_memory=True)

print("Train/Val sizes:", len(train_ds), len(val_ds))


# **STAGE 5: MODEL, LOSS, EVALUATE, HPO**

In [None]:
# Deep MLP Model
class DeepMLP(nn.Module):
    def __init__(self, input_dim, num_labels, hidden_dims=[1024, 512, 256], dropout=0.3):
        super().__init__()
        layers = []
        prev = input_dim
        for h in hidden_dims:
            layers += [
                nn.Linear(prev, h),
                nn.BatchNorm1d(h),
                nn.ReLU(inplace=True),
                nn.Dropout(dropout)
            ]
            prev = h
        layers.append(nn.Linear(prev, num_labels))
        self.net = nn.Sequential(*layers)
    def forward(self, x):
        return self.net(x)


# Focal Loss
class FocalLoss(nn.Module):
    def __init__(self, alpha=1.0, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.bce = nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, logits, targets):
        bce_loss = self.bce(logits, targets)
        probs = torch.sigmoid(logits)
        focal_weight = self.alpha * (
            (1 - probs) ** self.gamma * targets + (probs ** self.gamma) * (1 - targets)
        )
        loss = focal_weight * bce_loss
        return loss.mean() if self.reduction == 'mean' else loss.sum()


# Faster Evaluation Function (vectorized)
def evaluate(loader, model, threshold=0.25):
    model.eval()
    y_true_list, y_pred_list = [], []
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(DEVICE)
            probs = torch.sigmoid(model(xb)).cpu()
            preds = (probs >= threshold).float()
            y_pred_list.append(preds)
            y_true_list.append(yb)
    y_true = torch.cat(y_true_list).numpy()
    y_pred = torch.cat(y_pred_list).numpy()
    return (
        f1_score(y_true, y_pred, average='micro', zero_division=0),
        precision_score(y_true, y_pred, average='micro', zero_division=0),
        recall_score(y_true, y_pred, average='micro', zero_division=0),
    )


# HPO Optimization Function
def objective(trial):
    # Suggested hyperparameters
    hidden_dims = [
        trial.suggest_categorical("h1", [512, 1024, 2048]),
        trial.suggest_categorical("h2", [256, 512, 1024])
    ]
    dropout = trial.suggest_float("dropout", 0.1, 0.5)
    lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)

    # Build model
    model = DeepMLP(
        input_dim=X.shape[1],
        num_labels=Y.shape[1],
        hidden_dims=hidden_dims,
        dropout=dropout
    ).to(DEVICE)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=WEIGHT_DECAY)
    criterion = FocalLoss(alpha=1.0, gamma=2.0)

    # Train for fewer epochs for speed
    best_val_f1 = 0.0
    for epoch in range(2):  # Reduced from 3 to 2
        model.train()
        for xb, yb in train_loader:
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            optimizer.zero_grad(set_to_none=True)
            loss = criterion(model(xb), yb)
            loss.backward()
            optimizer.step()

        # Validate
        val_f1, _, _ = evaluate(val_loader, model)
        trial.report(val_f1, step=epoch)

        # Optuna pruning check
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

        best_val_f1 = max(best_val_f1, val_f1)

    return best_val_f1


# Optuna study setup
#study = optuna.create_study( study_name="protein_mlp_hpo_fast",  direction="maximize",   pruner=optuna.pruners.MedianPruner(n_startup_trials=2, n_warmup_steps=1, interval_steps=1))
# set n_trials 10/20 before execution

# Reduce number of trials for faster optimization
#study.optimize(objective, n_trials=0, show_progress_bar=True)

#print("Best hyperparameters found:")
#print(study.best_params)



# **STAGE 7: TRAINING (EARLY STOPPING)**

In [None]:
# [I 2025-10-20 17:24:12,504] Trial 1 finished with value: 0.16273737993294707 and parameters: {'h1': 2048, 'h2': 256, 'dropout': 0.1889064287439644, 'lr': 0.00011241516730960752}. Best is trial 1 with value: 0.16273737993294707.
# Use best hyperparameters from study
# best_params = study.best_params
hidden_dims = [2048, 256]
dropout = 0.1889064287439644
lr = 0.00011241516730960752

# Build final model
model = DeepMLP(input_dim=X.shape[1], num_labels=Y.shape[1], hidden_dims=hidden_dims, dropout=dropout).to(DEVICE)
criterion = FocalLoss(alpha=1.0, gamma=2.0)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=WEIGHT_DECAY)

# Train full model
best_f1, patience = 0, 0
for epoch in range(1, EPOCHS + 1):
    model.train()
    train_loss = 0.0
    for xb, yb in tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}"):
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        optimizer.zero_grad()
        with autocast():
            loss = criterion(model(xb), yb)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(train_loader)

    val_f1, val_prec, val_rec = evaluate(val_loader, model)
    print(f"Epoch {epoch}: loss={train_loss:.6f} val_f1={val_f1:.4f} prec={val_prec:.4f} rec={val_rec:.4f}")

    if val_f1 > best_f1:
        best_f1 = val_f1
        patience = 0
        torch.save(model.state_dict(), os.path.join(OUT_DIR, "best_mlp.pt"))
    else:
        patience += 1
    if patience >= PATIENCE:
        print("Early stopping triggered.")
        break

print("âœ… Best val F1:", best_f1)


# **STAGE 8: TESTING AND PREPARE SUBMMISSION**

In [None]:
# Build test matrix aligned with training embedding columns
test_ids = test_emb_df['protein_id'].tolist()
X_test = test_emb_df.drop(columns=['protein_id']).values
test_ds = ProteinDataset(X_test, np.zeros((X_test.shape[0], num_labels)))
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE_TRAIN, shuffle=False, num_workers=2)

all_probs = []
model.eval()
with torch.no_grad():
    for xb, _ in tqdm(test_loader, desc="Predict test"):
        xb = xb.to(DEVICE)
        logits = model(xb)
        probs = torch.sigmoid(logits).cpu().numpy()
        all_probs.append(probs)
all_probs = np.vstack(all_probs)
print("Test probs shape:", all_probs.shape)

# Convert to CAFA TSV: (protein_id, go_term, score)
THRESH = 0.1   # adjust threshold
TOP_K = 1500   # competition cap per protein overall
rows = []
go_terms = list(mlb.classes_)
for i, pid in enumerate(test_ids):
    probs = all_probs[i]
    idxs = np.where(probs > THRESH)[0]
    if len(idxs) == 0:
        # fallback: top 1
        idxs = [int(np.argmax(probs))]
    # sort by prob desc and cap
    idxs = idxs[np.argsort(probs[idxs])[::-1]]
    idxs = idxs[:TOP_K]
    for j in idxs:
        rows.append((pid, go_terms[j], float(probs[j])))

sub_df = pd.DataFrame(rows, columns=["protein_id","go_term","score"])
out_path = os.path.join(OUT_DIR, "submission_mlp_raw.tsv")
sub_df.to_csv(out_path, sep="\t", header=False, index=False)
print("Raw submission saved:", out_path, "rows:", len(sub_df))

# Optional: propagate to ancestors using GO graph (recommended)
print("Loading GO graph and propagating predictions ...")
go_graph = obonet.read_obo(GO_OBO)
parents = {n: list(go_graph.predecessors(n)) for n in go_graph.nodes()}

def propagate(term_scores, parents_map):
    propagated = dict(term_scores)
    for term, score in list(term_scores.items()):
        stack = [term]
        while stack:
            t = stack.pop()
            for p in parents_map.get(t, []):
                if propagated.get(p, 0.0) < score:
                    propagated[p] = score
                    stack.append(p)
    return propagated

from collections import defaultdict
pred_by_prot = defaultdict(dict)
for pid, term, score in rows:
    pred_by_prot[pid][term] = max(pred_by_prot[pid].get(term, 0.0), score)

prop_rows = []
for pid, ts in tqdm(pred_by_prot.items(), desc="Propagate"):
    pr = propagate(ts, parents)
    sorted_terms = sorted(pr.items(), key=lambda x: x[1], reverse=True)[:TOP_K]
    for t,s in sorted_terms:
        prop_rows.append((pid,t,s))

prop_df = pd.DataFrame(prop_rows, columns=["protein_id","go_term","score"])
prop_out = os.path.join(OUT_DIR, "submission_mlp_propagated.tsv")
prop_df.to_csv(prop_out, sep="\t", header=False, index=False)
print("Propagated submission saved:", prop_out, "rows:", len(prop_df))
