#Testing a Barlow Twins trained RESNET50 with a classifier head for CIFAR10
For this test, we build our DNN with the RESNET50 from the Barlow Twins group as a backbone and a fully connected layer as our classifier head.

In [None]:
class BarlowTwins(nn.Module):
  def __init__(self):
    super().__init__()
    self.backbone = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50', pretrained=True)
    self.fc = nn.Linear(1000,10)

  def forward(self, x):
    x = self.backbone(x)
    x = self.fc(x)
    return x

 First, we construct our Barlow Twins object. Then, we set the backbone as non-trainable.

In [None]:
BarTwModel = BarlowTwins()
BarTwModel.backbone.requires_grad_ = False

Setting the transforms for the train data and calling the CIFAR10 dataset.

In [None]:
TrainTransforms = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
trainset = datasets.CIFAR10(root="data/cifar10", train=True, download=True, transform=TrainTransforms)

Building a train function to train the classifier head.

In [None]:
def train(model, dataloader, nepochs=1, lr=1e-3):
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.train()
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, BarTwModel.parameters()),weight_decay=1e-5)
    criterion = nn.CrossEntropyLoss()
    model = model.to(device)
    criterion = criterion.to(device)

    running_loss = 0.
    running_samples = 0
    for epoch in range(nepochs):
        for it, data in enumerate(dataloader):
            ims, labels = data
            ims = ims.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            out = model(ims)
            loss = criterion(out, labels)
            running_loss += loss
            running_samples += ims.shape[0]

            if it % 100 == 0:
                print(f'ep: {epoch}, it: {it}, loss : {running_loss/running_samples:.5f}')
                running_loss = 0.
                running_samples = 0

            loss.backward()
            optimizer.step()

Training the classifier head of the Barlow Twins model.

In [None]:
input_shape = trainset[0][0].shape
train_loader = torch.utils.data.DataLoader(trainset, batch_size=64)
train(BarTwModel, train_loader, nepochs=10, lr=1e-3)