In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
CAFA 6 Protein Function Prediction - Complete Implementation
This code predicts Gene Ontology (GO) terms for proteins based on their sequences
"""

import pandas as pd
import numpy as np
from tqdm import tqdm
import time
import os
import warnings
warnings.filterwarnings('ignore')

# PyTorch modules
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import MultilabelF1Score

# Visualization
import matplotlib.pyplot as plt
plt.style.use('ggplot')

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# ====================== CONFIGURATION ======================
class Config:
    """Configuration settings for the model"""
    def __init__(self):
        self.main_dir = "/kaggle/input/cafa-6-protein-function-prediction"
        self.train_sequences_path = f"{self.main_dir}/Train/train_sequences.fasta"
        self.train_labels_path = f"{self.main_dir}/Train/train_terms.tsv"
        self.test_sequences_path = f"{self.main_dir}/Test/testsuperset.fasta"
        self.ia_path = f"{self.main_dir}/IA.tsv"
        
        # Model parameters
        self.num_labels = 600  # Start with top 500, can be increased
        self.n_epochs = 15
        self.batch_size = 64
        self.lr = 0.001
        self.dropout_rate = 0.3
        self.patience = 3
        
        # Device configuration
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Embedding configurations
        self.embeds_map = {
            "ESM2": "cafa-5-ems-2-embeddings-numpy",
            "ProtBERT": "protbert-embeddings-for-cafa5",
            "T5": "t5embeds"
        }
        
        self.embeds_dim = {
            "ESM2": 1280,
            "ProtBERT": 1024,
            "T5": 1024
        }

config = Config()
print(f"Using device: {config.device}")

# ====================== DATA LOADING ======================
class ProteinDataset(Dataset):
    """Dataset class for protein sequences and their GO annotations"""
    
    def __init__(self, datatype, embeddings_source, config):
        super().__init__()
        self.datatype = datatype
        self.config = config
        self.embeddings_source = embeddings_source
        
        # Load embeddings
        self._load_embeddings()
        
        # Load labels for training data
        if self.datatype == "train":
            self._load_labels()
    
    def _load_embeddings(self):
        """Load pre-computed embeddings"""
        embed_dir = f"/kaggle/input/{self.config.embeds_map[self.embeddings_source]}"
        
        if self.embeddings_source == "ESM2":
            embeds = np.load(f"{embed_dir}/{self.datatype}_embeddings.npy")
            ids = np.load(f"{embed_dir}/{self.datatype}_ids.npy")
        elif self.embeddings_source == "ProtBERT":
            embeds = np.load(f"{embed_dir}/{self.datatype}_embeddings.npy")
            ids = np.load(f"{embed_dir}/{self.datatype}_ids.npy")
        elif self.embeddings_source == "T5":
            embeds = np.load(f"{embed_dir}/{self.datatype}_embeds.npy")
            ids = np.load(f"{embed_dir}/{self.datatype}_ids.npy")
        
        # Create DataFrame
        self.df = pd.DataFrame({
            "EntryID": ids,
            "embed": [embeds[i] for i in range(embeds.shape[0])]
        })
    
    def _load_labels(self):
        """Load GO term labels for training data"""
        # Load pre-processed top labels if available
        label_file = f"/kaggle/input/train-targets-top{self.config.num_labels}/train_targets_top{self.config.num_labels}.npy"
        
        if os.path.exists(label_file):
            np_labels = np.load(label_file)
            df_labels = pd.DataFrame({
                'EntryID': self.df['EntryID'],
                'labels_vect': [row for row in np_labels]
            })
            self.df = self.df.merge(df_labels, on="EntryID", how="inner")
        else:
            # Process labels from scratch
            self._process_labels_from_tsv()
    
    def _process_labels_from_tsv(self):
        """Process labels from the TSV file"""
        labels_df = pd.read_csv(self.config.train_labels_path, sep="\t", names=["EntryID", "term", "aspect"])
        
        # Get top terms
        top_terms = labels_df.groupby("term")["EntryID"].count().sort_values(ascending=False)
        self.top_terms = top_terms[:self.config.num_labels].index.tolist()
        
        # Create label vectors
        label_vectors = []
        for entry_id in self.df['EntryID']:
            entry_terms = labels_df[labels_df['EntryID'] == entry_id]['term'].tolist()
            vector = [1 if term in entry_terms else 0 for term in self.top_terms]
            label_vectors.append(vector)
        
        self.df['labels_vect'] = label_vectors
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        embed = torch.tensor(row["embed"], dtype=torch.float32)
        
        if self.datatype == "train":
            labels = torch.tensor(row["labels_vect"], dtype=torch.float32)
            return embed, labels
        else:
            return embed, row["EntryID"]

# ====================== MODEL ARCHITECTURES ======================
class ImprovedMLP(nn.Module):
    """Improved Multi-Layer Perceptron with dropout and batch norm"""
    
    def __init__(self, input_dim, num_classes, dropout_rate=0.3):
        super().__init__()
        
        self.fc1 = nn.Linear(input_dim, 1024)
        self.bn1 = nn.BatchNorm1d(1024)
        self.dropout1 = nn.Dropout(dropout_rate)
        
        self.fc2 = nn.Linear(1024, 512)
        self.bn2 = nn.BatchNorm1d(512)
        self.dropout2 = nn.Dropout(dropout_rate)
        
        self.fc3 = nn.Linear(512, 256)
        self.bn3 = nn.BatchNorm1d(256)
        self.dropout3 = nn.Dropout(dropout_rate)
        
        self.fc4 = nn.Linear(256, num_classes)
    
    def forward(self, x):
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout1(x)
        
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout2(x)
        
        x = F.relu(self.bn3(self.fc3(x)))
        x = self.dropout3(x)
        
        x = self.fc4(x)
        return x

class AttentionCNN(nn.Module):
    """1D CNN with attention mechanism"""
    
    def __init__(self, input_dim, num_classes, dropout_rate=0.3):
        super().__init__()
        
        # Convolutional layers
        self.conv1 = nn.Conv1d(1, 64, kernel_size=7, padding=3)
        self.bn1 = nn.BatchNorm1d(64)
        self.pool1 = nn.MaxPool1d(2)
        
        self.conv2 = nn.Conv1d(64, 128, kernel_size=5, padding=2)
        self.bn2 = nn.BatchNorm1d(128)
        self.pool2 = nn.MaxPool1d(2)
        
        self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(256)
        self.pool3 = nn.AdaptiveAvgPool1d(1)
        
        # Attention layer
        self.attention = nn.MultiheadAttention(256, num_heads=8, dropout=dropout_rate, batch_first=True)
        
        # Fully connected layers
        self.fc1 = nn.Linear(256, 512)
        self.dropout = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(512, num_classes)
    
    def forward(self, x):
        # Reshape for Conv1d: (batch, 1, features)
        x = x.unsqueeze(1)
        
        # Convolutional blocks
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool1(x)
        
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool2(x)
        
        x = F.relu(self.bn3(self.conv3(x)))
        
        # Reshape for attention: (batch, seq_len, features)
        x = x.transpose(1, 2)
        
        # Self-attention
        x, _ = self.attention(x, x, x)
        
        # Global pooling
        x = x.mean(dim=1)
        
        # Classification head
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

class HybridModel(nn.Module):
    """Hybrid model combining CNN and LSTM features"""
    
    def __init__(self, input_dim, num_classes, dropout_rate=0.3):
        super().__init__()
        
        # Feature extraction branch 1: CNN
        self.conv1 = nn.Conv1d(1, 32, kernel_size=5, padding=2)
        self.conv2 = nn.Conv1d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.AdaptiveMaxPool1d(128)
        
        # Feature extraction branch 2: LSTM
        self.lstm = nn.LSTM(input_dim, 256, num_layers=2, 
                            bidirectional=True, dropout=dropout_rate, batch_first=True)
        
        # Fusion and classification
        self.fc1 = nn.Linear(64 * 128 + 512, 512)
        self.dropout = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, num_classes)
    
    def forward(self, x):
        batch_size = x.shape[0]
        
        # CNN branch
        cnn_x = x.unsqueeze(1)
        cnn_x = F.relu(self.conv1(cnn_x))
        cnn_x = F.relu(self.conv2(cnn_x))
        cnn_x = self.pool(cnn_x)
        cnn_x = cnn_x.view(batch_size, -1)
        
        # LSTM branch
        lstm_x = x.unsqueeze(1)
        lstm_out, _ = self.lstm(lstm_x)
        lstm_x = lstm_out[:, -1, :]
        
        # Concatenate features
        combined = torch.cat([cnn_x, lstm_x], dim=1)
        
        # Classification
        x = F.relu(self.fc1(combined))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        
        return x

# ====================== TRAINING FUNCTIONS ======================
def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    losses = []
    
    for embeddings, labels in tqdm(dataloader, desc="Training"):
        embeddings, labels = embeddings.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(embeddings)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
    
    return np.mean(losses)

def validate_epoch(model, dataloader, criterion, metric, device):
    """Validate for one epoch"""
    model.eval()
    losses = []
    scores = []
    
    with torch.no_grad():
        for embeddings, labels in dataloader:
            embeddings, labels = embeddings.to(device), labels.to(device)
            
            outputs = model(embeddings)
            loss = criterion(outputs, labels)
            score = metric(torch.sigmoid(outputs), labels.int())
            
            losses.append(loss.item())
            scores.append(score.item())
    
    return np.mean(losses), np.mean(scores)

def train_model(embeddings_source="ESM2", model_type="hybrid", train_ratio=0.9):
    """Complete training pipeline"""
    
    print(f"\nTraining {model_type} model with {embeddings_source} embeddings...")
    
    # Create dataset
    dataset = ProteinDataset("train", embeddings_source, config)
    
    # Split into train and validation
    train_size = int(len(dataset) * train_ratio)
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=2)
    
    # Initialize model
    input_dim = config.embeds_dim[embeddings_source]
    
    if model_type == "mlp":
        model = ImprovedMLP(input_dim, config.num_labels, config.dropout_rate)
    elif model_type == "cnn":
        model = AttentionCNN(input_dim, config.num_labels, config.dropout_rate)
    elif model_type == "hybrid":
        model = HybridModel(input_dim, config.num_labels, config.dropout_rate)
    else:
        raise ValueError(f"Unknown model type: {model_type}")
    
    model = model.to(config.device)
    
    # Loss and optimizer
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=1e-5)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=config.patience)
    metric = MultilabelF1Score(num_labels=config.num_labels, average='micro').to(config.device)
    
    # Training loop
    best_val_score = 0
    best_model_state = None
    train_losses, val_losses, val_scores = [], [], []
    
    for epoch in range(config.n_epochs):
        print(f"\nEpoch {epoch + 1}/{config.n_epochs}")
        
        # Train
        train_loss = train_epoch(model, train_loader, criterion, optimizer, config.device)
        train_losses.append(train_loss)
        
        # Validate
        val_loss, val_score = validate_epoch(model, val_loader, criterion, metric, config.device)
        val_losses.append(val_loss)
        val_scores.append(val_score)
        
        # Learning rate scheduling
        scheduler.step(val_loss)
        
        # Save best model
        if val_score > best_val_score:
            best_val_score = val_score
            best_model_state = model.state_dict().copy()
        
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val F1: {val_score:.4f}")
    
    # Load best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    
    return model, {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'val_scores': val_scores,
        'best_score': best_val_score
    }

# ====================== PREDICTION ======================
def predict(model, embeddings_source="ESM2"):
    """Generate predictions for test set"""
    
    print("\nGenerating predictions...")
    
    # Load test dataset
    test_dataset = ProteinDataset("test", embeddings_source, config)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    
    # Get label names
    labels_df = pd.read_csv(config.train_labels_path, sep="\t", names=["EntryID", "term", "aspect"])
    top_terms = labels_df.groupby("term")["EntryID"].count().sort_values(ascending=False)
    label_names = top_terms[:config.num_labels].index.tolist()
    
    # Generate predictions
    model.eval()
    predictions = []
    
    with torch.no_grad():
        for embeddings, protein_id in tqdm(test_loader, desc="Predicting"):
            embeddings = embeddings.to(config.device)
            outputs = torch.sigmoid(model(embeddings)).cpu().numpy().squeeze()
            
            for i, conf in enumerate(outputs):
                if conf > 0.01:  # Only include predictions above threshold
                    predictions.append({
                        'Id': protein_id[0],
                        'GO term': label_names[i],
                        'Confidence': min(conf, 0.999)  # Cap at 0.999
                    })
    
    return pd.DataFrame(predictions)

# ====================== ENSEMBLE ======================
def ensemble_predictions(predictions_list, weights=None):
    """Combine multiple predictions using weighted averaging"""
    
    if weights is None:
        weights = [1/len(predictions_list)] * len(predictions_list)
    
    # Combine all predictions
    combined = pd.concat(predictions_list, ignore_index=True)
    
    # Group by Id and GO term, taking weighted average
    ensemble = combined.groupby(['Id', 'GO term'])['Confidence'].apply(
        lambda x: np.average(x, weights=weights[:len(x)])
    ).reset_index()
    
    return ensemble

# ====================== MAIN EXECUTION ======================
def main():
    """Main execution function"""
    
    print("Starting CAFA 6 Protein Function Prediction")
    print("=" * 60)
    
    # Train models with different architectures and embeddings
    models_and_results = []
    
    # Train with ESM2 embeddings
    if os.path.exists(f"/kaggle/input/{config.embeds_map['ESM2']}/train_embeddings.npy"):
        model_esm2, results_esm2 = train_model("ESM2", "hybrid")
        models_and_results.append(("ESM2", model_esm2, results_esm2))
    
    # Generate predictions
    all_predictions = []
    
    for embed_source, model, results in models_and_results:
        print(f"\nBest validation F1 for {embed_source}: {results['best_score']:.4f}")
        preds = predict(model, embed_source)
        all_predictions.append(preds)
    
    # Combine predictions if multiple models
    if len(all_predictions) > 1:
        final_predictions = ensemble_predictions(all_predictions)
    else:
        final_predictions = all_predictions[0]
    
    # Load and merge with existing predictions if available
    existing_pred_path = '/kaggle/input/blast-quick-sprof-zero-pred/submission.tsv'
    if os.path.exists(existing_pred_path):
        print("\nMerging with existing predictions...")
        existing = pd.read_csv(existing_pred_path, sep='\t', header=None, 
                              names=['Id', 'GO term', 'Confidence'])
        
        # Merge predictions
        merged = pd.merge(existing, final_predictions, on=['Id', 'GO term'], 
                         how='outer', suffixes=('_existing', '_new'))
        
        # Combine confidences (take maximum or average)
        merged['Confidence'] = merged[['Confidence_existing', 'Confidence_new']].max(axis=1)
        final_predictions = merged[['Id', 'GO term', 'Confidence']]
    
    # Save submission
    print("\nSaving submission file...")
    final_predictions.to_csv('submission.tsv', sep='\t', header=False, index=False)
    print(f"Submission saved with {len(final_predictions)} predictions")
    
    # Visualize training results
    if models_and_results:
        plt.figure(figsize=(12, 4))
        
        plt.subplot(1, 2, 1)
        for embed_source, _, results in models_and_results:
            plt.plot(results['val_losses'], label=f"{embed_source}")
        plt.title('Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        
        plt.subplot(1, 2, 2)
        for embed_source, _, results in models_and_results:
            plt.plot(results['val_scores'], label=f"{embed_source}")
        plt.title('Validation F1 Score')
        plt.xlabel('Epoch')
        plt.ylabel('F1 Score')
        plt.legend()
        
        plt.tight_layout()
        plt.savefig('training_results.png')
        plt.show()
    
    print("\nDone! Submission file created: submission.tsv")

if __name__ == "__main__":
    main()