In [None]:
!pip install biopython > dev null

In [None]:
import os
import gc
import sys
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchmetrics.classification import MultilabelF1Score
from sklearn.preprocessing import MultiLabelBinarizer, StandardScaler
from collections import Counter, defaultdict
from tqdm.auto import tqdm
from Bio import SeqIO
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

# ==========================================
# 1. CONFIGURATION
# ==========================================
class CFG:
    # Paths
    TRAIN_TERMS = '/kaggle/input/cafa-6-protein-function-prediction/Train/train_terms.tsv'
    TRAIN_SEQS = '/kaggle/input/cafa-6-protein-function-prediction/Train/train_sequences.fasta'
    TRAIN_TAX = '/kaggle/input/cafa-6-protein-function-prediction/Train/train_taxonomy.tsv'
    
    TEST_SEQS = '/kaggle/input/cafa-6-protein-function-prediction/Test/testsuperset.fasta'
    TEST_TAX = '/kaggle/input/cafa-6-protein-function-prediction/Test/testsuperset-taxon-list.tsv'
    
    OBO_FILE = '/kaggle/input/cafa-6-protein-function-prediction/Train/go-basic.obo'
    
    EMBEDDINGS_PATH = '/kaggle/input/cafa6-protein-embeddings-esm2/protein_embeddings.npy'
    EMBEDDINGS_IDS = '/kaggle/input/cafa6-protein-embeddings-esm2/protein_ids.csv'
    GOA_PATH = '/kaggle/input/goa-uniprot-all/selection_ids_goa_uniprot_all.csv'
    
    # Global
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    THRESHOLDS = {'C': 0.08, 'F': 0.08, 'P': 0.01}
    MIN_PREDS = 25

# --- ENSEMBLE CONFIGURATIONS ---
ensemble_configs = [
    # 1. Aggressive
    {'BS': 32, 'EPOCHS': 40, 'LR': 1e-3, 'D1': 1024, 'D2': 512, 'Drop': 0.3},
    # 2. Wide & Robust
    {'BS': 32,  'EPOCHS': 40, 'LR': 1e-3, 'D1': 2048, 'D2': 1024, 'Drop': 0.4},
    # 3. Balanced Middle
    {'BS': 32,  'EPOCHS': 40, 'LR': 1e-3, 'D1': 1024, 'D2': 768, 'Drop': 0.3}
]

print(f"Using device: {CFG.DEVICE}")
print(f"\nEnsembling {len(ensemble_configs)} models for each aspect")

# ==========================================
# 2. HELPER FUNCTIONS
# ==========================================
def plot_history(history, aspect_name, model_name):
    epochs = range(1, len(history['train_loss']) + 1)
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    ax1.plot(epochs, history['train_loss'], 'r-o', label='Train Loss')
    ax1.plot(epochs, history['val_loss'], 'b-s', label='Val Loss')
    ax1.set_title(f'{aspect_name} [{model_name}] - Loss')
    ax1.legend(); ax1.grid(True)
    ax2.plot(epochs, history['val_f1'], 'g-^', label='Val F1')
    ax2.set_title(f'{aspect_name} [{model_name}] - F1 Score')
    ax2.legend(); ax2.grid(True)
    plt.tight_layout(); plt.show()

def parse_obo(obo_file):
    print(f"\nParsing GO hierarchy from {os.path.basename(obo_file)}...")
    go_parents = defaultdict(list)
    current_term = None
    with open(obo_file, 'r') as f:
        for line in f:
            line = line.strip()
            if line.startswith("id: GO:"): current_term = line.split("id: ")[1]
            elif line.startswith("is_a:") and current_term:
                go_parents[current_term].append(line.split("is_a: ")[1].split(" !")[0].strip())
            elif line.startswith("relationship: part_of") and current_term:
                go_parents[current_term].append(line.split("part_of ")[1].split(" !")[0].strip())
    print(f"✅ Loaded {len(go_parents):,} GO terms with parents")
    return go_parents

def parse_obo_children(obo_file):
    go_children = {}
    current_term = None
    
    with open(obo_file, 'r') as f:
        for line in f:
            line = line.strip()
            if line == '[Term]':
                current_term = None
            elif line.startswith('id: '):
                current_term = line[4:]
            elif line.startswith('is_a: ') and current_term:
                parent = line[5:].split(' ! ')[0]
                if parent not in go_children:
                    go_children[parent] = []
                go_children[parent].append(current_term)            
    return go_children

def get_all_ancestors(term, go_parents, cache):
    if term in cache: return cache[term]
    ancestors = set()
    for parent in go_parents.get(term, []):
        ancestors.add(parent)
        ancestors.update(get_all_ancestors(parent, go_parents, cache))
    cache[term] = ancestors
    return ancestors

def get_descendants(term, children_map, cache=None):
    if cache is not None and term in cache:
        return cache[term]
    descendants = set()
    stack = [term]
    # Simple DFS to find all children
    while stack:
        current = stack.pop()
        if current in children_map:
            kids = children_map[current]
            # Only add kids we haven't seen to avoid cycles/dups
            new_kids = [k for k in kids if k not in descendants]
            descendants.update(new_kids)
            stack.extend(new_kids)
    if cache is not None:
        cache[term] = descendants
    return descendants

def propagate_predictions(df, go_parents):
    print("\nPropagating predictions to ancestors...")
    ancestor_cache = {}
    propagated_rows = []
    # Fast propagation
    for pid, group in tqdm(df.groupby('pid'), desc="Propagating"):
        term_scores = {}
        for _, row in group.iterrows(): term_scores[row['term']] = max(term_scores.get(row['term'], 0), row['p'])
        initial_terms = list(term_scores.keys())
        for term in initial_terms:
            score = term_scores[term]
            for ancestor in get_all_ancestors(term, go_parents, ancestor_cache):
                term_scores[ancestor] = max(term_scores.get(ancestor, 0), score)
        for term, score in term_scores.items(): propagated_rows.append({'pid': pid, 'term': term, 'p': score})
    return pd.DataFrame(propagated_rows)

class SequenceFeatureExtractor:
    @staticmethod
    def extract(seq):
        if not seq: return np.zeros(85, dtype=np.float32)
        try:
            length = len(seq); aa_counts = Counter(seq)
            standard = 'ACDEFGHIKLMNPQRSTVWY'
            aa_freq = np.array([aa_counts.get(aa, 0)/length for aa in standard], dtype=np.float32)
            hydrophobic = sum(aa_counts.get(aa, 0) for aa in 'AILMFWYV') / length
            charged = sum(aa_counts.get(aa, 0) for aa in 'DEKR') / length
            aa_weights = {'A': 89, 'C': 121, 'D': 133, 'E': 147, 'F': 165, 'G': 75, 'H': 155, 'I': 131, 'K': 146, 'L': 131, 'M': 149, 'N': 132, 'P': 115, 'Q': 146, 'R': 174, 'S': 105, 'T': 119, 'V': 117, 'W': 204, 'Y': 181}
            mol_weight = sum(aa_counts.get(aa, 0) * aa_weights.get(aa, 0) for aa in aa_counts)
            props = np.array([np.log1p(length), hydrophobic, charged, np.log1p(mol_weight)], dtype=np.float32)
            groups = {'hydro': 'AILMFWYV', 'polar': 'STNQ', 'pos': 'RK', 'neg': 'DE', 'arom': 'FWY', 'aliph': 'ILV'}
            group_feats = np.array([sum(1 for aa in seq if aa in chars)/length for chars in groups.values()], dtype=np.float32)
            top_di = ['AL', 'LA', 'LE', 'EA', 'AA', 'AS', 'SA', 'EL', 'LL', 'AE', 'SE', 'ES', 'GA', 'AG', 'VA', 'AV', 'LV', 'VL', 'LS', 'SL']
            di_freq = np.zeros(20, dtype=np.float32)
            if length > 1:
                di_counts = Counter([seq[i:i+2] for i in range(length-1)])
                di_freq = np.array([di_counts.get(dp, 0)/(length-1) for dp in top_di], dtype=np.float32)
            top_tri = ['ALA', 'LEA', 'EAL', 'LAL', 'AAA', 'LLE', 'ELE', 'ALE', 'GAL', 'ASA', 'VLA', 'LAV', 'SLS', 'LSL', 'GLA', 'LAG', 'AVL', 'VLA', 'SLE', 'LES']
            tri_freq = np.zeros(20, dtype=np.float32)
            if length > 2:
                tri_counts = Counter([seq[i:i+3] for i in range(length-2)])
                tri_freq = np.array([tri_counts.get(tp, 0)/(length-2) for dp in top_tri], dtype=np.float32)
            extra_di = ['RE', 'ER', 'VE', 'EV', 'TE', 'ET', 'AV', 'VA', 'GL', 'LG']
            extra_freq = np.zeros(10, dtype=np.float32)
            if length > 1:
                di_counts = Counter([seq[i:i+2] for i in range(length-1)])
                extra_freq = np.array([di_counts.get(dp, 0)/(length-1) for dp in extra_di], dtype=np.float32)
            arom = sum(aa_counts.get(aa, 0) for aa in 'FWY') / length
            instab = (hydrophobic + charged) / 2
            iso = np.mean([ord(aa) for aa in seq[:100]])/255 if seq else 0
            misc = np.array([arom, instab, iso, 0, 0], dtype=np.float32)
            return np.concatenate([aa_freq, props, group_feats, di_freq, tri_freq, extra_freq, misc])
        except: return np.zeros(85, dtype=np.float32)

def parse_fasta(path):
    seqs = {}; feats = {}
    print(f"\nParsing {os.path.basename(path)}...")
    for rec in tqdm(SeqIO.parse(path, "fasta"), disable=True):
        pid = rec.id.split('|')[1] if '|' in rec.id else rec.id.split()[0]
        s = str(rec.seq); seqs[pid] = s; feats[pid] = SequenceFeatureExtractor.extract(s)
    return seqs, feats

# ==========================================
# 3. DATA LOADING
# ==========================================
print("\nLoading ESM2 Embeddings...")
emb_ids = pd.read_csv(CFG.EMBEDDINGS_IDS)['protein_id'].tolist()
emb_data = np.load(CFG.EMBEDDINGS_PATH)
embed_dict = {pid: emb_data[i] for i, pid in enumerate(emb_ids)}
del emb_data; gc.collect()
print(f"✅ Loaded {len(embed_dict):,} embeddings")

train_seqs, train_feats = parse_fasta(CFG.TRAIN_SEQS)
test_seqs, test_feats = parse_fasta(CFG.TEST_SEQS)

print("\nLoading Taxonomy...")
train_tax_df = pd.read_csv(CFG.TRAIN_TAX, sep='\t', header=None, names=['pid', 'taxid'])
test_tax_df = pd.read_csv(CFG.TEST_TAX, sep='\t', header=None, names=['pid', 'taxid'])
top_taxa = train_tax_df['taxid'].value_counts().head(20).index.tolist()

def get_tax_vector(pid, tax_map):
    vec = np.zeros(len(top_taxa), dtype=np.float32)
    if pid in tax_map and tax_map[pid] in top_taxa:
        vec[top_taxa.index(tax_map[pid])] = 1.0
    return vec

train_tax_map = train_tax_df.set_index('pid')['taxid'].to_dict()
test_tax_map = test_tax_df.set_index('pid')['taxid'].to_dict()

print("\nScaling Features...")
all_pids_train = list(train_feats.keys()); all_pids_test = list(test_feats.keys())
X_manual_train = np.array([train_feats[p] for p in all_pids_train])
X_manual_test = np.array([test_feats[p] for p in all_pids_test])

scaler = StandardScaler()
X_manual_train = scaler.fit_transform(X_manual_train)
X_manual_test = scaler.transform(X_manual_test)

for i, pid in enumerate(all_pids_train):
    train_feats[pid] = np.concatenate([X_manual_train[i], get_tax_vector(pid, train_tax_map)])
for i, pid in enumerate(all_pids_test):
    test_feats[pid] = np.concatenate([X_manual_test[i], get_tax_vector(pid, test_tax_map)])
print("\n✅ Data prepared successfully")

# ==========================================
# 4. MODEL & DATASET
# ==========================================
class MLP(nn.Module):
    # Dynamic init
    def __init__(self, input_dim, num_classes, h1, h2, drop):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, h1),
            nn.BatchNorm1d(h1),
            nn.ReLU(),
            nn.Dropout(drop),
            nn.Linear(h1, h2),
            nn.BatchNorm1d(h2),
            nn.ReLU(),
            nn.Dropout(drop),
            nn.Linear(h2, num_classes)
        )
    def forward(self, x): return self.net(x)

class HybridDataset(Dataset):
    def __init__(self, pids, labels, emb_dict, feat_dict):
        self.pids = pids; self.labels = labels; self.emb = emb_dict; self.feat = feat_dict
    def __len__(self): return len(self.pids)
    def __getitem__(self, idx):
        pid = self.pids[idx]
        x = np.concatenate([self.emb[pid], self.feat[pid]]).astype(np.float32)
        if self.labels is not None:
            y = self.labels[idx].toarray().flatten() if hasattr(self.labels[idx], 'toarray') else self.labels[idx]
            return torch.from_numpy(x), torch.tensor(y, dtype=torch.float32)
        return torch.from_numpy(x)

# ==========================================
# 5. TRAINING LOOP
# ==========================================
train_terms = pd.read_csv(CFG.TRAIN_TERMS, sep='\t')
aspects = {'Biological Process': 'P', 'Molecular Function': 'F', 'Cellular Component': 'C'}
submission_dfs = []

sample_pid = list(train_feats.keys())[0]
input_dim = 1280 + len(train_feats[sample_pid])

for long_name, short_code in aspects.items():
    print(f"\n{'='*40}\nTraining Aspect: {long_name} ({short_code})\n{'='*40}")
    
    aspect_df = train_terms[train_terms['aspect'] == short_code]
    prot_terms = aspect_df.groupby('EntryID')['term'].apply(list).to_dict()
    valid_pids = [p for p in prot_terms.keys() if p in embed_dict and p in train_feats]
    if not valid_pids: continue
        
    mlb = MultiLabelBinarizer(sparse_output=True)
    y_train = mlb.fit_transform([prot_terms[p] for p in valid_pids])
    num_classes = len(mlb.classes_)
    
    ds = HybridDataset(valid_pids, y_train, embed_dict, train_feats)
    
    # Store trained models for this aspect
    models_list = []
    
    for i, conf in enumerate(ensemble_configs):
        print(f"\nModel {i+1}/3: BS={conf['BS']}, Dim={conf['D1']}, Drop={conf['Drop']}")
        
        train_len = int(0.9 * len(ds))
        train_ds, val_ds = random_split(ds, [train_len, len(ds)-train_len])
        train_loader = DataLoader(train_ds, batch_size=conf['BS'], shuffle=True, num_workers=4)
        val_loader = DataLoader(val_ds, batch_size=conf['BS'], num_workers=4)
        
        model = MLP(input_dim, num_classes, conf['D1'], conf['D2'], conf['Drop']).to(CFG.DEVICE)
        opt = optim.AdamW(model.parameters(), lr=conf['LR'], weight_decay=1e-4)
        crit = nn.BCEWithLogitsLoss()
        sched = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=conf['EPOCHS'], eta_min=1e-6)
        f1_metric = MultilabelF1Score(num_labels=num_classes, threshold=0.1, average='micro').to(CFG.DEVICE)
        
        history = {'train_loss': [], 'val_loss': [], 'val_f1': []}

        for epoch in range(conf['EPOCHS']):
            model.train()
            t_loss = 0
            for x, y in train_loader:
                x, y = x.to(CFG.DEVICE), y.to(CFG.DEVICE)
                opt.zero_grad(); out = model(x); loss = crit(out, y)
                loss.backward(); opt.step(); t_loss += loss.item()
            
            model.eval()
            v_loss = 0; v_f1 = 0
            with torch.no_grad():
                for x, y in val_loader:
                    x, y = x.to(CFG.DEVICE), y.to(CFG.DEVICE)
                    out = model(x); v_loss += crit(out, y).item(); v_f1 += f1_metric(out, y).item()
            
            v_loss /= len(val_loader); v_f1 /= len(val_loader); t_loss /= len(train_loader)
            
            # Step Cosine Scheduler
            sched.step()
            current_lr = opt.param_groups[0]['lr']
            
            history['train_loss'].append(t_loss); history['val_loss'].append(v_loss); history['val_f1'].append(v_f1)
            
            if (epoch+1) % 5 == 0 or epoch == 0:
                print(f"Ep {epoch+1}: Train Loss {t_loss:.4f} | Val Loss {v_loss:.4f} | Val F1 {v_f1:.4f} | LR {current_lr:.2e}")
                
        plot_history(history, long_name, f"M{i+1}")
        models_list.append(model)
        
    # 3. Ensemble Prediction
    print(f"Predicting {short_code} with Ensemble...")
    test_pids = [p for p in test_seqs.keys() if p in embed_dict and p in test_feats]
    test_ds = HybridDataset(test_pids, None, embed_dict, test_feats)
    test_loader = DataLoader(test_ds, batch_size=64, num_workers=4)
    
    temp_preds = []
    threshold = CFG.THRESHOLDS[short_code]
    
    with torch.no_grad():
        for m in models_list: m.eval()
        
        batch_idx = 0
        for x in tqdm(test_loader):
            x = x.to(CFG.DEVICE)
            # Average 3 outputs
            outputs = [torch.sigmoid(m(x)) for m in models_list]
            avg_out = torch.stack(outputs).mean(dim=0).cpu().numpy()
            
            for i in range(len(x)):
                pid = test_pids[batch_idx + i]
                scores = avg_out[i]
                idxs = np.where(scores >= threshold)[0]
                if len(idxs) < CFG.MIN_PREDS: idxs = np.argsort(scores)[-CFG.MIN_PREDS:]
                for idx in idxs: temp_preds.append({'pid': pid, 'term': mlb.classes_[idx], 'p': scores[idx]})
            batch_idx += len(x)

    print(f"✅ Generated {len(temp_preds):,} predictions for {long_name}")
    submission_dfs.append(pd.DataFrame(temp_preds))
    del models_list, model, opt, crit, train_loader, val_loader, ds; gc.collect(); torch.cuda.empty_cache()

# ==========================================
# 6. POST-PROCESSING
# ==========================================
raw_df = pd.concat(submission_dfs, ignore_index=True)
initial_count = len(raw_df)
print(f"Total raw predictions: {initial_count:,}")
print(f"\nPost processing...")

go_parents = parse_obo(CFG.OBO_FILE)     
go_children = parse_obo_children(CFG.OBO_FILE)

# Positive Propagation
print("\nPropagating predictions...")
propagated_df = propagate_predictions(raw_df, go_parents)
after_prop_count = len(propagated_df)
print(f"Predictions after propagation: {after_prop_count:,} (Added {after_prop_count - initial_count:,})")

# GOA Corrections
if os.path.exists(CFG.GOA_PATH):
    print("\nApplying GOA Ground Truth corrections...")
    goa = pd.read_csv(CFG.GOA_PATH)
    col_map = {'db_object_id': 'pid', 'go_id': 'term'}
    goa = goa.rename(columns=col_map)
    
    if 'qualifier' in goa.columns:
        print("\nProcessing Negative Annotations...")
        # 1. Identify explicit 'NOT' annotations
        negs = goa[goa['qualifier'].str.contains('NOT', na=False)]
        
        # 2. Expand to include all descendants of 'NOT' terms
        expanded_neg_list = []
        descendant_cache = {}
        unique_neg_terms = negs['term'].unique()
        term_descendants = {t: get_descendants(t, go_children, descendant_cache) for t in tqdm(unique_neg_terms, disable=True)}
        
        for _, row in tqdm(negs.iterrows(), total=len(negs), disable=True):
            pid, term = row['pid'], row['term']
            # Ban the term itself
            expanded_neg_list.append((pid, term))
            # Ban all its children
            if term in term_descendants:
                for kid in term_descendants[term]:
                    expanded_neg_list.append((pid, kid))
        
        neg_set = set(expanded_neg_list)
        # Filter predictions
        propagated_df['key'] = list(zip(propagated_df['pid'], propagated_df['term']))
        count_before_filter = len(propagated_df)
        propagated_df = propagated_df[~propagated_df['key'].isin(neg_set)].drop(columns=['key'])
        
        removed_count = count_before_filter - len(propagated_df)
        print(f"✅ Removed {removed_count:,} negative annotations")

    print("\nProcessing Positive Annotations...")
    pos = goa[~goa['qualifier'].str.contains('NOT', na=False)] if 'qualifier' in goa.columns else goa
    # Only keep pids that are actually in our test set/submission
    relevant_pids = set(propagated_df['pid'].unique())
    pos = pos[pos['pid'].isin(relevant_pids)]
    pos_df = pd.DataFrame({'pid': pos['pid'], 'term': pos['term'], 'p': 1.0})
    
    count_before_merge = len(propagated_df)
    final_df = pd.concat([propagated_df, pos_df], ignore_index=True)
    final_df = final_df.groupby(['pid', 'term'], as_index=False)['p'].max()
    
    added_count = len(final_df) - count_before_merge
    print(f"✅ Applied {len(pos_df):,} known positive annotations")
    
else:
    final_df = propagated_df
    print("\nNo GOA file found, skipping corrections")

# Save
print(f"\n✅ Final submission contains {len(final_df):,} predictions")
final_df = final_df.sort_values(['pid', 'p'], ascending=[True, False])
final_df.to_csv('submission.tsv', sep='\t', index=False, header=False)
print("\nSaved submission.tsv")

In [None]:
from IPython.display import FileLink
FileLink(r'submission.tsv')