In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import torch.optim as optim
from tqdm import tqdm
import sys, os, math
import wandb
from sklearn.metrics import f1_score, confusion_matrix, accuracy_score
import json
from transformers import EsmModel, AutoTokenizer

sys.path.insert(0, '../dlp')
from batch import Batch

device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
print(device)

epochs = 100_000
val_epoch = 10
num_val = 10
batch_size = 64
virus_dataset_name = "corpus_1000_Viruses"
cellular_dataset_name = "corpus_1000_cellular"
lr = 0.001
model_name = "OnlyFirst"
max_seq_len = 1000

from data_access import PQDataAccess
virus_da = PQDataAccess(f"/home/aac/Alireza/datasets/export_pqt_4_taxseq_new/{virus_dataset_name}", batch_size)
cellular_da = PQDataAccess(f"/home/aac/Alireza/datasets/export_pqt_4_taxseq_new/{cellular_dataset_name}", batch_size)

checkpoint_dir = f"../checkpoints/{model_name}_checkpoints"
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
print(checkpoint_dir)

wandb.init(
    # set the wandb project where this run will be logged
    project=model_name,

    # track hyperparameters and run metadata
    config={
        "learning_rate": lr,
        "architecture": model_name,
        "epochs": epochs,
        "batch_szie": batch_size,
        "max_seq_len": max_seq_len
    }
)

  from .autonotebook import tqdm as notebook_tqdm


cuda:1
 WORLD_SIZE=1 , LOCAL_WORLD_SIZE=1,RANK =0,LOCAL_RANK = 0 
../checkpoints/OnlyFirst_checkpoints


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33malirezanor[0m ([33malireza_noroozi[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
class ESM1b(nn.Module):
    def __init__(self):
        super().__init__()
        self.esm = EsmModel.from_pretrained("facebook/esm1b_t33_650M_UR50S")
        
        # Freeze ESM parameters
        for param in self.esm.parameters():
            param.requires_grad = False
        
        
        self.attention = nn.Sequential(
            nn.Linear(1280, 256),
            nn.Tanh(),
            nn.Linear(256, 1)
        )
        
        self.layer1 = nn.Linear(1280, 512)
        self.layer_norm = nn.LayerNorm(512)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.1)
        self.layer2 = nn.Linear(512, 3)

    def attention_pooling(self, x):
        # x shape: (batch_size, seq_length, embedding_dim)
        attention_weights = self.attention(x)  # (batch_size, seq_length, 1)
        attention_weights = torch.softmax(attention_weights.squeeze(-1), dim=1)  # (batch_size, seq_length)
        attention_weights = attention_weights.unsqueeze(-1)  # (batch_size, seq_length, 1)
        pooled = torch.sum(x * attention_weights, dim=1)  # (batch_size, embedding_dim)
        return pooled

    def forward(self, x, attention_mask=None):
        outputs = self.esm(x, attention_mask=attention_mask).last_hidden_state
        outputs = self.attention_pooling(outputs)
        
        outputs = self.layer1(outputs)
        outputs = self.layer_norm(outputs)
        outputs = self.relu(outputs)
        outputs = self.dropout(outputs)
        outputs = self.layer2(outputs)
        return outputs

In [3]:
model = ESM1b().to(device)

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f'Trainable parameters: {trainable/ 1e6} M')
print(f'Total parameters: {total/ 1e6} M')
print(model)

optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)

criterion = nn.CrossEntropyLoss()
# Cosine annealing with warm restarts
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer,
    T_0=10,  # Initial restart interval
    T_mult=2,  # Multiply interval by 2 after each restart
    eta_min=1e-6  # Minimum learning rate
)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm1b_t33_650M_UR50S and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Trainable parameters: 0.986628 M
Total parameters: 653.343129 M
ESM1b(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 1280, padding_idx=1)
      (layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 1280, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-32): 33 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=1280, out_features=1280, bias=True)
              (key): Linear(in_features=1280, out_features=1280, bias=True)
              (value): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inplace

In [4]:
import random
index2name_file = "../data/taxonomy_index.json"

# Check if the file exists
if os.path.exists(index2name_file):
    with open(index2name_file, "rb") as f:
        index2name = json.load(f)

tax_vocab_sizes = {
    int(k): len(v) for k,v in index2name.items()
}

# print(tax_vocab_sizes)
# # Print tax_vocab_sizes sorted by value (number of taxa per rank)
# sorted_sizes = dict(sorted(tax_vocab_sizes.items(), key=lambda x: x[1], reverse=True))
# print("\nTaxonomic ranks sorted by number of taxa:")
# for rank, size in sorted_sizes.items():
#     print(f"{rank}: {size}")

level_encoder = {
    int(k): {name: idx + 1 for idx, name in enumerate(v)} for k,v in index2name.items()
}

level_decoder = {
    int(k): {idx + 1: name for idx, name in enumerate(v)} for k,v in index2name.items()
}

for k, v in level_decoder.items():
    level_decoder[k][0] = "NOT DEFINED"


def encode_lineage(lineage_str):
    taxes_str = lineage_str.split(", ")

    encoded = {int(k): 0 for k in index2name.keys()}
    
    for i, tax_str in enumerate(taxes_str):
        encoded[i] = level_encoder[i][tax_str]

    return encoded

tokenizer_ = AutoTokenizer.from_pretrained(f"facebook/esm1b_t33_650M_UR50S")

def mix_data_to_tensor_batch(b_virues, b_cellular, max_seq_len=max_seq_len, partition=0.25):
    if partition == -1:
        b = b_virues + b_cellular
    else:
        split_point = int(len(b_virues) * partition)
        b = b_virues[:split_point] + b_cellular[-len(b_virues) + split_point:]
        random.shuffle(b)  # In-place shuffle
    
    inputs = tokenizer_(
        [e['Sequence'] for e in b],
        return_tensors="pt", 
        padding='max_length', 
        truncation=True, 
        max_length=max_seq_len
    )

    tax_ids = [encode_lineage(e['Taxonomic_lineage__ALL_']) for e in b]
    combined_dict = {}
    for d in tax_ids:
        for key, value in d.items():
            combined_dict.setdefault(key, []).append(value)

    tensor_encoded = {k: torch.LongTensor(v) for k,v in combined_dict.items()}
    return Batch(inputs, tensor_encoded)

In [5]:
val_dir = f"val_results/{model_name}"
if not os.path.exists(val_dir):
    os.makedirs(val_dir)
    
val_batches = [virus_da.get_batch() for _ in range(num_val)]
cell_val_batches = [cellular_da.get_batch() for _ in range(num_val)]


def evaluate(model):
    model.eval()  # Set model to evaluation mode
    
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    for epoch in range(num_val):
        with torch.no_grad():  # Disable gradient computation during evaluation
            tensor_batch = mix_data_to_tensor_batch(val_batches[epoch], cell_val_batches[epoch], max_seq_len, partition=-1)
            tensor_batch.gpu(device)
        
            labels = tensor_batch.taxes[0]
            outputs = model(tensor_batch.seq_ids['input_ids'], tensor_batch.seq_ids['attention_mask'])

            # Calculate the loss
            loss = criterion(outputs, labels)
    
            running_loss += loss.item()
                
            preds = torch.argmax(outputs, dim=1)
    
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())
    
    # Concatenate all batches into single tensors
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    # Compute evaluation metrics (example: accuracy, F1 score)
    accuracy = accuracy_score(all_labels.numpy(), all_preds.numpy())
    f1_macro = f1_score(all_labels.numpy(), all_preds.numpy(), average='macro')  # F1-score for multi-label classification
    f1_micro = f1_score(all_labels.numpy(), all_preds.numpy(), average='micro')  # F1-score for multi-label classification
    conf_matrix = confusion_matrix(all_labels.numpy(), all_preds.numpy())
    avg_loss = running_loss / num_val
    
    return avg_loss, accuracy, f1_micro, f1_macro, conf_matrix

In [6]:
import glob
def load_checkpoint(model, optimizer=None, scheduler=None):
    checkpoints = glob.glob(os.path.join(checkpoint_dir, 'checkpoint_epoch_*.pt'))        
    # Extract epoch numbers and find latest
    latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('_')[-1].split('.')[0]))
    checkpoint = torch.load(latest_checkpoint)
    
    # Load model state
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    
    # Load optimizer state if provided (for training)
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        # Move optimizer state to GPU if necessary
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
    if scheduler is not None:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    # Get training metadata
    epoch = checkpoint['epoch']
    metrics = checkpoint['metrics']
    
    print(f"Successfully loaded checkpoint from epoch {epoch}")
    # print("Metrics at checkpoint:", metrics)
    
    return model, optimizer, scheduler, epoch, metrics
        

model, optimizer, scheduler, latest_epoch, metrics = load_checkpoint(model, optimizer, scheduler)

Successfully loaded checkpoint from epoch 229


In [7]:
def get_partition_ratio(epoch, decay_epochs=100000):
    """
    Calculate partition ratio that decreases from 8/16 to 1/16 in steps
    """
    # Calculate how many epochs before each step down
    epochs_per_step = decay_epochs // 7  # 7 steps from 8/16 down to 1/16
    
    # Calculate current step based on epoch
    step = min(epoch // epochs_per_step, 7)  # Max 7 steps down from 8
    
    # Map step to fraction
    fraction = (8 - step) / 16
    
    return fraction

In [None]:
running_loss = 0
current_lr = lr

for epoch in tqdm(range(latest_epoch, latest_epoch + epochs)):
    model.train()

    current_partition = get_partition_ratio(epoch)
    
    tensor_batch = mix_data_to_tensor_batch(
        virus_da.get_batch(),
        cellular_da.get_batch(),
        max_seq_len,
        partition=current_partition
    )
    tensor_batch.gpu(device)
    
    labels = tensor_batch.taxes[0]
    outputs = model(tensor_batch.seq_ids['input_ids'], tensor_batch.seq_ids['attention_mask'])
    
    loss = criterion(outputs, labels)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    running_loss += loss.item()
    
    if (epoch + 1) % val_epoch == 0:
        train_loss = running_loss / val_epoch
        # Evaluate on validation set
        val_loss, val_accuracy, val_f1_micro, val_f1_macro, val_cm = evaluate(model)
        
        print(f"Epoch [{epoch + 1}/{epochs}]")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
        print(f"Val F1 (micro): {val_f1_micro:.4f}, Val F1 (macro): {val_f1_macro:.4f}")
        print("Val Confusion Matrix:")
        print(val_cm)
        
        # Create metrics dictionary for saving
        metrics = {
            "train_loss": train_loss,
            "val_loss": val_loss,
            "val_accuracy": val_accuracy,
            "val_f1_micro": val_f1_micro,
            "val_f1_macro": val_f1_macro,
            "epoch": epoch + 1,
            "current_portion": current_partition,
            "lr": current_lr
        }

        # Save periodic checkpoint
        checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'metrics': metrics
        }, checkpoint_path)
        
        # Log to wandb
        wandb.log(metrics)

        # Step the scheduler
        scheduler.step(epoch + loss.item())
        current_lr = scheduler.get_last_lr()[0]
        
        # Reset training metrics
        running_loss = 0

wandb.finish()

  0%|          | 0/100000 [00:00<?, ?it/s]

Epoch [230/100000]
Train Loss: 0.0212
Val Loss: 0.2465, Val Accuracy: 0.9047
Val F1 (micro): 0.9047, Val F1 (macro): 0.9041
Val Confusion Matrix:
[[530 110]
 [ 12 628]]


  0%|          | 10/100000 [02:11<149:37:16,  5.39s/it]

Epoch [240/100000]
Train Loss: 0.2307
Val Loss: 0.2489, Val Accuracy: 0.9023
Val F1 (micro): 0.9023, Val F1 (macro): 0.9022
Val Confusion Matrix:
[[601  39]
 [ 86 554]]


  0%|          | 20/100000 [04:25<150:15:38,  5.41s/it]

Epoch [250/100000]
Train Loss: 0.2191
Val Loss: 0.2130, Val Accuracy: 0.9125
Val F1 (micro): 0.9125, Val F1 (macro): 0.9123
Val Confusion Matrix:
[[553  87]
 [ 25 615]]


  0%|          | 30/100000 [06:38<150:16:08,  5.41s/it]

Epoch [260/100000]
Train Loss: 0.2123
Val Loss: 0.2108, Val Accuracy: 0.9164
Val F1 (micro): 0.9164, Val F1 (macro): 0.9164
Val Confusion Matrix:
[[576  64]
 [ 43 597]]


  0%|          | 40/100000 [08:51<150:20:20,  5.41s/it]

Epoch [270/100000]
Train Loss: 0.2065
Val Loss: 0.2086, Val Accuracy: 0.9109
Val F1 (micro): 0.9109, Val F1 (macro): 0.9109
Val Confusion Matrix:
[[565  75]
 [ 39 601]]


  0%|          | 50/100000 [11:04<149:47:03,  5.39s/it]

Epoch [280/100000]
Train Loss: 0.1871
Val Loss: 0.2079, Val Accuracy: 0.9117
Val F1 (micro): 0.9117, Val F1 (macro): 0.9117
Val Confusion Matrix:
[[566  74]
 [ 39 601]]


  0%|          | 60/100000 [13:17<150:12:24,  5.41s/it]

Epoch [290/100000]
Train Loss: 0.1854
Val Loss: 0.2083, Val Accuracy: 0.9180
Val F1 (micro): 0.9180, Val F1 (macro): 0.9179
Val Confusion Matrix:
[[574  66]
 [ 39 601]]


  0%|          | 70/100000 [15:29<149:36:07,  5.39s/it]

Epoch [300/100000]
Train Loss: 0.1911
Val Loss: 0.2081, Val Accuracy: 0.9180
Val F1 (micro): 0.9180, Val F1 (macro): 0.9179
Val Confusion Matrix:
[[574  66]
 [ 39 601]]


  0%|          | 80/100000 [17:41<149:41:06,  5.39s/it]

Epoch [310/100000]
Train Loss: 0.2336
Val Loss: 0.2077, Val Accuracy: 0.9164
Val F1 (micro): 0.9164, Val F1 (macro): 0.9164
Val Confusion Matrix:
[[572  68]
 [ 39 601]]


  0%|          | 90/100000 [19:53<149:35:03,  5.39s/it]

Epoch [320/100000]
Train Loss: 0.2065
Val Loss: 0.2077, Val Accuracy: 0.9164
Val F1 (micro): 0.9164, Val F1 (macro): 0.9164
Val Confusion Matrix:
[[572  68]
 [ 39 601]]


  0%|          | 100/100000 [22:06<149:53:30,  5.40s/it]

Epoch [330/100000]
Train Loss: 0.1917
Val Loss: 0.2120, Val Accuracy: 0.9141
Val F1 (micro): 0.9141, Val F1 (macro): 0.9141
Val Confusion Matrix:
[[580  60]
 [ 50 590]]


  0%|          | 110/100000 [24:17<149:31:45,  5.39s/it]

Epoch [340/100000]
Train Loss: 0.1793
Val Loss: 0.2095, Val Accuracy: 0.9133
Val F1 (micro): 0.9133, Val F1 (macro): 0.9131
Val Confusion Matrix:
[[556  84]
 [ 27 613]]


  0%|          | 120/100000 [26:29<149:25:33,  5.39s/it]

Epoch [350/100000]
Train Loss: 0.2192
Val Loss: 0.1974, Val Accuracy: 0.9211
Val F1 (micro): 0.9211, Val F1 (macro): 0.9210
Val Confusion Matrix:
[[573  67]
 [ 34 606]]


  0%|          | 130/100000 [28:41<149:34:25,  5.39s/it]

Epoch [360/100000]
Train Loss: 0.1879
Val Loss: 0.2057, Val Accuracy: 0.9125
Val F1 (micro): 0.9125, Val F1 (macro): 0.9125
Val Confusion Matrix:
[[580  60]
 [ 52 588]]


  0%|          | 140/100000 [30:55<150:03:44,  5.41s/it]

Epoch [370/100000]
Train Loss: 0.1932
Val Loss: 0.2109, Val Accuracy: 0.9141
Val F1 (micro): 0.9141, Val F1 (macro): 0.9139
Val Confusion Matrix:
[[561  79]
 [ 31 609]]


  0%|          | 150/100000 [33:08<149:46:04,  5.40s/it]

Epoch [380/100000]
Train Loss: 0.1995
Val Loss: 0.2022, Val Accuracy: 0.9164
Val F1 (micro): 0.9164, Val F1 (macro): 0.9163
Val Confusion Matrix:
[[561  79]
 [ 28 612]]


  0%|          | 160/100000 [35:20<149:47:00,  5.40s/it]

Epoch [390/100000]
Train Loss: 0.2479
Val Loss: 0.1961, Val Accuracy: 0.9148
Val F1 (micro): 0.9148, Val F1 (macro): 0.9148
Val Confusion Matrix:
[[566  74]
 [ 35 605]]


  0%|          | 170/100000 [37:32<149:19:38,  5.38s/it]

Epoch [400/100000]
Train Loss: 0.1799
Val Loss: 0.2058, Val Accuracy: 0.9172
Val F1 (micro): 0.9172, Val F1 (macro): 0.9172
Val Confusion Matrix:
[[592  48]
 [ 58 582]]


  0%|          | 180/100000 [39:44<149:18:21,  5.38s/it]

Epoch [410/100000]
Train Loss: 0.2095
Val Loss: 0.1956, Val Accuracy: 0.9187
Val F1 (micro): 0.9187, Val F1 (macro): 0.9187
Val Confusion Matrix:
[[571  69]
 [ 35 605]]


  0%|          | 190/100000 [41:55<149:10:22,  5.38s/it]

Epoch [420/100000]
Train Loss: 0.1721
Val Loss: 0.2040, Val Accuracy: 0.9164
Val F1 (micro): 0.9164, Val F1 (macro): 0.9162
Val Confusion Matrix:
[[557  83]
 [ 24 616]]


  0%|          | 200/100000 [44:08<149:38:25,  5.40s/it]

Epoch [430/100000]
Train Loss: 0.2186
Val Loss: 0.2191, Val Accuracy: 0.9055
Val F1 (micro): 0.9055, Val F1 (macro): 0.9053
Val Confusion Matrix:
[[603  37]
 [ 84 556]]


  0%|          | 210/100000 [46:20<149:35:37,  5.40s/it]

Epoch [440/100000]
Train Loss: 0.2218
Val Loss: 0.1941, Val Accuracy: 0.9164
Val F1 (micro): 0.9164, Val F1 (macro): 0.9164
Val Confusion Matrix:
[[575  65]
 [ 42 598]]


  0%|          | 220/100000 [48:33<149:38:39,  5.40s/it]

Epoch [450/100000]
Train Loss: 0.1920
Val Loss: 0.1942, Val Accuracy: 0.9148
Val F1 (micro): 0.9148, Val F1 (macro): 0.9147
Val Confusion Matrix:
[[561  79]
 [ 30 610]]


  0%|          | 230/100000 [50:45<149:56:46,  5.41s/it]

Epoch [460/100000]
Train Loss: 0.1882
Val Loss: 0.1947, Val Accuracy: 0.9148
Val F1 (micro): 0.9148, Val F1 (macro): 0.9148
Val Confusion Matrix:
[[570  70]
 [ 39 601]]


  0%|          | 240/100000 [52:58<149:39:27,  5.40s/it]

Epoch [470/100000]
Train Loss: 0.2246
Val Loss: 0.1968, Val Accuracy: 0.9141
Val F1 (micro): 0.9141, Val F1 (macro): 0.9141
Val Confusion Matrix:
[[579  61]
 [ 49 591]]


  0%|          | 250/100000 [55:09<149:17:58,  5.39s/it]

Epoch [480/100000]
Train Loss: 0.1634
Val Loss: 0.1913, Val Accuracy: 0.9195
Val F1 (micro): 0.9195, Val F1 (macro): 0.9195
Val Confusion Matrix:
[[569  71]
 [ 32 608]]


  0%|          | 260/100000 [57:21<149:10:50,  5.38s/it]

Epoch [490/100000]
Train Loss: 0.1883
Val Loss: 0.1969, Val Accuracy: 0.9156
Val F1 (micro): 0.9156, Val F1 (macro): 0.9156
Val Confusion Matrix:
[[582  58]
 [ 50 590]]


  0%|          | 270/100000 [59:33<149:21:51,  5.39s/it]

Epoch [500/100000]
Train Loss: 0.2055
Val Loss: 0.1903, Val Accuracy: 0.9187
Val F1 (micro): 0.9187, Val F1 (macro): 0.9186
Val Confusion Matrix:
[[565  75]
 [ 29 611]]


  0%|          | 280/100000 [1:01:46<150:20:19,  5.43s/it]

Epoch [510/100000]
Train Loss: 0.2053
Val Loss: 0.1886, Val Accuracy: 0.9187
Val F1 (micro): 0.9187, Val F1 (macro): 0.9187
Val Confusion Matrix:
[[578  62]
 [ 42 598]]


  0%|          | 290/100000 [1:03:58<149:59:42,  5.42s/it]

Epoch [520/100000]
Train Loss: 0.1999
Val Loss: 0.1882, Val Accuracy: 0.9227
Val F1 (micro): 0.9227, Val F1 (macro): 0.9227
Val Confusion Matrix:
[[587  53]
 [ 46 594]]


  0%|          | 300/100000 [1:06:13<153:11:03,  5.53s/it]

Epoch [530/100000]
Train Loss: 0.1656
Val Loss: 0.1863, Val Accuracy: 0.9195
Val F1 (micro): 0.9195, Val F1 (macro): 0.9195
Val Confusion Matrix:
[[571  69]
 [ 34 606]]


  0%|          | 310/100000 [1:08:27<154:18:43,  5.57s/it]

Epoch [540/100000]
Train Loss: 0.2138
Val Loss: 0.1856, Val Accuracy: 0.9187
Val F1 (micro): 0.9187, Val F1 (macro): 0.9187
Val Confusion Matrix:
[[576  64]
 [ 40 600]]


  0%|          | 320/100000 [1:10:41<153:51:47,  5.56s/it]

Epoch [550/100000]
Train Loss: 0.1943
Val Loss: 0.1860, Val Accuracy: 0.9203
Val F1 (micro): 0.9203, Val F1 (macro): 0.9203
Val Confusion Matrix:
[[573  67]
 [ 35 605]]


  0%|          | 330/100000 [1:12:56<151:24:33,  5.47s/it]

Epoch [560/100000]
Train Loss: 0.2218
Val Loss: 0.1858, Val Accuracy: 0.9211
Val F1 (micro): 0.9211, Val F1 (macro): 0.9211
Val Confusion Matrix:
[[576  64]
 [ 37 603]]


  0%|          | 340/100000 [1:15:09<149:28:26,  5.40s/it]

Epoch [570/100000]
Train Loss: 0.1779
Val Loss: 0.1858, Val Accuracy: 0.9219
Val F1 (micro): 0.9219, Val F1 (macro): 0.9218
Val Confusion Matrix:
[[578  62]
 [ 38 602]]


  0%|          | 350/100000 [1:17:23<154:28:13,  5.58s/it]

Epoch [580/100000]
Train Loss: 0.1814
Val Loss: 0.1856, Val Accuracy: 0.9203
Val F1 (micro): 0.9203, Val F1 (macro): 0.9203
Val Confusion Matrix:
[[574  66]
 [ 36 604]]


  0%|          | 360/100000 [1:19:37<149:47:51,  5.41s/it]

Epoch [590/100000]
Train Loss: 0.1917
Val Loss: 0.1854, Val Accuracy: 0.9195
Val F1 (micro): 0.9195, Val F1 (macro): 0.9195
Val Confusion Matrix:
[[572  68]
 [ 35 605]]


  0%|          | 370/100000 [1:21:49<150:23:42,  5.43s/it]

Epoch [600/100000]
Train Loss: 0.1648
Val Loss: 0.1854, Val Accuracy: 0.9195
Val F1 (micro): 0.9195, Val F1 (macro): 0.9195
Val Confusion Matrix:
[[572  68]
 [ 35 605]]


  0%|          | 380/100000 [1:24:03<154:08:51,  5.57s/it]

Epoch [610/100000]
Train Loss: 0.2176
Val Loss: 0.1853, Val Accuracy: 0.9187
Val F1 (micro): 0.9187, Val F1 (macro): 0.9187
Val Confusion Matrix:
[[571  69]
 [ 35 605]]


  0%|          | 390/100000 [1:26:19<154:40:44,  5.59s/it]

Epoch [620/100000]
Train Loss: 0.1812
Val Loss: 0.1853, Val Accuracy: 0.9187
Val F1 (micro): 0.9187, Val F1 (macro): 0.9187
Val Confusion Matrix:
[[571  69]
 [ 35 605]]


  0%|          | 400/100000 [1:28:37<155:23:53,  5.62s/it]

Epoch [630/100000]
Train Loss: 0.1906
Val Loss: 0.1853, Val Accuracy: 0.9187
Val F1 (micro): 0.9187, Val F1 (macro): 0.9187
Val Confusion Matrix:
[[571  69]
 [ 35 605]]


  0%|          | 410/100000 [1:30:51<154:28:00,  5.58s/it]

Epoch [640/100000]
Train Loss: 0.1710
Val Loss: 0.1853, Val Accuracy: 0.9187
Val F1 (micro): 0.9187, Val F1 (macro): 0.9187
Val Confusion Matrix:
[[571  69]
 [ 35 605]]


  0%|          | 420/100000 [1:33:05<150:11:05,  5.43s/it]

Epoch [650/100000]
Train Loss: 0.1703
Val Loss: 0.1928, Val Accuracy: 0.9211
Val F1 (micro): 0.9211, Val F1 (macro): 0.9211
Val Confusion Matrix:
[[583  57]
 [ 44 596]]


  0%|          | 430/100000 [1:35:20<154:08:54,  5.57s/it]

Epoch [660/100000]
Train Loss: 0.2660
Val Loss: 0.2073, Val Accuracy: 0.9172
Val F1 (micro): 0.9172, Val F1 (macro): 0.9172
Val Confusion Matrix:
[[594  46]
 [ 60 580]]


  0%|          | 440/100000 [1:37:36<154:54:08,  5.60s/it]

Epoch [670/100000]
Train Loss: 0.2146
Val Loss: 0.2118, Val Accuracy: 0.9148
Val F1 (micro): 0.9148, Val F1 (macro): 0.9148
Val Confusion Matrix:
[[594  46]
 [ 63 577]]


  0%|          | 450/100000 [1:39:51<154:20:33,  5.58s/it]

Epoch [680/100000]
Train Loss: 0.2118
Val Loss: 0.1884, Val Accuracy: 0.9203
Val F1 (micro): 0.9203, Val F1 (macro): 0.9202
Val Confusion Matrix:
[[570  70]
 [ 32 608]]


  0%|          | 460/100000 [1:42:07<154:53:54,  5.60s/it]

Epoch [690/100000]
Train Loss: 0.1840
Val Loss: 0.1874, Val Accuracy: 0.9234
Val F1 (micro): 0.9234, Val F1 (macro): 0.9233
Val Confusion Matrix:
[[567  73]
 [ 25 615]]


  0%|          | 470/100000 [1:44:22<150:51:05,  5.46s/it]

Epoch [700/100000]
Train Loss: 0.1994
Val Loss: 0.1881, Val Accuracy: 0.9187
Val F1 (micro): 0.9187, Val F1 (macro): 0.9187
Val Confusion Matrix:
[[585  55]
 [ 49 591]]


  0%|          | 480/100000 [1:46:38<154:37:20,  5.59s/it]

Epoch [710/100000]
Train Loss: 0.1812
Val Loss: 0.2006, Val Accuracy: 0.9203
Val F1 (micro): 0.9203, Val F1 (macro): 0.9203
Val Confusion Matrix:
[[594  46]
 [ 56 584]]


  0%|          | 490/100000 [1:48:53<152:20:59,  5.51s/it]

Epoch [720/100000]
Train Loss: 0.2150
Val Loss: 0.1843, Val Accuracy: 0.9266
Val F1 (micro): 0.9266, Val F1 (macro): 0.9265
Val Confusion Matrix:
[[583  57]
 [ 37 603]]


  0%|          | 500/100000 [1:51:08<154:20:17,  5.58s/it]

Epoch [730/100000]
Train Loss: 0.2007
Val Loss: 0.1814, Val Accuracy: 0.9258
Val F1 (micro): 0.9258, Val F1 (macro): 0.9258
Val Confusion Matrix:
[[585  55]
 [ 40 600]]


  1%|          | 510/100000 [1:53:23<154:27:34,  5.59s/it]

Epoch [740/100000]
Train Loss: 0.1980
Val Loss: 0.1816, Val Accuracy: 0.9234
Val F1 (micro): 0.9234, Val F1 (macro): 0.9234
Val Confusion Matrix:
[[572  68]
 [ 30 610]]


  1%|          | 520/100000 [1:55:36<152:04:13,  5.50s/it]

Epoch [750/100000]
Train Loss: 0.1685
Val Loss: 0.1878, Val Accuracy: 0.9258
Val F1 (micro): 0.9258, Val F1 (macro): 0.9257
Val Confusion Matrix:
[[566  74]
 [ 21 619]]


  1%|          | 530/100000 [1:57:52<154:33:06,  5.59s/it]

Epoch [760/100000]
Train Loss: 0.1876
Val Loss: 0.1892, Val Accuracy: 0.9234
Val F1 (micro): 0.9234, Val F1 (macro): 0.9234
Val Confusion Matrix:
[[573  67]
 [ 31 609]]


  1%|          | 540/100000 [2:00:07<151:42:07,  5.49s/it]

In [None]:
model, _, _, latest_epoch, metrics = load_checkpoint(model)

val_batches_ = [virus_da.get_batch() for _ in range(num_val // 2)] + [cellular_da.get_batch() for _ in range(num_val // 2)]

# input_sequences_ = [e['Sequence'] for b in val_batches_ for e in b]
# labels_ = [encode_lineage(e['Taxonomic_lineage__ALL_'])  for b in val_batches_ for e in b]

input_sequences_ = ["ACACAD"]
labels_ = [{0: 1}]

def evaluate_df(model):
    model.eval()  # Set model to evaluation mode
    
    df = {
        "sequence": [],
        "label": [],
        "pred": [],
        "loss": []
    }

    metrics = {
        "loss": 0,
        "accuracy": 0,
        "f1 macro": 0,
        "f1 micro": 0
    }
    
    # Process each sequence
    for sequence, label in zip(input_sequences_, labels_):
        inputs = tokenizer_(
            [sequence],
            return_tensors="pt",
            padding='max_length',
            truncation=True,
            max_length=max_seq_len
        ).to(device)
    
        # Get model output
        with torch.no_grad():
            output = model(inputs['input_ids'], inputs['attention_mask'])

        pred = output.argmax(dim=-1).cpu().item()
        loss = criterion(output, torch.tensor([label[0]]).to(device))
        df["sequence"].append(sequence)
        df["label"].append(level_decoder[0][label[0]])
        df["pred"].append(level_decoder[0][pred])
        df["loss"].append(round(loss.cpu().item(), 4))

    # Convert to DataFrame
    new_df = pd.DataFrame(df)
    new_df['is_incorrect'] = new_df['label'] != new_df['pred']
    new_df = new_df.sort_values(['is_incorrect', 'loss'], ascending=[False, False])
    new_df.to_csv(f'classification_results__new_att.csv', index=False)

    metrics["loss"] = np.array(df["loss"]).mean()
    metrics["accuracy"] = accuracy_score(np.array(df["label"]), np.array(df["pred"]))
    metrics["f1 macro"] = f1_score(np.array(df["label"]), np.array(df["pred"]), average='macro')  # F1-score for multi-label classification
    metrics["f1 micro"] = f1_score(np.array(df["label"]), np.array(df["pred"]), average='micro') 
    print(metrics)

evaluate_df(model)