In [2]:
import torch
from torch import nn
from datasets import load_dataset, load_from_disk
from transformers import BertTokenizerFast, BertTokenizer
import functools

import collections
import numpy as np

from transformers import default_data_collator, DataCollatorForWholeWordMask
from torch.utils.data import DataLoader

from bert.model import BertConfig,BertMLM
from bert.train import TrainingConfig

In [3]:
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

%load_ext autoreload
%autoreload 2

In [4]:
def wrap_tokenize_and_chunk(tokenizer, max_length=32):
    def tokenize_and_chunk(examples):
        # https://huggingface.co/docs/transformers/main/en/pad_truncation
        result = tokenizer(
            examples["text"],
            truncation = True,
            max_length=max_length,
            padding="max_length",
            return_overflowing_tokens=True,
        )
        result.pop("overflow_to_sample_mapping")
        return result
    return tokenize_and_chunk

#         if tokenizer.is_fast:
#             # word_ids maps each token to the index of the word in the source sentence that it came from
#             result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]

        # Create a new labels column
#         result["labels"] = result["input_ids"].copy()

#         # Extract mapping between new and old indices
#         sample_map = result.pop("overflow_to_sample_mapping")
#         for key, values in examples.items():
#             result[key] = [values[i] for i in sample_map]
        
#         import ipdb
#         ipdb.set_trace()
#         return result

# To get the normalized and pre-tokenized list of words that "word_ids" will index into
def pretokenizer(tokenizer, text):
    normalizer = tokenizer.backend_tokenizer.normalizer
    pretokenizer = tokenizer.backend_tokenizer.pre_tokenizer
    return pretokenizer.pre_tokenize_str(normalizer.normalize_str(text))

## Given input_ids, can convert those to string tokens, e.g. ##word
# tokenizer.convert_ids_to_tokens(tokenized_dataset[index]["input_ids"])
## Given input_ids, can decode those to the original text (e.g. concat tokens)
# tokenizer.decode(tokenized_dataset[index]["input_ids"])
## Special tokens
# tokenizer.all_special_tokens = ['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]']
# tokenizer.all_special_ids = [100, 102, 0, 101, 103]


class TrainingCollator:
    def __init__(self, tokenizer, train_config):
        self.pass_through_keys = ["token_type_ids", "attention_mask"]
        self.collator = DataCollatorForWholeWordMask(
            tokenizer,
            mlm = True,
            mlm_probability=train_config.mask_lm_prob,
            return_tensors = "pt"
        )
    def __call__(self, examples):
        pass_through_examples = []
        input_ids = []
        for example in examples:
            pass_through = {
                key: example[key] for key in self.pass_through_keys
            }
            pass_through["original_input_ids"] = example["input_ids"].copy()
            pass_through_examples.append(pass_through)
            input_ids.append({
                "input_ids": example["input_ids"]
            })
        
        batch = {**default_data_collator(pass_through_examples, return_tensors="pt"),
                 **self.collator(examples)
                }
        return batch

In [5]:
config = BertConfig()
train_config = TrainingConfig("output")

In [6]:
dataset = load_dataset("bookcorpus", split="train").select(range(100000,200000))
# dataset = dataset.map(lambda samples: {"text_length": [len(text) for text in samples["text"]]}, batched=True)
# dataset = dataset.sort("text_length", reverse=True)
# print(dataset)

In [7]:
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
assert tokenizer.vocab_size == config.vocab_size
tokenized_dataset = dataset.map(wrap_tokenize_and_chunk(tokenizer, train_config.initial_sequence_length), batched=True, remove_columns=dataset.column_names)
print(tokenized_dataset)

Map:   0%|          | 0/100000 [00:00<?, ? examples/s]

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 100000
})


In [8]:
train_dataloader = DataLoader(
    tokenized_dataset,
    shuffle=True,
    batch_size=32,
    collate_fn=TrainingCollator(tokenizer, train_config),
)

In [9]:
batch = next(iter(train_dataloader))

In [10]:
model = BertMLM(config)

In [11]:
batch.keys()

dict_keys(['token_type_ids', 'attention_mask', 'original_input_ids', 'input_ids', 'labels'])

In [12]:
output = model(**batch)

In [13]:
output.shape

torch.Size([32, 128, 30522])