In [2]:
import requests
from transformers import BertTokenizer, BertForPreTraining
import torch




In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForPreTraining.from_pretrained('bert-base-uncased')


In [5]:
with open('./data/clean.txt','r') as fp:
    text = fp.read().split('\n')

In [6]:
bag = [sentence for para in text for sentence in para.split('.') if sentence !='']

In [7]:
bag_size = len(bag)

In [8]:
import random

sentence_a = []
sentence_b = []
label =[]
for paragraph in text:
    sentences =[
        sentence for sentence in paragraph.split('.') if sentence !=''
    ]
    num_sentences = len(sentences)
    if num_sentences >1:
        start = random.randint(0,num_sentences-2)
        sentence_a.append(sentences[start])
        if random.random() > 0.5:
            sentence_b.append(bag[random.randint(0,bag_size-1)])
            label.append(1)
        else:
            sentence_b.append(sentences[start+1])
            label.append(0)

In [9]:
inputs = tokenizer(sentence_a,sentence_b,return_tensors='pt',max_length=512,truncation=True,padding='max_length')

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.


In [10]:
inputs['next_sentence_labels'] = torch.LongTensor(label).unsqueeze(0).T

In [11]:
inputs['labels'] = inputs.input_ids.detach().clone()

In [12]:
inputs.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'next_sentence_labels', 'labels'])

In [19]:
rand = torch.rand(inputs.input_ids.shape)

In [20]:
mask_arr = (rand<0.15) *(inputs.input_ids!=101) *(inputs.input_ids!=102) *(inputs.input_ids != 0)
mask_arr

tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False,  True,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])

In [21]:
inputs.input_ids.shape[0]

317

In [23]:
for i in range(inputs.input_ids.shape[0]):
    selection = torch.flatten(mask_arr[i].nonzero()).tolist()
    inputs.input_ids[i,selection] = 103

In [24]:
class MeditationDataset(torch.utils.data.Dataset):
    def __init__(self,encodings):
        self.encodings = encodings
        
    def __getitem__(self,idx):
        return{key: torch.tensor(val[idx]) for key,val in self.encodings.items()}
    def __len__(self):
        return len(self.encodings.input_ids)

In [25]:
dataset = MeditationDataset(inputs)

In [29]:
dataloader = torch.utils.data.DataLoader(dataset,batch_size=2,shuffle=True)

In [27]:
model.train()

BertForPreTraining(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elemen

In [28]:
from transformers import AdamW
optim = AdamW(model.parameters(),lr=1e-5)



In [31]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [33]:
from tqdm import tqdm

epochs =2 
for epoch in range(epochs):
    loop = tqdm(dataloader,leave=True)
    for batch in loop:
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        next_sentence_labels = batch['next_sentence_labels'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model(input_ids,token_type_ids=token_type_ids,
                        attention_mask = attention_mask,
                        next_sentence_label=next_sentence_labels,
                        labels=labels)
        loss = outputs.loss
        loss.backward()
        optim.step()
        
        loop.set_description(f'epoch {epoch}')
        loop.set_postfix(loss=loss.item())

  return{key: torch.tensor(val[idx]) for key,val in self.encodings.items()}
epoch 0: 100%|██████████| 159/159 [1:02:22<00:00, 23.54s/it, loss=1.12] 
epoch 1: 100%|██████████| 159/159 [40:00<00:00, 15.10s/it, loss=0.154]
