In [3]:
from transformers import CONFIG_MAPPING, BertForMaskedLM, BertModel, BertPreTrainedModel, BertTokenizer
import torch
import torch.nn as nn
import torch.nn.functional as F
import pdb

In [7]:
ACT2FN = {
    "relu": F.relu,
    "silu": F.silu,
    "swish": F.silu,
    "gelu": F.gelu,
    "tanh": torch.tanh,
    "sigmoid": torch.sigmoid,
}

In [4]:
bert_tiny = {
    "hidden_size" : 128 ,
    "num_hidden_layers" : 2,
    "num_attention_heads": int(128/64),
    "intermediate_size" : int(128*4)
}

In [5]:
config = CONFIG_MAPPING['bert']()
config.update(bert_tiny)
config.projector='128-128'
config

BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 128,
  "initializer_range": 0.02,
  "intermediate_size": 512,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 2,
  "num_hidden_layers": 2,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "projector": "128-128",
  "transformers_version": "4.8.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [6]:
class Projector(nn.Module):
    def __init__(self, config,input_size,output_size):
        super().__init__()
        self.dense = nn.Linear(input_size, output_size,bias=False)
        if isinstance(config.hidden_act, str):
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
        self.LayerNorm = nn.LayerNorm(output_size, eps=config.layer_norm_eps)

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states

In [8]:
def off_diagonal(x):
    # return a flattened view of the off-diagonal elements of a square matrix
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

In [18]:
class BertForBarlowTwins(BertPreTrainedModel):

    _keys_to_ignore_on_load_unexpected = [r"pooler"]
    _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
    
    def _init_weights(self, module):
        """Initialize the weights"""
#         pdb.set_trace()
        if isinstance(module, nn.Linear):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
            
    def __init__(self, config):
        super().__init__(config)

        if config.is_decoder:
            logger.warning(
                "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
                "bi-directional self-attention."
            )
            
        self.bert = BertModel(config, add_pooling_layer=True)
        
        sizes = [config.hidden_size] + list(map(int, config.projector.split('-')))
        layers = []
        for i in range(len(sizes) - 2):
            layers.append(Projector(config,sizes[i], sizes[i + 1]))
        layers.append(Projector(config,sizes[-2], sizes[-1]))
        self.projector = nn.Sequential(*layers)
        self.bn = nn.BatchNorm1d((128,128), affine=False)
        
        self.apply(self._init_weights)
        
#     def get_output_embeddings(self):
#         return self.cls.decoder

#     def set_output_embeddings(self, new_embeddings):
#         self.cls.decoder = new_embeddings
        
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
#         pdb.set_trace()
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        projection = self.projector(outputs.pooler_output)
        
        return projection#,self.bn(projection)

In [20]:
model = BertForBarlowTwins(config)

In [21]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')#, padding=True,truncation=True)

In [27]:
input_1 = tokenizer("The capital of France is [MASK].", return_tensors="pt",padding='max_length',truncation=True)
input_2 = tokenizer("I am going to Berlin.", return_tensors="pt",padding='max_length',truncation=True)

In [32]:
out_2=model(**input_2)

In [54]:
class BarlowBert(nn.Module):
    
    def __init__(self, config):
        super().__init__()
        
        self.model = BertForBarlowTwins(config)
        self.scale_loss = 1/32
        self.lambd = 3.9e-3
        
    def forward(self, y1, y2):
        output1 = self.model(**y1)
        output2 = self.model(**y2)
        
        c = (output1.transpose(0,1) @ output2)
        
        # use --scale-loss to multiply the loss by a constant factor
        # see the Issues section of the readme
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum().mul(self.scale_loss)
        off_diag = off_diagonal(c).pow_(2).sum().mul(self.scale_loss)
        loss = on_diag + self.lambd * off_diag
        return loss      

In [47]:
c=(out_2.T@out_2)

In [51]:
torch.diagonal(c).add_(-1).pow_(2).sum()

tensor(238.8449, grad_fn=<SumBackward0>)

In [53]:
off_diagonal(c).pow_(2).sum()

tensor(16017.1543, grad_fn=<SumBackward0>)

In [55]:
barlow = BarlowBert(config)

In [56]:
barlow(input_1,input_2)

tensor(24.7527, grad_fn=<AddBackward0>)