In [2]:
import random
from tqdm import tqdm
from transformers import BertTokenizer, BertForPreTraining
import torch
from torch.optim import AdamW

In [3]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForPreTraining.from_pretrained('bert-base-uncased')

Some weights of BertForPreTraining were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['cls.predictions.decoder.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [43]:
with open('All_Quiet_on_the_Western_Front.txt', 'r') as fp:
    text = fp.read().split('\n')

In [44]:
bag = [item for sentence in text for item in sentence.split('.') if item != '']
bag_size = len(bag)

In [45]:
scentence_a = []
scentence_b = []
label = []

for paragraph in text:
    sentences = [sentence for sentence in paragraph.split('.') if sentence != '']
    num_scen = len(sentences)
    if num_scen > 1:
        start = random.randint(0, num_scen-2)
        if random.random()>=0.5:
            scentence_a.append(sentences[start])
            scentence_b.append(sentences[start+1])
            label.append(0)
        else:
            index = random.randint(0,bag_size-1)
            scentence_a.append(sentences[start])
            scentence_b.append(bag[index])
            label.append(1)

In [46]:
inputs = tokenizer(scentence_a, scentence_b, return_tensors='pt', max_length=512, truncation=True, padding='max_length')
inputs.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

In [47]:
inputs['next_sentence_label'] = torch.LongTensor([label]).T

In [48]:
inputs['labels'] = inputs.input_ids.detach().clone()

In [49]:
rand = torch.rand(inputs.input_ids.shape)

mask_arr = (rand<0.15)*(inputs.input_ids != 101)*(inputs.input_ids != 102)*(inputs.input_ids != 0)

In [50]:
selection = []

for i in range(inputs.input_ids.shape[0]):
    selection.append(torch.flatten(mask_arr[i].nonzero()).tolist())

In [51]:
for i in range(inputs.input_ids.shape[0]):
    inputs.input_ids[i, selection[i]] = 103

In [52]:
inputs.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'next_sentence_label', 'labels'])

In [53]:
class MlmNspDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
    def __len__(self):
        return len(self.encodings.input_ids)

In [54]:
dataset = MlmNspDataset(inputs)

In [55]:
loader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)

In [56]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

BertForPreTraining(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine

In [57]:
optim = AdamW(model.parameters(), lr=5e-5)
model.train()

BertForPreTraining(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine

In [None]:
for epoch in range(2):
    loop = tqdm(loader, leave=True)
    for batch in loop:
        optim.zero_grad()
        
        input_ids = batch['input_ids'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        next_sentence_label = batch['next_sentence_label'].to(device)
        labels = batch['labels'].to(device)

        output = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, next_sentence_label=next_sentence_label, labels=labels)
        
        loss = output.loss
        loss.backward()
        optim.step()
        
        loop.set_description(f'Epoch {epoch}')
        loop.set_postfix(loss=loss.item())