# Fully Convolutional PyTorch Model to Predict Protein Secondary Structure

### Overview

Your overall goal is to write a Fully Convolutional PyTorch model that can input protein sequence data (often called the Protein Primary Structure ), or additionally using PSSM Profiles to predict the protein secondary structure (H = Helix, E = Extended Sheet, C = Coil symbols).

The PDB Database contains the protein structures of over 200,000 proteins. Each has a unique PDB_ID code such as 1A0S (the first one in the training data) which is the structure shown above (sucrose-specific porin of salmonella) which is used to transfer sucrose across the cell membrane of salmonella bacteria which causes food poisoning. The protein has a 3D Structure which shows that most of this protein is extended beta sheet (flat arrows) and coil (random lines).

The Data Tab on Kaggle will allow you to browse the available data used for training. You should use this Data Tab to browse through the data so you understand what it is like. You will find a seqs_train.csv file which is a CSV file that gives the PDB_ID (unique identifier) and the SEQUENCE of each protein. You will also find a train.zip file which contains a large collection of 'PDB_ID'_train.csv files containing residue number, amino acid and PSSM profiles for each residue in that particular protein. The labels_train.csv file contains the secondary structure labels for the different training proteins (given as H = Helix, E = Extended Sheet, C = Coil symbols). The seqs_test.csv and test.zip contain similar data for the test sequences for which you need to predict the secondary structure.

IN ADDITION - you will also need to submit your Jupyter Notebook that produces these outputs via the Moodle web page.

Please see the Moodle course site for further details about this coursework.
<br>
### Evaluation
The evaluation metric is the "Q3 Accuracy" which is used for assessing the three states within a protein structure prediction (H = Helix, E = Extended Sheet, C = Coil). <br>

### Submission File
For each PDB_ID in the test set, you must predict the secondary structure of each residue in that protein. The file should contain a header and have the following format:

(So columns give ID consisting of 'PDB_ID', then underscore 'residue number', followed by the predicted secondary structure label of that residue.)

ID,STRUCTURE <br>
2AIO_1_A_1, C <br>
2AIO_1_A_2, C <br>
2AIO_1_A_3, C <br>
2AIO_1_A_4, H <br>
2AIO_1_A_5, H <br>
etc. <br>

## Import necessary libraries and store file paths

- The code imports necessary libraries for this project.
- It checks for the availability of a CUDA-enabled GPU and sets the device accordingly.
- It defines file paths for data files (for local use).

In [None]:
import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader, random_split

from sklearn.feature_selection import f_classif
from sklearn.preprocessing import LabelEncoder

from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.utils.tutorials.cnn_utils import evaluate, load_mnist, train
from ax.utils.notebook.plotting import init_notebook_plotting, render
init_notebook_plotting(offline=True)

# # Define dtype
# dtype = torch.float32  # or torch.float64 depending on your preference

# Check if a CUDA-enabled GPU is available
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("CUDA is available. Using GPU for accelerated computations.")
    print("Number of available CUDA devices:", torch.cuda.device_count())
    print("GPU device name:", torch.cuda.get_device_name(0))  # Assuming for only one GPU, index 0
else:
    device = torch.device("cpu")
    print("CUDA is not available. Using CPU for computations.")
# Display the device being used
print(f"Using device: {device}")


# Define store file paths
DATA_PATH = "./data/"
labels_train_path = DATA_PATH + "labels_train.csv"
sample_path = DATA_PATH + "sample.csv"
seqs_test_path = DATA_PATH + "seqs_test.csv"
seqs_train_path = DATA_PATH + "seqs_train.csv"
train_path = DATA_PATH + "train"
test_path = DATA_PATH + "test"

## Define a mapping from amino acid characters to integers

To enable model training and assessment, these mappings from amino acid characters to integers for encoding are necessary for converting categorical data into numerical representation: 
- `sec_struct_mapping`: A dictionary mapping secondary structure labels ('H' for Helix, 'E' for Extended Sheet, 'C' for Coil) to integer labels (0, 1, 2 respectively). Additional mappings can be added if there are more labels.
- `amino_acid_mapping`: A dictionary mapping amino acid characters to integer labels. Each amino acid is assigned a unique integer, with additional mappings provided for special cases such as unknown amino acids ('X'), ambiguous cases ('B', 'Z', 'J'), and gap or padding ('-').




In [None]:
# Define mappings for secondary structure and amino acids
sec_struct_mapping = {
    'H': 0,  # Helix
    'E': 1,  # Strand
    'C': 2   # Coil
}

amino_acid_mapping = {
    'A': 0,   # Alanine
    'C': 1,   # Cysteine
    'D': 2,   # Aspartic acid
    'E': 3,   # Glutamic acid
    'F': 4,   # Phenylalanine
    'G': 5,   # Glycine
    'H': 6,   # Histidine
    'I': 7,   # Isoleucine
    'K': 8,   # Lysine
    'L': 9,   # Leucine
    'M': 10,  # Methionine
    'N': 11,  # Asparagine
    'P': 12,  # Proline
    'Q': 13,  # Glutamine
    'R': 14,  # Arginine
    'S': 15,  # Serine
    'T': 16,  # Threonine
    'V': 17,  # Valine
    'W': 18,  # Tryptophan
    'Y': 19,  # Tyrosine
    'X': 20,  # Unknown amino acid
    'B': 21,  # Asparagine or Aspartic acid
    'Z': 22,  # Glutamine or Glutamic acid
    'J': 23,  # Leucine or Isoleucine
    '-': 24   # Gap or padding
}

## Building the Dataset for Protein Sequences and PSSM Data

Using a custom `Dataset` class allows for efficient management of protein data specifically tailored to your CNN model's needs. It encapsulates data loading, preprocessing, and access logic in a reusable way.
### Separate Data Sources:
- The code assumes protein sequences are stored in a separate CSV file (`seqs_train.csv` or `seqs_test.csv`) from individual protein data files (likely PSSM profiles) located in a directory (`train`). 
- This separation can improve code organization and potentially simplify data management if sequences and PSSM data have different update frequencies.
### Optional Label Loading: 
- The code allows loading secondary structure labels from a separate CSV file (`labels_train.csv`) if provided. 
- This flexibility enables training with or without labels depending on the task (supervised vs. unsupervised learning).
### One-Hot Encoding for Sequences:
- Converting protein sequences to one-hot encoded representations is a common practice for CNNs. 
- This transforms categorical amino acids into numerical vectors suitable for the model's computations.
### PSSM Normalization:
- Normalizing PSSM data can improve model performance by scaling the feature values to a similar range. 
- The code supports both 'min-max' and 'z-score' normalization methods, allowing you to experiment with different approaches.



In [None]:
class ProteinDataset(Dataset):
    """
    Dataset class for handling protein data.
    """

    def __init__(self, csv_file, train_dir, label_file=None, normalize_method='min-max'):
        """
        Initialize the ProteinDataset.

        Parameters:
        - csv_file (str): Path to the CSV file containing sequences.
        - train_dir (str): Directory containing protein data.
        - label_file (str, optional): Path to the CSV file containing labels.
        - normalize_method (str, optional): Normalization method for PSSM data.
        """

        # Load the sequences
        self.seqs = pd.read_csv(csv_file)

        # Load the protein data from the directory
        self.protein_data = {}  # Store the protein data in a dictionary
        for filename in os.listdir(train_dir):
            if filename.endswith(".csv"):
                protein_id = re.split(r'_train|_test', filename)[0]  # Split the filename to get the protein ID
                self.protein_data[protein_id] = pd.read_csv(os.path.join(train_dir, filename))

        # Load the labels, if provided
        self.labels = pd.read_csv(label_file) if label_file else None

        # Amino acid mapping
        self.amino_acid_mapping = amino_acid_mapping
        self.normalize_method = normalize_method

    def __len__(self):
        """
        Get the number of sequences in the dataset.

        Returns:
        - int: Number of sequences.
        """
        return len(self.seqs)

    def __getitem__(self, idx):
        """
        Get an item from the dataset.

        Parameters:
        - idx (int): Index of the item.

        Returns:
        - tuple: Tuple containing protein ID, sequence, PSSM, and labels (if available).
        """

        protein_id = self.seqs.iloc[idx]['PDB_ID']  # Get the protein ID
        sequence = self.seqs.iloc[idx]['SEQUENCE']  # Get the sequence
        encoded_sequence = self.encode_sequence(sequence)  # Encode the sequence
        pssm = self.protein_data[protein_id].values  # Assuming you will process PSSM separately
        normalized_pssm = self.normalize_pssm(pssm)  # Ensure this is uncommented to use normalized PSSM

        if self.labels is not None:
            label_seq = self.labels.iloc[idx]['SEC_STRUCT']  # Assuming the label is in the same order as the sequences
            label_numeric = [sec_struct_mapping[char] for char in label_seq]  # Convert the label to numeric format
            label_tensor = torch.tensor(label_numeric, dtype=torch.long)  # Convert the label to a tensor
            # Return protein ID, sequence, PSSM, and label tensor
            return (
                protein_id,
                torch.tensor(encoded_sequence, dtype=torch.float32),
                torch.tensor(normalized_pssm, dtype=torch.float32),
                label_tensor
            )
        # Return protein ID, sequence, and PSSM
        return (
            protein_id,
            torch.tensor(encoded_sequence, dtype=torch.float32),
            torch.tensor(normalized_pssm, dtype=torch.float32)
        )

    def encode_sequence(self, sequence):
        """
        Encode a sequence into a one-hot encoded vector.

        Parameters:
        - sequence (str): Sequence to encode.

        Returns:
        - numpy.ndarray: One-hot encoded sequence.
        """

        encoded_sequence = np.zeros((len(sequence), len(self.amino_acid_mapping)), dtype=int)
        for i, amino_acid in enumerate(sequence):
            # Default to 'X' for unknown amino acids
            index = self.amino_acid_mapping.get(amino_acid, self.amino_acid_mapping['X'])
            encoded_sequence[i, index] = 1
        return encoded_sequence

    def normalize_pssm(self, pssm):
        """
        Normalize the Position-Specific Scoring Matrix (PSSM).

        Parameters:
        - pssm (numpy.ndarray): The PSSM data to be normalized.

        Returns:
        - numpy.ndarray: The normalized PSSM data.
        """

        # Assuming the first two columns are non-numeric
        numeric_columns = pssm[:, 2:]  # Adjust this if your numeric data starts from a different column

        # Convert to floats
        try:
            pssm_numeric = numeric_columns.astype(np.float32)  # Float32 is usually sufficient
        except ValueError as e:
            # Handle or log the error if needed
            raise ValueError(f"Error converting PSSM to float: {e}")

        if self.normalize_method == 'min-max':
            # Min-Max normalization
            pssm_min = pssm_numeric.min(axis=0)
            pssm_max = pssm_numeric.max(axis=0)
            # Ensure no division by zero
            pssm_range = np.where(pssm_max - pssm_min == 0, 1, pssm_max - pssm_min)
            normalized_pssm = (pssm_numeric - pssm_min) / pssm_range
        elif self.normalize_method == 'z-score':
            # Z-Score normalization
            pssm_mean = pssm_numeric.mean(axis=0)
            pssm_std = pssm_numeric.std(axis=0)
            # Avoid division by zero
            pssm_std = np.where(pssm_std == 0, 1, pssm_std)
            normalized_pssm = (pssm_numeric - pssm_mean) / pssm_std
        else:
            # If no normalization method provided, return the original PSSM
            normalized_pssm = pssm_numeric

        return normalized_pssm

## Design Choices and Rationale for the ProteinModel Class
This class implements a fully convolutional neural network (FCN) model for protein classification. It serves as a fundamental component of the project pipeline for predicting protein secondary structures.

### Purpose
The purpose of this class is to define the architecture of the FCN model used for protein classification tasks. It consists of convolutional layers followed by a final layer that maps the input features to the number of classes for classification.


### Convolutional Architecture:

- **Leveraging Sequential Dependencies:** Protein secondary structures often exhibit local dependencies between amino acids. Convolutional layers are well-suited to capture these dependencies as they process the sequence data in a sequential manner, extracting features based on local windows of amino acids.

### Specific Layer Choices:

- **Multiple Convolutional Layers:** Stacking multiple convolutional layers allows the model to learn increasingly complex hierarchical features from the protein sequence representation. The initial layers capture low-level features like local sequence patterns, while deeper layers learn more abstract and global representations.
- **Increasing Channel Depths:** The number of output channels generally increases as we go deeper into the network. This allows the model to learn a richer set of features at each layer, facilitating more complex discriminative power for predicting secondary structures.
- **ReLU Activations:** The ReLU (Rectified Linear Unit) activation function is a popular choice for CNNs due to its efficiency and ability to introduce non-linearity. It helps the model learn non-linear relationships between the features extracted by the convolutional layers.
- **Final Layer and No Activation:** The final convolutional layer has an output channel dimension equal to the number of predicted secondary structure classes. It typically uses a kernel size of 1 to focus on capturing local dependencies within the feature maps from the previous layer. Since the commonly used CrossEntropyLoss function incorporates a softmax function, a separate activation function is not applied here.

### Output Transposition:

- **Reshaping for Prediction:** After the final convolutional layer, the output is transposed to have the format `[batch_size, sequence_length, num_classes]`. This ensures the correct output dimensions for predicting a class for each residue in the protein sequence. The model outputs a probability distribution over the class labels for each position in the protein sequence.

### Integration with Training Pipeline:

- **Initializing the Model:**  An instance of `ProteinModel` would be created within your training script, specifying the number of classes (e.g., 3 for Helix, Extended Sheet, Coil) and the number of input channels (based on your protein sequence representation, like 20 for one-hot encoded amino acids).
- **Training Process:** The model would be integrated into your training loop.  The forward pass would be used to compute predictions for protein sequences in each training batch. The loss function (e.g., CrossEntropyLoss) would compare these predictions with the ground truth labels to calculate the error. The optimizer (e.g., Adam) would then use the calculated errors to update the model's weights and biases iteratively during training.






In [None]:
class ProteinModel(nn.Module):
    """
    Convolutional neural network model for protein classification.
    """

    def __init__(self, num_classes=3, input_channels=20):
        """
        Initialize the ProteinModel.

        Parameters:
        - num_classes (int): Number of classes for classification.
        - input_channels (int): Number of input channels (e.g., 20 for amino acid one-hot encoding).
        """

        super(ProteinModel, self).__init__()
        # Total of 4 convolutional layers
        # Define the initial convolutional layer
        self.initial_conv = nn.Conv1d(in_channels=input_channels, out_channels=64, kernel_size=3, padding=1)
        
        # Define three convolutional layers
        self.conv1 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.conv3 = nn.Conv1d(in_channels=256, out_channels=512, kernel_size=3, padding=1)

        # Final layer that maps to the number of classes
        self.final_conv = nn.Conv1d(in_channels=512, out_channels=num_classes, kernel_size=1)

    def forward(self, x):
        """
        Forward pass of the model.

        Parameters:
        - x (torch.Tensor): Input tensor.

        Returns:
        - torch.Tensor: Output tensor after passing through the network.
        """

        # Apply convolutional layers with activation functions
        x = F.relu(self.initial_conv(x))
        
        # Apply three convolutional layers to the input 'x', 
        # each followed by a ReLU activation function to introduce non-linearity.
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))

        # Apply final convolutional layer - no activation, as CrossEntropyLoss includes it
        x = self.final_conv(x)

        # Transpose the output to match [batch_size, sequence_length, num_classes]
        x = x.transpose(1, 2)

        return x

## Batch Processing with Custom Collate Functions


### Challenges: with Variable-Length Data:

- Our protein sequences and PSSM data might have different lengths, creating difficulties during model training. 
- Fixed-size data structures are required for efficient vectorized operations within the model.

### Solution: Padding and Batching Strategy

- Addressing this challenge by defining two custom `collate_fn` functions:
    * `collate_fn`: Used for batches without labels (e.g., during pre-training).
    * `collate_fn_labels`: Used for batches with labels (e.g., for supervised training).
- Both functions perform the following key operations:
    * **Unpack Batch Data:** They unpack the input batch, which is a list of tuples containing protein IDs, sequences, and PSSMs (and labels for the function with labels).
    * **Pad Sequences and PSSMs:** They utilize the `pad_sequence` function from PyTorch to pad sequences and PSSMs within the batch to a common length. 
        * This ensures all data points in a batch have the same shape, enabling efficient processing by the CNN model.
        * Padding with a specific value (e.g., zeros) allows the model to distinguish between valid data and padding elements.
    * **Detaching Gradients:** Sequences and PSSMs are cloned and detached from the computational graph using `clone().detach()`. This can potentially improve memory usage during training by preventing unnecessary gradient calculations for these tensors.

### Addressing Labels (if present):

- The `collate_fn_labels` function specifically handles label data (secondary structure sequences).
    * It checks if labels are present in the batch.
    * If labels exist, it applies padding similar to sequences and PSSMs.
    * If labels are missing (e.g., during unsupervised pre-training), the function sets the `labels_padded` output to `None`.

### Creating a Mask:

- Both `collate_fn` functions create a mask tensor based on the original sequence lengths.
    * The mask is a binary tensor where 1 indicates a valid sequence position and 0 indicates padding.
    * This mask is crucial for the model to focus on relevant protein sequence information during training and avoid getting influenced by padding elements.

### Integration with DataLoaders:

- These custom `collate_fn` functions are passed as arguments when creating `DataLoader` objects. 
- The `DataLoader` then utilizes these functions to process individual data points into mini-batches during training.



In [None]:
def collate_fn(batch):
    """
    Collate function for processing batches without labels.

    Parameters:
    - batch (list): List of tuples containing ID, sequences, and PSSMs.

    Returns:
    - tuple: Tuple containing ID, padded sequences, and padded PSSMs.
    """

    id, sequences, pssms = zip(*batch)  # Unzip the batch

    # Pad sequences and PSSMs
    sequences_padded = pad_sequence([seq.clone().detach() for seq in sequences], batch_first=True)
    pssms_padded = pad_sequence([pssm.clone().detach() for pssm in pssms], batch_first=True)

    return id, sequences_padded, pssms_padded


def collate_fn_labels(batch):
    """
    Collate function for processing batches with labels.

    Parameters:
    - batch (list): List of tuples containing ID, sequences, PSSMs, and labels.

    Returns:
    - tuple: Tuple containing padded sequences, padded PSSMs, padded labels, and mask.
    """

    _, sequences, pssms, labels_list = zip(*batch)  # Unzip the batch

    # Pad sequences and PSSMs
    sequences_padded = pad_sequence([seq.clone().detach() for seq in sequences], batch_first=True)
    pssms_padded = pad_sequence([pssm.clone().detach() for pssm in pssms], batch_first=True)

    # Handling labels correctly
    if labels_list[0] is not None:  # Check if labels exist
        labels_padded = pad_sequence([label.clone().detach() for label in labels_list], batch_first=True)
    else:
        labels_padded = None

    # Create a mask based on the original sequence lengths
    mask = [torch.ones(len(label), dtype=torch.uint8) for label in labels_list]
    mask_padded = pad_sequence(mask, batch_first=True, padding_value=0)  # Assuming padding_value for labels is 0
    return sequences_padded, pssms_padded, labels_padded, mask_padded

## Training using Hyperparameter Tuning with Ax

### Hyperparameters Optimization:

- **Integration with Ax:** 
    * Ax provides a dictionary (`ax_params`) containing training hyperparameters like learning rate and number of epochs.
* **Flexibility with Defaults:** 
    * This allows Ax to explore different values during optimization, while providing default values for missing entries.

### Model, Loss, and Optimizer Setup:

- **Defining the Model:** An instance of the `ProteinModel` class, specifying the input channel dimension based on your protein sequence representation (e.g., 20 for one-hot encoded amino acids).
- **Loss Function:** The `nn.CrossEntropyLoss` function is chosen as it's commonly used for multi-class classification tasks like predicting protein secondary structures. Moving it to the device ensures compatibility with the model.
- **Optimizer Selection:** The `torch.optim.Adam` optimizer is a popular choice due to its efficiency and effectiveness. It's configured with the model's parameters and the specified learning rate.

### Structured Training Loop:

- The core training happens within a loop iterating over a pre-defined number of epochs (`num_epochs`). This allows the model to progressively learn from the data over multiple passes.
- **Training Model:** Inside the loop, the model is set to training mode using `model.train()` to activate dropout layers and other training-specific behaviors.
- **Tracking Training Statistics:** Initialize variables to track the running loss, correct predictions, and total predictions during the epoch. These metrics provide insights into the training progress.
- **Iterating Through Batches:** The function iterates over a data loader named `dataloader_hyper`, which presumably yields mini-batches of pre-processed protein sequence and label data.
    * **Processing Each Batch:** For each batch containing padded sequences, PSSMs, labels, and a mask:
        * PSSMs are permuted and moved to the device for compatibility with the model's input format.
        * Labels are also moved to the device to enable loss calculation on the appropriate hardware.
        * Gradients accumulated from previous iterations are cleared using `optimizer.zero_grad()`.
        * A forward pass through the model (`model(inputs)`) generates model predictions.
        * The model's outputs are reshaped to a format suitable for the `CrossEntropyLoss` function.
        * Labels are similarly reshaped to match the output format.
        * To handle potential discrepancies in batch sizes across iterations, the minimum batch size is determined and used for slicing both outputs and labels. This ensures compatibility during loss calculation.
        * The loss is calculated using the chosen criterion and propagated back with `.backward()` to compute gradients for weight updates.
        * The optimizer takes a step (`optimizer.step()`) to update the model's weights based on the calculated gradients, effectively learning from the current batch.
        * Training statistics are updated for the current batch (running loss, correct predictions).
- **Evaluating Epoch Performance:** After iterating through all batches in an epoch:
    * The average epoch loss is calculated by dividing the running loss by the dataset size within the data loader. This provides a measure of the model's overall performance on the training data for that epoch.
    * Training accuracy is calculated by dividing the total number of correct predictions by the total number of labels. This metric helps assess how well the model is classifying protein secondary structures.
    * Training progress is printed, showing the current epoch number, epoch loss, and training accuracy. This visualization aids in monitoring the training process.

### Returning Evaluation Metrics:

- After completing all epochs, the function returns a dictionary containing the calculated evaluation metrics (e.g., loss and accuracy). 
- These metrics are fed back to Ax for hyperparameter optimization, allowing it to identify parameter combinations that lead to better model performance.



In [None]:
def train_protein_model(ax_params):
    """
    Train the protein classification model using the given hyperparameters.

    Parameters:
    - ax_params (dict): Dictionary containing hyperparameters.

    Returns:
    - dict: Dictionary containing evaluation metrics (loss and accuracy).
    """

    # Extract hyperparameters from Ax parameterization
    learning_rate = ax_params.get("lr", 0.001)
    num_epochs = ax_params.get("num_epochs", 10)

    # Define your model, criterion, optimizer
    model = ProteinModel(input_channels=20).to(device)
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Train the model
    for epoch in range(num_epochs):
        model.train()  # Set model to training mode
        running_loss = 0.0
        correct_preds = 0
        total_preds = 0

        for sequences_padded, pssms, labels, mask_padded in dataloader_hyper:
            inputs = pssms.permute(0, 2, 1).to(device)  # Move input to device
            labels = labels.to(device)  # Move labels to device

            optimizer.zero_grad()
            outputs = model(inputs)
            outputs = outputs.permute(0, 2, 1).contiguous().view(-1, 3)  # Reshape to (batch_size * seq_length, num_classes)
            labels = labels.view(-1)
            min_batch_size = min(outputs.size(0), labels.size(0))
            outputs = outputs[:min_batch_size]
            labels = labels[:min_batch_size]
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)

            # Calculate training accuracy
            _, predicted = torch.max(outputs, 1)  # Get the index of the max log-probability
            correct_preds += (predicted == labels).sum().item()
            total_preds += labels.numel()

        epoch_loss = running_loss / len(dataloader_hyper.dataset)
        epoch_acc = correct_preds / total_preds
        print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')

    # Return the evaluation metric for Ax to optimize (e.g., loss or accuracy)
    return {"loss": epoch_loss, "accuracy": epoch_acc}


## Training, Validation, Testing, and Interpretation

### Encapsulation for Reusability:

- The class encapsulates the training pipeline, promoting reusability across different model architectures. It can train various models for protein secondary structure prediction by simply swapping the model instance within the `ProteinTrainer`.

### Modular Components:

- The constructor takes the model, criterion (loss function), optimizer, and datasets (train, validation, test) as input. This separation of concerns allows for easy modification of individual components without affecting the entire training pipeline.
- Data loaders are created for each dataset using `DataLoader`. These loaders handle batching, shuffling (for training), and potentially applying custom collation functions for specific data processing needs.

### Structured Training with Clear Goals:

- The `train_model` function orchestrates the training process for a specified number of epochs.
    * It maintains separate lists for training and validation metrics (loss and accuracy) to track performance during training.
- Inside the training loop:
    * The model is set to training mode to activate dropout and other training-specific behaviors.
    * Training statistics are initialized for each epoch (loss, correct predictions, total predictions) to monitor progress.
    * The function iterates over the training data loader, processing batches containing sequences, PSSMs, labels, and masks.
        * PSSMs and labels are moved to the appropriate device (CPU or GPU) for efficient computations.
        * Gradients are cleared before each batch using `optimizer.zero_grad()` to avoid accumulating gradients across multiple iterations.
        * A forward pass through the model generates predictions.
        * The loss is calculated using the criterion and propagated back with `.backward()` for weight updates.
        * The optimizer takes a step (`optimizer.step()`) to update the model's weights based on the calculated gradients.
        * Training statistics are updated for the current batch.
    * After each epoch:
        * Average epoch loss and accuracy are calculated for the training data, providing insights into model performance on the training set.
        * The model is evaluated on the validation set using `validate_model()`. Validation performance helps assess generalization and avoid overfitting.
        * Training progress is printed, showing the current epoch number, training and validation loss/accuracy, allowing you to monitor the training process.

### Validation for Generalization Assessment:

- The `validate_model` function evaluates the model's performance on the unseen validation set. This helps assess how well the model generalizes to unseen data and avoids overfitting to the training set.
- It sets the model to evaluation mode to deactivate dropout layers that might affect predictions.
- Similar to training, it iterates over the validation data loader, accumulating loss and accuracy statistics.
- It calculates and returns the average validation loss and accuracy, providing a measure of model generalization on unseen data.

### Testing for Final Performance Evaluation:

- The `test_model` function allows generating predictions on the unseen test set for final evaluation.
- It sets the model to evaluation mode.
- The function iterates directly over the test dataset, assuming appropriate indexing for accessing data points.
    * For each data point (potentially containing PDB ID, sequence, and PSSM):
        * The PSSM is converted to a tensor, reshaped, and moved to the device for compatibility with the model.
        * A prediction is made using a forward pass through the model.
        * The predicted secondary structure label is mapped from the numeric prediction.
        * The prediction (residue ID and predicted structure) is appended to a list.
- Finally, the predictions are saved to a CSV file using Pandas (assuming it's imported), allowing for further analysis or submission to a competition.

### Optional Model Interpretation:

- The `interpret_model` function demonstrates using Integrated Gradients from Captum for model interpretation. This is an optional functionality that can provide insights into how the model arrives at its predictions.
- It iterates over the data loader (assuming it provides data suitable for interpretation).
- For the first batch:
    * Model predictions and attributions using Integrated Gradients are calculated.
    * The attribution scores for the first sample are printed. These scores highlight which parts of the input PSSM contribute most to the model's predictions, aiding in understanding the model's decision-making process.



In [None]:
class ProteinTrainer:
    """
    Class for training, validating, testing, and interpreting protein classification models.
    """

    def __init__(self, model, criterion, optimizer, train_dataset, val_dataset, test_dataset, batch_size):
        """
        Initialize the ProteinTrainer.

        Parameters:
        - model (nn.Module): The neural network model.
        - criterion: Loss function.
        - optimizer: Optimization algorithm.
        - train_dataset: Dataset for training.
        - val_dataset: Dataset for validation.
        - test_dataset: Dataset for testing.
        - batch_size (int): Batch size for training.
        """
        self.model = model.to(device)  # Move the model to the device
        self.criterion = criterion.to(device)  # Move the loss function to the device
        self.optimizer = optimizer
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset
        self.batch_size = batch_size

        # Create data loaders
        self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn_labels)
        self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn_labels)
        self.test_loader = DataLoader(self.test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

    def train_model(self, num_epochs):
        """
        Train the model for a specified number of epochs.

        Parameters:
        - num_epochs (int): Number of epochs for training.
        """
        train_losses = []
        train_accuracies = []
        val_losses = []
        val_accuracies = []

        for epoch in range(num_epochs):
            self.model.train()  # Set model to training mode
            running_loss = 0.0
            correct_preds = 0
            total_preds = 0

            for sequences, pssms, labels, _ in self.train_loader:
                inputs = pssms.permute(0, 2, 1).to(device)  # Move input to device
                labels = labels.to(device)  # Move labels to device

                self.optimizer.zero_grad() # Clear the gradients

                outputs = self.model(inputs) # Forward pass
                loss = self.criterion(outputs.transpose(1, 2), labels) # Calculate the loss

                loss.backward() # Backward pass
                self.optimizer.step() # Update weights

                running_loss += loss.item() * inputs.size(0) # Accumulate the loss for the batch

                # Calculate training accuracy
                _, predicted = torch.max(outputs, 2)  # Get the index of the max log-probability
                correct_preds += (predicted == labels).sum().item() # Count the number of correct predictions
                total_preds += labels.numel() # Count the total number of predictions

            epoch_loss = running_loss / len(self.train_loader.dataset) # Calculate the average loss for the epoch
            epoch_acc = correct_preds / total_preds # Calculate the accuracy

            # Append training loss and accuracy
            train_losses.append(epoch_loss)
            train_accuracies.append(epoch_acc)

            # Evaluate on validation set
            val_loss, val_acc = self.validate_model()
            val_losses.append(val_loss)
            val_accuracies.append(val_acc)

            print(f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {epoch_loss:.4f}, Train Accuracy: {epoch_acc:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.4f}')

        # Plot loss and accuracy curves
        plt.figure(figsize=(12, 8))
        plt.subplot(2, 1, 1)
        plt.plot(train_losses, label='Train Loss')
        plt.plot(val_losses, label='Validation Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()

        plt.subplot(2, 1, 2)
        plt.plot(train_accuracies, label='Train Accuracy')
        plt.plot(val_accuracies, label='Validation Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()

        plt.tight_layout()
        plt.show()

    def validate_model(self):
        """
        Validate the model on the validation dataset.

        Returns:
        - tuple: Validation loss and accuracy.
        """
        self.model.eval()  # Set model to evaluation mode
        running_loss = 0.0
        correct_preds = 0
        total_preds = 0

        with torch.no_grad():
            for sequences, pssms, labels, _ in self.val_loader:
                inputs = pssms.permute(0, 2, 1).to(device)  # Move input to device
                labels = labels.to(device)  # Move labels to device

                outputs = self.model(inputs) # Forward pass
                loss = self.criterion(outputs.transpose(1, 2), labels) # Calculate the loss

                running_loss += loss.item() * inputs.size(0) # Accumulate the loss for the batch

                # Calculate accuracy
                _, predicted = torch.max(outputs, 2) # Get the index of the max log-probability
                correct_preds += (predicted == labels).sum().item() # Count the number of correct predictions
                total_preds += labels.numel() # Count the total number of predictions

        val_loss = running_loss / len(self.val_loader.dataset) # Calculate the average loss for the validation set
        val_acc = correct_preds / total_preds # Calculate the accuracy
        return val_loss, val_acc

    def test_model(self, output_file='./data/submission.csv'):
        """
        Test the model on the test dataset and save predictions to a CSV file.

        Parameters:
        - output_file (str): Path to save the CSV file with predictions.
        """
        self.model.eval()  # Set the model to evaluation mode
        predictions = []

        with torch.no_grad():
            for i in range(len(self.test_dataset)):  # Iterate directly over the dataset
                pdb_id, _, pssm = self.test_dataset[i]  # Assuming the dataset returns PDB_ID, sequence, and PSSM

                # Prepare the input tensor; add an extra batch dimension using unsqueeze
                input_pssm = pssm.unsqueeze(0).permute(0, 2, 1).to(device)  # Move input to device

                # Make a prediction
                outputs = self.model(input_pssm) # Forward pass
                _, predicted = torch.max(outputs, 2)  # Get the index of max log-probability

                # Process the predictions
                seq_len = pssm.shape[0]  # Assuming pssm is [features, seq_len]
                for j in range(seq_len):
                    residue_id = f"{pdb_id}_{j + 1}"  # Construct the ID
                    structure_label = ['H', 'E', 'C'][predicted[0, j].item()]  # Map numeric predictions to labels
                    predictions.append([residue_id, structure_label]) # Append the prediction

        # Write predictions to CSV
        pd.DataFrame(predictions, columns=['ID', 'STRUCTURE']).to_csv(output_file, index=False)
        print(f'Submission file saved to {output_file}')

    def interpret_model(self, data_loader):
        """
        Interpret the model predictions using integrated gradients.

        Parameters:
        - data_loader: Data loader for the dataset.

        This method prints the attribution scores for the first sample in the data loader.
        """
        
        from captum.attr import IntegratedGradients

        ig = IntegratedGradients(self.model)

        for sequences, pssms, labels, _ in data_loader:
            inputs = pssms.permute(0, 2, 1).to(device)  # Move input to device
            labels = labels.to(device)  # Move labels to device

            # Get model predictions
            outputs = self.model(inputs)
            _, predicted = torch.max(outputs, 2)

            # Compute integrated gradients
            attributions, delta = ig.attribute(inputs, target=predicted, return_convergence_delta=True)

            # Print the attribution scores for the first sample
            print("Attribution scores for the first sample:")
            print(attributions[0])

            break  # Break after processing the first batch (for demonstration purposes)

In [None]:
# Create datasets
dataset = ProteinDataset(csv_file=seqs_train_path, train_dir=train_path, label_file=labels_train_path)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_subset, val_subset = random_split(dataset, [train_size, val_size])

test_dataset = ProteinDataset(csv_file=seqs_test_path, train_dir=test_path)

# Create data loaders
dataloader_hyper = DataLoader(train_subset, batch_size=4, shuffle=True, collate_fn=collate_fn_labels)
val_dataloader_hyper = DataLoader(val_subset, batch_size=2, shuffle=False, collate_fn=collate_fn_labels)

# Create Ax client and experiment
ax_client = AxClient()
ax_client.create_experiment(
    name="tune_protein_model",
    parameters=[
        {
            "name": "lr",
            "type": "range",
            "bounds": [1e-6, 0.4],
            "value_type": "float",
            "log_scale": True,
        },
        {
            "name": "num_epochs",
            "type": "range",
            "bounds": [10, 100],
            "value_type": "int",
        },
    ],
    objectives={"accuracy": ObjectiveProperties(minimize=False)},
)

# Attach the initial trial
ax_client.attach_trial(parameters={"lr": 0.01, "num_epochs": 20})

# Get the parameters and run the initial trial
baseline_parameters = ax_client.get_trial_parameters(trial_index=0)
print(baseline_parameters)
ax_client.complete_trial(trial_index=0, raw_data=train_protein_model(baseline_parameters))

# Run additional trials and optimize hyperparameters
for i in range(2):
    parameters, trial_index = ax_client.get_next_trial()
    ax_client.complete_trial(trial_index=trial_index, raw_data=train_protein_model(parameters))

# Get the best hyperparameters
best_parameters, values = ax_client.get_best_parameters()
print("Best hyperparameters:", best_parameters)

# Create the final model, trainer, and evaluate
model = ProteinModel(input_channels=20).to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=best_parameters["lr"])

trainer = ProteinTrainer(model, criterion, optimizer, train_subset, val_subset, test_dataset, batch_size=32)

# Train and evaluate the final model
num_epochs = best_parameters["num_epochs"]
trainer.train_model(num_epochs)

# # Interpret the model using Captum
# trainer.interpret_model(trainer.val_loader)

# Test the model and save the predictions
trainer.validate_model()
trainer.test_model()

In [None]:
# # Create datasets
# dataset = ProteinDataset(csv_file=seqs_train_path, train_dir=train_path, label_file=labels_train_path)
# train_size = int(0.8 * len(dataset))
# val_size = len(dataset) - train_size
# train_subset, val_subset = random_split(dataset, [train_size, val_size])

# test_dataset = ProteinDataset(csv_file=seqs_test_path, train_dir=test_path)

# # Create model, loss function, and optimizer
# model = ProteinModel()
# criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001, weight_decay=0.0)

# # Create trainer instance
# trainer = ProteinTrainer(model, criterion, optimizer, train_subset, val_subset, test_dataset, batch_size=64)

# # Train and evaluate the model
# num_epochs = 50
# trainer.train_model(num_epochs)
# trainer.validate_model()
# trainer.test_model()