In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoTokenizer, AutoModel, T5Tokenizer, T5EncoderModel
import pandas as pd
import numpy as np
from tqdm import tqdm
import sys, os, math

# sys.path.insert(0, '../dlp')
# from data_process import *

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

epochs= 10_000
val_epoch = 100
num_val = 25

model_name = "ESM_T5"
checkpoint_dir = f"../checkpoints/{model_name}_checkpoints"

if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
print(checkpoint_dir)

  from .autonotebook import tqdm as notebook_tqdm


cuda:0
../checkpoints/ESM_T5_checkpoints


In [15]:
class embedding_Transformer(nn.Module):
    def __init__(
        self,
        output_dim,
        max_seq_len=1000,
        max_tax_len=150,
        d_model=512
    ):
        super().__init__()
        
        self.encoder_linear = nn.Linear(320, d_model)
        self.decoder_embedding = AutoModelForMaskedLM.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext")
        
        self.d_model = d_model
        self.transformer = nn.Transformer(
            d_model = d_model,
            nhead=4,
            num_encoder_layers=3,
            num_decoder_layers=3,
            dim_feedforward=512,
            dropout=0.1,
            batch_first=True,
            norm_first=True,
        )

        self.lm_head = nn.Linear(d_model, output_dim, bias=False)

    def forward(self, src, tgt, tgt_mask=None, src_padding_mask=None, tgt_padding_mask=None):
        src = self.encoder_linear(src)
        tgt = self.decoder_embedding(tgt)
        
        output = self.transformer(
            src,
            tgt,
            tgt_mask=tgt_mask,
            src_padding_mask=src_padding_mask,
            tgt_padding_mask=tgt_padding_mask
        )
        
        return self.lm_head(output)


In [16]:
model = embedding_Transformer(output_dim=len_tokenizer).to(device)
print("model:", sum(p.numel() for p in model.parameters()) / 1e6, 'M parameters')

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=0)

Some weights of the model checkpoint at microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


model: 137.93081 M parameters


In [20]:
sys.path.insert(0, '../dlp')
from batch import Batch

# tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")
tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext")
len_tokenizer = len(tokenizer.vocab)
print(len_tokenizer)


def encode_lineage_tokenizer(tax_lineage):
    encoded = tokenizer.encode(tax_lineage.split(", "), add_special_tokens=True, padding='max_length', truncation=True, max_length=max_tax_len, is_split_into_words= True)
    # print(encoded)
    return encoded

def embedding_batch(split, i):
    sequences, lineage_str = torch.load(f'../embeddings/esm_embeddings/{split}/{i}.pt')
    tax_ids = [encode_lineage_tokenizer(s) for s in lineage_str]
    
    return Batch(sequences, torch.LongTensor(tax_ids))

30522
[2, 3689, 7373, 4802, 22266, 7034, 9218, 1021, 10296, 20351, 3955, 25192, 10210, 10210, 5178, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[2, 6241, 25139, 6012, 5806, 3026, 2126, 2234, 3317, 3827, 2029, 3026, 3336, 14984, 3026, 5670, 3857, 10268, 14127, 2927, 14652, 10268, 20165, 1918, 20978, 10647, 3026, 10828, 12, 18273, 1956, 2815, 2447, 17, 13750, 18103, 13, 2037, 1928, 11074, 29562, 2828, 2037, 1928, 11074, 16389, 10503, 4735, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

<batch.Batch at 0x7e7827827df0>

In [21]:
model_name = "esm2_t6_8M_UR50D"
_model = None
_tokenizer = None
_device = 'cuda:1' if torch.cuda.is_available() else 'cpu'

def get_esm_model():
    global _model, _tokenizer
    if _model is None:
        _tokenizer = AutoTokenizer.from_pretrained(f"facebook/{model_name}")
        _model = EsmModel.from_pretrained(f"facebook/{model_name}").to(_device)
    return _model, _tokenizer

def esm_embedding_sequence(sequences):
    model, tokenizer = get_esm_model()
    inputs = tokenizer(
        sequences,
        return_tensors="pt", 
        padding='max_length', 
        truncation=True, 
        max_length=max_seq_len
    ).to(_device)
    
    with torch.no_grad():  # Add this to reduce memory usage
        # Forward pass through the model
        outputs = model(**inputs).last_hidden_state
        # Compute mean across the sequence dimension (or any other pooling method)
        output_embeddings = outputs.cpu()  # Move back to CPU
    return output_embeddings

In [22]:
def train_step(epoch):
    # Zero the gradients
    optimizer.zero_grad()
    
    # Get batch and convert to tensor
    tensor_batch = embedding_batch('train', epoch)
    tensor_batch.gpu(device)
    
    src = tensor_batch.seq_ids
    tgt = tensor_batch.taxes
    
    # Create masks
    tgt_mask = model.transformer.generate_square_subsequent_mask(tgt.size(1)).to(device)

    output = model(
        src,
        tgt,
        tgt_mask=tgt_mask,
        tgt_padding_mask=(tgt == 0)
    )

    # Calculate loss
    loss = criterion(output.view(-1, output.size(-1)), tgt.view(-1))
    
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
    optimizer.step()
    return loss.item()

In [23]:
def generate(max_len=50, start_token=2, end_token=3, device='cuda'):    
    test_protein_sequence = ["MKTAYIAKQRQISFVKSHFSRQDIL"]
    src = esm_embedding_sequence(test_protein_sequence).to(device)
    
    # Initialize target sequence with start token
    # Shape: (batch_size, 1)
    tgt = torch.ones(src.size(0), 1).long().to(device) * start_token
    
    with torch.no_grad():
        for _ in range(max_len):
            # Generate next token probabilities
            output = model(src, tgt)
            # Get the next token prediction
            # output shape: (batch_size, seq_len, vocab_size)
            # We only need the last token prediction
            next_token_logits = output[:, -1, :]
            # print(_, next_token_logits)
            next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
            
            # Append the predicted token to target sequence
            tgt = torch.cat([tgt, next_token], dim=1)
            
            # Check if end token is generated
            if (next_token == end_token).all():
                break

    print(tokenizer.decode(tgt[0]))

In [7]:
# You might also want to add beam search for better generation
def generate_with_beam_search(beam_width=5, max_len=100, start_token=101, end_token=102, device='cuda'):
    test_protein_sequence = ["MKTAYIAKQRQISFVKSHFSRQDIL"]
    src = torch.LongTensor([encode_sequence(*test_protein_sequence)]).to(device)

    model.eval()
    batch_size = src.size(0)
    
    # Initialize beams with start tokens
    beams = [(torch.ones(batch_size, 1).long().to(device) * start_token, 0.0)]
    
    with torch.no_grad():
        for _ in range(max_len):
            candidates = []
            
            # Expand each beam
            for sequence, score in beams:
                # tgt_mask = model.transformer.generate_square_subsequent_mask(sequence.size(1)).to(device)
                # output = model(src, sequence, tgt_mask=tgt_mask)
                output = model(src, sequence)
                
                next_token_logits = output[:, -1, :]
                
                # Get top k next tokens for each beam
                probs = nn.Softmax(dim=-1)(next_token_logits)
                top_probs, top_tokens = probs.topk(beam_width)
                
                for prob, token in zip(top_probs[0], top_tokens[0]):
                    new_sequence = torch.cat([sequence, token.unsqueeze(0).unsqueeze(1)], dim=1)
                    new_score = score - torch.log(prob).item()  # Convert to log probability
                    candidates.append((new_sequence, new_score))
            
            # Select top beam_width candidates
            candidates.sort(key=lambda x: x[1])  # Sort by score
            beams = candidates[:beam_width]
            
            # Check if best beam ended
            if beams[0][0][:, -1].item() == end_token:
                break

    print(tokenizer.decode(beams[0][0][0]))
    # Return the best beam
    return beams[0][0]

In [8]:
import wandb

# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="ESM_T5",
    config={
    "learning_rate": 0.001,
    "architecture": "Transformer",
    "dataset": "seqtext_1000",
    "epochs": epochs,
    } 
)

model.train()

train_losses = []

for epoch in range(epochs):
    train_loss = train_step(epoch)
    train_losses.append(train_loss)
    
    if (epoch + 1) % val_epoch == 0:
        mean_train_loss = sum(train_losses[-val_epoch:]) / val_epoch
        print(f"Epoch {epoch+1}, Train Loss: {mean_train_loss:.4f}")
        generate()
        # log metrics to wandb
        wandb.log({"train loss": train_loss})
        # wandb.log_artifact(model)

wandb.finish()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33malirezanor[0m ([33malirezanor-310-ai[0m). Use [1m`wandb login --relogin`[0m to force relogin


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 100, Train Loss: 6.3186
[CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS]
Epoch 200, Train Loss: 2.3871
[CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS]
Epoch 300, Train Loss: 1.3643
[CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS

0,1
train loss,█▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train loss,0.00353


In [None]:
generate_with_beam_search()