In [31]:
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
import torch
from transformers import PreTrainedTokenizerBase, BertTokenizer, BatchEncoding, BertModel, BertConfig
from copy import copy
import pdb
from datasets import load_dataset
from tokenizers import BertWordPieceTokenizer

In [15]:
corpus = load_dataset('bookcorpus',split='train')

Reusing dataset bookcorpus (/mounts/data/corp/huggingface/datasets/bookcorpus/plain_text/1.0.0/44662c4a114441c35200992bea923b170e6f13f2f0beb7c14e43759cec498700)


In [32]:
corpus[0]['text']

'the half-ling book one in the fall of igneeria series kaylee soderburg copyright 2013 kaylee soderburg all rights reserved .'

In [17]:
def _collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None):
    """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
    # Tensorize if necessary.
    if isinstance(examples[0], (list, tuple)):
        examples = [torch.tensor(e, dtype=torch.long) for e in examples]

    # Check if padding is necessary.
    length_of_first = examples[0].size(0)
    are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
    if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
        return torch.stack(examples, dim=0)

    # If yes, check if we have a `pad_token`.
    if tokenizer._pad_token is None:
        raise ValueError(
            "You are attempting to pad samples but the tokenizer you are using"
            f" ({tokenizer.__class__.__name__}) does not have a pad token."
        )

    # Creating the full tensor and filling it with our data.
    max_length = max(x.size(0) for x in examples)
    if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
        max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
    result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
    for i, example in enumerate(examples):
        if tokenizer.padding_side == "right":
            result[i, : example.shape[0]] = example
        else:
            result[i, -example.shape[0] :] = example
    return result

In [18]:
@dataclass
class DataCollatorForLanguageModeling:
    """
    Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
    are not all of the same length.

    Args:
        tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
            The tokenizer used for encoding the data.
        mlm (:obj:`bool`, `optional`, defaults to :obj:`True`):
            Whether or not to use masked language modeling. If set to :obj:`False`, the labels are the same as the
            inputs with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for
            non-masked tokens and the value to predict for the masked token.
        mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
            The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`.
        pad_to_multiple_of (:obj:`int`, `optional`):
            If set will pad the sequence to a multiple of the provided value.

    .. note::

        For best performance, this data collator should be used with a dataset having items that are dictionaries or
        BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
        :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
        argument :obj:`return_special_tokens_mask=True`.
    """

    tokenizer: PreTrainedTokenizerBase
    mlm: bool = True
    mlm_probability: float = 0.15
    pad_to_multiple_of: Optional[int] = None

    def __post_init__(self):
        if self.mlm and self.tokenizer.mask_token is None:
            raise ValueError(
                "This tokenizer does not have a mask token which is necessary for masked language modeling. "
                "You should pass `mlm=False` to train on causal language modeling instead."
            )

    def __call__(
        self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        # Handle dict or lists with proper padding and conversion to tensor.
        if isinstance(examples[0], (dict, BatchEncoding)):
#             pdb.set_trace()
            batch = self.tokenizer.pad(examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of)
        else:
            batch = {"input_ids": _collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)}
#         pdb.set_trace()
        # If special token mask has been preprocessed, pop it from the dict.
        special_tokens_mask = batch.pop("special_tokens_mask", None)
        batch_1 = copy(batch)
        batches=[]
        

        batch["input_ids"], masked_indices_1 = self.mask_tokens(batch["input_ids"], special_tokens_mask=special_tokens_mask)
        batches.append(batch)
        batch_1["input_ids"], masked_indices_2 = self.mask_tokens(batch_1["input_ids"], special_tokens_mask=special_tokens_mask)
        batches.append(batch_1)
        
        batches.append({'masked_indices': masked_indices_1 | masked_indices_2})
            
        return batches

    def mask_tokens(
        self, inputs: torch.Tensor, special_tokens_mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
        """
        labels = inputs.clone()
        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
        probability_matrix = torch.full(labels.shape, self.mlm_probability)
        if special_tokens_mask is None:
            special_tokens_mask = [
                self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
            ]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
#         labels[~masked_indices] = -100  # We only compute loss on masked tokens

        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
#         pdb.set_trace()
        return inputs, masked_indices

In [19]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', padding=True,truncation=True)

In [20]:
tokenizer2 = BertWordPieceTokenizer('bert-base-uncased-vocab.txt')

In [21]:
collator = DataCollatorForLanguageModeling(tokenizer,mlm_probability=0.8)

In [22]:
inputs = [tokenizer(corpus[i]['text'],return_tensors="pt",padding='max_length',max_length=32,return_special_tokens_mask=True,return_token_type_ids=False) for i in range(4) ]

In [23]:
tokenizer2.encode("The capital of France is Paris.",add_special_tokens=True)

Encoding(num_tokens=9, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])

In [24]:
input_1 = tokenizer("The capital of France is Paris.", return_tensors="pt",padding='max_length',max_length=16,return_special_tokens_mask=True,return_token_type_ids=False)
input_2 = tokenizer("I am going to Berlin.", return_tensors="pt",padding='max_length',max_length=16,return_special_tokens_mask=True,return_token_type_ids=False)
input_3 = tokenizer("It's a beautiful day today", return_tensors="pt",padding='max_length',max_length=16,return_special_tokens_mask=True,return_token_type_ids=False)
input_4 = tokenizer("Where are you going so early in the morning?", return_tensors="pt",padding='max_length',max_length=16,return_special_tokens_mask=True,return_token_type_ids=False)

In [28]:
tokenizer.pad([input_1,input_2,input_3])

{'input_ids': tensor([[[ 101, 1996, 3007, 1997, 2605, 2003, 3000, 1012,  102,    0,    0,
             0,    0,    0,    0,    0]],

        [[ 101, 1045, 2572, 2183, 2000, 4068, 1012,  102,    0,    0,    0,
             0,    0,    0,    0,    0]],

        [[ 101, 2009, 1005, 1055, 1037, 3376, 2154, 2651,  102,    0,    0,
             0,    0,    0,    0,    0]]]), 'special_tokens_mask': tensor([[[1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]],

        [[1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]],

        [[1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]]]), 'attention_mask': tensor([[[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]],

        [[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]],

        [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]]])}

In [26]:
# input_1 = tokenizer(corpus[0]['text'], return_tensors="pt",padding='max_length',max_length=16,return_special_tokens_mask=True,return_token_type_ids=False)
# input_2 = tokenizer(corpus[1]['text'], return_tensors="pt",padding='max_length',max_length=16,return_special_tokens_mask=True,return_token_type_ids=False)
# input_3 = tokenizer(corpus[2]['text'], return_tensors="pt",padding='max_length',max_length=16,return_special_tokens_mask=True,return_token_type_ids=False)
# input_4 = tokenizer(corpus[3]['text'], return_tensors="pt",padding='max_length',max_length=16,return_special_tokens_mask=True,return_token_type_ids=False)

In [25]:
input_1

{'input_ids': tensor([[ 101, 1996, 3007, 1997, 2605, 2003, 3000, 1012,  102,    0,    0,    0,
            0,    0,    0,    0]]), 'special_tokens_mask': tensor([[1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]])}

In [26]:
[input_1,input_2]

[{'input_ids': tensor([[ 101, 1996, 3007, 1997, 2605, 2003, 3000, 1012,  102,    0,    0,    0,
             0,    0,    0,    0]]), 'special_tokens_mask': tensor([[1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]])},
 {'input_ids': tensor([[ 101, 1045, 2572, 2183, 2000, 4068, 1012,  102,    0,    0,    0,    0,
             0,    0,    0,    0]]), 'special_tokens_mask': tensor([[1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]])}]

In [29]:
inp_collated = collator([input_1,input_2])

In [30]:
inp_collated

[{'input_ids': tensor([[[  101,   103, 16156,   103,   103,   103,   103,   103,   102,     0,
               0,     0,     0,     0,     0,     0]],
 
         [[  101,   103,    50,   103,   103,   103,   103,   102,     0,     0,
               0,     0,     0,     0,     0,     0]]]), 'attention_mask': tensor([[[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]],
 
         [[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]]])},
 {'input_ids': tensor([[[  101,   103, 16156,   103,   103,   103,   103,   103,   102,     0,
               0,     0,     0,     0,     0,     0]],
 
         [[  101,   103,    50,   103,   103,   103,   103,   102,     0,     0,
               0,     0,     0,     0,     0,     0]]]), 'attention_mask': tensor([[[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]],
 
         [[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]]])},
 {'masked_indices': tensor([[[False,  True,  True,  True,  True,  True,  True,  True, False, False,
            False, False, False,

In [29]:
inp_collated[0].keys()

dict_keys(['input_ids', 'attention_mask'])

In [30]:
inp_collated[1]['input_ids']

tensor([[[101, 103, 103, 103, 103, 103, 103, 103, 102,   0,   0,   0,   0,   0,
            0,   0]],

        [[101, 103, 103, 103, 103, 103, 103, 102,   0,   0,   0,   0,   0,   0,
            0,   0]]])

In [31]:
inp_collated[0]['input_ids']

tensor([[[101, 103, 103, 103, 103, 103, 103, 103, 102,   0,   0,   0,   0,   0,
            0,   0]],

        [[101, 103, 103, 103, 103, 103, 103, 102,   0,   0,   0,   0,   0,   0,
            0,   0]]])

In [32]:
inp_collated[2]['masked_indices']

tensor([[[False,  True,  True,  True,  True,  True,  True,  True, False, False,
          False, False, False, False, False, False]],

        [[False,  True,  True,  True,  True,  True,  True, False, False, False,
          False, False, False, False, False, False]]])

## BERT Model

In [33]:
config = BertConfig()

In [34]:
bert_tiny = {
    "hidden_size" : 128 ,
    "num_hidden_layers" : 2,
    "num_attention_heads": int(128/64),
    "intermediate_size" : int(128*4)
}

In [None]:
config.update(bert_tiny)

In [None]:
bert_model = BertModel(config, add_pooling_layer=False)

In [None]:
inp = inp_collated[0]
out1 = bert_model(input_ids = inp['input_ids'].squeeze(),
                  attention_mask = inp['attention_mask'].squeeze())

In [None]:
inp = inp_collated[1]
out2 = bert_model(input_ids = inp['input_ids'].squeeze(),
                  attention_mask = inp['attention_mask'].squeeze())

In [None]:
out1.last_hidden_state.shape

In [None]:
out2.last_hidden_state.shape

In [None]:
masked_indices = inp_collated[2]['masked_indices']

In [None]:
first = out1.last_hidden_state.view(-1,128)[masked_indices.flatten()]

In [None]:
second = out2.last_hidden_state.view(-1,128)[masked_indices.flatten()]

In [None]:
(first.T@second).shape

In [None]:
bert_model