In [None]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import time
import matplotlib.pyplot as plt
import os

# TORCH MODULES
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics.classification import MultilabelF1Score

# --- 1. SETUP & CONFIGURATION ---

In [None]:
# All constants and hyperparameters are defined in one place for easy management.
class config:
    MAIN_DIR = "/kaggle/input/cafa-6-protein-function-prediction"
    
    # Model and Training Parameters
    num_labels = 500
    n_epochs = 8
    batch_size = 128
    lr = 0.001  # Lowered learning rate for a more stable start.
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Using device: {config.device}")

# Mapping for embeddings
embeds_map = {
    "T5" : "t5embeds",
    "ProtBERT" : "protbert-embeddings-for-cafa5",
    "EMS2" : "cafa-5-ems-2-embeddings-numpy"
}
embeds_dim = {
    "T5" : 1024,
    "ProtBERT" : 1024,
    "EMS2" : 1280
}

# --- 2. DATA LOADING: PROTEIN SEQUENCE DATASET ---

In [None]:
# This PyTorch Dataset class is responsible for loading the protein embeddings and 
# their corresponding labels. It handles both training data (embeddings + labels) 
# and test data (embeddings + protein IDs).
class ProteinSequenceDataset(Dataset):
    def __init__(self, datatype, embeddings_source):
        super(ProteinSequenceDataset).__init__()
        self.datatype = datatype
        
        base_path = f"/kaggle/input/{embeds_map[embeddings_source]}/"
        
        # Construct paths correctly
        embeds_path = os.path.join(base_path, f"{datatype}_embeddings.npy")
        ids_path = os.path.join(base_path, f"{datatype}_ids.npy")
        
        # Handle special filenames for T5
        if embeddings_source == "T5":
            embeds_path = os.path.join(base_path, f"{datatype}_embeds.npy")

        embeds = np.load(embeds_path)
        ids = np.load(ids_path)
            
        # Optimized DataFrame creation
        self.df = pd.DataFrame({"EntryID": ids, "embed": list(embeds)})
        
        if datatype=="train":
            labels_path = f"/kaggle/input/train-targets-top{config.num_labels}/train_targets_top{config.num_labels}.npy"
            np_labels = np.load(labels_path)
            
            # Merge labels into the DataFrame
            labels_df = pd.DataFrame({"EntryID": self.df['EntryID'], "labels_vect": list(np_labels)})
            self.df = pd.merge(self.df, labels_df, on="EntryID")
            
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        embed = torch.tensor(self.df.iloc[index]["embed"], dtype=torch.float32)
        if self.datatype == "train":
            targets = torch.tensor(self.df.iloc[index]["labels_vect"], dtype=torch.float32)
            return embed, targets
        else: # datatype == "test"
            protein_id = self.df.iloc[index]["EntryID"]
            return embed, protein_id

# --- 3. MODEL ARCHITECTURE: 1D CONVOLUTIONAL NEURAL NETWORK (CNN) ---

In [None]:
# The model is a 1D CNN designed to find patterns in the sequence embeddings. Key features include:
# - `Conv1d` layers for feature extraction.
# - `BatchNorm1d` to stabilize and speed up training.
# - `ReLU` as the activation function.
# - `MaxPool1d` for down-sampling.
# - `Dropout` in the final fully-connected block to prevent overfitting.
class CNN1D(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(CNN1D, self).__init__()
        self.conv_block1 = nn.Sequential(
            nn.Conv1d(in_channels=1, out_channels=32, kernel_size=5, padding=2),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2)
        )
        self.conv_block2 = nn.Sequential(
            nn.Conv1d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2)
        )
        
        flattened_size = int(64 * input_dim / 4)
        
        self.fc_block = nn.Sequential(
            nn.Linear(in_features=flattened_size, out_features=1024),
            nn.ReLU(),
            nn.Dropout(p=0.4), # Added Dropout to prevent overfitting
            nn.Linear(in_features=1024, out_features=num_classes)
        )

    def forward(self, x):
        # (batch_size, embed_size) -> (batch_size, 1, embed_size)
        x = x.unsqueeze(1) 
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = torch.flatten(x, 1)
        x = self.fc_block(x)
        return x

# --- 4. MODEL TRAINING ---

In [None]:
# This function encapsulates the entire training and validation loop. Key improvements include:
# - Loss Function: Using `BCEWithLogitsLoss`, which is essential for multi-label classification.
# - Metric-driven Learning: The learning rate scheduler adjusts based on the validation F1-score.
# - Optimal Threshold Finding: During validation, the code iterates through different thresholds 
#   to find the one that maximizes the F1-score for the current epoch.
def train_model(embeddings_source, model_type="convolutional", train_size=0.9):
    
    train_dataset = ProteinSequenceDataset(datatype="train", embeddings_source=embeddings_source)
    
    train_set, val_set = random_split(train_dataset, 
                                      lengths=[int(len(train_dataset) * train_size), 
                                               len(train_dataset) - int(len(train_dataset) * train_size)])
    
    train_dataloader = DataLoader(train_set, batch_size=config.batch_size, shuffle=True)
    val_dataloader = DataLoader(val_set, batch_size=config.batch_size, shuffle=False)

    if model_type == "convolutional":
        model = CNN1D(input_dim=embeds_dim[embeddings_source], num_classes=config.num_labels).to(config.device)
    else:
        raise ValueError("Unsupported model type")

    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=1, verbose=True) # Operates on F1-score
    
    # CRITICAL CHANGE: Using the correct loss function for multi-label classification
    loss_fn = torch.nn.BCEWithLogitsLoss()
    
    # Metrics
    f1_metric = MultilabelF1Score(num_labels=config.num_labels, average='macro').to(config.device)

    print("STARTING TRAINING...")
    
    best_val_f1 = 0.0
    best_threshold = 0.5

    for epoch in range(config.n_epochs):
        print(f"EPOCH {epoch+1}/{config.n_epochs}")
        
        # --- TRAIN PHASE ---
        model.train()
        total_train_loss = 0
        for embed, targets in tqdm(train_dataloader, desc="Training"):
            embed, targets = embed.to(config.device), targets.to(config.device)
            
            optimizer.zero_grad()
            preds_logits = model(embed)
            loss = loss_fn(preds_logits, targets)
            
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_dataloader)
        print(f"Average Training Loss: {avg_train_loss:.4f}")

        # --- VALIDATION PHASE ---
        model.eval()
        all_val_preds = []
        all_val_targets = []
        with torch.no_grad():
            for embed, targets in tqdm(val_dataloader, desc="Validation"):
                embed, targets = embed.to(config.device), targets.to(config.device)
                preds_logits = model(embed)
                
                all_val_preds.append(torch.sigmoid(preds_logits))
                all_val_targets.append(targets)

        all_val_preds = torch.cat(all_val_preds)
        all_val_targets = torch.cat(all_val_targets)

        # Find the best F1 score by searching for the optimal threshold
        best_f1_for_epoch = 0
        best_thresh_for_epoch = 0
        thresholds = np.arange(0.1, 0.51, 0.05)
        for thresh in thresholds:
            f1_metric.threshold = thresh
            f1 = f1_metric(all_val_preds, all_val_targets.int())
            if f1 > best_f1_for_epoch:
                best_f1_for_epoch = f1
                best_thresh_for_epoch = thresh

        print(f"Average Validation F1-Score: {best_f1_for_epoch:.4f} (at best threshold: {best_thresh_for_epoch:.2f})")

        scheduler.step(best_f1_for_epoch)

        # Save the best model
        if best_f1_for_epoch > best_val_f1:
            best_val_f1 = best_f1_for_epoch
            best_threshold = best_thresh_for_epoch
            torch.save(model.state_dict(), "best_model.pth")
            print(f"New best model saved! F1: {best_val_f1:.4f}")

    print("\nTRAINING FINISHED")
    print(f"Highest Validation F1-Score: {best_val_f1:.4f}")
    print(f"Best threshold for this score: {best_threshold:.2f}")

    # Load the best performing model
    model.load_state_dict(torch.load("best_model.pth"))
    
    return model, best_threshold

# Train the model
ems2_model, best_threshold = train_model(embeddings_source="EMS2", model_type="convolutional")

# --- 5. GENERATING PREDICTIONS ---

In [None]:
# After training, the best model is used to generate predictions on the test set. 
# The saved optimal threshold from the validation phase is used to convert model 
# probabilities into final binary predictions.
def predict(model, embeddings_source, threshold):
    test_dataset = ProteinSequenceDataset(datatype="test", embeddings_source=embeddings_source)
    test_dataloader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)
    
    model.eval()
    
    # Pre-load label names
    labels_df = pd.read_csv(os.path.join(config.MAIN_DIR, "Train/train_terms.tsv"), sep="\t")
    top_terms = labels_df.groupby("term")["EntryID"].count().sort_values(ascending=False)
    labels_names = top_terms.head(config.num_labels).index.values
    
    print("\nGENERATING PREDICTIONS FOR THE TEST SET...")
    
    results = []
    with torch.no_grad():
        for embed, ids in tqdm(test_dataloader, desc="Predicting"):
            embed = embed.to(config.device)
            preds_logits = model(embed)
            preds_probs = torch.sigmoid(preds_logits).cpu().numpy()
            
            # Collect predictions that are above the optimal threshold
            for i, protein_id in enumerate(ids):
                protein_probs = preds_probs[i]
                go_indices = np.where(protein_probs > threshold)[0]
                for idx in go_indices:
                    results.append({
                        "Id": protein_id,
                        "GO term": labels_names[idx],
                        "Confidence": protein_probs[idx]
                    })
    
    submission_df = pd.DataFrame(results)
    print("PREDICTIONS COMPLETE.")
    return submission_df

submission_df = predict(ems2_model, "EMS2", best_threshold)

# --- 6. SUBMISSION FILE GENERATION ---

In [None]:
# The final step is to create the `submission.tsv` file. This involves merging our
# model's predictions with an external, pre-existing submission file. The `.fillna()`
# method provides an efficient way to combine them.
print("\nMerging submission files...")

# Load external submission file
submission2 = pd.read_csv('/kaggle/input/blast-quick-sprof-zero-pred/submission.tsv',
                          sep='\t', header=None, names=['Id', 'GO term', 'Confidence2'])

# Merge the two submissions. `outer` join keeps all rows from both DataFrames.
subs = pd.merge(submission_df, submission2, on=['Id', 'GO term'], how='outer')

# Combine confidence scores efficiently.
# Fill NaN values in 'Confidence2' with values from our model's 'Confidence'.
subs['Confidence_combined'] = subs['Confidence2'].fillna(subs['Confidence'])

# Select only the required columns and save to file
final_submission = subs[['Id', 'GO term', 'Confidence_combined']]
final_submission.to_csv('submission.tsv', sep='\t', header=False, index=False)

print("Submission file 'submission.tsv' created successfully!")
print(f"It contains {len(final_submission)} predictions in total.")