In [31]:
import torch.nn as nn
import torch
import pdb
# import torchvision
from transformers import CONFIG_MAPPING, AutoModelForMaskedLM, BertTokenizer

In [3]:
class BarlowTwins(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.backbone = torchvision.models.resnet50(zero_init_residual=True)
        self.backbone.fc = nn.Identity()
        # pdb.set_trace()
        # projector
        sizes = [2048] + list(map(int, args.projector.split('-')))
        print(sizes)
        layers = []
        for i in range(len(sizes) - 2):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
            layers.append(nn.BatchNorm1d(sizes[i + 1]))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
        self.projector = nn.Sequential(*layers)

        # normalization layer for the representations z1 and z2
        self.bn = nn.BatchNorm1d(sizes[-1], affine=False)

    def forward(self, y1, y2):
        z1 = self.projector(self.backbone(y1))
        z2 = self.projector(self.backbone(y2))

        # empirical cross-correlation matrix
        c = self.bn(z1).T @ self.bn(z2)

        # sum the cross-correlation matrix between all gpus
        c.div_(self.args.batch_size)
        torch.distributed.all_reduce(c)

        # use --scale-loss to multiply the loss by a constant factor
        # see the Issues section of the readme
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum().mul(self.args.scale_loss)
        off_diag = off_diagonal(c).pow_(2).sum().mul(self.args.scale_loss)
        loss = on_diag + self.args.lambd * off_diag
        return loss

In [4]:
class Args():
    def __init__(self,lambd, batch_size, scale_loss, projector):
        self.lambd = lambd
        self.batch_size = batch_size
        self.scale_loss = scale_loss
        self.projector = projector

In [5]:
args = Args(0.005,8,1/32,'64-64-64')

In [5]:
model = BarlowTwins(args)

[2048, 64, 64, 64]


In [6]:
backbone = torchvision.models.resnet50(zero_init_residual=True)

In [7]:
backbone.fc

Linear(in_features=2048, out_features=1000, bias=True)

In [6]:
bert_config = CONFIG_MAPPING['bert']()

In [12]:
bert_config

BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.5.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [18]:
bert_small = {
    "hidden_size" : 512 ,
    "num_hidden_layers" : 4,
    "num_attention_heads": int(512/64),
    "intermediate_size" : int(512*4)
}

bert_tiny = {
    "hidden_size" : 128 ,
    "num_hidden_layers" : 2,
    "num_attention_heads": int(128/64),
    "intermediate_size" : int(128*4)
}

In [26]:
bert_config.update(bert_tiny)

In [27]:
model = AutoModelForMaskedLM.from_config(bert_config)

In [32]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [29]:
sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6

4.416698

In [33]:
inputs = tokenizer("The capital of France is [MASK].", return_tensors="pt")

In [34]:
labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]

In [35]:
outputs = model(**inputs)

In [38]:
outputs.logits.shape

torch.Size([1, 9, 30522])

In [39]:
labels.shape

torch.Size([1, 9])