In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import torch.optim as optim
from tqdm import tqdm
import sys, os, math
import wandb
from sklearn.metrics import f1_score, confusion_matrix, accuracy_score
import json
from transformers import EsmModel, AutoTokenizer, AutoModel

sys.path.insert(0, '../dlp')
from batch import Batch

device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
print(device)

epochs = 100_000
val_epoch = 100
num_val = 10
batch_size = 64
dataset_name = "corpus_1000_Viruses_cellular"
lr = 0.001
model_name = "TwoLevelLoss"
max_seq_len = 1000

from data_access import PQDataAccess
da = PQDataAccess(f"/home/aac/Alireza/datasets/export_pqt_4_taxseq_new/{dataset_name}", batch_size)

checkpoint_dir = f"../checkpoints/{model_name}_checkpoints"
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
print(checkpoint_dir)

wandb.init(
    # set the wandb project where this run will be logged
    project=model_name,

    # track hyperparameters and run metadata
    config={
        "learning_rate": lr,
        "architecture": model_name,
        "epochs": epochs,
        "batch_szie": batch_size,
        "max_seq_len": max_seq_len
    }
)

  from .autonotebook import tqdm as notebook_tqdm


cuda:1
 WORLD_SIZE=1 , LOCAL_WORLD_SIZE=1,RANK =0,LOCAL_RANK = 0 
../checkpoints/TwoLevelLoss_checkpoints


git root error: Cmd('git') failed due to: exit code(128)
  cmdline: git rev-parse --show-toplevel
  stderr: 'fatal: detected dubious ownership in repository at '/home/aac/Alireza'
To add an exception for this directory, call:

	git config --global --add safe.directory /home/aac/Alireza'
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33malirezanor[0m ([33malireza_noroozi[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
index2name_file = "../data/taxonomy_index.json"

if os.path.exists(index2name_file):
    with open(index2name_file, "rb") as f:
        index2name = json.load(f)

index2name = {k:v for k, v in index2name.items() if k in ["0", "1"]}

tax_vocab_sizes = {
    int(k): len(v) + 1 for k,v in index2name.items()
}
# print(tax_vocab_sizes)

level_encoder = {
    int(k): {name: idx + 1 for idx, name in enumerate(v)} for k,v in index2name.items()
}

level_decoder = {
    int(k): {idx + 1: name for idx, name in enumerate(v)} for k,v in index2name.items()
}

for k, v in level_decoder.items():
    level_decoder[k][0] = "NOT DEFINED"


def encode_lineage(lineage_str):
    taxes_str = lineage_str.split(", ")

    encoded = {int(k): 0 for k in index2name.keys()}
    
    for i, tax_str in enumerate(taxes_str):
        if i <= 1:
            encoded[i] = level_encoder[i][tax_str]

    return encoded

In [3]:
class ESMHead(nn.Module):
    def __init__(self, num_classes_dict):
        super().__init__()
        self.esm = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
        for param in self.esm.parameters():
            param.requires_grad = False
        
        self.heads = nn.ModuleDict()
        self.attentions = nn.ModuleDict()

        prev_class = 0
        for index_name, num_classes in num_classes_dict.items():
            self.heads[str(index_name)] = nn.Sequential(
                nn.Linear(1280 + prev_class, 512),
                nn.LayerNorm(512),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(512, num_classes)
            )

            self.attentions[str(index_name)] = nn.Sequential(
                nn.Linear(1280, 256),
                nn.Tanh(),
                nn.Linear(256, 1)
            )

            prev_class = num_classes
        
    def attention_pooling(self, x, index):
        # x shape: (batch_size, seq_length, embedding_dim)
        attention_weights = self.attentions[index](x)  # (batch_size, seq_length, 1)
        attention_weights = torch.softmax(attention_weights.squeeze(-1), dim=1)  # (batch_size, seq_length)
        attention_weights = attention_weights.unsqueeze(-1)  # (batch_size, seq_length, 1)
        pooled = torch.sum(x * attention_weights, dim=1)  # (batch_size, embedding_dim)
        return pooled

    def forward(self, x, attention_mask=None):
        x = self.esm(x, attention_mask=attention_mask).last_hidden_state
        for index, head in self.heads.items():
            pooled = self.attention_pooling(x, index)  # Apply attention pooling
            if index != "0":
                current_pooled = torch.cat([pooled, output], dim=-1)
            else:
                current_pooled = pooled
            output = head(current_pooled)
        return output


In [4]:
model = ESMHead(tax_vocab_sizes).to(device)

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f'Trainable parameters: {trainable/ 1e6} M')
print(f'Total parameters: {total/ 1e6} M')

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

criterion = nn.CrossEntropyLoss()

# scheduler = torch.optim.lr_scheduler.StepLR(
#     optimizer,
#     step_size=10,  # Period of learning rate decay
#     gamma=0.1  # Multiplicative factor of decay
# )

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Trainable parameters: 1.990695 M
Total parameters: 654.344636 M


In [5]:
# class BatchReader:
#     def __init__(self, batch_size):
#         self.batch_size = batch_size
#         self.current_index = 0 
#         self.data = pd.read_csv("../embeddings/taxonomic_data.csv")

#     def get_batch(self):
#         if self.current_index >= len(self.data):
#             print("No more data available.")
#             return []

#         batch_data = self.data.iloc[self.current_index:self.current_index + self.batch_size]
#         self.current_index += self.batch_size

#         inputs = []
#         tax_ids = []
#         for _, row in batch_data.iterrows():
#             inputs.append(torch.load(f"../embeddings/tax_esm_embeddings/{row['index']}.pt"))
#             tax_ids.append(encode_lineage(row['Taxonomic_lineage']))
    
#         inputs = torch.stack(inputs)
    
#         combined_dict = {}
#         for d in tax_ids:
#             for key, value in d.items():
#                 combined_dict.setdefault(key, []).append(value)

#         tensor_encoded = {k: torch.LongTensor(v) for k,v in combined_dict.items()}
#         return Batch(inputs, tensor_encoded)


# da = BatchReader(batch_size=batch_size)
# import time
# start_time = time.time()
# val_batches = [da.get_batch() for _ in range(num_val)]
# end_time = time.time()
# elapsed_time = end_time - start_time
# print(f"Execution time: {elapsed_time:.6f} seconds")
tokenizer_ = AutoTokenizer.from_pretrained(f"facebook/esm1b_t33_650M_UR50S")

def data_to_tensor_batch(b, max_seq_len=max_seq_len):

    inputs = tokenizer_(
        [e['Sequence'] for e in b],
        return_tensors="pt", 
        padding='max_length', 
        truncation=True, 
        max_length=max_seq_len
    )

    tax_ids = [encode_lineage(e['Taxonomic_lineage__ALL_']) for e in b]

    combined_dict = {}
    for d in tax_ids:
        for key, value in d.items():
            combined_dict.setdefault(key, []).append(value)

    tensor_encoded = {k: torch.LongTensor(v) for k,v in combined_dict.items()}
    return Batch(inputs, tensor_encoded)

In [6]:
val_batches = [da.get_batch() for _ in range(num_val)]

def evaluate(model):
    model.eval()
    
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    for epoch in range(num_val):
        with torch.no_grad():
            tensor_batch = data_to_tensor_batch(val_batches[epoch], max_seq_len)
            tensor_batch.gpu(device)
            labels = tensor_batch.taxes
            labels = labels[list(labels.keys())[-1]]
            
        outputs = model(tensor_batch.seq_ids['input_ids'], tensor_batch.seq_ids['attention_mask'])
        loss = criterion(outputs, labels)
        preds = torch.argmax(outputs, dim=1)

        running_loss += loss.item()
        all_preds.append(preds.cpu())
        all_labels.append(labels.cpu())
    
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    accuracy = accuracy_score(all_labels.numpy(), all_preds.numpy())
    f1_macro = f1_score(all_labels.numpy(), all_preds.numpy(), average='macro')
    f1_micro = f1_score(all_labels.numpy(), all_preds.numpy(), average='micro')
    conf_matrix = confusion_matrix(all_labels.numpy(), all_preds.numpy())
    avg_loss = running_loss / num_val

    return avg_loss, accuracy, f1_micro, f1_macro, conf_matrix

# evaluate(model)

In [7]:
import glob
def load_checkpoint(model, optimizer=None, scheduler=None):
    checkpoints = glob.glob(os.path.join(checkpoint_dir, 'checkpoint_epoch_*.pt'))        
    # Extract epoch numbers and find latest
    latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('_')[-1].split('.')[0]))
    checkpoint = torch.load(latest_checkpoint)
    
    # Load model state
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    
    # Load optimizer state if provided (for training)
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        # Move optimizer state to GPU if necessary
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
    if scheduler is not None:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    # Get training metadata
    epoch = checkpoint['epoch']
    metrics = checkpoint['metrics']
    
    print(f"Successfully loaded checkpoint from epoch {epoch}")
    # print("Metrics at checkpoint:", metrics)
    
    return model, optimizer, scheduler, epoch, metrics
        

# model, optimizer, scheduler, latest_epoch, metrics = load_checkpoint(model, optimizer, scheduler)
latest_epoch = 0

In [None]:
running_loss = 0
current_lr = lr

for epoch in tqdm(range(latest_epoch, latest_epoch + epochs)):
    model.train()
    
    tensor_batch = da.get_batch()
    tensor_batch = data_to_tensor_batch(tensor_batch, max_seq_len)
    tensor_batch.gpu(device)
    
    labels = tensor_batch.taxes
    labels = labels[list(labels.keys())[-1]]
    
    outputs = model(tensor_batch.seq_ids['input_ids'], tensor_batch.seq_ids['attention_mask'])
    
    loss = criterion(outputs, labels)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    running_loss += loss.item()
    
    if (epoch + 1) % val_epoch == 0:
        train_loss = running_loss / val_epoch
        # Evaluate on validation set
        val_loss, val_accuracy, val_f1_micro, val_f1_macro, val_cm = evaluate(model)
        
        print(f"Epoch [{epoch + 1}/{epochs}]")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
        print(f"Val F1 (micro): {val_f1_micro:.4f}, Val F1 (macro): {val_f1_macro:.4f}")
        print("Val Confusion Matrix:")
        print(val_cm)
        
        # Create metrics dictionary for saving
        metrics = {
            "train_loss": train_loss,
            "val_loss":val_loss,
            "val_accuracy": val_accuracy,
            "val_f1_micro": val_f1_micro,
            "val_f1_macro": val_f1_macro,
            "epoch": epoch + 1,
            "lr": current_lr
        }

        # Save periodic checkpoint
        checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            # 'scheduler_state_dict': scheduler.state_dict(),
            'metrics': metrics
        }, checkpoint_path)
        
        # Log to wandb
        wandb.log(metrics)

        # Step the scheduler
        # scheduler.step(epoch + loss.item())
        # current_lr = scheduler.get_last_lr()[0]
        
        # Reset training metrics
        running_loss = 0

wandb.finish()

  0%|          | 99/100000 [07:32<127:05:32,  4.58s/it]

Epoch [100/100000]
Train Loss: 0.6424
Val Loss: 0.8306, Val Accuracy: 0.7719
Val F1 (micro): 0.7719, Val F1 (macro): 0.1825
Val Confusion Matrix:
[[ 39   0   0   0  80   0   0]
 [  3   0   0   0  11   0   0]
 [  2   0   0   0  13   0   0]
 [  0   0   0   0   2   0   0]
 [ 26   0   0   0 455   0   0]
 [  1   0   0   0   1   0   0]
 [  0   0   0   0   7   0   0]]


  0%|          | 199/100000 [16:01<126:44:08,  4.57s/it]

Epoch [200/100000]
Train Loss: 0.2962
Val Loss: 1.0210, Val Accuracy: 0.7656
Val F1 (micro): 0.7656, Val F1 (macro): 0.1538
Val Confusion Matrix:
[[ 15   0   0   0 104   0   0]
 [  0   0   0   0  14   0   0]
 [  0   0   0   0  15   0   0]
 [  0   0   0   0   2   0   0]
 [  6   0   0   0 475   0   0]
 [  2   0   0   0   0   0   0]
 [  0   0   0   0   7   0   0]]


  0%|          | 299/100000 [24:30<126:30:14,  4.57s/it]

Epoch [300/100000]
Train Loss: 0.1406
Val Loss: 1.2323, Val Accuracy: 0.7688
Val F1 (micro): 0.7688, Val F1 (macro): 0.1317
Val Confusion Matrix:
[[ 12   0   0   0   7 100   0   0]
 [  0   0   0   0   1  13   0   0]
 [  0   0   0   0   0  15   0   0]
 [  0   0   0   0   0   2   0   0]
 [  0   0   0   0   0   0   0   0]
 [  1   0   0   0   0 480   0   0]
 [  2   0   0   0   0   0   0   0]
 [  0   0   0   0   0   7   0   0]]


  0%|          | 399/100000 [32:59<126:22:50,  4.57s/it]

Epoch [400/100000]
Train Loss: 0.1229
Val Loss: 1.2201, Val Accuracy: 0.7625
Val F1 (micro): 0.7625, Val F1 (macro): 0.1242
Val Confusion Matrix:
[[  8   0   0   0   6 105   0   0]
 [  0   0   0   0   1  13   0   0]
 [  0   0   0   0   0  15   0   0]
 [  0   0   0   0   0   2   0   0]
 [  0   0   0   0   0   0   0   0]
 [  1   0   0   0   0 480   0   0]
 [  1   0   0   0   0   1   0   0]
 [  0   0   0   0   0   7   0   0]]


  0%|          | 499/100000 [41:29<126:11:20,  4.57s/it]

Epoch [500/100000]
Train Loss: 0.1300
Val Loss: 1.0158, Val Accuracy: 0.7688
Val F1 (micro): 0.7688, Val F1 (macro): 0.1422
Val Confusion Matrix:
[[ 19   0   0   0  15  85   0   0]
 [  0   0   0   0   3  11   0   0]
 [  0   0   0   0   0  15   0   0]
 [  0   0   0   0   0   2   0   0]
 [  0   0   0   0   0   0   0   0]
 [  8   0   0   0   0 473   0   0]
 [  2   0   0   0   0   0   0   0]
 [  0   0   0   0   0   7   0   0]]


  1%|          | 599/100000 [49:59<126:03:50,  4.57s/it]

Epoch [600/100000]
Train Loss: 0.2430
Val Loss: 1.3656, Val Accuracy: 0.7672
Val F1 (micro): 0.7672, Val F1 (macro): 0.1178
Val Confusion Matrix:
[[ 12   0   0   0   1  12  94   0   0]
 [  0   0   0   0   1   3  10   0   0]
 [  0   0   0   0   0   0  15   0   0]
 [  0   0   0   0   0   0   2   0   0]
 [  0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0]
 [  1   0   0   0   0   1 479   0   0]
 [  1   0   0   0   0   0   1   0   0]
 [  0   0   0   0   0   0   7   0   0]]


  1%|          | 699/100000 [58:28<125:58:02,  4.57s/it]

Epoch [700/100000]
Train Loss: 0.1236
Val Loss: 1.3164, Val Accuracy: 0.7688
Val F1 (micro): 0.7688, Val F1 (macro): 0.1349
Val Confusion Matrix:
[[ 13   0   0   0  18  88   0   0]
 [  0   0   0   0   5   9   0   0]
 [  0   0   0   0   0  15   0   0]
 [  0   0   0   0   0   2   0   0]
 [  0   0   0   0   0   0   0   0]
 [  1   0   0   0   1 479   0   0]
 [  1   0   0   0   0   1   0   0]
 [  0   0   0   0   0   7   0   0]]


  1%|          | 799/100000 [1:06:57<125:48:12,  4.57s/it]

Epoch [800/100000]
Train Loss: 0.0627
Val Loss: 1.2201, Val Accuracy: 0.7719
Val F1 (micro): 0.7719, Val F1 (macro): 0.1444
Val Confusion Matrix:
[[ 21   0   0   0   8  90   0   0]
 [  0   0   0   0   1  13   0   0]
 [  0   0   0   0   0  15   0   0]
 [  0   0   0   0   0   2   0   0]
 [  0   0   0   0   0   0   0   0]
 [  8   0   0   0   0 473   0   0]
 [  2   0   0   0   0   0   0   0]
 [  0   0   0   0   0   7   0   0]]


  1%|          | 899/100000 [1:15:26<125:39:22,  4.56s/it]

Epoch [900/100000]
Train Loss: 0.1093
Val Loss: 1.1847, Val Accuracy: 0.7625
Val F1 (micro): 0.7625, Val F1 (macro): 0.1146
Val Confusion Matrix:
[[ 34   3   0   2   0   0  13  67   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0]
 [  0   1   0   4   0   0   1   7   0   1   0]
 [  0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0  14   0   1   0]
 [  0   0   0   1   0   0   0   1   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0]
 [ 24   0   0   0   0   0   1 454   0   2   0]
 [  2   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   7   0   0   0]]


  1%|          | 999/100000 [1:23:54<125:39:28,  4.57s/it]

Epoch [1000/100000]
Train Loss: 0.1206
Val Loss: 1.3157, Val Accuracy: 0.7672
Val F1 (micro): 0.7672, Val F1 (macro): 0.1300
Val Confusion Matrix:
[[ 11   0   0   0   8 100   0   0]
 [  0   0   0   0   1  13   0   0]
 [  0   0   0   0   0  15   0   0]
 [  0   0   0   0   0   2   0   0]
 [  0   0   0   0   0   0   0   0]
 [  1   0   0   0   0 480   0   0]
 [  1   0   0   0   0   1   0   0]
 [  0   0   0   0   0   7   0   0]]


  1%|          | 1099/100000 [1:32:21<125:24:40,  4.56s/it]

Epoch [1100/100000]
Train Loss: 0.0672
Val Loss: 1.2068, Val Accuracy: 0.7703
Val F1 (micro): 0.7703, Val F1 (macro): 0.1136
Val Confusion Matrix:
[[ 19   0   0   0   0   9  91   0   0   0]
 [  0   0   0   0   0   1  13   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0  15   0   0   0]
 [  0   0   1   0   0   0   1   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0]
 [  6   0   0   0   0   0 474   0   1   0]
 [  2   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   7   0   0   0]]


  1%|          | 1199/100000 [1:40:50<125:39:15,  4.58s/it]

Epoch [1200/100000]
Train Loss: 0.2755
Val Loss: 1.2789, Val Accuracy: 0.7609
Val F1 (micro): 0.7609, Val F1 (macro): 0.1316
Val Confusion Matrix:
[[ 22   4   0   0   0  23  70   0   0]
 [  0   0   0   0   0   0   0   0   0]
 [  0   4   0   0   0   4   6   0   0]
 [  0   0   0   0   0   0  15   0   0]
 [  0   1   0   0   0   0   1   0   0]
 [  0   0   0   0   0   0   0   0   0]
 [  7   2   0   0   0   7 465   0   0]
 [  2   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   1   6   0   0]]


  1%|▏         | 1299/100000 [1:49:19<125:20:40,  4.57s/it]

Epoch [1300/100000]
Train Loss: 0.1012
Val Loss: 1.3074, Val Accuracy: 0.7594
Val F1 (micro): 0.7594, Val F1 (macro): 0.1294
Val Confusion Matrix:
[[ 22   0   0   0   0  21  76   0   0]
 [  0   0   0   0   0   0   0   0   0]
 [  0   1   0   0   0   5   8   0   0]
 [  0   0   0   0   0   0  15   0   0]
 [  0   1   0   0   0   0   1   0   0]
 [  0   0   0   0   0   0   0   0   0]
 [ 13   0   0   0   0   4 464   0   0]
 [  2   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   1   6   0   0]]


  1%|▏         | 1399/100000 [1:57:49<124:59:01,  4.56s/it]

Epoch [1400/100000]
Train Loss: 0.2339
Val Loss: 1.2792, Val Accuracy: 0.7672
Val F1 (micro): 0.7672, Val F1 (macro): 0.1205
Val Confusion Matrix:
[[ 25   0   0   0   0  13  76   0   5   0]
 [  0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   3   6   0   5   0]
 [  0   0   0   0   0   0  14   0   1   0]
 [  0   1   0   0   0   0   1   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0]
 [ 11   0   0   0   0   1 466   0   3   0]
 [  2   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   7   0   0   0]]


  1%|▏         | 1499/100000 [2:06:17<124:59:26,  4.57s/it]

Epoch [1500/100000]
Train Loss: 0.0691
Val Loss: 1.3962, Val Accuracy: 0.7719
Val F1 (micro): 0.7719, Val F1 (macro): 0.1221
Val Confusion Matrix:
[[ 15   0   0   0   0  10  94   0   0]
 [  0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   3  11   0   0]
 [  0   0   0   0   0   0  15   0   0]
 [  0   1   0   0   0   0   1   0   0]
 [  0   0   0   0   0   0   0   0   0]
 [  2   0   0   0   0   0 479   0   0]
 [  1   0   0   0   0   0   1   0   0]
 [  0   0   0   0   0   0   7   0   0]]


  2%|▏         | 1599/100000 [2:14:46<124:45:00,  4.56s/it]

Epoch [1600/100000]
Train Loss: 0.0998
Val Loss: 1.6517, Val Accuracy: 0.6375
Val F1 (micro): 0.6375, Val F1 (macro): 0.1153
Val Confusion Matrix:
[[ 26   5   0   0   0  48  39   0   1   0]
 [  0   0   0   0   0   0   0   0   0   0]
 [  0   3   0   0   0   7   4   0   0   0]
 [  0   0   0   0   0   1  12   0   2   0]
 [  0   1   0   0   0   1   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0]
 [ 14   2   0   0   0  81 382   0   2   0]
 [  2   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   5   2   0   0   0]]


  2%|▏         | 1699/100000 [2:23:15<124:59:51,  4.58s/it]

Epoch [1700/100000]
Train Loss: 0.2225
Val Loss: 2.0615, Val Accuracy: 0.6469
Val F1 (micro): 0.6469, Val F1 (macro): 0.0927
Val Confusion Matrix:
[[  5   0   0   0   0  60  54   0   0   0]
 [  0   0   1   0   0   9   4   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0  14   0   1   0]
 [  0   0   1   0   0   1   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0  72 409   0   0   0]
 [  0   0   0   0   0   1   1   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   4   3   0   0   0]]


  2%|▏         | 1774/100000 [2:29:50<124:45:31,  4.57s/it]

In [None]:
model, _, _, latest_epoch, metrics = load_checkpoint(model)

val_batches_ = [da.get_batch() for _ in range(num_val)]


input_sequences_ = ["ACACAD"]
labels_ = [{0: 1}]

def evaluate_df(model):
    model.eval()  # Set model to evaluation mode
    
    df = {
        "sequence": [],
        "label": [],
        "pred": [],
        "loss": []
    }

    metrics = {
        "loss": 0,
        "accuracy": 0,
        "f1 macro": 0,
        "f1 micro": 0
    }
    
    # Process each sequence
    for sequence, label in zip(input_sequences_, labels_):
        inputs = tokenizer_(
            [sequence],
            return_tensors="pt",
            padding='max_length',
            truncation=True,
            max_length=max_seq_len
        ).to(device)
    
        # Get model output
        with torch.no_grad():
            output = model(inputs['input_ids'], inputs['attention_mask'])

        pred = output.argmax(dim=-1).cpu().item()
        loss = criterion(output, torch.tensor([label[0]]).to(device))
        df["sequence"].append(sequence)
        df["label"].append(level_decoder[0][label[0]])
        df["pred"].append(level_decoder[0][pred])
        df["loss"].append(round(loss.cpu().item(), 4))

    # Convert to DataFrame
    new_df = pd.DataFrame(df)
    new_df['is_incorrect'] = new_df['label'] != new_df['pred']
    new_df = new_df.sort_values(['is_incorrect', 'loss'], ascending=[False, False])
    new_df.to_csv(f'classification_results__new_att.csv', index=False)

    metrics["loss"] = np.array(df["loss"]).mean()
    metrics["accuracy"] = accuracy_score(np.array(df["label"]), np.array(df["pred"]))
    metrics["f1 macro"] = f1_score(np.array(df["label"]), np.array(df["pred"]), average='macro')  # F1-score for multi-label classification
    metrics["f1 micro"] = f1_score(np.array(df["label"]), np.array(df["pred"]), average='micro') 
    print(metrics)

evaluate_df(model)