# Protbert, CNN, LSTM Classification of Protein Sequences by Tropism

In [None]:
import torch
import matplotlib.pyplot as plt
from torch import nn
from transformers import AutoTokenizer, AutoModel
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, mean_squared_error, roc_curve, auc
import seaborn as sns
import pandas as pd
from Bio import SeqIO
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm import tqdm
import random

### Define functions and models for Protbert, CNN, and LSTM

In [8]:
# Set device to GPU if available, else fallback to CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Step 1: Read your FASTA file and prepare data
def read_fasta(file_path):
    """Reads a FASTA file and returns protein sequences and labels."""
    sequences = []
    labels = []
    
    for record in SeqIO.parse(file_path, "fasta"):
        sequences.append(str(record.seq))
        label = get_tropism_label(record.id)
        labels.append(label)
    
    return sequences, labels

def get_tropism_label(record_id):
    """Assign tropism labels based on sequence ID or other logic."""
    if "TrophismTypeA" in record_id:
        return 0
    elif "TrophismTypeB" in record_id:
        return 1
    else:
        return 2  # Default case, adjust as needed

# Step 2: Create a custom dataset class
class ProteinDataset(Dataset):
    def __init__(self, sequences, labels, tokenizer, max_length=512):
        self.sequences = sequences
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        label = self.labels[idx]
        
        # Tokenize the protein sequence
        encoding = self.tokenizer(
            sequence,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        
        item = {key: encoding[key].squeeze(0) for key in encoding}
        item['labels'] = torch.tensor(label, dtype=torch.long)
        return item

# Step 3: Define the CNN and LSTM models separately
class CNN_Model(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=256, dropout=0.5):
        super(CNN_Model, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=hidden_dim, kernel_size=3, padding=1)
        self.pool = nn.MaxPool1d(2)
        self.fc1 = nn.Linear(hidden_dim, 128)
        self.fc2 = nn.Linear(128, 3)  # Change 3 to match number of classes
    
    def forward(self, x):
        x = x.permute(0, 2, 1)  # Change to (batch_size, channels, seq_len)
        x = self.conv1(x)
        x = self.pool(x)
        x = x.permute(0, 2, 1)  # Change to (batch_size, seq_len, channels)
        x = x.mean(dim=1)  # Global average pooling
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        return x

class LSTM_Model(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=256, dropout=0.5):
        super(LSTM_Model, self).__init__()
        self.lstm = nn.LSTM(input_size=input_dim, hidden_size=hidden_dim, num_layers=1, batch_first=True)
        self.fc1 = nn.Linear(hidden_dim, 128)
        self.fc2 = nn.Linear(128, 3)  # Change 3 to match number of classes
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x, (hn, cn) = self.lstm(x)
        x = self.dropout(hn[-1])  # Use the last hidden state
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        return x

# Step 4: Define the ProtBERT model for fine-tuning
class ProtBERT_Model(nn.Module):
    def __init__(self, transformer_model):
        super(ProtBERT_Model, self).__init__()
        self.transformer_model = transformer_model
        self.fc = nn.Linear(768, 3)  # Change 3 to match number of classes
    
    def forward(self, input_ids, attention_mask):
        transformer_outputs = self.transformer_model(input_ids, attention_mask=attention_mask)
        hidden_states = transformer_outputs.last_hidden_state  # Shape: (batch_size, seq_len, 768)
        x = hidden_states.mean(dim=1)  # Global average pooling
        x = self.fc(x)
        return x

# Step 5: Train the ProtBERT model separately
def train_protbert(model, train_dataloader, val_dataloader, optimizer, criterion, num_epochs=10):
    protbert_train_losses = []
    protbert_val_losses = []
    
    for epoch in range(num_epochs):
        model.train()  # Set the model to training mode
        train_loss = 0.0
        for batch in train_dataloader:
            optimizer.zero_grad()
            inputs = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(inputs, labels=labels)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()

        # Record the average training loss for this epoch
        protbert_train_losses.append(train_loss / len(train_dataloader))

        # Validation phase
        model.eval()  # Set the model to evalua

        val_loss = 0.0
        with torch.no_grad():
            for batch in val_dataloader:
                inputs = batch['input_ids'].to(device)
                labels = batch['labels'].to(device)

                outputs = model(inputs, labels=labels)
                loss = outputs.loss

                val_loss += loss.item()

        # Record the average validation loss for this epoch
        protbert_val_losses.append(val_loss / len(val_dataloader))

    return protbert_train_losses, protbert_val_losses

# Step 6: Train the CNN/LSTM model
def train_cnn_lstm(cnn_model, lstm_model, train_dataloader, val_dataloader, optimizer_cnn, optimizer_lstm, criterion, num_epochs=10):
    cnn_train_losses = []
    cnn_val_losses = []
    lstm_train_losses = []
    lstm_val_losses = []
    
    for epoch in range(num_epochs):
        # Train CNN
        cnn_model.train()  # Set CNN model to training mode
        cnn_train_loss = 0.0
        for batch in train_dataloader:
            optimizer_cnn.zero_grad()
            inputs, labels = batch
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs_cnn = cnn_model(inputs)
            loss_cnn = criterion(outputs_cnn, labels)
            loss_cnn.backward()
            optimizer_cnn.step()

            cnn_train_loss += loss_cnn.item()

        # Record the average CNN training loss
        cnn_train_losses.append(cnn_train_loss / len(train_dataloader))

        # Train LSTM
        lstm_model.train()  # Set LSTM model to training mode
        lstm_train_loss = 0.0
        for batch in train_dataloader:
            optimizer_lstm.zero_grad()
            inputs, labels = batch
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs_lstm = lstm_model(inputs)
            loss_lstm = criterion(outputs_lstm, labels)
            loss_lstm.backward()
            optimizer_lstm.step()

            lstm_train_loss += loss_lstm.item()

        # Record the average LSTM training loss
        lstm_train_losses.append(lstm_train_loss / len(train_dataloader))

        # Validation phase for CNN
        cnn_model.eval()
        cnn_val_loss = 0.0
        with torch.no_grad():
            for batch in val_dataloader:
                inputs, labels = batch
                inputs, labels = inputs.to(device), labels.to(device)

                outputs_cnn = cnn_model(inputs)
                loss_cnn = criterion(outputs_cnn, labels)
                cnn_val_loss += loss_cnn.item()

        # Record the average CNN validation loss
        cnn_val_losses.append(cnn_val_loss / len(val_dataloader))

        # Validation phase for LSTM
        lstm_model.eval()
        lstm_val_loss = 0.0
        with torch.no_grad():
            for batch in val_dataloader:
                inputs, labels = batch
                inputs, labels = inputs.to(device), labels.to(device)

                outputs_lstm = lstm_model(inputs)
                loss_lstm = criterion(outputs_lstm, labels)
                lstm_val_loss += loss_lstm.item()

        # Record the average LSTM validation loss
        lstm_val_losses.append(lstm_val_loss / len(val_dataloader))

    return cnn_train_losses, cnn_val_losses, lstm_train_losses, lstm_val_losses

# Step 7: Evaluate the model
def evaluate_model(model, eval_dataloader, device):
    model.eval()
    predictions, true_labels = [], []
    with torch.no_grad():
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            batch = {key: value.to(device) for key, value in batch.items()}
            outputs = model(**batch)
            preds = torch.argmax(outputs, dim=1)
            
            predictions.extend(preds.cpu().numpy())
            true_labels.extend(batch['labels'].cpu().numpy())
    
    report = classification_report(true_labels, predictions, target_names=["Type A", "Type B", "Type C"])
    print(report)

class WeightedLoss(nn.Module):
    def __init__(self, weights):
        super(WeightedLoss, self).__init__()
        self.weights = weights
    
    def forward(self, outputs, targets):
        loss = nn.CrossEntropyLoss(weight=self.weights)
        return loss(outputs, targets)

def augment_sequence(sequence):
    # Introduce random mutations (simple example)
    mutated_sequence = list(sequence)
    for i in range(len(sequence)):
        if random.random() < 0.1:  # 10% chance to mutate
            mutated_sequence[i] = random.choice('ACGT')
    return ''.join(mutated_sequence)

def plot_loss_curve(protbert_losses, cnn_losses, lstm_losses,
                    protbert_val_losses, cnn_val_losses, lstm_val_losses):
    plt.figure(figsize=(10, 6))
    
    # Plot the losses for ProtBERT
    plt.plot(protbert_losses, label='ProtBERT Training Loss', color='blue')
    plt.plot(protbert_val_losses, label='ProtBERT Validation Loss', color='blue', linestyle='--')

    # Plot the losses for CNN
    plt.plot(cnn_losses, label='CNN Training Loss', color='green')
    plt.plot(cnn_val_losses, label='CNN Validation Loss', color='green', linestyle='--')

    # Plot the losses for LSTM
    plt.plot(lstm_losses, label='LSTM Training Loss', color='red')
    plt.plot(lstm_val_losses, label='LSTM Validation Loss', color='red', linestyle='--')
    
    plt.title('Loss vs Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

    # Assuming you have true labels and predictions for each model
def plot_confusion_matrix(true_labels, protbert_preds, cnn_preds, lstm_preds):
    # ProtBERT Confusion Matrix
    protbert_cm = confusion_matrix(true_labels, protbert_preds)
    plt.figure(figsize=(6, 6))
    sns.heatmap(protbert_cm, annot=True, fmt="d", cmap="Blues", 
                xticklabels=['Class 0', 'Class 1', 'Class 2'], 
                yticklabels=['Class 0', 'Class 1', 'Class 2'])
    plt.title('ProtBERT Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()

    # CNN Confusion Matrix
    cnn_cm = confusion_matrix(true_labels, cnn_preds)
    plt.figure(figsize=(6, 6))
    sns.heatmap(cnn_cm, annot=True, fmt="d", cmap="Blues", 
                xticklabels=['Class 0', 'Class 1', 'Class 2'], 
                yticklabels=['Class 0', 'Class 1', 'Class 2'])
    plt.title('CNN Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()

    # LSTM Confusion Matrix
    lstm_cm = confusion_matrix(true_labels, lstm_preds)
    plt.figure(figsize=(6, 6))
    sns.heatmap(lstm_cm, annot=True, fmt="d", cmap="Blues", 
                xticklabels=['Class 0', 'Class 1', 'Class 2'], 
                yticklabels=['Class 0', 'Class 1', 'Class 2'])
    plt.title('LSTM Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()

def plot_metrics(true_labels, protbert_preds, cnn_preds, lstm_preds):
    # Generate classification report
    protbert_report = classification_report(true_labels, protbert_preds, output_dict=True)
    cnn_report = classification_report(true_labels, cnn_preds, output_dict=True)
    lstm_report = classification_report(true_labels, lstm_preds, output_dict=True)

    # Extract metrics
    metrics = ['accuracy', 'precision', 'recall', 'f1-score']
    protbert_scores = [protbert_report['accuracy'], protbert_report['macro avg']['precision'], 
                       protbert_report['macro avg']['recall'], protbert_report['macro avg']['f1-score']]
    cnn_scores = [cnn_report['accuracy'], cnn_report['macro avg']['precision'], 
                  cnn_report['macro avg']['recall'], cnn_report['macro avg']['f1-score']]
    lstm_scores = [lstm_report['accuracy'], lstm_report['macro avg']['precision'], 
                   lstm_report['macro avg']['recall'], lstm_report['macro avg']['f1-score']]

    # Plot bar chart
    x = np.arange(len(metrics))
    width = 0.2  # Bar width
    fig, ax = plt.subplots(figsize=(10, 6))

    ax.bar(x - width, protbert_scores, width, label='ProtBERT')
    ax.bar(x, cnn_scores, width, label='CNN')
    ax.bar(x + width, lstm_scores, width, label='LSTM')

    ax.set_ylabel('Scores')
    ax.set_title('Model Comparison: Accuracy, Precision, Recall, F1 Score')
    ax.set_xticks(x)
    ax.set_xticklabels(metrics)
    ax.legend()
    plt.show()

def plot_roc_curve(true_labels, protbert_preds, cnn_preds, lstm_preds):
    fpr_protbert, tpr_protbert, _ = roc_curve(true_labels, protbert_preds)
    fpr_cnn, tpr_cnn, _ = roc_curve(true_labels, cnn_preds)
    fpr_lstm, tpr_lstm, _ = roc_curve(true_labels, lstm_preds)

    roc_auc_protbert = auc(fpr_protbert, tpr_protbert)
    roc_auc_cnn = auc(fpr_cnn, tpr_cnn)
    roc_auc_lstm = auc(fpr_lstm, tpr_lstm)

    plt.figure(figsize=(10, 6))
    plt.plot(fpr_protbert, tpr_protbert, color='blue', label=f'ProtBERT AUC = {roc_auc_protbert:.2f}')
    plt.plot(fpr_cnn, tpr_cnn, color='green', label=f'CNN AUC = {roc_auc_cnn:.2f}')
    plt.plot(fpr_lstm, tpr_lstm, color='red', label=f'LSTM AUC = {roc_auc_lstm:.2f}')
    plt.plot([0, 1], [0, 1], color='gray', linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve')
    plt.legend(loc='lower right')
    plt.show()

Using device: cpu


## Data Preparation

In [None]:
# Step 8: Load and prepare data
tokenizer = AutoTokenizer.from_pretrained("Rostlab/prot_bert_bfd")
sequences, labels = read_fasta("your_protein_sequences.fasta")

# Split the data into training and evaluation sets
#need to create validation sequences and labels
train_sequences, eval_sequences, train_labels, eval_labels = train_test_split(sequences, labels, test_size=0.2)

# Create datasets and dataloaders
train_dataset = ProteinDataset(train_sequences, train_labels, tokenizer)
val_dataset = ProteinDataset(val_sequences, val_labels, tokenizer)
eval_dataset = ProteinDataset(eval_sequences, eval_labels, tokenizer)

train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False)
eval_dataloader = DataLoader(eval_dataset, batch_size=8, shuffle=False)

criterion = WeightedLoss(weights=torch.tensor([1.0, 1.0, 1.0]).to(device))
# Initialize optimizers for CNN and LSTM models
optimizer_cnn = torch.optim.Adam(CNN_Model.parameters(), lr=1e-3)
optimizer_lstm = torch.optim.Adam(LSTM_Model.parameters(), lr=1e-3)

NameError: name 'val_dataset' is not defined

## Implement ProtBERT

In [None]:
# Step 9: Initialize the ProtBERT model for fine-tuning
transformer_model = AutoModel.from_pretrained("Rostlab/prot_bert_bfd")
protbert_model = ProtBERT_Model(transformer_model)

# Move the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
protbert_model.to(device)

# Step 10: Set up optimizer for ProtBERT
optimizer = torch.optim.Adam(protbert_model.parameters(), lr=1e-5)

# Step 11: Fine-tune ProtBERT model
train_protbert(ProtBERT_Model, train_dataloader, val_dataloader, optimizer, criterion)

In [None]:
protbert_losses = []  # List of training losses for ProtBERT
cnn_losses = []       # List of training losses for CNN
lstm_losses = []      # List of training losses for LSTM

# Similarly, track validation losses
protbert_val_losses = []
cnn_val_losses = []
lstm_val_losses = []

# Assuming `test_labels` are the true labels for the test set:
protbert_preds = protbert_model.predict(eval_dataset)
cnn_preds = CNN_Model.predict(eval_dataset)
lstm_preds = LSTM_Model.predict(eval_dataset)

# True labels for the test set (ensure this matches the order of your data)
true_labels = eval_labels

# Assuming `cnn_model` and `lstm_model` are already defined, and `train_dataloader` and `val_dataloader` are prepared
# Also assuming optimizers for both CNN and LSTM models are defined (`optimizer_cnn` and `optimizer_lstm`)

cnn_train_losses, cnn_val_losses, lstm_train_losses, lstm_val_losses = train_cnn_lstm(
    CNN_Model, LSTM_Model, train_dataloader, val_dataloader, optimizer_cnn, optimizer_lstm, criterion, num_epochs=10
)

## Implement CNN

In [None]:
model_choice = 'cnn' 

if model_choice == 'cnn':
    cnn_model = CNN_Model()
    cnn_model.train() # Set model to training mode
    cnn_model.to(device)
    optimizer = torch.optim.Adam(cnn_model.parameters(), lr=1e-5)
    train_cnn_lstm(cnn_model, train_dataloader, optimizer, device, num_epochs=3)

# Set model to evaluation mode
cnn_model.eval()

evaluate_model(cnn_model if model_choice == 'cnn' else LSTM_Model)

## Implement LSTM

In [None]:
model_choice = 'LSTM' 

if model_choice == 'LSTM':
    lstm_model = LSTM_Model()
    lstm_model.train() # Set model to training mode
    lstm_model.to(device)
    optimizer = torch.optim.Adam(lstm_model.parameters(), lr=1e-5)
    train_cnn_lstm(lstm_model, train_dataloader, optimizer, device, num_epochs=3)

# Set model to evaluation mode
lstm_model.eval()

# Step 13: Evaluate the final model (CNN or LSTM)
evaluate_model(cnn_model if model_choice == 'cnn' else lstm_model)

## Visualize the results

In [None]:
# Call the loss vs epoch plot function
plot_loss_curve(protbert_losses, cnn_losses, lstm_losses,
                protbert_val_losses, cnn_val_losses, lstm_val_losses)

# Call the confusion matrix function
plot_confusion_matrix(true_labels, protbert_preds, cnn_preds, lstm_preds)

# Call the metrics comparison function
plot_metrics(true_labels, protbert_preds, cnn_preds, lstm_preds)

# Calculate ROC curve and AUC during validation
true_labels = []  # To store the true labels
preds = []  # To store the predicted scores

cnn_model.eval()  # Set model to evaluation mode
with torch.no_grad():
    for batch in val_dataloader:
        inputs, labels = batch
        inputs, labels = inputs.to(device), labels.to(device)

        # Forward pass
        outputs = cnn_model(inputs)

        # Store true labels and predicted scores
        true_labels.extend(labels.cpu().numpy())
        preds.extend(outputs.cpu().numpy())

# Calculate ROC and AUC
fpr, tpr, _ = roc_curve(true_labels, preds)
roc_auc = auc(fpr, tpr)

print(f"ROC AUC: {roc_auc}")

lstm_model.eval()  # Set model to evaluation mode
with torch.no_grad():
    for batch in val_dataloader:
        inputs, labels = batch
        inputs, labels = inputs.to(device), labels.to(device)

        # Forward pass
        outputs = lstm_model(inputs)

        # Store true labels and predicted scores
        true_labels.extend(labels.cpu().numpy())
        preds.extend(outputs.cpu().numpy())

# Calculate ROC and AUC
fpr, tpr, _ = roc_curve(true_labels, preds)
roc_auc = auc(fpr, tpr)

print(f"ROC AUC: {roc_auc}")

# Call the ROC curve function
plot_roc_curve(true_labels, protbert_preds, cnn_preds, lstm_preds)