In [6]:
import random
import torch
from tqdm import tqdm
from transformers import BertTokenizer, BertForPreTraining

In [2]:
DATASET_PATH = "../../data/meditations.txt"
PRE_TRAINED_MODEL_PATH = "../../models/meditations_mlm_nsp_bert_base_uncased_continual_pretraining"
MODEL_NAME = "bert-base-uncased"
NUM_EPOCHS = 2
BATCH_SIZE = 2
MAX_LENGTH = 256
SHUFFLE    = True

In [8]:
def generate_nsp_dataset(text):
    # bag of sentences
    bag = [item for sentence in text for item in sentence.split('.') if item != '']
    bag_size = len(bag)

    sentence_a = []
    sentence_b = []
    label = []

    for paragraph in text:
        sentences = [
            sentence for sentence in paragraph.split('.') if sentence != ''
        ]
        num_sentences = len(sentences)
        if num_sentences > 1:
            start = random.randint(0, num_sentences-2)
            # 50/50 whether is IsNextSentence or NotNextSentence
            # In labels,
                # 0 indicates sequence B is a continuation of sequence A,
                # 1 indicates sequence B is a random sequence.
            
            if random.random() >= 0.5:
                # this is IsNextSentence
                sentence_a.append(sentences[start])
                sentence_b.append(sentences[start+1])
                label.append(0)
            else:
                index = random.randint(0, bag_size-1)
                # this is NotNextSentence
                sentence_a.append(sentences[start])
                sentence_b.append(bag[index])
                label.append(1)
    
    return sentence_a, sentence_b, label

In [9]:
def generate_mlm_dataset(tokenizer, inputs):
    
    # create labels tensor by cloning the input_ids tensor
    inputs['labels'] = inputs.input_ids.detach().clone()
    
    # create random array of floats with equal dimensions to input_ids tensor
    rand = torch.rand(inputs.input_ids.shape)
    
    # create mask array except for special tokens [CLS], [SEP], [PAD]
    mask_arr = (rand < 0.15) \
             * (inputs.input_ids != 101) \
             * (inputs.input_ids != 102) \
             * (inputs.input_ids != 0)
    
    selection = []
    for i in range(inputs.input_ids.shape[0]):
        selection.append(
            torch.flatten(mask_arr[i].nonzero()).tolist()
        )
    
    # For each row in the input_ids, assign 103 [MASK] token to the selection indices
    for i in range(inputs.input_ids.shape[0]):
        inputs.input_ids[i, selection[i]] = 103
    
    return inputs

In [12]:
class BertDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    
    def __len__(self):
        return len(self.encodings.input_ids)
    
    def __getitem__(self, index):
        return {key: torch.tensor(value[index].clone().detach()) for key, value in self.encodings.items()}

In [13]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
model = BertForPreTraining.from_pretrained(MODEL_NAME)

with open(DATASET_PATH, 'r') as fp:
    text = fp.read().split('\n')
    
# Generate NSP Dataset
sentence_a, sentence_b, label = generate_nsp_dataset(text)

inputs = tokenizer(sentence_a,
                   sentence_b,
                   return_tensors='pt',
                   max_length=MAX_LENGTH,
                   truncation=True,
                   padding='max_length')

inputs['next_sentence_label'] = torch.LongTensor([label]).T

# Generate MLM Dataset
inputs = generate_mlm_dataset(tokenizer, inputs)

dataset = BertDataset(inputs)
loader  = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=SHUFFLE)

Some weights of BertForPreTraining were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['cls.predictions.decoder.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.


In [14]:
model.to(device)
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

for epoch in range(NUM_EPOCHS):
    # this just creates blank progress bars for both our epochs
    loop = tqdm(loader, leave=True) #leave=True enables us to see the progress bars for each epoch
    for batch in loop:
        # zero_grad sets the gradients of all optimized tensors to zero
        # we need to set the gradients to zero b/f starting backprop b/c pytoch accumulates the gradients on subsequent backward pass (this is convenient while training RNNs)
        # 
        optimizer.zero_grad() # stops the gradient calculations from the previous set being carried over to the next set
        
        # batch in loop contains the 4 inputs (ie. input_ids, token_ids, attention_masks and labels)
        input_ids = batch['input_ids'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        next_sentence_label = batch['next_sentence_label'].to(device)
        labels = batch['labels'].to(device)
        
        # forward pass, feeding input data through all the neurons in the network from first to last layer
        outputs = model(input_ids,
                        attention_mask=attention_mask,
                        token_type_ids=token_type_ids,
                        next_sentence_label=next_sentence_label,
                        labels=labels)
        loss = outputs.loss

        # Backward propagation, compute the gradient
        loss.backward()
        
        # update model parameters and take a step using the computerd gradient
        optimizer.step()
        loop.set_description("Epoch {}".format(epoch))
        loop.set_postfix(loss=loss.item())
        
model.save_pretrained(PRE_TRAINED_MODEL_PATH)

Epoch 0: 100%|██████████| 159/159 [00:49<00:00,  3.23it/s, loss=0.537]
Epoch 1: 100%|██████████| 159/159 [00:49<00:00,  3.22it/s, loss=0.266]
