In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import torch.optim as optim
from tqdm import tqdm
import sys, os, math
import wandb
from sklearn.metrics import f1_score, confusion_matrix, accuracy_score
import json
from transformers import EsmModel, AutoTokenizer

sys.path.insert(0, '../dlp')
from batch import Batch

device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
print(device)

epochs = 100_000
val_epoch = 100
num_val = 10
batch_size = 8
dataset_name = "new_corpus"
lr = 0.001
model_name = "Finetune_ESM_FNN"
max_seq_len = 500

from data_access import PQDataAccess
da = PQDataAccess(f"/home/aac/Alireza/datasets/export_pqt_4_taxseq_new/{dataset_name}", batch_size)

checkpoint_dir = f"../checkpoints/{model_name}_checkpoints"
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
print(checkpoint_dir)

wandb.init(
    # set the wandb project where this run will be logged
    project=model_name,

    # track hyperparameters and run metadata
    config={
        "learning_rate": lr,
        "architecture": model_name,
        "epochs": epochs,
        "batch_szie": batch_size,
        "max_seq_len": max_seq_len
    }
)

cuda:1
../checkpoints/Finetune_ESM_FNN_checkpoints


In [26]:
tax_ids_file = "../data/tax_ids.csv"

tax_ids = pd.read_csv(tax_ids_file)
# print(tax_ids)
num_classes = len(tax_ids) + 1
print(num_classes)
id_encoder = {name: idx + 1 for idx, name in enumerate(tax_ids['Taxonomic_lineage_IDs'].values)}

id_decoder = {idx + 1: name for idx, name in enumerate(tax_ids['Taxonomic_lineage_IDs'].values)}
id_decoder[0] = "NOT DEFINED"

# Character vocabulary for protein sequences (20 amino acids + 1 padding)
vocab = "ACDEFGHIKLMNPQRSTVWY"
char_to_idx = {char: idx + 1 for idx, char in enumerate(vocab)}  # Start index from 1 for padding
# Sequence encoder: Convert the protein sequence into integers
def encode_sequence(sequence):
    return [char_to_idx.get(char, 0) for char in sequence] + [0 for _ in range(max_seq_len - len(sequence))]  # 0 for unknown characters or padding 

def data_to_tensor_batch(b):
    inputs = torch.LongTensor([encode_sequence(e['sequence']) for e in b])
    tax_ids = torch.LongTensor([id_encoder.get(e['Taxonomic_lineage_IDs'], 0) for e in b])

    return Batch(inputs, tax_ids)

67486


In [27]:
import torch
import torch.nn as nn

class ESM2(nn.Module):
    def __init__(self, num_classes, embedding_dim=1280, vocab_size=21):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.layer1 = nn.Linear(embedding_dim, 512)
        self.layer_norm = nn.LayerNorm(512)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.1)
        self.layer2 = nn.Linear(512, num_classes)

    def forward(self, x):
        # Input `x` is of shape (batch_size, sequence_length)
        embeddings = self.embedding(x)  # Shape: (batch_size, sequence_length, embedding_dim)
        outputs = self.layer1(embeddings)  # Shape: (batch_size, sequence_length, 512)
        outputs = self.layer_norm(outputs)  # Normalize across feature dimensions
        outputs = self.relu(outputs)
        outputs = self.dropout(outputs)

        # Pooling across the sequence dimension to get a single vector per sequence
        pooled_output = torch.mean(outputs, dim=1)  # Shape: (batch_size, 512)

        # Final classification layer
        logits = self.layer2(pooled_output)  # Shape: (batch_size, num_classes)
        return logits

In [28]:
model = ESM2(num_classes).to(device)

total = sum(p.numel() for p in model.parameters())
print(f'Total parameters: {total/ 1e6} M')
print(model)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

criterion = nn.CrossEntropyLoss()
# Cosine annealing with warm restarts
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer,
    T_0=10,  # Initial restart interval
    T_mult=2,  # Multiply interval by 2 after each restart
    eta_min=1e-6  # Minimum learning rate
)

Total parameters: 35.304094 M
ESM2(
  (embedding): Embedding(21, 1280)
  (layer1): Linear(in_features=1280, out_features=512, bias=True)
  (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.1, inplace=False)
  (layer2): Linear(in_features=512, out_features=67486, bias=True)
)


In [29]:
val_dir = f"val_results/{model_name}"
if not os.path.exists(val_dir):
    os.makedirs(val_dir)
    
val_batches = [da.get_batch() for _ in range(num_val)]

def evaluate(model):
    model.eval()  # Set model to evaluation mode
    
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    for epoch in range(num_val):
        with torch.no_grad():  # Disable gradient computation during evaluation
            tensor_batch = data_to_tensor_batch(val_batches[epoch])
            tensor_batch.gpu(device)
        
            labels = tensor_batch.taxes
            outputs = model(tensor_batch.seq_ids)
            
            # Calculate the loss
            loss = criterion(outputs, labels)
    
            running_loss += loss.item()
                
            preds = torch.argmax(outputs, dim=1)
    
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())
    
    # Concatenate all batches into single tensors
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    print(all_labels)
    # Compute evaluation metrics (example: accuracy, F1 score)
    accuracy = accuracy_score(all_labels.numpy(), all_preds.numpy())
    f1_macro = f1_score(all_labels.numpy(), all_preds.numpy(), average='macro')  # F1-score for multi-label classification
    f1_micro = f1_score(all_labels.numpy(), all_preds.numpy(), average='micro')  # F1-score for multi-label classification
    # conf_matrix = confusion_matrix(all_labels.numpy(), all_preds.numpy())
    avg_loss = running_loss / num_val
    
    return avg_loss, accuracy, f1_micro, f1_macro

evaluate(model)

tensor([62847,  4563, 59752, 62848, 15365, 41312, 10661, 50520, 58202, 29003,
        47514, 33664, 56666, 21453, 30594, 64470, 64471, 58203, 65961, 21454,
            1,  9151, 32139,  6074, 35209, 16872, 44417, 15366, 59753, 53560,
        44418, 59754, 16873, 16874,  4564, 59755, 18436,  6075, 27445, 62847,
        59756, 53561, 24425, 39812, 62849, 13836, 61322,     2, 29004,  3028,
        35210,  1509, 65962, 53562,  9152, 27446, 61323, 50521, 47515, 55106,
        61324,  4564, 27447, 41313, 50522,  1510, 10662, 33665, 41314, 12229,
        42858, 49064, 18437, 32140, 33666, 58204, 32141, 64472, 33667, 32142])


RuntimeError: Numpy is not available

In [6]:
import glob
def load_checkpoint(model, optimizer=None, scheduler=None):
    checkpoints = glob.glob(os.path.join(checkpoint_dir, 'checkpoint_epoch_*.pt'))        
    # Extract epoch numbers and find latest
    latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('_')[-1].split('.')[0]))
    checkpoint = torch.load(latest_checkpoint)
    
    # Load model state
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    
    # Load optimizer state if provided (for training)
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        # Move optimizer state to GPU if necessary
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
    if scheduler is not None:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    # Get training metadata
    epoch = checkpoint['epoch']
    metrics = checkpoint['metrics']
    
    print(f"Successfully loaded checkpoint from epoch {epoch}")
    # print("Metrics at checkpoint:", metrics)
    
    return model, optimizer, scheduler, epoch, metrics
        

# model, optimizer, scheduler, latest_epoch, metrics = load_checkpoint(model, optimizer, scheduler)
latest_epoch = 0

In [None]:
running_loss = 0
current_lr = lr

for epoch in tqdm(range(latest_epoch + 1, latest_epoch + epochs + 1)):
    model.train()
    
    tensor_batch = data_to_tensor_batch(da.get_batch())
    tensor_batch.gpu(device)
    
    labels = tensor_batch.taxes
    outputs = model(tensor_batch.seq_ids['input_ids'], tensor_batch.seq_ids['attention_mask'])
    
    loss = criterion(outputs, labels)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    running_loss += loss.item()
    
    if (epoch + 1) % val_epoch == 0:
        train_loss = running_loss / val_epoch
        # Evaluate on validation set
        val_loss, val_accuracy, val_f1_micro, val_f1_macro = evaluate(model)
        
        print(f"Epoch [{epoch + 1}/{epochs}]")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"test Loss: {val_loss:.4f}, test Accuracy: {val_accuracy:.4f}")
        print(f"test F1 (micro): {val_f1_micro:.4f}, test F1 (macro): {val_f1_macro:.4f}")
        
        # Create metrics dictionary for saving
        metrics = {
            "train_loss": train_loss,
            "test_loss": val_loss,
            "test_accuracy": val_accuracy,
            "test_f1_micro": val_f1_micro,
            "test_f1_macro": val_f1_macro,
            "epoch": epoch + 1,
            "lr": current_lr
        }

        # Save periodic checkpoint
        checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'metrics': metrics
        }, checkpoint_path)
        
        # Log to wandb
        wandb.log(metrics)

        # Step the scheduler
        scheduler.step(epoch + loss.item())
        current_lr = scheduler.get_last_lr()[0]
        
        # Reset training metrics
        running_loss = 0

wandb.finish()

  0%|          | 98/100000 [01:23<23:50:12,  1.16it/s]

Epoch [100/100000]
Train Loss: 10.9839
test Loss: 11.0489, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  0%|          | 198/100000 [03:19<23:45:52,  1.17it/s] 

Epoch [200/100000]
Train Loss: 10.7905
test Loss: 10.5840, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  0%|          | 298/100000 [05:11<23:42:55,  1.17it/s] 

Epoch [300/100000]
Train Loss: 10.6935
test Loss: 10.3726, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  0%|          | 398/100000 [06:56<23:38:15,  1.17it/s] 

Epoch [400/100000]
Train Loss: 10.6458
test Loss: 10.3708, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  0%|          | 498/100000 [08:43<23:42:00,  1.17it/s] 

Epoch [500/100000]
Train Loss: 10.5916
test Loss: 10.2280, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  1%|          | 698/100000 [12:21<23:36:50,  1.17it/s] 

Epoch [700/100000]
Train Loss: 10.5770
test Loss: 10.2566, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  1%|          | 798/100000 [14:08<23:33:19,  1.17it/s] 

Epoch [800/100000]
Train Loss: 10.5264
test Loss: 10.2511, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  1%|          | 898/100000 [15:59<23:31:48,  1.17it/s] 

Epoch [900/100000]
Train Loss: 10.5826
test Loss: 10.1532, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  1%|          | 998/100000 [17:43<23:38:26,  1.16it/s] 

Epoch [1000/100000]
Train Loss: 10.4992
test Loss: 10.1603, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  1%|          | 1098/100000 [19:32<23:21:34,  1.18it/s] 

Epoch [1100/100000]
Train Loss: 10.4366
test Loss: 10.1345, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  1%|          | 1198/100000 [21:18<23:31:32,  1.17it/s] 

Epoch [1200/100000]
Train Loss: 10.4988
test Loss: 10.1922, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  1%|▏         | 1298/100000 [23:05<23:31:24,  1.17it/s] 

Epoch [1300/100000]
Train Loss: 10.4969
test Loss: 10.1520, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  1%|▏         | 1398/100000 [24:53<23:14:23,  1.18it/s] 

Epoch [1400/100000]
Train Loss: 10.3082
test Loss: 10.1369, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  2%|▏         | 1598/100000 [28:36<23:20:35,  1.17it/s] 

Epoch [1600/100000]
Train Loss: 10.2984
test Loss: 10.0398, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  2%|▏         | 2498/100000 [44:58<23:08:20,  1.17it/s]

Epoch [2500/100000]
Train Loss: 10.2772
test Loss: 9.9119, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  3%|▎         | 2598/100000 [46:41<23:08:41,  1.17it/s] 

Epoch [2600/100000]
Train Loss: 10.3002
test Loss: 9.9121, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  3%|▎         | 2698/100000 [48:23<23:11:16,  1.17it/s] 

Epoch [2700/100000]
Train Loss: 10.2658
test Loss: 9.9387, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  3%|▎         | 2798/100000 [50:07<23:10:33,  1.17it/s] 

Epoch [2800/100000]
Train Loss: 10.2264
test Loss: 9.9310, test Accuracy: 0.0125
test F1 (micro): 0.0125, test F1 (macro): 0.0003


  3%|▎         | 2898/100000 [51:52<23:05:58,  1.17it/s] 

Epoch [2900/100000]
Train Loss: 10.3502
test Loss: 9.9088, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  3%|▎         | 2998/100000 [53:43<22:59:48,  1.17it/s] 

Epoch [3000/100000]
Train Loss: 10.3410
test Loss: 9.8685, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  3%|▎         | 3098/100000 [55:31<22:52:33,  1.18it/s] 

Epoch [3100/100000]
Train Loss: 10.2105
test Loss: 9.9348, test Accuracy: 0.0125
test F1 (micro): 0.0125, test F1 (macro): 0.0003


  3%|▎         | 3198/100000 [57:22<22:56:18,  1.17it/s] 

Epoch [3200/100000]
Train Loss: 10.2321
test Loss: 9.8927, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  3%|▎         | 3298/100000 [59:10<22:59:41,  1.17it/s] 

Epoch [3300/100000]
Train Loss: 10.3060
test Loss: 9.9776, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  3%|▎         | 3398/100000 [1:00:57<22:55:42,  1.17it/s]

Epoch [3400/100000]
Train Loss: 10.1084
test Loss: 9.8358, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  3%|▎         | 3498/100000 [1:02:46<22:59:27,  1.17it/s] 

Epoch [3500/100000]
Train Loss: 10.2265
test Loss: 9.9052, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  4%|▎         | 3598/100000 [1:04:45<22:59:22,  1.16it/s] 

Epoch [3600/100000]
Train Loss: 10.3402
test Loss: 9.9393, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  4%|▎         | 3698/100000 [1:06:34<22:53:06,  1.17it/s] 

Epoch [3700/100000]
Train Loss: 10.1987
test Loss: 9.8702, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  4%|▍         | 3798/100000 [1:08:21<22:52:52,  1.17it/s] 

Epoch [3800/100000]
Train Loss: 10.2467
test Loss: 9.9193, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  4%|▍         | 3898/100000 [1:10:06<22:54:20,  1.17it/s] 

Epoch [3900/100000]
Train Loss: 10.3373
test Loss: 9.9886, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  4%|▍         | 3998/100000 [1:11:54<22:40:42,  1.18it/s] 

Epoch [4000/100000]
Train Loss: 10.2399
test Loss: 9.8428, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  4%|▍         | 4098/100000 [1:13:41<22:46:31,  1.17it/s] 

Epoch [4100/100000]
Train Loss: 10.2230
test Loss: 9.8572, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  4%|▍         | 4198/100000 [1:15:25<22:36:26,  1.18it/s] 

Epoch [4200/100000]
Train Loss: 10.2443
test Loss: 9.8650, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  4%|▍         | 4298/100000 [1:17:11<23:02:30,  1.15it/s] 

Epoch [4300/100000]
Train Loss: 10.2075
test Loss: 9.8768, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  4%|▍         | 4398/100000 [1:19:01<22:45:16,  1.17it/s] 

Epoch [4400/100000]
Train Loss: 10.2353
test Loss: 9.8379, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  4%|▍         | 4498/100000 [1:20:46<22:39:32,  1.17it/s] 

Epoch [4500/100000]
Train Loss: 10.2091
test Loss: 9.8507, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  5%|▍         | 4598/100000 [1:22:37<22:46:02,  1.16it/s] 

Epoch [4600/100000]
Train Loss: 10.1832
test Loss: 9.8413, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  5%|▍         | 4698/100000 [1:24:30<22:40:11,  1.17it/s] 

Epoch [4700/100000]
Train Loss: 10.2155
test Loss: 9.8500, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  5%|▍         | 4798/100000 [1:26:24<22:40:32,  1.17it/s] 

Epoch [4800/100000]
Train Loss: 10.1529
test Loss: 9.8459, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  5%|▍         | 4898/100000 [1:28:09<22:31:48,  1.17it/s] 

Epoch [4900/100000]
Train Loss: 10.2520
test Loss: 9.8512, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  5%|▍         | 4998/100000 [1:30:03<22:25:56,  1.18it/s] 

Epoch [5000/100000]
Train Loss: 10.2844
test Loss: 9.8563, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  5%|▌         | 5098/100000 [1:31:51<22:37:43,  1.16it/s] 

Epoch [5100/100000]
Train Loss: 10.1657
test Loss: 9.8549, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  5%|▌         | 5198/100000 [1:33:38<22:31:31,  1.17it/s] 

Epoch [5200/100000]
Train Loss: 10.2654
test Loss: 9.8551, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  5%|▌         | 5298/100000 [1:35:25<22:29:26,  1.17it/s] 

Epoch [5300/100000]
Train Loss: 10.2141
test Loss: 9.8808, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  5%|▌         | 5398/100000 [1:37:12<22:23:32,  1.17it/s] 

Epoch [5400/100000]
Train Loss: 10.2586
test Loss: 9.9591, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  5%|▌         | 5498/100000 [1:38:58<22:24:38,  1.17it/s] 

Epoch [5500/100000]
Train Loss: 10.2305
test Loss: 9.8920, test Accuracy: 0.0250
test F1 (micro): 0.0250, test F1 (macro): 0.0006


  6%|▌         | 5598/100000 [1:40:47<22:24:21,  1.17it/s] 

Epoch [5600/100000]
Train Loss: 10.2416
test Loss: 9.8798, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  6%|▌         | 5698/100000 [1:42:35<22:26:40,  1.17it/s] 

Epoch [5700/100000]
Train Loss: 10.0899
test Loss: 9.8406, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  6%|▌         | 5798/100000 [1:44:23<22:36:22,  1.16it/s] 

Epoch [5800/100000]
Train Loss: 10.2330
test Loss: 9.9097, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  6%|▌         | 5898/100000 [1:46:12<22:24:46,  1.17it/s] 

Epoch [5900/100000]
Train Loss: 10.1574
test Loss: 9.9145, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  6%|▌         | 5998/100000 [1:48:06<22:15:16,  1.17it/s] 

Epoch [6000/100000]
Train Loss: 10.2457
test Loss: 9.8572, test Accuracy: 0.0250
test F1 (micro): 0.0250, test F1 (macro): 0.0006


  6%|▌         | 6098/100000 [1:49:55<22:18:36,  1.17it/s] 

Epoch [6100/100000]
Train Loss: 10.0967
test Loss: 9.8741, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  6%|▌         | 6198/100000 [1:51:42<22:16:54,  1.17it/s] 

Epoch [6200/100000]
Train Loss: 10.2032
test Loss: 9.9327, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  6%|▋         | 6298/100000 [1:53:31<22:17:31,  1.17it/s] 

Epoch [6300/100000]
Train Loss: 10.2815
test Loss: 9.9015, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  6%|▋         | 6398/100000 [1:55:18<22:11:22,  1.17it/s] 

Epoch [6400/100000]
Train Loss: 10.2280
test Loss: 9.9432, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  6%|▋         | 6498/100000 [1:57:05<22:17:43,  1.16it/s] 

Epoch [6500/100000]
Train Loss: 10.1315
test Loss: 9.8989, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  7%|▋         | 6598/100000 [1:58:56<22:15:26,  1.17it/s] 

Epoch [6600/100000]
Train Loss: 10.2961
test Loss: 9.9240, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  7%|▋         | 6698/100000 [2:00:42<22:10:21,  1.17it/s] 

Epoch [6700/100000]
Train Loss: 10.2215
test Loss: 9.9514, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  7%|▋         | 6798/100000 [2:02:32<22:04:41,  1.17it/s] 

Epoch [6800/100000]
Train Loss: 10.2397
test Loss: 9.9354, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  7%|▋         | 6898/100000 [2:04:25<22:05:18,  1.17it/s] 

Epoch [6900/100000]
Train Loss: 10.1755
test Loss: 9.9303, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  7%|▋         | 6998/100000 [2:06:09<22:04:58,  1.17it/s] 

Epoch [7000/100000]
Train Loss: 10.1580
test Loss: 9.9841, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


  7%|▋         | 7098/100000 [2:07:57<22:04:53,  1.17it/s] 

Epoch [7100/100000]
Train Loss: 10.1968
test Loss: 9.9468, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


 11%|█▏        | 11398/100000 [3:25:49<21:03:39,  1.17it/s]

Epoch [11400/100000]
Train Loss: 10.1088
test Loss: 10.0872, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


 11%|█▏        | 11498/100000 [3:27:36<21:05:29,  1.17it/s] 

Epoch [11500/100000]
Train Loss: 10.1869
test Loss: 10.1001, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


 12%|█▏        | 11598/100000 [3:29:22<20:58:33,  1.17it/s] 

Epoch [11600/100000]
Train Loss: 10.2427
test Loss: 10.1338, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


 12%|█▏        | 11698/100000 [3:31:09<21:04:51,  1.16it/s] 

Epoch [11700/100000]
Train Loss: 10.0887
test Loss: 10.0980, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


 12%|█▏        | 11798/100000 [3:32:54<21:10:32,  1.16it/s] 

Epoch [11800/100000]
Train Loss: 10.2298
test Loss: 10.0627, test Accuracy: 0.0000
test F1 (micro): 0.0000, test F1 (macro): 0.0000


 12%|█▏        | 11809/100000 [3:33:27<25:45:53,  1.05s/it] 

In [None]:
model, _, _, latest_epoch, metrics = load_checkpoint(model)

val_batches_ = [virus_da.get_batch() for _ in range(num_val // 2)] + [cellular_da.get_batch() for _ in range(num_val // 2)]

# input_sequences_ = [e['Sequence'] for b in val_batches_ for e in b]
# labels_ = [encode_lineage(e['Taxonomic_lineage__ALL_'])  for b in val_batches_ for e in b]

input_sequences_ = ["ACACAD"]
labels_ = [{0: 1}]

def evaluate_df(model):
    model.eval()  # Set model to evaluation mode
    
    df = {
        "sequence": [],
        "label": [],
        "pred": [],
        "loss": []
    }

    metrics = {
        "loss": 0,
        "accuracy": 0,
        "f1 macro": 0,
        "f1 micro": 0
    }
    
    # Process each sequence
    for sequence, label in zip(input_sequences_, labels_):
        inputs = tokenizer_(
            [sequence],
            return_tensors="pt",
            padding='max_length',
            truncation=True,
            max_length=max_seq_len
        ).to(device)
    
        # Get model output
        with torch.no_grad():
            output = model(inputs['input_ids'], inputs['attention_mask'])

        pred = output.argmax(dim=-1).cpu().item()
        loss = criterion(output, torch.tensor([label[0]]).to(device))
        df["sequence"].append(sequence)
        df["label"].append(level_decoder[0][label[0]])
        df["pred"].append(level_decoder[0][pred])
        df["loss"].append(round(loss.cpu().item(), 4))

    # Convert to DataFrame
    new_df = pd.DataFrame(df)
    new_df['is_incorrect'] = new_df['label'] != new_df['pred']
    new_df = new_df.sort_values(['is_incorrect', 'loss'], ascending=[False, False])
    new_df.to_csv(f'classification_results__new_att.csv', index=False)

    metrics["loss"] = np.array(df["loss"]).mean()
    metrics["accuracy"] = accuracy_score(np.array(df["label"]), np.array(df["pred"]))
    metrics["f1 macro"] = f1_score(np.array(df["label"]), np.array(df["pred"]), average='macro')  # F1-score for multi-label classification
    metrics["f1 micro"] = f1_score(np.array(df["label"]), np.array(df["pred"]), average='micro') 
    print(metrics)

evaluate_df(model)

[1;34mwandb[0m: 🚀 View run [33mlight-terrain-11[0m at: [34mhttps://wandb.ai/alireza_noroozi/Finetune_ESM/runs/nvokgtnu[0m
[1;34mwandb[0m: Find logs at: [1;35mwandb/run-20241205_144714-nvokgtnu/logs[0m
