In [1]:
import torch
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import sys, os, math

sys.path.insert(0, '../dlp')
from data_access import PQDataAccess
from data_process import *

pd.set_option('future.no_silent_downcasting', True)
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

batch_size = 64
da = PQDataAccess("/home/aac/Alireza/datasets/taxseq/corpus_1000", batch_size)
epochs= 10_000
val_epoch = 100
num_val = 25

model_name = "Local_T5"
checkpoint_dir = f"../checkpoints/{model_name}_checkpoints"

if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)
print(checkpoint_dir)

 WORLD_SIZE=1 , LOCAL_WORLD_SIZE=1,RANK =0,LOCAL_RANK = 0 


  from .autonotebook import tqdm as notebook_tqdm


Loaded dictionary.


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


cuda:0
../checkpoints/ESM_T5_checkpoints


In [2]:
import torch.nn as nn
import torch

class new_Transformer(nn.Module):
    def __init__(
        self,
        output_dim,
        max_seq_len=1000,
        max_tax_len=100,
        vocab_size=21,
        d_model=512
    ):
        super().__init__()
        
        # Sequence embedding layers
        self.encoder_embedding = nn.Embedding(vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(output_dim, d_model)
        self.input_pos_encoder = PositionalEncoding(d_model, dropout=0.1, max_len=max_seq_len)
        self.output_pos_encoder = PositionalEncoding(d_model, dropout=0.1, max_len=max_tax_len)
        
        self.d_model = d_model
        self.transformer = nn.Transformer(
            d_model = d_model,
            nhead=4,
            num_encoder_layers=3,
            num_decoder_layers=3,
            dim_feedforward=512,
            dropout=0.1,
            batch_first=True,
            norm_first=True,
        )

        self.lm_head = nn.Linear(d_model, output_dim, bias=False)

        self.apply(self._init_weights)

    def forward(self, src, tgt, tgt_mask=None, src_padding_mask=None, tgt_padding_mask=None):
        # Embedding and positional encoding
        src = self.encoder_embedding(src) * torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
        tgt = self.decoder_embedding(tgt) * torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
        
        src = self.input_pos_encoder(src)
        tgt = self.output_pos_encoder(tgt)

        output = self.transformer(src, tgt, tgt_mask=tgt_mask)
        
        output = self.lm_head(output)
        
        return output

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [3]:
# from models.T5Model import new_Transformer
model = new_Transformer(output_dim=len_tokenizer).to(device)
# print(model)
print("model:", sum(p.numel() for p in model.parameters()) / 1e6, 'M parameters')

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=0)



model: 42.481152 M parameters


In [4]:
def train_step(model, optimizer, da, device):
    # Zero the gradients
    optimizer.zero_grad()
    
    # Get batch and convert to tensor
    tensor_batch = T5_data_to_tensor_batch(da.get_batch())
    tensor_batch.gpu(device)
    
    src = tensor_batch.seq_ids
    tgt = tensor_batch.taxes
    
    # Create masks
    tgt_mask = model.transformer.generate_square_subsequent_mask(tgt.size(1)).to(device)

    output = model(
        src,
        tgt,
        tgt_mask=tgt_mask,
        src_padding_mask=(src == 0),
        tgt_padding_mask=(tgt == 0)
    )

    # Calculate loss
    loss = criterion(output.view(-1, output.size(-1)), tgt.view(-1))
    
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
    optimizer.step()
    return loss.item()

In [11]:
def generate(max_len=50, start_token=101, end_token=102, device='cuda'):    
    test_protein_sequence = ["MKTAYIAKQRQISFVKSHFSRQDIL"]
    src = torch.LongTensor([encode_sequence(*test_protein_sequence)]).to(device)
    
    # Initialize target sequence with start token
    # Shape: (batch_size, 1)
    tgt = torch.ones(src.size(0), 1).long().to(device) * start_token
    
    with torch.no_grad():
        for _ in range(max_len):
            # Generate next token probabilities
            output = model(src, tgt)
            # Get the next token prediction
            # output shape: (batch_size, seq_len, vocab_size)
            # We only need the last token prediction
            next_token_logits = output[:, -1, :]
            # print(_, next_token_logits)
            next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
            
            # Append the predicted token to target sequence
            tgt = torch.cat([tgt, next_token], dim=1)
            
            # Check if end token is generated
            if (next_token == end_token).all():
                break

    print(tokenizer.decode(tgt[0]))

In [12]:
# You might also want to add beam search for better generation
def generate_with_beam_search(beam_width=5, max_len=100, start_token=101, end_token=102, device='cuda'):
    test_protein_sequence = ["MKTAYIAKQRQISFVKSHFSRQDIL"]
    src = torch.LongTensor([encode_sequence(*test_protein_sequence)]).to(device)

    model.eval()
    batch_size = src.size(0)
    
    # Initialize beams with start tokens
    beams = [(torch.ones(batch_size, 1).long().to(device) * start_token, 0.0)]
    
    with torch.no_grad():
        for _ in range(max_len):
            candidates = []
            
            # Expand each beam
            for sequence, score in beams:
                # tgt_mask = model.transformer.generate_square_subsequent_mask(sequence.size(1)).to(device)
                # output = model(src, sequence, tgt_mask=tgt_mask)
                output = model(src, sequence)
                
                next_token_logits = output[:, -1, :]
                
                # Get top k next tokens for each beam
                probs = nn.Softmax(dim=-1)(next_token_logits)
                top_probs, top_tokens = probs.topk(beam_width)
                
                for prob, token in zip(top_probs[0], top_tokens[0]):
                    new_sequence = torch.cat([sequence, token.unsqueeze(0).unsqueeze(1)], dim=1)
                    new_score = score - torch.log(prob).item()  # Convert to log probability
                    candidates.append((new_sequence, new_score))
            
            # Select top beam_width candidates
            candidates.sort(key=lambda x: x[1])  # Sort by score
            beams = candidates[:beam_width]
            
            # Check if best beam ended
            if beams[0][0][:, -1].item() == end_token:
                break

    print(tokenizer.decode(beams[0][0][0]))
    # Return the best beam
    return beams[0][0]

In [14]:
import wandb

# start a new wandb run to track this script
wandb.init(
    # set the wandb project where this run will be logged
    project="Local_T5",
)

model.train()

train_losses = []

for epoch in range(epochs):
    train_loss = train_step(model, optimizer, da, device)
    train_losses.append(train_loss)
    
    if (epoch + 1) % val_epoch == 0:
        mean_train_loss = sum(train_losses[-val_epoch:]) / val_epoch
        print(f"Epoch {epoch+1}, Train Loss: {mean_train_loss:.4f}")
        generate()
        # log metrics to wandb
        wandb.log({"train loss": train_loss})

wandb.finish()

Epoch 100, Train Loss: 6.4264
[CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS]
Epoch 200, Train Loss: 2.2150
[CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS]
Epoch 300, Train Loss: 0.9074
[CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS

0,1
train loss,█▃▄▂▄▂▁▁▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train loss,0.0094


In [15]:
generate_with_beam_search()

[CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS] [CLS]


tensor([[101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
         101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
         101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
         101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
         101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
         101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
         101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101, 101,
         101, 101, 101]], device='cuda:0')