In [None]:
pip install pytorch-lightning

In [None]:
import os
import urllib.request
from copy import deepcopy
from urllib.error import HTTPError

import matplotlib
import matplotlib.pyplot as plt
import pytorch_lightning as pl
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as DataLoader

from IPython.display import set_matplotlib_formats
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import Callback

import torchvision
from torchvision import transforms
import torchvision.models as models
from torchvision import datasets
from torchvision.datasets import STL10
from tqdm.notebook import tqdm


from torch.optim import Adam

import numpy as np
from torch.optim.lr_scheduler import OneCycleLR

import zipfile
from PIL import Image
import cv2

In [None]:
print(torch.__version__, torchvision.__version__, pl.__version__)

In [None]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
NUM_WORKERS = os.cpu_count()
print("Device:", device)
print("Number of workers:", NUM_WORKERS)

In [None]:
pl.seed_everything(96)

In [None]:
class DataAugTransform:
    def __init__(self, base_transforms, n_views=2):
        self.base_transforms = base_transforms
        self.n_views = n_views

    def __call__(self, x):
        return [self.base_transforms(x) for i in range(self.n_views)]
augmentation_transforms = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop(size=96),
        transforms.RandomApply([transforms.ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.1)], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        #transforms.GaussianBlur(kernel_size=9),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)

In [None]:
# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = os.environ.get("PATH_DATASETS", "bookdata/")
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "booksaved_models/")

In [None]:
unlabeled_data = STL10(
    root=DATASET_PATH,
    split="unlabeled",
    download=True,
    transform=DataAugTransform(augmentation_transforms, n_views=2),
)
train_data_contrast = STL10(
    root=DATASET_PATH,
    split="train",
    download=True,
    transform=DataAugTransform(augmentation_transforms, n_views=2),
)

In [None]:
# Visualize some examples
pl.seed_everything(96)
NUM_IMAGES = 20
imgs = torch.stack([img for idx in range(NUM_IMAGES) for img in unlabeled_data[idx][0]], dim=0)
img_grid = torchvision.utils.make_grid(imgs, nrow=10, normalize=True, pad_value=0.9)
img_grid = img_grid.permute(1, 2, 0)

plt.figure(figsize=(20, 10))
#plt.title("Augmented image examples of the STL10 dataset")
plt.imshow(img_grid)
plt.axis("off")
plt.show()
#plt.close()

In [None]:
class NTXentLoss(torch.nn.Module):

    def __init__(self, batch_size, temperature, use_cosine_similarity):
        super(NTXentLoss, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.softmax = torch.nn.Softmax(dim=-1)
        self.mask_samples_from_same_repr = self._get_correlated_mask()
        self.similarity_function = self._get_similarity_function(use_cosine_similarity)
        self.criterion = torch.nn.CrossEntropyLoss(reduction="sum")

    def _get_similarity_function(self, use_cosine_similarity):
        if use_cosine_similarity:
            self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
            return self._cosine_simililarity
        else:
            return self._dot_simililarity

    def _get_correlated_mask(self):
        diag = np.eye(2 * self.batch_size)
        l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
        l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
        mask = torch.from_numpy((diag + l1 + l2))
        mask = (1 - mask).type(torch.bool)
        return mask

    @staticmethod
    def _dot_simililarity(x, y):
        v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
        # x shape: (N, 1, C)
        # y shape: (1, C, 2N)
        # v shape: (N, 2N)
        return v

    def _cosine_simililarity(self, x, y):
        # x shape: (N, 1, C)
        # y shape: (1, 2N, C)
        # v shape: (N, 2N)
        v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
        return v

    def forward(self, zis, zjs):
        representations = torch.cat([zjs, zis], dim=0)

        similarity_matrix = self.similarity_function(representations, representations)

        # filter out the scores from the positive samples
        l_pos = torch.diag(similarity_matrix, self.batch_size)
        r_pos = torch.diag(similarity_matrix, -self.batch_size)
        positives = torch.cat([l_pos, r_pos]).view(2 * self.batch_size, 1)

        negatives = similarity_matrix[self.mask_samples_from_same_repr].view(2 * self.batch_size, -1)

        logits = torch.cat((positives, negatives), dim=1)
        logits /= self.temperature

        labels = torch.zeros(2 * self.batch_size).long()
        loss = self.criterion(logits, labels)

        return loss / (2 * self.batch_size)

In [None]:
class ResNetSimCLR(nn.Module):

    def __init__(self, base_model, out_dim, freeze=True):
        super(ResNetSimCLR, self).__init__()
        
        # Number of input features into the last linear layer
        num_ftrs = base_model.fc.in_features
        # Remove last layer of resnet
        self.features = nn.Sequential(*list(base_model.children())[:-1])
        if freeze:
            self._freeze()

        # header projection MLP - for SimCLR 
        self.l1 = nn.Linear(num_ftrs, 2*num_ftrs)
        self.l2_bn = nn.BatchNorm1d(2*num_ftrs)
        self.l2 = nn.Linear(2*num_ftrs, num_ftrs)
        self.l3_bn = nn.BatchNorm1d(num_ftrs)
        self.l3 = nn.Linear(num_ftrs, out_dim)

    def _freeze(self):
        num_layers = len(list(self.features.children())) # 9 layers, freeze all but last 2
        current_layer = 1
        for child in list(self.features.children()):
            if current_layer > num_layers-2:
                for param in child.parameters():
                    param.requires_grad = True
            else:
                for param in child.parameters():
                    param.requires_grad = False
            current_layer += 1

    def forward(self, x):
        h = self.features(x)
        h = h.squeeze()

        if len(h.shape) == 1:
            h = h.unsqueeze(0)

        x_l1 = self.l1(h)
        x = self.l2_bn(x_l1)
        x = F.selu(x)
        x = self.l2(x)
        x = self.l3_bn(x)
        x = F.selu(x)
        x = self.l3(x)
        return h, x_l1, x

In [None]:
import yaml # Handles config file loading
# Load config file
config = '''
batch_size: 128
epochs: 100
weight_decay: 10e-6
out_dim: 256

dataset:
  s: 1
  input_shape: (96,96,3)
  num_workers: 2

optimizer:
  lr: 0.0001

loss:
  temperature: 0.05
  use_cosine_similarity: True

lr_schedule:
  max_lr: .1
  total_steps: 1500
'''
config = yaml.full_load(config)

In [None]:
class simCLR(pl.LightningModule):

    def __init__(self, model, config, optimizer=Adam, loss=NTXentLoss):
        super(simCLR, self).__init__()
        # Config file (dictionary) to pass on parameters to each module: optimizer, loss, lr_schedule, 
        self.config = config

        # Optimizer
        self.optimizer = optimizer

        # Model
        self.model = model
        
        # Loss
        self.loss = loss(self.config['batch_size'], **self.config['loss'])

    # Prediction/inference
    def forward(self, x):
        return self.model(x)

    # Sets up optimizer
    def configure_optimizers(self):
        optimizer = self.optimizer(self.parameters(), **self.config['optimizer'])
        scheduler = OneCycleLR(optimizer, **self.config["lr_schedule"])
        return [optimizer], [scheduler]

    # Training loops
    def training_step(self, batch, batch_idx):
        x, y = batch
        xis, xjs = x
        ris, _, zis = self(xis)
        rjs, _, zjs = self(xjs)

        zis = F.normalize(zis, dim=1)
        zjs = F.normalize(zjs, dim=1)

        loss = self.loss(zis, zjs)
        return loss

    # Validation step
    def validation_step(self, batch, batch_idx):
        x, y = batch
        xis, xjs = x
        ris, _, zis = self(xis)
        rjs, _, zjs = self(xjs)

        zis = F.normalize(zis, dim=1)
        zjs = F.normalize(zjs, dim=1)

        loss = self.loss(zis, zjs)
        self.log('val_loss', loss)
        return loss

    def test_step(self, batch, batch_idx):
        loss = None
        return loss

def _get_model_checkpoint():
    return ModelCheckpoint(
        filepath=os.path.join(os.getcwd(),"checkpoints","best_val_models"),
        save_top_k = 3,
        monitor="val_loss"
    )

In [None]:
train_loader = DataLoader.DataLoader(
            unlabeled_data,
            batch_size=128,
            shuffle=True,
            drop_last=True,
            pin_memory=True,
            num_workers=NUM_WORKERS,
        )
val_loader = DataLoader.DataLoader(
            train_data_contrast,
            batch_size=128,
            shuffle=False,
            drop_last=False,
            pin_memory=True,
            num_workers=NUM_WORKERS,
        )

In [None]:
resnet = models.resnet50(pretrained=True)

simclr_resnet = ResNetSimCLR(base_model=resnet, out_dim=config['out_dim'])

In [None]:
# Creates the simCLR model with the specified architecture from aboce
model = simCLR(config=config, model=simclr_resnet)

In [None]:
# Initializes the model trainer
trainer = pl.Trainer()

In [None]:
# Fits the model
trainer.fit(model, train_loader, val_loader)

In [None]:
# %tensorboard --logdir ../saved_models/tutorial17/tensorboards/SimCLR/