In [18]:
import copy
import random
from functools import wraps

import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist

from tqdm import tqdm
from torchvision import transforms as T

# Loss fn in BYOL paper
def loss_fn(x, y):
    x = F.normalize(x, dim=-1, p=2)
    y = F.normalize(y, dim=-1, p=2)
    return 2 - 2 * (x * y).sum(dim=-1)

In [19]:
resnet_path = './improved-net-2.pt'
learner_path = './learner-net-2.pt'

In [20]:
class RandomApply(nn.Module):
    def __init__(self, fn, p):
        super().__init__()
        self.fn = fn
        self.p = p
    def forward(self, x):
        if random.random() > self.p:
            return x
        return self.fn(x)

In [21]:
def default(val, def_val):
    return def_val if val is None else val

def flatten(t):
    return t.reshape(t.shape[0], -1)

def singleton(cache_key):
    def inner_fn(fn):
        @wraps(fn)
        def wrapper(self, *args, **kwargs):
            instance = getattr(self, cache_key)
            if instance is not None:
                return instance

            instance = fn(self, *args, **kwargs)
            setattr(self, cache_key, instance)
            return instance
        return wrapper
    return inner_fn

def get_module_device(module):
    return next(module.parameters()).device

def set_requires_grad(model, val):
    for p in model.parameters():
        p.requires_grad = val

def MaybeSyncBatchnorm(is_distributed = None):
    is_distributed = default(is_distributed, dist.is_initialized() and dist.get_world_size() > 1)
    return nn.SyncBatchNorm if is_distributed else nn.BatchNorm1d

# loss fn

# augmentation utils

class RandomApply(nn.Module):
    def __init__(self, fn, p):
        super().__init__()
        self.fn = fn
        self.p = p
    def forward(self, x):
        if random.random() > self.p:
            return x
        return self.fn(x)

# exponential moving average

class EMA():
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

def update_moving_average(ema_updater, ma_model, current_model):
    for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
        old_weight, up_weight = ma_params.data, current_params.data
        ma_params.data = ema_updater.update_average(old_weight, up_weight)

# MLP class for projector and predictor

def MLP(dim, projection_size, hidden_size=4096, sync_batchnorm=None):
    return nn.Sequential(
        nn.Linear(dim, hidden_size),
        MaybeSyncBatchnorm(sync_batchnorm)(hidden_size),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_size, projection_size)
    )

def SimSiamMLP(dim, projection_size, hidden_size=4096, sync_batchnorm=None):
    return nn.Sequential(
        nn.Linear(dim, hidden_size, bias=False),
        MaybeSyncBatchnorm(sync_batchnorm)(hidden_size),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_size, hidden_size, bias=False),
        MaybeSyncBatchnorm(sync_batchnorm)(hidden_size),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_size, projection_size, bias=False),
        MaybeSyncBatchnorm(sync_batchnorm)(projection_size, affine=False)
    )

# a wrapper class for the base neural network
# will manage the interception of the hidden layer output
# and pipe it into the projecter and predictor nets

class NetWrapper(nn.Module):
    def __init__(self, net, projection_size, projection_hidden_size, layer = -2, use_simsiam_mlp = False, sync_batchnorm = None):
        super().__init__()
        self.net = net
        self.layer = layer

        self.projector = None
        self.projection_size = projection_size
        self.projection_hidden_size = projection_hidden_size

        self.use_simsiam_mlp = use_simsiam_mlp
        self.sync_batchnorm = sync_batchnorm

        self.hidden = {}
        self.hook_registered = False

    def _find_layer(self):
        if type(self.layer) == str:
            modules = dict([*self.net.named_modules()])
            return modules.get(self.layer, None)
        elif type(self.layer) == int:
            children = [*self.net.children()]
            return children[self.layer]
        return None

    def _hook(self, _, input, output):
        device = input[0].device
        self.hidden[device] = flatten(output)

    def _register_hook(self):
        layer = self._find_layer()
        assert layer is not None, f'hidden layer ({self.layer}) not found'
        handle = layer.register_forward_hook(self._hook)
        self.hook_registered = True

    @singleton('projector')
    def _get_projector(self, hidden):
        _, dim = hidden.shape
        create_mlp_fn = MLP if not self.use_simsiam_mlp else SimSiamMLP
        projector = create_mlp_fn(dim, self.projection_size, self.projection_hidden_size, sync_batchnorm = self.sync_batchnorm)
        return projector.to(hidden)

    def get_representation(self, x):
        if self.layer == -1:
            return self.net(x)

        if not self.hook_registered:
            self._register_hook()

        self.hidden.clear()
        _ = self.net(x)
        hidden = self.hidden[x.device]
        self.hidden.clear()

        assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
        return hidden

    def forward(self, x, return_projection = True):
        representation = self.get_representation(x)

        if not return_projection:
            return representation

        projector = self._get_projector(representation)
        projection = projector(representation)
        return projection, representation


In [22]:
image_size = 256
DEFAULT_AUG = torch.nn.Sequential(
            RandomApply(
                T.ColorJitter(0.8, 0.8, 0.8, 0.2),
                p = 0.3
            ),
            T.RandomGrayscale(p=0.2),
            T.RandomHorizontalFlip(),
            RandomApply(
                T.GaussianBlur((3, 3), (1.0, 2.0)),
                p = 0.2
            ),
            T.RandomResizedCrop((image_size, image_size)),
            T.Normalize(
                mean=torch.tensor([0.485, 0.456, 0.406]),
                std=torch.tensor([0.229, 0.224, 0.225])),
        )

In [23]:
def loss_fn(x1, x2, temperature=0.7):
    # Normalize input embeddings
    x1_normalized = F.normalize(x1, dim=-1)
    x2_normalized = F.normalize(x2, dim=-1)
    
    # Calculate cosine similarity matrix
    similarity_matrix = torch.matmul(x1_normalized, x2_normalized.T)
    
    # Calculate logits
    logits = similarity_matrix / temperature
    
    # Calculate diagonal terms for positive samples
    diag_terms = torch.diag(logits)
    
    # Calculate numerator (positive term)
    numerator = torch.exp(diag_terms / temperature)
    
    # Calculate denominator (positive and negative terms)
    denominator = torch.sum(torch.exp(logits), dim=1) + torch.exp(diag_terms / temperature)
    
    # Calculate loss
    loss = -torch.mean(torch.log(numerator / denominator))
    
    return loss


In [24]:
class BYOL(nn.Module):
    def __init__(
        self,
        net,
        image_size,
        hidden_layer = -2,
        projection_size = 256,
        projection_hidden_size = 4096,
        augment_fn = None,
        augment_fn2 = None,
        moving_average_decay = 0.99,
        use_momentum = True,
        sync_batchnorm = None
    ):
        super().__init__()
        self.net = net

        # default SimCLR augmentation

        DEFAULT_AUG = torch.nn.Sequential(
            RandomApply(
                T.ColorJitter(0.8, 0.8, 0.8, 0.2),
                p = 0.3
            ),
            T.RandomGrayscale(p=0.2),
            T.RandomHorizontalFlip(),
            RandomApply(
                T.GaussianBlur((3, 3), (1.0, 2.0)),
                p = 0.2
            ),
            T.RandomResizedCrop((image_size, image_size)),
            T.Normalize(
                mean=torch.tensor([0.485, 0.456, 0.406]),
                std=torch.tensor([0.229, 0.224, 0.225])),
        )

        self.augment1 = default(augment_fn, DEFAULT_AUG)
        self.augment2 = default(augment_fn2, self.augment1)

        self.online_encoder = NetWrapper(
            net,
            projection_size,
            projection_hidden_size,
            layer = hidden_layer,
            use_simsiam_mlp = not use_momentum,
            sync_batchnorm = sync_batchnorm
        )

        self.use_momentum = use_momentum
        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)

        self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)

        # get device of network and make wrapper same device
        device = get_module_device(net)
        self.to(device)

        # send a mock image tensor to instantiate singleton parameters
        self.forward(torch.randn(2, 3, image_size, image_size, device=device))

    @singleton('target_encoder')
    def _get_target_encoder(self):
        target_encoder = copy.deepcopy(self.online_encoder)
        set_requires_grad(target_encoder, False)
        return target_encoder

    def reset_moving_average(self):
        del self.target_encoder
        self.target_encoder = None

    def update_moving_average(self):
        assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder'
        assert self.target_encoder is not None, 'target encoder has not been created yet'
        update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)

    def forward(
        self,
        x,
        # y,
        return_embedding = False,
        return_projection = True
    ):
        assert not (self.training and x.shape[0] == 1), 'you must have greater than 1 sample when training, due to the batchnorm in the projection layer'

        if return_embedding:
            return self.online_encoder(x, return_projection = return_projection)

        image_one, image_two = self.augment1(x), self.augment2(x)
        # print("Image shapes:", image_one.shape, image_two.shape)
        images = torch.cat((image_one, image_two), dim = 0)
        # labels = torch.cat((), dim=0)
        # print(images.shape)
        online_projections, _ = self.online_encoder(images)
        online_predictions = self.online_predictor(online_projections)
        # print("Online Projections = online_encoder(aug_1 + aug_2) shapes:", online_projections.shape, _.shape, online_predictions.shape)
        online_pred_one, online_pred_two = online_predictions.chunk(2, dim = 0)
        # print("Pred (chunked) shapes:", online_pred_one.shape, online_pred_two.shape)
        with torch.no_grad():
            target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder

            target_projections, _ = target_encoder(images)
            target_projections = target_projections.detach()

            target_proj_one, target_proj_two = target_projections.chunk(2, dim = 0)

        loss_one = loss_fn(online_pred_one, target_proj_two.detach())
        loss_two = loss_fn(online_pred_two, target_proj_one.detach())

        loss = loss_one + loss_two
        return loss.mean()

In [25]:
import torchvision

img_path = "dataset/style/artbench-10-imagefolder-split/"

data_transform = T.Compose([
    T.ToTensor()
])

train_data = torchvision.datasets.ImageFolder(root=img_path+"train", transform=data_transform)
test_data = torchvision.datasets.ImageFolder(root=img_path+"test", transform=data_transform)

In [26]:
BATCH_SIZE = 128

train_dl =  torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
test_dl =  torch.utils.data.DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=False)

In [27]:
resnet = torchvision.models.resnet50(pretrained=True)
device = 'cuda'
resnet = resnet.to(device)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool',
    use_momentum=False
)

opt = torch.optim.Adam(learner.parameters(), lr=3e-4)

In [28]:
from tqdm import tqdm

import torchvision

img_path = "dataset/style/artbench-10-imagefolder-split/"

data_transform = T.Compose([
    T.ToTensor()
])

test_data = torchvision.datasets.ImageFolder(root=img_path+"test", transform=data_transform)
test_dl =  torch.utils.data.DataLoader(dataset=test_data, batch_size=100, shuffle=False)

with torch.no_grad():
    for images, labels in tqdm(test_dl, desc=f'Epoch {0}/{0}', leave=False):
        images = images.to(device)
        loss = learner(images)
        print(loss.item())
        break

                                                  

9.21093463897705




In [29]:
training = True
resnet.load_state_dict(torch.load(resnet_path))
learner.load_state_dict(torch.load(learner_path))

<All keys matched successfully>

In [30]:
with torch.no_grad():
    for images, labels in tqdm(test_dl, desc=f'Epoch {0}/{0}', leave=False):
        images = images.to(device)
        loss = learner(images)
        print(loss.item())
        break

                                                  

8.525811195373535




In [31]:
# Similarity between same class embeddings
torch.cuda.empty_cache()
idxs = [((i//100)*1000 + (i%100) + 5) for i in range(1000)]
# idxs = [_ for _ in range(18)]
# print(len(idxs))

subset_dataset = torch.utils.data.Subset(test_data, idxs)
shuffle_dl = torch.utils.data.DataLoader(subset_dataset, batch_size=100, shuffle=False)

for images, labels in shuffle_dl:
    images, labels = images.to(device), labels.to(device)
    embeddings = learner(images, return_embedding = True)
    
    img = embeddings[0][0:1]
    embeddings = embeddings[0][1:]
    
    cos_similarities = F.cosine_similarity(img.expand_as(embeddings), embeddings, dim=1)
    mean_cos_similarity = torch.mean(cos_similarities)
    
    print("Mean Cosine Similarity:", mean_cos_similarity.item())
    # break

Mean Cosine Similarity: 0.959464430809021
Mean Cosine Similarity: 0.9797874689102173
Mean Cosine Similarity: 0.9797340631484985
Mean Cosine Similarity: 0.9797700643539429
Mean Cosine Similarity: 0.9392114281654358
Mean Cosine Similarity: 0.9796575903892517
Mean Cosine Similarity: 0.9797666668891907
Mean Cosine Similarity: 0.9646632671356201
Mean Cosine Similarity: 0.9797666668891907
Mean Cosine Similarity: 0.9797442555427551


In [32]:
# Similarity different class embeddings
torch.cuda.empty_cache()
TEST_CASES = 100
true_mean = 0 
for __ in range(TEST_CASES):
    idxs = [
        random.randint(1, 999), 
        random.randint(1001, 1999), 
        random.randint(2001, 2999),
        random.randint(3001, 3999), 
        random.randint(4001, 4999),
        random.randint(5001, 5999), 
        random.randint(6001, 6999),
        random.randint(7001, 7999), 
        random.randint(8001, 8999),
        random.randint(9001, 9999),
    ]

    subset_dataset = torch.utils.data.Subset(test_data, idxs)
    shuffle_dl = torch.utils.data.DataLoader(subset_dataset, batch_size=len(idxs), shuffle=True)
    batch_mean = 0
    for images, labels in shuffle_dl:
        images, labels = images.to(device), labels.to(device)
        embeddings = learner(images, return_embedding = True)
        rand = random.randint(0, 9)
        img = embeddings[0][rand].reshape(1, -1)
        indices = torch.tensor([i for i in range(embeddings[0].size(0)) if i != rand])
        
        embeddings = embeddings[0][indices]
        
        cos_similarities = F.cosine_similarity(img.expand_as(embeddings), embeddings, dim=1)
        mean_cos_similarity = torch.mean(cos_similarities)
        
        print("Mean Cosine Similarity:", mean_cos_similarity.item())
        # break
        batch_mean += (mean_cos_similarity.item()**2)**0.5
    true_mean += batch_mean/len(shuffle_dl)

Mean Cosine Similarity: 0.33603277802467346
Mean Cosine Similarity: 0.30012691020965576
Mean Cosine Similarity: -0.1853371560573578
Mean Cosine Similarity: 0.07703905552625656
Mean Cosine Similarity: 0.11345871537923813
Mean Cosine Similarity: 0.30472663044929504
Mean Cosine Similarity: -0.27860695123672485
Mean Cosine Similarity: 0.14890296757221222
Mean Cosine Similarity: 0.08156626671552658
Mean Cosine Similarity: -0.25515738129615784
Mean Cosine Similarity: 0.46345409750938416
Mean Cosine Similarity: 0.21496708691120148
Mean Cosine Similarity: -0.2696516215801239
Mean Cosine Similarity: -0.24149322509765625
Mean Cosine Similarity: -0.3686572313308716
Mean Cosine Similarity: 0.15797919034957886
Mean Cosine Similarity: 0.10038799792528152
Mean Cosine Similarity: -0.1328631043434143
Mean Cosine Similarity: -0.27926215529441833
Mean Cosine Similarity: 0.14929652214050293
Mean Cosine Similarity: 0.22925111651420593
Mean Cosine Similarity: -0.6491965651512146
Mean Cosine Similarity: -0.2

In [33]:
true_mean / 100

0.23770515604992398

In [34]:
# Similarity in a random batch
torch.cuda.empty_cache()
shuffle_dl = torch.utils.data.DataLoader(dataset=test_data, batch_size=100, shuffle=True)

for images, labels in shuffle_dl:
    images, labels = images.to(device), labels.to(device)
    embeddings = learner(images, return_embedding = True)
    
    img = embeddings[0][0:1]
    embeddings = embeddings[0][1:]
    
    cos_similarities = F.cosine_similarity(img.expand_as(embeddings), embeddings, dim=1)
    
    mean_cos_similarity = torch.mean(cos_similarities)
    
    print("Mean Cosine Similarity:", mean_cos_similarity.item())
    break

Mean Cosine Similarity: 0.9797444939613342
