In [1]:
from transformers import CONFIG_MAPPING, BertForMaskedLM, BertModel, BertPreTrainedModel, BertTokenizer
import torch
import torch.nn as nn
import torch.nn.functional as F
import pdb

In [2]:
num_param = lambda model : sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6

In [3]:
ACT2FN = {
    "relu": F.relu,
    "silu": F.silu,
    "swish": F.silu,
    "gelu": F.gelu,
    "tanh": torch.tanh,
#     "gelu_new": gelu_new,
#     "gelu_fast": gelu_fast,
#     "mish": mish,
#     "linear": linear_act,
    "sigmoid": torch.sigmoid,
}

In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [5]:
bert_small = {
    "hidden_size" : 512 ,
    "num_hidden_layers" : 4,
    "num_attention_heads": int(512/64),
    "intermediate_size" : int(512*4)
}

bert_tiny = {
    "hidden_size" : 128 ,
    "num_hidden_layers" : 2,
    "num_attention_heads": int(128/64),
    "intermediate_size" : int(128*4)
}

In [6]:
config = CONFIG_MAPPING['bert']()
config.update(bert_tiny)
config.projector='256-256-256'
config

BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 128,
  "initializer_range": 0.02,
  "intermediate_size": 512,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 2,
  "num_hidden_layers": 2,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "projector": "256-256-256",
  "transformers_version": "4.5.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [7]:
class Projector(nn.Module):
    def __init__(self, config,input_size,output_size):
        super().__init__()
        self.dense = nn.Linear(input_size, output_size,bias=False)
        if isinstance(config.hidden_act, str):
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
        self.LayerNorm = nn.LayerNorm(output_size, eps=config.layer_norm_eps)

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states

In [8]:
class BertForBarlowTwins(BertPreTrainedModel):

    _keys_to_ignore_on_load_unexpected = [r"pooler"]
    _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
    
    def _init_weights(self, module):
        """Initialize the weights"""
#         pdb.set_trace()
        if isinstance(module, nn.Linear):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
            
    def __init__(self, config):
        super().__init__(config)

        if config.is_decoder:
            logger.warning(
                "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
                "bi-directional self-attention."
            )
            
        self.bert = BertModel(config, add_pooling_layer=False)
        
        sizes = [config.hidden_size] + list(map(int, config.projector.split('-')))
        layers = []
        for i in range(len(sizes) - 2):
            layers.append(Projector(config,sizes[i], sizes[i + 1]))
        layers.append(Projector(config,sizes[-2], sizes[-1]))
        self.projector = nn.Sequential(*layers)
        
        self.apply(self._init_weights)
        
#     def get_output_embeddings(self):
#         return self.cls.decoder

#     def set_output_embeddings(self, new_embeddings):
#         self.cls.decoder = new_embeddings
        
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        return self.projector(outputs.last_hidden_state)

In [11]:
projector = Projector(config,256,256)

In [12]:
bert = BertForBarlowTwins(config)

In [13]:
bert

BertForBarlowTwins(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=128, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=128, out_features=128, bias=True)
              (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine

In [34]:
input_1 = tokenizer("The capital of France is [MASK].", return_tensors="pt",padding='max_length')
input_2 = tokenizer("I am going to Berlin.", return_tensors="pt",padding='max_length')

In [35]:
output1 = bert(**input_1)
output2 = bert(**input_2)

In [36]:
output1.shape, output2.shape

(torch.Size([1, 512, 256]), torch.Size([1, 512, 256]))

In [20]:
output1.T @ output2

RuntimeError: Expected batch2_sizes[0] == bs && batch2_sizes[1] == contraction_size to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

In [17]:
output1.T.shape

torch.Size([256, 9])

In [26]:
torch.rand(9,1)@torch.rand(1,9)

tensor([[2.4702e-01, 2.5387e-01, 7.9965e-02, 1.1321e-01, 1.2602e-01, 1.9818e-01,
         7.6304e-03, 6.9384e-02, 4.3218e-07],
        [8.1785e-01, 8.4051e-01, 2.6475e-01, 3.7483e-01, 4.1721e-01, 6.5615e-01,
         2.5263e-02, 2.2972e-01, 1.4309e-06],
        [2.1965e-01, 2.2574e-01, 7.1105e-02, 1.0067e-01, 1.1205e-01, 1.7623e-01,
         6.7850e-03, 6.1696e-02, 3.8430e-07],
        [8.2524e-01, 8.4810e-01, 2.6714e-01, 3.7822e-01, 4.2098e-01, 6.6208e-01,
         2.5491e-02, 2.3179e-01, 1.4438e-06],
        [3.2830e-01, 3.3740e-01, 1.0628e-01, 1.5047e-01, 1.6748e-01, 2.6339e-01,
         1.0141e-02, 9.2213e-02, 5.7438e-07],
        [1.8539e-01, 1.9053e-01, 6.0013e-02, 8.4967e-02, 9.4573e-02, 1.4874e-01,
         5.7266e-03, 5.2072e-02, 3.2435e-07],
        [7.1850e-01, 7.3840e-01, 2.3259e-01, 3.2930e-01, 3.6653e-01, 5.7644e-01,
         2.2194e-02, 2.0181e-01, 1.2570e-06],
        [8.1185e-01, 8.3434e-01, 2.6281e-01, 3.7208e-01, 4.1415e-01, 6.5134e-01,
         2.5078e-02, 2.2803e-0

In [24]:
tokenizer.decode(input_1['input_ids'][0])

'[CLS] the capital of france is [MASK]. [SEP]'

In [23]:
input_1['input_ids'][0]

tensor([ 101, 1996, 3007, 1997, 2605, 2003,  103, 1012,  102])

In [27]:
bert_lm = BertForMaskedLM(config=config)

In [28]:
out_1_lm = bert_lm(**input_1)

In [31]:
out_1_lm.logits.shape

torch.Size([1, 9, 30522])

In [32]:
bert_lm

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=128, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=128, out_features=128, bias=True)
              (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=Tr