# BYOL PyTorch
BYOL implementation using PyTorch
Main Reference: Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning https://arxiv.org/abs/2006.07733

Main part of the code is inspired from https://github.com/reshinthadithyan/BYOL-Pytorch, and https://github.com/google-deepmind/deepmind-research/tree/master/byol

## Default Training

### Data Augmentations

* Random cropping
* Optional left-right flip **Optional**
* Color jittering -> brightness, contrast, saturation, hue -> shifted by a random offset, applied on all pixels of the image, the order in which these shifts are performed is randomly selected for each patch.
* Color dropping -> Conversion to gray scale
* Gaussian blurring -> **Optional** for a 224 x 224 image, a square Gaussian kernel of size 23 x 23 is used, with a standard deviation uniformly sampled  over `[0.1,2.0]`
* Solarization -> **Optional** `x -> x*1{x<0.5} + (1-x)*1{x>=0.5}` for pixel with values in `[0,1]`
    * Solarization not implemented in actual code implementation

* **Also while applying Gaussian Blur, Gaussian Blur function is not directly used**.
* **Instead, conv2d is used where `3 channels, 3 kernels, kernel size = (kernel_size,1) and (1,kernel_size), and no bias, no padding, stride = 1, groups = 3`**

## Model

* Consists of two networks:
    * Online network
    * Target network

In [10]:
import os
import torch
import torchvision
from torchvision import transforms
import torch.nn as nn
import numpy as np
from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F

In [11]:
class GaussianBlur(object):
    """blur a single image on CPU"""

    def __init__(self, kernel_size):
        radias = kernel_size // 2
        kernel_size = radias * 2 + 1
        self.blur_h = nn.Conv2d(3, 3, kernel_size=(kernel_size, 1),
                                stride=1, padding=0, bias=False, groups=3)
        self.blur_v = nn.Conv2d(3, 3, kernel_size=(1, kernel_size),
                                stride=1, padding=0, bias=False, groups=3)
        self.k = kernel_size
        self.r = radias

        self.blur = nn.Sequential(
            nn.ReflectionPad2d(radias),
            self.blur_h,
            self.blur_v
        )

        self.pil_to_tensor = transforms.ToTensor()
        self.tensor_to_pil = transforms.ToPILImage()

    def __call__(self, img):
        img = self.pil_to_tensor(img).unsqueeze(0)

        sigma = np.random.uniform(0.1, 2.0)
        x = np.arange(-self.r, self.r + 1)
        x = np.exp(-np.power(x, 2) / (2 * sigma * sigma))
        x = x / x.sum()
        x = torch.from_numpy(x).view(1, -1).repeat(3, 1)

        self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1))
        self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k))

        with torch.no_grad():
            img = self.blur(img)
            img = img.squeeze()

        img = self.tensor_to_pil(img)

        return img

In [12]:
def get_simclr_data_transforms(input_shape, s):
    # get a set of data augmentation transformations as described in the SimCLR paper.
    color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
    data_transforms = transforms.Compose([transforms.RandomResizedCrop(size=input_shape[0]),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.RandomApply([color_jitter], p=0.8),
                                          transforms.RandomGrayscale(p=0.2),
                                          GaussianBlur(kernel_size=int(0.1 * input_shape[0])),
                                          transforms.ToTensor()])
    return data_transforms

In [13]:
class MultiViewDataInjector(object):
    def __init__(self, *args):
        self.transforms = args[0]
        self.random_flip = transforms.RandomHorizontalFlip()

    def __call__(self, sample, *with_consistent_flipping):
        if with_consistent_flipping:
            sample = self.random_flip(sample)
        output = [transform(sample) for transform in self.transforms]
        return output

In [14]:
class MLPHead(nn.Module):
    def __init__(self, in_channels, mlp_hidden_size, projection_size):
        super(MLPHead, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(in_channels, mlp_hidden_size),
            nn.BatchNorm1d(mlp_hidden_size),
            nn.ReLU(inplace=True),
            nn.Linear(mlp_hidden_size, projection_size)
        )

    def forward(self, x):
        return self.net(x)

In [15]:
class ResNet18(torch.nn.Module):
    def __init__(self,model_name):
        super(ResNet18, self).__init__()
        if model_name == 'resnet18':
            resnet = torchvision.models.resnet18(pretrained=False)
        elif model_name == 'resnet50':
            resnet = torchvision.models.resnet50(pretrained=False)

        self.encoder = torch.nn.Sequential(*list(resnet.children())[:-1])
        self.projetion = MLPHead(in_channels=resnet.fc.in_features,mlp_hidden_size = 512,projection_size = 128)

    def forward(self, x):
        h = self.encoder(x)
        h = h.view(h.shape[0], h.shape[1])
        return self.projetion(h)

In [23]:
import os
from shutil import copyfile

def _create_model_training_folder(writer, files_to_same):
    model_checkpoints_folder = os.path.join(writer.log_dir, 'checkpoints')
    if not os.path.exists(model_checkpoints_folder):
        os.makedirs(model_checkpoints_folder)
        '''for file in files_to_same:
            copyfile(file, os.path.join(model_checkpoints_folder, os.path.basename(file)))'''

In [24]:
class BYOLTrainer:
    def __init__(self, online_network, target_network, predictor, optimizer, device, max_epochs, m, batch_size, num_workers, checkpoint_interval,):
        self.online_network = online_network
        self.target_network = target_network
        self.optimizer = optimizer
        self.device = device
        self.predictor = predictor
        self.max_epochs = max_epochs
        self.writer = SummaryWriter(log_dir = '/kaggle/working/logs')
        self.m = m
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.checkpoint_interval = checkpoint_interval
        _create_model_training_folder(self.writer, files_to_same=["./config/config.yaml", "main.py", 'trainer.py'])

    @torch.no_grad()
    def _update_target_network_parameters(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(self.online_network.parameters(), self.target_network.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @staticmethod
    def regression_loss(x, y):
        x = F.normalize(x, dim=1)
        y = F.normalize(y, dim=1)
        return 2 - 2 * (x * y).sum(dim=-1)

    def initializes_target_network(self):
        # init momentum network as encoder net
        for param_q, param_k in zip(self.online_network.parameters(), self.target_network.parameters()):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

    def train(self, train_dataset):

        train_loader = DataLoader(train_dataset, batch_size=self.batch_size,
                                  num_workers=self.num_workers, drop_last=False, shuffle=True)

        niter = 0
        model_checkpoints_folder = os.path.join(self.writer.log_dir, 'checkpoints')

        self.initializes_target_network()
        loss_train = []

        for epoch_counter in range(self.max_epochs):
            loss_epoch = 0
            count = 0
            for (batch_view_1, batch_view_2), _ in train_loader:

                batch_view_1 = batch_view_1.to(self.device)
                batch_view_2 = batch_view_2.to(self.device)

                if niter == 0:
                    grid = torchvision.utils.make_grid(batch_view_1[:32])
                    self.writer.add_image('views_1', grid, global_step=niter)

                    grid = torchvision.utils.make_grid(batch_view_2[:32])
                    self.writer.add_image('views_2', grid, global_step=niter)

                loss = self.update(batch_view_1, batch_view_2)
                self.writer.add_scalar('loss', loss, global_step=niter)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                self._update_target_network_parameters()  # update the key encoder
                niter += 1
                
                loss_epoch += loss
            print(f"Epoch: {epoch_counter}, Loss: {loss_epoch}")
            loss_train.append(loss_epoch)

            print("End of epoch {}".format(epoch_counter))

        # save checkpoints
        self.save_model(os.path.join(model_checkpoints_folder, 'model.pth'))

    def update(self, batch_view_1, batch_view_2):
        # compute query feature
        predictions_from_view_1 = self.predictor(self.online_network(batch_view_1))
        predictions_from_view_2 = self.predictor(self.online_network(batch_view_2))

        # compute key features
        with torch.no_grad():
            targets_to_view_2 = self.target_network(batch_view_1)
            targets_to_view_1 = self.target_network(batch_view_2)

        loss = self.regression_loss(predictions_from_view_1, targets_to_view_1)
        loss += self.regression_loss(predictions_from_view_2, targets_to_view_2)
        return loss.mean()

    def save_model(self, PATH):

        torch.save({
            'online_network_state_dict': self.online_network.state_dict(),
            'target_network_state_dict': self.target_network.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, PATH)

In [25]:
torch.manual_seed(0)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Training with: {device}")

data_transform = get_simclr_data_transforms(s= 1, input_shape= (96,96,3))

train_dataset = torchvision.datasets.STL10('/home/thalles/Downloads/', split='train+unlabeled', download=True,
                               transform=MultiViewDataInjector([data_transform, data_transform]))

# online network
online_network = ResNet18('resnet18').to(device)

# predictor network
predictor = MLPHead(in_channels=online_network.projetion.net[-1].out_features,
                    mlp_hidden_size = 512,
                    projection_size = 128).to(device)

# target encoder
target_network = ResNet18('resnet18').to(device)

optimizer = torch.optim.SGD(list(online_network.parameters()) + list(predictor.parameters()),
                            lr = 0.03,
                            momentum = 0.9,
                            weight_decay = 0.0004)

trainer = BYOLTrainer(online_network=online_network,
                      target_network=target_network,
                      optimizer=optimizer,
                      predictor=predictor,
                      device=device,
                      batch_size = 64,
                      m = 0.996,
                      checkpoint_interval = 5000,
                      max_epochs = 15,
                      num_workers = 4)

trainer.train(train_dataset)

Training with: cuda
Files already downloaded and verified
Epoch: 0, Loss: 2038.26025390625
End of epoch 0
Epoch: 1, Loss: 1229.427490234375
End of epoch 1
Epoch: 2, Loss: 1092.4007568359375
End of epoch 2
Epoch: 3, Loss: 1031.49072265625
End of epoch 3
Epoch: 4, Loss: 1002.2159423828125
End of epoch 4
Epoch: 5, Loss: 960.385986328125
End of epoch 5
Epoch: 6, Loss: 950.72314453125
End of epoch 6
Epoch: 7, Loss: 921.6815795898438
End of epoch 7
Epoch: 8, Loss: 917.8095703125
End of epoch 8
Epoch: 9, Loss: 892.4237670898438
End of epoch 9
Epoch: 10, Loss: 893.6664428710938
End of epoch 10
Epoch: 11, Loss: 881.2210083007812
End of epoch 11
Epoch: 12, Loss: 848.5782470703125
End of epoch 12
Epoch: 13, Loss: 827.1861572265625
End of epoch 13
Epoch: 14, Loss: 804.5858154296875
End of epoch 14


In [35]:
print(trainer.online_network)

ResNet18(
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True

In [36]:
print(trainer.target_network)

ResNet18(
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True

In [37]:
print(trainer.optimizer)

SGD (
Parameter Group 0
    dampening: 0
    differentiable: False
    foreach: None
    lr: 0.03
    maximize: False
    momentum: 0.9
    nesterov: False
    weight_decay: 0.0004
)


In [34]:
MODEL_PATH = '/kaggle/working/'
MODEL_NAME = 'BYOL_ON_RESNET18_online_network.pth'
MODEL_SAVE_PATH = MODEL_PATH + MODEL_NAME
print(MODEL_SAVE_PATH)
torch.save(trainer.online_network.state_dict(),MODEL_SAVE_PATH)
MODEL_TARGET_NETWORK = MODEL_PATH + 'BYOL_ON_RESNET18_target_network.pth'
torch.save(trainer.target_network.state_dict(),MODEL_TARGET_NETWORK)
MODEL_OPTIMIZER = MODEL_PATH + 'BYOL_ON_RESNET18_optimizer.pth'
torch.save(trainer.optimizer.state_dict(),MODEL_OPTIMIZER)

/kaggle/working/BYOL_ON_RESNET18_online_network.pth


## LINEAR CLASSIFIER EVALUATION

In [39]:
data_transforms = torchvision.transforms.Compose([transforms.ToTensor()])
train_dataset = torchvision.datasets.STL10('/home/thalles/Downloads/', split='train', download=False,
                               transform=data_transforms)

test_dataset = torchvision.datasets.STL10('/home/thalles/Downloads/', split='test', download=False,
                               transform=data_transforms)

In [41]:
batch_size = 512
train_loader = DataLoader(train_dataset, batch_size=batch_size,
                          num_workers=0, drop_last=False, shuffle=True)

test_loader = DataLoader(test_dataset, batch_size=batch_size,
                          num_workers=0, drop_last=False, shuffle=True)

In [43]:
device = 'cuda' #'cuda' if torch.cuda.is_available() else 'cpu'
encoder = ResNet18('resnet18')
output_feature_dim = encoder.projetion.net[0].in_features

In [52]:
encoder.load_state_dict(torch.load('/kaggle/working/BYOL_ON_RESNET18_online_network.pth'))
encoder = torch.nn.Sequential(*list(encoder.children())[:-1])    
encoder = encoder.to(device)

In [None]:
print(encoder)

In [53]:
class LogisticRegression(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LogisticRegression, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)
        
    def forward(self, x):
        return self.linear(x)

In [54]:
logreg = LogisticRegression(output_feature_dim, 10)
logreg = logreg.to(device)

In [69]:
def get_features_from_encoder(encoder, loader):
    x_train = []
    y_train = []

    # get the features from the pre-trained model
    for i, (x, y) in enumerate(loader):
        with torch.no_grad():
            x = x.to(device)
            feature_vector = encoder(x)
            x_train.extend(feature_vector)
            y_train.extend(y.numpy())

            
    x_train = torch.stack(x_train)
    y_train = torch.tensor(y_train)
    return x_train, y_train

In [70]:
encoder.eval()
x_train, y_train = get_features_from_encoder(encoder, train_loader)
x_test, y_test = get_features_from_encoder(encoder, test_loader)

if len(x_train.shape) > 2:
    x_train = torch.mean(x_train, dim=[2, 3])
    x_test = torch.mean(x_test, dim=[2, 3])
    
print("Training data shape:", x_train.shape, y_train.shape)
print("Testing data shape:", x_test.shape, y_test.shape)

Training data shape: torch.Size([5000, 512]) torch.Size([5000])
Testing data shape: torch.Size([8000, 512]) torch.Size([8000])


In [71]:
def create_data_loaders_from_arrays(X_train, y_train, X_test, y_test):

    train = torch.utils.data.TensorDataset(X_train, y_train)
    train_loader = torch.utils.data.DataLoader(train, batch_size=64, shuffle=True)

    test = torch.utils.data.TensorDataset(X_test, y_test)
    test_loader = torch.utils.data.DataLoader(test, batch_size=512, shuffle=False)
    return train_loader, test_loader

In [89]:
train_loader, test_loader = create_data_loaders_from_arrays(torch.from_numpy(x_train), y_train, x_test, y_test)

In [90]:
optimizer = torch.optim.Adam(logreg.parameters(), lr=3e-4)
criterion = torch.nn.CrossEntropyLoss()
eval_every_n_epochs = 10

for epoch in range(200):
#     train_acc = []
    for x, y in train_loader:

        x = x.to(device)
        y = y.to(device)
        
        # zero the parameter gradients
        optimizer.zero_grad()        
        
        logits = logreg(x)
        predictions = torch.argmax(logits, dim=1)
        
        loss = criterion(logits, y)
        
        loss.backward()
        optimizer.step()
    
    total = 0
    if epoch % eval_every_n_epochs == 0:
        correct = 0
        for x, y in test_loader:
            x = x.to(device)
            y = y.to(device)

            logits = logreg(x)
            predictions = torch.argmax(logits, dim=1)
            
            total += y.size(0)
            correct += (predictions == y).sum().item()
            
        acc = 100 * correct / total
        print(f"Testing accuracy: {np.mean(acc)}")

Testing accuracy: 30.9625
Testing accuracy: 34.2375
Testing accuracy: 36.3375
Testing accuracy: 37.775
Testing accuracy: 39.5375
Testing accuracy: 41.4875
Testing accuracy: 42.275
Testing accuracy: 42.3375
Testing accuracy: 41.8375
Testing accuracy: 41.975
Testing accuracy: 39.9375
Testing accuracy: 39.675
Testing accuracy: 37.0
Testing accuracy: 35.575
Testing accuracy: 34.6
Testing accuracy: 33.0125
Testing accuracy: 31.025
Testing accuracy: 29.9625
Testing accuracy: 28.45
Testing accuracy: 27.6375


* This is the linear classification accuracy when the encoder takes the parameters of the BYOL trained for 15 epochs.

In [93]:
train_dataset_second_train = torchvision.datasets.STL10('/home/thalles/Downloads/', split='train+unlabeled', download=True,
                               transform=MultiViewDataInjector([data_transform, data_transform]))

trainer.train(train_dataset_second_train)

Files already downloaded and verified
Epoch: 0, Loss: 799.3731079101562
End of epoch 0
Epoch: 1, Loss: 802.874755859375
End of epoch 1
Epoch: 2, Loss: 797.1381225585938
End of epoch 2
Epoch: 3, Loss: 793.4573364257812
End of epoch 3
Epoch: 4, Loss: 808.0007934570312
End of epoch 4
Epoch: 5, Loss: 811.2720947265625
End of epoch 5
Epoch: 6, Loss: 810.4304809570312
End of epoch 6
Epoch: 7, Loss: 814.619384765625
End of epoch 7
Epoch: 8, Loss: 814.6397094726562
End of epoch 8
Epoch: 9, Loss: 824.9259643554688
End of epoch 9
Epoch: 10, Loss: 823.8887939453125
End of epoch 10
Epoch: 11, Loss: 831.4505004882812
End of epoch 11
Epoch: 12, Loss: 831.2811279296875
End of epoch 12
Epoch: 13, Loss: 836.8493041992188
End of epoch 13
Epoch: 14, Loss: 842.2604370117188
End of epoch 14
