In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import torch.optim as optim
from tqdm import tqdm
import sys, os, math
import wandb
import json
from transformers import EsmModel, AutoTokenizer

sys.path.insert(0, '../dlp')
from batch import Batch

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

epochs = 100_000
val_epoch = 1_000
num_val = 20
batch_size = 16
virus_dataset_name = "corpus_1000_Viruses"
cellular_dataset_name = "corpus_1000_cellular"
lr = 0.001
model_name = "Freeze ESM All head"
max_seq_len = 1000

from data_access import PQDataAccess
virus_da = PQDataAccess(f"/home/aac/Alireza/datasets/export_pqt_4_taxseq_new/{virus_dataset_name}", batch_size)
cellular_da = PQDataAccess(f"/home/aac/Alireza/datasets/export_pqt_4_taxseq_new/{cellular_dataset_name}", batch_size)

checkpoint_dir = f"../checkpoints/{model_name}_checkpoints"
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
print(checkpoint_dir)

wandb.init(
    # set the wandb project where this run will be logged
    project=model_name,

    # track hyperparameters and run metadata
    config={
        "learning_rate": lr,
        "architecture": model_name,
        "epochs": epochs,
        "batch_szie": batch_size,
        "max_seq_len": max_seq_len
    }
)

  from .autonotebook import tqdm as notebook_tqdm


cuda:0
 WORLD_SIZE=1 , LOCAL_WORLD_SIZE=1,RANK =0,LOCAL_RANK = 0 
../checkpoints/Freeze ESM All head_checkpoints


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33malirezanor[0m ([33malireza_noroozi[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
import random
index2name_file = "../data/taxonomy_index.json"

# Check if the file exists
if os.path.exists(index2name_file):
    with open(index2name_file, "rb") as f:
        index2name = json.load(f)

tax_vocab_sizes = {
    int(k): len(v) for k,v in index2name.items()
}

print(tax_vocab_sizes)
# Print tax_vocab_sizes sorted by value (number of taxa per rank)
sorted_sizes = dict(sorted(tax_vocab_sizes.items(), key=lambda x: x[1], reverse=True))
print("\nTaxonomic ranks sorted by number of taxa:")
for rank, size in sorted_sizes.items():
    print(f"{rank}: {size}")

level_encoder = {
    int(k): {name: idx + 1 for idx, name in enumerate(v)} for k,v in index2name.items()
}

level_decoder = {
    int(k): {idx + 1: name for idx, name in enumerate(v)} for k,v in index2name.items()
}

for k, v in level_decoder.items():
    level_decoder[k][0] = "NOT DEFINED"


def encode_lineage(lineage_str):
    taxes_str = lineage_str.split(", ")

    encoded = {int(k): 0 for k in index2name.keys()}
    
    for i, tax_str in enumerate(taxes_str):
        encoded[i] = level_encoder[i][tax_str]

    return encoded

tokenizer_ = AutoTokenizer.from_pretrained(f"facebook/esm1b_t33_650M_UR50S")

def mix_data_to_tensor_batch(b_virues, b_cellular, max_seq_len=max_seq_len, partition=0.25):
    split_point = int(len(b_virues) * partition)
    b = b_virues[:split_point] + b_cellular[-len(b_virues) + split_point:]
    random.shuffle(b)  # In-place shuffle
    
    inputs = tokenizer_(
        [e['Sequence'] for e in b],
        return_tensors="pt", 
        padding='max_length', 
        truncation=True, 
        max_length=max_seq_len
    )

    tax_ids = [encode_lineage(e['Taxonomic_lineage__ALL_']) for e in b]
    combined_dict = {}
    for d in tax_ids:
        for key, value in d.items():
            combined_dict.setdefault(key, []).append(value)

    tensor_encoded = {k: torch.LongTensor(v) for k,v in combined_dict.items()}
    return Batch(inputs, tensor_encoded)

{0: 2, 1: 33, 2: 1298, 3: 5633, 4: 3733, 5: 5181, 6: 5444, 7: 15383, 8: 65228, 9: 59006, 10: 53852, 11: 141571, 12: 20369, 13: 13615, 14: 17388, 15: 32162, 16: 35564, 17: 34333, 18: 33100, 19: 43499, 20: 61507, 21: 57902, 22: 31211, 23: 32827, 24: 107510, 25: 82470, 26: 96454, 27: 90631, 28: 85202, 29: 70506, 30: 71899, 31: 26726, 32: 11716, 33: 6444, 34: 2510, 35: 872, 36: 4}

Taxonomic ranks sorted by number of taxa:
11: 141571
24: 107510
26: 96454
27: 90631
28: 85202
25: 82470
30: 71899
29: 70506
8: 65228
20: 61507
9: 59006
21: 57902
10: 53852
19: 43499
16: 35564
17: 34333
18: 33100
23: 32827
15: 32162
22: 31211
31: 26726
12: 20369
14: 17388
7: 15383
13: 13615
32: 11716
33: 6444
3: 5633
6: 5444
5: 5181
4: 3733
34: 2510
2: 1298
35: 872
1: 33
36: 4
0: 2


In [4]:
class ESM1b(nn.Module):
    def __init__(self, num_classes_dict):
        super().__init__()
        self.esm = EsmModel.from_pretrained("facebook/esm1b_t33_650M_UR50S")
        
        # Freeze ESM parameters
        for param in self.esm.parameters():
            param.requires_grad = False
        
        self.heads = nn.ModuleDict()
        for index_name, num_classes in num_classes_dict.items():
            self.heads[str(index_name)] = nn.Sequential(
                nn.Linear(1280, 512),
                nn.ReLU(),
                nn.Linear(512, num_classes)
            )
    
    def get_optimizers(self, base_lr=1e-4):
        optimizers = {}
        for index, head in self.heads.items():
            optimizers[index] = torch.optim.Adam(head.parameters(), lr=base_lr)
        return optimizers
    
    def forward(self, x, attention_mask=None, index=None):
        embeddings = self.esm(x, attention_mask=attention_mask)
        outputs = embeddings.last_hidden_state.mean(dim=1)
        
        if index is not None:
            if index not in self.heads:
                raise ValueError(f"Task {index} not found in model heads")
            return self.heads[index](outputs)
        
        return {index: head(outputs) for index, head in self.heads.items()}

# Usage example
model = ESM1b(num_classes_dict={0: 10, 1: 20})
optimizers = model.get_optimizers()

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm1b_t33_650M_UR50S and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
model = ESM1b(tax_vocab_sizes).to(device)
optimizers = model.get_optimizers()


trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f'Trainable parameters: {trainable/ 1e6} M')
print(f'Total parameters: {total/ 1e6} M')
print(model)

# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
criterion = nn.CrossEntropyLoss()
# Cosine annealing with warm restarts
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizers['0'],
    T_0=10,  # Initial restart interval
    T_mult=2,  # Multiply interval by 2 after each restart
    eta_min=1e-6  # Minimum learning rate
)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm1b_t33_650M_UR50S and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Trainable parameters: 754.155969 M
Total parameters: 1406.51247 M
ESM1b(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 1280, padding_idx=1)
      (layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 1280, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-32): 33 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=1280, out_features=1280, bias=True)
              (key): Linear(in_features=1280, out_features=1280, bias=True)
              (value): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inpla

In [8]:
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score

val_dir = f"val_results/{model_name}"
if not os.path.exists(val_dir):
    os.makedirs(val_dir)
    
val_batches = [virus_da.get_batch() for _ in range(num_val // 2)] + [cellular_da.get_batch() for _ in range(num_val // 2)]

input_sequences = [e['Sequence'] for b in val_batches for e in b]
labels_ = [encode_lineage(e['Taxonomic_lineage__ALL_'])  for b in val_batches for e in b]

def evaluate(model):
    model.eval()  # Set model to evaluation mode
    
    df = {
        i : {
            "sequence": [],
            "label": [],
            "pred": [],
            "loss": []
        } for i in tax_vocab_sizes.keys()
    }

    metrics = {
        i : {
            "loss": 0,
            "accuracy": 0,
            "f1 macro": 0,
            "f1 micro": 0
        } for i in tax_vocab_sizes.keys()
    }
    
    # Process each sequence
    for sequence, label in zip(input_sequences, labels_):
        inputs = tokenizer_(
            [sequence],
            return_tensors="pt",
            padding='max_length',
            truncation=True,
            max_length=max_seq_len
        ).to(device)
    
        # Get model output
        with torch.no_grad():
            output = model(inputs['input_ids'], inputs['attention_mask'])

        for k in tax_vocab_sizes.keys():
            pred = output[str(k)].argmax(dim=-1).cpu().item()
            loss = criterion(output[str(k)], torch.tensor([label[k]]).to(device))
            df[k]["sequence"].append(sequence)
            df[k]["label"].append(level_decoder[k][label[k]])
            df[k]["pred"].append(level_decoder[k][pred])
            df[k]["loss"].append(round(loss.cpu().item(), 4))

    for k in tax_vocab_sizes.keys():
        # Convert to DataFrame
        new_df = pd.DataFrame(df[k])
        new_df['is_incorrect'] = new_df['label'] != new_df['pred']
        new_df = new_df.sort_values(['is_incorrect', 'loss'], ascending=[False, False])
        new_df.to_csv(f'val_results/{model_name}/classification_results_{k}.csv', index=False)

        metrics[k]["loss"] = np.array(df[k]["loss"]).mean()
        metrics[k]["accuracy"] = accuracy_score(np.array(df[k]["label"]), np.array(df[k]["pred"]))
        metrics[k]["f1 macro"] = f1_score(np.array(df[k]["label"]), np.array(df[k]["pred"]), average='macro')  # F1-score for multi-label classification
        metrics[k]["f1 micro"] = f1_score(np.array(df[k]["label"]), np.array(df[k]["pred"]), average='micro') 
    
    return metrics

In [9]:
evaluate(model)

{0: {'loss': 0.5130515625000001,
  'accuracy': 0.5,
  'f1 macro': 0.3333333333333333,
  'f1 micro': 0.5},
 1: {'loss': 3.52185875, 'accuracy': 0.0, 'f1 macro': 0.0, 'f1 micro': 0.0},
 2: {'loss': 7.239747812500001,
  'accuracy': 0.0,
  'f1 macro': 0.0,
  'f1 micro': 0.0},
 3: {'loss': 8.6302459375, 'accuracy': 0.0, 'f1 macro': 0.0, 'f1 micro': 0.0},
 4: {'loss': 8.2487228125, 'accuracy': 0.0, 'f1 macro': 0.0, 'f1 micro': 0.0},
 5: {'loss': 8.5836009375, 'accuracy': 0.0, 'f1 macro': 0.0, 'f1 micro': 0.0},
 6: {'loss': 8.657818125, 'accuracy': 0.0, 'f1 macro': 0.0, 'f1 micro': 0.0},
 7: {'loss': 9.632887812500002,
  'accuracy': 0.0,
  'f1 macro': 0.0,
  'f1 micro': 0.0},
 8: {'loss': 11.102313125, 'accuracy': 0.0, 'f1 macro': 0.0, 'f1 micro': 0.0},
 9: {'loss': 10.909422187500002,
  'accuracy': 0.0,
  'f1 macro': 0.0,
  'f1 micro': 0.0},
 10: {'loss': 10.911146562499999,
  'accuracy': 0.0,
  'f1 macro': 0.0,
  'f1 micro': 0.0},
 11: {'loss': 12.0065621875,
  'accuracy': 0.0,
  'f1 macro'

In [10]:
def get_partition_ratio(epoch, decay_epochs=100000):
    """
    Calculate partition ratio that decreases from 8/16 to 1/16 in steps
    """
    # Calculate how many epochs before each step down
    epochs_per_step = decay_epochs // 7  # 7 steps from 8/16 down to 1/16
    
    # Calculate current step based on epoch
    step = min(epoch // epochs_per_step, 7)  # Max 7 steps down from 8
    
    # Map step to fraction
    fraction = (8 - step) / 16
    
    return fraction

In [None]:
running_loss = 0
current_lr = lr

for epoch in tqdm(range(epochs)):
    model.train()

    tensor_batch = mix_data_to_tensor_batch(virus_da.get_batch(), cellular_da.get_batch(), partition = get_partition_ratio(epoch+1))
    tensor_batch.gpu(device)
    
    labels = tensor_batch.taxes

    batch_loss = 0
    for index, optimizer in optimizers.items():
        optimizer.zero_grad()
        output = model(tensor_batch.seq_ids['input_ids'], tensor_batch.seq_ids['attention_mask'], index=index)
        loss = criterion(output, labels[int(index)])
        loss.backward()
        optimizer.step()
        batch_loss += loss
    
    running_loss += batch_loss.item()
    
    if (epoch + 1) % val_epoch == 0:
        train_loss = running_loss / val_epoch
        val_metrics = evaluate(model)
        val_losses = {k: v["loss"] for k, v in val_metrics.items()} 
        val_loss = sum([entry['loss'] for entry in val_metrics.values()]) 
        print(f"Epoch [{epoch + 1}/{epochs}]")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Loss: {val_loss:.4f}")
        print(val_losses)
        
        # Create metrics dictionary for saving
        metrics = {
            "train_loss": train_loss,
            "val_loss": val_loss,
            "epoch": epoch + 1,
            "lr": current_lr,
            "partition": get_partition_ratio(epoch+1)
        }
        metrics.update(val_losses)
            
        # Save periodic checkpoint
        checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizers,
            'scheduler_state_dict': scheduler.state_dict(),
            'metrics': metrics
        }, checkpoint_path)
        
        # Log to wandb
        wandb.log(metrics)

        # Step the scheduler
        scheduler.step(epoch + batch_loss.item())
        current_lr = scheduler.get_last_lr()[0]
        
        # Reset training metrics
        running_loss = 0
wandb.finish()

  0%|          | 457/100000 [5:00:38<1090:48:51, 39.45s/it]