In [None]:
import pandas as pd
from IPython.display import clear_output
import pickle
import gc
import pandas as pd
import pickle
import torch
from transformers import EsmForSequenceClassification, AdamW, AutoTokenizer, AutoModel
from torch.utils.data import DataLoader, TensorDataset, random_split
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm
import numpy as np
from torch.optim.lr_scheduler import StepLR
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import os

In [None]:
df = pd.read_csv("/home/aarya/uniref50.231020.80gaps.v2.csv", sep=",")
df = df.drop_duplicates(subset=['Sequence'], keep='first')
df

In [None]:
x = list(df["Sequence"].str.upper())
y = list(df["Family"])

In [None]:
# Modifiable parameters

# Name of ESM2 model be one listed on: https://github.com/facebookresearch/esm
esm_model_name = "facebook/esm2_t12_35M_UR50D"
# Suffix added to all files
file_naming = "35M_t12_famv2"
lr=3e-4
weight_decay = 1e-5 # Regularization strength for L2
l1_lambda = 1e-5  # Regularization strength for L1
batch_size = 32
num_epochs = 160
step_size = 15 # epoch step size for LR scheduler at which point LR is multiplied by gamma to produce a smoother learning rate
gamma = 0.6
device = 'cuda'

In [None]:
y_series = pd.Series(y)
class_counts = y_series.value_counts()
single_occurrence_classes = class_counts[class_counts == 1].index
mask = ~y_series.isin(single_occurrence_classes)
y_filtered = y_series[mask].tolist()
x_filtered = [x[i] for i in range(len(x)) if mask.iloc[i]]

In [None]:
from sklearn.utils.class_weight import compute_class_weight

# First, split the data into 70% training and 30% temp
x_train, x_temp, y_train, y_temp = train_test_split(x_filtered, y_filtered, test_size=0.3, shuffle=True, stratify=y_filtered, random_state=42)

# Then, split the temp data into 50% validation and 50% testing
x_val, x_test, y_val, y_test = train_test_split(x_temp, y_temp, test_size=0.5, shuffle=True, random_state=42)

# Tokenizer for sequences and label encoder for labels
tokenizer_esm = AutoTokenizer.from_pretrained(esm_model_name)
label_encoder = LabelEncoder()
label_encoder.fit(y)  # Fit the label encoder on the entire label set 'y'

# Transform the labels for train, val, and test sets
encoded_labels_train = label_encoder.transform(y_train)
encoded_labels_val = label_encoder.transform(y_val)
encoded_labels_test = label_encoder.transform(y_test)

# Convert to PyTorch tensors
labels_train = torch.tensor(encoded_labels_train)
labels_val = torch.tensor(encoded_labels_val)
labels_test = torch.tensor(encoded_labels_test)


class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
class_weights = torch.tensor(class_weights, dtype=torch.float)


# Save the labels so they can be reconstructed in other files
with open('/home/aarya/35M_t12_famv1_labels.pkl', 'wb') as file:
    pickle.dump(label_encoder, file)

In [None]:
# Formats data for model (train set) 
encoded_train_inputs = tokenizer_esm(x_train, padding=True, truncation=True, max_length=450, return_tensors="pt")
train_input_ids = encoded_train_inputs["input_ids"]
train_attention_mask = encoded_train_inputs["attention_mask"]
encoded_train_labels = label_encoder.transform(y_train)
train_labels = torch.tensor(encoded_train_labels)

train_dataset = TensorDataset(train_input_ids, train_attention_mask, train_labels)

# Formats data for model (validation set) 
encoded_val_inputs = tokenizer_esm(x_val, padding=True, truncation=True, max_length=450, return_tensors="pt")
val_input_ids = encoded_val_inputs["input_ids"]
val_attention_mask = encoded_val_inputs["attention_mask"]
encoded_val_labels = label_encoder.transform(y_val)
val_labels = torch.tensor(encoded_val_labels)
val_dataset = TensorDataset(val_input_ids, val_attention_mask, val_labels)

# Formats data for model (test set) 
encoded_test_inputs = tokenizer_esm(x_test, padding=True, truncation=True, max_length=450, return_tensors="pt")
test_input_ids = encoded_test_inputs["input_ids"]
test_attention_mask = encoded_test_inputs["attention_mask"]
encoded_test_labels = label_encoder.transform(y_test)
test_labels = torch.tensor(encoded_test_labels)

test_dataset = TensorDataset(test_input_ids, test_attention_mask, test_labels)

In [None]:
import torch
import torch.nn.functional as F

class FocalLoss(torch.nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduction == 'mean':
            return torch.mean(F_loss)
        elif self.reduction == 'sum':
            return torch.sum(F_loss)
        else:
            return F_loss

# Defining Focal Loss here. Use focal lost for multi-label classification problems (eg 1 sequence might be categorized under multiple labels)
# Use cross-entropy loss (default) when one label captures one sequence. (EG: Sequence is X family only).

# A future consideration may be implemetnation of focal loss for family prediction, where GTs may belong to an overarching family (GT2), but may also be classified into subfamilies (GT2-exo, GT2_bact, etc).
# The development of this dataset is not difficult, but the success of a simple cross entropy single label classification approach serves as the foundation for that implementation.

In [None]:
unique_labels = set(y)  # Assuming y is your label array
print("Unique labels in dataset:", unique_labels)
print("Number of unique labels:", len(unique_labels))

In [None]:
import csv
# CSV function to track various aspects about model training 
def write_to_csv(epoch, train_accuracy, avg_train_loss, val_accuracy, avg_val_loss, filename='training_progress.csv'):
    with open(filename, 'a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow([epoch, train_accuracy, avg_train_loss, val_accuracy, avg_val_loss])

In [None]:
model = EsmForSequenceClassification.from_pretrained(esm_model_name, num_labels=len(set(y)))
model = model.to(device)

# Can be changed if needed (AdamW has worked well but others may be better)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)  # Decays lr every X epochs by multiplying the lr by Y
# Create DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Initialize variables for best validation accuracy and corresponding state dict
best_val_acc = 0
best_state_dict = None
train_losses = []
test_losses = []
train_accuracies = []
val_accuracies = []
best_test_acc = 0
patience = 10  # Number of epochs with no improvement after which training will be stopped
no_improve = 0  # Track epochs with no improvement
prev_val_loss = float('inf')  # Set an initial large previous validation loss
class_weights = class_weights.to(device)
criterion = torch.nn.CrossEntropyLoss(weight=class_weights) # balances class weights for labels

# Initially, create the file and write headers
with open('training_progress.csv', 'w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['Epoch', 'Training Accuracy', 'Training Loss', 'Validation Accuracy', 'Validation Loss'])

for epoch in range(num_epochs):
    
    model.train()
    total_train_loss = 0
    correct_train = 0
    total_train = 0
    progress_bar = tqdm(total=len(train_dataloader), desc=f"Epoch {epoch+1}/{num_epochs}", position=0, leave=False)

    train_loss = 0
    num_correct_train = 0
    total_samples_train = 0

    # Train loop, trains each batch
    for batch in train_dataloader:
        batch_input_ids, batch_attention_mask, batch_labels = batch
        batch_labels = batch_labels.to(device)

        # Ensure gradients are zeroed 
        optimizer.zero_grad()

        # Forward pass
        outputs = model(batch_input_ids.to(device), attention_mask=batch_attention_mask.to(device))

        # Calculate loss with criterion
        loss = criterion(outputs.logits, batch_labels.to(device))

        # Add L1 regularization
        l1_loss = 0
        for param in model.parameters():
            l1_loss += param.abs().sum()
        loss += l1_lambda * l1_loss

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Update loss/accuracy tracking
        train_loss += loss.item()
        _, predicted_labels_train = torch.max(outputs.logits, dim=1)
        num_correct_train += (predicted_labels_train == batch_labels).sum().item()
        total_samples_train += len(batch_labels)

        progress_bar.update(1)
        progress_bar.set_postfix({"Loss": loss.item()})


    train_accuracy = num_correct_train / total_samples_train
    avg_train_loss = train_loss / len(train_dataloader)

    # Validation Loop
    model.eval()
    total_val_loss = 0
    correct_val = 0
    total_val = 0
    
    with torch.no_grad():
        for batch in val_dataloader:
            batch_input_ids, batch_attention_mask, batch_labels = batch
            outputs = model(batch_input_ids.to(device), attention_mask=batch_attention_mask.to(device), labels=batch_labels.to(device))
            loss = outputs.loss
            logits = outputs.logits

            total_val_loss += loss.item()
            _, predicted_labels = torch.max(logits, dim=1)
            correct_val += (predicted_labels == batch_labels.to(device)).sum().item()
            total_val += len(batch_labels)
            
    val_accuracy = correct_val / total_val
    avg_val_loss = total_val_loss / len(val_dataloader)
    train_accuracies.append(train_accuracy)
    val_accuracies.append(val_accuracy)
    
    # Check if this epoch resulted in a better validation accuracy
    if val_accuracy > best_val_acc:
        best_val_acc = val_accuracy
        best_state_dict = model.state_dict()


    scheduler.step()
    # Print statistics after the epoch, write to csv, and save model
    print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_train_loss:.4f} - Train Accuracy: {train_accuracy:.4f} - Validation Loss: {avg_val_loss:.4f} - Validation Accuracy: {val_accuracy:.4f}")
    write_to_csv(epoch+1, train_accuracy, avg_train_loss, val_accuracy, avg_val_loss)
    torch.save(model.state_dict(), f"model_epoch_{epoch+1}_{file_naming}.pth") # Save best epoch (if don't overwrite prev epochs: epoch_{epoch+1}_)

    # Early stopping based on validation loss
    if avg_val_loss < prev_val_loss:
        prev_val_loss = avg_val_loss
        no_improve = 0
    else:
        no_improve += 1

    if no_improve == patience:
        print(f"Early stopping after {patience} epochs with no improvement.")
        break

    # Stop if learning rate is too small
    current_lr = optimizer.param_groups[0]['lr']
    if current_lr < 1e-6:  # This threshold can be adjusted
        print(f"Stopping due to learning rate becoming too small: {current_lr}")
        break

### TESTING LOOP

In [None]:

# Initialize variables for tracking test loss and accuracy
test_loss = 0
num_correct_test = 0
total_samples_test = 0

test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


# Testing loop
model.eval()  # Switch model to evaluation mode
with torch.no_grad():
    for batch in test_dataloader:
        batch_input_ids, batch_attention_mask, batch_labels = batch

        # Forward pass
        outputs = model(batch_input_ids.to(device), 
                        attention_mask=batch_attention_mask.to(device), 
                        labels=batch_labels.to(device))
        
        loss = outputs.loss
        logits = outputs.logits

        # Update test loss
        test_loss += loss.item()

        # Get predictions and update the number of correct predictions and total samples
        _, predicted_labels_test = torch.max(logits, dim=1)
        num_correct_test += (predicted_labels_test == batch_labels.to(device)).sum().item()
        total_samples_test += len(batch_labels)

# Calculate average test loss and test accuracy
avg_test_loss = test_loss / len(test_dataloader)
test_accuracy = num_correct_test / total_samples_test

print(f"Test Loss: {avg_test_loss}, Test Accuracy: {test_accuracy}")


### CONFUSION MATRIX

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import pickle

test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

with open('/home/aarya/35M_t12_famv1_labels.pkl', 'rb') as file:
    label_encoder = pickle.load(file)


# Define the model architecture (make sure it's the same as the one you trained)
model = EsmForSequenceClassification.from_pretrained(esm_model_name, num_labels=len(set(y)))
model = model.to(device)

# Load the saved state dictionary
model.load_state_dict(torch.load('/home/aarya/best_model_35M_t12_famv1.pth'))

# Put the model in evaluation mode
model.eval()


# Initialize variables for tracking test loss and accuracy
test_loss = 0
num_correct_test = 0
total_samples_test = 0

# Lists to store true and predicted labels
true_labels_int = []
predicted_labels_int = []

# Testing loop
model.eval()  # Switch model to evaluation mode
with torch.no_grad():
    for batch in test_dataloader:
        batch_input_ids, batch_attention_mask, batch_labels = batch

        # Forward pass
        outputs = model(batch_input_ids.to(device), 
                        attention_mask=batch_attention_mask.to(device), 
                        labels=batch_labels.to(device))
        
        loss = outputs.loss
        logits = outputs.logits

        # Update test loss
        test_loss += loss.item()

        # Get predictions and update the number of correct predictions and total samples
        _, predicted_labels_test = torch.max(logits, dim=1)
        num_correct_test += (predicted_labels_test == batch_labels.to(device)).sum().item()
        total_samples_test += len(batch_labels)

        # Store true and predicted labels
        true_labels_int.extend(batch_labels.tolist())
        predicted_labels_int.extend(predicted_labels_test.tolist())

# Calculate average test loss and test accuracy
avg_test_loss = test_loss / len(test_dataloader)
test_accuracy = num_correct_test / total_samples_test

print(f"Test Loss: {avg_test_loss}, Test Accuracy: {test_accuracy}")

# Convert integer labels back to original string labels
true_labels_str = label_encoder.inverse_transform(true_labels_int)
predicted_labels_str = label_encoder.inverse_transform(predicted_labels_int)

# Generate the confusion matrix
cm = confusion_matrix(true_labels_str, predicted_labels_str, labels=label_encoder.classes_)

# Normalize the confusion matrix by row (i.e by the number of samples in each class)
cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
cm_norm[np.isnan(cm_norm)] = 0  # Replace NaNs with zeros if needed

# Convert to percentages
cm_percentage = cm_norm * 100

# Plot the normalized confusion matrix
plt.figure(figsize=(100, 100))
sns.heatmap(cm_percentage, annot=True, fmt='.1f', cmap='Blues', xticklabels=label_encoder.classes_, yticklabels=label_encoder.classes_)
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.show()

In [None]:
plt.plot(train_losses, label = "Train")
plt.plot(test_losses, label = "Test")
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
plt.plot(train_accuracies, label = "Train")
plt.plot(test_accuracies, label = "Test")
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()