### Imports & Configs

In [4]:
import torch
import torch.nn as NN
import torchvision

from torch.utils.data import Dataset, DataLoader

from torchvision.datasets import ImageFolder
from torchvision.transforms import RandomResizedCrop, RandomHorizontalFlip, ColorJitter, RandomGrayscale, Compose, GaussianBlur, ToTensor, RandomApply

import torchvision.models as models

import os
import glob
import time
from skimage import io
import matplotlib.pyplot as plt

In [5]:
print(torch.__version__)

1.13.1+cu117


In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


### SimCLR is based out of the following simplified modules:

* A stochastic data augmentation module.
* A neural network base encoder  f(.)
* A neural network projection head  g(.)
* A contrastive loss function.

### Stochastic Augmentation

* Author suggest that composition of data augmentation operations is crucial for learning good representations.
* Contrastive learning needs stronger data augmentation than supervised learning


In [8]:
def complete_image_transform(output_shape, kernel_size, s=1.0):
    """
    This method will transform the image as suggested by author
    - Random Crop with Resize
    - Random Horizontal Flip with 50% probability
    - Random Color Distortion
       - Random Color Jitter with 80% probability
       - Random Color Drop with 20% probability
    - Random Gaussian Blur with 50% probability
    
    s = filter strength paramter 
    """
    random_crop = RandomResizedCrop(output_shape)
    random_flip = RandomHorizontalFlip(p=0.5)
    
    color_jitter = ColorJitter(brightness=0.8*s, contrast=0.8*s, saturation=0.8*s, hue=0.2*S)
    random_color_jitter = RandomApply([color_jitter], p=0.8)
    
    random_gray = RandomGrayscale(p=0.2)
    gaussian_blur = GaussianBlur(kernel_size=kernel_size)
    random_gaussian_blur = RandomApply([gaussian_blur], p=0.5)
    
    to_tensor = ToTensor()
    
    image_transform = Compose([
        to_tensor,
        random_crop,
        random_flip,
        random_color_jitter,
        random_gray,
        random_gaussian_blur        
    ])
    raise image_transform
    
class ContrastiveLearningViewGenerator(object):
    def __init__(self, base_transform, n_views = 2):
        self.base_transform = base_transform
        self.n_views = n_views
        
    def __call__(self, x):
        views = [self.base_transform(x) for i in range(self.n_views)]
        return views

### Prepare Data

* _torch.utils.data_ is the main class which help us in loading the data and iterating over it
* **torch.utils.data.DataSet** -> helps encapsulating raw data
* **torch.utils.data.DataLoader** -> helps iterating over the dataset