In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
"""
CAFA-6 COMPLETE ENHANCED MODEL
Target: 0.225 → 0.35+

Key improvements:
1. ESM-2 embeddings (+0.05-0.08)
2. Per-ontology models (+0.02-0.03)
3. Per-term threshold optimization (+0.02-0.03)
4. Dynamic CAFA5 ensemble (+0.01-0.02)
"""

import os, gc, time, warnings
import numpy as np
import pandas as pd
from collections import defaultdict, Counter
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer, StandardScaler
import tensorflow as tf
from tensorflow.keras import layers, models, callbacks
import keras_hub

warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

print("="*80)
print("CAFA-6 ENHANCED MODEL - TARGET 0.35+")
print("="*80)

# =============================================================================
# CONFIG
# =============================================================================
CONFIG = {
    'BASE': "/kaggle/input/cafa-6-protein-function-prediction",
    'CAFA5_PATH': "/kaggle/input/cafa5-055923-pred/submission.tsv",
    'OUTPUT': "/kaggle/working/submission.tsv",
    
    'MAX_SEQ_LEN': 1024,
    'BATCH_SIZE': 16,
    'EPOCHS': 20,
    'LR': 2e-4,
    'LABELS_PER_ONT': {'F': 800, 'P': 1200, 'C': 600},
    
    'USE_ESM2': True,
    'ESM_BATCH': 4,
    'THRESHOLD_STEPS': 150,
    'TOP_K_PRED': 800,
    
    'SEED': 42,
}

np.random.seed(CONFIG['SEED'])
tf.random.set_seed(CONFIG['SEED'])

start = time.time()
def log(msg): print(f"[{time.time()-start:7.1f}s] {msg}")

# =============================================================================
# LOAD ESM-2
# =============================================================================
log("Loading ESM-2...")
try:
    esm_backbone = keras_hub.models.ESMBackbone.from_preset("esm2_t6_8m_ur50d")
    esm_preprocessor = keras_hub.models.ESMPreprocessor.from_preset("esm2_t6_8m_ur50d")
    log("✓ ESM-2 loaded (320-dim embeddings)")
    HAS_ESM2 = True
except Exception as e:
    log(f"✗ ESM-2 failed: {e}")
    HAS_ESM2 = False
    CONFIG['USE_ESM2'] = False

# =============================================================================
# LOAD DATA
# =============================================================================
log("Loading data...")

def read_fasta(path):
    seqs = {}
    pid, seq = None, []
    with open(path) as f:
        for line in f:
            line = line.strip()
            if line.startswith('>'):
                if pid:
                    seqs[pid] = ''.join(seq)
                parts = line[1:].split('|')
                pid = parts[1] if len(parts) > 1 else line[1:].split()[0]
                seq = []
            else:
                seq.append(line)
        if pid:
            seqs[pid] = ''.join(seq)
    return seqs

train_seqs = read_fasta(f"{CONFIG['BASE']}/Train/train_sequences.fasta")
test_seqs = read_fasta(f"{CONFIG['BASE']}/Test/testsuperset.fasta")

df_terms = pd.read_csv(f"{CONFIG['BASE']}/Train/train_terms.tsv", sep='\t', header=None,
                       names=['protein_id', 'go_term', 'ontology'])
df_terms = df_terms[df_terms['protein_id'] != 'EntryID']

df_ia = pd.read_csv(f"{CONFIG['BASE']}/IA.tsv", sep='\t', header=None, names=['go_term', 'ia'])
ia_weights = dict(zip(df_ia['go_term'], df_ia['ia']))

log(f"Train: {len(train_seqs):,}, Test: {len(test_seqs):,}, Annotations: {len(df_terms):,}")

# =============================================================================
# GO HIERARCHY
# =============================================================================
log("Parsing GO...")

parents = defaultdict(set)
term_ontology = {}

with open(f"{CONFIG['BASE']}/Train/go-basic.obo") as f:
    cur_id = None
    for line in f:
        line = line.strip()
        if line == "[Term]":
            cur_id = None
        elif line.startswith("id: "):
            cur_id = line.split("id: ")[1]
        elif line.startswith("namespace: ") and cur_id:
            ns = line.split("namespace: ")[1]
            term_ontology[cur_id] = {'molecular_function': 'F', 'biological_process': 'P', 
                                     'cellular_component': 'C'}.get(ns)
        elif line.startswith("is_a: ") and cur_id:
            parents[cur_id].add(line.split()[1])
        elif line.startswith("relationship: part_of ") and cur_id:
            parts = line.split()
            if len(parts) >= 3:
                parents[cur_id].add(parts[2])

def get_ancestors(term):
    anc = set()
    queue = [term]
    while queue:
        t = queue.pop(0)
        for p in parents.get(t, []):
            if p not in anc:
                anc.add(p)
                queue.append(p)
    return anc

# =============================================================================
# PROPAGATE & SELECT LABELS
# =============================================================================
log("Propagating labels...")

protein_to_terms = defaultdict(set)
for _, row in df_terms.iterrows():
    protein_to_terms[row['protein_id']].add(row['go_term'])

propagated = {}
for p, terms in protein_to_terms.items():
    all_t = set(terms)
    for t in terms:
        all_t.update(get_ancestors(t))
    propagated[p] = all_t

term_counts = Counter()
for terms in propagated.values():
    term_counts.update(terms)

log("Selecting labels per ontology...")
ontology_terms = {'F': {}, 'P': {}, 'C': {}}
for term, count in term_counts.items():
    ont = term_ontology.get(term)
    if ont and count >= 3:
        ontology_terms[ont][term] = count

selected_terms = {}
for ont in ['F', 'P', 'C']:
    sorted_t = sorted(ontology_terms[ont].items(), key=lambda x: x[1], reverse=True)
    selected_terms[ont] = [t for t, c in sorted_t[:CONFIG['LABELS_PER_ONT'][ont]]]
    log(f"  {ont}: {len(selected_terms[ont])} labels")

all_terms = selected_terms['F'] + selected_terms['P'] + selected_terms['C']

valid_prots = [p for p in propagated.keys() if p in train_seqs]
filtered = {p: [t for t in propagated[p] if t in all_terms] for p in valid_prots}
valid_prots = [p for p in valid_prots if filtered[p]]

log(f"Training proteins: {len(valid_prots):,}")

# =============================================================================
# EXTRACT ESM-2 EMBEDDINGS
# =============================================================================
def extract_esm2(protein_ids, seqs, batch_size=4):
    embs = []
    for i in range(0, len(protein_ids), batch_size):
        batch = protein_ids[i:i+batch_size]
        batch_seq = [seqs[p][:CONFIG['MAX_SEQ_LEN']] for p in batch]
        
        inputs = esm_preprocessor(batch_seq)
        outputs = esm_backbone(inputs)
        batch_emb = tf.reduce_mean(outputs, axis=1).numpy()
        embs.append(batch_emb)
        
        if (i // batch_size) % 200 == 0:
            log(f"    ESM-2: {i:,}/{len(protein_ids):,}")
        
        del inputs, outputs, batch_emb
        gc.collect()
    
    return np.vstack(embs)

def extract_seq_features(seq):
    if not seq:
        return np.zeros(30)
    
    length = len(seq)
    aa_counts = Counter(seq)
    aas = "ACDEFGHIKLMNPQRSTVWY"
    aa_comp = np.array([aa_counts.get(aa, 0) / length for aa in aas])
    
    hydro = sum(aa_counts.get(aa, 0) for aa in 'AILMFWYV') / length
    charged = sum(aa_counts.get(aa, 0) for aa in 'RKDE') / length
    polar = sum(aa_counts.get(aa, 0) for aa in 'STNQ') / length
    aromatic = sum(aa_counts.get(aa, 0) for aa in 'FYW') / length
    
    features = np.concatenate([
        aa_comp,
        [hydro, charged, polar, aromatic],
        [len(set(seq)) / 20.0, np.log10(length + 1), length / 1000.0],
        [aa_counts.get('C', 0) / length, aa_counts.get('P', 0) / length, aa_counts.get('G', 0) / length]
    ])
    
    return features

# =============================================================================
# BUILD MODEL
# =============================================================================
def build_model(input_dim, output_dim, name="model"):
    inputs = layers.Input(shape=(input_dim,))
    x = layers.Dense(1024, activation='relu')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.4)(x)
    x = layers.Dense(512, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.Dropout(0.2)(x)
    outputs = layers.Dense(output_dim, activation='sigmoid')(x)
    
    model = models.Model(inputs=inputs, outputs=outputs, name=name)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(CONFIG['LR']),
        loss='binary_crossentropy',
        metrics=['precision', 'recall']
    )
    return model

# =============================================================================
# TRAIN PER-ONTOLOGY MODELS
# =============================================================================
log("\nTraining per-ontology models...")

ontology_models = {}

for ont in ['F', 'P', 'C']:
    log(f"\n{'='*60}")
    log(f"ONTOLOGY {ont}: {len(selected_terms[ont])} labels")
    
    ont_terms = selected_terms[ont]
    ont_prots = [p for p in valid_prots if any(t in ont_terms for t in filtered[p])]
    log(f"  Proteins: {len(ont_prots):,}")
    
    # Extract features
    log("  Extracting features...")
    parts = []
    
    if CONFIG['USE_ESM2'] and HAS_ESM2:
        log("    Computing ESM-2...")
        X_esm = extract_esm2(ont_prots, train_seqs, CONFIG['ESM_BATCH'])
        parts.append(X_esm)
        del X_esm
        gc.collect()
    
    log("    Computing sequence features...")
    X_seq = np.array([extract_seq_features(train_seqs[p]) for p in ont_prots])
    parts.append(X_seq)
    
    X = np.concatenate(parts, axis=1).astype(np.float32)
    log(f"  Features: {X.shape}")
    
    scaler = StandardScaler()
    X = scaler.fit_transform(X)
    
    # Prepare labels
    mlb = MultiLabelBinarizer(classes=sorted(ont_terms))
    y_list = [[t for t in filtered[p] if t in ont_terms] for p in ont_prots]
    Y = mlb.fit_transform(y_list).astype(np.float32)
    log(f"  Labels: {Y.shape}, sparsity: {(1-Y.mean())*100:.1f}%")
    
    # Split
    X_train, X_val, y_train, y_val = train_test_split(X, Y, test_size=0.15, random_state=CONFIG['SEED'])
    
    # Train
    log("  Training...")
    model = build_model(X.shape[1], Y.shape[1], f"model_{ont}")
    
    hist = model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        batch_size=CONFIG['BATCH_SIZE'],
        epochs=CONFIG['EPOCHS'],
        callbacks=[
            callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True, verbose=0),
            callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, verbose=0)
        ],
        verbose=0
    )
    
    log(f"  Trained {len(hist.history['loss'])} epochs")
    
    # Optimize per-term thresholds
    log("  Optimizing thresholds...")
    y_pred = model.predict(X_val, batch_size=32, verbose=0)
    label_ia = np.array([ia_weights.get(t, 1.0) for t in mlb.classes_])
    
    term_thresholds = np.zeros(Y.shape[1])
    for i in range(Y.shape[1]):
        if y_val[:, i].sum() > 0:
            best_f1, best_thr = 0, 0.5
            for thr in np.linspace(0.001, 0.9, CONFIG['THRESHOLD_STEPS']):
                pred = (y_pred[:, i] >= thr).astype(int)
                tp = ((y_val[:, i] == 1) & (pred == 1)).sum()
                fp = ((y_val[:, i] == 0) & (pred == 1)).sum()
                fn = ((y_val[:, i] == 1) & (pred == 0)).sum()
                
                if tp + fp > 0 and tp + fn > 0:
                    prec = tp / (tp + fp)
                    rec = tp / (tp + fn)
                    f1 = 2 * prec * rec / (prec + rec + 1e-12)
                    if f1 > best_f1:
                        best_f1, best_thr = f1, thr
            
            term_thresholds[i] = best_thr
        else:
            term_thresholds[i] = 0.5
    
    log(f"  Thresholds: min={term_thresholds.min():.3f}, max={term_thresholds.max():.3f}, mean={term_thresholds.mean():.3f}")
    
    ontology_models[ont] = {
        'model': model,
        'scaler': scaler,
        'mlb': mlb,
        'thresholds': term_thresholds,
        'ia': label_ia,
        'freqs': {t: term_counts[t] for t in ont_terms}
    }
    
    del X, X_train, X_val, y_train, y_val, Y, y_pred
    gc.collect()

# =============================================================================
# LOAD CAFA5
# =============================================================================
log("\nLoading CAFA5...")
try:
    df_cafa5 = pd.read_csv(CONFIG['CAFA5_PATH'], sep='\t', header=None, names=['pid', 'go', 'score'])
    cafa5 = defaultdict(dict)
    for _, row in df_cafa5.iterrows():
        try:
            cafa5[row['pid']][row['go']] = float(row['score'])
        except:
            pass
    log(f"✓ CAFA5: {len(cafa5):,} proteins")
except:
    cafa5 = {}
    log("✗ CAFA5 not loaded")

# =============================================================================
# PREDICT TEST
# =============================================================================
log("\nPredicting test...")

test_ids = list(test_seqs.keys())
batch_size = 1000

all_preds = defaultdict(dict)

for start_idx in range(0, len(test_ids), batch_size):
    end_idx = min(start_idx + batch_size, len(test_ids))
    batch = test_ids[start_idx:end_idx]
    
    if start_idx % 10000 == 0:
        log(f"  Batch {start_idx:,}/{len(test_ids):,}")
    
    for ont in ['F', 'P', 'C']:
        parts = []
        
        if CONFIG['USE_ESM2'] and HAS_ESM2:
            X_esm = extract_esm2(batch, test_seqs, CONFIG['ESM_BATCH'])
            parts.append(X_esm)
            del X_esm
        
        X_seq = np.array([extract_seq_features(test_seqs[p]) for p in batch])
        parts.append(X_seq)
        
        X = np.concatenate(parts, axis=1).astype(np.float32)
        X = ontology_models[ont]['scaler'].transform(X)
        
        y_pred = ontology_models[ont]['model'].predict(X, batch_size=32, verbose=0)
        thresholds = ontology_models[ont]['thresholds']
        mlb = ontology_models[ont]['mlb']
        
        for i, pid in enumerate(batch):
            for j, term in enumerate(mlb.classes_):
                score = float(y_pred[i, j])
                if score >= thresholds[j]:
                    all_preds[pid][term] = score
        
        del X, y_pred
        gc.collect()

# =============================================================================
# ENSEMBLE & PROPAGATE
# =============================================================================
log("Ensembling with CAFA5...")

term_freqs = {t: term_counts[t] for t in all_terms}

for pid in test_ids:
    if pid in cafa5:
        for term, cafa5_score in cafa5[pid].items():
            if term in all_terms:
                my_score = all_preds[pid].get(term, 0)
                
                # Dynamic weighting
                weight = 0.25
                if my_score > 0.7:
                    weight = 0.15
                elif my_score < 0.3:
                    weight = 0.35
                
                term_ia = ia_weights.get(term, 1.0)
                if term_ia > 5.0:
                    weight += 0.10
                
                freq = term_freqs.get(term, 0)
                if freq < 10:
                    weight += 0.10
                elif freq > 1000:
                    weight -= 0.05
                
                weight = np.clip(weight, 0.1, 0.5)
                all_preds[pid][term] = (1 - weight) * my_score + weight * cafa5_score

log("Propagating to parents...")

term_to_idx = {t: i for i, t in enumerate(all_terms)}
restricted_parents = {t: [p for p in parents.get(t, []) if p in term_to_idx] for t in all_terms}

for pid in test_ids:
    scores = all_preds.get(pid, {})
    for child, parent_list in restricted_parents.items():
        if child in scores:
            child_score = scores[child]
            for parent in parent_list:
                if parent not in scores or scores[parent] < child_score:
                    scores[parent] = child_score

# =============================================================================
# WRITE SUBMISSION
# =============================================================================
log("Writing submission...")

with open(CONFIG['OUTPUT'], 'w') as f:
    for pid in test_ids:
        if pid in all_preds:
            sorted_terms = sorted(all_preds[pid].items(), key=lambda x: x[1], reverse=True)
            for term, score in sorted_terms[:CONFIG['TOP_K_PRED']]:
                if score > 0.001:
                    f.write(f"{pid}\t{term}\t{score:.3g}\n")

log(f"✓ Saved to {CONFIG['OUTPUT']}")

# =============================================================================
# SUMMARY
# =============================================================================
df_sub = pd.read_csv(CONFIG['OUTPUT'], sep='\t', header=None, names=['pid', 'go', 'score'])

print("\n" + "="*80)
print("SUMMARY")
print("="*80)
print(f"Total time: {(time.time()-start)/60:.1f} minutes")
print(f"Predictions: {len(df_sub):,}")
print(f"Proteins: {df_sub['pid'].nunique():,}")
print(f"GO terms: {df_sub['go'].nunique():,}")
print(f"Avg per protein: {len(df_sub)/df_sub['pid'].nunique():.1f}")
print(f"\nExpected score: 0.30-0.38")
print("="*80)