In [None]:
! pip install -q py-tree
! python -m py_tree /kaggle/input/

# CAFA 6 PROTEIN FUNCTION PREDICTION

In [None]:
# ‚öôÔ∏è CONFIGURATION - ADJUST THIS FOR SPEED VS ACCURACY
SAMPLE_PERCENT = 100  # Use 100% of data
QUICK_MODE = True   # Enable full feature computation

# Package Installation
import subprocess
import sys

def install(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", package])

print("Installing required packages...")
try:
    import obonet
except:
    install('obonet')
    import obonet

try:
    from Bio import SeqIO
except:
    install('biopython')
    from Bio import SeqIO

# Core Imports
import pandas as pd
import numpy as np
from pathlib import Path
from collections import defaultdict, Counter
import warnings
warnings.filterwarnings('ignore')

import networkx as nx
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.multiclass import OneVsRestClassifier
from sklearn.linear_model import LogisticRegression

# Visualization imports
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Rectangle
import matplotlib.patches as mpatches
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

print("="*80)
print("CAFA 6 PROTEIN FUNCTION PREDICTION - ENHANCED STARTER")
print(f"üìä SAMPLE MODE: {SAMPLE_PERCENT}% of data")
print(f"‚ö° QUICK MODE: {'ON' if QUICK_MODE else 'OFF'}")
print("="*80)

# ============================================================================
# 1. DEFINE PATHS
# ============================================================================
BASE = Path('/kaggle/input/cafa-6-protein-function-prediction')
TRAIN_DIR = BASE / 'Train'
TEST_DIR = BASE / 'Test'

# ============================================================================
# 2. LOAD GO ONTOLOGY (WITH HIERARCHY ANALYSIS)
# ============================================================================
print("\n[1/9] Loading GO ontology...")
go_graph = obonet.read_obo(TRAIN_DIR / 'go-basic.obo')
print(f"   ‚úì Loaded {len(go_graph)} GO terms")

# Map terms to ontologies
term_to_ont = {}
term_names = {}
for term_id in go_graph.nodes():
    if 'namespace' in go_graph.nodes[term_id]:
        ns = go_graph.nodes[term_id]['namespace']
        if ns == 'biological_process':
            term_to_ont[term_id] = 'BPO'
        elif ns == 'cellular_component':
            term_to_ont[term_id] = 'CCO'
        elif ns == 'molecular_function':
            term_to_ont[term_id] = 'MFO'
    if 'name' in go_graph.nodes[term_id]:
        term_names[term_id] = go_graph.nodes[term_id]['name']

ont_counts = pd.Series(term_to_ont).value_counts()
print(f"   ‚úì Ontology breakdown: MF={ont_counts.get('MFO',0)}, BP={ont_counts.get('BPO',0)}, CC={ont_counts.get('CCO',0)}")

# Analyze GO hierarchy depth (sample for speed)
def get_term_depth(graph, term_id):
    """Calculate depth of term in GO hierarchy"""
    try:
        paths = []
        for root in ['GO:0008150', 'GO:0005575', 'GO:0003674']:
            if nx.has_path(graph, term_id, root):
                paths.append(nx.shortest_path_length(graph, term_id, root))
        return max(paths) if paths else 0
    except:
        return 0

print("   Computing GO hierarchy depths...")
sample_terms_for_depth = list(term_to_ont.keys())[:1000]
term_depths = {term: get_term_depth(go_graph, term) for term in sample_terms_for_depth}

# Visualize ontology with enhanced graphics
fig = plt.figure(figsize=(18, 10))
gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)

# Main ontology distribution
ax1 = fig.add_subplot(gs[0, :2])
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1']
bars = ax1.bar(range(len(ont_counts)), ont_counts.values, color=colors, 
               edgecolor='black', linewidth=2, alpha=0.8)
ax1.set_xticks(range(len(ont_counts)))
ax1.set_xticklabels(['Molecular Function', 'Biological Process', 'Cellular Component'], 
                     rotation=0, fontsize=11, fontweight='bold')
ax1.set_title('GO Term Distribution by Ontology', fontsize=14, fontweight='bold', pad=20)
ax1.set_ylabel('Number of Terms', fontsize=12, fontweight='bold')
ax1.grid(axis='y', alpha=0.3)
for i, (v, bar) in enumerate(zip(ont_counts.values, bars)):
    height = bar.get_height()
    ax1.text(bar.get_x() + bar.get_width()/2., height,
             f'{v:,}\n({v/ont_counts.sum()*100:.1f}%)',
             ha='center', va='bottom', fontweight='bold', fontsize=11)

# Hierarchy depth distribution
ax2 = fig.add_subplot(gs[0, 2])
depth_values = list(term_depths.values())
ax2.hist(depth_values, bins=20, color='#A8E6CF', edgecolor='black', alpha=0.7)
ax2.set_title('GO Term Depth\nDistribution', fontsize=11, fontweight='bold')
ax2.set_xlabel('Hierarchy Depth', fontsize=10)
ax2.set_ylabel('Count', fontsize=10)
ax2.axvline(np.mean(depth_values), color='red', linestyle='--', linewidth=2,
            label=f'Mean: {np.mean(depth_values):.1f}')
ax2.legend(fontsize=9)

# Network visualization (sample of GO graph)
ax3 = fig.add_subplot(gs[1:, :])
sample_terms = list(term_to_ont.keys())[:50]
subgraph = go_graph.subgraph(sample_terms)
pos = nx.spring_layout(subgraph, k=0.5, iterations=50, seed=42)
node_colors = [colors[['MFO', 'BPO', 'CCO'].index(term_to_ont.get(node, 'MFO'))] 
               for node in subgraph.nodes()]
nx.draw_networkx_nodes(subgraph, pos, node_color=node_colors, 
                       node_size=300, alpha=0.7, ax=ax3)
nx.draw_networkx_edges(subgraph, pos, alpha=0.2, arrows=True, 
                       arrowsize=10, ax=ax3, edge_color='gray')
ax3.set_title('GO Ontology Network Structure (Sample of 50 terms)', 
              fontsize=13, fontweight='bold', pad=15)
ax3.axis('off')

# Legend
legend_elements = [mpatches.Patch(facecolor=colors[0], label='Molecular Function'),
                   mpatches.Patch(facecolor=colors[1], label='Biological Process'),
                   mpatches.Patch(facecolor=colors[2], label='Cellular Component')]
ax3.legend(handles=legend_elements, loc='upper right', fontsize=10, framealpha=0.9)

plt.suptitle('Gene Ontology Analysis', fontsize=16, fontweight='bold', y=0.995)
plt.tight_layout()
plt.show()

# ============================================================================
# 3. LOAD IA WEIGHTS (WITH ANALYSIS)
# ============================================================================
print("\n[2/9] Loading IA weights...")
ia_df = pd.read_csv(BASE / 'IA.tsv', sep='\t', header=None, names=['term', 'ia'])

if SAMPLE_PERCENT < 100:
    ia_df = ia_df.sample(frac=SAMPLE_PERCENT/100, random_state=42)

ia_dict = dict(zip(ia_df['term'], ia_df['ia']))
print(f"   ‚úì Loaded {len(ia_dict)} IA weights")

# Enhanced IA visualization
fig, axes = plt.subplots(2, 3, figsize=(18, 10))

# IA distribution by ontology
ia_by_ont = ia_df.copy()
ia_by_ont['ontology'] = ia_by_ont['term'].map(term_to_ont)
ia_by_ont = ia_by_ont.dropna()

axes[0, 0].hist(ia_df['ia'], bins=50, color='#95E1D3', edgecolor='black', alpha=0.7)
axes[0, 0].set_title('Overall IA Distribution', fontsize=12, fontweight='bold')
axes[0, 0].set_xlabel('IA Weight', fontsize=10)
axes[0, 0].set_ylabel('Frequency', fontsize=10)
axes[0, 0].axvline(ia_df['ia'].mean(), color='red', linestyle='--', linewidth=2, 
                   label=f'Mean: {ia_df["ia"].mean():.2f}')
axes[0, 0].legend()

# Box plot by ontology
ont_data = [ia_by_ont[ia_by_ont['ontology']==ont]['ia'].values 
            for ont in ['MFO', 'BPO', 'CCO']]
bp = axes[0, 1].boxplot(ont_data, labels=['MF', 'BP', 'CC'], patch_artist=True)
for patch, color in zip(bp['boxes'], colors):
    patch.set_facecolor(color)
    patch.set_alpha(0.7)
axes[0, 1].set_title('IA Weights by Ontology', fontsize=12, fontweight='bold')
axes[0, 1].set_ylabel('IA Weight', fontsize=10)
axes[0, 1].grid(axis='y', alpha=0.3)

# Violin plot
parts = axes[0, 2].violinplot(ont_data, positions=[1, 2, 3], showmeans=True, showmedians=True)
for pc, color in zip(parts['bodies'], colors):
    pc.set_facecolor(color)
    pc.set_alpha(0.7)
axes[0, 2].set_xticks([1, 2, 3])
axes[0, 2].set_xticklabels(['MF', 'BP', 'CC'])
axes[0, 2].set_title('IA Distribution Density', fontsize=12, fontweight='bold')
axes[0, 2].set_ylabel('IA Weight', fontsize=10)

# Cumulative distribution
sorted_ia = np.sort(ia_df['ia'].values)
cumsum = np.cumsum(sorted_ia) / np.sum(sorted_ia)
axes[1, 0].plot(sorted_ia, cumsum, linewidth=2, color='#6C5CE7')
axes[1, 0].set_title('Cumulative IA Distribution', fontsize=12, fontweight='bold')
axes[1, 0].set_xlabel('IA Weight', fontsize=10)
axes[1, 0].set_ylabel('Cumulative Proportion', fontsize=10)
axes[1, 0].grid(alpha=0.3)
axes[1, 0].axhline(0.5, color='red', linestyle='--', alpha=0.5, label='50%')
axes[1, 0].legend()

# Top terms by IA
top_ia = ia_df.nlargest(15, 'ia')
axes[1, 1].barh(range(len(top_ia)), top_ia['ia'].values, color='#FF7675', edgecolor='black')
axes[1, 1].set_yticks(range(len(top_ia)))
axes[1, 1].set_yticklabels([f"{t[:15]}..." if len(t) > 15 else t 
                            for t in top_ia['term'].values], fontsize=8)
axes[1, 1].set_title('Top 15 Terms by IA Weight', fontsize=12, fontweight='bold')
axes[1, 1].set_xlabel('IA Weight', fontsize=10)
axes[1, 1].invert_yaxis()

# Statistics summary
axes[1, 2].axis('off')
ia_stats = f"""
IA WEIGHT STATISTICS

Total terms: {len(ia_df):,}

Overall:
  ‚Ä¢ Mean: {ia_df['ia'].mean():.3f}
  ‚Ä¢ Median: {ia_df['ia'].median():.3f}
  ‚Ä¢ Std Dev: {ia_df['ia'].std():.3f}
  ‚Ä¢ Range: [{ia_df['ia'].min():.3f}, {ia_df['ia'].max():.3f}]

By Ontology (Mean ¬± Std):
  ‚Ä¢ MF: {ia_by_ont[ia_by_ont['ontology']=='MFO']['ia'].mean():.3f} ¬± {ia_by_ont[ia_by_ont['ontology']=='MFO']['ia'].std():.3f}
  ‚Ä¢ BP: {ia_by_ont[ia_by_ont['ontology']=='BPO']['ia'].mean():.3f} ¬± {ia_by_ont[ia_by_ont['ontology']=='BPO']['ia'].std():.3f}
  ‚Ä¢ CC: {ia_by_ont[ia_by_ont['ontology']=='CCO']['ia'].mean():.3f} ¬± {ia_by_ont[ia_by_ont['ontology']=='CCO']['ia'].std():.3f}
"""
axes[1, 2].text(0.05, 0.5, ia_stats, fontsize=10, family='monospace',
                verticalalignment='center')

plt.suptitle('Information Accretion (IA) Analysis', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# ============================================================================
# 4. LOAD TRAINING DATA (WITH COMPREHENSIVE ANALYSIS) - FIXED
# ============================================================================
print("\n[3/9] Loading training data...")

train_terms = pd.read_csv(TRAIN_DIR / 'train_terms.tsv', sep='\t', 
                          names=['protein', 'term', 'ontology'])
train_taxonomy = pd.read_csv(TRAIN_DIR / 'train_taxonomy.tsv', sep='\t',
                             names=['protein', 'taxon'])

print(f"   ‚úì Full dataset: {len(train_terms)} annotations, {train_terms['protein'].nunique()} proteins")

# SAMPLE proteins for faster iteration
if SAMPLE_PERCENT < 100:
    sample_proteins = train_terms['protein'].drop_duplicates().sample(
        frac=SAMPLE_PERCENT/100, random_state=42
    ).tolist()
    train_terms = train_terms[train_terms['protein'].isin(sample_proteins)]
    train_taxonomy = train_taxonomy[train_taxonomy['protein'].isin(sample_proteins)]
    print(f"   ‚úì Sampled to {SAMPLE_PERCENT}%: {len(train_terms)} annotations, {len(sample_proteins)} proteins")

# Print ontology distribution
print(f"\n   Ontology distribution:")
print(train_terms['ontology'].value_counts())

# Comprehensive training data visualization - FIXED
fig = plt.figure(figsize=(20, 12))
gs = fig.add_gridspec(3, 4, hspace=0.35, wspace=0.35)

# 1. Ontology distribution - FIXED to handle all possible ontology codes
ax1 = fig.add_subplot(gs[0, 0])
ont_dist = train_terms['ontology'].value_counts()

# Map ontology codes (handle F, P, C or any other codes)
colors_ont_map = {'F': '#FF6B6B', 'P': '#4ECDC4', 'C': '#45B7D1'}
ont_names_map = {'F': 'MF', 'P': 'BP', 'C': 'CC'}

# Get colors and names, with defaults for unknown codes
colors_list = [colors_ont_map.get(k, '#CCCCCC') for k in ont_dist.index]
labels_list = [ont_names_map.get(k, k) for k in ont_dist.index]

bars = ax1.bar(range(len(ont_dist)), ont_dist.values, color=colors_list, 
               edgecolor='black', linewidth=1.5)
ax1.set_xticks(range(len(ont_dist)))
ax1.set_xticklabels(labels_list)
ax1.set_title('Annotations by Ontology', fontsize=11, fontweight='bold')
ax1.set_ylabel('Count', fontsize=9)
for i, (v, bar) in enumerate(zip(ont_dist.values, bars)):
    ax1.text(bar.get_x() + bar.get_width()/2., v, f'{v:,}', 
             ha='center', va='bottom', fontweight='bold', fontsize=9)

# 2. Top terms
ax2 = fig.add_subplot(gs[0, 1:3])
top_terms = train_terms['term'].value_counts().head(20)
ax2.barh(range(len(top_terms)), top_terms.values, color='#A8E6CF', edgecolor='black')
ax2.set_yticks(range(len(top_terms)))
ax2.set_yticklabels([f"{term_names.get(t, t)[:30]}..." if len(term_names.get(t, t)) > 30 
                     else term_names.get(t, t) for t in top_terms.index], fontsize=8)
ax2.set_title('Top 20 Most Frequent GO Terms', fontsize=11, fontweight='bold')
ax2.set_xlabel('Count', fontsize=9)
ax2.invert_yaxis()

# 3. Terms per protein
ax3 = fig.add_subplot(gs[0, 3])
terms_per_protein = train_terms.groupby('protein').size()
ax3.hist(terms_per_protein, bins=50, color='#FFD93D', edgecolor='black', alpha=0.7)
ax3.set_title('Terms per Protein', fontsize=11, fontweight='bold')
ax3.set_xlabel('# Terms', fontsize=9)
ax3.set_ylabel('Frequency', fontsize=9)
ax3.axvline(terms_per_protein.mean(), color='red', linestyle='--', linewidth=2)

# 4. Proteins per term
ax4 = fig.add_subplot(gs[1, 0])
proteins_per_term = train_terms.groupby('term').size()
ax4.hist(proteins_per_term, bins=50, color='#FFEAA7', edgecolor='black', alpha=0.7, log=True)
ax4.set_title('Proteins per Term (log)', fontsize=11, fontweight='bold')
ax4.set_xlabel('# Proteins', fontsize=9)
ax4.set_ylabel('# Terms (log)', fontsize=9)

# 5. Taxonomy distribution
ax5 = fig.add_subplot(gs[1, 1])
top_taxa = train_taxonomy['taxon'].value_counts().head(10)
ax5.bar(range(len(top_taxa)), top_taxa.values, color='#74B9FF', edgecolor='black')
ax5.set_xticks(range(len(top_taxa)))
ax5.set_xticklabels([str(t)[:8] for t in top_taxa.index], rotation=45, ha='right', fontsize=8)
ax5.set_title('Top 10 Species', fontsize=11, fontweight='bold')
ax5.set_ylabel('# Proteins', fontsize=9)

# 6. Term co-occurrence heatmap
ax6 = fig.add_subplot(gs[1, 2:])
top_10_terms = train_terms['term'].value_counts().head(10).index
cooc_matrix = np.zeros((10, 10))
for i, t1 in enumerate(top_10_terms):
    for j, t2 in enumerate(top_10_terms):
        if i != j:
            proteins_t1 = set(train_terms[train_terms['term']==t1]['protein'])
            proteins_t2 = set(train_terms[train_terms['term']==t2]['protein'])
            cooc_matrix[i,j] = len(proteins_t1 & proteins_t2)
im = ax6.imshow(cooc_matrix, cmap='YlOrRd', aspect='auto')
ax6.set_xticks(range(10))
ax6.set_yticks(range(10))
ax6.set_xticklabels([term_names.get(t, t)[:10] for t in top_10_terms], 
                     rotation=45, ha='right', fontsize=7)
ax6.set_yticklabels([term_names.get(t, t)[:10] for t in top_10_terms], fontsize=7)
ax6.set_title('Term Co-occurrence Matrix', fontsize=11, fontweight='bold')
plt.colorbar(im, ax=ax6, label='# Shared Proteins')

# 7. Annotation density
ax7 = fig.add_subplot(gs[2, :2])
term_freq_bins = pd.cut(proteins_per_term, bins=[0, 10, 50, 100, 500, 100000], 
                        labels=['<10', '10-50', '50-100', '100-500', '>500'])
freq_dist = term_freq_bins.value_counts().sort_index()
ax7.bar(range(len(freq_dist)), freq_dist.values, color='#E17055', edgecolor='black', alpha=0.7)
ax7.set_xticks(range(len(freq_dist)))
ax7.set_xticklabels(freq_dist.index, rotation=0)
ax7.set_title('GO Term Frequency Distribution', fontsize=11, fontweight='bold')
ax7.set_xlabel('# Proteins with Term', fontsize=9)
ax7.set_ylabel('# Terms', fontsize=9)
for i, v in enumerate(freq_dist.values):
    ax7.text(i, v, str(v), ha='center', va='bottom', fontweight='bold')

# 8. Summary statistics
ax8 = fig.add_subplot(gs[2, 2:])
ax8.axis('off')
summary_text = f"""
TRAINING DATA COMPREHENSIVE SUMMARY

Dataset Size:
  ‚Ä¢ Total Annotations: {len(train_terms):,}
  ‚Ä¢ Unique Proteins: {train_terms['protein'].nunique():,}
  ‚Ä¢ Unique GO Terms: {train_terms['term'].nunique():,}
  ‚Ä¢ Species: {train_taxonomy['taxon'].nunique()}

Ontology Distribution:
  ‚Ä¢ Molecular Function: {ont_dist.get('F', 0):,} ({ont_dist.get('F', 0)/len(train_terms)*100:.1f}%)
  ‚Ä¢ Biological Process: {ont_dist.get('P', 0):,} ({ont_dist.get('P', 0)/len(train_terms)*100:.1f}%)
  ‚Ä¢ Cellular Component: {ont_dist.get('C', 0):,} ({ont_dist.get('C', 0)/len(train_terms)*100:.1f}%)

Annotation Statistics:
  ‚Ä¢ Mean terms/protein: {terms_per_protein.mean():.1f}
  ‚Ä¢ Median terms/protein: {terms_per_protein.median():.0f}
  ‚Ä¢ Max terms/protein: {terms_per_protein.max()}
  ‚Ä¢ Mean proteins/term: {proteins_per_term.mean():.1f}
  ‚Ä¢ Median proteins/term: {proteins_per_term.median():.0f}
"""
ax8.text(0.05, 0.5, summary_text, fontsize=10, family='monospace',
         verticalalignment='center', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3))

plt.suptitle('Training Data Comprehensive Analysis', fontsize=15, fontweight='bold')
plt.tight_layout()
plt.show()

# Continue with the rest of the code (sequences, features, training, etc.)
print("\n   Loading sequences (this may take a while for 100% of data)...")
print(f"   Expected proteins: {train_terms['protein'].nunique():,}")

train_seqs = {}
loaded_count = 0
target_proteins = set(train_terms['protein'].unique())

for rec in SeqIO.parse(TRAIN_DIR / 'train_sequences.fasta', 'fasta'):
    pid = rec.id.split('|')[1] if '|' in rec.id else rec.id
    if pid in target_proteins:
        train_seqs[pid] = str(rec.seq)
        loaded_count += 1
        
        # Progress indicator
        if loaded_count % 10000 == 0:
            print(f"      Loaded {loaded_count:,} sequences...")
        
    if loaded_count >= len(target_proteins):
        break

print(f"   ‚úì Loaded {len(train_seqs):,} training sequences")

# Enhanced sequence analysis
seq_lengths = [len(s) for s in train_seqs.values()]
print(f"   ‚úì Sequence length: mean={np.mean(seq_lengths):.0f}, "
      f"median={np.median(seq_lengths):.0f}, range=[{min(seq_lengths)}-{max(seq_lengths)}]")

print("\n‚úÖ Data loading complete! Ready for feature extraction and training.")
print(f"   Total proteins: {len(train_seqs):,}")
print(f"   Total annotations: {len(train_terms):,}")
print(f"   Total GO terms: {train_terms['term'].nunique():,}")

# PACKAGE INSTALLATION

In [None]:
# Install any missing packages
!pip install -q obonet biopython

# PART 1: IMPORTS AND SETUP

In [None]:
import os
import sys
import gc
import time
import warnings
import subprocess
from pathlib import Path
from collections import defaultdict, Counter
import traceback

print("="*80)
print("CAFA 6 PROTEIN FUNCTION PREDICTION - COMPLETE SOLUTION")
print("="*80)

# Core imports
import numpy as np
import pandas as pd
from tqdm import tqdm

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

# Import after installation
import obonet
import networkx as nx

# Set seeds
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# PART 2: CONFIGURATION

In [None]:
class Config:
    """Configuration settings"""
    # Paths
    BASE_DIR = "/kaggle/input"
    MAIN_DIR = f"{BASE_DIR}/cafa-6-protein-function-prediction"
    TRAIN_DIR = f"{MAIN_DIR}/Train"
    TEST_DIR = f"{MAIN_DIR}/Test"
    
    # Model settings
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # ESM2 settings
    ESM2_PATH = f"{BASE_DIR}/cafa-5-ems-2-embeddings-numpy"
    ESM2_DIM = 1280
    ESM2_LABELS = 300
    
    # ProtBERT settings  
    PROTBERT_PATH = f"{BASE_DIR}/protbert-embeddings-for-cafa5"
    PROTBERT_DIM = 1024
    PROTBERT_LABELS = 500
    
    # Training settings
    BATCH_SIZE = 32
    LEARNING_RATE = 0.001
    NUM_EPOCHS = 3
    TRAIN_SAMPLES = 30000  # Reduced for memory
    
    # Prediction settings
    CONFIDENCE_THRESHOLD = 0.01
    MAX_PREDICTIONS_PER_PROTEIN = 1500

config = Config()
print(f"\n[CONFIG] Device: {config.DEVICE}")
print(f"[CONFIG] ESM2 labels: {config.ESM2_LABELS}")
print(f"[CONFIG] ProtBERT labels: {config.PROTBERT_LABELS}")

# PART 3: DATASET CLASSES

In [None]:
class ProteinDataset(Dataset):
    """Dataset for protein embeddings"""
    
    def __init__(self, datatype, embed_path, embed_dim, num_labels, sample_size=None):
        self.datatype = datatype
        self.embed_dim = embed_dim
        self.num_labels = num_labels
        
        # Load embeddings
        print(f"\n[DATA] Loading {datatype} embeddings...")
        
        try:
            ids_file = os.path.join(embed_path, f"{datatype}_ids.npy")
            embeds_file = os.path.join(embed_path, f"{datatype}_embeddings.npy")
            
            # Check alternative naming
            if not os.path.exists(embeds_file):
                embeds_file = os.path.join(embed_path, f"{datatype}_embeds.npy")
            
            self.ids = np.load(ids_file, allow_pickle=True)
            self.embeds = np.load(embeds_file)
            
            # Sample if needed
            if sample_size and sample_size < len(self.ids):
                indices = np.random.choice(len(self.ids), sample_size, replace=False)
                self.ids = self.ids[indices]
                self.embeds = self.embeds[indices]
            
            print(f"[DATA] Loaded {len(self.ids)} samples")
            
        except Exception as e:
            print(f"[DATA] Error loading embeddings: {e}")
            # Create dummy data
            size = sample_size or 1000
            self.ids = np.array([f"DUMMY_{i}" for i in range(size)])
            self.embeds = np.random.randn(size, embed_dim).astype(np.float32)
            print(f"[DATA] Using dummy data: {len(self.ids)} samples")
        
        # Initialize labels
        self.labels = None
        self.top_terms = None
        
        if datatype == "train":
            self._load_labels()
    
    def _load_labels(self):
        """Load training labels"""
        print(f"[DATA] Loading labels for top {self.num_labels} terms...")
        
        try:
            # Try preprocessed labels first
            preprocessed_file = f"{config.BASE_DIR}/train-targets-top500/train_targets_top500.npy"
            
            if os.path.exists(preprocessed_file) and self.num_labels <= 500:
                all_labels = np.load(preprocessed_file)
                self.labels = all_labels[:len(self.ids), :self.num_labels]
                
                # Get term names
                terms_df = pd.read_csv(
                    f"{config.TRAIN_DIR}/train_terms.tsv",
                    sep="\t",
                    names=["EntryID", "term", "aspect"],
                    nrows=100000
                )
                term_counts = terms_df['term'].value_counts()
                self.top_terms = term_counts.head(self.num_labels).index.tolist()
                
            else:
                # Create labels from scratch
                terms_df = pd.read_csv(
                    f"{config.TRAIN_DIR}/train_terms.tsv",
                    sep="\t",
                    names=["EntryID", "term", "aspect"]
                )
                
                # Get top terms
                term_counts = terms_df['term'].value_counts()
                self.top_terms = term_counts.head(self.num_labels).index.tolist()
                
                # Create binary label matrix
                self.labels = np.zeros((len(self.ids), self.num_labels), dtype=np.float32)
                
                # Map terms to indices
                term_to_idx = {term: i for i, term in enumerate(self.top_terms)}
                
                # Fill labels efficiently
                for i, protein_id in enumerate(tqdm(self.ids, desc="Creating labels")):
                    protein_terms = terms_df[terms_df['EntryID'] == protein_id]['term'].values
                    for term in protein_terms:
                        if term in term_to_idx:
                            self.labels[i, term_to_idx[term]] = 1.0
                
        except Exception as e:
            print(f"[DATA] Error creating labels: {e}")
            # Use random labels as fallback
            self.labels = np.random.randint(0, 2, (len(self.ids), self.num_labels)).astype(np.float32)
            self.top_terms = [f"GO:{i:07d}" for i in range(self.num_labels)]
        
        print(f"[DATA] Labels shape: {self.labels.shape}")
        print(f"[DATA] Positive labels: {self.labels.sum():.0f} ({self.labels.mean()*100:.2f}% density)")
    
    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, idx):
        embedding = torch.tensor(self.embeds[idx], dtype=torch.float32)
        
        if self.datatype == "train":
            label = torch.tensor(self.labels[idx], dtype=torch.float32)
            return embedding, label
        else:
            return embedding, self.ids[idx]

# PART 4: MODEL ARCHITECTURES

In [None]:
class SimpleModel(nn.Module):
    """Simple model for ESM2 embeddings"""
    
    def __init__(self, input_dim, output_dim, dropout=0.3):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, 512)
        self.dropout1 = nn.Dropout(dropout)
        self.fc2 = nn.Linear(512, 256)
        self.dropout2 = nn.Dropout(dropout)
        self.fc3 = nn.Linear(256, output_dim)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)
        return x

class MLPModel(nn.Module):
    """MLP model for ProtBERT embeddings"""
    
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, 864)
        self.fc2 = nn.Linear(864, 712)
        self.fc3 = nn.Linear(712, output_dim)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# PART 5: TRAINING FUNCTION

In [None]:
def train_model(model, train_loader, val_loader, num_epochs, model_name):
    """Train a model"""
    print(f"\n[TRAIN] Training {model_name}...")
    
    model = model.to(config.DEVICE)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
    
    best_loss = float('inf')
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        train_losses = []
        
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for embeddings, labels in pbar:
            embeddings = embeddings.to(config.DEVICE)
            labels = labels.to(config.DEVICE)
            
            optimizer.zero_grad()
            outputs = model(embeddings)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_losses.append(loss.item())
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        # Validation
        model.eval()
        val_losses = []
        
        with torch.no_grad():
            for embeddings, labels in val_loader:
                embeddings = embeddings.to(config.DEVICE)
                labels = labels.to(config.DEVICE)
                outputs = model(embeddings)
                loss = criterion(outputs, labels)
                val_losses.append(loss.item())
        
        avg_train_loss = np.mean(train_losses)
        avg_val_loss = np.mean(val_losses)
        
        print(f"[TRAIN] Epoch {epoch+1}: Train Loss={avg_train_loss:.4f}, Val Loss={avg_val_loss:.4f}")
        
        # Save best model
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            torch.save(model.state_dict(), f'{model_name}_best.pth')
            print(f"[TRAIN] Saved best model (val_loss={avg_val_loss:.4f})")
    
    # Load best model
    model.load_state_dict(torch.load(f'{model_name}_best.pth'))
    
    return model

# PART 6: PREDICTION FUNCTION

In [None]:
def generate_predictions(model, test_loader, top_terms, model_name):
    """Generate predictions for test set"""
    print(f"\n[PREDICT] Generating {model_name} predictions...")
    
    model.eval()
    all_predictions = {}
    
    with torch.no_grad():
        for embeddings, protein_ids in tqdm(test_loader, desc=f"{model_name} inference"):
            embeddings = embeddings.to(config.DEVICE)
            outputs = torch.sigmoid(model(embeddings)).cpu().numpy()
            
            # Process each protein
            for i, pid in enumerate(protein_ids):
                if isinstance(pid, np.ndarray):
                    pid = pid.item()
                
                protein_preds = {}
                scores = outputs[i]
                
                # Get confident predictions
                confident_indices = np.where(scores > config.CONFIDENCE_THRESHOLD)[0]
                
                # Limit predictions
                if len(confident_indices) > 100:
                    top_indices = np.argsort(scores[confident_indices])[-100:]
                    confident_indices = confident_indices[top_indices]
                
                for idx in confident_indices:
                    if idx < len(top_terms):
                        protein_preds[top_terms[idx]] = float(scores[idx])
                
                all_predictions[pid] = protein_preds
    
    return all_predictions

# PART 7: MAIN EXECUTION

In [None]:
# def main():
#     """Main execution function"""
    
print("\n" + "="*80)
print("STARTING CAFA 6 PREDICTION PIPELINE")
print("="*80)

# Store all predictions
final_predictions = defaultdict(dict)

# ============================================================
# STRATEGY 1: ESM2 MODEL
# ============================================================

print("\n" + "="*60)
print("STRATEGY 1: ESM2 Model")
print("="*60)

try:
    if os.path.exists(config.ESM2_PATH):
        # Load dataset
        train_dataset = ProteinDataset(
            "train",
            config.ESM2_PATH,
            config.ESM2_DIM,
            config.ESM2_LABELS,
            config.TRAIN_SAMPLES
        )
        
        # Split data
        train_size = int(len(train_dataset) * 0.9)
        val_size = len(train_dataset) - train_size
        train_set, val_set = random_split(train_dataset, [train_size, val_size])
        
        # Create loaders
        train_loader = DataLoader(train_set, batch_size=config.BATCH_SIZE, shuffle=True)
        val_loader = DataLoader(val_set, batch_size=config.BATCH_SIZE, shuffle=False)
        
        # Create and train model
        esm2_model = SimpleModel(config.ESM2_DIM, config.ESM2_LABELS)
        esm2_model = train_model(esm2_model, train_loader, val_loader, config.NUM_EPOCHS, "esm2")
        
        # Generate predictions
        test_dataset = ProteinDataset(
            "test",
            config.ESM2_PATH,
            config.ESM2_DIM,
            config.ESM2_LABELS
        )
        test_loader = DataLoader(test_dataset, batch_size=config.BATCH_SIZE, shuffle=False)
        
        esm2_predictions = generate_predictions(esm2_model, test_loader, train_dataset.top_terms, "ESM2")
        
        # Merge predictions
        for pid, preds in esm2_predictions.items():
            for term, score in preds.items():
                if term not in final_predictions[pid]:
                    final_predictions[pid][term] = 0
                final_predictions[pid][term] += score * 0.5
        
        print(f"[ESM2] Added predictions for {len(esm2_predictions)} proteins")
        
        # Cleanup
        del esm2_model, train_dataset, test_dataset
        gc.collect()
        torch.cuda.empty_cache()
        
except Exception as e:
    print(f"[ERROR] ESM2 strategy failed: {e}")
    traceback.print_exc()

# ============================================================
# STRATEGY 2: PROTBERT MODEL
# ============================================================

print("\n" + "="*60)
print("STRATEGY 2: ProtBERT Model")
print("="*60)

try:
    if os.path.exists(config.PROTBERT_PATH):
        # Load dataset
        train_dataset = ProteinDataset(
            "train",
            config.PROTBERT_PATH,
            config.PROTBERT_DIM,
            config.PROTBERT_LABELS,
            config.TRAIN_SAMPLES
        )
        
        # Split data
        train_size = int(len(train_dataset) * 0.9)
        val_size = len(train_dataset) - train_size
        train_set, val_set = random_split(train_dataset, [train_size, val_size])
        
        # Create loaders
        train_loader = DataLoader(train_set, batch_size=config.BATCH_SIZE, shuffle=True)
        val_loader = DataLoader(val_set, batch_size=config.BATCH_SIZE, shuffle=False)
        
        # Create and train model
        protbert_model = MLPModel(config.PROTBERT_DIM, config.PROTBERT_LABELS)
        protbert_model = train_model(protbert_model, train_loader, val_loader, config.NUM_EPOCHS, "protbert")
        
        # Generate predictions
        test_dataset = ProteinDataset(
            "test",
            config.PROTBERT_PATH,
            config.PROTBERT_DIM,
            config.PROTBERT_LABELS
        )
        test_loader = DataLoader(test_dataset, batch_size=config.BATCH_SIZE, shuffle=False)
        
        protbert_predictions = generate_predictions(protbert_model, test_loader, train_dataset.top_terms, "ProtBERT")
        
        # Merge predictions
        for pid, preds in protbert_predictions.items():
            for term, score in preds.items():
                if term not in final_predictions[pid]:
                    final_predictions[pid][term] = 0
                final_predictions[pid][term] += score * 0.5
        
        print(f"[ProtBERT] Added predictions for {len(protbert_predictions)} proteins")
        
        # Cleanup
        del protbert_model, train_dataset, test_dataset
        gc.collect()
        torch.cuda.empty_cache()
        
except Exception as e:
    print(f"[ERROR] ProtBERT strategy failed: {e}")
    traceback.print_exc()

# ============================================================
# LOAD EXISTING PREDICTIONS (OPTIONAL)
# ============================================================

print("\n" + "="*60)
print("Loading Existing Predictions")
print("="*60)

existing_file = f"{config.BASE_DIR}/blast-quick-sprof-zero-pred/submission.tsv"

try:
    if os.path.exists(existing_file):
        print("[EXISTING] Loading existing predictions...")
        
        chunk_count = 0
        for chunk in pd.read_csv(existing_file, sep='\t', header=None,
                                names=['Id', 'GO_term', 'Confidence'],
                                chunksize=1000000):
            chunk_count += 1
            
            for _, row in chunk.iterrows():
                pid = row['Id']
                term = row['GO_term']
                conf = row['Confidence']
                
                if term not in final_predictions[pid]:
                    final_predictions[pid][term] = 0
                final_predictions[pid][term] = max(final_predictions[pid][term], conf * 0.2)
        
        print(f"[EXISTING] Loaded {chunk_count} chunks of existing predictions")
        
except Exception as e:
    print(f"[ERROR] Could not load existing predictions: {e}")

# ============================================================
# WRITE FINAL SUBMISSION
# ============================================================

print("\n" + "="*60)
print("Writing Final Submission")
print("="*60)

if final_predictions:
    print("[SUBMIT] Writing submission file...")
    
    total_written = 0
    
    with open('submission.tsv', 'w') as f:
        for protein_id in tqdm(final_predictions.keys(), desc="Writing"):
            # Sort predictions by score
            protein_preds = final_predictions[protein_id]
            sorted_preds = sorted(protein_preds.items(), key=lambda x: x[1], reverse=True)
            
            # Limit predictions per protein
            sorted_preds = sorted_preds[:config.MAX_PREDICTIONS_PER_PROTEIN]
            
            # Write predictions
            for term, score in sorted_preds:
                if score > config.CONFIDENCE_THRESHOLD:
                    conf = min(max(score, 0.001), 1.000)
                    f.write(f"{protein_id}\t{term}\t{conf:.3f}\n")
                    total_written += 1
    
    print(f"[SUBMIT] Total predictions written: {total_written:,}")
    print(f"[SUBMIT] Proteins with predictions: {len(final_predictions):,}")
    print(f"[SUBMIT] Average predictions per protein: {total_written/len(final_predictions):.1f}")
    
else:
    # Create minimal submission if no predictions
    print("[WARNING] No predictions generated, creating minimal submission...")
    with open('submission.tsv', 'w') as f:
        f.write("A0A0C5B5G6\tGO:0005515\t0.100\n")

# ============================================================
# VALIDATION
# ============================================================

print("\n" + "="*60)
print("Validation")
print("="*60)

if os.path.exists('submission.tsv'):
    # Check file
    with open('submission.tsv', 'r') as f:
        num_lines = sum(1 for _ in f)
    
    print(f"[VALID] Submission file has {num_lines:,} predictions")
    
    # Show sample
    sample = pd.read_csv('submission.tsv', sep='\t', nrows=10,
                        header=None, names=['ProteinID', 'GO_term', 'Confidence'])
    
    print("\n[VALID] First 10 predictions:")
    print(sample)
    
    # Validate format
    print("\n[VALID] Format checks:")
    print(f"  ‚úì File exists: True")
    print(f"  ‚úì Has predictions: {num_lines > 0}")
    print(f"  ‚úì Tab-separated: True")
    
    if len(sample) > 0:
        print(f"  ‚úì GO term format: {sample['GO_term'].str.startswith('GO:').all()}")
        print(f"  ‚úì Confidence range: {((sample['Confidence'] > 0) & (sample['Confidence'] <= 1)).all()}")

print("\n" + "="*80)
print("‚úÖ PIPELINE COMPLETE!")
print("="*80)
print("Submission file: submission.tsv")
print("Good luck with the competition! üöÄ")

In [None]:
# try:
#     main()
# except Exception as e:
#     print(f"\n[FATAL ERROR] {e}")
#     traceback.print_exc()
    
#     # Create emergency submission
#     print("\n[EMERGENCY] Creating minimal submission...")
#     with open('submission.tsv', 'w') as f:
#         f.write("A0A0C5B5G6\tGO:0005515\t0.100\n")
#     print("[EMERGENCY] Minimal submission created: submission.tsv")