# SimCLR

<img src="img/SimCLR.png" width="30%">  

In [None]:
import torch
import math

import os
import PIL.Image as Image
import numpy as np
from os import listdir
from os import walk
import glob
from torch.utils.data import Dataset
from torchvision import transforms

import simsiam.loader

import numpy as np
from torchvision import transforms
from PIL import Image, ImageFilter
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt

import torchvision.models as models
from facenet_pytorch import InceptionResnetV1
from torch import nn
import torch.nn.functional as F
from sync_batchnorm import convert_model
from pytorch_metric_learning.losses import NTXentLoss

## GPU Check

In [None]:
DATA_PATH_TRAIN = '../../dataset/face_labeled_data/train'
DATA_PATH_VAL = '../../dataset/face_labeled_data/val'
batch_size = 256
# learning rate
# init_learning_rate = 0.05 * batch_size / 256
init_learning_rate = 0.005 * batch_size / 256
momentum_val = 0.9
weight_decay_val = 1e-4
output_foloder = 'model_simCLR_lin'
WORKERS = 16

print('torch version:' + torch.__version__)

if torch.cuda.is_available():
    device = torch.device('cuda')
    print('Available GPUs: ', end='')
    for i in range(torch.cuda.device_count()):
        print(torch.cuda.get_device_name(i), end=' ')
else:
    device = torch.device('cpu')
    print('CUDA is not available.')

## define dataset

In [None]:
class FaceImages(Dataset):
    
    def __init__(self, img_dir, transform, specific = '**'):
        self.img_dir = img_dir
        self.img_path_list = glob.glob(os.path.join(img_dir, specific + '/*.jpg'))
        self.transform = transform
        
    def __len__(self):
        return len(self.img_path_list)
    
    def __getitem__(self, idx):
        img_path = self.img_path_list[idx]
        img = FaceImages.read_image(img_path)
        target = int(img_path.split('/')[5])
        return self.transform(img), self.transform(img), target
    
    @staticmethod
    def read_image(img_path):
        #return cv2.imread(img_path)
        return Image.open(img_path, mode='r').convert('RGB')

## define data augmentation

In [None]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                          
def get_aug_trnsform(s=1.0):
    color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.1)

    transform = transforms.Compose([
        transforms.RandomResizedCrop(80, scale=(0.2, 1.)),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomApply([simsiam.loader.GaussianBlur([.1, 2.])], p=0.5),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
    ])
    
    return transform

def ge_eval_trnsform(s=1.0):
    color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.1)

    transform = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])
    
    return transform

aug = get_aug_trnsform(0.5)

## define model

In [None]:
# make embedding to length=1
class L2_norm(nn.Module):
    def __init__(self):
        super(L2_norm, self).__init__()

    def forward(self, x):
        return F.normalize(x, p=2, dim=-1)

In [None]:
# encoder
encoder = InceptionResnetV1()
# projector
projector = nn.Sequential(
    nn.Linear(512, 512)
) 
encoder = nn.DataParallel(encoder)
# after convert, m is using SyncBN
encoder = convert_model(encoder)

projector = nn.DataParallel(projector)
# after convert, m is using SyncBN
projector = convert_model(projector)

encoder.to(device)
projector.to(device)

In [None]:
# optimizer
from torch import optim
optim_params_encoder = [{'params': encoder.parameters(), 'fix_lr': False}]
optim_params_projector = [{'params': projector.parameters(), 'fix_lr': False}]

encoder_opt = optim.SGD(optim_params_encoder, lr=init_learning_rate, momentum = momentum_val, weight_decay = weight_decay_val)
projector_opt = optim.SGD(optim_params_projector, lr=init_learning_rate, momentum = momentum_val, weight_decay = weight_decay_val)

### save model

In [None]:
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

### 動態lr 

In [None]:
def adjust_learning_rate(optimizer, init_lr, epoch, epochs):
    """Decay the learning rate based on schedule"""
    cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * epoch / epochs))
    for param_group in optimizer.param_groups:
        if 'fix_lr' in param_group and param_group['fix_lr']:
            param_group['lr'] = init_lr
        else:
            param_group['lr'] = cur_lr

### define loss_fnc

In [None]:
loss_func = NTXentLoss(temperature=0.10)

### define accuracy

In [None]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

## Train

In [None]:
def pass_epoch(loader, mode = 'Train'):
    loss = 0
    loss_sim = 0
    loss_var = 0
    loss_cov = 0
    if (mode == 'Train'):
        encoder.train()
        projector.train()
    else:
        encoder.eval()
        projector.eval()
    for i_batch, image_batch in tqdm(enumerate(loader)):
        target  = torch.stack(list(image_batch[2]), dim=0).to(device)
        x1, x2 = image_batch[0].to(device), image_batch[1].to(device)
        # forward
        y1, y2 = encoder(x1), encoder(x2)
        z1, z2 = projector(y1), projector(y2)

        # compute loos
        embeddings = torch.cat((z1, z2))
        indices = torch.arange(0, z1.size(0), device=z1.device)
        labels = torch.cat((indices, indices))
        loss_batch = loss_func(embeddings, labels)
        
        loss += loss_batch

        if mode == 'Train':
            # update
            encoder_opt.zero_grad()
            projector_opt.zero_grad()
            loss_batch.backward()
            encoder_opt.step()
            projector_opt.step()        
    return loss / (i_batch + 1)

In [None]:
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from IPython.display import clear_output
from tqdm import tqdm 

dataset_train = FaceImages(DATA_PATH_TRAIN, aug)
dataloader_train = DataLoader(dataset_train, batch_size=batch_size, num_workers=WORKERS, shuffle=True)
dataset_val = FaceImages(DATA_PATH_VAL, aug)
dataloader_val = DataLoader(dataset_val, batch_size=batch_size, num_workers=WORKERS, shuffle=True)

epoch = 100

loss_history_train = []
loss_history_val = []

def update_loss_hist(train_list, val_list, name='result'):
    clear_output(wait=True)
    plt.plot(train_list)
    plt.plot(val_list)
    plt.title(name)
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['train', 'val'], loc='center right')
    plt.savefig('./{}/{}.png'.format(output_foloder, name))
    plt.show()

# train loop
for i in range(epoch):
    adjust_learning_rate(encoder_opt, init_learning_rate, i, epoch)
    adjust_learning_rate(projector_opt, init_learning_rate, i, epoch)
    
    train_loss = pass_epoch(dataloader_train, 'Train')
    with torch.no_grad():
        val_loss = pass_epoch(dataloader_val, 'Eval')

    
    loss_history_train.append(train_loss)
    loss_history_val.append(val_loss)
    update_loss_hist(loss_history_train, loss_history_val, 'NTXentLoss')

    save_checkpoint({
        'epoch': i + 1,
        'arch': 'SimCLR',
        'state_dict': encoder.state_dict(),
        'optimizer' : encoder.state_dict(),
    }, is_best=False, filename='./{}/checkpoint_{:04d}.pth.tar'.format(output_foloder, i + 1))
torch.save(encoder, './{}/checkpoint.pth.tar'.format(output_foloder))
torch.cuda.empty_cache()

### collapse check(simularity matrix)

In [None]:
MODEL_PATH = './{}/checkpoint.pth.tar'.format(output_foloder)
model = torch.load(MODEL_PATH).to(device).eval()

dataset_eval = FaceImages(DATA_PATH_VAL, ge_eval_trnsform(0.5))
dataloader_eval = DataLoader(dataset_eval, batch_size=batch_size, shuffle=False)

def collapseCheck(model, loader):
    x, _, _ = next(iter(loader))
    h = model(x.to(device))
    h_norm = h / h.norm(dim=1)[:, None]
    res = torch.mm(h_norm, h_norm.transpose(0,1))
    print(res.cpu().detach().numpy())
    
collapseCheck(model, dataloader_eval)