In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import torch.optim as optim
from tqdm import tqdm
import sys, os, math
import wandb
import json
from transformers import EsmModel, AutoTokenizer

sys.path.insert(0, '../dlp')
from batch import Batch

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

epochs = 100_000
val_epoch = 200
num_val = 20
batch_size = 16
virus_dataset_name = "corpus_1000_Viruses"
cellular_dataset_name = "corpus_1000_cellular"
lr = 0.001
model_name = "Freeze ESM All head"
max_seq_len = 1000

from data_access import PQDataAccess
virus_da = PQDataAccess(f"/home/aac/Alireza/datasets/export_pqt_4_taxseq_new/{virus_dataset_name}", batch_size)
cellular_da = PQDataAccess(f"/home/aac/Alireza/datasets/export_pqt_4_taxseq_new/{cellular_dataset_name}", batch_size)

checkpoint_dir = f"../checkpoints/{model_name}_checkpoints"
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
print(checkpoint_dir)

wandb.init(
    # set the wandb project where this run will be logged
    project=model_name,

    # track hyperparameters and run metadata
    config={
        "learning_rate": lr,
        "architecture": model_name,
        "epochs": epochs,
        "batch_szie": batch_size,
        "max_seq_len": max_seq_len
    }
)

  from .autonotebook import tqdm as notebook_tqdm


cuda:0
 WORLD_SIZE=1 , LOCAL_WORLD_SIZE=1,RANK =0,LOCAL_RANK = 0 
../checkpoints/Freeze ESM All head_checkpoints


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33malirezanor[0m ([33malireza_noroozi[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
import random
index2name_file = "../data/taxonomy_index.json"

# Check if the file exists
if os.path.exists(index2name_file):
    with open(index2name_file, "rb") as f:
        index2name = json.load(f)

tax_vocab_sizes = {
    int(k): len(v) for k,v in index2name.items()
}

# print(tax_vocab_sizes)
# # Print tax_vocab_sizes sorted by value (number of taxa per rank)
# sorted_sizes = dict(sorted(tax_vocab_sizes.items(), key=lambda x: x[1], reverse=True))
# print("\nTaxonomic ranks sorted by number of taxa:")
# for rank, size in sorted_sizes.items():
#     print(f"{rank}: {size}")

level_encoder = {
    int(k): {name: idx + 1 for idx, name in enumerate(v)} for k,v in index2name.items()
}

level_decoder = {
    int(k): {idx + 1: name for idx, name in enumerate(v)} for k,v in index2name.items()
}

for k, v in level_decoder.items():
    level_decoder[k][0] = "NOT DEFINED"


def encode_lineage(lineage_str):
    taxes_str = lineage_str.split(", ")

    encoded = {int(k): 0 for k in index2name.keys()}
    
    for i, tax_str in enumerate(taxes_str):
        encoded[i] = level_encoder[i][tax_str]

    return encoded

tokenizer_ = AutoTokenizer.from_pretrained(f"facebook/esm1b_t33_650M_UR50S")

def mix_data_to_tensor_batch(b_virues, b_cellular, max_seq_len=max_seq_len, partition=0.25):
    split_point = int(len(b_virues) * partition)
    b = b_virues[:split_point] + b_cellular[-len(b_virues) + split_point:]
    random.shuffle(b)  # In-place shuffle
    
    inputs = tokenizer_(
        [e['Sequence'] for e in b],
        return_tensors="pt", 
        padding='max_length', 
        truncation=True, 
        max_length=max_seq_len
    )

    tax_ids = [encode_lineage(e['Taxonomic_lineage__ALL_']) for e in b]
    combined_dict = {}
    for d in tax_ids:
        for key, value in d.items():
            combined_dict.setdefault(key, []).append(value)

    tensor_encoded = {k: torch.LongTensor(v) for k,v in combined_dict.items()}
    return Batch(inputs, tensor_encoded)

In [3]:
class ESM1b(nn.Module):
    def __init__(self, num_classes_dict):
        super().__init__()
        self.esm = EsmModel.from_pretrained("facebook/esm1b_t33_650M_UR50S")
        
        # Freeze ESM parameters
        for param in self.esm.parameters():
            param.requires_grad = False
        
        self.heads = nn.ModuleDict()
        self.attentions = nn.ModuleDict()
        
        for index_name, num_classes in num_classes_dict.items():
            self.heads[str(index_name)] = nn.Sequential(
                nn.Linear(1280, 512),
                nn.LayerNorm(512),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(512, num_classes)
            )

            self.attentions[str(index_name)] = nn.Sequential(
                nn.Linear(1280, 256),
                nn.Tanh(),
                nn.Linear(256, 1)
            )
        

    def attention_pooling(self, x, index):
        # x shape: (batch_size, seq_length, embedding_dim)
        attention_weights = self.attentions[index](x)  # (batch_size, seq_length, 1)
        attention_weights = torch.softmax(attention_weights.squeeze(-1), dim=1)  # (batch_size, seq_length)
        attention_weights = attention_weights.unsqueeze(-1)  # (batch_size, seq_length, 1)
        pooled = torch.sum(x * attention_weights, dim=1)  # (batch_size, embedding_dim)
        return pooled

    
    def forward(self, x, attention_mask=None, index=None):
        embeddings = self.esm(x, attention_mask=attention_mask)
        outputs = embeddings.last_hidden_state
        
        return {index: head(self.attention_pooling(outputs, index)) for index, head in self.heads.items()}


In [4]:
model = ESM1b(tax_vocab_sizes).to(device)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f'Trainable parameters: {trainable/ 1e6} M')
print(f'Total parameters: {total/ 1e6} M')
print(model)

optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
criterion = nn.CrossEntropyLoss()
# Cosine annealing with warm restarts
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer,
    T_0=10,  # Initial restart interval
    T_mult=2,  # Multiply interval by 2 after each restart
    eta_min=1e-6  # Minimum learning rate
)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm1b_t33_650M_UR50S and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Trainable parameters: 766.336998 M
Total parameters: 1418.693499 M
ESM1b(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 1280, padding_idx=1)
      (layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 1280, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-32): 33 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=1280, out_features=1280, bias=True)
              (key): Linear(in_features=1280, out_features=1280, bias=True)
              (value): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inpl

In [5]:
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score

val_dir = f"val_results/{model_name}"
if not os.path.exists(val_dir):
    os.makedirs(val_dir)
    
val_batches = [virus_da.get_batch() for _ in range(num_val // 2)] + [cellular_da.get_batch() for _ in range(num_val // 2)]

input_sequences = [e['Sequence'] for b in val_batches for e in b]
labels_ = [encode_lineage(e['Taxonomic_lineage__ALL_'])  for b in val_batches for e in b]

def evaluate(model):
    model.eval()  # Set model to evaluation mode
    
    df = {
        i : {
            "sequence": [],
            "label": [],
            "pred": [],
            "loss": []
        } for i in tax_vocab_sizes.keys()
    }

    metrics = {
        i : {
            "loss": 0,
            "accuracy": 0,
            "f1 macro": 0,
            "f1 micro": 0
        } for i in tax_vocab_sizes.keys()
    }
    
    # Process each sequence
    for sequence, label in zip(input_sequences, labels_):
        inputs = tokenizer_(
            [sequence],
            return_tensors="pt",
            padding='max_length',
            truncation=True,
            max_length=max_seq_len
        ).to(device)
    
        # Get model output
        with torch.no_grad():
            output = model(inputs['input_ids'], inputs['attention_mask'])

        for k in tax_vocab_sizes.keys():
            pred = output[str(k)].argmax(dim=-1).cpu().item()
            loss = criterion(output[str(k)], torch.tensor([label[k]]).to(device))
            df[k]["sequence"].append(sequence)
            df[k]["label"].append(level_decoder[k][label[k]])
            df[k]["pred"].append(level_decoder[k][pred])
            df[k]["loss"].append(round(loss.cpu().item(), 4))

    for k in tax_vocab_sizes.keys():
        # Convert to DataFrame
        new_df = pd.DataFrame(df[k])
        new_df['is_incorrect'] = new_df['label'] != new_df['pred']
        new_df = new_df.sort_values(['is_incorrect', 'loss'], ascending=[False, False])
        new_df.to_csv(f'val_results/{model_name}/classification_results_{k}.csv', index=False)

        metrics[k]["loss"] = round(np.array(df[k]["loss"]).mean(), 4)
        metrics[k]["accuracy"] = accuracy_score(np.array(df[k]["label"]), np.array(df[k]["pred"]))
        metrics[k]["f1 macro"] = f1_score(np.array(df[k]["label"]), np.array(df[k]["pred"]), average='macro')  # F1-score for multi-label classification
        metrics[k]["f1 micro"] = f1_score(np.array(df[k]["label"]), np.array(df[k]["pred"]), average='micro') 
    
    return metrics

In [6]:
# evaluate(model)

In [7]:
def get_partition_ratio(epoch, decay_epochs=100000):
    """
    Calculate partition ratio that decreases from 8/16 to 1/16 in steps
    """
    # Calculate how many epochs before each step down
    epochs_per_step = decay_epochs // 7  # 7 steps from 8/16 down to 1/16
    
    # Calculate current step based on epoch
    step = min(epoch // epochs_per_step, 7)  # Max 7 steps down from 8
    
    # Map step to fraction
    fraction = (8 - step) / 16
    
    return fraction

In [None]:
running_loss = 0
current_lr = lr

for epoch in tqdm(range(epochs)):
    model.train()

    tensor_batch = mix_data_to_tensor_batch(virus_da.get_batch(), cellular_da.get_batch(), partition = get_partition_ratio(epoch+1))
    tensor_batch.gpu(device)
    
    labels = tensor_batch.taxes
    output = model(tensor_batch.seq_ids['input_ids'], tensor_batch.seq_ids['attention_mask'])

    batch_loss = 0
    for index in tax_vocab_sizes.keys():
        optimizer.zero_grad()
        loss = criterion(output[str(index)], labels[int(index)])
        loss.backward()
        optimizer.step()
        batch_loss += loss
    
    running_loss += batch_loss.item()
    
    if (epoch + 1) % val_epoch == 0:
        train_loss = running_loss / val_epoch
        val_metrics = evaluate(model)
        val_losses = {k: v["loss"] for k, v in val_metrics.items()} 
        val_loss = sum([entry['loss'] for entry in val_metrics.values()]) 
        print(f"Epoch [{epoch + 1}/{epochs}]")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Loss: {val_loss:.4f}")
        print(val_losses)
        
        # Create metrics dictionary for saving
        metrics = {
            "train_loss": train_loss,
            "val_loss": val_loss,
            "epoch": epoch + 1,
            "lr": current_lr,
            "partition": get_partition_ratio(epoch+1)
        }
        for k, v in val_losses.items():
            metrics[f"val loss head {k}"] = v

        # Save periodic checkpoint
        checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'metrics': metrics
        }, checkpoint_path)
        
        # Log to wandb
        wandb.log(metrics)

        # Step the scheduler
        scheduler.step(epoch + batch_loss.item())
        current_lr = scheduler.get_last_lr()[0]
        
        # Reset training metrics
        running_loss = 0
        
wandb.finish()

  0%|          | 199/100000 [04:20<36:15:46,  1.31s/it]

Epoch [200/100000]
Train Loss: 57.2345
Val Loss: 45.5851
{0: 2.8985, 1: 0.4674, 2: 0.9905, 3: 1.5665, 4: 2.0715, 5: 2.6607, 6: 3.2483, 7: 4.4959, 8: 5.6275, 9: 4.3652, 10: 2.5532, 11: 2.0157, 12: 0.8898, 13: 0.852, 14: 0.8094, 15: 0.808, 16: 0.6794, 17: 0.7341, 18: 0.7761, 19: 0.8093, 20: 0.7883, 21: 0.8302, 22: 0.7332, 23: 0.5961, 24: 0.5347, 25: 0.5451, 26: 0.449, 27: 0.4818, 28: 0.3968, 29: 0.3775, 30: 0.2898, 31: 0.1423, 32: 0.0963, 33: 0.0039, 34: 0.0009, 35: 0.0001, 36: 0.0001}


  0%|          | 399/100000 [09:51<36:01:13,  1.30s/it] 

Epoch [400/100000]
Train Loss: 41.3255
Val Loss: 43.3288
{0: 3.8725, 1: 0.4358, 2: 0.8479, 3: 1.3272, 4: 1.8291, 5: 2.3748, 6: 2.9434, 7: 4.1658, 8: 5.2943, 9: 4.2358, 10: 2.4412, 11: 1.9265, 12: 0.8063, 13: 0.7607, 14: 0.7486, 15: 0.7592, 16: 0.6334, 17: 0.6453, 18: 0.6986, 19: 0.7177, 20: 0.6982, 21: 0.7574, 22: 0.6856, 23: 0.5541, 24: 0.5082, 25: 0.5095, 26: 0.4214, 27: 0.4566, 28: 0.3809, 29: 0.367, 30: 0.2894, 31: 0.1416, 32: 0.092, 33: 0.0021, 34: 0.0006, 35: 0.0001, 36: 0.0}


  1%|          | 599/100000 [15:22<36:03:42,  1.31s/it] 

Epoch [600/100000]
Train Loss: 39.2807
Val Loss: 41.8149
{0: 4.1397, 1: 0.4336, 2: 0.8546, 3: 1.3674, 4: 1.8533, 5: 2.305, 6: 2.8843, 7: 4.029, 8: 5.0816, 9: 4.1286, 10: 2.3739, 11: 1.8479, 12: 0.7536, 13: 0.706, 14: 0.698, 15: 0.7226, 16: 0.5622, 17: 0.5717, 18: 0.6549, 19: 0.6509, 20: 0.6329, 21: 0.6576, 22: 0.5895, 23: 0.4668, 24: 0.4096, 25: 0.4401, 26: 0.3679, 27: 0.406, 28: 0.363, 29: 0.3566, 30: 0.2785, 31: 0.1336, 32: 0.0901, 33: 0.0024, 34: 0.0012, 35: 0.0003, 36: 0.0}


  1%|          | 799/100000 [20:54<35:59:56,  1.31s/it] 

Epoch [800/100000]
Train Loss: 38.6702
Val Loss: 42.4349
{0: 5.4021, 1: 0.4523, 2: 0.8887, 3: 1.3269, 4: 1.7663, 5: 2.2407, 6: 2.7953, 7: 3.9614, 8: 5.0403, 9: 4.1548, 10: 2.333, 11: 1.7782, 12: 0.7186, 13: 0.6844, 14: 0.6652, 15: 0.7184, 16: 0.5487, 17: 0.5524, 18: 0.6123, 19: 0.5984, 20: 0.6291, 21: 0.6727, 22: 0.5986, 23: 0.475, 24: 0.4093, 25: 0.442, 26: 0.3862, 27: 0.4172, 28: 0.3414, 29: 0.3432, 30: 0.2669, 31: 0.1187, 32: 0.0887, 33: 0.0057, 34: 0.0013, 35: 0.0005, 36: 0.0}


  1%|          | 999/100000 [26:17<36:20:08,  1.32s/it] 

Epoch [1000/100000]
Train Loss: 36.6787
Val Loss: 38.2178
{0: 3.1063, 1: 0.4172, 2: 0.8363, 3: 1.2727, 4: 1.7062, 5: 2.1648, 6: 2.7158, 7: 3.8015, 8: 4.8759, 9: 4.0554, 10: 2.2155, 11: 1.7142, 12: 0.6322, 13: 0.6207, 14: 0.6159, 15: 0.6266, 16: 0.4941, 17: 0.4998, 18: 0.6219, 19: 0.5569, 20: 0.5622, 21: 0.6145, 22: 0.5404, 23: 0.415, 24: 0.3667, 25: 0.4036, 26: 0.3058, 27: 0.3707, 28: 0.3083, 29: 0.3137, 30: 0.2573, 31: 0.1174, 32: 0.0918, 33: 0.0003, 34: 0.0002, 35: 0.0, 36: 0.0}


  1%|          | 1199/100000 [31:40<35:50:41,  1.31s/it] 

Epoch [1200/100000]
Train Loss: 35.3446
Val Loss: 37.6746
{0: 3.644, 1: 0.3878, 2: 0.7922, 3: 1.1989, 4: 1.5885, 5: 2.0473, 6: 2.6033, 7: 3.7307, 8: 4.8197, 9: 4.0141, 10: 2.1574, 11: 1.6996, 12: 0.6093, 13: 0.5959, 14: 0.5897, 15: 0.614, 16: 0.4683, 17: 0.4869, 18: 0.5739, 19: 0.5375, 20: 0.5428, 21: 0.5736, 22: 0.5174, 23: 0.3953, 24: 0.348, 25: 0.3807, 26: 0.3025, 27: 0.3569, 28: 0.3011, 29: 0.3188, 30: 0.2686, 31: 0.1168, 32: 0.0899, 33: 0.0026, 34: 0.0006, 35: 0.0, 36: 0.0}


  1%|▏         | 1341/100000 [36:08<35:46:19,  1.31s/it] 