In [1]:
import torch
from bert.data import prepare_pretraining_dataset, TrainingCollator
from transformers import BertTokenizer, BertTokenizerFast, default_data_collator, DataCollatorForWholeWordMask

import random

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

%load_ext autoreload
%autoreload 2

In [107]:
import random
from typing import List, Optional, Any, Tuple
import numpy as np

class DataCollatorForWholeWordMaskDeterministic(DataCollatorForWholeWordMask):
    def __init__(self, *args, random_seed: int = 0, **kwargs):
        super().__init__(*args, **kwargs)
        self.random_seed = random_seed
        self.call_counter = 0
    
    def __call__(self, features, return_tensors=None):
        random.seed(self.random_seed + self.call_counter)
        self.call_counter += 1
        return super().__call__(features, return_tensors)


    def torch_mask_tokens(self, inputs: Any, mask_labels: Any) -> Tuple[Any, Any]:
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. Set
        'mask_labels' means we use whole word mask (wwm), we directly mask idxs according to it's ref.
        """
        import torch

        if self.tokenizer.mask_token is None:
            raise ValueError(
                "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
                " --mlm flag if you want to use this tokenizer."
            )
        labels = inputs.clone()
        # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)

        probability_matrix = mask_labels

        special_tokens_mask = [
            self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
        ]

        probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
        if self.tokenizer.pad_token is not None:
            padding_mask = labels.eq(self.tokenizer.pad_token_id)
            probability_matrix.masked_fill_(padding_mask, value=0.0)

        masked_indices = probability_matrix.bool()
        labels[~masked_indices] = -100  # We only compute loss on masked tokens
        inputs[masked_indices] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
        return inputs, labels

In [4]:
tokenizer_fast = BertTokenizerFast.from_pretrained("bert-base-uncased")
tokenizer_slow = BertTokenizer.from_pretrained("bert-base-uncased")

In [5]:
dataset = prepare_pretraining_dataset(tokenizer_fast, sample_limit=10000)

Loading dataset shards:   0%|          | 0/41 [00:00<?, ?it/s]

In [108]:
def get_batches():
    dataloader = torch.utils.data.DataLoader(
        dataset["test"],
        batch_size=5,
        shuffle = False,
        num_workers = 0,
        collate_fn = DataCollatorForWholeWordMaskDeterministic(
                tokenizer_slow, mlm=True, mlm_probability=0.15,
                return_tensors="pt", random_seed = 0
            )
    )
    iterator = iter(dataloader)
    batches = [next(iterator) for _ in range(5)]
    return batches

In [109]:
batches1 = get_batches()
print()
batches2 = get_batches()




In [110]:
for b1, b2 in zip(batches1, batches2):
    assert b1["input_ids"].count_nonzero() == b2["input_ids"].count_nonzero()
    assert torch.equal(b1["input_ids"], b2["input_ids"])

In [111]:
dataloader = torch.utils.data.DataLoader(
    dataset["test"],
    batch_size=5,
    shuffle = False,
    num_workers = 0,
    collate_fn = DataCollatorForWholeWordMaskDeterministic(
            tokenizer_slow, mlm=True, mlm_probability=0.15,
            return_tensors="pt", random_seed = 0
        )
)

In [112]:
dataloader.collate_fn

DataCollatorForWholeWordMaskDeterministic(tokenizer=BertTokenizer(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True, added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
), mlm=True, mlm_probability=0.15, mask_replace_prob=0.8, random_replace_prob=0.

In [113]:
dataloader.collate_fn.tokenizer

BertTokenizer(name_or_path='bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True, added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)