In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import torch.optim as optim
from tqdm import tqdm
import sys, os, math
import wandb
from sklearn.metrics import f1_score, confusion_matrix, accuracy_score
import json
from transformers import EsmModel, AutoTokenizer, AutoModel

sys.path.insert(0, '../dlp')
from batch import Batch

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

epochs = 100_000
val_epoch = 100
num_val = 10
batch_size = 64
dataset_name = "corpus_1000_Viruses_cellular"
lr = 0.001
model_name = "Guess_M"
max_seq_len = 1000

from data_access import PQDataAccess
da = PQDataAccess(f"/home/aac/Alireza/datasets/export_pqt_4_taxseq_new/{dataset_name}", batch_size)

checkpoint_dir = f"../checkpoints/{model_name}_checkpoints"
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
print(checkpoint_dir)

wandb.init(
    # set the wandb project where this run will be logged
    project=model_name,

    # track hyperparameters and run metadata
    config={
        "learning_rate": lr,
        "architecture": model_name,
        "epochs": epochs,
        "batch_szie": batch_size,
        "max_seq_len": max_seq_len
    }
)

cuda:0
../checkpoints/Guess_M_checkpoints


0,1
epoch,▁▅█
lr,██▁
train_loss,█▁▁
val_accuracy,▁█▆
val_f1_macro,▁█▆
val_f1_micro,▁█▆
val_loss,█▁▁

0,1
epoch,30.0
lr,0.0001
train_loss,0.25196
val_accuracy,0.86094
val_f1_macro,0.57485
val_f1_micro,0.86094
val_loss,0.29141


In [23]:
class ESM1b(nn.Module):
    def __init__(self):
        super().__init__()
        self.esm = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
        
        for param in self.esm.parameters():
            param.requires_grad = False
        
        self.attention = nn.Sequential(
            nn.Linear(1280, 256),
            nn.Tanh(),
            nn.Linear(256, 1)
        )
        
        self.layer1 = nn.Linear(1280, 512)
        self.layer_norm = nn.LayerNorm(512)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.1)
        self.layer2 = nn.Linear(512, 2)

    def attention_pooling(self, x):
        # x shape: (batch_size, seq_length, embedding_dim)
        attention_weights = self.attention(x)  # (batch_size, seq_length, 1)
        attention_weights = torch.softmax(attention_weights.squeeze(-1), dim=1)  # (batch_size, seq_length)
        attention_weights = attention_weights.unsqueeze(-1)  # (batch_size, seq_length, 1)
        pooled = torch.sum(x * attention_weights, dim=1)  # (batch_size, embedding_dim)
        return pooled

    def forward(self, x, attention_mask=None):
        outputs = self.esm(x, attention_mask=attention_mask).last_hidden_state
        outputs = self.attention_pooling(outputs)
        
        outputs = self.layer1(outputs)
        outputs = self.layer_norm(outputs)
        outputs = self.relu(outputs)
        outputs = self.dropout(outputs)
        outputs = self.layer2(outputs)
        return outputs

In [24]:
model = ESM1b().to(device)

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f'Trainable parameters: {trainable/ 1e6} M')
print(f'Total parameters: {total/ 1e6} M')
# print(model)

optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)

criterion = nn.CrossEntropyLoss()
# Cosine annealing with warm restarts
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=10,  # Period of learning rate decay
    gamma=0.1  # Multiplicative factor of decay
)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Trainable parameters: 0.986115 M
Total parameters: 653.340056 M


In [25]:
tokenizer_ = AutoTokenizer.from_pretrained(f"facebook/esm1b_t33_650M_UR50S")

def data_to_tensor_batch(b, max_seq_len=max_seq_len):
    inputs = tokenizer_(
        [e['Sequence'] for e in b],
        return_tensors="pt", 
        padding='max_length', 
        truncation=True, 
        max_length=max_seq_len
    )

    labels = torch.LongTensor([
        1 if e['Sequence'].startswith('M') else 0 for e in b
    ])
    
    return Batch(inputs, labels)

In [26]:
val_batches = [da.get_batch() for _ in range(num_val)]


def evaluate(model):
    model.eval()  # Set model to evaluation mode
    
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    
    for epoch in range(num_val):
        with torch.no_grad():  # Disable gradient computation during evaluation
            tensor_batch = data_to_tensor_batch(val_batches[epoch], max_seq_len)
            tensor_batch.gpu(device)
            labels = tensor_batch.taxes
            
        outputs = model(tensor_batch.seq_ids['input_ids'], tensor_batch.seq_ids['attention_mask'])
        loss = criterion(outputs, labels)
        preds = torch.argmax(outputs, dim=1)

        running_loss += loss.item()
        all_preds.append(preds.cpu())
        all_labels.append(labels.cpu())
    
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)


    accuracy = accuracy_score(all_labels.numpy(), all_preds.numpy())
    f1_macro = f1_score(all_labels.numpy(), all_preds.numpy(), average='macro')
    f1_micro = f1_score(all_labels.numpy(), all_preds.numpy(), average='micro')
    conf_matrix = confusion_matrix(all_labels.numpy(), all_preds.numpy())
    avg_loss = running_loss / num_val

    return avg_loss, accuracy, f1_micro, f1_macro, conf_matrix


# evaluate(model)

In [27]:
import glob
def load_checkpoint(model, optimizer=None, scheduler=None):
    checkpoints = glob.glob(os.path.join(checkpoint_dir, 'checkpoint_epoch_*.pt'))        
    # Extract epoch numbers and find latest
    latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('_')[-1].split('.')[0]))
    checkpoint = torch.load(latest_checkpoint)
    
    # Load model state
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    
    # Load optimizer state if provided (for training)
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        # Move optimizer state to GPU if necessary
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
    if scheduler is not None:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    # Get training metadata
    epoch = checkpoint['epoch']
    metrics = checkpoint['metrics']
    
    print(f"Successfully loaded checkpoint from epoch {epoch}")
    # print("Metrics at checkpoint:", metrics)
    
    return model, optimizer, scheduler, epoch, metrics
        

# model, optimizer, scheduler, latest_epoch, metrics = load_checkpoint(model, optimizer, scheduler)
latest_epoch = 0

In [None]:
running_loss = 0
current_lr = lr

for epoch in tqdm(range(latest_epoch, latest_epoch + epochs)):
    model.train()
    
    tensor_batch = data_to_tensor_batch(
        da.get_batch(),
        max_seq_len,
    )
    tensor_batch.gpu(device)
    
    labels = tensor_batch.taxes
    outputs = model(tensor_batch.seq_ids['input_ids'], tensor_batch.seq_ids['attention_mask'])
    
    loss = criterion(outputs, labels)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    running_loss += loss.item()
    
    if (epoch + 1) % val_epoch == 0:
        train_loss = running_loss / val_epoch
        # Evaluate on validation set
        val_loss, val_accuracy, val_f1_micro, val_f1_macro, val_cm = evaluate(model)
        
        print(f"Epoch [{epoch + 1}/{epochs}]")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
        print(f"Val F1 (micro): {val_f1_micro:.4f}, Val F1 (macro): {val_f1_macro:.4f}")
        print("Val Confusion Matrix:")
        print(val_cm)
        
        # Create metrics dictionary for saving
        metrics = {
            "train_loss": train_loss,
            "val_loss":val_loss,
            "val_accuracy": val_accuracy,
            "val_f1_micro": val_f1_micro,
            "val_f1_macro": val_f1_macro,
            "epoch": epoch + 1,
            "lr": current_lr
        }

        # Save periodic checkpoint
        checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'metrics': metrics
        }, checkpoint_path)
        
        # Log to wandb
        wandb.log(metrics)

        # Step the scheduler
        scheduler.step(epoch + loss.item())
        current_lr = scheduler.get_last_lr()[0]
        
        # Reset training metrics
        running_loss = 0

wandb.finish()

  0%|          | 99/100000 [07:30<126:36:26,  4.56s/it]

Epoch [100/100000]
Train Loss: 0.2389
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  0%|          | 199/100000 [15:58<126:09:14,  4.55s/it]

Epoch [200/100000]
Train Loss: 0.1061
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  0%|          | 299/100000 [24:24<126:01:39,  4.55s/it]

Epoch [300/100000]
Train Loss: 0.0798
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  0%|          | 399/100000 [32:52<125:51:07,  4.55s/it]

Epoch [400/100000]
Train Loss: 0.0496
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  0%|          | 499/100000 [41:19<126:18:55,  4.57s/it]

Epoch [500/100000]
Train Loss: 0.0775
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  1%|          | 599/100000 [49:45<125:47:32,  4.56s/it]

Epoch [600/100000]
Train Loss: 0.1134
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  1%|          | 699/100000 [58:12<125:29:01,  4.55s/it]

Epoch [700/100000]
Train Loss: 0.0635
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  1%|          | 799/100000 [1:06:39<125:34:47,  4.56s/it]

Epoch [800/100000]
Train Loss: 0.0588
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  2%|▏         | 1899/100000 [2:39:40<124:13:42,  4.56s/it]

Epoch [1900/100000]
Train Loss: 0.2148
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  2%|▏         | 1999/100000 [2:48:07<124:03:02,  4.56s/it]

Epoch [2000/100000]
Train Loss: 0.2671
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  2%|▏         | 2099/100000 [2:56:35<124:26:31,  4.58s/it]

Epoch [2100/100000]
Train Loss: 0.3298
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  3%|▎         | 3199/100000 [4:29:35<122:20:27,  4.55s/it]

Epoch [3200/100000]
Train Loss: 0.0735
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  3%|▎         | 3299/100000 [4:38:02<122:13:03,  4.55s/it]

Epoch [3300/100000]
Train Loss: 0.0778
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  3%|▎         | 3399/100000 [4:46:29<122:09:06,  4.55s/it]

Epoch [3400/100000]
Train Loss: 0.0636
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  3%|▎         | 3499/100000 [4:54:56<121:58:27,  4.55s/it]

Epoch [3500/100000]
Train Loss: 0.1437
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  4%|▎         | 3599/100000 [5:03:23<121:51:57,  4.55s/it]

Epoch [3600/100000]
Train Loss: 0.0488
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  4%|▎         | 3699/100000 [5:11:49<121:44:34,  4.55s/it]

Epoch [3700/100000]
Train Loss: 0.0687
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  4%|▍         | 3799/100000 [5:20:16<121:32:09,  4.55s/it]

Epoch [3800/100000]
Train Loss: 0.0670
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  4%|▍         | 3899/100000 [5:28:43<121:23:35,  4.55s/it]

Epoch [3900/100000]
Train Loss: 0.0839
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  4%|▍         | 3999/100000 [5:37:11<121:23:31,  4.55s/it]

Epoch [4000/100000]
Train Loss: 0.1054
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  4%|▍         | 4099/100000 [5:45:38<121:20:48,  4.56s/it]

Epoch [4100/100000]
Train Loss: 0.2341
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  4%|▍         | 4199/100000 [5:54:06<121:23:18,  4.56s/it]

Epoch [4200/100000]
Train Loss: 0.2160
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  4%|▍         | 4299/100000 [6:02:32<121:11:11,  4.56s/it]

Epoch [4300/100000]
Train Loss: 0.2734
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  4%|▍         | 4399/100000 [6:10:59<121:15:55,  4.57s/it]

Epoch [4400/100000]
Train Loss: 0.3621
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  4%|▍         | 4499/100000 [6:19:26<120:57:37,  4.56s/it]

Epoch [4500/100000]
Train Loss: 0.4306
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  5%|▍         | 4599/100000 [6:27:55<120:49:39,  4.56s/it]

Epoch [4600/100000]
Train Loss: 0.5240
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  5%|▍         | 4699/100000 [6:36:21<120:29:35,  4.55s/it]

Epoch [4700/100000]
Train Loss: 0.1181
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  5%|▍         | 4799/100000 [6:44:48<120:24:49,  4.55s/it]

Epoch [4800/100000]
Train Loss: 0.1191
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  5%|▍         | 4999/100000 [7:01:42<119:58:49,  4.55s/it]

Epoch [5000/100000]
Train Loss: 0.0499
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  5%|▌         | 5099/100000 [7:10:09<119:52:41,  4.55s/it]

Epoch [5100/100000]
Train Loss: 0.0821
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  5%|▌         | 5199/100000 [7:18:36<119:50:37,  4.55s/it]

Epoch [5200/100000]
Train Loss: 0.1068
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  5%|▌         | 5299/100000 [7:27:03<119:53:29,  4.56s/it]

Epoch [5300/100000]
Train Loss: 0.0637
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  5%|▌         | 5399/100000 [7:35:30<119:41:43,  4.55s/it]

Epoch [5400/100000]
Train Loss: 0.0601
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  5%|▌         | 5499/100000 [7:43:58<119:18:11,  4.54s/it]

Epoch [5500/100000]
Train Loss: 0.0771
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  6%|▌         | 5599/100000 [7:52:26<119:20:08,  4.55s/it]

Epoch [5600/100000]
Train Loss: 0.0757
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  6%|▌         | 5699/100000 [8:00:54<119:13:11,  4.55s/it]

Epoch [5700/100000]
Train Loss: 0.0653
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  6%|▌         | 5799/100000 [8:09:21<119:03:29,  4.55s/it]

Epoch [5800/100000]
Train Loss: 0.1327
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  6%|▌         | 5899/100000 [8:17:47<119:04:26,  4.56s/it]

Epoch [5900/100000]
Train Loss: 0.0500
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  6%|▌         | 5999/100000 [8:26:14<119:03:15,  4.56s/it]

Epoch [6000/100000]
Train Loss: 0.0636
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  6%|▌         | 6199/100000 [8:43:09<118:36:14,  4.55s/it]

Epoch [6200/100000]
Train Loss: 0.0998
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  6%|▋         | 6299/100000 [8:51:35<118:33:39,  4.56s/it]

Epoch [6300/100000]
Train Loss: 0.1246
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  6%|▋         | 6399/100000 [9:00:04<118:57:58,  4.58s/it]

Epoch [6400/100000]
Train Loss: 0.2217
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  6%|▋         | 6499/100000 [9:08:31<118:18:18,  4.56s/it]

Epoch [6500/100000]
Train Loss: 0.2282
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  7%|▋         | 6599/100000 [9:16:58<118:03:20,  4.55s/it]

Epoch [6600/100000]
Train Loss: 0.2776
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  7%|▋         | 6699/100000 [9:25:25<118:03:28,  4.56s/it]

Epoch [6700/100000]
Train Loss: 0.3428
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  7%|▋         | 6799/100000 [9:33:53<118:10:33,  4.56s/it]

Epoch [6800/100000]
Train Loss: 0.4709
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  7%|▋         | 6899/100000 [9:42:20<117:31:09,  4.54s/it]

Epoch [6900/100000]
Train Loss: 0.4413
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  7%|▋         | 6999/100000 [9:50:48<117:22:00,  4.54s/it]

Epoch [7000/100000]
Train Loss: 0.1147
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  7%|▋         | 7099/100000 [9:59:14<117:29:32,  4.55s/it]

Epoch [7100/100000]
Train Loss: 0.1220
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  7%|▋         | 7199/100000 [10:07:41<117:26:14,  4.56s/it]

Epoch [7200/100000]
Train Loss: 0.0767
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  7%|▋         | 7299/100000 [10:16:08<117:17:45,  4.56s/it]

Epoch [7300/100000]
Train Loss: 0.0525
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  7%|▋         | 7399/100000 [10:24:36<117:14:58,  4.56s/it]

Epoch [7400/100000]
Train Loss: 0.0913
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  7%|▋         | 7499/100000 [10:33:03<117:03:09,  4.56s/it]

Epoch [7500/100000]
Train Loss: 0.1063
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  8%|▊         | 7599/100000 [10:41:29<116:57:20,  4.56s/it]

Epoch [7600/100000]
Train Loss: 0.0578
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  8%|▊         | 7699/100000 [10:49:56<116:43:32,  4.55s/it]

Epoch [7700/100000]
Train Loss: 0.0637
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  8%|▊         | 7799/100000 [10:58:24<116:32:17,  4.55s/it]

Epoch [7800/100000]
Train Loss: 0.0718
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  8%|▊         | 7899/100000 [11:06:51<116:30:22,  4.55s/it]

Epoch [7900/100000]
Train Loss: 0.0764
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  8%|▊         | 7999/100000 [11:15:18<116:32:43,  4.56s/it]

Epoch [8000/100000]
Train Loss: 0.0745
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  8%|▊         | 8099/100000 [11:23:45<116:11:37,  4.55s/it]

Epoch [8100/100000]
Train Loss: 0.1376
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  8%|▊         | 8199/100000 [11:32:12<116:03:57,  4.55s/it]

Epoch [8200/100000]
Train Loss: 0.0512
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  8%|▊         | 8299/100000 [11:40:38<116:05:33,  4.56s/it]

Epoch [8300/100000]
Train Loss: 0.0737
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  8%|▊         | 8399/100000 [11:49:05<116:12:55,  4.57s/it]

Epoch [8400/100000]
Train Loss: 0.0612
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  8%|▊         | 8499/100000 [11:57:31<115:37:06,  4.55s/it]

Epoch [8500/100000]
Train Loss: 0.1019
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  9%|▊         | 8599/100000 [12:05:58<115:40:03,  4.56s/it]

Epoch [8600/100000]
Train Loss: 0.1159
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  9%|▊         | 8699/100000 [12:14:27<115:40:51,  4.56s/it]

Epoch [8700/100000]
Train Loss: 0.2085
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  9%|▉         | 8799/100000 [12:22:55<115:29:09,  4.56s/it]

Epoch [8800/100000]
Train Loss: 0.2131
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  9%|▉         | 8899/100000 [12:31:22<115:16:18,  4.56s/it]

Epoch [8900/100000]
Train Loss: 0.2776
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  9%|▉         | 8999/100000 [12:39:49<115:16:07,  4.56s/it]

Epoch [9000/100000]
Train Loss: 0.3656
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  9%|▉         | 9099/100000 [12:48:17<115:08:36,  4.56s/it]

Epoch [9100/100000]
Train Loss: 0.4931
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  9%|▉         | 9199/100000 [12:56:45<114:50:42,  4.55s/it]

Epoch [9200/100000]
Train Loss: 0.4427
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  9%|▉         | 9299/100000 [13:05:12<114:28:15,  4.54s/it]

Epoch [9300/100000]
Train Loss: 0.1148
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  9%|▉         | 9399/100000 [13:13:38<114:28:47,  4.55s/it]

Epoch [9400/100000]
Train Loss: 0.1113
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


  9%|▉         | 9499/100000 [13:22:05<114:27:14,  4.55s/it]

Epoch [9500/100000]
Train Loss: 0.0702
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


 10%|▉         | 9599/100000 [13:30:31<114:16:26,  4.55s/it]

Epoch [9600/100000]
Train Loss: 0.0498
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


 10%|▉         | 9699/100000 [13:38:59<114:20:19,  4.56s/it]

Epoch [9700/100000]
Train Loss: 0.0871
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


 10%|▉         | 9799/100000 [13:47:25<113:55:13,  4.55s/it]

Epoch [9800/100000]
Train Loss: 0.1039
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


 10%|▉         | 9899/100000 [13:55:53<113:52:07,  4.55s/it]

Epoch [9900/100000]
Train Loss: 0.0659
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


 10%|▉         | 9999/100000 [14:04:19<114:12:05,  4.57s/it]

Epoch [10000/100000]
Train Loss: 0.0576
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


 10%|█         | 10099/100000 [14:12:46<115:49:37,  4.64s/it]

Epoch [10100/100000]
Train Loss: 0.0743
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


 10%|█         | 10199/100000 [14:21:13<113:28:19,  4.55s/it]

Epoch [10200/100000]
Train Loss: 0.0750
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


 10%|█         | 10399/100000 [14:38:07<113:31:34,  4.56s/it]

Epoch [10400/100000]
Train Loss: 0.1350
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


 10%|█         | 10499/100000 [14:46:34<113:08:05,  4.55s/it]

Epoch [10500/100000]
Train Loss: 0.0468
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


 11%|█         | 10599/100000 [14:55:00<113:02:19,  4.55s/it]

Epoch [10600/100000]
Train Loss: 0.0707
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


 11%|█         | 10799/100000 [15:11:53<112:44:34,  4.55s/it]

Epoch [10800/100000]
Train Loss: 0.1098
Val Loss: 0.1451, Val Accuracy: 0.9437
Val F1 (micro): 0.9437, Val F1 (macro): 0.9005
Val Confusion Matrix:
[[ 91  10]
 [ 26 513]]


 11%|█         | 10899/100000 [15:20:21<112:53:07,  4.56s/it]

In [None]:
model, _, _, latest_epoch, metrics = load_checkpoint(model)

val_batches_ = [da.get_batch() for _ in range(num_val)]


input_sequences_ = ["ACACAD"]
labels_ = [{0: 1}]

def evaluate_df(model):
    model.eval()  # Set model to evaluation mode
    
    df = {
        "sequence": [],
        "label": [],
        "pred": [],
        "loss": []
    }

    metrics = {
        "loss": 0,
        "accuracy": 0,
        "f1 macro": 0,
        "f1 micro": 0
    }
    
    # Process each sequence
    for sequence, label in zip(input_sequences_, labels_):
        inputs = tokenizer_(
            [sequence],
            return_tensors="pt",
            padding='max_length',
            truncation=True,
            max_length=max_seq_len
        ).to(device)
    
        # Get model output
        with torch.no_grad():
            output = model(inputs['input_ids'], inputs['attention_mask'])

        pred = output.argmax(dim=-1).cpu().item()
        loss = criterion(output, torch.tensor([label[0]]).to(device))
        df["sequence"].append(sequence)
        df["label"].append(level_decoder[0][label[0]])
        df["pred"].append(level_decoder[0][pred])
        df["loss"].append(round(loss.cpu().item(), 4))

    # Convert to DataFrame
    new_df = pd.DataFrame(df)
    new_df['is_incorrect'] = new_df['label'] != new_df['pred']
    new_df = new_df.sort_values(['is_incorrect', 'loss'], ascending=[False, False])
    new_df.to_csv(f'classification_results__new_att.csv', index=False)

    metrics["loss"] = np.array(df["loss"]).mean()
    metrics["accuracy"] = accuracy_score(np.array(df["label"]), np.array(df["pred"]))
    metrics["f1 macro"] = f1_score(np.array(df["label"]), np.array(df["pred"]), average='macro')  # F1-score for multi-label classification
    metrics["f1 micro"] = f1_score(np.array(df["label"]), np.array(df["pred"]), average='micro') 
    print(metrics)

evaluate_df(model)