In [None]:
import h5py
import json
import numpy as np
import pandas as pd
from tqdm import tqdm
from datasets import load_dataset
from src.tkns.tokenizer import RNASequenceTokenizer

______________________________
## I. RNACentral data inspection and tokenization of the data for training.

In [None]:
process_data = True
local_dir = "./data"
seq_length = 512
chunk_size = 10_000


if process_data:
    dataset = load_dataset(f"multimolecule/rnacentral.{seq_length}", cache_dir=local_dir)
    # dataset["train"].to_csv(f"{local_dir}/rnacentral.{seq_length}.csv", index=False)

    unique_nucleotides = set()

    # Process the dataset in chunks
    train_data = dataset["train"]
    for start_idx in range(0, len(train_data), chunk_size):
        end_idx = min(start_idx + chunk_size, len(train_data))
        chunk = train_data[start_idx:end_idx]
        sequences = chunk['sequence']

        for seq in sequences:
            unique_nucleotides.update(seq)
        
    unique_nucleotides = set([e.upper() for e in unique_nucleotides])
    print("Unique nucleotides:", unique_nucleotides)

In [None]:
# Add the special tokens to unique nucleotides
unique_nucleotides.update(["[PAD]", "[MASK]"])
list_nucleotides = list(unique_nucleotides)
list_nucleotides.sort()

# Create a token vocabulary from unique nucleotides
vocabulary = {token: idx for idx, token in enumerate(list_nucleotides)}

# Save the vocabulary to a JSON file
with open(f"{local_dir}/vocabulary.json", "w") as f:
    json.dump(vocabulary, f)

In [None]:
# Tokenize and store sequences in an H5 file for faster training.
tokenize = False

if tokenize:
    tokenizer = RNASequenceTokenizer()

    sequences = dataset["train"]["sequence"]
    rna_types = dataset["train"]["type"]    

    print("Number of sequences: ", len(sequences))
    print("RNA types: ", len(rna_types))

    h5_file_path = f"{local_dir}/tokenized_sequences.h5"

    with h5py.File(h5_file_path, 'w') as h5f:

        tokenized_seqs_ds = h5f.create_dataset('tokenized_sequences', (len(sequences), seq_length), dtype='int8')
        seq_types_ds = h5f.create_dataset('sequence_types', (len(sequences),), dtype=h5py.string_dtype())

        # Tokenize sequences and store them in the H5 file
        for idx, seq in enumerate(tqdm(sequences, desc="Tokenizing sequences: ")):
            tokenized_seq = tokenizer.encode(seq)
            tokenized_seqs_ds[idx, :len(tokenized_seq)] = tokenized_seq
            seq_types_ds[idx] = rna_types[idx]

    print(f"Tokenized sequences and sequence types have been stored in {h5_file_path}")

In [None]:
# Split the data into train and test sets
split_data = False

if split_data:
    h5_file_path = "data/tokenized_sequences.h5"

    # Load the original H5 file
    with h5py.File(h5_file_path, 'r') as h5f:
        tokenized_sequences = h5f['tokenized_sequences'][:]
        sequence_types = h5f['sequence_types'][:]

    # Determine the split index
    num_sequences = len(tokenized_sequences)
    split_index = int(0.9 * num_sequences)

    # Shuffle the indices
    indices = np.arange(num_sequences)
    np.random.shuffle(indices)

    # Split the indices into train and test sets
    train_indices = indices[:split_index]
    test_indices = indices[split_index:]

    # Create train and test H5 files
    train_h5_file_path = f"{local_dir}/tokenized_sequences_train.h5"
    test_h5_file_path = f"{local_dir}/tokenized_sequences_test.h5"

    with h5py.File(train_h5_file_path, 'w') as train_h5f, h5py.File(test_h5_file_path, 'w') as test_h5f:
        # Create datasets for train and test sets
        train_h5f.create_dataset('tokenized_sequences', data=tokenized_sequences[train_indices])
        train_h5f.create_dataset('sequence_types', data=sequence_types[train_indices])
        
        test_h5f.create_dataset('tokenized_sequences', data=tokenized_sequences[test_indices])
        test_h5f.create_dataset('sequence_types', data=sequence_types[test_indices])

    print(f"Train and test H5 files have been created: {train_h5_file_path}, {test_h5_file_path}")

In [None]:
# Load and test tokenizer
tokenizer = RNASequenceTokenizer()

# Encoding and decoding example
sequence = "AAAFCG" # sequences[0]
encoded = tokenizer.encode(sequence)
decoded = tokenizer.decode(encoded)

print("Encoded:", encoded)
print("Decoded:", decoded)

print("Encoding / decoding: ", sequence == decoded)

______________________________
## II. Init dataset, collate function and dataloader. Inspect inputs, masked inputs and targets.

In [None]:
from functools import partial
from torch.utils.data import DataLoader
from src.datasets.rna_central import RNACentral
from src.datasets.masked_lm import collate_fn_mlm

In [None]:
# Configuration
config = {"mask_prob": 0.30, "no_mask_tokens": [], "randomize_prob": 0.1, "no_change_prob": 0.1, "max_length": 10, "batch_size": 1, "n_tokens": 15}

sequences = ["ACGTACGCGTATATTTGGGA", "TTAAACCCGGTAACAAAATTTGCGTA", "CGTACGTA", "ACGTACGT", "TTGACGTA", "CGTACGTA"]

tokenizer = RNASequenceTokenizer()
dataset = RNACentral(lines=sequences, tokenizer=tokenizer, max_length=10)

custom_collate_fn = partial(collate_fn_mlm,
                            pad_token_id=tokenizer.vocabulary["[PAD]"],
                            mask_token_id=tokenizer.vocabulary["[MASK]"],
                            mask_prob=config["mask_prob"],
                            no_mask_tokens=config["no_mask_tokens"],
                            n_tokens=len(tokenizer.vocabulary),
                            randomize_prob=config["randomize_prob"],
                            no_change_prob=config["no_change_prob"],)

dataloader = DataLoader(dataset, batch_size=1, collate_fn=custom_collate_fn)

In [None]:
# Generate a batch and demonstrate masking
for batch_idx, (masked_input_ids, masked_labels) in enumerate(dataloader):
    print(f"Batch {batch_idx + 1}")
    print("Input sequences:")
    
    for seq_idx, (masked_sequence, token_ids) in enumerate(zip(masked_input_ids, dataset)):
        
        original_sequence = ' '.join(map(str, token_ids.tolist()))
        masked_sequence_str = ' '.join(
            f"\033[31m{token_id}\033[0m" if token_id == tokenizer.vocabulary.get("[MASK]", 2) else str(token_id)
            for token_id in masked_sequence.tolist()
        )
        print(f"\tOriginal Sequence {seq_idx + 1}: {original_sequence}")
        print(f"\tMasked Sequence   {seq_idx + 1}: {masked_sequence_str}")
    
    print("\nTarget sequences:")
    for seq_idx, sequence in enumerate(masked_labels):
        print(f"\tSequence {seq_idx + 1}:      {' '.join(map(str, sequence.tolist()))}")
    
    print("\n")

_________________________________________________
## Model initialization, training & testing

In [None]:

import random
import pandas as pd
from tqdm import tqdm
from typing import List
from functools import partial
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import accuracy_score

from src.tkns.tokenizer import RNASequenceTokenizer
from src.models.bert.bert import BERT
from src.models.bert.cgf import BERTConfig, TrainingConfig
from src.datasets.masked_lm import collate_fn_mlm
from src.datasets.rna_central import RNACentral

In [None]:
class Trainer:
    def __init__(self, 
                 model: nn.Module, 
                 train_dataloader: DataLoader, 
                 val_dataloader: DataLoader,
                 tokenizer,
                 criterion: nn.Module,
                 optimizer: torch.optim.Optimizer,
                 scheduler: torch.optim.lr_scheduler._LRScheduler,
                 config: TrainingConfig,
                 writer):
        self.model = model
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.tokenizer = tokenizer
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.config = config
        self.writer = writer
        self.global_step = 0

    def train_epoch(self, epoch: int):
        self.model.train()
        total_loss = 0.0

        for i, (input_ids, labels) in enumerate(tqdm(self.train_dataloader, desc=f"Epoch {epoch + 1} Training")):
            input_ids, labels = input_ids.to(self.config.device), labels.to(self.config.device)

            # Forward pass
            logits = self.model(input_ids)["logits"]

            # Compute loss (transpose to match CrossEntropyLoss dimensions)
            loss = self.criterion(logits.transpose(1, 2), labels)
            total_loss += loss.item()

            # Backward pass
            loss.backward()

            # Clip gradients to avoid exploding gradients
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.gradient_clip)

            # Optimizer step
            self.optimizer.step()
            self.optimizer.zero_grad()

            # Scheduler step
            self.scheduler.step()

            # Log loss periodically
            if self.global_step % self.config.log_steps == 0:
                avg_loss = total_loss / (i + 1)
                self.writer.add_scalar("Loss/Train", avg_loss, self.global_step)
                print(f"Step {self.global_step} | Training Loss: {avg_loss:.4f}")

            # Save the model periodically
            if self.global_step % self.config.save_steps == 0:
                model_path = f"model_step_{self.global_step}.pt"
                torch.save(self.model.state_dict(), model_path)
                print(f"Model saved at step {self.global_step}: {model_path}")

            self.global_step += 1

        avg_loss = total_loss / len(self.train_dataloader)
        return avg_loss

    def validate_epoch(self, epoch: int):
        self.model.eval()
        total_loss = 0.0

        with torch.no_grad():
            for i, (input_ids, labels) in tqdm(self.val_dataloader, desc=f"Epoch {epoch + 1} Validation"):
                input_ids, labels = input_ids.to(self.config.device), labels.to(self.config.device)

                # Forward pass
                logits = self.model(input_ids)["logits"]

                # Compute loss
                loss = self.criterion(logits.transpose(1, 2), labels)
                total_loss += loss.item()

                # Log loss periodically during validation
                if i % self.config.log_steps == 0:
                    avg_loss = total_loss / (i + 1)
                    self.writer.add_scalar("Loss/Validation", avg_loss, self.global_step)
                    print(f"Step {self.global_step} | Validation Loss: {avg_loss:.4f}")

        avg_loss = total_loss / len(self.val_dataloader)
        return avg_loss

    def train(self):
        best_val_loss = float("inf")

        for epoch in range(self.config.n_epochs):
            # Train and validate
            train_loss = self.train_epoch(epoch)
            val_loss = self.validate_epoch(epoch)

            print(f"Epoch {epoch + 1} | Train Loss: {train_loss:.4f} | Validation Loss: {val_loss:.4f}")

            # Save the best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(self.model.state_dict(), "best_model.pt")
                print("Best model saved.")


In [None]:
# Set the seed for reproducibility
seed = 42
random.seed(seed)
torch.manual_seed(seed)
tokenizer = RNASequenceTokenizer()



# Dataclass for training configurations
@dataclass
class TrainingConfig:
    batch_size: int
    lr: float
    n_epochs: int
    max_seq_length: int
    device: str
    gradient_clip: float = 1.0
    log_steps: int = 500
    save_steps: int = 100000
    pad_token_id: int = 0
    mask_token_id: int = 1
    mask_prob: float = 0.15
    no_mask_tokens: List[int] = None
    n_tokens: int = 0
    randomize_prob: float = 0.1
    no_change_prob: float = 0.1

In [None]:
bert_config = BERTConfig(dim=256, n_heads=8, attn_dropout=0.1, mlp_dropout=0.1, depth=6, 
                         vocab_size=21, max_len=512, pad_token_id=0, mask_token_id=1)

train_config = TrainingConfig(
    batch_size=32, 
    lr=1e-4, 
    n_epochs=10, 
    max_seq_length=512, 
    device="cuda", 
    log_steps=500, 
    save_steps=100000,
    pad_token_id=tokenizer.vocabulary["[PAD]"],
    mask_token_id=tokenizer.vocabulary["[MASK]"],
    mask_prob=0.15,
    no_mask_tokens=[],
    n_tokens=len(tokenizer.vocabulary),
    randomize_prob=0.1,
    no_change_prob=0.1
)

In [None]:
writer = SummaryWriter(log_dir="./logs/mlm_training")

# Initialize tokenizer
tokenizer = RNASequenceTokenizer()

# Create datasets
dataset_train = RNACentral(h5_file_path="/home/andrii/Documents/genrna/data/tokenized_sequences_train.h5", tokenizer=tokenizer, max_length=train_config.max_seq_length)
dataset_test = RNACentral(h5_file_path="/home/andrii/Documents/genrna/data/tokenized_sequences_test.h5", tokenizer=tokenizer, max_length=train_config.max_seq_length)

# Create custom collate function
custom_collate_fn = partial(collate_fn_mlm,
                            pad_token_id=tokenizer.vocabulary["[PAD]"],
                            mask_token_id=tokenizer.vocabulary["[MASK]"],
                            mask_prob=train_config.mask_prob,
                            no_mask_tokens=train_config.no_mask_tokens,
                            n_tokens=train_config.n_tokens,
                            randomize_prob=train_config.randomize_prob,
                            no_change_prob=train_config.no_change_prob)

# Create data loaders
dataloader_train = DataLoader(dataset_train, batch_size=train_config.batch_size, collate_fn=custom_collate_fn)
dataloader_test = DataLoader(dataset_test, batch_size=train_config.batch_size, collate_fn=custom_collate_fn)

In [None]:
# Initialize model
model = BERT(bert_config).to(train_config.device)

# Optimizer and scheduler
optimizer = Adam(model.parameters(), lr=train_config.lr)
scheduler = OneCycleLR(optimizer, max_lr=train_config.lr, steps_per_epoch=len(dataloader_train), epochs=train_config.n_epochs)

# Criterion
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.vocabulary["[PAD]"])

# Initialize the Trainer
trainer = Trainer(
    model=model,
    train_dataloader=dataloader_train,
    val_dataloader=dataloader_test,
    tokenizer=tokenizer,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    config=train_config,
    writer=writer
)

In [None]:

# Train the model
trainer.train()

In [None]:
class Trainer:
    def __init__(self, 
                 model: nn.Module, 
                 train_dataloader: DataLoader, 
                 val_dataloader: DataLoader,
                 tokenizer,
                 criterion: nn.Module,
                 optimizer: torch.optim.Optimizer,
                 scheduler: torch.optim.lr_scheduler._LRScheduler,
                 config: TrainingConfig,
                 writer):
        self.model = model
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.tokenizer = tokenizer
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.config = config
        self.writer = writer
        self.global_step = 0

    def train_epoch(self, epoch: int):
        self.model.train()
        total_loss = 0.0

        for i, (input_ids, labels) in enumerate(tqdm(self.train_dataloader, desc=f"Epoch {epoch + 1} Training")):
            input_ids, labels = input_ids.to(self.config.device), labels.to(self.config.device)

            # Forward pass
            logits = self.model(input_ids)

            # Compute loss (transpose to match CrossEntropyLoss dimensions)
            loss = self.criterion(logits.transpose(1, 2), labels)
            total_loss += loss.item()

            # Backward pass
            loss.backward()

            # Clip gradients to avoid exploding gradients
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.gradient_clip)

            # Optimizer step
            self.optimizer.step()
            self.optimizer.zero_grad()

            # Scheduler step
            self.scheduler.step()

            # Log loss periodically
            if self.global_step % self.config.log_steps == 0:
                avg_loss = total_loss / (i + 1)
                self.writer.add_scalar("Loss/Train", avg_loss, self.global_step)
                print(f"Step {self.global_step} | Training Loss: {avg_loss:.4f}")

            # Save the model periodically
            if self.global_step % self.config.save_steps == 0:
                model_path = f"model_step_{self.global_step}.pt"
                torch.save(self.model.state_dict(), model_path)
                print(f"Model saved at step {self.global_step}: {model_path}")

            self.global_step += 1

        avg_loss = total_loss / len(self.train_dataloader)
        return avg_loss

    def validate_epoch(self, epoch: int):
        self.model.eval()
        total_loss = 0.0

        with torch.no_grad():
            for i, (input_ids, labels) in tqdm(self.val_dataloader, desc=f"Epoch {epoch + 1} Validation"):
                input_ids, labels = input_ids.to(self.config.device), labels.to(self.config.device)

                # Forward pass
                logits = self.model(input_ids)

                # Compute loss
                loss = self.criterion(logits.transpose(1, 2), labels)
                total_loss += loss.item()

                # Log loss periodically during validation
                if i % self.config.log_steps == 0:
                    avg_loss = total_loss / (i + 1)
                    self.writer.add_scalar("Loss/Validation", avg_loss, self.global_step)
                    print(f"Step {self.global_step} | Validation Loss: {avg_loss:.4f}")

        avg_loss = total_loss / len(self.val_dataloader)
        return avg_loss

    def train(self):
        best_val_loss = float("inf")

        for epoch in range(self.config.n_epochs):
            # Train and validate
            train_loss = self.train_epoch(epoch)
            val_loss = self.validate_epoch(epoch)

            print(f"Epoch {epoch + 1} | Train Loss: {train_loss:.4f} | Validation Loss: {val_loss:.4f}")

            # Save the best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(self.model.state_dict(), "best_model.pt")
                print("Best model saved.")

In [None]:
tokenizer = RNASequenceTokenizer()

# Create datasets and data loaders
dataset_train = RNACentral(h5_file_path="/home/andrii/Documents/genrna/data/tokenized_sequences_train.h5", tokenizer=tokenizer, max_length=train_config.max_seq_length)
dataset_test  = RNACentral(h5_file_path="/home/andrii/Documents/genrna/data/tokenized_sequences_test.h5",  tokenizer=tokenizer, max_length=train_config.max_seq_length)

custom_collate_fn = partial(collate_fn_mlm,
                            pad_token_id=tokenizer.vocabulary["[PAD]"],
                            mask_token_id=tokenizer.vocabulary["[MASK]"],
                            mask_prob=0.15,
                            no_mask_tokens=[],
                            n_tokens=len(tokenizer.vocabulary),
                            randomize_prob=0.1,
                            no_change_prob=0.1)

dataloader_train = DataLoader(dataset_train, batch_size=config_train["batch_size"], collate_fn=custom_collate_fn)
dataloader_test = DataLoader(dataset_test, batch_size=config_train["batch_size"], collate_fn=custom_collate_fn)

In [None]:
writer = SummaryWriter(log_dir="./logs/mlm_training")

# Initialize model
model = BERT(bert_config).to(train_config.device)

# Optimizer and scheduler
optimizer = Adam(model.parameters(), lr=train_config.lr)
scheduler = OneCycleLR(optimizer, max_lr=train_config.lr, steps_per_epoch=len(dataloader_train), epochs=train_config.n_epochs)

# Criterion
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.vocabulary["[PAD]"])

# Initialize the Trainer
trainer = Trainer(
    model=model,
    train_dataloader=dataloader_train,
    val_dataloader=dataloader_test,
    tokenizer=tokenizer,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    config=train_config,
    writer=writer
)

In [None]:



config_train = {
    "seed": 11,
    "train_test_split": 0.8,
    "max_seq_length": 1000,
    'batch_size': 32,
    'n_epochs': 5,
    'lr': 6e-4,
    'warmup_steps': 10000,
    'log_interval': 100,
    "mask_prob": 0.15,

}

In [None]:
# Load the dataset
data = pd.read_csv(f"{local_dir}/rnacentral.{seq_length}.csv")
data.head(3)

print("Number of sequences: ", len(data))
print("RNA types: ", len(data.type.unique()))

In [None]:
sequences = data.sequence.to_list()

# Optionaal: not used in the current LM implementation
rna_types = data.type.to_list()

In [None]:
for i in range (100):
    print(len(sequences[i]))

In [None]:
# Randomly shuffle and split into train and test sets
random.shuffle(sequences)
split_index = int(config_train["train_test_split"] * len(sequences))
train_sequences = sequences[:split_index]
test_sequences = sequences[split_index:]

In [None]:
tokenizer = RNASequenceTokenizer()

# Create datasets and data loaders
dataset_train = MLMDataset(train_sequences, tokenizer, max_length=config_train["max_seq_length"])
dataset_test  = MLMDataset(test_sequences,  tokenizer, max_length=config_train["max_seq_length"])

custom_collate_fn = partial(collate_fn,
                            mask_token_id=tokenizer.vocabulary["[MASK]"],
                            mask_prob=config_train["mask_prob"],
                            pad_token_id=tokenizer.vocabulary["[PAD]"])

dataloader_train = DataLoader(dataset_train, batch_size=config_train["batch_size"], collate_fn=custom_collate_fn)
dataloader_test = DataLoader(dataset_test, batch_size=config_train["batch_size"], collate_fn=custom_collate_fn)

In [None]:
# Initialize the model
model = BERT(config).to('cuda')
print('Trainable parameters:', sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000, 'M')

In [None]:
# Optimizer and scheduler
optim = torch.optim.Adam(model.parameters(), lr=config_train["lr"] / 25.)
sched = torch.optim.lr_scheduler.OneCycleLR(optim, max_lr=config_train["lr"], steps_per_epoch=len(dataloader_train), epochs=config_train["n_epochs"])

criterion = torch.nn.CrossEntropyLoss(ignore_index=-100)

In [None]:
# TensorBoard writer
writer = SummaryWriter(log_dir="./logs/mlm_training")

# Training loop
train_losses = []
valid_losses = []
best_val_loss = float('inf')

for ep in range(config_train["n_epochs"]):
    # Training phase
    model.train()
    train_loss = 0.0
    for i, (input_ids, labels) in enumerate(tqdm(dataloader_train, desc=f"Epoch {ep + 1} Training")):
        input_ids, labels = input_ids.to('cuda'), labels.to('cuda')
        
        logits_lm = model(input_ids)
        loss = criterion(logits_lm.transpose(1, 2), labels)

        loss.backward()
        optim.step()
        optim.zero_grad()
        sched.step()

        train_loss += loss.item()

    train_loss /= len(dataloader_train)
    train_losses.append(train_loss)
    writer.add_scalar("Loss/Train", train_loss, ep)

    # Validation phase
    model.eval()
    valid_loss = 0.0
    with torch.no_grad():
        for i, (input_ids, labels) in enumerate(tqdm(dataloader_test, desc=f"Epoch {ep + 1} Validation")):
            input_ids, labels = input_ids.to('cuda'), labels.to('cuda')
            
            logits_lm = model(input_ids)
            loss = criterion(logits_lm.transpose(1, 2), labels)
        
            valid_loss += loss.item()

    valid_loss /= len(dataloader_test)
    valid_losses.append(valid_loss)
    writer.add_scalar("Loss/Validation", valid_loss, ep)

    print(f"Epoch {ep + 1} | Train Loss: {train_loss:.4f} | Validation Loss: {valid_loss:.4f}")

    # Save the best model
    if valid_loss < best_val_loss:
        best_val_loss = valid_loss
        torch.save(model.state_dict(), './mlm-baby-bert/best_model.pt')
        print("Best model saved.")

# Close the TensorBoard writer
writer.close()

In [None]:
vocab_size = 10 # including [mask] and [pad]
max_len = 5
num_seq = 5

def gen_sample_data(vocab_size, max_len, num_seq):
    """generate a list of text with variable lengths
    """
    # minus 2 for [0: padding ,1: mask]
    gen_single_sequence = lambda : torch.randint(2, vocab_size-3, size=(torch.randint(1, max_len, size=(1,)),))
    return [gen_single_sequence() for _ in range(num_seq)]

seqs = gen_sample_data(vocab_size, max_len, num_seq)

def batch_data(data):
    """Generate batched_data with padding
    """
    num_samples = len(data)
    full_data = torch.zeros(num_samples, max_len)
    for i, sent in enumerate(data):
        min_length = min(len(sent), max_len)
        full_data[i, :min_length] = sent[:min_length]
    return full_data.long()

batch_data = batch_data(seqs)
batch_data

In [None]:
masking_prob = 0.15 
full_mask = torch.randn(batch_data.shape) < masking_prob
full_mask