Fine Tune ESM

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import torch.optim as optim
from tqdm import tqdm
import sys, os, math
import wandb
from sklearn.metrics import f1_score, confusion_matrix, accuracy_score

sys.path.insert(0, '../dlp')
from data_process import simple_data_to_tensor_batch

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

epochs = 10_000
val_epoch = 500
num_val = 1000
batch_size = 64
dataset_name = "corpus_1000_random"
lr = 0.001
model_name = "Fine Tune ESM"

from data_access import PQDataAccess
da = PQDataAccess(f"/home/aac/Alireza/datasets/export_pqt_4_taxseq/{dataset_name}", batch_size)

checkpoint_dir = f"../checkpoints/{model_name}_checkpoints"
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
print(checkpoint_dir)

wandb.init(
    # set the wandb project where this run will be logged
    project=model_name,

    # track hyperparameters and run metadata
    config={
    "learning_rate": lr,
    "architecture": "ESM + FNN",
    "dataset": dataset_name,
    "epochs": epochs,
    }
)

cuda:0
../checkpoints/Fine Tune ESM_checkpoints


In [12]:
from transformers import EsmModel

class ESM1b(nn.Module):
    def __init__(self):
        super().__init__()
        self.esm = EsmModel.from_pretrained("facebook/esm1b_t33_650M_UR50S")
        self.layer1 = nn.Linear(1280, 512)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(512, 4)
        
    def forward(self, x, attention_mask=None):
        outputs = self.esm(x, attention_mask=attention_mask)
        # Mean pooling over all residues
        outputs = outputs.last_hidden_state.mean(dim=1)
        outputs = self.layer1(outputs)
        outputs = self.relu(outputs)
        outputs = self.layer2(outputs)
        return outputs

In [13]:
model = ESM1b().to(device)
print("model:", sum(p.numel() for p in model.parameters()) / 1e6, 'M parameters')

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss(ignore_index=0)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm1b_t33_650M_UR50S and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


model: 653.014425 M parameters


In [14]:
def evaluate():
    model.eval()  # Set model to evaluation mode
    
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    for epoch in range(num_val):
        with torch.no_grad():  # Disable gradient computation during evaluation
            tensor_batch = simple_data_to_tensor_batch(da.get_batch())
            tensor_batch.gpu(device)
        
            labels = tensor_batch.taxes["begining root"]
            
            outputs = model(**tensor_batch.seq_ids)

            # Calculate the loss
            loss = criterion(outputs, labels)
    
            running_loss += loss.item()
                
            preds = torch.argmax(outputs, dim=1)
    
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())
    
    # Concatenate all batches into single tensors
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    # Compute evaluation metrics (example: accuracy, F1 score)
    accuracy = accuracy_score(all_labels.numpy(), all_preds.numpy())
    f1_macro = f1_score(all_labels.numpy(), all_preds.numpy(), average='macro')  # F1-score for multi-label classification
    f1_micro = f1_score(all_labels.numpy(), all_preds.numpy(), average='micro')  # F1-score for multi-label classification
    conf_matrix = confusion_matrix(all_labels.numpy(), all_preds.numpy())
    avg_loss = running_loss / num_val
    
    return avg_loss, accuracy, f1_micro, f1_macro, conf_matrix

In [15]:
running_loss = 0

for epoch in range(epochs):
    model.train()  # Set model to training mode
    running_loss = 0.0
    
    tensor_batch = simple_data_to_tensor_batch(da.get_batch())
    tensor_batch.gpu(device)

    labels = tensor_batch.taxes["begining root"]
    
    outputs = model(**tensor_batch.seq_ids)

    # Calculate the loss
    loss = criterion(outputs, labels)

    # Backpropagation: Zero the gradients, compute the backward pass, and update weights
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Track the loss
    running_loss += loss.item()

    if (epoch + 1) % val_epoch == 0:
        # Print loss for this epoch
        epoch_loss = running_loss / (epoch + 1)
        print(f"Epoch [{epoch + 1}/{epochs}], Train Loss: {epoch_loss:.4f}")
        
        # Evaluate the model on the test set
        val_loss, val_accuracy, val_f1_micro, val_f1_macro, cm = evaluate()
        print(cm)
        print(f"val Loss: {val_loss:.4f}, val Accuracy: {val_accuracy:.4f}, val F1 Score (micro): {val_f1_micro:.4f}, , val F1 Score (macro): {val_f1_macro:.4f}")

        wandb.log({"train loss": epoch_loss, "val acc": val_accuracy, "val loss": val_loss})

wandb.finish()

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


AttributeError: 'list' object has no attribute 'items'