In [1]:
import torch
from torch import nn
from datasets import load_dataset
from transformers import BertTokenizerFast
import functools

import collections
import numpy as np

from transformers import default_data_collator, DataCollatorForWholeWordMask
from torch.utils.data import DataLoader

from bert.model import BertConfig, BertEmbeddings
from bert.train import TrainingConfig

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

%load_ext autoreload
%autoreload 2

In [68]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

def wrap_tokenize_and_chunk(tokenizer, max_length=32):
    def tokenize_and_chunk(examples):
        # https://huggingface.co/docs/transformers/main/en/pad_truncation
        result = tokenizer(
            examples["text"],
            truncation = True,
            max_length=max_length,
            padding="max_length",
            return_overflowing_tokens=True,
        )
        result.pop("overflow_to_sample_mapping")
        return result
    return tokenize_and_chunk

#         if tokenizer.is_fast:
#             # word_ids maps each token to the index of the word in the source sentence that it came from
#             result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]

        # Create a new labels column
#         result["labels"] = result["input_ids"].copy()

#         # Extract mapping between new and old indices
#         sample_map = result.pop("overflow_to_sample_mapping")
#         for key, values in examples.items():
#             result[key] = [values[i] for i in sample_map]
        
#         import ipdb
#         ipdb.set_trace()
#         return result

# To get the normalized and pre-tokenized list of words that "word_ids" will index into
def pretokenizer(tokenizer, text):
    normalizer = tokenizer.backend_tokenizer.normalizer
    pretokenizer = tokenizer.backend_tokenizer.pre_tokenizer
    return pretokenizer.pre_tokenize_str(normalizer.normalize_str(text))

## Given input_ids, can convert those to string tokens, e.g. ##word
# tokenizer.convert_ids_to_tokens(tokenized_dataset[index]["input_ids"])
## Given input_ids, can decode those to the original text (e.g. concat tokens)
# tokenizer.decode(tokenized_dataset[index]["input_ids"])
## Special tokens
# tokenizer.all_special_tokens = ['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]']
# tokenizer.all_special_ids = [100, 102, 0, 101, 103]


class TrainingCollator:
    def __init__(self, tokenizer, train_config):
        self.pass_through_keys = ["token_type_ids", "attention_mask"]
        self.collator = DataCollatorForWholeWordMask(
            tokenizer,
            mlm = True,
            mlm_probability=train_config.mask_lm_prob,
            return_tensors = "pt"
        )
    def __call__(self, examples):
        pass_through_examples = []
        input_ids = []
        for example in examples:
            pass_through = {
                key: example[key] for key in self.pass_through_keys
            }
            pass_through["original_input_ids"] = example["input_ids"].copy()
            pass_through_examples.append(pass_through)
            input_ids.append({
                "input_ids": example["input_ids"]
            })
        
        batch = {**default_data_collator(pass_through_examples, return_tensors="pt"),
                 **self.collator(examples)
                }
        return batch

In [69]:
config = BertConfig()
train_config = TrainingConfig()

In [70]:
dataset = load_dataset("bookcorpus", split="train").select(range(100000,200000))
# dataset = dataset.map(lambda samples: {"text_length": [len(text) for text in samples["text"]]}, batched=True)
# dataset = dataset.sort("text_length", reverse=True)
# print(dataset)

In [71]:
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
assert tokenizer.vocab_size == config.vocab_size
tokenized_dataset = dataset.map(wrap_tokenize_and_chunk(tokenizer, train_config.initial_sequence_length), batched=True, remove_columns=dataset.column_names)
print(tokenized_dataset)

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 100000
})


In [72]:
train_dataloader = DataLoader(
    tokenized_dataset,
    shuffle=False,
    batch_size=train_config.batch_size,
    collate_fn=TrainingCollator(tokenizer, train_config),
)

In [73]:
batch = next(iter(train_dataloader))

In [97]:
bert_embeddings = BertEmbeddings(config)

In [98]:
x = bert_embeddings(batch["input_ids"], batch["token_type_ids"])

In [99]:
x.shape

torch.Size([32, 128, 768])

In [87]:
word_emb.shape

torch.Size([32, 128, 768])

In [88]:
batch["input_ids"]

tensor([[  101,  2002,  9826,  ...,     0,     0,     0],
        [  101,  2002,   103,  ...,     0,     0,     0],
        [  101,  2009,  2001,  ...,     0,     0,     0],
        ...,
        [  101,  1036,   103,  ...,     0,     0,     0],
        [  101, 10882,  2721,  ...,     0,     0,     0],
        [  101, 10021,  4841,  ...,     0,     0,     0]])

In [95]:
pos_emb + word_emb

tensor([[[-1.3491e+00,  1.2063e-01,  9.9072e-01,  ...,  3.0361e+00,
          -1.9730e-01, -6.8039e-01],
         [ 1.3265e-01, -1.0415e-01,  5.6475e-01,  ..., -3.9019e-01,
           1.1565e+00,  1.3714e+00],
         [-1.6426e+00, -1.5759e+00,  1.1122e+00,  ...,  1.4033e+00,
          -3.3764e-01, -2.3891e-01],
         ...,
         [ 8.7195e-01, -8.6577e-01,  8.3622e-01,  ..., -5.8911e-02,
           5.4693e-01, -5.4651e-01],
         [-4.0395e-01,  1.6512e-01, -2.9788e-01,  ...,  6.4924e-01,
           4.7402e-02, -1.7359e+00],
         [-1.1760e+00,  8.9313e-02, -4.4666e-01,  ..., -1.0645e+00,
           1.2450e+00,  5.3799e-01]],

        [[-1.3491e+00,  1.2063e-01,  9.9072e-01,  ...,  3.0361e+00,
          -1.9730e-01, -6.8039e-01],
         [ 1.3265e-01, -1.0415e-01,  5.6475e-01,  ..., -3.9019e-01,
           1.1565e+00,  1.3714e+00],
         [-1.3423e+00, -1.6372e-01, -1.8336e-01,  ...,  4.3507e-01,
           5.5937e-01, -1.0800e+00],
         ...,
         [ 8.7195e-01, -8