In [1]:
from transformers import CONFIG_MAPPING, BertForMaskedLM, BertModel, BertPreTrainedModel, BertTokenizer
import torch
import torch.nn as nn
import torch.nn.functional as F
import pdb

In [2]:
num_param = lambda model : sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6

In [3]:
ACT2FN = {
    "relu": F.relu,
    "silu": F.silu,
    "swish": F.silu,
    "gelu": F.gelu,
    "tanh": torch.tanh,
#     "gelu_new": gelu_new,
#     "gelu_fast": gelu_fast,
#     "mish": mish,
#     "linear": linear_act,
    "sigmoid": torch.sigmoid,
}

In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [5]:
bert_small = {
    "hidden_size" : 512 ,
    "num_hidden_layers" : 4,
    "num_attention_heads": int(512/64),
    "intermediate_size" : int(512*4)
}

bert_tiny = {
    "hidden_size" : 128 ,
    "num_hidden_layers" : 2,
    "num_attention_heads": int(128/64),
    "intermediate_size" : int(128*4)
}

In [6]:
config = CONFIG_MAPPING['bert']()
config.update(bert_tiny)
config.projector='256-256-256'
config

BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 128,
  "initializer_range": 0.02,
  "intermediate_size": 512,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 2,
  "num_hidden_layers": 2,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "projector": "256-256-256",
  "transformers_version": "4.5.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [8]:
bert_model = BertModel(config=config)
bert_lm = BertForMaskedLM(config=config)

In [9]:
tokenizer('i am going to berlin')

{'input_ids': [101, 1045, 2572, 2183, 2000, 4068, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}

In [10]:
tokenizer.decode(tokenizer('i am going to berlin')['input_ids'])

'[CLS] i am going to berlin [SEP]'

In [14]:
class Projector(nn.Module):
    def __init__(self, config,input_size,output_size):
        super().__init__()
        self.dense = nn.Linear(input_size, output_size,bias=False)
        if isinstance(config.hidden_act, str):
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
        self.LayerNorm = nn.LayerNorm(output_size, eps=config.layer_norm_eps)

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states

In [15]:
class BertForBarlowTwins(BertPreTrainedModel):

    _keys_to_ignore_on_load_unexpected = [r"pooler"]
    _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
    
    def _init_weights(self, module):
        """Initialize the weights"""
#         pdb.set_trace()
        if isinstance(module, nn.Linear):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
            
    def __init__(self, config):
        super().__init__(config)

        if config.is_decoder:
            logger.warning(
                "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
                "bi-directional self-attention."
            )
            
        self.bert = BertModel(config, add_pooling_layer=False)
        
        sizes = [config.hidden_size] + list(map(int, config.projector.split('-')))
        layers = []
        for i in range(len(sizes) - 2):
            layers.append(Projector(config,sizes[i], sizes[i + 1]))
        layers.append(Projector(config,sizes[-2], sizes[-1]))
        self.projector = nn.Sequential(*layers)
        
        self.apply(self._init_weights)
        
#     def get_output_embeddings(self):
#         return self.cls.decoder

#     def set_output_embeddings(self, new_embeddings):
#         self.cls.decoder = new_embeddings
        
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        return self.projector(outputs.last_hidden_state)

In [18]:
projector = Projector(config,128,128)

In [19]:
bert = BertForBarlowTwins(config)

In [20]:
bert

BertForBarlowTwins(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=128, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=128, out_features=128, bias=True)
              (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine

In [41]:
input_1 = tokenizer("The capital of France is [MASK].", return_tensors="pt",padding='max_length',max_length=32)
input_2 = tokenizer("I am going to Berlin.", return_tensors="pt",padding='max_length',max_length=32)

In [42]:
output1 = bert(**input_1)
output2 = bert(**input_2)

In [43]:
output1.shape, output2.shape

(torch.Size([1, 32, 256]), torch.Size([1, 32, 256]))

In [95]:
output1.T @ output2

RuntimeError: Expected batch2_sizes[0] == bs && batch2_sizes[1] == contraction_size to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

In [96]:
output1.T.shape

torch.Size([128, 9, 1])

In [30]:
bert_lm = BertForMaskedLM(config=config)

In [97]:
torch.rand(9,1)@torch.rand(1,9)

tensor([[0.1294, 0.2320, 0.1160, 0.2871, 0.2027, 0.1145, 0.0425, 0.1529, 0.2508],
        [0.3112, 0.5580, 0.2789, 0.6905, 0.4875, 0.2755, 0.1022, 0.3677, 0.6032],
        [0.0291, 0.0523, 0.0261, 0.0647, 0.0457, 0.0258, 0.0096, 0.0344, 0.0565],
        [0.0956, 0.1714, 0.0857, 0.2122, 0.1498, 0.0847, 0.0314, 0.1130, 0.1853],
        [0.1932, 0.3463, 0.1731, 0.4286, 0.3026, 0.1710, 0.0634, 0.2282, 0.3744],
        [0.0050, 0.0090, 0.0045, 0.0111, 0.0079, 0.0044, 0.0016, 0.0059, 0.0097],
        [0.0950, 0.1703, 0.0851, 0.2107, 0.1487, 0.0841, 0.0312, 0.1122, 0.1840],
        [0.1233, 0.2211, 0.1105, 0.2737, 0.1932, 0.1092, 0.0405, 0.1457, 0.2391],
        [0.0313, 0.0561, 0.0280, 0.0694, 0.0490, 0.0277, 0.0103, 0.0369, 0.0606]])

In [11]:
bert_model = BertModel(config, add_pooling_layer=False)

In [27]:
bert_model_output = bert_model(**input_1)

In [29]:
bert_model_output.last_hidden_state.shape

torch.Size([1, 9, 128])

In [44]:
input_1

{'input_ids': tensor([[ 101, 1996, 3007, 1997, 2605, 2003,  103, 1012,  102,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]])}