# BYOL in PyTorch

This is written by Ying Liu (yili@di.ku.dk) who is the student in the Department of Computer Science at KÃ¸benhavns Universitet.

## Algorithm

<div style="display: flex; justify-content: center; align-items: center; background: white;">
    <img src="https://imgur.com/rBZRkOf.png" style="width: 45%; margin-right: 1rem;" />
    <img src="https://imgur.com/uudkfAk.png" style="width: 45%;" />
</div>


There are three important components, they are **encoder** $f$ (ResNet-50), **projector** $g$ and **predictor** $q$, respectively.

- Online Network: encoder + projector + predictor, with parameters $\theta$;

- Target Network: encoder + projector, with parameters $\xi$.

During the training, the online network is updated by the **loss function** while the target network is updated slowly by the **exponential moving average** (EMA) of the online network.

## Global Area

This contains import and global definitions.

In [3]:
import copy
import torch
import random
import numpy as np
from torch import nn
from functools import wraps
import torch.nn.functional as F

## Model Implementation

Apply the original BYOL method to train the 3D Resnet model.

**Citation**: [lucidrains](https://github.com/lucidrains/byol-pytorch/) and the other one [AI Summer](https://theaisummer.com/byol/).

### Helper functions

In [2]:
# default: return the default value if the input value is None
def default(val, def_val):
    return def_val if val is None else val

# flatten: flatten the tensor
def flatten(t):
    # original shape (N, C, D, H, W) -> (N, C*D*H*W)
    return t.reshape(t.shape[0], -1)

# signleton: cache the result of the function
def singleton(cache_key):
    def inner_fn(fn):
        @wraps(fn)
        def wrapper(self, *args, **kwargs):
            instance = getattr(self, cache_key)
            if instance is not None:
                return instance

            instance = fn(self, *args, **kwargs)
            setattr(self, cache_key, instance)
            return instance
        return wrapper
    return inner_fn

# get_module_device: get the device of the module
def get_module_device(module):
    return next(module.parameters()).device

# set_requires_grad: set the requires_grad attribute of the model
def set_requires_grad(model, val):
    for p in model.parameters():
        p.requires_grad = val

### Loss function

This loss function is defined by the original paper ==> BYOL. 

Normalize: torch.nn.functional (https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html#torch.nn.functional.normalize)

In [None]:
# apply L2 normalization to the last dimension. 
# The last dimension is the embedding dimension where stores the feature vector.
def loss_fn(x, y):
    x = F.normalize(x, dim=-1, p=2) 
    y = F.normalize(y, dim=-1, p=2)
    return 2 - 2 * (x * y).sum(dim=-1)

### Augmentation utilities

In [None]:
# random_apply: apply the function with the probability p
class RandomApply(nn.Module):
    # fn: the function for augmentation
    def __init__(self, fn, p):
        super().__init__()
        self.fn = fn
        self.p = p

    def forward(self, x):
        if random.random() > self.p:
            return x
        return self.fn(x)

### Exponential moving average

EMA: an exponential moving average (EMA), also known as an exponential weighted moving average (EWMA), is a first-order infinite impulse response filter that applies weighting factors which decrease exponentially, never reaching zero. This formulation is according to Hunter (1986). https://en.wikipedia.org/wiki/Moving_average

By the paper, more precisely, given a target decay rate $\tau \in [0, 1]$, after each training step we perform the following update:

$$
\xi \leftarrow \tau \xi + (1 - \tau) \theta
$$

where $\xi$ is the target network parameters and $\theta$ is the online network parameters.

In [4]:
class EMA():
    # given a target decay rate beta (which is tau in the paper)
    def __init__(self, beta):
        super().__init__()
        self.beta = beta
    
    # old corresponds to target (xi), new corresponds to online model (theta)
    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new
    
def update_moving_average(ema_updater, ma_model, current_model):
    '''
    Update moving average of model weights.
    ema_updater: EMA object
    ma_model: the model with moving average weights (the target model)
    current_model: the model with current weights (the online model)
    '''
    for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
        old_weight, up_weight = ma_params.data, current_params.data
        ma_params.data = ema_updater.update_average(old_weight, up_weight)

### Multilayer perceptron

A multilayer perceptron (MLP) is a name for a modern feedforward artificial neural network, consisting of fully connected neurons with a nonlinear kind of activation function, organized in at least three layers, notable for being able to distinguish data this is not linearly separable. It is a misnormer because the original perceptron used a Heaviside step function, instead of a nonlinear kind of activation function (used by modern networks). https://en.wikipedia.org/wiki/Multilayer_perceptron.

By the paper, as in SimCLR, the representation $y$ is projected to a smaller space by a multi-layer perceptron (MLP) $g_{\theta}$, and similarly for the target projection $g_{\xi}$. This MLP consists in a linear layer with output size $4096$ followed by batch normalization, rectified linear units (ReLU), and a final layer with output dimension $256$. Contrary to SimCLR, the output of this MLP is not batch normalized. The predictor $q_{\theta}$ uses the same architecutre as $g_{\theta}$.

https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html


The difference between inplace=True and False is that it will modify the input directly, without allocating any additional output. https://discuss.pytorch.org/t/whats-the-difference-between-nn-relu-and-nn-relu-inplace-true/948

In [5]:
def MLP(dim, projection_size, hidden_size=4096):
    return nn.Sequential(
        nn.Linear(dim, hidden_size),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_size, projection_size)
    )

### NetWrapper

A wrapper class for the base neural network, will manage the interception of the hidden layer output and pipe it into the projector and predictor nets.

Extract the representation $y$ and projection $z$.

In [None]:
class NetWrapper(nn.Module):
    # layer = -2 means the second last layer of the network
    def __init__(self, net, projection_size, projection_hidden_size, layer=-2):
        super().__init__()
        self.net = net
        self.layer = layer
        
        self.projector = None
        self.projection_size = projection_size
        self.projection_hidden_size = projection_hidden_size

        self.hidden = {} # to store the hidden layer output when hook is registered
        self.hook_registered = False

    def _find_layer(self):
        if type(self.layer) == str:
            modules = dict([*self.net.named_modules()])
            return modules.get(self.layer, None)
        elif type(self.layer) == int:
            children = [*self.net.children()]
            return children[self.layer]
        return None
    
    def _hook(self, _, input, output):
        # store the output of the hidden layer
        device = input[0].device
        self.hidden[device] = flatten(output) # store the flattened output of the hidden layer

    def _register_hook(self):
        layer = self._find_layer()
        assert layer is not None, f'hidden layer ({self.layer}) not found'
        handle = layer.register_forward_hook(self._hook)
        self.hook_registered = True

    @singleton('projector')
    # get the projection z
    def _get_projector(self, hidden):
        _, dim = hidden.shape
        create_mlp_fn = MLP
        projector = create_mlp_fn(dim, self.projection_size, self.projection_hidden_size)
        return projector.to(hidden)
    
    # get the representation y
    def get_representation(self, x):
        if self.layer == -1:
            return self.net(x)
        
        if not self.hook_registered:
            self._register_hook()
        
        self.hidden.clear()
        _ = self.net(x)
        hidden = self.hidden[x.device] 
        self.hidden.clear()

        assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
        return hidden
    
    def forward(self, x, return_projection = True):
        representation = self.get_representation(x)
        
        if not return_projection:
            return representation
        
        projector = self._get_projector(representation)
        projection = projector(representation)
        return projection, representation

### Main BYOL

By reading the paper https://github.com/Tencent/MedicalNet/blob/master/datasets/brains18.py, I was wondering if there is some step need to be changed in preprocessing the data. 

Maybe first constrain all the cts to be the shape of (320, 320, 320), then, drop invalid slices that are all zeros. Then, perform resized crop to (128, 128, 128) and normalize the data.

In [None]:
class BYOL(nn.Module):
    def __init__(
            self,
            net,
            image_depth,
            image_size,
            hidden_layer=-2,
            projection_size=256,
            projection_hidden_size=4096,
            augment_fn=None,
            augment_fn2=None,
            moving_average_decay=0.99
    ):
        super().__init__()
        self.net = net

        # default SimCLR augmentation
        DEFAULT_AUG == torch.nn.Sequential(
            # haven't implemented
        )

        self.augment1 = default(augment_fn, DEFAULT_AUG)
        self.augment2 = default(augment_fn2, self.augment1)

        self.online_encoder = NetWrapper(
                net,
                projection_size,
                projection_hidden_size,
                layer = hidden_layer,
        )

        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)

        self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)

        # get the device of network and make warpper same device
        device = get_module_device(net)
        self.to(device)

        # send a mock image tensor to instantiate singleton parameters
        self.forward(torch.randn(2, 1, image_depth, image_size, image_size, device=device))

    @singleton('target_encoder')
    def _get_target_encoder(self):
        target_encoder = copy.deepcopy(self.online_encoder)
        set_requires_grad(target_encoder, False)
        return target_encoder
    
    def reset_moving_average(self):
        del self.target_encoder
        self.target_encoder = None

    def update_moving_average(self):
        assert self.target_encoder is not None, 'target encoder has not been created yet'
        update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)

    def forward(
            self,
            x,
            return_embedding=False,
            return_projection=True
    ):
        assert not (self.training and x.shape[0] == 1), 'you must have greater than 1 sample when training, due to the batchnorm in the projection layer'

        if return_embedding:
            return self.online_encoder(x, return_projection = return_projection)
        
        image_one, image_two = self.augment1(x), self.augment2(x)

        images = torch.cat([image_one, image_two], dim=0)

        online_projections, _ = self.online_encoder(images)
        online_predications = self.online_predictor(online_projections)

        online_pred_one, online_pred_two = online_predications.chunk(2, dim=0)

        with torch.no_grad():
            target_encoder = self._get_target_encoder()

            target_projections, _ = target_encoder(images)
            target_predications = self.target_projections.detach()

            target_proj_one, target_proj_two = target_projections.chunk(2, dim=0)

        loss_one = loss_fn(online_pred_one, target_proj_two)
        loss_two = loss_fn(online_pred_two, target_proj_one)

        loss = loss_one + loss_two
        return loss.mean()

## 3D ResNet-50

In [None]:
# solve the function of 1/(x^3) = 3 => x = 1/(3^(1/3)) == 0.6933612743506345
# solve the function of x^3 = 3 => x = 3^(1/3) == 1.4422495703074083