In [22]:
from transformers import CONFIG_MAPPING, BertForMaskedLM, BertModel, BertPreTrainedModel, BertTokenizer
import torch
import torch.nn as nn
import torch.nn.functional as F
import pdb
from datasets import load_dataset

In [2]:
ACT2FN = {
    "relu": F.relu,
    "silu": F.silu,
    "swish": F.silu,
    "gelu": F.gelu,
    "tanh": torch.tanh,
    "sigmoid": torch.sigmoid,
}

In [3]:
bert_tiny = {
    "hidden_size" : 128 ,
    "num_hidden_layers" : 2,
    "num_attention_heads": int(128/64),
    "intermediate_size" : int(128*4)
}

In [4]:
config = CONFIG_MAPPING['bert']()
config.update(bert_tiny)
config.projector='128-128'
config

BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 128,
  "initializer_range": 0.02,
  "intermediate_size": 512,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 2,
  "num_hidden_layers": 2,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "projector": "128-128",
  "transformers_version": "4.2.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [5]:
class Projector(nn.Module):
    def __init__(self, config,input_size,output_size):
        super().__init__()
        self.dense = nn.Linear(input_size, output_size,bias=False)
        if isinstance(config.hidden_act, str):
            self.transform_act_fn = ACT2FN[config.hidden_act]
        else:
            self.transform_act_fn = config.hidden_act
        self.LayerNorm = nn.LayerNorm(output_size, eps=config.layer_norm_eps)

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        return hidden_states

In [6]:
def off_diagonal(x):
    # return a flattened view of the off-diagonal elements of a square matrix
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

In [7]:
class BertForBarlowTwins(BertPreTrainedModel):

    _keys_to_ignore_on_load_unexpected = [r"pooler"]
    _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
    
    def _init_weights(self, module):
        """Initialize the weights"""
#         pdb.set_trace()
        if isinstance(module, nn.Linear):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
            
    def __init__(self, config):
        super().__init__(config)

        if config.is_decoder:
            logger.warning(
                "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
                "bi-directional self-attention."
            )
            
        self.bert = BertModel(config, add_pooling_layer=True)
        
        sizes = [config.hidden_size] + list(map(int, config.projector.split('-')))
        layers = []
        for i in range(len(sizes) - 2):
            layers.append(Projector(config,sizes[i], sizes[i + 1]))
        layers.append(Projector(config,sizes[-2], sizes[-1]))
        self.projector = nn.Sequential(*layers)
        self.bn = nn.BatchNorm1d((128,128), affine=False)
        
        self.apply(self._init_weights)
        
#     def get_output_embeddings(self):
#         return self.cls.decoder

#     def set_output_embeddings(self, new_embeddings):
#         self.cls.decoder = new_embeddings
        
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
#         pdb.set_trace()
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        projection = self.projector(outputs.pooler_output)
        
        return projection#,self.bn(projection)

In [8]:
model = BertForBarlowTwins(config)

In [9]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')#, padding=True,truncation=True)

In [10]:
input_1 = tokenizer("The capital of France is [MASK].", return_tensors="pt",padding='max_length',truncation=True)
input_2 = tokenizer("I am going to Berlin.", return_tensors="pt",padding='max_length',truncation=True)

In [23]:
dataset = load_dataset('bookcorpus',split='train')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1689.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=946.0, style=ProgressStyle(description_…

Reusing dataset bookcorpus (/mounts/data/corp/huggingface/datasets/bookcorpus/plain_text/1.0.0/af844be26c089fb64810e9f2cd841954fd8bd596d6ddd26326e4c70e2b8c96fc)





In [43]:
min_tokenized=dataset.select(range(100)).map(lambda e: tokenizer(e['text'],padding='max_length',truncation=True),remove_columns=['text'])

HBox(children=(FloatProgress(value=0.0), HTML(value='')))




In [44]:
min_tokenized.set_format('torch')

In [45]:
dataloader = torch.utils.data.DataLoader(min_tokenized,batch_size=16)

In [46]:
batch = next(iter(dataloader))

In [81]:
batch

{'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]),
 'input_ids': tensor([[ 101, 1996, 2431,  ...,    0,    0,    0],
         [ 101, 3175, 1024,  ...,    0,    0,    0],
         [ 101, 1045, 4299,  ...,    0,    0,    0],
         ...,
         [ 101, 1045, 2921,  ...,    0,    0,    0],
         [ 101, 2788, 1045,  ...,    0,    0,    0],
         [ 101, 1045, 6476,  ...,    0,    0,    0]]),
 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]])}

In [None]:
tokenizer.pad(examples, return_tensors="pt")

In [88]:
z1=model(**batch)
z2=model(**batch)

In [64]:
N, D = z1.size()

In [76]:
diag = torch.eye(D, device=z1.device)

# Barlow Twins Loss

In [65]:
z1m = z1 - z1.mean(dim=0)
z2m = z2 - z2.mean(dim=0)

In [77]:
c=(z1m.T@z2m) / (N - 1)

In [78]:
torch.diagonal(c).add_(-1).pow_(2).sum()/D

tensor(0.9903, grad_fn=<DivBackward0>)

In [79]:
off_diagonal(c).pow_(2).sum()/D

tensor(0.0898, grad_fn=<DivBackward0>)

In [80]:
c[~diag.bool()].pow_(2).sum()/D

tensor(0.0898, grad_fn=<DivBackward0>)

# Covariance Loss

In [69]:
z1m = z1 - z1.mean(dim=0)
z2m = z2 - z2.mean(dim=0)

In [70]:
cov_z1 = (z1m.T @ z1m) / (N - 1)
cov_z2 = (z2m.T @ z2m) / (N - 1)

cov_z1[~diag.bool()].pow_(2).sum() / D + cov_z2[~diag.bool()].pow_(2).sum() / D

tensor(0.3314, grad_fn=<AddBackward0>)

# Variance Loss

In [51]:
eps = 1e-4
std_z1 = torch.sqrt(z1.var(dim=0) + eps)
std_z2 = torch.sqrt(z2.var(dim=0) + eps)
std_loss = torch.mean(F.relu(1 - std_z1)) + torch.mean(F.relu(1 - std_z2))

In [86]:
z1.shape

torch.Size([16, 128])

In [89]:
z1.var(dim=0)

tensor([0.1805, 0.0803, 0.0371, 0.1373, 0.0327, 0.1542, 0.1113, 0.2210, 0.0287,
        0.0155, 0.0603, 0.0897, 0.0307, 0.0185, 0.1443, 0.0727, 0.0291, 0.0170,
        0.0093, 0.1019, 0.0976, 0.1299, 0.0944, 0.1271, 0.1457, 0.0797, 0.0220,
        0.0729, 0.0923, 0.0574, 0.0942, 0.1191, 0.0574, 0.2731, 0.0997, 0.1672,
        0.1743, 0.2080, 0.1008, 0.0405, 0.1845, 0.2886, 0.0398, 0.1831, 0.0500,
        0.1044, 0.0511, 0.0836, 0.0214, 0.0311, 0.0970, 0.0383, 0.5869, 0.0233,
        0.0834, 0.1279, 0.2007, 0.1390, 0.0073, 0.0905, 0.0634, 0.1566, 0.0673,
        0.1147, 0.0595, 0.0555, 0.0945, 0.1078, 0.2833, 0.1474, 0.3221, 0.0831,
        0.1491, 0.0134, 0.0560, 0.0723, 0.0554, 0.2790, 0.0747, 0.1271, 0.1810,
        0.1856, 0.0646, 0.0758, 0.1822, 0.1752, 0.1077, 0.1482, 0.1098, 0.0858,
        0.0497, 0.0699, 0.3350, 0.1854, 0.2708, 0.0408, 0.0538, 0.1051, 0.0951,
        0.0192, 0.1096, 0.1308, 0.2213, 0.0155, 0.0703, 0.0480, 0.0424, 0.1252,
        0.1210, 0.0377, 0.1597, 0.4602, 

In [56]:
F.relu(1 - std_z1)

tensor([0.8065, 0.8827, 0.8593, 0.7133, 0.8951, 0.6573, 0.7037, 0.7396, 0.6810,
        0.8893, 0.7314, 0.5148, 0.8443, 0.8209, 0.6200, 0.8423, 0.8944, 0.8716,
        0.9449, 0.6570, 0.7384, 0.7558, 0.7705, 0.7297, 0.6275, 0.8227, 0.8265,
        0.8472, 0.6689, 0.7906, 0.7722, 0.7049, 0.7035, 0.7905, 0.7418, 0.6278,
        0.6226, 0.7449, 0.5424, 0.7018, 0.6458, 0.6325, 0.8602, 0.8256, 0.8674,
        0.6654, 0.7883, 0.8475, 0.8720, 0.7822, 0.7624, 0.8474, 0.6376, 0.9141,
        0.8179, 0.7264, 0.7338, 0.8114, 0.9535, 0.7120, 0.7395, 0.8275, 0.6705,
        0.8042, 0.9198, 0.8902, 0.7696, 0.7035, 0.6811, 0.7246, 0.5317, 0.6856,
        0.8009, 0.9226, 0.9310, 0.7896, 0.7854, 0.6225, 0.7812, 0.6297, 0.7130,
        0.6941, 0.7906, 0.8219, 0.3729, 0.5896, 0.7334, 0.6301, 0.7397, 0.8049,
        0.7329, 0.8140, 0.5573, 0.7892, 0.6147, 0.7629, 0.8114, 0.7583, 0.7534,
        0.8312, 0.7856, 0.7605, 0.7835, 0.8806, 0.7663, 0.6906, 0.9004, 0.7610,
        0.7407, 0.7974, 0.6445, 0.5366, 

## BarlowBert

In [54]:
class BarlowBert(nn.Module):
    
    def __init__(self, config):
        super().__init__()
        
        self.model = BertForBarlowTwins(config)
        self.scale_loss = 1/32
        self.lambd = 3.9e-3
        
    def forward(self, y1, y2):
        output1 = self.model(**y1)
        output2 = self.model(**y2)
        
        c = (output1.transpose(0,1) @ output2)
        
        # use --scale-loss to multiply the loss by a constant factor
        # see the Issues section of the readme
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum().mul(self.scale_loss)
        off_diag = off_diagonal(c).pow_(2).sum().mul(self.scale_loss)
        loss = on_diag + self.lambd * off_diag
        return loss      

In [55]:
barlow = BarlowBert(config)

In [56]:
barlow(input_1,input_2)

tensor(24.7527, grad_fn=<AddBackward0>)