In [1]:
import json
import numpy as np
import pandas as pd
from tqdm import tqdm
from datasets import load_dataset
from src.tkns.tokenizer import RNASequenceTokenizer

In [2]:
download_data = False
local_dir = "./data"

if download_data:
    dataset = load_dataset("multimolecule/rnacentral.1024", cache_dir=local_dir)

______________________________
## I. RNACentral data inspection and tokenizer

In [3]:
search_unique_nucleotides = False

# Seearch unique nucleotides to build a vocabulary for the tokenizer.
if search_unique_nucleotides:
    chunk_size = 10_000
    unique_nucleotides = set()

    for chunk in pd.read_csv(f"{local_dir}/rna_central.1024.csv", chunksize=chunk_size):
        sequences = chunk['sequence']

        for seq in sequences:
            unique_nucleotides.update(seq)
        
    unique_nucleotides = set([e.upper() for e in list(unique_nucleotides)])
    print("Unique nucleotides:", unique_nucleotides)

    # Add special tokens to the vocabulary and populate the dictionary with unique nucleotides.
    nucleotide2id = {'[PAD]': 0, '[MASK]': 1}

    for i, w in enumerate(unique_nucleotides):
        nucleotide2id[w] = i + len(nucleotide2id)

    # Save vocabulary dict for tokenizer as json
    json.dump(nucleotide2id, open(f"{local_dir}/nucleotide2id.json", "w"))

In [4]:
# Load and test tokenizer
tokenizer = RNASequenceTokenizer()

# Encoding and decoding example
sequence = "AAAFCG" # sequences[0]
encoded = tokenizer.encode(sequence)
decoded = tokenizer.decode(encoded)

print("Encoded:", encoded)
print("Decoded:", decoded)

print("Encoding / decoding: ", sequence == decoded)


Encoded: [12, 12, 12, 2, 24, 30]
Decoded: AAAFCG
Encoding / decoding:  True


______________________________
## II. Init dataset, collate function and dataloader. Inspect inputs, masked inputs and targets.

In [5]:
from functools import partial
from torch.utils.data import DataLoader, Dataset
from src.datasets.masked_lm import MLMDataset, collate_fn

In [17]:
# Configuration
config = {
        "mask_prob": 0.15
    }
sequences = ["ACGTACGCGTAT", "TTGACAAAATTTGCGTA", "CGTACGTA", "ACGTACGT", "TTGACGTA", "CGTACGTA"]

tokenizer = RNASequenceTokenizer()
dataset = MLMDataset(sequences, tokenizer, max_length=10)

In [18]:
# DataLoader setup with partial function for collate_fn
custom_collate_fn = partial(collate_fn,
                            mask_token_id=tokenizer.vocabulary["[MASK]"],
                            mask_prob=config["mask_prob"],
                            pad_token_id=tokenizer.vocabulary["[PAD]"])

dataloader = DataLoader(dataset, batch_size=1, collate_fn=custom_collate_fn)

In [23]:
# Generate a batch and demonstrate masking
for batch_idx, (masked_input_ids, masked_labels) in enumerate(dataloader):
    print(f"Batch {batch_idx + 1}")
    print("Input sequences:")
    
    for seq_idx, (masked_sequence, pristine_sequence) in enumerate(zip(masked_input_ids, dataset)):
        token_ids, _ = pristine_sequence
        original_sequence = ' '.join(map(str, token_ids.tolist()))
        masked_sequence_str = ' '.join(
            f"\033[31m{token_id}\033[0m" if token_id == tokenizer.vocabulary.get("[MASK]", 2) else str(token_id)
            for token_id in masked_sequence.tolist()
        )
        print(f"\tOriginal Sequence {seq_idx + 1}: {original_sequence}")
        print(f"\tMasked Sequence   {seq_idx + 1}: {masked_sequence_str}")
    
    print("\nTarget sequences:")
    for seq_idx, sequence in enumerate(masked_labels):
        print(f"\tSequence {seq_idx + 1}:      {' '.join(map(str, sequence.tolist()))}")
    
    print("\n")
    

Batch 1
Input sequences:
	Original Sequence 1: 12 24 30 16 12 24 30 24 30 16
	Masked Sequence   1: 12 [31m1[0m 30 16 12 24 30 24 30 [31m1[0m

Target sequences:
	Sequence 1:      -100 24 -100 -100 -100 -100 -100 -100 -100 16


Batch 2
Input sequences:
	Original Sequence 1: 12 24 30 16 12 24 30 24 30 16
	Masked Sequence   1: 16 16 30 12 24 12 12 12 12 16

Target sequences:
	Sequence 1:      -100 -100 -100 -100 -100 -100 -100 -100 -100 -100


Batch 3
Input sequences:
	Original Sequence 1: 12 24 30 16 12 24 30 24 30 16
	Masked Sequence   1: 24 30 16 [31m1[0m 24 30 16 12 0 0

Target sequences:
	Sequence 1:      -100 -100 -100 12 -100 -100 -100 -100 -100 -100


Batch 4
Input sequences:
	Original Sequence 1: 12 24 30 16 12 24 30 24 30 16
	Masked Sequence   1: 12 24 [31m1[0m [31m1[0m 12 24 30 16 0 0

Target sequences:
	Sequence 1:      -100 -100 30 16 -100 -100 -100 -100 -100 -100


Batch 5
Input sequences:
	Original Sequence 1: 12 24 30 16 12 24 30 24 30 16
	Masked Sequence   1: 16 

_________________________________________________
## Model initialization and training

In [6]:
# Load the dataset
data = pd.read_csv(f"{local_dir}/rna_central.1024.csv")
data.head(3)

print("Number of sequences: ", len(data))
print("RNA types: ", len(data.type.unique()))

Number of sequences:  32524827
RNA types:  31


In [7]:
sequences = data.sequence.to_list()
rna_types = data.type.to_list()

In [None]:
config = {
    'dim': 256,
    'n_heads': 8,
    'attn_dropout': 0.1,
    'mlp_dropout': 0.1,
    'depth': 6,
    'vocab_size': 8192,
    'max_len': 128,
    'pad_token_id': 1,
    'mask_token_id': 2
}

In [None]:
from src.models.bert import BERT

In [None]:
import torch
from typing import List

In [None]:
model = BERT(config).to('cuda')
print('trainable:', sum([p.numel() for p in model.parameters() if p.requires_grad]) / 1_000_000, 'M')

In [None]:
class MLMRNACentral:
    def __init__(self, sequences: List[str], tokenizer: Tokenizer):
        self.sequences = sequences
        self.tokenizer = tokenizer

    def __len__(self,):
        return len(self.sequences)

    def __getitem__(self,idx):
        seq = self.sequences[idx]
        ids = self.tokenizer.encode(seq)
        labels = ids.copy()
        return ids, labels

In [None]:
def collate_fn(batch):
    input_ids = [torch.tensor(i[0]) for i in batch]
    labels = [torch.tensor(i[1]) for i in batch]

    input_ids = torch.stack(input_ids)
    labels = torch.stack(labels)

    # mask 15% of text leaving [PAD]
    mlm_mask = torch.rand(input_ids.size()) < 0.15 * (input_ids!=1)
    masked_tokens = input_ids * mlm_mask
    labels[masked_tokens==0]=-100 # set all tokens except masked tokens to -100
    input_ids[masked_tokens!=0]=2 # MASK TOKEN
    return input_ids, labels


In [None]:
vocab_size = 10 # including [mask] and [pad]
max_len = 5
num_seq = 5

def gen_sample_data(vocab_size, max_len, num_seq):
    """generate a list of text with variable lengths
    """
    # minus 2 for [0: padding ,1: mask]
    gen_single_sequence = lambda : torch.randint(2, vocab_size-3, size=(torch.randint(1, max_len, size=(1,)),))
    return [gen_single_sequence() for _ in range(num_seq)]

seqs = gen_sample_data(vocab_size, max_len, num_seq)

def batch_data(data):
    """Generate batched_data with padding
    """
    num_samples = len(data)
    full_data = torch.zeros(num_samples, max_len)
    for i, sent in enumerate(data):
        min_length = min(len(sent), max_len)
        full_data[i, :min_length] = sent[:min_length]
    return full_data.long()

batch_data = batch_data(seqs)
batch_data

In [None]:
masking_prob = 0.15 
full_mask = torch.randn(batch_data.shape) < masking_prob
full_mask