In [1]:
from transformers import CONFIG_MAPPING, BertForMaskedLM, BertModel, BertPreTrainedModel, BertTokenizer
import torch
import torch.nn as nn
import torch.nn.functional as F
import pdb

In [2]:
num_param = lambda model : sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6

In [3]:
ACT2FN = {
    "relu": F.relu,
    "silu": F.silu,
    "swish": F.silu,
    "gelu": F.gelu,
    "tanh": torch.tanh,
    "sigmoid": torch.sigmoid,
}

In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [5]:
bert_small = {
    "hidden_size" : 512 ,
    "num_hidden_layers" : 4,
    "num_attention_heads": int(512/64),
    "intermediate_size" : int(512*4)
}

bert_tiny = {
    "hidden_size" : 128 ,
    "num_hidden_layers" : 2,
    "num_attention_heads": int(128/64),
    "intermediate_size" : int(128*4)
}

In [6]:
config = CONFIG_MAPPING['bert']()
config.update(bert_tiny)
config.projector='128-128'
config

BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 128,
  "initializer_range": 0.02,
  "intermediate_size": 512,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 2,
  "num_hidden_layers": 2,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "projector": "128-128",
  "transformers_version": "4.8.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [None]:
class Pooling(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.mean_pooling = config.mean_pooling
        self.max_pooling = config.max_pooling
        self.cls_pooling = config.cls_pooling

    def forward(self, bert_encoder_output):
        attention_mask = bert_encoder_output.attentions
        token_embeddings = bert_encoder_output.last_hidden_state

        output_vectors = []

        input_mask_expanded = attention_mask.unsqueeze(-1).expand_like(token_embeddings).float()
        sum_mask = input_mask_expanded.sum(1)

        if self.cls_pooling:
            first_token_tensor = token_embeddings[:, 0]
            output_vectors.append(first_token_tensor)
        if self.mean_pooling:
            sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
            output_vectors.append(sum_embeddings / sum_mask)
        if self.max_pooling:
            token_embeddings[input_mask_expanded == 0] = -1e9  # Set padding tokens to large negative value
            max_over_time = torch.max(token_embeddings, 1)[0]
            output_vectors.append(max_over_time)    
        output_vector = torch.cat(output_vectors, 1)

        bert_encoder_output.update({'sentence_embedding': output_vector})
        return bert_encoder_output

In [7]:
class Projector(nn.Module):
    def __init__(self, config,input_size,output_size):
        super().__init__()
        self.dense = nn.Linear(input_size, output_size,bias=False)
        if isinstance(config.hidden_act, str):
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
        self.LayerNorm = nn.LayerNorm(output_size, eps=config.layer_norm_eps)

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states

In [8]:
def off_diagonal(x):
    # return a flattened view of the off-diagonal elements of a square matrix
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

In [99]:
class BertForBarlowTwins(BertPreTrainedModel):

    _keys_to_ignore_on_load_unexpected = [r"pooler"]
    _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
    
    def _init_weights(self, module):
        """Initialize the weights"""
#         pdb.set_trace()
        if isinstance(module, nn.Linear):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
            
    def __init__(self, config):
        super().__init__(config)

        if config.is_decoder:
            logger.warning(
                "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
                "bi-directional self-attention."
            )
            
        self.bert = BertModel(config, add_pooling_layer=False)
        
        sizes = [config.hidden_size] + list(map(int, config.projector.split('-')))
        layers = []
        for i in range(len(sizes) - 2):
            layers.append(Projector(config,sizes[i], sizes[i + 1]))
        layers.append(Projector(config,sizes[-2], sizes[-1]))
        self.projector = nn.Sequential(*layers)
        self.bn = nn.BatchNorm1d((128,128), affine=False)
        
        self.apply(self._init_weights)
        
#     def get_output_embeddings(self):
#         return self.cls.decoder

#     def set_output_embeddings(self, new_embeddings):
#         self.cls.decoder = new_embeddings
        
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        projection = self.projector(outputs.last_hidden_state)
        
        return self.bn(projection)

In [100]:
class MyBarlow(nn.Module):
    
    def __init__(self, config):
        super().__init__()
        
        self.model = BertForBarlowTwins(config)
        self.scale_loss = 1/32
        self.lambd = 3.9e-3
        
    def forward(self, y1, y2):
        z1 = self.model(**y1)
        z2 = self.model(**y2)
        
        c = (output1.transpose(1,2) @ output2).sum(0)
        
        # use --scale-loss to multiply the loss by a constant factor
        # see the Issues section of the readme
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum().mul(self.scale_loss)
        off_diag = off_diagonal(c).pow_(2).sum().mul(self.scale_loss)
        loss = on_diag + self.lambd * off_diag
        return loss        

In [101]:
bert = BertForBarlowTwins(config)

In [102]:
nn.BatchNorm1d(tuple(map(int,config.projector.split('-'))))

BatchNorm1d((128, 128), eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

In [103]:
input_1 = tokenizer("The capital of France is [MASK].", return_tensors="pt",padding='max_length')
input_2 = tokenizer("I am going to Berlin.", return_tensors="pt",padding='max_length')

In [104]:
model = MyBarlow(config)

In [105]:
model(input_1,input_2)

RuntimeError: running_mean should contain 512 elements not 16384

In [106]:
output1 = bert(**input_1)
output2 = bert(**input_2)

RuntimeError: running_mean should contain 512 elements not 16384

In [65]:
output1.shape, output2.shape

(torch.Size([1, 512, 128]), torch.Size([1, 512, 128]))

In [30]:
(output1.transpose(1,2) @ output2).shape

torch.Size([1, 128, 128])

In [20]:
torch.matmul(output1.transpose(1,2), output1).size()

torch.Size([1, 256, 256])

In [42]:
a=torch.rand(5,9,1)
b=torch.rand(5,1,9)

In [43]:
(a@b).sum(0)-torch.sum(a@b,dim=0)

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [None]:
(torch.rand(5,9,1)@torch.rand(5,1,9)).sum

In [24]:
tokenizer.decode(input_1['input_ids'][0])

'[CLS] the capital of france is [MASK]. [SEP]'

In [23]:
input_1['input_ids'][0]

tensor([ 101, 1996, 3007, 1997, 2605, 2003,  103, 1012,  102])

In [27]:
bert_lm = BertForMaskedLM(config=config)

In [28]:
out_1_lm = bert_lm(**input_1)

In [31]:
out_1_lm.logits.shape

torch.Size([1, 9, 30522])

In [32]:
bert_lm

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=128, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=128, out_features=128, bias=True)
              (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=Tr