In [1]:
import json
import torch
import random
from tqdm import tqdm
from transformers import (BertTokenizer,
                          BertForNextSentencePrediction)

In [2]:
DATASET_PATH = "../../data/meditations.txt"
PRE_TRAINED_MODEL_PATH = "../../models/meditations_nsp_bert_base_uncased_continual_pretraining"
MODEL_NAME = "bert-base-uncased"
NUM_EPOCHS = 2
BATCH_SIZE = 16
MAX_LENGTH = 256
SHUFFLE    = True

In [3]:
def generate_nsp_dataset(text):
    # bag of sentences
    bag = [item for sentence in text for item in sentence.split('.') if item != '']
    bag_size = len(bag)

    sentence_a = []
    sentence_b = []
    label = []

    for paragraph in text:
        sentences = [
            sentence for sentence in paragraph.split('.') if sentence != ''
        ]
        num_sentences = len(sentences)
        if num_sentences > 1:
            start = random.randint(0, num_sentences-2)
            # 50/50 whether is IsNextSentence or NotNextSentence
            # In labels,
                # 0 indicates sequence B is a continuation of sequence A,
                # 1 indicates sequence B is a random sequence.
            
            if random.random() >= 0.5:
                # this is IsNextSentence
                sentence_a.append(sentences[start])
                sentence_b.append(sentences[start+1])
                label.append(0)
            else:
                index = random.randint(0, bag_size-1)
                # this is NotNextSentence
                sentence_a.append(sentences[start])
                sentence_b.append(bag[index])
                label.append(1)
    
    return sentence_a, sentence_b, label

In [4]:
with open(DATASET_PATH, "r") as fin:
    text = fin.read().split('\n')

sentence_a, sentence_b, label = generate_nsp_dataset(text)
# In labels,
    # 0 indicates sequence B is a continuation of sequence A,
    # 1 indicates sequence B is a random sequence.

In [14]:
class NSPDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    
    def __len__(self):
        return len(self.encodings.input_ids)
    
    def __getitem__(self, index):
        return {key: torch.tensor(value[index].clone().detach()) for key, value in self.encodings.items()}

In [15]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

#
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
model = BertForNextSentencePrediction.from_pretrained(MODEL_NAME)

#
inputs = tokenizer(sentence_a,
                   sentence_b,
                   return_tensors='pt',
                   max_length=MAX_LENGTH,
                   truncation=True,
                   padding='max_length')

inputs['labels'] = torch.LongTensor([label]).T

#
dataset = NSPDataset(inputs)
loader  = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=SHUFFLE)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForNextSentencePrediction: ['cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the ret

In [16]:
model.to(device)
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)

for epoch in range(NUM_EPOCHS):
    # this just creates blank progress bars for both our epochs
    loop = tqdm(loader, leave=True) #leave=True enables us to see the progress bars for each epoch
    for batch in loop:
        # zero_grad sets the gradients of all optimized tensors to zero
        # we need to set the gradients to zero b/f starting backprop b/c pytoch accumulates the gradients on subsequent backward pass (this is convenient while training RNNs)
        # 
        optimizer.zero_grad() # stops the gradient calculations from the previous set being carried over to the next set
        
        # batch in loop contains the 4 inputs (ie. input_ids, token_ids, attention_masks and labels)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        labels = batch['labels'].to(device)
        
        # forward pass, feeding input data through all the neurons in the network from first to last layer
        outputs = model(input_ids,
                        attention_mask=attention_mask,
                        token_type_ids=token_type_ids,
                        labels=labels)
        loss = outputs.loss

        # Backward propagation, compute the gradient
        loss.backward()
        
        # update model parameters and take a step using the computerd gradient
        optimizer.step()
        loop.set_description("Epoch {}".format(epoch))
        loop.set_postfix(loss=loss.item())
        
model.save_pretrained(PRE_TRAINED_MODEL_PATH)

Epoch 0: 100%|██████████| 20/20 [00:27<00:00,  1.38s/it, loss=1.38] 
Epoch 1: 100%|██████████| 20/20 [00:26<00:00,  1.34s/it, loss=0.684]


In [17]:
# To load the custom pre-trained model
model = BertForNextSentencePrediction.from_pretrained(PRE_TRAINED_MODEL_PATH)
model

BertForNextSentencePrediction(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element