In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import torch.optim as optim
from tqdm import tqdm
import sys, os, math
import wandb
from sklearn.metrics import f1_score, confusion_matrix, accuracy_score

sys.path.insert(0, '../dlp')
from data_process import mix_data_to_tensor_batch, simple_data_to_tensor_batch

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

epochs = 100_000
val_epoch = 1000
num_val = 50
batch_size = 16
dataset_name = "corpus_1000_random"
virus_dataset_name = "corpus_200_500_Viruses_random"
non_virus_dataset_name = "corpus_200_500_Non_Viruses_random"
lr = 0.001
model_name = "Fine Tune ESM uniform sampling"
max_seq_len = 500

from data_access import PQDataAccess
virus_da = PQDataAccess(f"/home/aac/Alireza/datasets/export_pqt_4_taxseq/{virus_dataset_name}", batch_size)
non_virus_da = PQDataAccess(f"/home/aac/Alireza/datasets/export_pqt_4_taxseq/{non_virus_dataset_name}", batch_size)
da = PQDataAccess(f"/home/aac/Alireza/datasets/export_pqt_4_taxseq/{dataset_name}", 2 * batch_size)

checkpoint_dir = f"../checkpoints/{model_name}_checkpoints"
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
print(checkpoint_dir)

wandb.init(
    # set the wandb project where this run will be logged
    project=model_name,

    # track hyperparameters and run metadata
    config={
        "learning_rate": lr,
        "architecture": "ESM + FNN",
        "dataset": virus_dataset_name + ", " + non_virus_dataset_name,
        "epochs": epochs,
        "batch_szie": batch_size,
        "max_seq_len": max_seq_len
    }
)

  from .autonotebook import tqdm as notebook_tqdm


Loaded dictionary.
cuda:0
 WORLD_SIZE=1 , LOCAL_WORLD_SIZE=1,RANK =0,LOCAL_RANK = 0 
../checkpoints/Fine Tune ESM uniform sampling_checkpoints


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33malirezanor[0m ([33malirezanor-310-ai[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
from transformers import EsmModel

class ESM1b(nn.Module):
    def __init__(self):
        super().__init__()
        self.esm = EsmModel.from_pretrained("facebook/esm1b_t33_650M_UR50S")
        self.layer1 = nn.Linear(1280, 512)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(512, 4)
        
    def forward(self, x, attention_mask=None):
        outputs = self.esm(x, attention_mask=attention_mask)
        # Mean pooling over all residues
        outputs = outputs.last_hidden_state.mean(dim=1)
        outputs = self.layer1(outputs)
        outputs = self.relu(outputs)
        outputs = self.layer2(outputs)
        return outputs

In [3]:
model = ESM1b().to(device)
print("model:", sum(p.numel() for p in model.parameters()) / 1e6, 'M parameters')
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
# Cosine annealing with warm restarts
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer,
    T_0=10,  # Initial restart interval
    T_mult=2,  # Multiply interval by 2 after each restart
    eta_min=1e-6  # Minimum learning rate
)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm1b_t33_650M_UR50S and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


model: 653.014425 M parameters
ESM1b(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 1280, padding_idx=1)
      (layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 1280, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-32): 33 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=1280, out_features=1280, bias=True)
              (key): Linear(in_features=1280, out_features=1280, bias=True)
              (value): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
           

In [4]:
val_batches = [da.get_batch() for _ in range(num_val)]

def evaluate(model):
    model.eval()  # Set model to evaluation mode
    
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    for epoch in range(num_val):
        with torch.no_grad():  # Disable gradient computation during evaluation
            tensor_batch = simple_data_to_tensor_batch(val_batches[epoch], max_seq_len)
            tensor_batch.gpu(device)
        
            labels = tensor_batch.taxes["begining root"]
            outputs = model(tensor_batch.seq_ids['input_ids'], tensor_batch.seq_ids['attention_mask'])

            # Calculate the loss
            loss = criterion(outputs, labels)
    
            running_loss += loss.item()
                
            preds = torch.argmax(outputs, dim=1)
    
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())
    
    # Concatenate all batches into single tensors
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    # Compute evaluation metrics (example: accuracy, F1 score)
    accuracy = accuracy_score(all_labels.numpy(), all_preds.numpy())
    f1_macro = f1_score(all_labels.numpy(), all_preds.numpy(), average='macro')  # F1-score for multi-label classification
    f1_micro = f1_score(all_labels.numpy(), all_preds.numpy(), average='micro')  # F1-score for multi-label classification
    conf_matrix = confusion_matrix(all_labels.numpy(), all_preds.numpy(), labels= [0, 1, 2, 3])
    avg_loss = running_loss / num_val
    
    return avg_loss, accuracy, f1_micro, f1_macro, conf_matrix

In [None]:
import glob
def load_checkpoint(model, optimizer=None, scheduler=None):
    checkpoints = glob.glob(os.path.join(checkpoint_dir, 'checkpoint_epoch_*.pt'))        
    # Extract epoch numbers and find latest
    latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('_')[-1].split('.')[0]))
    checkpoint = torch.load(latest_checkpoint)
    
    # Load model state
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    
    # Load optimizer state if provided (for training)
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        # Move optimizer state to GPU if necessary
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)

    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    # Get training metadata
    epoch = checkpoint['epoch']
    metrics = checkpoint['metrics']
    
    print(f"Successfully loaded checkpoint from epoch {epoch}")
    print("Metrics at checkpoint:", metrics)
    
    return model, optimizer, scheduler, epoch, metrics
        

model, optimizer, scheduler, latest_epoch, metrics = load_checkpoint(model, optimizer, scheduler)

In [5]:
latest_epoch = 0

In [6]:
def get_partition_ratio(epoch, decay_epochs=50000):
    """
    Calculate partition ratio that decreases from 8/16 to 1/16 in steps
    """
    # Calculate how many epochs before each step down
    epochs_per_step = decay_epochs // 7  # 7 steps from 8/16 down to 1/16
    
    # Calculate current step based on epoch
    step = min(epoch // epochs_per_step, 7)  # Max 7 steps down from 8
    
    # Map step to fraction
    fraction = (8 - step) / 16
    
    return fraction

In [None]:
running_loss = 0
train_preds = []
train_labels = []

for epoch in tqdm(range(latest_epoch, latest_epoch + epochs)):
    model.train()

    current_partition = get_partition_ratio(epoch)
    
    tensor_batch = mix_data_to_tensor_batch(
        virus_da.get_batch(),
        non_virus_da.get_batch(),
        max_seq_len,
        partition=current_partition
    )
    tensor_batch.gpu(device)
    
    labels = tensor_batch.taxes["begining root"]
    outputs = model(tensor_batch.seq_ids['input_ids'], tensor_batch.seq_ids['attention_mask'])
    
    loss = criterion(outputs, labels)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    running_loss += loss.item()


    
    preds = torch.argmax(outputs, dim=1)
    train_preds.append(preds.cpu())
    train_labels.append(labels.cpu())
    
    if (epoch + 1) % val_epoch == 0:
        # Calculate training metrics
        all_train_preds = torch.cat(train_preds)
        all_train_labels = torch.cat(train_labels)
        
        train_accuracy = accuracy_score(all_train_labels.numpy(), all_train_preds.numpy())
        train_f1_micro = f1_score(all_train_labels.numpy(), all_train_preds.numpy(), average='micro')
        train_f1_macro = f1_score(all_train_labels.numpy(), all_train_preds.numpy(), average='macro')
        train_cm = confusion_matrix(all_train_labels.numpy(), all_train_preds.numpy(), labels=[0, 1, 2, 3])
        train_loss = running_loss / val_epoch
        
        print(f"Epoch [{epoch + 1}/{epochs}]")
        print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")
        print(f"Train F1 (micro): {train_f1_micro:.4f}, Train F1 (macro): {train_f1_macro:.4f}")
        print("Train Confusion Matrix:")
        print(train_cm)
        
        # Evaluate on validation set
        val_loss, val_accuracy, val_f1_micro, val_f1_macro, val_cm = evaluate(model)
        
        print(f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
        print(f"Val F1 (micro): {val_f1_micro:.4f}, Val F1 (macro): {val_f1_macro:.4f}")
        print("Val Confusion Matrix:")
        print(val_cm)
        
        # Create metrics dictionary for saving
        metrics = {
            "train_loss": train_loss,
            "train_accuracy": train_accuracy,
            "train_f1_micro": train_f1_micro,
            "train_f1_macro": train_f1_macro,
            "train_confusion_matrix": train_cm,
            "val_loss": val_loss,
            "val_accuracy": val_accuracy,
            "val_f1_micro": val_f1_micro,
            "val_f1_macro": val_f1_macro,
            "val_confusion_matrix": val_cm,
            "epoch": epoch + 1,
            "current_portion": current_partition,
            "lr": current_lr
        }

        # Save periodic checkpoint
        checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'metrics': metrics
        }, checkpoint_path)
        
        # Log to wandb
        wandb.log(metrics)

        # Step the scheduler
        scheduler.step(epoch + loss.item())
        current_lr = scheduler.get_last_lr()[0]
        
        # Reset training metrics
        running_loss = 0
        train_preds = []
        train_labels = []

wandb.finish()

  1%|          | 999/100000 [25:08<41:40:39,  1.52s/it]

Epoch [1000/100000]
Train Loss: 0.6482, Train Accuracy: 0.6358
Train F1 (micro): 0.6358, Train F1 (macro): 0.4241
Train Confusion Matrix:
[[   0    0    0    0]
 [   0 5218 2782    0]
 [   0 3020 4954    0]
 [   0    8   18    0]]
Val Loss: 0.5250, Val Accuracy: 0.7475
Val F1 (micro): 0.7475, Val F1 (macro): 0.3141
Val Confusion Matrix:
[[   0    0    0    0]
 [   0   19    6    0]
 [   0  386 1177    0]
 [   0    4    8    0]]


  2%|▏         | 1999/100000 [51:28<41:04:17,  1.51s/it] 

Epoch [2000/100000]
Train Loss: 0.5404, Train Accuracy: 0.7392
Train F1 (micro): 0.7392, Train F1 (macro): 0.4930
Train Confusion Matrix:
[[   0    0    0    0]
 [   0 6025 1975    0]
 [   0 2181 5802    0]
 [   0   10    7    0]]
Val Loss: 0.9491, Val Accuracy: 0.5569
Val F1 (micro): 0.5569, Val F1 (macro): 0.2570
Val Confusion Matrix:
[[  0   0   0   0]
 [  0  22   3   0]
 [  0 694 869   0]
 [  0   7   5   0]]


  3%|▎         | 2999/100000 [1:17:45<40:36:17,  1.51s/it]

Epoch [3000/100000]
Train Loss: 0.5092, Train Accuracy: 0.7689
Train F1 (micro): 0.7689, Train F1 (macro): 0.5657
Train Confusion Matrix:
[[   0    0    0    0]
 [   0 6135 1828   37]
 [   0 1590 6143   79]
 [   0   29  135   24]]
Val Loss: 0.5578, Val Accuracy: 0.7581
Val F1 (micro): 0.7581, Val F1 (macro): 0.3218
Val Confusion Matrix:
[[   0    0    0    0]
 [   0   22    3    0]
 [   0  372 1191    0]
 [   0    4    8    0]]


  4%|▍         | 3999/100000 [1:44:07<40:18:38,  1.51s/it] 

Epoch [4000/100000]
Train Loss: 0.4300, Train Accuracy: 0.8061
Train F1 (micro): 0.8061, Train F1 (macro): 0.5940
Train Confusion Matrix:
[[   0    0    0    0]
 [   0 6239 1758    3]
 [   0 1275 6651   18]
 [   0   33   16    7]]
Val Loss: 0.8752, Val Accuracy: 0.6562
Val F1 (micro): 0.6562, Val F1 (macro): 0.2885
Val Confusion Matrix:
[[   0    0    0    0]
 [   0   22    3    0]
 [   0  535 1028    0]
 [   0    6    6    0]]


  5%|▍         | 4999/100000 [2:10:27<39:44:58,  1.51s/it] 

Epoch [5000/100000]
Train Loss: 0.4214, Train Accuracy: 0.8185
Train F1 (micro): 0.8185, Train F1 (macro): 0.5468
Train Confusion Matrix:
[[   0    0    0    0]
 [   0 6299 1701    0]
 [   0 1130 6797    0]
 [   0   26   47    0]]
Val Loss: 0.6682, Val Accuracy: 0.7388
Val F1 (micro): 0.7388, Val F1 (macro): 0.3139
Val Confusion Matrix:
[[   0    0    0    0]
 [   0   21    4    0]
 [   0  402 1161    0]
 [   0    5    7    0]]


  5%|▌         | 5174/100000 [2:15:58<39:47:16,  1.51s/it] IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

  9%|▉         | 8999/100000 [3:55:42<38:10:05,  1.51s/it]

Epoch [9000/100000]
Train Loss: 0.4452, Train Accuracy: 0.8066
Train F1 (micro): 0.8066, Train F1 (macro): 0.6205
Train Confusion Matrix:
[[   0    0    0    0]
 [   0 5329 1638   33]
 [   0 1139 7527   92]
 [   0   50  142   50]]
Val Loss: 0.6621, Val Accuracy: 0.7156
Val F1 (micro): 0.7156, Val F1 (macro): 0.3063
Val Confusion Matrix:
[[   0    0    0    0]
 [   0   21    4    0]
 [   0  439 1124    0]
 [   0    5    7    0]]


 10%|▉         | 9999/100000 [4:22:05<37:50:58,  1.51s/it] 

Epoch [10000/100000]
Train Loss: 0.3717, Train Accuracy: 0.8433
Train F1 (micro): 0.8433, Train F1 (macro): 0.5610
Train Confusion Matrix:
[[   0    0    0    0]
 [   0 5544 1456    0]
 [   0  988 7949    0]
 [   0    7   56    0]]
Val Loss: 0.5433, Val Accuracy: 0.7619
Val F1 (micro): 0.7619, Val F1 (macro): 0.3204
Val Confusion Matrix:
[[   0    0    0    0]
 [   0   20    5    0]
 [   0  364 1199    0]
 [   0    5    7    0]]


 11%|█         | 10999/100000 [4:48:16<37:16:34,  1.51s/it] 

Epoch [11000/100000]
Train Loss: 0.4054, Train Accuracy: 0.8195
Train F1 (micro): 0.8195, Train F1 (macro): 0.5841
Train Confusion Matrix:
[[   0    0    0    0]
 [   0 5290 1704    6]
 [   0 1094 7816   14]
 [   0   21   49    6]]
Val Loss: 0.3899, Val Accuracy: 0.8550
Val F1 (micro): 0.8550, Val F1 (macro): 0.3491
Val Confusion Matrix:
[[   0    0    0    0]
 [   0   16    9    0]
 [   0  211 1352    0]
 [   0    3    9    0]]


 12%|█▏        | 11999/100000 [5:14:29<36:43:44,  1.50s/it] 

Epoch [12000/100000]
Train Loss: 0.6197, Train Accuracy: 0.6613
Train F1 (micro): 0.6613, Train F1 (macro): 0.4213
Train Confusion Matrix:
[[   0    0    0    0]
 [   0 3029 3971    0]
 [   0 1442 7551    0]
 [   0    0    7    0]]
Val Loss: 0.5118, Val Accuracy: 0.9762
Val F1 (micro): 0.9762, Val F1 (macro): 0.3293
Val Confusion Matrix:
[[   0    0    0    0]
 [   0    0   25    0]
 [   0    1 1562    0]
 [   0    0   12    0]]


 13%|█▎        | 12999/100000 [5:40:35<36:22:23,  1.51s/it] 

Epoch [13000/100000]
Train Loss: 0.5739, Train Accuracy: 0.7114
Train F1 (micro): 0.7114, Train F1 (macro): 0.4650
Train Confusion Matrix:
[[   0    0    0    0]
 [   0 3962 3038    0]
 [   0 1568 7421    0]
 [   0    2    9    0]]
Val Loss: 0.5634, Val Accuracy: 0.7863
Val F1 (micro): 0.7863, Val F1 (macro): 0.3209
Val Confusion Matrix:
[[   0    0    0    0]
 [   0   15   10    0]
 [   0  320 1243    0]
 [   0    5    7    0]]


 14%|█▍        | 13999/100000 [6:07:28<35:50:33,  1.50s/it] 

Epoch [14000/100000]
Train Loss: 0.4995, Train Accuracy: 0.7711
Train F1 (micro): 0.7711, Train F1 (macro): 0.5089
Train Confusion Matrix:
[[   0    0    0    0]
 [   0 4609 2391    0]
 [   0 1225 7729    7]
 [   0    6   33    0]]
Val Loss: 0.6428, Val Accuracy: 0.7600
Val F1 (micro): 0.7600, Val F1 (macro): 0.3169
Val Confusion Matrix:
[[   0    0    0    0]
 [   0   18    7    0]
 [   0  365 1198    0]
 [   0    6    6    0]]


 15%|█▍        | 14999/100000 [6:34:35<38:57:50,  1.65s/it] 

Epoch [15000/100000]
Train Loss: 0.5756, Train Accuracy: 0.7282
Train F1 (micro): 0.7282, Train F1 (macro): 0.4994
Train Confusion Matrix:
[[   0    0    0    0]
 [   0 4033 2140  111]
 [   0 1559 7605  292]
 [   0   98  149   13]]
Val Loss: 0.4167, Val Accuracy: 0.8788
Val F1 (micro): 0.8788, Val F1 (macro): 0.3474
Val Confusion Matrix:
[[   0    0    0    0]
 [   0   11   14    0]
 [   0  168 1395    0]
 [   0    3    9    0]]


 16%|█▌        | 15999/100000 [7:00:51<34:59:25,  1.50s/it] 

Epoch [16000/100000]
Train Loss: 0.4419, Train Accuracy: 0.8162
Train F1 (micro): 0.8162, Train F1 (macro): 0.5314
Train Confusion Matrix:
[[   0    0    0    0]
 [   0 3948 2052    0]
 [   0  820 9111    0]
 [   0    1   68    0]]
Val Loss: 0.3129, Val Accuracy: 0.9213
Val F1 (micro): 0.9213, Val F1 (macro): 0.3641
Val Confusion Matrix:
[[   0    0    0    0]
 [   0    9   16    0]
 [   0   98 1465    0]
 [   0    4    8    0]]


 17%|█▋        | 16999/100000 [7:27:15<35:40:03,  1.55s/it] 

Epoch [17000/100000]
Train Loss: 0.4157, Train Accuracy: 0.8217
Train F1 (micro): 0.8217, Train F1 (macro): 0.5354
Train Confusion Matrix:
[[   0    0    0    0]
 [   0 4053 1947    0]
 [   0  867 9094    0]
 [   0   19   20    0]]
Val Loss: 0.2660, Val Accuracy: 0.9250
Val F1 (micro): 0.9250, Val F1 (macro): 0.3628
Val Confusion Matrix:
[[   0    0    0    0]
 [   0    8   17    0]
 [   0   91 1472    0]
 [   0    2   10    0]]


 18%|█▊        | 17512/100000 [7:41:55<36:01:09,  1.57s/it] 