In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import torch.optim as optim
from tqdm import tqdm
import sys, os, math
import wandb
from sklearn.metrics import f1_score, confusion_matrix, accuracy_score

sys.path.insert(0, '../dlp')
from data_process import simple_data_to_tensor_batch

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

epochs = 10_000
val_epoch = 500
num_val = 100
batch_size = 16
dataset_name = "corpus_1000_random"
lr = 0.001
model_name = "Fine Tune ESM"
max_seq_len = 500

from data_access import PQDataAccess
da = PQDataAccess(f"/home/aac/Alireza/datasets/export_pqt_4_taxseq/{dataset_name}", batch_size)

checkpoint_dir = f"../checkpoints/{model_name}_checkpoints"
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
print(checkpoint_dir)

wandb.init(
    # set the wandb project where this run will be logged
    project=model_name,

    # track hyperparameters and run metadata
    config={
        "learning_rate": lr,
        "architecture": "ESM + FNN",
        "dataset": dataset_name,
        "epochs": epochs,
        "batch_szie": batch_size,
        "max_seq_len": max_seq_len
    }
)

  from .autonotebook import tqdm as notebook_tqdm


Loaded dictionary.
30522
cuda:0
 WORLD_SIZE=1 , LOCAL_WORLD_SIZE=1,RANK =0,LOCAL_RANK = 0 
../checkpoints/Fine Tune ESM_checkpoints


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33malirezanor[0m ([33malirezanor-310-ai[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
from transformers import EsmModel

class ESM1b(nn.Module):
    def __init__(self):
        super().__init__()
        self.esm = EsmModel.from_pretrained("facebook/esm1b_t33_650M_UR50S")
        self.layer1 = nn.Linear(1280, 512)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(512, 4)
        
    def forward(self, x, attention_mask=None):
        outputs = self.esm(x, attention_mask=attention_mask)
        # Mean pooling over all residues
        outputs = outputs.last_hidden_state.mean(dim=1)
        outputs = self.layer1(outputs)
        outputs = self.relu(outputs)
        outputs = self.layer2(outputs)
        return outputs

In [3]:
model = ESM1b().to(device)
print("model:", sum(p.numel() for p in model.parameters()) / 1e6, 'M parameters')
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss(ignore_index=0)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm1b_t33_650M_UR50S and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


model: 653.014425 M parameters
ESM1b(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 1280, padding_idx=1)
      (layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 1280, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-32): 33 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=1280, out_features=1280, bias=True)
              (key): Linear(in_features=1280, out_features=1280, bias=True)
              (value): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
           

In [4]:
def evaluate():
    model.eval()  # Set model to evaluation mode
    
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    for epoch in range(num_val):
        with torch.no_grad():  # Disable gradient computation during evaluation
            tensor_batch = simple_data_to_tensor_batch(da.get_batch(), max_seq_len)
            tensor_batch.gpu(device)
        
            labels = tensor_batch.taxes["begining root"]
            
            outputs = model(tensor_batch.seq_ids['input_ids'], tensor_batch.seq_ids['attention_mask'])

            # Calculate the loss
            loss = criterion(outputs, labels)
    
            running_loss += loss.item()
                
            preds = torch.argmax(outputs, dim=1)
    
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())
    
    # Concatenate all batches into single tensors
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    # Compute evaluation metrics (example: accuracy, F1 score)
    accuracy = accuracy_score(all_labels.numpy(), all_preds.numpy())
    f1_macro = f1_score(all_labels.numpy(), all_preds.numpy(), average='macro')  # F1-score for multi-label classification
    f1_micro = f1_score(all_labels.numpy(), all_preds.numpy(), average='micro')  # F1-score for multi-label classification
    conf_matrix = confusion_matrix(all_labels.numpy(), all_preds.numpy())
    avg_loss = running_loss / num_val
    
    return avg_loss, accuracy, f1_micro, f1_macro, conf_matrix

In [5]:
running_loss = 0

for epoch in tqdm(range(epochs)):
    model.train()  # Set model to training mode
    running_loss = 0.0
    
    tensor_batch = simple_data_to_tensor_batch(da.get_batch(), max_seq_len)
    tensor_batch.gpu(device)

    labels = tensor_batch.taxes["begining root"]
    
    outputs = model(tensor_batch.seq_ids['input_ids'], tensor_batch.seq_ids['attention_mask'])

    # Calculate the loss
    loss = criterion(outputs, labels)

    # Backpropagation: Zero the gradients, compute the backward pass, and update weights
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Track the loss
    running_loss += loss.item()

    if (epoch + 1) % val_epoch == 0:
        # Print loss for this epoch
        epoch_loss = running_loss / (epoch + 1)
        print(f"Epoch [{epoch + 1}/{epochs}], Train Loss: {epoch_loss:.4f}")
        
        # Evaluate the model on the test set
        val_loss, val_accuracy, val_f1_micro, val_f1_macro, cm = evaluate()
        print(cm)
        print(f"val Loss: {val_loss:.4f}, val Accuracy: {val_accuracy:.4f}, val F1 Score (micro): {val_f1_micro:.4f}, , val F1 Score (macro): {val_f1_macro:.4f}")

        wandb.log({
            "train loss": epoch_loss,
            "val acc": val_accuracy,
            "val loss": val_loss,
            "val F1 (micro)": val_f1_micro,
            "val F1 (macro)": val_f1_macro,
            "confusion matrix": cm
        })

wandb.finish()

  5%|▍         | 499/10000 [12:35<3:59:13,  1.51s/it]

Epoch [500/10000], Train Loss: 0.0001


  5%|▌         | 500/10000 [13:27<44:00:56, 16.68s/it]

[[   0   16    0]
 [   0 1572    0]
 [   0   12    0]]
val Loss: 0.1142, val Accuracy: 0.9825, val F1 Score (micro): 0.9825, , val F1 Score (macro): 0.3304


 10%|▉         | 999/10000 [26:00<3:45:51,  1.51s/it] 

Epoch [1000/10000], Train Loss: 0.0001


 10%|█         | 1000/10000 [26:52<41:39:14, 16.66s/it]

[[   0   14    0]
 [   0 1561    0]
 [   0   25    0]]
val Loss: 0.1758, val Accuracy: 0.9756, val F1 Score (micro): 0.9756, , val F1 Score (macro): 0.3292


 15%|█▍        | 1499/10000 [39:26<3:34:12,  1.51s/it] 

Epoch [1500/10000], Train Loss: 0.0005


 15%|█▌        | 1500/10000 [40:18<39:21:38, 16.67s/it]

[[   0   23    0]
 [   0 1554    0]
 [   0   23    0]]
val Loss: 0.1630, val Accuracy: 0.9712, val F1 Score (micro): 0.9712, , val F1 Score (macro): 0.3285


 20%|█▉        | 1999/10000 [52:51<3:20:58,  1.51s/it] 

Epoch [2000/10000], Train Loss: 0.0000


 20%|██        | 2000/10000 [53:43<37:01:58, 16.66s/it]

[[   0   20    0]
 [   0 1561    0]
 [   0   19    0]]
val Loss: 0.1319, val Accuracy: 0.9756, val F1 Score (micro): 0.9756, , val F1 Score (macro): 0.3292


 25%|██▍       | 2499/10000 [1:06:15<3:08:43,  1.51s/it]

Epoch [2500/10000], Train Loss: 0.0000


 25%|██▌       | 2500/10000 [1:07:07<34:46:48, 16.69s/it]

[[   0   28    0]
 [   0 1557    0]
 [   0   15    0]]
val Loss: 0.1469, val Accuracy: 0.9731, val F1 Score (micro): 0.9731, , val F1 Score (macro): 0.3288


 30%|██▉       | 2999/10000 [1:19:41<2:56:01,  1.51s/it] 

Epoch [3000/10000], Train Loss: 0.0002


 30%|███       | 3000/10000 [1:20:33<32:25:54, 16.68s/it]

[[   0   21    0]
 [   0 1566    0]
 [   0   13    0]]
val Loss: 0.1185, val Accuracy: 0.9788, val F1 Score (micro): 0.9788, , val F1 Score (macro): 0.3298


 35%|███▍      | 3499/10000 [1:33:07<2:43:27,  1.51s/it] 

Epoch [3500/10000], Train Loss: 0.0001


 35%|███▌      | 3500/10000 [1:34:00<30:14:30, 16.75s/it]

[[   0   14    0]
 [   0 1566    0]
 [   0   20    0]]
val Loss: 0.1182, val Accuracy: 0.9788, val F1 Score (micro): 0.9788, , val F1 Score (macro): 0.3298


 40%|███▉      | 3999/10000 [1:46:34<2:31:04,  1.51s/it] 

Epoch [4000/10000], Train Loss: 0.0000


 40%|████      | 4000/10000 [1:47:26<27:50:10, 16.70s/it]

[[   0   22    0]
 [   0 1560    0]
 [   0   18    0]]
val Loss: 0.1361, val Accuracy: 0.9750, val F1 Score (micro): 0.9750, , val F1 Score (macro): 0.3291


 45%|████▍     | 4499/10000 [2:00:00<2:19:20,  1.52s/it] 

Epoch [4500/10000], Train Loss: 0.0000


 45%|████▌     | 4500/10000 [2:00:52<25:38:21, 16.78s/it]

[[   0   20    0]
 [   0 1568    0]
 [   0   12    0]]
val Loss: 0.1140, val Accuracy: 0.9800, val F1 Score (micro): 0.9800, , val F1 Score (macro): 0.3300


 50%|████▉     | 4999/10000 [2:13:26<2:06:00,  1.51s/it] 

Epoch [5000/10000], Train Loss: 0.0000


 50%|█████     | 5000/10000 [2:14:19<23:12:25, 16.71s/it]

[[   0   31    0]
 [   0 1557    0]
 [   0   12    0]]
val Loss: 0.1389, val Accuracy: 0.9731, val F1 Score (micro): 0.9731, , val F1 Score (macro): 0.3288


 55%|█████▍    | 5499/10000 [2:26:54<1:53:36,  1.51s/it] 

Epoch [5500/10000], Train Loss: 0.0000


 55%|█████▌    | 5500/10000 [2:27:46<20:52:34, 16.70s/it]

[[   0   21    0]
 [   0 1558    0]
 [   0   21    0]]
val Loss: 0.1365, val Accuracy: 0.9738, val F1 Score (micro): 0.9738, , val F1 Score (macro): 0.3289


 60%|█████▉    | 5999/10000 [2:40:20<1:41:32,  1.52s/it] 

Epoch [6000/10000], Train Loss: 0.0000


 60%|██████    | 6000/10000 [2:41:12<18:33:40, 16.71s/it]

[[   0   25    0]
 [   0 1559    0]
 [   0   16    0]]
val Loss: 0.1453, val Accuracy: 0.9744, val F1 Score (micro): 0.9744, , val F1 Score (macro): 0.3290


 63%|██████▎   | 6288/10000 [2:48:28<1:33:33,  1.51s/it] IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

