<a href="https://colab.research.google.com/github/PhillipOverloeper/BarlowTwins/blob/main/SVvsSSV.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [72]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random

from PIL import Image, ImageFilter, ImageOps

In [73]:
class GaussianBlur(object):
  """
  Apply Gaussian blur with a certain probability p.
  """
  def __init__(self, p):
      self.p = p

  def __call__(self, img):
      if random.random() < self.p:
          sigma = random.random() * 1.9 + 0.1
          return img.filter(ImageFilter.GaussianBlur(sigma))
      else:
          return img

      
class Solarization(object):
  """
  Apply solarization with a certain probability p.
  """
  def __init__(self, p):
      self.p = p

  def __call__(self, img):
      if random.random() < self.p:
          return ImageOps.solarize(img)
      else:
          return img

In [74]:
class Barlow_Transform:
    def __init__(self, train=True, input_height=224):
        self.train = train
        # First augmented image 
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(input_height, interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            GaussianBlur(p=1.0),
            Solarization(p=0.0),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])])
        # Second augmented image
        self.transform_prime = transforms.Compose([
            transforms.RandomResizedCrop(input_height, interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            GaussianBlur(p=0.1),
            Solarization(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])])

    def __call__(self, x):
        y1 = self.transform(x)
        y2 = self.transform_prime(x)
        if self.train == True:
          return y1, y2
        else:
          return y1

In [75]:
num_epochs = 10
batch_size = 32
learning_rate = 1e-3
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

# Defining the augmentations for the supervised and self-supervised learning
transforms_supervised = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])])
transforms_barlow_twins_train = Barlow_Transform(train=True, input_height = 32)
transforms_barlow_twins_test = Barlow_Transform(train=False, input_height = 32)

# Loading the test and train dataset for supervised learning
trainset_supervised = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_supervised)
trainloader_supervised = torch.utils.data.DataLoader(trainset_supervised, batch_size=batch_size, shuffle=True, num_workers=2)
testset_supervised = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_supervised)
testloader_supervised = torch.utils.data.DataLoader(testset_supervised, batch_size=batch_size, shuffle=False, num_workers=2)

# Loading the test and train dataset for self-supervised learning
trainset_barlow_twins = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_barlow_twins_train)
trainloader_barlow_twins = torch.utils.data.DataLoader(trainset_barlow_twins, batch_size=batch_size, shuffle=True, num_workers=2)
testset_barlow_twins = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_barlow_twins_test)
testloader_barlow_twins = torch.utils.data.DataLoader(testset_barlow_twins, batch_size=batch_size, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

  "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [76]:
class Net(nn.Module):
  """
  Model for the learning
  """
  def __init__(self, mode='sl'):
    super().__init__()
    self.mode = mode
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = torch.flatten(x, 1)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    if self.mode == 'sl':
      x = self.fc3(x)
    return x


  def class_head(self, x):
    with torch.no_grad():
      x = self.pool(F.relu(self.conv1(x)))
      x = self.pool(F.relu(self.conv2(x)))
      x = torch.flatten(x, 1)
      x = F.relu(self.fc1(x))
      x = F.relu(self.fc2(x))
    return self.fc3(x)

In [77]:
class Barlow_Twins(nn.Module):
  """
  Defining the Barlow Twins architecture
  """
  def __init__(self, kerneltype='gauss', param=2):
    super().__init__()
    self.param = param
    self.kerneltype = kerneltype


  def forward(self, x1, x2):
    if self.kerneltype == 'gauss':
      out = torch.exp(-torch.cdist(x1.T, x2.T, p=2)/self.param)
      out = out + 0
    elif self.kerneltype == 'poly':
      x1_norm = x1 / torch.max(x1.norm(dim=1)[:, None], 1e-08 * torch.ones(x1.shape, device='cuda:0'))
      x2_norm = x2 / torch.max(x2.norm(dim=1)[:, None], 1e-08 * torch.ones(x2.shape, device='cuda:0'))
      out = torch.matmul(x1.T, x2).add_(1).pow_(self.param)
      out = out / 2**self.param

    return out


In [78]:
class Barlow_Twins_Loss(nn.Module):
  def __init__(self, lambda_coeff=1e8):
    super().__init__()

    self.lambda_coeff = lambda_coeff


  def off_diagonal_elements(self, x):
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n-1, n+1)[:,1:].flatten()


  def forward(self, x):
    on_diag = torch.diagonal(x).add_(-1).pow_(2).sum()
    off_diag = self.off_diagonal_elements(x).pow_(2).sum()

    return on_diag + self.lambda_coeff * off_diag



In [79]:
net_supervised = Net(mode='sl').to(device)
net_barlow_twins = Net(mode='ssl').to(device)

criterion_supervised = nn.CrossEntropyLoss()
criterion_barlow_twins = Barlow_Twins_Loss(1e-3)
criterion_class_head = nn.CrossEntropyLoss()

optimizer_supervised = optim.SGD(net_supervised.parameters(), lr=learning_rate, momentum=0.9)
optimizer_barlow_twins = optim.SGD(net_barlow_twins.parameters(), lr=learning_rate, momentum=0.9)
optimizer_class_head = optim.SGD(net_supervised.parameters(), lr=learning_rate, momentum=0.9)

barlow_twins = Barlow_Twins('gauss', 50)

In [80]:
def supervised_training():
  for epoch in range(num_epochs):
    for i,data in enumerate(trainloader_supervised):
      inputs, labels = data[0].to(device), data[1].to(device)
   
      optimizer_supervised.zero_grad()

      outputs = net_supervised(inputs)
      loss = criterion_supervised(outputs, labels)
      loss.backward()
      optimizer_supervised.step()

      if i % 200 == 199:
        print('Epoch: ' + str(epoch) + ' Loss: ' + str(loss.item()))


def self_supervised_training():
  for epoch in range(num_epochs):
    for i,data in enumerate(trainloader_barlow_twins):
      (x1, x2), _ = data
      x1 = x1.to(device)
      x2 = x2.to(device)

      optimizer_barlow_twins.zero_grad()

      output1 = net_barlow_twins(x1)
      output2 = net_barlow_twins(x2)
      
      result = barlow_twins(output1, output2)

      loss = criterion_barlow_twins(result)
      loss.backward()
      optimizer_barlow_twins.step()

      if i % 200 == 199:
        print('Epoch: ' + str(epoch) + ' Loss: ' + str(loss.item()))


def class_head_training():
  for epoch in range(num_epochs):
    for i,data in enumerate(trainloader_barlow_twins):
      (x1, x2), labels = data
      x1 = x1.to(device)
      labels = labels.to(device)

      optimizer_class_head.zero_grad()

      outputs = net_barlow_twins.class_head(x1)

      loss = criterion_class_head(outputs, labels)
      loss.backward()
      optimizer_class_head.step()

      if i % 200 == 199:
        print('Epoch: ' + str(epoch) + ' Loss: ' + str(loss.item()))

    


In [84]:
def train_main():
  supervised_training()
  self_supervised_training()
  class_head_training()

train_main()

Epoch: 0 Loss: 2.2892820835113525
Epoch: 0 Loss: 2.275528907775879
Epoch: 0 Loss: 2.245453119277954
Epoch: 0 Loss: 2.023996353149414
Epoch: 0 Loss: 2.1120705604553223
Epoch: 0 Loss: 1.9132088422775269
Epoch: 0 Loss: 2.0560309886932373
Epoch: 1 Loss: 1.773587942123413
Epoch: 1 Loss: 1.2738450765609741
Epoch: 1 Loss: 1.8319664001464844
Epoch: 1 Loss: 1.7463308572769165
Epoch: 1 Loss: 1.8013215065002441
Epoch: 1 Loss: 1.440604329109192
Epoch: 1 Loss: 1.9785887002944946
Epoch: 2 Loss: 1.7868703603744507
Epoch: 2 Loss: 1.8539307117462158
Epoch: 2 Loss: 1.3063477277755737
Epoch: 2 Loss: 1.7682067155838013
Epoch: 2 Loss: 1.5640348196029663
Epoch: 2 Loss: 1.123510479927063
Epoch: 2 Loss: 1.432120442390442
Epoch: 3 Loss: 1.479819893836975
Epoch: 3 Loss: 1.4072133302688599
Epoch: 3 Loss: 1.3873536586761475
Epoch: 3 Loss: 1.2858750820159912
Epoch: 3 Loss: 1.2656748294830322
Epoch: 3 Loss: 1.2321780920028687
Epoch: 3 Loss: 1.1554049253463745
Epoch: 4 Loss: 1.4378881454467773
Epoch: 4 Loss: 1.29370

In [85]:
def supervised_test():
  total = 0
  correct = 0

  with torch.no_grad():
    for data in testloader_supervised:
      images, labels = data[0].to(device), data[1].to(device)

      outputs = net_supervised(images)

      _, predicted = torch.max(outputs.data, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()

  print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')


def barlow_twins_test():
  total = 0
  correct = 0

  with torch.no_grad():
    for data in testloader_barlow_twins:
      images, labels = data[0].to(device), data[1].to(device)

      outputs = net_barlow_twins.class_head(images)

      _, predicted = torch.max(outputs.data, 1)
      total += labels.size(0)
      correct += (predicted == labels).sum().item()

  print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')

In [86]:
def test_main():
  supervised_test()
  barlow_twins_test()

test_main()

Accuracy of the network on the 10000 test images: 59 %
Accuracy of the network on the 10000 test images: 10 %
