In [1]:
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
import torch
from transformers import PreTrainedTokenizerBase, BertTokenizer, BatchEncoding, BertModel, BertConfig
import pdb
from copy import deepcopy
from datasets import load_dataset
from tokenizers import BertWordPieceTokenizer

In [2]:
corpus = load_dataset('bookcorpus',split='train')

Reusing dataset bookcorpus (/mounts/data/corp/huggingface/datasets/bookcorpus/plain_text/1.0.0/44662c4a114441c35200992bea923b170e6f13f2f0beb7c14e43759cec498700)


In [3]:
corpus

Dataset({
    features: ['text'],
    num_rows: 74004228
})

In [4]:
corpus[0]['text']

'the half-ling book one in the fall of igneeria series kaylee soderburg copyright 2013 kaylee soderburg all rights reserved .'

In [12]:
@dataclass
class DataCollatorForBarlowBertWithMLM:
    """
    Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
    are not all of the same length.
    Args:
        tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
            The tokenizer used for encoding the data.
        mlm (:obj:`bool`, `optional`, defaults to :obj:`True`):
            Whether or not to use masked language modeling. If set to :obj:`False`, the labels are the same as the
            inputs with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for
            non-masked tokens and the value to predict for the masked token.
        mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
            The probability with which to (randomly) mask tokens in the input, when :obj:`mlm` is set to :obj:`True`.
        pad_to_multiple_of (:obj:`int`, `optional`):
            If set will pad the sequence to a multiple of the provided value.
    .. note::
        For best performance, this data collator should be used with a dataset having items that are dictionaries or
        BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
        :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
        argument :obj:`return_special_tokens_mask=True`.
    """

    tokenizer: PreTrainedTokenizerBase
    mlm: bool = True
    mlm_probability: float = 0.15
    pad_to_multiple_of: Optional[int] = None

    def __post_init__(self):
        if self.mlm and self.tokenizer.mask_token is None:
            raise ValueError(
                "This tokenizer does not have a mask token which is necessary for masked language modeling. "
                "You should pass `mlm=False` to train on causal language modeling instead."
            )

    def __call__(
        self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        # Handle dict or lists with proper padding and conversion to tensor.
        if isinstance(examples[0], (dict, BatchEncoding)):
            batch = self.tokenizer.pad(examples, 
                                        return_tensors="pt", 
                                        pad_to_multiple_of=self.pad_to_multiple_of)
        else:
            print('Error. This collator only works with dicts or BatchEncoded inputs')

        # If special token mask has been preprocessed, pop it from the dict.
        special_tokens_mask = batch.pop("special_tokens_mask", None)
        batch_1 = deepcopy(batch)
        batches = []

        batch["mlm_input_ids"],batch["mlm_labels"] = self.mask_tokens(batch["input_ids"], special_tokens_mask=special_tokens_mask)
        batches.append(batch)
        batch_1["mlm_input_ids"],batch_1["mlm_labels"] = self.mask_tokens(batch_1["input_ids"], special_tokens_mask=special_tokens_mask)
        batches.append(batch_1)
        
        # batches.append({'masked_indices': masked_indices_1 | masked_indices_2})
            
        return batches

    def mask_tokens(
        self, inputs: torch.Tensor, special_tokens_mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
        """
        labels = inputs.clone()
        inputs = inputs.clone()
        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
        probability_matrix = torch.full(inputs.shape, self.mlm_probability)
#         if special_tokens_mask is None:
#             special_tokens_mask = [
#                 self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in inputs.tolist()
#             ]
#             special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
#         else:
        special_tokens_mask = special_tokens_mask.bool()
#         pdb.set_trace()
        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100  # We only compute loss on masked tokens

        # # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
        inputs[masked_indices] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

        # # The rest of the time (20% of the time) we keep the masked input tokens unchanged
        
        #  # 10% of the time, we replace masked input tokens with random word
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        #  # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels

In [6]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')#, padding=True,truncation=True)

In [26]:
collator = DataCollatorForBarlowBertWithMLM(tokenizer,mlm_probability=0.1)

In [8]:
inputs = [tokenizer(corpus[i]['text'],return_tensors="pt",padding='max_length',truncation=True,max_length=32,return_special_tokens_mask=True,return_token_type_ids=False) for i in range(4) ]

In [9]:
inputs

[{'input_ids': tensor([[  101,  1996,  2431,  1011, 17002,  2338,  2028,  1999,  1996,  2991,
           1997, 16270, 11510,  2401,  2186, 10905, 10559,  2061,  4063,  4645,
           9385,  2286, 10905, 10559,  2061,  4063,  4645,  2035,  2916,  9235,
           1012,   102]]), 'special_tokens_mask': tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1]])},
 {'input_ids': tensor([[  101,  3175,  1024, 17332, 24594, 17134,  2581, 21486,  3175,  1011,
           2410,  1024,  4891,  1011, 17332, 24594, 17134,  2581, 22394,  2005,
           2026,  2155,  1010,  2040,  6628,  2033,  2000,  2196,  2644,  3554,
           2005,   102]]), 'special_tokens_mask': tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 1]]), 'attention_mask': t

In [10]:
tokenizer.decode(inputs[0]['input_ids'][0].squeeze().data.numpy(),skip_special_tokens=False)

'[CLS] the half - ling book one in the fall of igneeria series kaylee soderburg copyright 2013 kaylee soderburg all rights reserved. [SEP]'

In [27]:
out = collator(inputs)

In [15]:
len(out)

2

In [16]:
out[0]['input_ids']

tensor([[[  101,  1996,  2431,  1011, 17002,  2338,  2028,  1999,  1996,  2991,
           1997, 16270, 11510,  2401,  2186, 10905, 10559,  2061,  4063,  4645,
           9385,  2286, 10905, 10559,  2061,  4063,  4645,  2035,  2916,  9235,
           1012,   102]],

        [[  101,  3175,  1024, 17332, 24594, 17134,  2581, 21486,  3175,  1011,
           2410,  1024,  4891,  1011, 17332, 24594, 17134,  2581, 22394,  2005,
           2026,  2155,  1010,  2040,  6628,  2033,  2000,  2196,  2644,  3554,
           2005,   102]],

        [[  101,  1045,  4299,  1045,  2018,  1037,  2488,  3437,  2000,  2008,
           3160,  1012,   102,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0]],

        [[  101,  2732, 11227,  1010,  2047,  2259,  2003,  2025,  1996,  2173,
           2017,  2094,  5987,  2172,  2000,  4148,  1012,   102,     0,     0,
              0,     0,     0,     0,  

In [17]:
out[0]['mlm_input_ids']

tensor([[[  101,   103,   103,   103, 17002,   103,   103,   103,   103,   103,
            103,  4957,   103,  4856,  2186,   103,   103,   103,   103,   103,
            103,   103,   103, 25434,   103, 20353,   103,   103,   103,   103,
           1012,   102]],

        [[  101,   103,   103,   103,   103, 17134,   103,   103,  3175,   392,
            103,   103,   103,  1011,   103,   103,   103,   103,   103,  2005,
           1191,   103,   103, 28837,   103,   103,   103,   103,   103, 19245,
            103,   102]],

        [[  101,   103,   103,   103,   103,   103,   103,  3437,   103,   103,
            103,   103,   102,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0]],

        [[  101,   103,   103,   103,   103,   103,   103,   103,   103,   103,
            103,  2094,   103,   103,   103,   103,   103,   102,     0,     0,
              0,     0,     0,     0,  

In [18]:
tokenizer.decode(out[0]['input_ids'][0].squeeze().data.numpy(),skip_special_tokens=False)

'[CLS] the half - ling book one in the fall of igneeria series kaylee soderburg copyright 2013 kaylee soderburg all rights reserved. [SEP]'

In [19]:
tokenizer.decode(out[0]['mlm_input_ids'][0].squeeze().data.numpy(),skip_special_tokens=False)

'[CLS] [MASK] [MASK] [MASK] ling [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] link [MASK] 1950s series [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK]national [MASK] biotechnology [MASK] [MASK] [MASK] [MASK]. [SEP]'

In [20]:
tokenizer.decode(out[1]['mlm_input_ids'][0].squeeze().data.numpy())

'[CLS] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK]„Å£ [MASK]gra [MASK] [MASK] [MASK] ready ethanol [MASK] [MASK] so [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [SEP]'

In [30]:
(out[0]['input_ids'][0]==out[0]['mlm_input_ids'][0]).sum()/32

tensor(0.9375)

In [29]:
(out[0]['mlm_input_ids'][0]==103).sum()/32, (out[1]['mlm_input_ids'][0]==103).sum()/32

(tensor(0.0312), tensor(0.1875))

In [28]:
(out[0]['mlm_labels'][0]!=-100).sum()/(32)

tensor(0.0625)

## Dropout BERT

In [164]:
config = BertConfig()

In [165]:
bert_tiny = {
    "hidden_size" : 128 ,
    "num_hidden_layers" : 2,
    "num_attention_heads": int(128/64),
    "intermediate_size" : int(128*4)
}

In [181]:
no_dropout = {
    "attention_probs_dropout_prob": 0.0,
    "hidden_dropout_prob": 0.0
}

In [182]:
config.update(bert_tiny)

In [183]:
config.update(no_dropout)

In [185]:
config

BertConfig {
  "attention_probs_dropout_prob": 0.0,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 128,
  "initializer_range": 0.02,
  "intermediate_size": 512,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 2,
  "num_hidden_layers": 2,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.8.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [186]:
bert_model = BertModel(config, add_pooling_layer=True)

In [170]:
inputs = [tokenizer(corpus[i]['text'],return_tensors="pt",padding='max_length',truncation=True,max_length=32,return_special_tokens_mask=False,return_token_type_ids=True) for i in range(4) ]

In [187]:
out1 = bert_model(**inputs[0])

In [188]:
out2 = bert_model(**inputs[0])

In [189]:
out1.pooler_output

tensor([[-0.2360,  0.3188,  0.0112,  0.2215,  0.2356, -0.0209,  0.0389, -0.1634,
          0.0339,  0.2450,  0.0752, -0.3388,  0.2422, -0.3530, -0.1269,  0.1579,
         -0.1193,  0.1439, -0.0313, -0.0605,  0.0805,  0.1103,  0.3938,  0.1160,
          0.3036, -0.2942,  0.2283,  0.1398,  0.0811,  0.1837, -0.0860, -0.1693,
         -0.4958,  0.1911,  0.3638, -0.3315, -0.0873, -0.1776, -0.1431, -0.2119,
          0.3633, -0.0500,  0.0549, -0.0465, -0.0073, -0.0145, -0.1697, -0.1467,
         -0.0921,  0.0206, -0.0330, -0.1368, -0.0939, -0.0871, -0.1688, -0.1531,
         -0.1137,  0.0058, -0.1541, -0.3338, -0.0227, -0.0152,  0.0033, -0.0960,
          0.0118, -0.2984,  0.0507,  0.3106, -0.2967,  0.2001, -0.0014, -0.1241,
         -0.1981, -0.0143, -0.1187,  0.2613, -0.1896,  0.1479, -0.2526, -0.3897,
          0.1129,  0.3664, -0.3909, -0.2956, -0.3465, -0.2942, -0.0788,  0.0301,
          0.0533, -0.2214, -0.4570,  0.1595, -0.2155, -0.0009,  0.3324,  0.1567,
          0.0501, -0.0744,  

In [190]:
out2.pooler_output

tensor([[-0.2360,  0.3188,  0.0112,  0.2215,  0.2356, -0.0209,  0.0389, -0.1634,
          0.0339,  0.2450,  0.0752, -0.3388,  0.2422, -0.3530, -0.1269,  0.1579,
         -0.1193,  0.1439, -0.0313, -0.0605,  0.0805,  0.1103,  0.3938,  0.1160,
          0.3036, -0.2942,  0.2283,  0.1398,  0.0811,  0.1837, -0.0860, -0.1693,
         -0.4958,  0.1911,  0.3638, -0.3315, -0.0873, -0.1776, -0.1431, -0.2119,
          0.3633, -0.0500,  0.0549, -0.0465, -0.0073, -0.0145, -0.1697, -0.1467,
         -0.0921,  0.0206, -0.0330, -0.1368, -0.0939, -0.0871, -0.1688, -0.1531,
         -0.1137,  0.0058, -0.1541, -0.3338, -0.0227, -0.0152,  0.0033, -0.0960,
          0.0118, -0.2984,  0.0507,  0.3106, -0.2967,  0.2001, -0.0014, -0.1241,
         -0.1981, -0.0143, -0.1187,  0.2613, -0.1896,  0.1479, -0.2526, -0.3897,
          0.1129,  0.3664, -0.3909, -0.2956, -0.3465, -0.2942, -0.0788,  0.0301,
          0.0533, -0.2214, -0.4570,  0.1595, -0.2155, -0.0009,  0.3324,  0.1567,
          0.0501, -0.0744,  

In [180]:
out1.pooler_output.shape

torch.Size([1, 128])

In [191]:
out1.pooler_output-out2.pooler_output

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0.]], grad_fn=<SubBackward0>)

### with dropout=0.0, we get same representation when passing a sentence twice. Hence the only difference in representation comes from dropout