In [None]:
def set_requires_grad(model, val):
    """
    Set pytorch model's require grad property to val.

    Copyright (c) 2020 Phil Wang Redistributed under the MIT license.
    Function taken from: https://github.com/lucidrains/byol-pytorch
    """
    for param in model.parameters():
        param.requires_grad = val


def ema(target_param, online_param, alpha):
    if alpha is None:
        return online_param
    return alpha * target_param + (1 - alpha) * online_param

In [None]:
import torch
import torch.nn as nn


class BYOL(nn.Module):
    """
    Build a BYOL model.

    Parameters
    ----------
    base_encoder : torch.nn.Module
        Base encoder model.
    dim : int, default=2048
        Feature dimension
    pred_dim : int, default=512
        Hidden dimension of the predictor
    """

    def __init__(self, base_encoder, init_target_from_online, dim=2048, pred_dim=512):
        super(BYOL, self).__init__()

        # create the online encoder
        # num_classes is the output fc dimension, zero-initialize last BNs
        self.encoder = base_encoder(num_classes=dim, zero_init_residual=True)

        # build a 3-layer online projector
        prev_dim = self.encoder.fc.weight.shape[1]
        self.encoder.fc = nn.Sequential(
            nn.Linear(prev_dim, prev_dim, bias=False),
            nn.BatchNorm1d(prev_dim),
            nn.ReLU(inplace=True),  # first layer
            nn.Linear(prev_dim, prev_dim, bias=False),
            nn.BatchNorm1d(prev_dim),
            nn.ReLU(inplace=True),  # second layer
            self.encoder.fc,
            nn.BatchNorm1d(dim, affine=False),
        )  # output layer
        self.encoder.fc[
            6
        ].bias.requires_grad = False  # hack: not use bias as it is followed by BN

        # build target model
        self.target_encoder = base_encoder(num_classes=dim, zero_init_residual=True)
        self.target_encoder.fc = nn.Sequential(
            nn.Linear(prev_dim, prev_dim, bias=False),
            nn.BatchNorm1d(prev_dim),
            nn.ReLU(inplace=True),  # first layer
            nn.Linear(prev_dim, prev_dim, bias=False),
            nn.BatchNorm1d(prev_dim),
            nn.ReLU(inplace=True),  # second layer
            self.target_encoder.fc,
            nn.BatchNorm1d(dim, affine=False),
        )  # output layer

        if init_target_from_online:
            self.target_encoder.load_state_dict(self.encoder.state_dict())

        # disable grad calculations for target model
        set_requires_grad(self.target_encoder, False)

        # build a 2-layer predictor
        self.predictor = nn.Sequential(
            nn.Linear(dim, pred_dim, bias=False),
            nn.BatchNorm1d(pred_dim),
            nn.ReLU(inplace=True),  # hidden layer
            nn.Linear(pred_dim, dim),
        )  # output layer

    def forward(self, x1, x2):
        """
        Forward step.

        Parameters
        ----------
        x1 : torch.Tensor
            First view of images.
        x2 : torch.Tensor
            Second view of images.
        Return
        ------
        p1, p2, z1, z2 :
            online predictors and target projections of the networks
        Note
        ----
        See https://arxiv.org/abs/2006.07733 for detailed notations
        """
        # compute features for one view
        z1 = self.encoder(x1)  # NxC
        z2 = self.encoder(x2)  # NxC
        z1_target = self.target_encoder(x1)
        z2_target = self.target_encoder(x2)

        p1 = self.predictor(z1)  # NxC
        p2 = self.predictor(z2)  # NxC

        return p1, p2, z1_target, z2_target

    def update_target(self, target_model, online_model, alpha=0.99):
        target_state_dict = target_model.state_dict()
        for param in target_state_dict:
            target_state_dict[param] = ema(
                target_state_dict[param], online_model.state_dict()[param], alpha
            )
        target_model.load_state_dict(target_state_dict)

In [None]:
# build model
import torchvision.models as models

arch = "resnet18"
init_target_from_online = False

encoder = models.__dict__[arch]
model = BYOL(encoder, init_target_from_online)

In [None]:
# online and target aren't the same reference - should return false
print(model.encoder.fc == model.target_encoder.fc)

In [None]:
# autograd check for online - should be all true except for fc 6's biases
for name, param in model.encoder.named_parameters():
    print(name, param.requires_grad)

for name, param in model.predictor.named_parameters():
    print(name, param.requires_grad)

In [None]:
# autograd check for target - should be all false
for name, param in model.target_encoder.named_parameters():
    print(name, param.requires_grad)

In [None]:
# check to see if weights between online and target are identical
target_online_init = True
for param in zip(model.target_encoder.parameters(), model.encoder.parameters()):
    if not (param[0] == param[1]).all():
        target_online_init = False
print(target_online_init)

In [None]:
# check to see if target model is properly randomly initialized
# true - if properly randomly initialized
# false - if initialized from online or improperly randomly initialized
# list constructed from remaking target_encoder from scratch
reset_target_check = [
    False,
    True,
    True,
    False,
    True,
    True,
    False,
    True,
    True,
    False,
    True,
    True,
    False,
    True,
    True,
    False,
    True,
    True,
    False,
    True,
    True,
    False,
    True,
    True,
    False,
    True,
    True,
    False,
    True,
    True,
    False,
    True,
    True,
    False,
    True,
    True,
    False,
    True,
    True,
    False,
    True,
    True,
    False,
    True,
    True,
    False,
    True,
    True,
    False,
    True,
    True,
    False,
    True,
    True,
    False,
    True,
    True,
    False,
    True,
    True,
    False,
    True,
    True,
    False,
    True,
    True,
    False,
    False,
]

reset_test_results = []

for param in zip(model.target_encoder.parameters(), model.encoder.parameters()):
    reset_test_results.append(bool((param[0] == param[1]).all()))

proper_random_init = reset_test_results == reset_target_check
print(proper_random_init)

In [None]:
# check if ema works properly - should return true
import random

a = random.randint(-10, 10) * random.random()
b = random.randint(-10, 10) * random.random()
alpha = random.random()
print(ema(a, b, alpha) == alpha * a + (1 - alpha) * b)

In [None]:
# check if update target model works properly
# should return true
import copy

alpha = random.random()
target_original = copy.deepcopy(model.target_encoder)
proper_target_update = True

if not proper_random_init:
    print("Target model may not have been properly randomly intialized.")

model.update_target(model.target_encoder, model.encoder, alpha)

for param in zip(
    model.target_encoder.parameters(),
    target_original.parameters(),
    model.encoder.parameters(),
):
    if not (param[0] == ema(param[1], param[2], alpha)).all():
        proper_target_update = False

print(proper_target_update)