In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import torch.optim as optim
from tqdm import tqdm
import sys, os, math
import wandb
import json
from transformers import EsmModel, AutoTokenizer


sys.path.insert(0, '../dlp')
from batch import Batch
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

epochs = 100_000
val_epoch = 1000
num_val = 500
batch_size = 16
dataset_name = "corpus_200_500_random"
lr = 0.001
model_name = "Freeze ESM Multiple head"
max_seq_len = 500

from data_access import PQDataAccess
da = PQDataAccess(f"/home/aac/Alireza/datasets/export_pqt_4_taxseq_new/{dataset_name}", 2 * batch_size)

checkpoint_dir = f"../checkpoints/{model_name}_checkpoints"
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
print(checkpoint_dir)

wandb.init(
    # set the wandb project where this run will be logged
    project=model_name,

    # track hyperparameters and run metadata
    config={
        "learning_rate": lr,
        "architecture": "Freeze ESM Multiple Heads",
        "dataset": dataset_name,
        "epochs": epochs,
        "batch_szie": batch_size,
        "max_seq_len": max_seq_len
    }
)

  from .autonotebook import tqdm as notebook_tqdm


cuda:0
../checkpoints/Freeze ESM Multiple head_checkpoints


In [6]:
index2name_file = "../data/index2name.json"

# Check if the file exists
if os.path.exists(index2name_file):
    with open(index2name_file, "rb") as f:
        index2name = json.load(f)

tax_vocab_sizes = {
    k: len(v) for k,v in index2name.items()
}

print(tax_vocab_sizes)
# Print tax_vocab_sizes sorted by value (number of taxa per rank)
sorted_sizes = dict(sorted(tax_vocab_sizes.items(), key=lambda x: x[1], reverse=True))
print("\nTaxonomic ranks sorted by number of taxa:")
for rank, size in sorted_sizes.items():
    print(f"{rank}: {size}")

level_encoder = {
    k: {name: idx + 1 for idx, name in enumerate(v)} for k,v in index2name.items()
}

level_decoder = {
    k: {idx + 1: name for idx, name in enumerate(v)} for k,v in index2name.items()
}

def encode_lineage(lineage_str):
    taxes_str = lineage_str.split(", ")
    # print(taxes_str)

    encoded = {k: [0] for k in rank2name.keys()}
    
    for i, tax_str in enumerate(taxes_str):
        encoded[i][0] = level_encoder[i].get(tax_str, 0)

    return encoded

tokenizer_ = AutoTokenizer.from_pretrained(f"facebook/esm1b_t33_650M_UR50S")

def data_to_tensor_batch(b):
    inputs = tokenizer_(
        [e['Sequence'] for e in b],
        return_tensors="pt", 
        padding='max_length', 
        truncation=True, 
        max_length=max_seq_len
    )

    tax_ids = [encode_lineage(e['Taxonomic_lineage__ALL_']) for e in b]
    combined_dict = {}
    for d in tax_ids:
        for key, value in d.items():
            combined_dict.setdefault(key, []).extend(value)

    tensor_encoded = {k: torch.LongTensor(v) for k,v in combined_dict.items()}
    return Batch(inputs, tensor_encoded)

{'0': 4, '1': 35, '2': 1368, '3': 6024, '4': 4330, '5': 5265, '6': 5453, '7': 15592, '8': 65895, '9': 59786, '10': 54221, '11': 141660, '12': 20431, '13': 13635, '14': 17392, '15': 32172, '16': 35581, '17': 34356, '18': 33115, '19': 43504, '20': 61510, '21': 57903, '22': 31214, '23': 32835, '24': 107520, '25': 82492, '26': 96471, '27': 90666, '28': 85207, '29': 70508, '30': 71902, '31': 26728, '32': 11716, '33': 6444, '34': 2510, '35': 872, '36': 4}

Taxonomic ranks sorted by number of taxa:
11: 141660
24: 107520
26: 96471
27: 90666
28: 85207
25: 82492
30: 71902
29: 70508
8: 65895
20: 61510
9: 59786
21: 57903
10: 54221
19: 43504
16: 35581
17: 34356
18: 33115
23: 32835
15: 32172
22: 31214
31: 26728
12: 20431
14: 17392
7: 15592
13: 13635
32: 11716
33: 6444
3: 6024
6: 5453
5: 5265
4: 4330
34: 2510
2: 1368
35: 872
1: 35
0: 4
36: 4


In [9]:
class ESM2(nn.Module):
    def __init__(self, num_classes_dict):
        super().__init__()
        self.esm = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
        
        # Freeze ESM parameters
        for param in self.esm.parameters():
            param.requires_grad = False
        
        # Create separate classification heads for each task
        self.heads = nn.ModuleDict()
        for index_name, num_classes in num_classes_dict.items():
            self.heads[index_name] = nn.Sequential(
                nn.Linear(1280, 512),
                nn.ReLU(),
                nn.Linear(512, num_classes)
            )
    
    def forward(self, x, attention_mask=None, index=None):
        # Get ESM embeddings
        embeddings = self.esm(x, attention_mask=attention_mask).pooler_output
        
        # If specific task requested, return only that output
        if index is not None:
            if index not in self.heads:
                raise ValueError(f"Task {index} not found in model heads")
            return self.heads[index](embeddings)
        
        # Otherwise return all task outputs
        return {index: head(embeddings) for index, head in self.heads.items()}

In [8]:
model = ESM2(tax_vocab_sizes).to(device)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f'Trainable parameters: {trainable/ 1e6} M')
print(f'Total parameters: {total/ 1e6} M')
print(model)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
criterion = nn.CrossEntropyLoss()
# Cosine annealing with warm restarts
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer,
    T_0=10,  # Initial restart interval
    T_mult=2,  # Multiply interval by 2 after each restart
    eta_min=1e-6  # Minimum learning rate
)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Trainable parameters: 755.969937 M
Total parameters: 1408.323878 M
ESM2(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 1280, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 1280, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-32): 33 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=1280, out_features=1280, bias=True)
              (key): Linear(in_features=1280, out_features=1280, bias=True)
              (value): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            

In [6]:
val_batches = [da.get_batch() for _ in range(num_val)]

def evaluate(model):
    model.eval()  # Set model to evaluation mode
    
    running_loss = 0.0
    losses = {k: 0.0 for k in tax_vocab_sizes.keys()}
    
    for epoch in range(num_val):
        with torch.no_grad():  # Disable gradient computation during evaluation
            tensor_batch = data_to_tensor_batch(val_batches[epoch])
            tensor_batch.gpu(device)
        
            labels = tensor_batch.taxes
            outputs = model(tensor_batch.seq_ids['input_ids'], tensor_batch.seq_ids['attention_mask'])

            # Calculate individual task losses
            batch_loss = 0
            for k in tax_vocab_sizes.keys():
                task_loss = criterion(outputs[k], labels[k])
                losses[k] += task_loss.item()
                batch_loss += task_loss
            
            running_loss += batch_loss.item()
    
    avg_loss = running_loss / num_val
    avg_index_losses = {k: losses[k]/num_val for k in losses}
    
    return avg_index_losses, avg_loss

In [7]:
import glob
def load_checkpoint(model, optimizer=None, scheduler=None):
    checkpoints = glob.glob(os.path.join("new data Fine Tune ESM uniform sampling_checkpoints", 'checkpoint_epoch_*.pt'))        
    # Extract epoch numbers and find latest
    latest_checkpoint = max(checkpoints, key=lambda x: int(x.split('_')[-1].split('.')[0]))
    checkpoint = torch.load(latest_checkpoint)
    
    # Load model state
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)

    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    criterion = nn.CrossEntropyLoss()
    # Cosine annealing with warm restarts
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=10,  # Initial restart interval
        T_mult=2,  # Multiply interval by 2 after each restart
        eta_min=1e-6  # Minimum learning rate
    )
    
    # Load optimizer state if provided (for training)
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        # Move optimizer state to GPU if necessary
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)

    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    # Get training metadata
    epoch = checkpoint['epoch']
    metrics = checkpoint['metrics']
    
    print(f"Successfully loaded checkpoint from epoch {epoch}")
    print("Metrics at checkpoint:", metrics)
    
    return model, optimizer, scheduler, epoch, metrics
        

# model, optimizer, scheduler, latest_epoch, metrics = load_checkpoint(model, optimizer, scheduler)

In [None]:
running_loss = 0
current_lr = lr
losses = {k: 0.0 for k in tax_vocab_sizes.keys()}


for epoch in tqdm(range(epochs)):
    model.train()

    tensor_batch = data_to_tensor_batch(da.get_batch())
    tensor_batch.gpu(device)
    
    labels = tensor_batch.taxes
    outputs = model(tensor_batch.seq_ids['input_ids'], tensor_batch.seq_ids['attention_mask'])
    
    batch_loss = 0
    for k in tax_vocab_sizes.keys():
        task_loss = criterion(outputs[k], labels[k])
        losses[k] += task_loss.item()
        batch_loss += task_loss
    
    running_loss += batch_loss.item()
    
    optimizer.zero_grad()
    batch_loss.backward()
    optimizer.step()
    
    
    if (epoch + 1) % val_epoch == 0:
        train_loss = running_loss / val_epoch
        val_losses, val_loss = evaluate(model)
        print(f"Epoch [{epoch + 1}/{epochs}]")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Loss: {val_loss:.4f}")
        
        # Create metrics dictionary for saving
        metrics = {
            "train_loss": train_loss,
            "val_loss": val_loss,
            "epoch": epoch + 1,
            "lr": current_lr
        }
        for k in tax_vocab_sizes.keys():
            metrics[f"val_loss_{k}"] = val_losses[k]
            metrics[f"train_loss_{k}"] = losses[k]
            

        # Save periodic checkpoint
        checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'metrics': metrics
        }, checkpoint_path)
        
        # Log to wandb
        wandb.log(metrics)

        # Step the scheduler
        scheduler.step(epoch + loss.item())
        current_lr = scheduler.get_last_lr()[0]
        
        # Reset training metrics
        running_loss = 0
wandb.finish()

  1%|          | 999/100000 [08:28<13:55:31,  1.97it/s]

Epoch [1000/100000]
Train Loss: 0.1842, Train Accuracy: 0.9389
Train F1 (micro): 0.9389, Train F1 (macro): 0.6274
Train Confusion Matrix:
[[   0    0    0    0]
 [   0 7618  382    0]
 [   0  520 7404    0]
 [   0   14   62    0]]
Val Loss: 0.1754, Val Accuracy: 0.9381
Val F1 (micro): 0.9381, Val F1 (macro): 0.4357
Val Confusion Matrix:
[[    0     0     0     0]
 [    0   230     8     0]
 [    0   868 14780     0]
 [    0    22    92     0]]


  2%|▏         | 1999/100000 [25:13<13:47:04,  1.97it/s]   

Epoch [2000/100000]
Train Loss: 0.1410, Train Accuracy: 0.9526
Train F1 (micro): 0.9526, Train F1 (macro): 0.6364
Train Confusion Matrix:
[[   0    0    0    0]
 [   0 7719  281    0]
 [   0  408 7523    0]
 [   0   10   59    0]]
