In [None]:
import torch
from pathlib import Path
from transformers import RobertaTokenizer
from transformers import RobertaConfig
from transformers import RobertaForMaskedLM
from transformers import AdamW
from tqdm.auto import tqdm
import gc
import pickle


torch.cuda.empty_cache()

class Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        # store encodings internally
        self.encodings = encodings

    def __len__(self):
        # return the number of samples
        return self.encodings['input_ids'].shape[0]

    def __getitem__(self, i):
        # return dictionary of input_ids, attention_mask, and labels for index i
        return {key: tensor[i] for key, tensor in self.encodings.items()}
    
def mlm(input_ids):
    # input_ids = labels.detach().clone()
    # create random array of floats with equal dims to input_ids
    rand = torch.rand(input_ids.shape)
    # mask random 15% where token is not 0 [PAD], 1 [CLS], or 2 [SEP]
    mask_arr = (rand < .15) * (input_ids != 0) * (input_ids != 1) * (input_ids != 2)
    # loop through each row in input_ids tensor (cannot do in parallel)
    for i in range(input_ids.shape[0]):
        # get indices of mask positions from mask array
        # print("mask array shape:", mask_arr.shape)
        selection = torch.flatten(mask_arr[i].nonzero()).tolist()
        # print("selection array shape:", selection)
        # mask input_ids
        input_ids[i, selection] = 4 
    return input_ids

paths = [str(x) for x in Path('/userdirs/piyumal/roberta_sinhala/content/sinhala-dataset-creation/datasets/tokenized/').glob('**/*.txt')]
tokenizer = RobertaTokenizer.from_pretrained('Roberta_tokenizer', max_len=512)

with open('Sinhala_all_data.txt', 'r', encoding='utf-8') as fp:
    lines = fp.read().split('\n')
    
config = RobertaConfig(
    vocab_size=52_000,  # we align this to the tokenizer vocab_size
    max_position_embeddings=514,
    hidden_size=768,
    num_attention_heads=12,
    num_hidden_layers=12,
    type_vocab_size=1
)

model = RobertaForMaskedLM(config)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# and move our model over to the selected device
model.to(device)

# activate training mode
model.train()
# initialize optimizer
optim = AdamW(model.parameters(), lr=1e-4)

print("##############################################################################################")
print("Training started..............................................................................")
epochs = 2
loss_values = []
for epoch in range(epochs):
    for i in range(0, len(lines), 5_00_000):
        if i == (len(lines)//5_00_000)*5_00_000:
            batch_data = tokenizer(lines[i:], max_length=512, padding='max_length', truncation=True)
        else:
            batch_data = tokenizer(lines[i:i+5_00_000], max_length=512, padding='max_length', truncation=True)
        labels_all = torch.tensor(batch_data.input_ids)
        mask_all = torch.tensor(batch_data.attention_mask)
        input_ids_all = mlm(labels_all.detach().clone())
        encodings = {'input_ids': input_ids_all, 'attention_mask': mask_all, 'labels': labels_all}
        dataset = Dataset(encodings)
        loader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)
    # setup loop with TQDM and dataloader
        loop = tqdm(loader, leave=True)
        for batch in loop:
            # initialize calculated gradients (from prev step)
            optim.zero_grad()
            # pull all tensor batches required for training
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            # process
            outputs = model(input_ids, attention_mask=attention_mask,
                            labels=labels)
            # extract loss
            loss = outputs.loss
            # calculate loss for every parameter that needs grad update
            loss.backward()
            # update parameters
            optim.step()
            # print relevant info to progress bar
            loop.set_description(f'Epoch: {epoch} Text Loading Iter: {i//5_00_000}')
            loop.set_postfix(loss=loss.item())
            loss_values.append(loss.item())
            del input_ids
            del attention_mask
            del labels
        else:
            model.save_pretrained(f'./Roberta_models/roberta_model_{epoch}_{i}')
            del batch_data
            del labels_all 
            del mask_all 
            del input_ids_all
            gc.collect()
            with open('loss_value_file', 'wb') as fp:
                pickle.dump(loss_values, fp)
            print("############################################################################################")
            print(f'batch completed for epoch {epoch}_{i}')
                

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'BertTokenizer'. 
The class this function is called from is 'RobertaTokenizer'.


##############################################################################################
Training started..............................................................................


  0%|          | 0/31250 [00:00<?, ?it/s]

############################################################################################
batch completed for epoch 0_0


  0%|          | 0/31250 [00:00<?, ?it/s]

############################################################################################
batch completed for epoch 0_500000


  0%|          | 0/31250 [00:00<?, ?it/s]

############################################################################################
batch completed for epoch 0_1000000


  0%|          | 0/31250 [00:00<?, ?it/s]

############################################################################################
batch completed for epoch 0_1500000


  0%|          | 0/31250 [00:00<?, ?it/s]

############################################################################################
batch completed for epoch 0_2000000


  0%|          | 0/31250 [00:00<?, ?it/s]

############################################################################################
batch completed for epoch 0_2500000


  0%|          | 0/31250 [00:00<?, ?it/s]

############################################################################################
batch completed for epoch 0_3000000


  0%|          | 0/31250 [00:00<?, ?it/s]

############################################################################################
batch completed for epoch 0_3500000


  0%|          | 0/31250 [00:00<?, ?it/s]

############################################################################################
batch completed for epoch 0_4000000


  0%|          | 0/31250 [00:00<?, ?it/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



############################################################################################
batch completed for epoch 0_4500000


  0%|          | 0/31250 [00:00<?, ?it/s]

############################################################################################
batch completed for epoch 0_5000000


  0%|          | 0/31250 [00:00<?, ?it/s]

############################################################################################
batch completed for epoch 0_5500000


  0%|          | 0/31250 [00:00<?, ?it/s]

############################################################################################
batch completed for epoch 0_6000000


  0%|          | 0/31250 [00:00<?, ?it/s]

############################################################################################
batch completed for epoch 0_6500000


  0%|          | 0/31250 [00:00<?, ?it/s]

############################################################################################
batch completed for epoch 0_7000000


  0%|          | 0/31250 [00:00<?, ?it/s]

############################################################################################
batch completed for epoch 0_7500000


  0%|          | 0/31250 [00:00<?, ?it/s]