 
based on https://github.com/JamesQFreeman/contrastive_learning_in_100_lines

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%cd drive/MyDrive

/content/drive/MyDrive


In [None]:
# %mkdir SimClR

In [None]:
%cd SimClR
!ls

/content/drive/MyDrive/SimClR
contrastive_learning_in_100_lines-main	main.zip


In [None]:
# !wget https://github.com/JamesQFreeman/contrastive_learning_in_100_lines/archive/refs/heads/main.zip

In [None]:
# !unzip main.zip

In [None]:
%cd contrastive_learning_in_100_lines-main/

/content/drive/MyDrive/SimClR/contrastive_learning_in_100_lines-main


In [None]:
import torchvision
import glob
from PIL import Image
from torchvision import transforms as T
import random
from typing import ValuesView
import torch
from torch import nn
from torchvision import models
from tqdm import tqdm

In [None]:
class ImageNet_5Class(torch.utils.data.Dataset):
    """Some Information about ImageNet_5Class"""

    def __init__(self, train: bool = True, augmentation: bool = False, annotation: bool = False, ratio: float = 1.0):
        super(ImageNet_5Class, self).__init__()
        data_dir = 'data/train/' if train else 'data/test/'
        image_list = glob.glob(f'{data_dir}/*.jpg')
        self.image_list = random.sample(image_list, int(ratio*len(image_list)))
        self.augmentation = augmentation
        self.annotation = annotation

    def __getitem__(self, index):
        img_dir = self.image_list[index]
        pil_img = Image.open(img_dir)
        no_transform = T.Compose([
            T.Resize((224, 224)),
            T.ToTensor(),
            T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
        basic_transform = T.Compose([
            T.RandomHorizontalFlip(),
            T.RandomResizedCrop((224, 224)),
            T.ToTensor(),
            T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])

        tensor_img = basic_transform(
            pil_img) if self.augmentation else no_transform(pil_img)

        if self.annotation:
            img_label = (img_dir.split('/')[-1]).split('_')[0]
            label = ["airplane", "car", "cat",
                     "dog", "elephant"].index(img_label)
            return tensor_img, label
        else:
            return tensor_img

    def __len__(self):
        return len(self.image_list)

In [None]:
class RandomApply(torch.nn.Module):
    def __init__(self, fn, p):
        super().__init__()
        self.fn = fn
        self.p = p

    def forward(self, x):
        if random.random() > self.p:
            return x
        return self.fn(x)


SimCLR_augment = torch.nn.Sequential(
    RandomApply(
        T.ColorJitter(0.8, 0.8, 0.8, 0.2),
        p=0.3
    ),
    T.RandomGrayscale(p=0.2),
    T.RandomHorizontalFlip(),
    RandomApply(
        T.GaussianBlur((3, 3), (1.0, 2.0)),
        p=0.2
    ),
    T.RandomResizedCrop((224, 224)),
    T.Normalize(
        mean=torch.tensor([0.485, 0.456, 0.406]),
        std=torch.tensor([0.229, 0.224, 0.225])),
)

In [None]:
def get_encoder(net: nn.Module) -> nn.Module:
    """ input a network and output it's convolutional feature encoder"""
    return nn.Sequential(*(list(net.children())[:-1]))


def MLP(in_size, out_size, hidden_size=4096):
    return nn.Sequential(
        nn.Linear(in_size, hidden_size),
        nn.BatchNorm1d(hidden_size),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_size, out_size)
    )

def InfoNCE(x1,x2):
    pass

class SimCLR(nn.Module):
    def __init__(self, net: nn.Module) -> None:
        super().__init__()
        num_features = net.fc.in_features
        self.augment1 = SimCLR_augment
        self.augment2 = SimCLR_augment

        self.encoder = get_encoder(net)
        self.projector = MLP(in_size=num_features, out_size=256)

        self.criterion = nn.CosineSimilarity(dim=1)

    def forward(self, x):
        view1, view2 = self.augment1(x), self.augment2(x)
        proj1, proj2 = self.projector(self.encoder(
            view1)), self.projector(self.encoder(view2))
        pred1, pred2 = self.predictor(proj1), self.predictor(proj2)

        loss = nn
        loss = -(self.criterion(proj1, pred2).mean() +
                 self.criterion(proj2, pred1).mean()) * 0.5
        return loss.mean()

In [None]:
def get_encoder(net: nn.Module) -> nn.Module:
    """ input a network and output it's convolutional feature encoder"""
    return nn.Sequential(*(list(net.children())[:-1]))


def loss_fn(x, y):
    x = nn.functional.normalize(x, dim=-1, p=2)
    y = nn.functional.normalize(y, dim=-1, p=2)
    return 2 - 2*(x*y).sum(dim=-1)


def MLP(in_size, out_size, hidden_size=4096):
    return nn.Sequential(
        nn.Linear(in_size, hidden_size),
        nn.BatchNorm1d(hidden_size),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_size, out_size)
    )


def EMA(moving_average_model: nn.Module, current_model: nn.Module, beta: float):
    """Exponential Moving Average"""
    for c_params, ma_params in zip(current_model.parameters(), moving_average_model.parameters()):
        ma_weight, c_weight = ma_params.data, c_params.data
        ma_params.data = beta*ma_weight + (1-beta)*c_weight


class BYOL(nn.Module):
    def __init__(self, net: nn.Module, moving_average_decay: float = 0.99) -> None:
        super().__init__()
        num_features = net.fc.in_features
        self.augment1 = SimCLR_augment
        self.augment2 = SimCLR_augment
        self.target_encoder = get_encoder(net)
        self.online_encoder = get_encoder(net)  # they have same weight
        self.target_projector = MLP(in_size=num_features, out_size=256)
        self.online_projector = MLP(in_size=num_features, out_size=256)

        # use EMA to copy weight of target to online for initialization
        EMA(self.target_encoder, self.online_encoder, beta=1)
        EMA(self.target_projector, self.online_projector, beta=1)

        self.online_predictor = MLP(in_size=256, out_size=256)
        self.moving_average_decay = moving_average_decay

    def update_moving_average(self):
        EMA(self.target_encoder, self.online_encoder, self.moving_average_decay)
        EMA(self.target_projector, self.online_projector, self.moving_average_decay)

    def online_pipeline(self, x):
        return self.online_predictor(self.online_projector(torch.flatten(self.online_encoder(x), 1)))

    def target_pipeline(self, x):
        return self.target_projector(torch.flatten(self.target_encoder(x), 1))

    def forward(self, x):
        view1, view2 = self.augment1(x), self.augment2(x)
        pred1, pred2 = self.online_pipeline(view1), self.online_pipeline(view2)
        with torch.no_grad():
            proj1, proj2 = self.target_pipeline(
                view1), self.target_pipeline(view2)
        loss = loss_fn(pred1, proj2.detach()) + loss_fn(pred2, proj1.detach())
        return loss.mean()

In [None]:
def get_encoder(net: nn.Module) -> nn.Module:
    """ input a network and output it's convolutional feature encoder"""
    return nn.Sequential(*(list(net.children())[:-1]))


def MLP(in_size, out_size, hidden_size=4096):
    return nn.Sequential(
        nn.Linear(in_size, hidden_size),
        nn.BatchNorm1d(hidden_size),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_size, out_size)
    )


class SimSiam(nn.Module):
    def __init__(self, net: nn.Module) -> None:
        super().__init__()
        num_features = net.fc.in_features
        self.augment1 = SimCLR_augment
        self.augment2 = SimCLR_augment

        self.encoder = get_encoder(net)
        self.projector = MLP(in_size=num_features, out_size=256)
        self.predictor = MLP(in_size=256, out_size=256)

        self.criterion = nn.CosineSimilarity(dim=1)

    def forward(self, x):
        view1, view2 = self.augment1(x), self.augment2(x)
        proj1, proj2 = self.projector(self.encoder(view1)), self.projector(self.encoder(view2))
        pred1, pred2 = self.predictor(proj1), self.predictor(proj2)
        loss = -(self.criterion(proj1, pred2).mean() +
                 self.criterion(proj2, pred1).mean()) * 0.5
        return loss.mean()

In [None]:
n_epoch = 5

In [None]:
my_dataset = ImageNet_5Class(augmentation=False, annotation=False)
my_dataset.__len__()

1250

In [None]:
trainloader = torch.utils.data.DataLoader(
    my_dataset, batch_size=64, shuffle=True)

In [None]:
# Can be pretrained=True if self-supervised on IN pretrained weight is prefered
resnet = models.resnet50()
# Optimal hyperparameter is set according to paper. Change that in model file
learner = BYOL(net=resnet)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
optimizer = torch.optim.Adam(learner.parameters(), lr=3e-4)

learner.to(device)

BYOL(
  (augment1): Sequential(
    (0): RandomApply(
      (fn): ColorJitter(brightness=[0.19999999999999996, 1.8], contrast=[0.19999999999999996, 1.8], saturation=[0.19999999999999996, 1.8], hue=[-0.2, 0.2])
    )
    (1): RandomGrayscale(p=0.2)
    (2): RandomHorizontalFlip(p=0.5)
    (3): RandomApply(
      (fn): GaussianBlur(kernel_size=(3, 3), sigma=(1.0, 2.0))
    )
    (4): RandomResizedCrop(size=(224, 224), scale=(0.08, 1.0), ratio=(0.75, 1.3333), interpolation=bilinear)
    (5): Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
  )
  (augment2): Sequential(
    (0): RandomApply(
      (fn): ColorJitter(brightness=[0.19999999999999996, 1.8], contrast=[0.19999999999999996, 1.8], saturation=[0.19999999999999996, 1.8], hue=[-0.2, 0.2])
    )
    (1): RandomGrayscale(p=0.2)
    (2): RandomHorizontalFlip(p=0.5)
    (3): RandomApply(
      (fn): GaussianBlur(kernel_size=(3, 3), sigma=(1.0, 2.0))
    )
    (4): RandomResizedCrop(size=(224, 224), s

In [None]:
import numpy as np

In [None]:

# loop over the dataset/multiple times
for _ in tqdm(range(n_epoch)):
    running_loss = 0.0
    for _, data in enumerate(trainloader, 0):
        inputs = data
        inputs = inputs.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        loss = learner(inputs)
        loss.backward()
        optimizer.step()
        learner.update_moving_average()

        running_loss += loss.item()

    print('Loss: {}'.format(running_loss))

torch.save(learner.online_encoder.state_dict(), "encoder.pth")
print('Finished Training')



  0%|          | 0/5 [00:00<?, ?it/s]