In [None]:
import numpy as np
import pandas as pd
import os
import gc
import sys
import subprocess
from pathlib import Path
from collections import defaultdict, Counter
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

# ============================================================================
# CONFIG
# ============================================================================
CONFIG = {
    # Data sampling
    'MAX_TRAIN_SAMPLES': 120000,
    'TOP_K_FREQUENT': 1000, # t·ª´ 650 v√† tƒÉng ƒëi·ªÉm
    'TOP_K_RARE': 1200, # t·ª´ 800 v√† tƒÉng ƒëi·ªÉm
    'MIN_FREQ_COMMON': 20,
    'MIN_FREQ_RARE': 3,
    
    # Focal Loss parameters
    'FOCAL_LOSS_GAMMA': 2.5,
    'FOCAL_LOSS_ALPHA': 0.3,
    'USE_IA_WEIGHTS': True,
    
    # Multi-embedding ensemble
    'USE_MULTI_EMBEDDING': True,
    'EMBEDDING_SOURCES': ['esm2', 'protbert', 't5'],
    'EMBEDDING_FUSION': 'attention',
    
    # High-IA term prioritization
    'IA_SAMPLING_RATIO': 0.4,
    'HIGH_IA_THRESHOLD': 2.5,
    'IA_CLIP_MIN': None,        # ‚úÖ B·ªé clipping!
    'IA_CLIP_MAX': None,
    'IA_TRANSFORM': 'log1p',
    
    # Co-occurrence modeling
    'USE_COOCCURRENCE': True,
    'COOCCUR_TOP_K': 50,
    
    # Model architecture
    'HIDDEN_DIMS': [512, 256],
    'DROPOUT_RATE': 0.3,
    'USE_BATCH_NORM': True,
    
    # Training
    'EPOCHS': 35,
    'BATCH_SIZE': 64,
    'LEARNING_RATE': 1e-3,
    'WEIGHT_DECAY': 1e-4,
    'GRAD_CLIP': 1.0,
    'WARMUP_EPOCHS': 5,
    
    # Prediction
    'MAX_PREDS_PER_PROTEIN': 2000,
    'MIN_CONFIDENCE': 0.005,# t·ª´ 0.005
    'TEMPERATURE': 0.8,
    
    # GO propagation
    'USE_GO_PROPAGATION': True,
    'PROPAGATION_DECAY': 0.70,
    
    # Ontology calibration
    'ONTOLOGY_CALIBRATION': {
        'MFO': 1.10,
        'BPO': 1.00,
        'CCO': 1.05
    },
    
    # Ensemble weights
    'BASE_BLAST_WEIGHT': 0.55, # best
    'BASE_DL_WEIGHT': 0.35, # best
    'BASE_FREQ_WEIGHT': 0.10, # best
    
    # Paths
    'BASE_PATH': '/kaggle/input/cafa-6-protein-function-prediction',
    'ESM2_PATH': 'cafa-5-ems-2-embeddings-numpy',
    'PROTBERT_PATH': 'protbert-embeddings-for-cafa5',
    'T5_PATH': 't5embeds',
    'BLAST_PATH': '/kaggle/input/blast-quick-sprof-zero-pred/submission.tsv',
    'RANDOM_SEED': 42,
}

print("="*80)
print("üöÄ CAFA 6 - FIXED VERSION")
print("="*80)

# ============================================================================
# SETUP
# ============================================================================
def install(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", package])

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
except:
    install('torch')
    import torch
    import torch.nn as nn
    import torch.nn.functional as F

try:
    import networkx as nx
    import obonet
except:
    install('networkx')
    install('obonet')
    import networkx as nx
    import obonet

np.random.seed(CONFIG['RANDOM_SEED'])
torch.manual_seed(CONFIG['RANDOM_SEED'])
if torch.cuda.is_available():
    torch.cuda.manual_seed(CONFIG['RANDOM_SEED'])
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üéØ Device: {device}")

# ============================================================================
# FOCAL LOSS
# ============================================================================
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, alpha=0.25, reduction='mean'):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction
    
    def forward(self, inputs, targets):
        bce_loss = F.binary_cross_entropy(inputs, targets, reduction='none')
        p_t = inputs * targets + (1 - inputs) * (1 - targets)
        focal_weight = (1 - p_t) ** self.gamma
        
        if self.alpha is not None:
            alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
            focal_loss = alpha_t * focal_weight * bce_loss
        else:
            focal_loss = focal_weight * bce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# ============================================================================
# LOAD GO ONTOLOGY
# ============================================================================
print("\n[1/12] Loading GO ontology...")
BASE = Path(CONFIG['BASE_PATH'])
TRAIN_DIR = BASE / 'Train'

go_graph = None
ancestor_cache = {}

if CONFIG['USE_GO_PROPAGATION']:
    try:
        go_graph = obonet.read_obo(TRAIN_DIR / 'go-basic.obo')
        print(f"   ‚úì Loaded {len(go_graph):,} GO terms")
        
        def get_ancestors(term_id, graph):
            ancestors = set()
            try:
                for node in nx.descendants(graph, term_id):
                    ancestors.add(node)
            except:
                pass
            return ancestors
    except Exception as e:
        print(f"   ‚ö†Ô∏è Could not load GO graph: {e}")
        CONFIG['USE_GO_PROPAGATION'] = False

# ============================================================================
# LOAD IA WEIGHTS
# ============================================================================
print("\n[2/12] Loading IA weights...")
ia_weights_dict = {}

try:
    ia_file = BASE / 'IA.txt'
    if not ia_file.exists():
        ia_file = BASE / 'IA.tsv'
    
    ia_df = pd.read_csv(ia_file, sep='\t', header=None, names=['term', 'ia_weight'])
    ia_weights_dict = dict(zip(ia_df['term'], ia_df['ia_weight']))
    
    ia_values = list(ia_weights_dict.values())
    print(f"   Loaded IA weights for {len(ia_weights_dict):,} terms")
    print(f"   Mean IA: {np.mean(ia_values):.2f}")
    print(f"   Max IA: {np.max(ia_values):.2f}")
    
    high_ia_terms = [t for t, ia in ia_weights_dict.items() 
                     if ia >= CONFIG['HIGH_IA_THRESHOLD']]
    print(f"   High-IA terms: {len(high_ia_terms):,}")
    
except Exception as e:
    print(f"   Could not load IA weights: {e}")
    print(f"   ‚Üí Using uniform weights")
    CONFIG['USE_IA_WEIGHTS'] = False
    high_ia_terms = []

# ============================================================================
# TWO-STAGE TERM SELECTION
# ============================================================================
print("\n[3/12] Loading annotations with two-stage selection...")
train_terms_df = pd.read_csv(TRAIN_DIR / 'train_terms.tsv', sep='\t', 
                              header=None, names=['protein', 'term', 'aspect'])

term_freq = train_terms_df['term'].value_counts()
print(f"   Total unique terms: {len(term_freq):,}")

# Stage 1: Frequent terms
frequent_terms = term_freq[term_freq >= CONFIG['MIN_FREQ_COMMON']].index.tolist()
frequent_terms = frequent_terms[:CONFIG['TOP_K_FREQUENT']]
print(f"   üìä Stage 1 - Frequent terms: {len(frequent_terms):,}")

# Stage 2: High-IA rare terms
rare_candidates = term_freq[
    (term_freq >= CONFIG['MIN_FREQ_RARE']) & 
    (term_freq < CONFIG['MIN_FREQ_COMMON'])
].index.tolist()

if CONFIG['USE_IA_WEIGHTS'] and high_ia_terms:
    rare_with_ia = [(t, ia_weights_dict.get(t, 0)) for t in rare_candidates 
                    if t in ia_weights_dict]
    rare_with_ia.sort(key=lambda x: x[1], reverse=True)
    rare_high_ia = [t for t, ia in rare_with_ia[:CONFIG['TOP_K_RARE']]]
    print(f" Stage 2 - High-IA rare terms: {len(rare_high_ia):,}")
else:
    rare_high_ia = []
    print(f" Stage 2 - High-IA rare terms: 0 (no IA data)")

# Combine
top_terms = list(set(frequent_terms + rare_high_ia))

print(f" Total selected terms: {len(top_terms):,}")
# BUILD CACHE SAU KHI C√ì top_terms
if CONFIG['USE_GO_PROPAGATION'] and go_graph:
    print("   Building ancestor cache for selected terms...")
    for term in tqdm(top_terms, desc="   Caching", leave=False):
        if term in go_graph:
            ancestor_cache[term] = get_ancestors(term, go_graph)
    print(f"   ‚úì Cached {len(ancestor_cache):,}/{len(top_terms):,} term ancestors")
train_terms_df = train_terms_df[train_terms_df['term'].isin(top_terms)]

# Create mappings
protein_to_terms = train_terms_df.groupby('protein')['term'].apply(list).to_dict()
term_to_aspect = dict(zip(train_terms_df['term'], train_terms_df['aspect']))
term_to_idx = {term: idx for idx, term in enumerate(top_terms)}
idx_to_term = {idx: term for term, idx in term_to_idx.items()}

print(f" {len(protein_to_terms):,} proteins with annotations")

# Create IA weight vector
ia_weight_vector = np.ones(len(top_terms))

if CONFIG['USE_IA_WEIGHTS']:
    for idx, term in enumerate(top_terms):
        ia_weight_vector[idx] = ia_weights_dict.get(term, 1.0)
    ia_weight_vector = np.log1p(ia_weight_vector)
    ia_weight_vector = ia_weight_vector / ia_weight_vector.mean()
    ia_weight_vector = np.clip(ia_weight_vector, 0.1, 10.0)
    
    print(f"IA weights (log-scaled): mean={ia_weight_vector.mean():.2f}, "
          f"range=[{ia_weight_vector.min():.2f}, {ia_weight_vector.max():.2f}]")

# ============================================================================
# CO-OCCURRENCE MATRIX
# ============================================================================
print("\n[4/12] Building co-occurrence matrix...")

if CONFIG['USE_COOCCURRENCE']:
    cooccur_matrix = np.zeros((len(top_terms), len(top_terms)))
    
    for terms in tqdm(protein_to_terms.values(), desc="   Computing", leave=False):
        term_indices = [term_to_idx[t] for t in terms if t in term_to_idx]
        
        for i in term_indices:
            for j in term_indices:
                if i != j:
                    cooccur_matrix[i, j] += 1
    
    row_sums = cooccur_matrix.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0] = 1
    cooccur_matrix = cooccur_matrix / row_sums
    
    print(f"  Co-occurrence matrix: {cooccur_matrix.shape}")
    print(f"  Sparsity: {(cooccur_matrix > 0).sum() / cooccur_matrix.size * 100:.1f}%")
else:
    cooccur_matrix = None

# ============================================================================
# LOAD EMBEDDINGS WITH FALLBACK
# ============================================================================
print("\n[5/12] Loading embeddings...")

embedding_dicts = {}
embedding_dims = []

# ESM2 (always available)
esm2_base = f"/kaggle/input/{CONFIG['ESM2_PATH']}"
train_ids = np.load(f"{esm2_base}/train_ids.npy", allow_pickle=True)
train_embeds = np.load(f"{esm2_base}/train_embeddings.npy")
test_ids = np.load(f"{esm2_base}/test_ids.npy", allow_pickle=True)
test_embeds = np.load(f"{esm2_base}/test_embeddings.npy")

embedding_dicts['esm2'] = {
    'train': {str(pid): emb for pid, emb in zip(train_ids, train_embeds)},
    'test': {str(pid): emb for pid, emb in zip(test_ids, test_embeds)}
}
embedding_dims.append(train_embeds.shape[1])
print(f" ESM2: dim={train_embeds.shape[1]}")

del train_ids, train_embeds, test_ids, test_embeds
gc.collect()

# Try other embeddings
if CONFIG['USE_MULTI_EMBEDDING']:
    for emb_name, path_key in [('protbert', 'PROTBERT_PATH'), ('t5', 'T5_PATH')]:
        try:
            emb_path = f"/kaggle/input/{CONFIG[path_key]}"
            
            if emb_name == 't5':
                train_ids = np.load(f"{emb_path}/train_ids.npy", allow_pickle=True)
                train_embeds = np.load(f"{emb_path}/train_embeds.npy")
                test_ids = np.load(f"{emb_path}/test_ids.npy", allow_pickle=True)
                test_embeds = np.load(f"{emb_path}/test_embeds.npy")
            else:
                train_ids = np.load(f"{emb_path}/train_ids.npy", allow_pickle=True)
                train_embeds = np.load(f"{emb_path}/train_embeddings.npy")
                test_ids = np.load(f"{emb_path}/test_ids.npy", allow_pickle=True)
                test_embeds = np.load(f"{emb_path}/test_embeddings.npy")
            
            embedding_dicts[emb_name] = {
                'train': {str(pid): emb for pid, emb in zip(train_ids, train_embeds)},
                'test': {str(pid): emb for pid, emb in zip(test_ids, test_embeds)}
            }
                
            embedding_dims.append(train_embeds.shape[1])
            print(f" {emb_name.upper()}: dim={train_embeds.shape[1]}")
            
            del train_ids, train_embeds, test_ids, test_embeds
            gc.collect()
            
        except Exception as e:
            print(f"   ‚ö†Ô∏è Could not load {emb_name}: {e}")
            pass

print(f"\n Total sources: {len(embedding_dicts)}")
print(f"  Embedding dims: {embedding_dims}")

# ============================================================================
# PREPARE TRAINING DATA
# ============================================================================
print("\n[6/12] Preparing training data...")

valid_proteins = []
for p in protein_to_terms.keys():
    has_all = all(p in embedding_dicts[src]['train'] 
                  for src in embedding_dicts.keys())
    if has_all:
        valid_proteins.append(p)

valid_proteins = valid_proteins[:CONFIG['MAX_TRAIN_SAMPLES']]
print(f" Valid proteins: {len(valid_proteins):,}")

from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import StratifiedShuffleSplit

mlb = MultiLabelBinarizer(classes=range(len(top_terms)))
y_labels = [[term_to_idx[t] for t in protein_to_terms.get(p, []) if t in term_to_idx] 
            for p in valid_proteins]
y_encoded = mlb.fit_transform(y_labels).astype(float)

protein_categories = []
for p in valid_proteins:
    n_terms = len(protein_to_terms.get(p, []))
    if n_terms < 5:
        category = 'sparse'
    elif n_terms < 15:
        category = 'medium'
    else:
        category = 'rich'
    protein_categories.append(category)

splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.15, 
                                  random_state=CONFIG['RANDOM_SEED'])
train_idx, val_idx = next(splitter.split(valid_proteins, protein_categories))

train_proteins = [valid_proteins[i] for i in train_idx]
val_proteins = [valid_proteins[i] for i in val_idx]
y_train = y_encoded[train_idx]
y_val = y_encoded[val_idx]

print(f"   Train: {len(train_proteins):,} proteins")
print(f"   Val: {len(val_proteins):,} proteins")

print("\n[7/12] Building model...")

class MultiEmbeddingFusion(nn.Module):
    """Fuse multiple embeddings using attention"""
    def __init__(self, embedding_dims, output_dim):
        super().__init__()
        self.n_sources = len(embedding_dims)
        self.projections = nn.ModuleList([
            nn.Linear(dim, output_dim) for dim in embedding_dims
        ])
        if self.n_sources > 1:
            self.attention = nn.Linear(output_dim, 1)
    
    def forward(self, embeddings):
        if not isinstance(embeddings, list):
            embeddings = [embeddings]
        
        projected = [proj(emb) for proj, emb in zip(self.projections, embeddings)]
        
        if len(projected) == 1:
            return projected[0]
        
        # Attention fusion
        stacked = torch.stack(projected, dim=1)
        attn_scores = self.attention(stacked).squeeze(-1)
        attn_weights = F.softmax(attn_scores, dim=1).unsqueeze(-1)
        fused = (stacked * attn_weights).sum(dim=1)
        
        return fused

class AdvancedProteinClassifier(nn.Module):
    def __init__(self, embedding_dims, num_terms, hidden_dims, 
                 dropout, use_cooccurrence=False):
        super().__init__()
        
        # Multi-embedding fusion
        fusion_dim = hidden_dims[0]
        self.fusion = MultiEmbeddingFusion(embedding_dims, fusion_dim)
        
        # Deep encoder
        layers = []
        in_dim = fusion_dim
        
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(in_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout)
            ])
            in_dim = hidden_dim
        
        self.encoder = nn.Sequential(*layers)
        self.output = nn.Linear(hidden_dims[-1], num_terms)
        
        # Initialize
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
    
    def forward(self, embeddings, cooccur_matrix=None, return_logits=False):
        fused = self.fusion(embeddings)
        features = self.encoder(fused)
        logits = self.output(features)
        
        if return_logits:
            return logits
        return torch.sigmoid(logits)

model = AdvancedProteinClassifier(
    embedding_dims,
    len(top_terms),
    CONFIG['HIDDEN_DIMS'],
    CONFIG['DROPOUT_RATE'],
    use_cooccurrence=CONFIG['USE_COOCCURRENCE']
).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"   Parameters: {total_params:,}")
print(f"   Embedding sources: {len(embedding_dims)}")

# ============================================================================
# TRAINING
# ============================================================================
print("\n" + "="*80)
print(f"TRAINING ({CONFIG['EPOCHS']} EPOCHS)")
print("="*80)

criterion = FocalLoss(
    gamma=CONFIG['FOCAL_LOSS_GAMMA'],
    alpha=CONFIG['FOCAL_LOSS_ALPHA'],
    reduction='none'
)

ia_weight_tensor = torch.FloatTensor(ia_weight_vector).to(device)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG['LEARNING_RATE'],
    weight_decay=CONFIG['WEIGHT_DECAY']
)

def get_lr_lambda(epoch):
    if epoch < CONFIG['WARMUP_EPOCHS']:
        return (epoch + 1) / CONFIG['WARMUP_EPOCHS']
    else:
        progress = (epoch - CONFIG['WARMUP_EPOCHS']) / (CONFIG['EPOCHS'] - CONFIG['WARMUP_EPOCHS'])
        return 0.5 * (1 + np.cos(np.pi * progress))

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, get_lr_lambda)

best_val_loss = float('inf')
cooccur_tensor = torch.FloatTensor(cooccur_matrix).to(device) if cooccur_matrix is not None else None

for epoch in range(CONFIG['EPOCHS']):
    # Training
    model.train()
    indices = np.random.permutation(len(train_proteins))
    epoch_loss = 0
    n_batches = 0
    
    for i in range(0, len(indices), CONFIG['BATCH_SIZE']):
        batch_idx = indices[i:i + CONFIG['BATCH_SIZE']]
        batch_proteins = [train_proteins[j] for j in batch_idx]
        
        # Gather embeddings
        if len(embedding_dicts) == 1:
            # Single source - pass as tensor
            src = list(embedding_dicts.keys())[0]
            embeddings = torch.FloatTensor([
                embedding_dicts[src]['train'][p] for p in batch_proteins
            ]).to(device)
        else:
            # Multiple sources - pass as list
            embeddings = []
            for src in embedding_dicts.keys():
                emb_batch = torch.FloatTensor([
                    embedding_dicts[src]['train'][p] for p in batch_proteins
                ]).to(device)
                embeddings.append(emb_batch)
        
        y_batch = torch.FloatTensor(y_train[batch_idx]).to(device)
        
        optimizer.zero_grad()
        outputs = model(embeddings, cooccur_tensor)
        
        loss_per_sample = criterion(outputs, y_batch)
        weighted_loss = (loss_per_sample * ia_weight_tensor).mean()
        
        weighted_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), CONFIG['GRAD_CLIP'])
        optimizer.step()
        
        epoch_loss += weighted_loss.item()
        n_batches += 1
    
    # Validation
    model.eval()
    val_loss = 0
    val_batches = 0
    
    with torch.no_grad():
        for i in range(0, len(val_proteins), CONFIG['BATCH_SIZE']):
            batch_proteins = val_proteins[i:i + CONFIG['BATCH_SIZE']]
            
            if len(embedding_dicts) == 1:
                src = list(embedding_dicts.keys())[0]
                embeddings = torch.FloatTensor([
                    embedding_dicts[src]['train'][p] for p in batch_proteins
                ]).to(device)
            else:
                embeddings = []
                for src in embedding_dicts.keys():
                    emb_batch = torch.FloatTensor([
                        embedding_dicts[src]['train'][p] for p in batch_proteins
                    ]).to(device)
                    embeddings.append(emb_batch)
            
            y_batch = torch.FloatTensor(y_val[i:i + CONFIG['BATCH_SIZE']]).to(device)
            
            outputs = model(embeddings, cooccur_tensor)
            loss_per_sample = criterion(outputs, y_batch)
            weighted_loss = (loss_per_sample * ia_weight_tensor).mean()
            
            val_loss += weighted_loss.item()
            val_batches += 1
    
    train_loss_avg = epoch_loss / n_batches
    val_loss_avg = val_loss / val_batches
    current_lr = optimizer.param_groups[0]['lr']
    
    scheduler.step()
    gc.collect()
    torch.cuda.empty_cache() 
    
    if val_loss_avg < best_val_loss:
        best_val_loss = val_loss_avg
        marker = "‚≠ê"
    else:
        marker = ""
    
    if (epoch + 1) % 5 == 0 or marker:
        print(f"Epoch {epoch+1:2d}: Train={train_loss_avg:.4f}, Val={val_loss_avg:.4f}, "
              f"LR={current_lr:.6f} {marker}")
    
    gc.collect()

print(f"\n‚úÖ Best Val Loss: {best_val_loss:.4f}")

# ============================================================================
# PREDICTIONS WITH ALL IMPROVEMENTS
# ============================================================================
print("\n" + "="*80)
print("GENERATING PREDICTIONS")
print("="*80)

model.eval()

# Get all test protein IDs (from first embedding source)
test_protein_ids = list(embedding_dicts[list(embedding_dicts.keys())[0]]['test'].keys())
print(f"Total test proteins: {len(test_protein_ids):,}")

# [Step 1/5] Generate DL predictions with multi-embedding
dl_predictions = {}

print("\n[Step 1/5] Generating multi-embedding DL predictions...")
with torch.no_grad():
    for start in tqdm(range(0, len(test_protein_ids), CONFIG['BATCH_SIZE']), 
                     desc="DL Inference"):
        batch_ids = test_protein_ids[start:start + CONFIG['BATCH_SIZE']]
        
        # Gather embeddings (match training logic)
        if len(embedding_dicts) == 1:
            src = list(embedding_dicts.keys())[0]
            embeddings = torch.FloatTensor([
                embedding_dicts[src]['test'][p] for p in batch_ids
            ]).to(device)
        else:
            embeddings = []
            for src in embedding_dicts.keys():
                emb_batch = torch.FloatTensor([
                    embedding_dicts[src]['test'][p] for p in batch_ids
                ]).to(device)
                embeddings.append(emb_batch)
        
        logits = model(embeddings, cooccur_tensor, return_logits=True)
        outputs = torch.sigmoid(logits / CONFIG['TEMPERATURE']).cpu().numpy()
        
        for i, pid in enumerate(batch_ids):
            dl_predictions[pid] = outputs[i]
        
        del embeddings, outputs, logits
        if start % 5000 == 0:
            gc.collect()

print(f"‚úì Generated DL predictions for {len(dl_predictions):,} proteins")

# [Step 2/5] Load BLAST predictions
blast_dict = {}

print("\n[Step 2/5] Loading BLAST predictions...")
if os.path.exists(CONFIG['BLAST_PATH']):
    blast_data = defaultdict(lambda: np.zeros(len(top_terms)))
    
    for chunk in pd.read_csv(CONFIG['BLAST_PATH'], sep='\t', header=None,
                             names=['Id', 'GO term', 'Confidence'], 
                             chunksize=100000):
        for _, row in chunk.iterrows():
            pid = row['Id']
            term = row['GO term']
            conf = float(row['Confidence'])
            
            if term in term_to_idx:
                idx = term_to_idx[term]
                blast_data[pid][idx] = max(blast_data[pid][idx], conf)
    
    blast_dict = dict(blast_data)
    print(f"  ‚úì Loaded BLAST for {len(blast_dict):,} proteins")
else:
    print("  ‚ö†Ô∏è BLAST file not found")

# [Step 3/5] Compute frequency baseline
print("\n[Step 3/5] Computing frequency baseline...")
term_frequencies = np.zeros(len(top_terms))
total_proteins = len(protein_to_terms)

for idx, term in enumerate(top_terms):
    count = term_freq.get(term, 0)
    term_frequencies[idx] = min(count / total_proteins, 0.5)

print(f"‚úì Frequency baseline computed")

# [Step 4/5] Adaptive ensemble merging
print("\n[Step 4/5] Adaptive ensemble merging...")

def get_knowledge_level(protein_id, train_terms_dict, term_to_aspect_dict):
    """Detect knowledge level"""
    if protein_id not in train_terms_dict:
        return 'no'
    
    terms = train_terms_dict[protein_id]
    aspects = set(term_to_aspect_dict.get(t, '') for t in terms)
    aspect_map = {'F': 'MFO', 'P': 'BPO', 'C': 'CCO'}
    mapped_aspects = set(aspect_map.get(a, a) for a in aspects)
    
    if len(mapped_aspects) == 0:
        return 'no'
    elif len(mapped_aspects) < 3:
        return 'limited'
    else:
        return 'partial'

merged_predictions = {}
knowledge_stats = Counter()

for pid, dl_probs in tqdm(dl_predictions.items(), desc="Merging"):
    knowledge = get_knowledge_level(pid, protein_to_terms, term_to_aspect)
    knowledge_stats[knowledge] += 1
    
    blast_probs = blast_dict.get(pid, np.zeros(len(top_terms)))
    
    # Adaptive weighting
    if knowledge == 'no':
        dl_w = 0.35
        freq_w = 0.15
        blast_w = 0.50
    elif knowledge == 'limited':
        dl_w = 0.40
        freq_w = 0.10
        blast_w = 0.50
    else:
        dl_w = 0.50
        freq_w = 0.05
        blast_w = 0.45
    
    # Weighted combination
    merged = (dl_probs * dl_w + 
              term_frequencies * freq_w + 
              blast_probs * blast_w)
    
    merged_predictions[pid] = merged

print(f"‚úì Merged predictions for {len(merged_predictions):,} proteins")
print("\n  Knowledge distribution:")
for level, count in sorted(knowledge_stats.items()):
    pct = count / len(merged_predictions) * 100
    print(f"    {level:8s}: {count:6,} ({pct:5.1f}%)")

# [Step 5/5] GO propagation + Ontology calibration
print("\n[Step 5/5] GO propagation + Ontology calibration...")

final_predictions = {}

for pid, probs in tqdm(merged_predictions.items(), desc="Processing"):
    new_probs = probs.copy()
    
    # GO propagation
    # Th√™m threshold cao h∆°n cho propagation
    PROPAGATION_MIN_CONF = 0.1 # Ch·ªâ propagate high-confidence predictions # t·ª´ 0.3
    if CONFIG['USE_GO_PROPAGATION'] and go_graph and ancestor_cache:
        # ‚úÖ Optimized
        high_conf_indices = np.where(probs > PROPAGATION_MIN_CONF)[0]
        for term_idx in high_conf_indices:
            term = idx_to_term[term_idx]
            score = probs[term_idx]
            
            if term in ancestor_cache:
                ancestor_indices = [term_to_idx[a] for a in ancestor_cache[term] 
                                   if a in term_to_idx]
                if ancestor_indices:
                    propagated = score * CONFIG['PROPAGATION_DECAY']
                    new_probs[ancestor_indices] = np.maximum(
                        new_probs[ancestor_indices], 
                        propagated
                    )
    
    # Ontology-specific calibration
    for term_idx, score in enumerate(new_probs):
        if score > CONFIG['MIN_CONFIDENCE']:
            term = idx_to_term[term_idx]
            aspect = term_to_aspect.get(term, 'F')
            
            aspect_map = {'F': 'MFO', 'P': 'BPO', 'C': 'CCO'}
            ontology = aspect_map.get(aspect, 'MFO')
            
            calibration_factor = CONFIG['ONTOLOGY_CALIBRATION'][ontology]
            new_probs[term_idx] = score * calibration_factor
    
    final_predictions[pid] = new_probs

print(f"‚úì Final predictions ready for {len(final_predictions):,} proteins")




# ============================================================================
# ADAPTIVE CUTOFF HELPERS - Th√™m v√†o tr∆∞·ªõc ph·∫ßn WRITE SUBMISSION
# ============================================================================

def get_adaptive_thresholds(protein_id, knowledge_level, blast_coverage, probs):
    """T√≠nh adaptive thresholds d·ª±a tr√™n protein characteristics"""
    
    # Base thresholds
    if knowledge_level == 'no':
        if blast_coverage > 50:
            min_conf = 0.008
            max_preds = 3000
        else:
            min_conf = 0.015
            max_preds = 1500
    elif knowledge_level == 'limited':
        min_conf = 0.010
        max_preds = 2500
    else:  # partial
        min_conf = 0.008
        max_preds = 3000
    
    # ƒêi·ªÅu ch·ªânh theo high-confidence predictions
    high_conf_count = (probs > 0.1).sum()
    
    if high_conf_count > 100:
        max_preds = min(max_preds + 500, 4000)
    elif high_conf_count < 20:
        max_preds = min(max_preds, 2000)
        min_conf = max(min_conf, 0.015)
    
    return min_conf, max_preds

def should_include_prediction(score, position, prev_score, min_conf, max_preds):
    """Quy·∫øt ƒë·ªãnh c√≥ include prediction kh√¥ng"""
    
    # Hard limits
    if score < min_conf or position >= max_preds:
        return False
    
    # Progressive thresholds
    if position > 2000 and score < 0.02:
        return False
    
    if position > 3000 and score < 0.05:
        return False
    
    # Score drop detection (cliff detection)
    if prev_score > 0 and position > 500:
        if score < prev_score * 0.3:  # Drop > 70%
            return False
    
    return True

# ============================================================================
# WRITE SUBMISSION WITH ADAPTIVE CUTOFF
# ============================================================================
print("\n" + "="*80)
print("WRITING SUBMISSION WITH ADAPTIVE CUTOFF")
print("="*80)

n_predictions = 0
protein_pred_counts = []
cutoff_stats = {'min_conf': 0, 'max_preds': 0, 'quality': 0}

with open('submission.tsv', 'w') as f:
    for pid, probs in tqdm(final_predictions.items(), desc="Writing"):
        # Get protein characteristics
        knowledge = get_knowledge_level(pid, protein_to_terms, term_to_aspect)
        blast_probs = blast_dict.get(pid, np.zeros(len(top_terms)))
        blast_coverage = (blast_probs > 0.01).sum()
        
        # Get adaptive thresholds
        min_conf, max_preds = get_adaptive_thresholds(
            pid, knowledge, blast_coverage, probs
        )
        
        # Sort predictions
        top_indices = np.argsort(probs)[::-1]
        
        protein_preds = 0
        prev_score = 1.0
        
        for position, idx in enumerate(top_indices):
            score = probs[idx]
            
            # Check if should include
            if should_include_prediction(score, position, prev_score, 
                                        min_conf, max_preds):
                term = idx_to_term[idx]
                f.write(f"{pid}\t{term}\t{min(score, 0.999):.3f}\n")
                n_predictions += 1
                protein_preds += 1
                prev_score = score
            else:
                # Track why stopped
                if score < min_conf:
                    cutoff_stats['min_conf'] += 1
                elif position >= max_preds:
                    cutoff_stats['max_preds'] += 1
                else:
                    cutoff_stats['quality'] += 1
                break
        
        if protein_preds > 0:
            protein_pred_counts.append(protein_preds)

print(f"\n‚úÖ Adaptive cutoff statistics:")
print(f"  Stopped by min_confidence: {cutoff_stats['min_conf']:,}")
print(f"  Stopped by max_preds: {cutoff_stats['max_preds']:,}")
print(f"  Stopped by quality filter: {cutoff_stats['quality']:,}")


print("\n‚úÖ ALL DONE!")

üöÄ CAFA 6 - FIXED VERSION
üéØ Device: cuda

[1/12] Loading GO ontology...
   ‚úì Loaded 40,122 GO terms

[2/12] Loading IA weights...
   ‚úì Loaded IA weights for 40,122 terms
   üìä Mean IA: 2.65
   üìä Max IA: 15.88
   üî• High-IA terms: 15,614

[3/12] Loading annotations with two-stage selection...
   Total unique terms: 26,126
   üìä Stage 1 - Frequent terms: 1,000
   üìä Stage 2 - High-IA rare terms: 1,200
   ‚úÖ Total selected terms: 2,200
   Building ancestor cache for selected terms...


   Caching:   0%|          | 0/2200 [00:00<?, ?it/s]

   ‚úì Cached 2,200/2,200 term ancestors
   ‚úì 75,281 proteins with annotations
   ‚úì IA weights (log-scaled): mean=1.01, range=[0.10, 1.80]

[4/12] Building co-occurrence matrix...


   Computing:   0%|          | 0/75281 [00:00<?, ?it/s]

   ‚úì Co-occurrence matrix: (2200, 2200)
   üìä Sparsity: 9.5%

[5/12] Loading embeddings...
   ‚úì ESM2: dim=1280
   ‚úì PROTBERT: dim=1024
   ‚úì T5: dim=1024

   ‚úÖ Total sources: 3
   üìä Embedding dims: [1280, 1024, 1024]

[6/12] Preparing training data...
   ‚úì Valid proteins: 72,804
   ‚úì Train: 61,883 proteins
   ‚úì Val: 10,921 proteins

[7/12] Building model...
   ‚úì Parameters: 2,666,905
   ‚úì Embedding sources: 3

TRAINING (35 EPOCHS)
Epoch  1: Train=0.0040, Val=0.0005, LR=0.000200 ‚≠ê
Epoch  2: Train=0.0005, Val=0.0004, LR=0.000400 ‚≠ê
Epoch  3: Train=0.0004, Val=0.0004, LR=0.000600 ‚≠ê
Epoch  4: Train=0.0004, Val=0.0003, LR=0.000800 ‚≠ê
Epoch  5: Train=0.0004, Val=0.0003, LR=0.001000 ‚≠ê
Epoch  6: Train=0.0003, Val=0.0003, LR=0.001000 ‚≠ê
Epoch  7: Train=0.0003, Val=0.0003, LR=0.000997 ‚≠ê
Epoch  8: Train=0.0003, Val=0.0003, LR=0.000989 ‚≠ê
Epoch  9: Train=0.0003, Val=0.0003, LR=0.000976 ‚≠ê
Epoch 10: Train=0.0003, Val=0.0003, LR=0.000957 ‚≠ê
Epoch 11: Train=0.000

DL Inference:   0%|          | 0/2217 [00:00<?, ?it/s]

‚úì Generated DL predictions for 141,864 proteins

[Step 2/5] Loading BLAST predictions...
  ‚úì Loaded BLAST for 200,797 proteins

[Step 3/5] Computing frequency baseline...
‚úì Frequency baseline computed

[Step 4/5] Adaptive ensemble merging...


Merging:   0%|          | 0/141864 [00:00<?, ?it/s]

‚úì Merged predictions for 141,864 proteins

  Knowledge distribution:
    limited : 48,452 ( 34.2%)
    no      : 71,517 ( 50.4%)
    partial : 21,895 ( 15.4%)

[Step 5/5] GO propagation + Ontology calibration...


Processing:   0%|          | 0/141864 [00:00<?, ?it/s]

‚úì Final predictions ready for 141,864 proteins

WRITING SUBMISSION WITH ADAPTIVE CUTOFF


Writing:   0%|          | 0/141864 [00:00<?, ?it/s]


‚úÖ Adaptive cutoff statistics:
  Stopped by min_confidence: 141,864
  Stopped by max_preds: 0
  Stopped by quality filter: 0

‚úÖ ALL DONE!
