In [88]:
import torch.nn as nn
import torch
import pdb
import torchvision
import torchvision.transforms as transforms
from PIL import Image, ImageOps, ImageFilter
from torch import nn, optim
import torch

from main import GaussianBlur, Solarization, off_diagonal

In [70]:
class Transform:
    def __init__(self):
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                        saturation=0.2, hue=0.1)],
                p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
            GaussianBlur(p=1.0),
            Solarization(p=0.0),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        self.transform_prime = transforms.Compose([
            transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                        saturation=0.2, hue=0.1)],
                p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
            GaussianBlur(p=0.1),
            Solarization(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def __call__(self, x):
        y1 = self.transform(x)
        y2 = self.transform_prime(x)
        return y1, y2

In [81]:
class BarlowTwins(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.backbone = torchvision.models.resnet50(zero_init_residual=True)
        self.backbone.fc = nn.Identity()
        # pdb.set_trace()
        # projector
        sizes = [2048] + list(map(int, args.projector.split('-')))
        print(sizes)
        layers = []
        for i in range(len(sizes) - 2):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
            layers.append(nn.BatchNorm1d(sizes[i + 1]))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
        self.projector = nn.Sequential(*layers)

        # normalization layer for the representations z1 and z2
        self.bn = nn.BatchNorm1d(sizes[-1], affine=False)

    def forward(self, y1, y2):
        z1 = self.projector(self.backbone(y1))
        z2 = self.projector(self.backbone(y2))

        # empirical cross-correlation matrix
        c = self.bn(z1).T @ self.bn(z2)

        # sum the cross-correlation matrix between all gpus
        c.div_(self.args.batch_size)
#         torch.distributed.all_reduce(c)

        # use --scale-loss to multiply the loss by a constant factor
        # see the Issues section of the readme
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum().mul(self.args.scale_loss)
        off_diag = off_diagonal(c).pow_(2).sum().mul(self.args.scale_loss)
        loss = on_diag + self.args.lambd * off_diag
        return loss,c

In [82]:
class Args():
    def __init__(self,lambd, batch_size, scale_loss, projector):
        self.lambd = lambd
        self.batch_size = batch_size
        self.scale_loss = scale_loss
        self.projector = projector

In [83]:
args = Args(0.005,8,1/32,'64-64-64')

In [84]:
model = BarlowTwins(args)

[2048, 64, 64, 64]


In [85]:
backbone = torchvision.models.resnet50(zero_init_residual=True)

In [86]:
backbone.fc

Linear(in_features=2048, out_features=1000, bias=True)

In [71]:
dataset = torchvision.datasets.ImageFolder('/mounts/data/proj/jabbar/torchvision/imagenet/tiny-imagenet-200/',Transform())

In [72]:
len(dataset)

120000

In [105]:
loader = torch.utils.data.DataLoader(dataset,batch_size=2)

In [106]:
batch = next(iter(loader))

In [107]:
batch[1]

tensor([0, 0])

In [108]:
len(batch[0][0])

2

In [109]:
batch[0][1].shape

torch.Size([2, 3, 224, 224])

In [110]:
batch[0][0].shape

torch.Size([2, 3, 224, 224])

In [111]:
(y1, y2), _ = batch

In [112]:
y1.shape

torch.Size([2, 3, 224, 224])

In [113]:
loss,c = model(y1,y2)

In [114]:
c.shape

torch.Size([64, 64])

In [117]:
emb = model.projector(model.backbone(y1))

In [118]:
emb.shape

torch.Size([2, 64])