In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import torch.optim as optim
from tqdm import tqdm
import sys, os, math
import wandb
from sklearn.metrics import f1_score, confusion_matrix, accuracy_score
import json
from transformers import EsmModel, AutoTokenizer

sys.path.insert(0, '../dlp')
from batch import Batch

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

epochs = 100_000
val_epoch = 100
num_val = 10
batch_size = 64
virus_dataset_name = "corpus_1000_Viruses"
cellular_dataset_name = "corpus_1000_cellular"
lr = 0.001
model_name = "OnlyFirst_Flat"
max_seq_len = 1000

from data_access import PQDataAccess
virus_da = PQDataAccess(f"/home/aac/Alireza/datasets/export_pqt_4_taxseq_new/{virus_dataset_name}", batch_size)
cellular_da = PQDataAccess(f"/home/aac/Alireza/datasets/export_pqt_4_taxseq_new/{cellular_dataset_name}", batch_size)

checkpoint_dir = f"../checkpoints/{model_name}_checkpoints"
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
print(checkpoint_dir)

wandb.init(
    # set the wandb project where this run will be logged
    project=model_name,

    # track hyperparameters and run metadata
    config={
        "learning_rate": lr,
        "architecture": model_name,
        "epochs": epochs,
        "batch_szie": batch_size,
        "max_seq_len": max_seq_len
    }
)

  from .autonotebook import tqdm as notebook_tqdm


cuda:0
 WORLD_SIZE=1 , LOCAL_WORLD_SIZE=1,RANK =0,LOCAL_RANK = 0 
../checkpoints/OnlyFirst_Flat_checkpoints


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33malirezanor[0m ([33malireza_noroozi[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
class ESM1b(nn.Module):
    def __init__(self):
        super().__init__()
        self.esm = EsmModel.from_pretrained("facebook/esm1b_t33_650M_UR50S")
        
        # Freeze ESM parameters
        for param in self.esm.parameters():
            param.requires_grad = False
        
        self.layer1 = nn.Linear(1280 * 1000, 512)
        self.layer_norm = nn.LayerNorm(512)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.1)
        self.layer2 = nn.Linear(512, 3)

    def forward(self, x, attention_mask=None):
        outputs = self.esm(x, attention_mask=attention_mask).last_hidden_state
        outputs = outputs.reshape(x.shape[0], -1)
        outputs = self.layer1(outputs)
        outputs = self.layer_norm(outputs)
        outputs = self.relu(outputs)
        outputs = self.dropout(outputs)
        outputs = self.layer2(outputs)
        return outputs

In [3]:
model = ESM1b().to(device)

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f'Trainable parameters: {trainable/ 1e6} M')
print(f'Total parameters: {total/ 1e6} M')
# print(model)

optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)

criterion = nn.CrossEntropyLoss()
# Cosine annealing with warm restarts
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer,
    T_0=10,  # Initial restart interval
    T_mult=2,  # Multiply interval by 2 after each restart
    eta_min=1e-6  # Minimum learning rate
)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm1b_t33_650M_UR50S and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Trainable parameters: 655.363075 M
Total parameters: 1307.719576 M


In [4]:
import random
index2name_file = "../data/taxonomy_index.json"

# Check if the file exists
if os.path.exists(index2name_file):
    with open(index2name_file, "rb") as f:
        index2name = json.load(f)

tax_vocab_sizes = {
    int(k): len(v) for k,v in index2name.items()
}

# print(tax_vocab_sizes)
# # Print tax_vocab_sizes sorted by value (number of taxa per rank)
# sorted_sizes = dict(sorted(tax_vocab_sizes.items(), key=lambda x: x[1], reverse=True))
# print("\nTaxonomic ranks sorted by number of taxa:")
# for rank, size in sorted_sizes.items():
#     print(f"{rank}: {size}")

level_encoder = {
    int(k): {name: idx + 1 for idx, name in enumerate(v)} for k,v in index2name.items()
}

level_decoder = {
    int(k): {idx + 1: name for idx, name in enumerate(v)} for k,v in index2name.items()
}

for k, v in level_decoder.items():
    level_decoder[k][0] = "NOT DEFINED"


def encode_lineage(lineage_str):
    taxes_str = lineage_str.split(", ")

    encoded = {int(k): 0 for k in index2name.keys()}
    
    for i, tax_str in enumerate(taxes_str):
        encoded[i] = level_encoder[i][tax_str]

    return encoded

tokenizer_ = AutoTokenizer.from_pretrained(f"facebook/esm1b_t33_650M_UR50S")

def mix_data_to_tensor_batch(b_virues, b_cellular, max_seq_len=max_seq_len, partition=0.25):
    if partition == -1:
        b = b_virues + b_cellular
    else:
        split_point = int(len(b_virues) * partition)
        b = b_virues[:split_point] + b_cellular[-len(b_virues) + split_point:]
        random.shuffle(b)  # In-place shuffle
    
    inputs = tokenizer_(
        [e['Sequence'] for e in b],
        return_tensors="pt", 
        padding='max_length', 
        truncation=True, 
        max_length=max_seq_len
    )

    tax_ids = [encode_lineage(e['Taxonomic_lineage__ALL_']) for e in b]
    combined_dict = {}
    for d in tax_ids:
        for key, value in d.items():
            combined_dict.setdefault(key, []).append(value)

    tensor_encoded = {k: torch.LongTensor(v) for k,v in combined_dict.items()}
    return Batch(inputs, tensor_encoded)

In [5]:
val_dir = f"val_results/{model_name}"
if not os.path.exists(val_dir):
    os.makedirs(val_dir)
    
val_batches = [virus_da.get_batch() for _ in range(num_val)]
cell_val_batches = [cellular_da.get_batch() for _ in range(num_val)]


def evaluate(model):
    model.eval()  # Set model to evaluation mode
    
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    for epoch in range(num_val):
        with torch.no_grad():  # Disable gradient computation during evaluation
            tensor_batch = mix_data_to_tensor_batch(val_batches[epoch], cell_val_batches[epoch], max_seq_len, partition=-1)
            tensor_batch.gpu(device)
        
            labels = tensor_batch.taxes[0]
            outputs = model(tensor_batch.seq_ids['input_ids'], tensor_batch.seq_ids['attention_mask'])

            # Calculate the loss
            loss = criterion(outputs, labels)
    
            running_loss += loss.item()
                
            preds = torch.argmax(outputs, dim=1)
    
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())
    
    # Concatenate all batches into single tensors
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    # Compute evaluation metrics (example: accuracy, F1 score)
    accuracy = accuracy_score(all_labels.numpy(), all_preds.numpy())
    f1_macro = f1_score(all_labels.numpy(), all_preds.numpy(), average='macro')  # F1-score for multi-label classification
    f1_micro = f1_score(all_labels.numpy(), all_preds.numpy(), average='micro')  # F1-score for multi-label classification
    conf_matrix = confusion_matrix(all_labels.numpy(), all_preds.numpy())
    avg_loss = running_loss / num_val
    
    return avg_loss, accuracy, f1_micro, f1_macro, conf_matrix

# evaluate(model)

In [8]:
import glob
def load_checkpoint(model, optimizer=None, scheduler=None):
    checkpoints = glob.glob(os.path.join(checkpoint_dir, 'checkpoint_epoch_*.pt'))        
    # Extract epoch numbers and find latest
    latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('_')[-1].split('.')[0]))
    checkpoint = torch.load(latest_checkpoint)
    
    # Load model state
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    
    # Load optimizer state if provided (for training)
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        # Move optimizer state to GPU if necessary
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
    if scheduler is not None:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    # Get training metadata
    epoch = checkpoint['epoch']
    metrics = checkpoint['metrics']
    
    print(f"Successfully loaded checkpoint from epoch {epoch}")
    # print("Metrics at checkpoint:", metrics)
    
    return model, optimizer, scheduler, epoch, metrics
        

# model, optimizer, scheduler, latest_epoch, metrics = load_checkpoint(model, optimizer, scheduler)
latest_epoch = 0

In [9]:
def get_partition_ratio(epoch, decay_epochs=100000):
    """
    Calculate partition ratio that decreases from 8/16 to 1/16 in steps
    """
    # Calculate how many epochs before each step down
    epochs_per_step = decay_epochs // 7  # 7 steps from 8/16 down to 1/16
    
    # Calculate current step based on epoch
    step = min(epoch // epochs_per_step, 7)  # Max 7 steps down from 8
    
    # Map step to fraction
    fraction = (8 - step) / 16
    
    return fraction

In [None]:
running_loss = 0
current_lr = lr

for epoch in tqdm(range(latest_epoch + 1, latest_epoch + epochs + 1)):
    model.train()

    current_partition = get_partition_ratio(epoch)
    
    tensor_batch = mix_data_to_tensor_batch(
        virus_da.get_batch(),
        cellular_da.get_batch(),
        max_seq_len,
        partition=current_partition
    )
    tensor_batch.gpu(device)
    
    labels = tensor_batch.taxes[0]
    outputs = model(tensor_batch.seq_ids['input_ids'], tensor_batch.seq_ids['attention_mask'])
    
    loss = criterion(outputs, labels)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    running_loss += loss.item()
    
    if (epoch + 1) % val_epoch == 0:
        train_loss = running_loss / val_epoch
        # Evaluate on validation set
        val_loss, val_accuracy, val_f1_micro, val_f1_macro, val_cm = evaluate(model)
        
        print(f"Epoch [{epoch + 1}/{epochs}]")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"test Loss: {val_loss:.4f}, test Accuracy: {val_accuracy:.4f}")
        print(f"test F1 (micro): {val_f1_micro:.4f}, test F1 (macro): {val_f1_macro:.4f}")
        print("test Confusion Matrix:")
        print(val_cm)
        
        # Create metrics dictionary for saving
        metrics = {
            "train_loss": train_loss,
            "test_loss": val_loss,
            "test_accuracy": val_accuracy,
            "test_f1_micro": val_f1_micro,
            "test_f1_macro": val_f1_macro,
            "epoch": epoch + 1,
            "current_portion": current_partition,
            "lr": current_lr
        }

        # Save periodic checkpoint
        checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'metrics': metrics
        }, checkpoint_path)
        
        # Log to wandb
        wandb.log(metrics)

        # Step the scheduler
        scheduler.step(epoch + loss.item())
        current_lr = scheduler.get_last_lr()[0]
        
        # Reset training metrics
        running_loss = 0

wandb.finish()

  0%|          | 98/100000 [07:11<122:08:18,  4.40s/it]

Epoch [100/100000]
Train Loss: 0.4323
test Loss: 0.3391, test Accuracy: 0.8531
test F1 (micro): 0.8531, test F1 (macro): 0.8529
test Confusion Matrix:
[[521 119]
 [ 69 571]]


  0%|          | 198/100000 [16:21<121:53:44,  4.40s/it]

Epoch [200/100000]
Train Loss: 0.2958
test Loss: 0.2749, test Accuracy: 0.8891
test F1 (micro): 0.8891, test F1 (macro): 0.8890
test Confusion Matrix:
[[583  57]
 [ 85 555]]


  0%|          | 298/100000 [25:32<121:38:06,  4.39s/it] 

Epoch [300/100000]
Train Loss: 0.2556
test Loss: 0.2286, test Accuracy: 0.9148
test F1 (micro): 0.9148, test F1 (macro): 0.9148
test Confusion Matrix:
[[580  60]
 [ 49 591]]


  0%|          | 398/100000 [34:51<121:40:23,  4.40s/it] 

Epoch [400/100000]
Train Loss: 0.2392
test Loss: 0.2277, test Accuracy: 0.9125
test F1 (micro): 0.9125, test F1 (macro): 0.9125
test Confusion Matrix:
[[571  69]
 [ 43 597]]


  0%|          | 498/100000 [44:12<121:21:46,  4.39s/it] 

Epoch [500/100000]
Train Loss: 0.2390
test Loss: 0.2091, test Accuracy: 0.9203
test F1 (micro): 0.9203, test F1 (macro): 0.9203
test Confusion Matrix:
[[580  60]
 [ 42 598]]


  1%|          | 598/100000 [53:25<121:22:49,  4.40s/it] 

Epoch [600/100000]
Train Loss: 0.2201
test Loss: 0.2139, test Accuracy: 0.9156
test F1 (micro): 0.9156, test F1 (macro): 0.9154
test Confusion Matrix:
[[552  88]
 [ 20 620]]


  1%|          | 698/100000 [1:02:39<121:08:53,  4.39s/it]

Epoch [700/100000]
Train Loss: 0.2148
test Loss: 0.2038, test Accuracy: 0.9219
test F1 (micro): 0.9219, test F1 (macro): 0.9218
test Confusion Matrix:
[[574  66]
 [ 34 606]]


  1%|          | 798/100000 [1:11:53<121:07:25,  4.40s/it] 

Epoch [800/100000]
Train Loss: 0.2218
test Loss: 0.2181, test Accuracy: 0.9164
test F1 (micro): 0.9164, test F1 (macro): 0.9162
test Confusion Matrix:
[[552  88]
 [ 19 621]]


  1%|          | 898/100000 [1:21:03<120:59:48,  4.40s/it] 

Epoch [900/100000]
Train Loss: 0.2104
test Loss: 0.2050, test Accuracy: 0.9187
test F1 (micro): 0.9187, test F1 (macro): 0.9187
test Confusion Matrix:
[[594  46]
 [ 58 582]]


  1%|          | 998/100000 [1:30:17<120:56:13,  4.40s/it] 

Epoch [1000/100000]
Train Loss: 0.2001
test Loss: 0.1969, test Accuracy: 0.9172
test F1 (micro): 0.9172, test F1 (macro): 0.9171
test Confusion Matrix:
[[567  73]
 [ 33 607]]


  1%|          | 1098/100000 [1:39:31<120:47:39,  4.40s/it]

Epoch [1100/100000]
Train Loss: 0.1937
test Loss: 0.1930, test Accuracy: 0.9203
test F1 (micro): 0.9203, test F1 (macro): 0.9203
test Confusion Matrix:
[[585  55]
 [ 47 593]]


  1%|          | 1198/100000 [1:48:43<120:38:22,  4.40s/it] 

Epoch [1200/100000]
Train Loss: 0.1924
test Loss: 0.1905, test Accuracy: 0.9203
test F1 (micro): 0.9203, test F1 (macro): 0.9203
test Confusion Matrix:
[[573  67]
 [ 35 605]]


  1%|▏         | 1298/100000 [1:57:59<120:30:03,  4.40s/it] 

Epoch [1300/100000]
Train Loss: 0.1875
test Loss: 0.1904, test Accuracy: 0.9211
test F1 (micro): 0.9211, test F1 (macro): 0.9210
test Confusion Matrix:
[[574  66]
 [ 35 605]]


  1%|▏         | 1398/100000 [2:07:09<120:20:17,  4.39s/it] 

Epoch [1400/100000]
Train Loss: 0.2080
test Loss: 0.1931, test Accuracy: 0.9211
test F1 (micro): 0.9211, test F1 (macro): 0.9210
test Confusion Matrix:
[[572  68]
 [ 33 607]]


  1%|▏         | 1402/100000 [2:09:21<441:43:59, 16.13s/it] 

In [None]:
model, _, _, latest_epoch, metrics = load_checkpoint(model)

val_batches_ = [virus_da.get_batch() for _ in range(num_val // 2)] + [cellular_da.get_batch() for _ in range(num_val // 2)]

# input_sequences_ = [e['Sequence'] for b in val_batches_ for e in b]
# labels_ = [encode_lineage(e['Taxonomic_lineage__ALL_'])  for b in val_batches_ for e in b]

input_sequences_ = ["ACACAD"]
labels_ = [{0: 1}]

def evaluate_df(model):
    model.eval()  # Set model to evaluation mode
    
    df = {
        "sequence": [],
        "label": [],
        "pred": [],
        "loss": []
    }

    metrics = {
        "loss": 0,
        "accuracy": 0,
        "f1 macro": 0,
        "f1 micro": 0
    }
    
    # Process each sequence
    for sequence, label in zip(input_sequences_, labels_):
        inputs = tokenizer_(
            [sequence],
            return_tensors="pt",
            padding='max_length',
            truncation=True,
            max_length=max_seq_len
        ).to(device)
    
        # Get model output
        with torch.no_grad():
            output = model(inputs['input_ids'], inputs['attention_mask'])

        pred = output.argmax(dim=-1).cpu().item()
        loss = criterion(output, torch.tensor([label[0]]).to(device))
        df["sequence"].append(sequence)
        df["label"].append(level_decoder[0][label[0]])
        df["pred"].append(level_decoder[0][pred])
        df["loss"].append(round(loss.cpu().item(), 4))

    # Convert to DataFrame
    new_df = pd.DataFrame(df)
    new_df['is_incorrect'] = new_df['label'] != new_df['pred']
    new_df = new_df.sort_values(['is_incorrect', 'loss'], ascending=[False, False])
    new_df.to_csv(f'classification_results__new_att.csv', index=False)

    metrics["loss"] = np.array(df["loss"]).mean()
    metrics["accuracy"] = accuracy_score(np.array(df["label"]), np.array(df["pred"]))
    metrics["f1 macro"] = f1_score(np.array(df["label"]), np.array(df["pred"]), average='macro')  # F1-score for multi-label classification
    metrics["f1 micro"] = f1_score(np.array(df["label"]), np.array(df["pred"]), average='micro') 
    print(metrics)

evaluate_df(model)