## This code cell is used to import the modules and libraries required for the project, including data processing, deep learning model training, machine learning algorithm application, and bioinformatics-related tools.

In [None]:
# Import libraries for data processing and analysis
import pandas as pd  # For handling tabular data
import numpy as np  # Provides efficient array operations and mathematical tools
import json  # For loading and saving data in JSON format
import re  # Provides regular expression functions for string matching and replacement

# Import progress bar tool
from tqdm import tqdm  # For displaying a progress bar during loops, improving code interactivity

# Import PyTorch and related tools
import torch  # Deep learning framework for building and training neural networks
from torch.utils.data import DataLoader, TensorDataset, random_split  # Tools for data loading and processing
import torch.nn as nn  # Provides neural network modules
import torch.optim as optim  # Provides optimization algorithms
import esm  # Toolkit for protein sequence modeling

# Import machine learning related tools
from sklearn.model_selection import train_test_split  # For splitting data into training and test sets
from sklearn.ensemble import ExtraTreesClassifier  # Provides an extended implementation of the random forest algorithm
from sklearn.preprocessing import LabelEncoder  # For converting labels into numeric encoding

# Import general tools
from functools import partial  # For creating functions with partially applied arguments
import multiprocessing  # Provides parallel computing functionality
import pickle  # For saving and loading binary objects

# Import bioinformatics tools
from Bio import pairwise2  # Provides sequence alignment functions

## Enzyme Functional Annotation
### Predicting EC Numbers Based on Protein Sequences
### Today's exercise focuses on training and analyzing an EC number prediction model, addressing the following aspects:

- 1. Data Preprocessing: Reviewing EC numbers and generating protein/enzyme embeddings.
- 2. Machine Learning Model Training and Analysis: Developing a machine learning model for enzyme EC number prediction.
- 3. Deep Learning Model Training and Analysis: Constructing a deep learning model for enzyme EC number prediction.
- 4. Prediction Using Published Models: Utilizing published deep learning models (CLEAN and DeepECTransformer) to predict EC numbers for E. coli sequences.
- 5. Result Analysis and Comparison: Analyzing and comparing the prediction results from CLEAN and DeepECTransformer.

## 1. Data Preprocessing

## 1.1 Introduction to EC Numbers
- EC Number (Enzyme Commission Number) is a standardized naming system used for classifying and describing enzyme-catalyzed reactions. It consists of four - - - groups of numbers, each representing different aspects of the enzyme: its class, subclass, sub-subclass, and specific reaction.
- EC 1.1.1.1 represents alcohol dehydrogenase, with the following specific structure:
- 1: Oxidoreductases
- 1.1: Oxidoreductases acting on alcohol groups
- 1.1.1: Acting on NAD⁺ or NADP⁺ as electron acceptors

### This code cell is used to load and clean the UniProt dataset, processing and encoding the EC numbers of enzymes. The primary goal of the cleaned data is to provide high-quality training samples for subsequent models, while also preserving the EC number label mapping for future use.

In [None]:
# Define a function to validate EC numbers
def is_valid_ec_number(ec_number):
    """
    Validates if an EC number is in the correct format.
    Uses a regular expression to match a valid 4-part EC number:
    1. Composed of four groups of digits separated by '.'.
    2. For example: '1.1.1.1' is valid, but '1.1.1' or '1.1.1.a' are not valid.
    
    Parameters:
    - ec_number (str): The EC number to validate.

    Returns:
    - bool: True if valid, False if invalid.
    """
    pattern = r"^\d+\.\d+\.\d+\.\d+$"  # Regular expression for matching EC number
    return bool(re.match(pattern, ec_number))

# Example usage
examples = ["1.1.1.1", "1.1.1", "1.1.1.a", "2.7.1.12"]
example_results = [is_valid_ec_number(ec) for ec in examples]
print(f"Validation results: {dict(zip(examples, example_results))}")

In [None]:
# Load UniProt data
uniprot_data_file = './uniprotkb_organism_id_9606_2024_11_21.tsv'  # Path to the UniProt data file
uniprot_data = pd.read_csv(uniprot_data_file, sep='\t')  # Read the TSV file into a DataFrame

# Data cleaning
# Filter out records where EC number or sequence is missing
uniprot_data = uniprot_data[uniprot_data['EC number'].notna() & uniprot_data['Sequence'].notna()]

# Clean EC numbers, keeping only valid ones
uniprot_data['EC number'] = uniprot_data['EC number'].apply(
    lambda x: [ec for ec in x.split('; ') if is_valid_ec_number(ec)]
)

# Filter out records with sequence length greater than 1000 to avoid interference with model training
uniprot_data = uniprot_data[uniprot_data['Sequence'].apply(lambda x: len(x) < 1000)]

# Expand multiple EC numbers into separate rows, remove null values, and reset index
uniprot_data = uniprot_data.explode('EC number').dropna(subset=['EC number']).reset_index(drop=True)

# For convenience, retain only the first 2000 records
uniprot_data = uniprot_data.head(2000)
uniprot_data

## This cell performs label encoding for the EC numbers (Enzyme Commission numbers) in the UniProt dataset. The EC numbers are first extracted from the dataset and then converted into integer labels using LabelEncoder.

In [None]:
# Label encoding
ec_numbers = uniprot_data['EC number']  # Extract the EC number column
label_encoder = LabelEncoder()  # Initialize the label encoder
ecnumber_labels = label_encoder.fit_transform(ec_numbers)  # Convert EC numbers to integer labels

# Create mappings between labels and indices
label_to_index = dict(zip(label_encoder.classes_, range(len(label_encoder.classes_))))  # Mapping from label to index
index_to_label = dict(zip(range(len(label_encoder.classes_)), label_encoder.classes_))  # Mapping from index to label

# Save the mappings to files
with open('./label_to_index.pkl', 'wb') as f:
    pickle.dump(label_to_index, f)  # Save the label-to-index mapping
with open('./index_to_label.pkl', 'wb') as f:
    pickle.dump(index_to_label, f)  # Save the index-to-label mapping

# Generate an array of labels for further processing
ecnumber_label = np.array([label_to_index[ec] for ec in ec_numbers])

label_to_index

## 1.2 Protein Sequence Feature Extraction (Knowledge review)
- **Definition**: Protein language models, such as ESM2 (Evolutionary Scale Modeling 2), are pre-trained deep learning models that can encode protein sequences into numerical representations. These models capture the contextual information and structural properties of amino acids in the sequence.
- There are different variants of ESM2 with parameter counts ranging from 35M to 15B. Among them, the commonly used ESM2-T33-650M_UR50D model has about 650 million parameters.
- ESM2 was pretrained on large-scale protein sequence databases, including UniRef50 (a non-redundant protein sequence database containing 50% sequence identity) and other publicly available protein sequence datasets.
- ESM2 accepts protein sequences ESM2 accepts as input protein sequences, typically represented as one-dimensional sequences of amino acids, and the model generates embedding vectors for each position, which can be used for a variety of downstream tasks.
- **Example**:

   - if the download is slow, you can download the model from https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t6_8M_UR50D.pt and put it in /data/home/{yourusername}/.cache/torch/hub/checkpoints/ 


## This cell defines a function `batch_esm_embedding`, which is used to extract ESM model embeddings for a given set of sequences. 
## The function processes sequences in batches, feeding them into the ESM2 model, and generates sequence-level embeddings which are stored in dictionary format.

In [None]:
# Define a function for batch extraction of ESM embeddings
def batch_esm_embedding(sequences, batch_size=4):
    """
    Extracts ESM embeddings for a list of protein sequences in batches.

    Main steps:
    1. Load the pre-trained ESM2 model and its alphabet.
    2. Process sequences in batches to save GPU memory.
    3. Generate ESM representations (average embedding vector from layer 6) for each batch.
    4. Return a mapping of sequences to their corresponding embeddings.

    Parameters:
    - sequences (list of str): List of input protein sequences.
    - batch_size (int, optional): The size of each batch, default is 128.

    Returns:
    - embeddings_dict (dict): A dictionary mapping each sequence to its corresponding embedding vector.
    """
    # Check if GPU is available, prefer using GPU if possible
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load the pre-trained ESM2 model and its alphabet
    model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
    model = model.to(device).eval()  # Switch the model to evaluation mode
    batch_converter = alphabet.get_batch_converter()  # Get the batch converter
    
    # Initialize a dictionary to store the embeddings
    embeddings_dict = {}

    # Process the sequences in batches
    for i in tqdm(range(0, len(sequences), batch_size), desc="Processing batches"):
        # Get the current batch of sequences
        batch_seqs = sequences[i:i + batch_size]
        
        # Convert the sequences into the format required by the ESM model
        data = [(idx, seq) for idx, seq in enumerate(batch_seqs)]
        _, _, batch_tokens = batch_converter(data)
        batch_tokens = batch_tokens.to(device)

        # Disable gradient calculation and get the embeddings
        with torch.no_grad():
            # Use the model to calculate the representations from layer 6 (embedding for each sequence)
            results = model(batch_tokens, repr_layers=[6])
            sequence_embeddings = results["representations"][6].mean(dim=1).cpu().numpy()

        # Update the dictionary with the sequences and their embeddings
        embeddings_dict.update({seq: emb for seq, emb in zip(batch_seqs, sequence_embeddings)})

    # Return the embeddings dictionary
    return embeddings_dict

## This cell use the batch_esm_embedding function to extract embeddings for protein sequences and save the embeddings dictionary.
## Convert the extracted embeddings into NumPy arrays and prepare the input data for further analysis.

In [None]:
# Extract the list of protein sequences from the data
sequence = uniprot_data['Sequence'].to_list()  # Extract all protein sequences and store them in a list

# Generate sequence embeddings using the ESM model
esm_embeddings_dict = batch_esm_embedding(sequence)

# Save the embeddings dictionary to a file
with open('./esm_embeddings_dict.pkl', 'wb') as f:
    pickle.dump(esm_embeddings_dict, f)  # Serialize the embeddings dictionary to a local file using pickle

# Extract embeddings from the dictionary and convert them to a NumPy array
esm_embeddings = np.array([esm_embeddings_dict[seq] for seq in sequence])
# Record the dimensions of the embeddings for use in later model building
seq_shape = esm_embeddings.shape[1]

# Inspect the embedding results
esm_embeddings_dict

## 2.Machine Learning Model Training

## This cell mainly performs the following tasks: 
- Ensure that the lengths of labels and features are consistent to maintain the integrity of the training data. 
- Randomly split the embedding data, labels, and sequences into training and test sets.
- Train a classification model using ExtraTreesClassifier and make predictions on the test set.
- Compare the predicted results with the actual results to calculate and output the model's accuracy.

In [None]:
# Verify that the lengths of the labels and embedding features match to avoid mismatched data errors
assert len(ecnumber_label) == len(esm_embeddings), "The lengths of labels and features do not match"

# Randomly split the data into training and test sets (80% training, 20% test)
Train_data, Test_data, Train_label, Test_label, Train_seq, Test_seq = train_test_split(
    esm_embeddings,  # Input feature embeddings
    ecnumber_label,  # Labels
    sequence,        # Sequences (optional, for further analysis)
    test_size=0.2,   # Test set ratio
    random_state=42  # Fixed random seed for reproducibility
)

# Initialize the ExtraTrees classifier (a variant of random forests)
model = ExtraTreesClassifier()  # Automatically selects default parameters
model.fit(Train_data, Train_label)  # Train the model using the training data

# Output the shape of the data and labels to check if they match expectations
print("Training data feature shape:", Train_data.shape)
print("Training data label shape:", Train_label.shape)
print("Test data feature shape:", Test_data.shape)
print("Test data label shape:", Test_label.shape)

# Use the trained model to predict on the test set
Test_label_pred = model.predict(Test_data)

# Compare the predicted results with the actual labels and display the first 20 samples (predicted vs actual)
print("First 20 predicted labels (raw values):", Test_label_pred[:20])
print("First 20 actual labels (raw values):", Test_label[:20])

# Calculate and output the model's accuracy
accuracy = np.mean(Test_label_pred == Test_label)  # Calculate the proportion of correct predictions
print("Model accuracy:", accuracy)  # Output the final accuracy


### Question: Map the first label in the predicted labels back to the corresponding EC number.

## This cell mainly performs the following tasks: 
- Sequence similarity calculation: Implements the function calculate_identity, which calculates the similarity (match score) between two sequences based on global sequence alignment.
- Maximum similarity between test and training sets: For each sequence in the test set, calculates the maximum similarity with all sequences in the training set.
- Parallel acceleration: Uses multiprocessing.Pool and tqdm to calculate the maximum similarity for each sequence in the test set in parallel, and displays a progress bar.
- Results output: Prints the highest similarity (as a percentage) between each test sequence and sequences in the training set.

In [None]:
# Define a function to calculate the similarity (match percentage) between two sequences
def calculate_identity(seq1, seq2):
    """
    Calculate the match percentage (similarity) between two sequences.
    Uses the global sequence alignment algorithm globalxx (match +1, mismatch 0, no other penalties).
    """
    # Perform global sequence alignment and return all possible alignments
    alignments = pairwise2.align.globalxx(seq1, seq2)

    # Retrieve the best alignment result (the highest score alignment)
    best_alignment = alignments[0]

    # Extract the number of matches and the total alignment length from the alignment result
    matches = best_alignment[2]  # index 2 is the number of matches
    total = best_alignment[4]    # index 4 is the total alignment length

    # Calculate the match percentage (the proportion of matched characters)
    identity_percent = (matches / total) * 100

    # Return the match percentage
    return identity_percent

# Example: Compare the match percentage of two sequences
seq1 = 'MTEITAAMVKELRESTGAGMMDCKNALSETQHEWFAAKRQGKLSPWITGRKTGQDEHILLMNDGWQ'
seq2 = 'MTEITAAMVKELRESTGAGMMDCKNALSETQHEWFAALLMNDGWQ'

# Calculate the match percentage for the example sequences
identity = calculate_identity(seq1, seq2)

# Print the match percentage
identity

In [None]:
# Define a function to calculate the maximum match percentage between a test sequence and all training sequences
def calculate_max_identity_for_test(test_seq, train_seqs):
    """
    For a given test sequence, calculate its match percentage with all sequences in the training set and return the highest match percentage.
    """
    # Iterate over all training sequences, calculate the match percentage, and return the maximum value
    similarities = [calculate_identity(test_seq, train_seq) for train_seq in train_seqs]
    return max(similarities)

# Use partial to fix the training set parameter and generate a new function for parallel computation
calculate_max_identity_fn = partial(calculate_max_identity_for_test, train_seqs=Train_seq)

# Use a multiprocessing pool to calculate the maximum match percentage for each test sequence in parallel
with multiprocessing.Pool(processes=60) as pool:
    # Use tqdm to show the progress bar and parallelize the calculation
    results = list(tqdm(pool.map(calculate_max_identity_fn, Test_seq), total=len(Test_seq)))

# Print the maximum match percentage for each test sequence
for i, test_seq in enumerate(Test_seq):
    print(test_seq)
    print(f"Test sequence {i + 1}: Max identity = {results[i]:.2f}%")

In [None]:
# Print the length and content of the highest similarity results between the test set sequences and the training set
print('Test sequences maximum similarity with training set:', len(results), results)

# Print the length and content of the true labels for the test set
print('True labels:', len(Test_label), Test_label)

# Print the length and content of the predicted labels for the test set
print('Predicted labels:', len(Test_label_pred), Test_label_pred)

## This cell visualizes the prediction accuracy and the data count for each similarity interval using a dual-axis plot.
- The cell computes the accuracy of predictions and the count of samples within specific similarity intervals, then plots these values on a dual Y-axis chart.
- The left Y-axis shows the data count for each interval, while the right Y-axis displays the corresponding prediction accuracy.

In [None]:
# Importing the plotting and numerical calculation libraries
import matplotlib.pyplot as plt
import numpy as np

# Assuming results, y_true_test, y_pred_test are the data you provided
# Creating the list of intervals (0-10, 10-20, ..., 90-100)
intervals = np.arange(0, 101, 10)  # Creating similarity intervals from 0 to 100 with step size of 10
accuracy_per_interval = []  # List to store prediction accuracy for each interval
count_per_interval = []  # List to store the number of samples for each interval

# Looping through each interval
for i in range(len(intervals) - 1):
    # Getting the lower and upper bounds of the current interval
    lower_bound = intervals[i]  # Lower bound of the interval
    upper_bound = intervals[i + 1]  # Upper bound of the interval
    
    # Finding indices of samples within the current interval
    in_range_indices = np.where((np.array(results) >= lower_bound) & (np.array(results) < upper_bound))[0]
    
    # If there are samples in the current interval, calculate prediction accuracy
    if len(in_range_indices) > 0:
        # Getting the true labels and predicted labels for the current interval
        y_true_in_range = np.array(Test_label)[in_range_indices]
        y_pred_in_range = np.array(Test_label_pred)[in_range_indices]
        
        # Calculating accuracy
        accuracy = np.sum(y_true_in_range == y_pred_in_range) / len(y_true_in_range)
        accuracy_per_interval.append(accuracy)  # Storing the accuracy
        
        # Recording the number of samples in the current interval
        count_per_interval.append(len(in_range_indices))
    else:
        # If no data in the current interval, set accuracy to 0
        accuracy_per_interval.append(0)
        count_per_interval.append(0)

# Creating a dual-axis plot
fig, ax1 = plt.subplots(figsize=(10, 6))  # Creating a figure and axes

# Left Y-axis: Plotting the data count per interval (bar chart)
ax1.bar(intervals[:-1], count_per_interval, width=8, color='lightblue', label='Data Count', align='edge')
ax1.set_xlabel('Identity (%)')  # Setting X-axis label
ax1.set_ylabel('Data Count', color='b')  # Setting left Y-axis label
ax1.tick_params(axis='y', labelcolor='b')  # Setting color of left Y-axis tick labels

# Right Y-axis: Plotting prediction accuracy (line chart)
ax2 = ax1.twinx()  # Creating right Y-axis sharing the X-axis
ax2.plot(intervals[:-1], accuracy_per_interval, marker='o', linestyle='-', color='r', label='Accuracy')  # Plotting line chart
ax2.set_ylabel('Accuracy', color='r')  # Setting right Y-axis label
ax2.tick_params(axis='y', labelcolor='r')  # Setting color of right Y-axis tick labels

# Adding title and grid
plt.title('Prediction Accuracy and Data Count per Similarity Interval')  # Setting the title
ax1.grid(True)  # Enabling grid

# Displaying the legend
fig.tight_layout()  # Adjusting layout to prevent label overlap
plt.show()  # Displaying the plot


## Question: What conclusions can you draw from this figure?

## 3.Deep Learning Model Training

## This code handles the preparation of data for training and testing a model, including splitting datasets, saving sequence data, and setting up data loaders for training and validation.
- Data Conversion: Converts esm_embeddings and ecnumber_label into PyTorch tensors.
- Dataset Splitting: Splits the data into training and testing sets (80/20 split).
- Sequence Saving: Saves the training and testing sequences into pickle files.
- Dataset Creation: Creates training and validation datasets from the training data.
- Data Loaders: Sets up data loaders for the training and validation datasets.

In [None]:
# Convert esm_embeddings and ecnumber_label to PyTorch tensors
esm_embeddings = torch.tensor(esm_embeddings, dtype=torch.float32)
ecnumber_label = torch.tensor(ecnumber_label, dtype=torch.long)

# Split the data into training and testing sets (80% training, 20% testing)
Train_data, Test_data, Train_label, Test_label, Train_seq, Test_seq = train_test_split(
    esm_embeddings.numpy(), ecnumber_label.numpy(), sequence, test_size=0.2, random_state=42
)

# Convert the numpy arrays back to PyTorch tensors
Train_data = torch.tensor(Train_data, dtype=torch.float32)
Train_label = torch.tensor(Train_label, dtype=torch.long)
Test_data = torch.tensor(Test_data, dtype=torch.float32)
Test_label = torch.tensor(Test_label, dtype=torch.long)

# Save the training and testing sequences into separate pickle files
with open('./train_sequences.pkl', 'wb') as f:
    pickle.dump(Train_seq, f)
with open('./test_sequences.pkl', 'wb') as f:
    pickle.dump(Test_seq, f)

# Create a dataset from the training data and labels
full_train_dataset = TensorDataset(Train_data, Train_label)

# Split the dataset into training and validation sets (90% training, 10% validation)
train_size = int(0.9 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size
train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

# Define batch size for loading data
batch_size = 128

# Create data loaders for training and validation datasets
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

## This cell defines a neural network model, DeepEC, which is designed for predicting EC numbers from protein sequences. It uses a fully connected architecture with hidden layers, activation functions, batch normalization, and dropout for regularization.

In [None]:
# Define the model
class DeepEC(nn.Module):
    def __init__(self, seq_shape, num_classes, dropout=0.5):
        super().__init__()
        self.hidden_dims = [256, 128, 64]
        self.network = nn.Sequential(
            nn.Linear(seq_shape, self.hidden_dims[0]),
            nn.ReLU(), nn.BatchNorm1d(self.hidden_dims[0]),
            nn.Linear(self.hidden_dims[0], self.hidden_dims[1]),
            nn.ReLU(), nn.Dropout(dropout),
            nn.Linear(self.hidden_dims[1], self.hidden_dims[2])
        )
        self.fc = nn.Linear(self.hidden_dims[-1], num_classes)

    def forward(self, x):
        x = self.network(x)
        return self.fc(x)

## This cell sets up the training loop for the DeepEC model, which is trained on a dataset to predict EC numbers. It includes model initialization, loss calculation, optimization, and evaluation metrics (accuracy, training and validation loss) over multiple epochs.
- Model Setup: Defines the device for training (GPU if available), initializes the model, loss function (CrossEntropyLoss), optimizer (Adam), and learning rate scheduler.
- Training Loop: Runs for 200 epochs, performing forward passes, loss calculation, backpropagation, and optimization.
- Evaluation: After each epoch, calculates training loss, validation loss, and accuracy, updating the learning rate if necessary based on validation loss.
- Plotting: Visualizes the training/validation loss and accuracy across epochs for model evaluation.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DeepEC(seq_shape=seq_shape, num_classes=len(label_to_index)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-6)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=10, verbose=True)

# Training loop
num_epochs = 200
train_losses, val_losses, accuracy_values = [], [], []

for epoch in range(num_epochs):
    model.train()
    epoch_train_loss = []
    for data, label in train_loader:
        data, label = data.to(device), label.to(device)
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, label)
        loss.backward()
        optimizer.step()
        epoch_train_loss.append(loss.item())

    model.eval()
    epoch_val_loss, y_true, y_pred = [], [], []
    with torch.no_grad():
        for val_data, val_labels in val_loader:
            val_data, val_labels = val_data.to(device), val_labels.to(device)
            outputs = model(val_data)
            val_loss = criterion(outputs, val_labels)
            epoch_val_loss.append(val_loss.item())
            y_true.append(val_labels)
            y_pred.append(outputs.argmax(dim=1))

    avg_train_loss = np.mean(epoch_train_loss)
    avg_val_loss = np.mean(epoch_val_loss)
    accuracy_value = (torch.cat(y_true) == torch.cat(y_pred)).float().mean().item()

    train_losses.append(avg_train_loss)
    val_losses.append(avg_val_loss)
    accuracy_values.append(accuracy_value)
    scheduler.step(avg_val_loss)

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Accuracy: {accuracy_value:.4f}")

# Plotting and evaluation
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.legend()
plt.title('Loss over Epochs')
plt.subplot(1, 2, 2)
plt.plot(accuracy_values, label='Accuracy')
plt.legend()
plt.title('Accuracy over Epochs')
plt.tight_layout()
plt.show()


## This code evaluates the performance of the trained model on the test set by predicting class labels and calculating the accuracy.
- Evaluation Mode: Switches the model to evaluation mode to disable dropout and batch normalization updates.
- Inference: Iterates over the test set in batches, performs forward passes to obtain predictions, and stores the true and predicted labels.
- Accuracy Calculation: Compares the predicted labels with the true labels and calculates the test accuracy.
- Output: Prints the calculated test accuracy.

In [None]:
### Predicting results for the test set
# Switch to evaluation mode
model.eval()

# Initialize lists for storing true labels and predicted labels
y_true_test, y_pred_test = [], []

# Disable gradient calculation during inference
with torch.no_grad():
    # Iterate over the test set
    for test_data, test_label in DataLoader(TensorDataset(Test_data, Test_label), batch_size=batch_size):
        test_data, test_label = test_data.to(device), test_label.to(device)
        
        # Model prediction
        outputs = model(test_data)
        
        # Get the predicted class labels
        preds = outputs.argmax(dim=1)
        
        # Append true and predicted labels to the lists
        y_true_test.append(test_label)
        y_pred_test.append(preds)

# Convert the lists to tensors
y_true_test = torch.cat(y_true_test)
y_pred_test = torch.cat(y_pred_test)

# Calculate accuracy
test_accuracy = (y_true_test == y_pred_test).float().mean().item()

# Print the test accuracy
print(f"Test Accuracy: {test_accuracy * 100:.2f}%")

### Question: What factors are limiting the accuracy here?

### This cell calculates the highest similarity between each sequence in the test set and the sequences in the training set. The function is defined earlier.

In [None]:
# Use the partial function to fix the train_seqs parameter
calculate_max_identity_fn = partial(calculate_max_identity_for_test, train_seqs=Train_seq)

# Use a multiprocessing pool to parallelize the computation and display progress
with multiprocessing.Pool(processes=60) as pool:
    results = list(tqdm(pool.map(calculate_max_identity_fn, Test_seq), total=len(Test_seq)))

# Print the results for each test sequence
for i, test_seq in enumerate(Test_seq):
    print(test_seq)
    print(f"Test sequence {i + 1}: Max identity = {results[i]:.2f}%")

In [None]:
# Print the length and contents of the 'results' list
print(len(results), results)

# Print the length and contents of the 'y_true_test' list (true labels)
print(len(y_true_test), y_true_test)

# Print the length and contents of the 'y_pred_test' list (predicted labels)
print(len(y_pred_test), y_pred_test)

## This cell visualizes the prediction accuracy and the data count for each similarity interval using a dual-axis plot.
- The cell computes the accuracy of predictions and the count of samples within specific similarity intervals, then plots these values on a dual Y-axis chart.
- The left Y-axis shows the data count for each interval, while the right Y-axis displays the corresponding prediction accuracy.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

intervals = np.arange(0, 101, 10)
accuracy_per_interval = []
count_per_interval = []

# Iterate over each interval
for i in range(len(intervals) - 1):
    # Get the current interval range
    lower_bound = intervals[i]
    upper_bound = intervals[i + 1]
    
    # Find the indices within the current interval
    in_range_indices = np.where((np.array(results) >= lower_bound) & (np.array(results) < upper_bound))[0]
    
    # If there is data in the current interval, calculate the prediction accuracy
    if len(in_range_indices) > 0:

        # Get the true and predicted labels for the current interval
        y_true_in_range = np.array(y_true_test.cpu())[in_range_indices]
        y_pred_in_range = np.array(y_pred_test.cpu())[in_range_indices]
        
        # Calculate the accuracy
        accuracy = np.sum(y_true_in_range == y_pred_in_range) / len(y_true_in_range)
        accuracy_per_interval.append(accuracy)
        
        # Record the sample count for the current interval
        count_per_interval.append(len(in_range_indices))
    else:
        # If there is no data in the current interval, set accuracy to NaN or 0
        accuracy_per_interval.append(0)
        count_per_interval.append(0)

# Create a dual Y-axis plot
fig, ax1 = plt.subplots(figsize=(10, 6))

# Left Y-axis: plot the data count for each interval (bar chart)
ax1.bar(intervals[:-1], count_per_interval, width=8, color='lightblue', label='Data Count', align='edge')
ax1.set_xlabel('Identity (%)')
ax1.set_ylabel('Data Count', color='b')
ax1.tick_params(axis='y', labelcolor='b')

# Right Y-axis: plot the prediction accuracy (line plot)
ax2 = ax1.twinx()
ax2.plot(intervals[:-1], accuracy_per_interval, marker='o', linestyle='-', color='r', label='Accuracy')
ax2.set_ylabel('Accuracy', color='r')
ax2.tick_params(axis='y', labelcolor='r')

# Add title and grid
plt.title('Prediction Accuracy and Data Count per Similarity Interval')
ax1.grid(True)

# Display the legend
fig.tight_layout()
plt.show()


## 4.Prediction Using Published Models

## 4.1 CLEAN
- You can directly use ./CLEAN/results/ecoli_genome_clean_result.json, which contains the results of EC number prediction on the Escherichia coli genome generated by the CLEAN model.
- If you want to perform EC number prediction using the CLEAN model, you can go to ./CLEAN/BIO.00.CLEAN_genome.ipynb

## 4.2 DeepECtransformer
- ./ecoli_genome_deepprozyme_result is the result file generated by the pre-trained DeepECtransformer model for EC number prediction on the Escherichia coli genome

## 5.Result Analysis and Comparison

## The goal of this cell is to compare the prediction results of CLEAN and DeepECTransformer with the data from the UniProt database.
- First, load the UniProt data and perform some cleaning tasks, such as filling missing values, extracting relevant columns, 
- and preprocessing EC numbers and gene names. Afterward, the code formats the data accordingly for further analysis or comparison tasks.

In [None]:
# Task: Compare the prediction results of CLEAN and DeepECTransformer with the UniProt database
# Input file paths
uniprot_file = './uniprotkb_taxonomy_id_83333_2024_02_19.tsv'  # UniProt data file
clean_result_file = './ecoli_genome_clean_result.json'  # CLEAN prediction results
deepectransformer_result_file = './ecoli_genome_deepprozyme_result.json'  # DeepECTransformer prediction results
gene2entry_mapping_file = './iML1515_gene2entry_mapping.json'  # Gene to UniProt entry mapping file

In [None]:
# Read the UniProt data file
uniprot_data = pd.read_csv(uniprot_file, sep='\t')  # Read UniProt file using tab as the delimiter
uniprot_data.fillna('', inplace=True)  # Fill missing values with empty strings
uniprot_data = uniprot_data[['Entry', 'Gene Names', 'EC number']]  # Extract necessary columns (UniProt entry, gene names, and EC number)

# Process the Gene Names column by splitting gene names into lists (using space as the delimiter)
uniprot_data['Gene Names'] = uniprot_data['Gene Names'].apply(lambda x: x.split(' '))

# Process the EC number column by splitting EC numbers into lists (using '; ' as the delimiter) and removing empty values
uniprot_data['EC number'] = uniprot_data['EC number'].apply(lambda x: x.split('; '))  # Split EC numbers by '; '
uniprot_data['EC number'] = uniprot_data['EC number'].apply(lambda x: [ec for ec in x if ec != ''])  # Remove empty EC numbers

# View the processed UniProt data
uniprot_data  # Output the cleaned data

## The purpose of this code is to load the prediction result files from DeepECTransformer and CLEAN, and check the number of results contained in each file. 

In [None]:
# Load the DeepECTransformer prediction results and check the number of data entries
with open(deepectransformer_result_file, 'r') as f:  # Open the DeepECTransformer result file
    deepectransformer_result = json.load(f)  # Read and load the JSON file contents

print(len(deepectransformer_result))  # Print the number of data entries in the DeepECTransformer result
print(deepectransformer_result)  # Print the DeepECTransformer result

# Load the CLEAN prediction results and check the number of data entries
with open(clean_result_file, 'r') as f:  # Open the CLEAN result file
    clean_result = json.load(f)  # Read and load the JSON file contents

print(len(clean_result))  # Print the number of data entries in the CLEAN result
print(clean_result)  # Print the CLEAN result

### Question: What are the prediction results for b3045 using DeepECTransformer and CLEAN, respectively?

In [None]:
# Define a function to retrieve the EC number corresponding to a given UniProt entry
def entry2ECnumber(entry, uniprot_data):
    tmp = uniprot_data[uniprot_data['Entry'] == entry]  # Search for the matching entry in the UniProt data
    return tmp['EC number'].to_list()[0]  # Return the EC number of the matching entry, converted to a list and get the first element

# Define a function to retrieve the EC number corresponding to a given gene name
def gene2ECnumber(gene, gene2entry_mapping, uniprot_data):
    if gene in gene2entry_mapping:  # Check if the gene exists in the gene-to-entry mapping
        entry = gene2entry_mapping[gene]  # Get the corresponding UniProt entry
        ECnumber = entry2ECnumber(entry, uniprot_data)  # Use entry2ECnumber function to get the EC number
        return ECnumber  # Return the EC number
    else:
        return '-'  # If no matching gene is found, return '-'

# Read the gene-to-UniProt entry mapping dictionary from the JSON file
with open(gene2entry_mapping_file, 'r') as f:  # Open the file to read its contents
    gene2entry_mapping = json.load(f)  # Load the content of the file into a dictionary

gene2entry_mapping  # Output the gene-to-UniProt entry mapping dictionary
# Create a list of genes, containing all genes from the gene-to-UniProt entry mapping
model_gene_lst = list(gene2entry_mapping.keys())
print(model_gene_lst)

## This code iterates through the gene list 'model_gene_lst' and collects relevant information for each gene, including the UniProt entry, EC number, and the prediction results from CLEAN and DeepECTransformer. 
- If any relevant information for a gene is missing, the missing value is represented by '-'. 
- Finally, all the collected data is stored in a dictionary called 'result', which contains the gene name, UniProt entry, UniProt EC number, and the prediction results from CLEAN and DeepECTransformer.

In [None]:
# Create an empty dictionary to store gene information and related prediction results
result = {
    'gene': [],  # Store gene names
    'UniProt_entry': [],  # Store corresponding UniProt entry for the gene
    'UniProt_EC': [],  # Store the corresponding EC number for the gene
    'CLEAN': [],  # Store CLEAN prediction results
    'DeepECtransformer': []  # Store DeepECTransformer prediction results
}

# Use tqdm to display a progress bar while iterating through the gene list 'model_gene_lst'
for i in tqdm(model_gene_lst):
    result['gene'].append(i)  # Add the current gene name to the result dictionary

    # Check if the gene exists in the gene-to-UniProt entry mapping
    if i in gene2entry_mapping:
        result['UniProt_entry'].append(gene2entry_mapping[i])  # Add the corresponding UniProt entry
    else:
        result['UniProt_entry'].append('-')  # If no entry is found, append a placeholder '-'

    # Get the EC number for the gene (use placeholder if missing)
    result['UniProt_EC'].append(gene2ECnumber(i, gene2entry_mapping, uniprot_data))

    # Check if CLEAN prediction result is available, add prediction or placeholder
    if i in clean_result:
        result['CLEAN'].append(clean_result[i])  # Add CLEAN prediction result
    else:
        result['CLEAN'].append('-')  # If no result is available, append a placeholder

    # Check if DeepECTransformer prediction result is available, add prediction or placeholder
    if i in deepectransformer_result:
        result['DeepECtransformer'].append(deepectransformer_result[i])  # Add DeepECTransformer prediction result
    else:
        result['DeepECtransformer'].append('-')  # If no result is available, append a placeholder

# Convert the result dictionary to a DataFrame
result_df = pd.DataFrame(result)

# Filter out rows with no UniProt entry (those with a UniProt_entry value of '-')
result_df = result_df[result_df['UniProt_entry'] != '-']

result_df.head(5)  # Display the first 5 rows of the resulting DataFrame

## This code evaluates the performance of the CLEAN and DeepECTransformer models in predicting EC numbers for given protein sequences and compares their predictions with a database to assess their ability to distinguish enzymes from non-enzymes. It calculates and visualizes the True Positives (TP), True Negatives (TN), False Positives (FP), and False Negatives (FN) for both models, along with the corresponding confusion matrices.

- Calculates TP, TN, FP, FN for DeepECTransformer and plots the confusion matrix.
- Calculates TP, TN, FP, FN for CLEAN and plots the confusion matrix.

## DeepECtransformer

In [None]:
# Calculate True Positive (TP): Correctly predicted positive cases
TP = len(result_df[(result_df['UniProt_EC'].apply(len) > 0) & (result_df['DeepECtransformer'] != '-')])

# Calculate False Positive (FP): Predicted as positive but actually negative
FP = len(result_df[(result_df['UniProt_EC'].apply(len) == 0) & (result_df['DeepECtransformer'] != '-')])

# Calculate True Negative (TN): Predicted as negative and actually negative
TN = len(result_df[(result_df['UniProt_EC'].apply(len) == 0) & (result_df['DeepECtransformer'] == '-')])

# Calculate False Negative (FN): Predicted as negative but actually positive
FN = len(result_df[(result_df['UniProt_EC'].apply(len) > 0) & (result_df['DeepECtransformer'] == '-')])

# Output the components of the confusion matrix
print(TP)  # Output the number of True Positives
print(FP)  # Output the number of False Positives
print(TN)  # Output the number of True Negatives
print(FN)  # Output the number of False Negatives

In [None]:
# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import MaxNLocator

# Create the confusion matrix
confusion_matrix1 = np.array([[TP, FP], [FN, TN]])

# Set figure size and resolution
fig = plt.figure(figsize=(2, 2), dpi=300)

# Define subplot layout
gs = GridSpec(1, 2, width_ratios=[4, 0.1])

# Plot the main confusion matrix
ax1 = plt.subplot(gs[0])
im1 = ax1.imshow(confusion_matrix1, cmap='PuBu', interpolation='nearest', vmin=1, vmax=1400, aspect='auto')

# Hide the axes
ax1.set_xticks([])
ax1.set_yticks([])
ax1.grid(False)

# Add text labels for matrix elements
ax1.text(0, 0, f"TP\n{confusion_matrix1[0, 0]}", ha='center', va='center', fontsize=6, color='black')
ax1.text(1, 0, f"FP\n{confusion_matrix1[0, 1]}", ha='center', va='center', fontsize=6, color='black')
ax1.text(0, 1, f"FN\n{confusion_matrix1[1, 0]}", ha='center', va='center', fontsize=6, color='black')
ax1.text(1, 1, f"TN\n{confusion_matrix1[1, 1]}", ha='center', va='center', fontsize=6, color='black')

# Assuming Accuracy is defined, calculate the accuracy
Accuracy = (TP + TN) / (TP + FP + FN + TN)  # Example calculation
ax1.set_xlabel(f"DeepECtransformer ACC = {Accuracy:.2f}", fontsize=6, labelpad=1)

# Add a border around the image
for spine in ax1.spines.values():
    spine.set_edgecolor('black')
    spine.set_linewidth(0.2)

# Plot the color bar
cbar_ax = plt.subplot(gs[1])
cbar = fig.colorbar(im1, cax=cbar_ax)
cbar.ax.tick_params(labelsize=6)

# Set the color bar scale
cbar.ax.yaxis.set_major_locator(MaxNLocator(nbins=6))

# Add a border around the color bar
for spine in cbar_ax.spines.values():
    spine.set_edgecolor('black')
    spine.set_linewidth(0.2)

# Adjust the spacing between subplots
plt.subplots_adjust(wspace=0.02)

# Display the confusion matrix plot
plt.show()


## CLEAN

In [None]:
# Calculate True Positive (TP): Correctly predicted positive cases
TP = len(result_df[(result_df['UniProt_EC'].apply(len) > 0) & (result_df['CLEAN'] != '-')])

# Calculate False Positive (FP): Predicted as positive but actually negative
FP = len(result_df[(result_df['UniProt_EC'].apply(len) == 0) & (result_df['CLEAN'] != '-')])

# Calculate True Negative (TN): Predicted as negative and actually negative
TN = len(result_df[(result_df['UniProt_EC'].apply(len) == 0) & (result_df['CLEAN'] == '-')])

# Calculate False Negative (FN): Predicted as negative but actually positive
FN = len(result_df[(result_df['UniProt_EC'].apply(len) > 0) & (result_df['CLEAN'] == '-')])

# Output the components of the confusion matrix
print(TP)  # Output the number of True Positives
print(FP)  # Output the number of False Positives
print(TN)  # Output the number of True Negatives
print(FN)  # Output the number of False Negatives

In [None]:
# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import MaxNLocator

# Create the confusion matrix
confusion_matrix1 = np.array([[TP, FP], [FN, TN]])

# Set figure size and resolution
fig = plt.figure(figsize=(2, 2), dpi=300)

# Define the subplot layout
gs = GridSpec(1, 2, width_ratios=[4, 0.1])

# Plot the main confusion matrix
ax1 = plt.subplot(gs[0])
im1 = ax1.imshow(confusion_matrix1, cmap='PuBu', interpolation='nearest', vmin=1, vmax=1400, aspect='auto')

# Hide axes ticks
ax1.set_xticks([])
ax1.set_yticks([])
ax1.grid(False)

# Add text labels for matrix elements
ax1.text(0, 0, f"TP\n{confusion_matrix1[0, 0]}", ha='center', va='center', fontsize=6, color='black')
ax1.text(1, 0, f"FP\n{confusion_matrix1[0, 1]}", ha='center', va='center', fontsize=6, color='black')
ax1.text(0, 1, f"FN\n{confusion_matrix1[1, 0]}", ha='center', va='center', fontsize=6, color='black')
ax1.text(1, 1, f"TN\n{confusion_matrix1[1, 1]}", ha='center', va='center', fontsize=6, color='black')

# Calculate accuracy (assuming the variable `Accuracy` is defined)
Accuracy = (TP + TN) / (TP + FP + FN + TN)  # Example calculation
ax1.set_xlabel(f"CLEAN ACC = {Accuracy:.2f}", fontsize=6, labelpad=1)

# Add borders to the image
for spine in ax1.spines.values():
    spine.set_edgecolor('black')
    spine.set_linewidth(0.2)

# Plot the colorbar
cbar_ax = plt.subplot(gs[1])
cbar = fig.colorbar(im1, cax=cbar_ax)
cbar.ax.tick_params(labelsize=6)

# Set colorbar ticks
cbar.ax.yaxis.set_major_locator(MaxNLocator(nbins=6))

# Add borders to the colorbar
for spine in cbar_ax.spines.values():
    spine.set_edgecolor('black')
    spine.set_linewidth(0.2)

# Adjust spacing between subplots
plt.subplots_adjust(wspace=0.02)

# Display the confusion matrix plot
plt.show()

### Question: The CLEAN method predicts an EC number for all protein sequences. Why is the False Negative (FN) in the confusion matrix for CLEAN equal to 1 instead of 0?

### Question: The CLEAN method predicts an EC number for all protein sequences. What issues could this cause?

## This code matches the prediction results with the actual data and evaluates their accuracy by comparing the predictions of CLEAN and DeepECTransformer row by row.

In [None]:
# Calculate whether CLEAN prediction results match the actual EC numbers
# For each row, check if any element in row['CLEAN'] exists in row['UniProt_EC']
result_df['CLEAN_res'] = result_df.apply(
    lambda row: any(x in row['UniProt_EC'] for x in row['CLEAN']), axis=1
)

# Calculate whether DeepECTransformer prediction results match the actual EC numbers
# For each row, check if any element in row['DeepECTransformer'] exists in row['UniProt_EC']
result_df['DeepECtransformer_res'] = result_df.apply(
    lambda row: any(x in row['UniProt_EC'] for x in row['DeepECtransformer']), axis=1
)

# Calculate the matching ratio for CLEAN: the number of matching rows divided by the total number of rows
print('CLEAN:', len(result_df[result_df['CLEAN_res'] == True]) / len(result_df))

# Calculate the matching ratio for DeepECTransformer: the number of matching rows divided by the total number of rows
print('DeepECtransformer:', len(result_df[result_df['DeepECtransformer_res'] == True]) / len(result_df))

## Question: If considering only the first digit, first two digits, or first three digits of the EC number,what are the prediction accuracies of CLEAN and DeepECTransformer respectively?

In [None]:
### Type your code here