In [9]:
from datasets import load_dataset, load_from_disk
from transformers import CONFIG_MAPPING, BertModel, BertTokenizerFast
import torch

In [3]:
config = CONFIG_MAPPING['bert'].from_pretrained('bert-base-uncased')

In [4]:
bert = BertModel(config,add_pooling_layer=True).from_pretrained('bert-base-uncased')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
corpus = load_dataset('bookcorpus',split='train')

Reusing dataset bookcorpus (/mounts/data/corp/huggingface/datasets/bookcorpus/plain_text/1.0.0/44662c4a114441c35200992bea923b170e6f13f2f0beb7c14e43759cec498700)


In [10]:
tokenized_corpus = load_from_disk('/mounts/data/proj/jabbar/barlowbert/bookcorpus_20mil_128/')

In [17]:
tokenized_corpus.set_format('torch')

In [7]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

In [42]:
tokenizer.decode(tokenized_corpus[3]['input_ids'])

'[CLS] starlings, new york is not the place youd expect much to happen. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'

In [44]:
bert.train()

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

In [45]:
output1 = bert(**tokenized_corpus[:8]).pooler_output
output2 = bert(**tokenized_corpus[:8]).pooler_output

In [78]:
output1.shape

torch.Size([8, 768])

In [48]:
corr = output1@output2.transpose(0,1)

In [55]:
corr

tensor([[297.4918, 318.6734, 167.6713, 296.1940, 213.6166, 323.7787, 199.3153,
         295.8680],
        [261.3195, 281.3953, 215.0332, 282.9865, 225.3754, 307.4869, 221.2936,
         282.1366],
        [157.5238, 166.2443, 212.6556, 185.5733, 175.3278, 201.3529, 184.8230,
         193.3990],
        [273.4422, 311.7669, 198.9742, 308.6535, 230.3356, 330.9769, 222.9607,
         314.0715],
        [306.1982, 362.0233, 239.9843, 353.9535, 261.9241, 388.3650, 255.9556,
         373.7408],
        [283.5265, 321.2749, 208.2207, 314.8960, 235.7771, 344.4146, 230.8725,
         324.1483],
        [233.1920, 257.7081, 223.5415, 266.9668, 219.5132, 292.3654, 228.0817,
         280.1246],
        [312.9385, 378.4106, 239.3377, 366.3718, 262.3573, 401.9810, 257.0002,
         390.0294]], grad_fn=<MmBackward>)

In [79]:
cos = torch.nn.CosineSimilarity(dim=1)

In [81]:
cos(output1,output2)

tensor([0.9601, 0.9308, 0.9698, 0.9888, 0.8885, 0.9836, 0.9834, 0.9695],
       grad_fn=<DivBackward0>)

In [53]:
def pearson(x1,x2):
    cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
    pearson = cos(x1 - x1.mean(dim=1, keepdim=True), x2 - x2.mean(dim=1, keepdim=True))
    return pearson

In [127]:
def cos_sim(x1,x2):
    batch_size = x1.shape[0]
    x1n = x1/x1.norm(dim=1, keepdim=True)
    x2n = x2/x2.norm(dim=1, keepdim=True)
    cs = x1n@x2n.transpose(0,1)
    return cs.div_(batch_size) 

In [124]:
def corr(x1,x2):
    x1m = (x1 - x1.mean(dim=1, keepdim=True))/x1.norm(dim=1, keepdim=True)
    x2m = (x2 - x2.mean(dim=1, keepdim=True))/x2.norm(dim=1, keepdim=True)
    return x1m@x2m.transpose(0,1)

In [54]:
pearson(output1,output2)

tensor([0.9601, 0.9310, 0.9699, 0.9888, 0.8887, 0.9836, 0.9834, 0.9696],
       grad_fn=<DivBackward0>)

In [125]:
corr(output1,output2)

tensor([[0.9594, 0.9581, 0.5748, 0.9170, 0.8239, 0.9170, 0.7567, 0.8829],
        [0.9265, 0.9303, 0.8097, 0.9632, 0.9555, 0.9574, 0.9232, 0.9255],
        [0.6769, 0.6664, 0.9697, 0.7655, 0.9006, 0.7598, 0.9340, 0.7689],
        [0.9120, 0.9695, 0.7052, 0.9883, 0.9188, 0.9695, 0.8753, 0.9693],
        [0.8681, 0.9570, 0.7233, 0.9635, 0.8883, 0.9671, 0.8543, 0.9806],
        [0.9216, 0.9738, 0.7191, 0.9827, 0.9166, 0.9832, 0.8832, 0.9750],
        [0.8548, 0.8810, 0.8698, 0.9394, 0.9620, 0.9411, 0.9834, 0.9500],
        [0.8401, 0.9471, 0.6831, 0.9443, 0.8425, 0.9478, 0.8123, 0.9690]],
       grad_fn=<MmBackward>)

In [111]:
cs = cos_sim(output1,output2)

In [120]:
torch.diagonal(cs).add_(-1).pow_(2).mean()

tensor(0.0027, grad_fn=<MeanBackward0>)

In [114]:
def off_diagonal(x):
    # return a flattened view of the off-diagonal elements of a square matrix
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

In [119]:
off_diagonal(cs).pow_(2).mean()

tensor(0.7813, grad_fn=<MeanBackward0>)

In [117]:
cs

tensor([[0.9601, 0.9592, 0.5743, 0.9175, 0.8242, 0.9176, 0.7565, 0.8833],
        [0.9268, 0.9308, 0.8094, 0.9634, 0.9556, 0.9577, 0.9231, 0.9257],
        [0.6768, 0.6662, 0.9698, 0.7654, 0.9006, 0.7597, 0.9340, 0.7688],
        [0.9126, 0.9704, 0.7048, 0.9888, 0.9190, 0.9700, 0.8752, 0.9697],
        [0.8688, 0.9580, 0.7227, 0.9640, 0.8885, 0.9677, 0.8542, 0.9811],
        [0.9221, 0.9745, 0.7187, 0.9830, 0.9167, 0.9836, 0.8831, 0.9753],
        [0.8549, 0.8811, 0.8698, 0.9394, 0.9620, 0.9412, 0.9834, 0.9500],
        [0.8408, 0.9483, 0.6825, 0.9449, 0.8427, 0.9485, 0.8122, 0.9695]],
       grad_fn=<MmBackward>)

In [126]:
cos_sim(output1,output2)-corr(output1,output2)

tensor([[ 6.7735e-04,  1.0337e-03, -5.1188e-04,  5.3900e-04,  2.0713e-04,
          5.6118e-04, -1.3322e-04,  4.8608e-04],
        [ 3.0065e-04,  4.5884e-04, -2.2721e-04,  2.3913e-04,  9.1970e-05,
          2.4921e-04, -5.9068e-05,  2.1565e-04],
        [-9.8884e-05, -1.5086e-04,  7.4685e-05, -7.8738e-05, -3.0160e-05,
         -8.1837e-05,  1.9312e-05, -7.0930e-05],
        [ 5.7793e-04,  8.8149e-04, -4.3654e-04,  4.5967e-04,  1.7667e-04,
          4.7892e-04, -1.1361e-04,  4.1479e-04],
        [ 7.0310e-04,  1.0729e-03, -5.3120e-04,  5.5945e-04,  2.1493e-04,
          5.8240e-04, -1.3834e-04,  5.0473e-04],
        [ 4.7928e-04,  7.3081e-04, -3.6192e-04,  3.8117e-04,  1.4645e-04,
          3.9679e-04, -9.4116e-05,  3.4386e-04],
        [ 4.4465e-05,  6.7830e-05, -3.3677e-05,  3.5465e-05,  1.3709e-05,
          3.6895e-05, -8.8215e-06,  3.1829e-05],
        [ 7.6663e-04,  1.1697e-03, -5.7912e-04,  6.0982e-04,  2.3437e-04,
          6.3515e-04, -1.5086e-04,  5.4997e-04]], grad_fn=<SubBac

In [133]:
cos_sim(output1,output2)

tensor([[0.1200, 0.1199, 0.0718, 0.1147, 0.1030, 0.1147, 0.0946, 0.1104],
        [0.1159, 0.1164, 0.1012, 0.1204, 0.1194, 0.1197, 0.1154, 0.1157],
        [0.0846, 0.0833, 0.1212, 0.0957, 0.1126, 0.0950, 0.1168, 0.0961],
        [0.1141, 0.1213, 0.0881, 0.1236, 0.1149, 0.1212, 0.1094, 0.1212],
        [0.1086, 0.1198, 0.0903, 0.1205, 0.1111, 0.1210, 0.1068, 0.1226],
        [0.1153, 0.1218, 0.0898, 0.1229, 0.1146, 0.1230, 0.1104, 0.1219],
        [0.1069, 0.1101, 0.1087, 0.1174, 0.1203, 0.1176, 0.1229, 0.1188],
        [0.1051, 0.1185, 0.0853, 0.1181, 0.1053, 0.1186, 0.1015, 0.1212]],
       grad_fn=<DivBackward0>)

In [137]:
off_diagonal(cs)

tensor([0.9592, 0.5743, 0.9175, 0.8242, 0.9176, 0.7565, 0.8833, 0.9268, 0.8094,
        0.9634, 0.9556, 0.9577, 0.9231, 0.9257, 0.6768, 0.6662, 0.7654, 0.9006,
        0.7597, 0.9340, 0.7688, 0.9126, 0.9704, 0.7048, 0.9190, 0.9700, 0.8752,
        0.9697, 0.8688, 0.9580, 0.7227, 0.9640, 0.9677, 0.8542, 0.9811, 0.9221,
        0.9745, 0.7187, 0.9830, 0.9167, 0.8831, 0.9753, 0.8549, 0.8811, 0.8698,
        0.9394, 0.9620, 0.9412, 0.9500, 0.8408, 0.9483, 0.6825, 0.9449, 0.8427,
        0.9485, 0.8122], grad_fn=<UnsafeViewBackward>)

In [136]:
off_diagonal(cos_sim(output1,output2))

tensor([0.1199, 0.0718, 0.1147, 0.1030, 0.1147, 0.0946, 0.1104, 0.1159, 0.1012,
        0.1204, 0.1194, 0.1197, 0.1154, 0.1157, 0.0846, 0.0833, 0.0957, 0.1126,
        0.0950, 0.1168, 0.0961, 0.1141, 0.1213, 0.0881, 0.1149, 0.1212, 0.1094,
        0.1212, 0.1086, 0.1198, 0.0903, 0.1205, 0.1210, 0.1068, 0.1226, 0.1153,
        0.1218, 0.0898, 0.1229, 0.1146, 0.1104, 0.1219, 0.1069, 0.1101, 0.1087,
        0.1174, 0.1203, 0.1176, 0.1188, 0.1051, 0.1185, 0.0853, 0.1181, 0.1053,
        0.1186, 0.1015], grad_fn=<UnsafeViewBackward>)