In [1]:
import torch
import pandas as pd
from torch import nn
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
import numpy as np
import random
from tqdm import tqdm
import sys
import os

sys.path.append(os.path.abspath('../'))

In [2]:
from models.RNA_transformer.model import build_transformer
from models.RNA_transformer.config import get_config, get_weights_file_path, latest_weights_file_path
from models.RNA_transformer.vocabulary import Vocabulary

config = get_config()

d_model = EMBED_DIM = config["d_model"]
VOCAB_SRC_LENGTH = config["vocab_src_length"]
VOCAB_TGT_LENGTH = config["vocab_tgt_length"]
SRC_SEQ_LEN = TGT_SEQ_LEN = config["seq_len"]
CUSTOM_EMB_PERCENTAGE = config["custom_emb_percentage"]
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False

vocab = Vocabulary()
tokenizer_src = vocab.get_tokenizer_src()
tokenizer_tgt = tokenizer_src
tokenizer_tgt_inv = vocab.get_inverted_tokenizer_tgt()

cuda


In [3]:
from models.RNA_transformer.dataset import BilingualDataset

# Read data (RNA sequences extracted from RNAcentral database)
# see ./scripts/rna_data_extract.py for extraction details

# path_data = "../data/raw/data_rna_central/rna_sequences_8952.csv"
# path_data = "../data/raw/data_rna_central/rna_sequences_100_7939.csv"
# 100 000 RNA sequences, add general_rnafm_2000_4000.csv otherwise
path_data = "../data/raw/data_rna_central/general_dataframes/general_rnafm_0_2000.csv"

def get_ds(path_data, batch_size=config["batch_size"], 
           RNA_seq_len_max=config["RNA_seq_len_max"],
            tokenizer_src=tokenizer_src,
            tokenizer_tgt=tokenizer_tgt  ):
    # Read dataframe
    # ds_raw = pd.read_csv(path_data)[:1000]
    ds_raw = pd.read_pickle(path_data)
    ds_raw_filtered = ds_raw [ds_raw['sequence'].apply(len) < RNA_seq_len_max].reset_index(drop=True) # [:70000]
    train_ds_raw, val_ds_raw = train_test_split(ds_raw_filtered, test_size=0.1, random_state=RANDOM_SEED)
    train_ds_raw, val_ds_raw = train_ds_raw.reset_index(drop=True), val_ds_raw.reset_index(drop=True)

    train_ds = BilingualDataset(train_ds_raw, seq_len=SRC_SEQ_LEN,
                                tokenizer_src=tokenizer_src, 
                                tokenizer_tgt=tokenizer_tgt,
                                apply_masking=config["apply_masking"])
    val_ds = BilingualDataset(val_ds_raw, seq_len=SRC_SEQ_LEN,
                              tokenizer_src=tokenizer_src, 
                              tokenizer_tgt=tokenizer_tgt,
                              apply_masking=config["apply_masking"])
    
    max_len_src = ds_raw_filtered['sequence'].apply(len).max() #len(ds_raw_filtered)
    max_len_tgt = max_len_src #max(max_len_tgt, len(tgt_ids))
    print(f'Max length of source sentence: {max_len_src}')
    print(f'Max length of target sentence: {max_len_tgt}')
    print("Nlines (N rna molecules) in train dataset", len(train_ds_raw))
    print("Average rna len in raw dataset", ds_raw['sequence'].apply(len).mean())
    print("Average rna len for train", train_ds_raw['sequence'].apply(len).mean())

    train_dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)

    vocab = Vocabulary()
    tokenizer_src = vocab.get_tokenizer_src()
    tokenizer_tgt = tokenizer_src

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt



train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt  = get_ds(path_data)


Max length of source sentence: 204
Max length of target sentence: 204
Nlines (N rna molecules) in train dataset 62983
Average rna len in raw dataset 319.7815807903952
Average rna len for train 86.50246892018481


In [4]:
################## TRAIN #################

In [5]:
# Define the device
device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu"
print("Using device:", device)
if (device == 'cuda'):
    print(f"Device name: {torch.cuda.get_device_name(device.index)}")
    print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB")
elif (device == 'mps'):
    print(f"Device name: <mps>")
else:
    print("NOTE: If you have a GPU, consider using it for training.")
    print("      On a Windows machine with NVidia GPU, check this video: https://www.youtube.com/watch?v=GMSjDTU8Zlc")
    print("      On a Mac machine, run: pip3 install --pre torch torchvision torchaudio torchtext --index-url https://download.pytorch.org/whl/nightly/cpu")
device = torch.device(device)

Using device: cuda
Device name: NVIDIA GeForce RTX 4060 Laptop GPU
Device memory: 7.99560546875 GB


In [6]:
# Create output folder
from pathlib import Path

# Make sure the weights folder exists
Path(f"{config['datasource']}//{config['model_folder']}").mkdir(parents=True, exist_ok=True)
# Make sure the inference folder exists
Path(f"{config['datasource']}//{config['inference_folder']}").mkdir(parents=True, exist_ok=True)


In [7]:
# Define model, optimizer, loss

#import torchmetrics
from torch.utils.tensorboard import SummaryWriter

train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(path_data)
model = build_transformer(src_vocab_size=VOCAB_SRC_LENGTH, tgt_vocab_size=VOCAB_TGT_LENGTH, src_seq_len=SRC_SEQ_LEN, 
                          tgt_seq_len=TGT_SEQ_LEN, d_model=d_model, 
                          custom_emb_percentage=CUSTOM_EMB_PERCENTAGE).to(device)
# Tensorboard
writer = SummaryWriter(config['experiment_name']) # ('runs/simple_test')
# tensorboard --logdir=runs --port=6006

autoencoder_cond = config["autoencoder_vanilla"] or config["autoencoder_bert"]
if config["trainable_classification_weight"] and autoencoder_cond:

    #alpha = nn.Parameter(torch.tensor(5.0))  # Weight for loss_1
    alpha = torch.tensor(5.0, requires_grad=True)  # Initial coefficient, Weight for loss_1
    optimizer = torch.optim.Adam(
        list(model.parameters()) + [alpha],  # Include alpha as parameter to be optimized
        lr=config['lr'],
        eps=1e-9
    )
else:
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)
    alpha = torch.tensor(5.0, requires_grad=False) #5
    print("--> trainable_classification_weight set to const value of 0.5!")

# loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)
loss_fn_vanilla = nn.CrossEntropyLoss(ignore_index=3, label_smoothing=0.1).to(device)
loss_fn_bert = nn.CrossEntropyLoss(label_smoothing=0.1).to(device)
loss_cl = nn.CrossEntropyLoss(label_smoothing=0.1).to(device) 

Max length of source sentence: 204
Max length of target sentence: 204
Nlines (N rna molecules) in train dataset 62983
Average rna len in raw dataset 319.7815807903952
Average rna len for train 86.50246892018481
--> trainable_classification_weight set to const value of 0.5!


In [8]:
# Run model training
# Parameters are in config file: models/config.py

from models.RNA_transformer.validation import run_validation

initial_epoch = 0
global_step = 0

for epoch in range(initial_epoch, config['num_epochs']):
    torch.cuda.empty_cache()
    model.train()
    loss_epoch = torch.Tensor([0])
    batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
    #loss = torch.Tensor([0], require).to(device)

    for batch in batch_iterator:

        encoder_input = batch['encoder_input'].to(device) # (b, seq_len)
        encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
        decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
        decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)

        # Run the tensors through the encoder, decoder and the projection layer
        encoder_output = model.encode(encoder_input, encoder_mask) # (B, seq_len, d_model)
        decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (B, seq_len, d_model)
        proj_output = model.project(decoder_output) # (B, seq_len, vocab_size)
        proj_output_bert = model.project_bert(encoder_output) # (B, seq_len, vocab_size)

        # VANILLA autoencoder
        # Compute reconstruction loss using a simple cross entropy
        label = batch['label'].to(device) # (B, seq_len)
        loss_vanilla = loss_fn_vanilla(proj_output.view(-1, VOCAB_SRC_LENGTH), label.view(-1))

        # BERT autoencoder
        # Compute reconstruction loss using a simple cross entropy
        label_bert = batch['label_bert'].to(device) # (B, seq_len)
        #print(label_bert)
        loss_bert = loss_fn_bert(proj_output_bert.view(-1, VOCAB_SRC_LENGTH), 
                            label_bert.view(-1))

        # CLASSIFICATION
        # Extract a latent vector (e.g., from the last token or mean pooling)
        cls_token_rep = encoder_output[:, 0, :] # the first token or [:, -1, :]
        proj_output_cl = model.project_cl(cls_token_rep) # logits 
        cl = batch['class'].to(device) # (B, seq_len)
        loss_classification = loss_cl(proj_output_cl, cl)

        # Calculate composite loss
        # The weight 0.5=alpha/10 (const or trainable) 
        # -- is related to classification loss
        loss = loss_vanilla * int(config["autoencoder_vanilla"]) + \
            loss_bert * int(config["autoencoder_bert"]) + \
                loss_classification * int(config["classification"]) * alpha/10
            
        if loss == torch.Tensor([0]).to(device):
            print("define loss in config in a proper way!")
            raise SystemExit
        
        # Calculate loss for whole epoch
        loss_epoch = loss_epoch + loss.item()

        # Print some info during training
        if not config["trainable_classification_weight"]:
            batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}" })
        else:
            batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}",
                                        "alpha": f"{alpha.item():6.3f}"})



        # Add to tesorboard to see it later    
        writer.add_scalar('train loss', loss.item()/config["batch_size"], global_step)
        #writer.add_scalar('alpha', alpha.item(), global_step)
        writer.flush()

        # Backpropagate the loss
        loss.backward()

        # Update the weights
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

        # Clip the alpha/10 coefficient to ensure it's between 0 and 1
        if config["trainable_classification_weight"]:
            with torch.no_grad():
                alpha.data = torch.clamp(alpha.data, 1.0, 9.0)  # Ensuring the value is between 0 and 1

        global_step += 1

    Nlines = len(batch_iterator) * config["batch_size"]
    writer.add_scalar('train loss epoch', loss_epoch.item()/Nlines, epoch)
    if config["autoencoder_vanilla"]:
        writer.add_scalar('loss_vanilla', loss_vanilla.item(), epoch)
    # writer.add_scalar('alpha', alpha.item()/10.0, epoch)
    writer.flush()

    # Run validation at the end of every epoch
    # for num_examples=3 sequences
    run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), 
                   epoch, writer, num_examples=3)


    # Save the model at the end of every epoch
    model_filename = get_weights_file_path(config, f"{epoch:02d}")
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(), # weights of the model
        'optimizer_state_dict': optimizer.state_dict(), # where to move
        'global_step': global_step
    }, model_filename)

writer.close()

Processing Epoch 00: 100%|██████████| 985/985 [05:50<00:00,  2.81it/s, loss=1.449]


Character error rate (CER) for 3 sequences: 0.05371900647878647


Processing Epoch 01: 100%|██████████| 985/985 [05:54<00:00,  2.77it/s, loss=0.978]


Character error rate (CER) for 3 sequences: 0.07865168899297714


Processing Epoch 02: 100%|██████████| 985/985 [06:13<00:00,  2.63it/s, loss=0.942]


Character error rate (CER) for 3 sequences: 0.04089219495654106




In [9]:
############ end training ############

In [10]:
# Read the current model

model_filename = "../data/processed/RNA_transformer/weights/tmodel_02.pt"
state = torch.load(model_filename)
model.load_state_dict(state['model_state_dict'])
initial_epoch = state['epoch'] + 1
optimizer.load_state_dict(state['optimizer_state_dict'])
global_step = state['global_step']

In [11]:
# Compare the current model with rnafm using dataframes
# num_examples is max number of RNA sequences to keep from validation dataset

from models.RNA_transformer.comparison import comparison

# where to save dataframes
path_file = f"{config['datasource']}//{config['inference_folder']}"

df_my, df_rnafm = comparison(model, val_dataloader, device, 
                             num_examples=5, path_file=path_file)

Each dataframe contains 5 lines


In [13]:
# Run validation outside training loop, if needed
# num_examples is max number of RNA sequences to keep from validation dataset
# see stdout and tensorboard

from models.RNA_transformer.validation import run_validation

epoch=11
run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), 
               epoch, writer=0, verbose=True, num_examples=50)
#writer.close()