In [None]:
import os
import random
import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import tqdm
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaTokenizer, get_cosine_schedule_with_warmup

### Tokenizer

In [None]:
tokenizer = RobertaTokenizer.from_pretrained("FacebookAI/roberta-base")

print("Special Tokens in Roberta Tokenizer")
special_tokens = tokenizer.special_tokens_map
print(special_tokens)


special_token_idx = {}
for token in special_tokens.values():
    special_token_idx[token] = tokenizer.encode(token,add_special_tokens = False)[0]

print(f"Special Token Index: {special_token_idx}")

all_tokens_idx = list(range(tokenizer.vocab_size))
all_special_tokens_idx = sorted(list(special_token_idx.values()))
all_non_special_tokens_idx = [token for token in all_tokens_idx if token not in all_special_tokens_idx]

### Prepare Dataset

In [None]:
path_to_data = "./data/harry_potter_txt/"

text_files = os.listdit(path_to_data)

all_text = ""
for book in text_files:
    with open(os.path.join(path_to_data,book),"r") as f:
        text = f.readlines()
        text = [line for line in text if "Page" not in line]
        text = " ".join(text).replace("\n", "")
        text = [word for word in text.split(" ") if len(word)>0]
        text = " ".join(text)
        all_text += text
all_text = all_text.split(".") 

all_text_chunked = [".".join(all_text[i:i+5]) for i in range(0,len(all_text),5)]

tokenized_text = [tokenizer.encode(text) for text in all_text_chunked]

In [None]:
class MaskedLMLoader(Dataset):
    def __init__(self, tokenized_data, max_seq_len=100, masking_ratio=0.15):
        self.data = tokenized_data
        self.mask_ratio = masking_ratio
        self.max_seq_len = max_seq_len
        
    def __len__(self):
        return len(self.data)

    def _random_mask_text(self, tokens):
        random_masking = torch.rand(*tokens.shape)

        special_tokens = torch.tensor(tokenizer.get_special_tokens_mask(tokens, already_has_special_tokens=True))
        random_masking[special_tokens==1] = 1

        random_masking = (random_masking < self.mask_ratio)

        labels = torch.full((tokens.shape), -100)
        labels[random_masking] = tokens[random_masking]

        random_selected_idx = random_masking.nonzero()

        masking_flag = torch.rand(*random_selected_idx.shape)
        masking_flag = (masking_flag<0.8)
        selected_idx_for_masking = random_selected_idx[masking_flag]

        unselected_idx_for_masking = random_selected_idx[~masking_flag]

        masking_flag = torch.rand(*unselected_idx_for_masking.shape)
        masking_flag = (masking_flag<0.5)
        selected_idx_for_random_filling = unselected_idx_for_masking[masking_flag]
        selected_idx_to_be_left_alone = unselected_idx_for_masking[~masking_flag]
        
        if len(selected_idx_for_masking) > 0:
            tokens[selected_idx_for_masking] = special_token_idx["<mask>"]
        
        if len(selected_idx_for_random_filling) > 0:
            randomly_selected_tokens = torch.tensor(random.sample(all_non_special_tokens_idx, len(selected_idx_for_random_filling)))
            tokens[selected_idx_for_random_filling] = randomly_selected_tokens
        
        
        return tokens, labels
        
    def __getitem__(self, idx):
        data = torch.tensor(self.data[idx])

        if len(data) > self.max_seq_len:
            rand_start_idx = random.choice(list(range(len(data) - self.max_seq_len)))
            end_idx = rand_start_idx + self.max_seq_len
            data = data[rand_start_idx:end_idx]
  
        masked_tokens, label = self._random_mask_text(data)

        return masked_tokens, label

mlm = MaskedLMLoader(tokenized_text)

for masked_tokens, labels in mlm:
    print(masked_tokens)
    print(labels)
    break

In [None]:
def collate_fn(batch):
    token_samples = []
    label_samples =[]

    for token, label in batch:
        token_samples.append(token)
        label_samples.append(label)

    sequence_lengths = [len(tok) for tok in token_samples]
    max_seq_len = max(sequence_lengths)

    padding_masks = []
    for idx in range(len(token_samples)):
        sample = token_samples[idx]
        seq_len = len(sample)
        diff = max_seq_len - seq_len

        if diff > 0:

            padding = torch.tensor([special_token_idx["<pad>"] for _ in range(diff)])
            sample = torch.concatenate((sample, padding))
            token_samples[idx] = sample
            
            label_padding = torch.tensor([-100 for _ in range(diff)])
            label_samples[idx] = torch.concatenate((label_samples[idx], label_padding))

            padding_mask = (sample==special_token_idx["<pad>"])
            padding_masks.append(padding_mask)

        else:
            padding_masks.append(torch.zeros(max_seq_len))

    token_samples = torch.stack(token_samples)
    label_samples = torch.stack(label_samples)
    padding_masks = torch.stack(padding_masks)

    assert token_samples.shape == label_samples.shape == padding_masks.shape
    
    batch = {"input_ids": token_samples, 
             "labels": label_samples, 
             "attention_mask": padding_masks.bool()}

    return batch
    
dataloader = DataLoader(mlm, batch_size=16, collate_fn=collate_fn)

for batch in dataloader:
    print(batch["input_ids"])
    print(batch["labels"])
    print(batch["attention_mask"])
    break