## Download needed dataset and pacakge
> Wandb should login first!

In [None]:
# Download dataset
!gdown --id 1iXIqmmQWQO5owW03ZOGIi7fdLCjAe4SI --output CelebA
!gdown --id 1bfW4ljiLRQKLdUo68-qqRtxHFU_u4VJK
!unzip -q CelebA
!wget https://github.com/opencv/opencv/raw/master/data/haarcascades/haarcascade_frontalface_alt.xml

In [None]:
!pip install wandb
!wandb login

In [None]:
import os

import random

import cv2

import numpy as np

import pandas as pd

import matplotlib.pyplot as plt

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

from PIL import Image

from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import Dataset, DataLoader

import torchvision.models as models
import torchvision.transforms as transforms

import wandb

## Create model
The model architecture is refered to [here](https://github.com/sthalles/SimCLR).

In [None]:
class SimCLRModel(nn.Module):
    def __init__(self, model_type="resnet18", 
                 pretrained=False, out_dim=128):
        super().__init__()

        # Create model
        if pretrained == False:
            self.model = getattr(models, "resnet18")(
                num_classes=out_dim
            )
        else:
            self.model = getattr(models, "resnet18")(
                pretrained=True
            )

        dim_mlp = self.model.fc.in_features

        # Add projection head
        self.model.fc = nn.Sequential(
            nn.Linear(dim_mlp, out_dim),
            nn.ReLU(inplace=True), 
            nn.Linear(out_dim, out_dim)
        )

    def forward(self, x):
        return self.model(x)

## Create Dataset and Data-Aug
* RandomHorizontalFlip
* RandomResizedCrop
* Gaussian Blur
* Color Jitter
* RandomGrayscale

In [None]:
class CelebADataset(Dataset):
    def __init__(self, csv_path="celebA.csv", 
                 img_root="img_align_celeba", size=32):
        self.img_root = img_root
        self.csv = pd.read_csv(csv_path, index_col=0)
        self.face_cascade = cv2.CascadeClassifier("haarcascade_frontalface_alt.xml")
        color_jitter = transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)
        self.transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.RandomResizedCrop(size=size),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([color_jitter], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=int(0.1 * size)),
        ])

    def __len__(self):
        return len(self.csv)

    def __getitem__(self, idx):
        # Read in image
        img_path = os.path.join(self.img_root, str(self.csv.iloc[idx, 0]))
        img = cv2.imread(img_path)
        # gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

        # # Crop the faces only
        # x, y, w, h = self.face_cascade.detectMultiScale(gray, 1.3, 4)[0]
        # img = cv2.cvtColor(img[y:y+h, x:x+w], cv2.COLOR_BGR2RGB)

        # Transformation
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_one = self.transforms(img)
        img_two = self.transforms(img)

        return img_one, img_two

In [None]:
def make_loader(batch_size):
    dataset = CelebADataset()
    dataloader = DataLoader(
        dataset=dataset, batch_size=batch_size, shuffle=True,
        drop_last=True, pin_memory=True
    )

    return dataset, dataloader

## Visualizing tools

In [None]:
@torch.no_grad()
def plot_points(model, img_root="img_align_celeba", num_points=10):
    model.eval()
    
    # Choose points
    dirs = os.listdir(img_root)
    plot_image_list = random.choices(dirs, k=num_points)
    base_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize(32)
    ])

    # Get features from the list
    features_list = []
    for img_name in plot_image_list:
        img = Image.open(os.path.join(img_root, img_name))
        img = base_transform(img).unsqueeze(0).to(config.device)
        features = features_list.append(model(img).cpu()[0].tolist())

    # Do the dimension reduction
    two_dim_pca = PCA(n_components=2).fit_transform(features_list)
    two_dim_tsne = TSNE(n_components=2).fit_transform(features_list)

    # Show plot
    fig = plt.figure(figsize=(12, 6))
    ax1 = fig.add_subplot(1, 2, 1)
    ax1.set_title("PCA")
    ax1.scatter(two_dim_pca[:, 0], two_dim_pca[:, 1], c=list(range(num_points)))
    ax2 = fig.add_subplot(1, 2, 2)
    ax2.set_title("TSNE")
    ax2.scatter(two_dim_tsne[:, 0], two_dim_tsne[:, 1], c=list(range(num_points)))
    
    return fig

## Create main code

In [None]:
class SimCLR:
    def __init__(self, args):
        self.args = args
        _, self.train_loader = make_loader(batch_size=args.batch_size)
        self.model = SimCLRModel(args.model_type, args.pretrained, args.out_dim).to(self.args.device)
        self.criterion = torch.nn.CrossEntropyLoss().to(self.args.device)
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(), args.lr, weight_decay=args.weight_decay
        )
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=len(self.train_loader), eta_min=0, last_epoch=-1
        )
        os.makedirs("weight", exist_ok=True)


    def info_nce_loss(self, features):
        labels = torch.cat([torch.arange(self.args.batch_size) for i in range(self.args.n_views)], dim=0)
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
        labels = labels.to(self.args.device)

        features = F.normalize(features, dim=1)

        similarity_matrix = torch.matmul(features, features.T)

        # Discard the main diagonal from both: labels and similarities matrix
        mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.args.device)
        labels = labels[~mask].view(labels.shape[0], -1)
        similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)

        # Select and combine multiple positives
        positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)

        # Select only the negatives the negatives
        negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

        logits = torch.cat([positives, negatives], dim=1)
        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.args.device)

        logits = logits / self.args.temperature
        return logits, labels

    def train(self,):
        model_config = {
            "batch_size": self.args.batch_size,
            "epochs": self.args.epochs,
            "learning rate": self.args.lr,
            "temerature": self.args.temperature
        }

        run = wandb.init(
            project="facial_identity",
            resume=False,
            config=model_config,
        )

        self.model.train()

        scaler = GradScaler(enabled=self.args.fp16_precision)

        for epoch in range(self.args.epochs):
            tqdm_iter = tqdm(
                self.train_loader,
                bar_format="{l_bar}|{bar}| {n_fmt}/{total_fmt} [{rate_fmt}{postfix}|{elapsed}<{remaining}]",
            )

            epoch_loss = 0.0
            for images in tqdm_iter:
                images = torch.cat(images, dim=0)

                images = images.to(self.args.device)

                with autocast(enabled=self.args.fp16_precision):
                    features = self.model(images)
                    logits, labels = self.info_nce_loss(features)
                    loss = self.criterion(logits, labels)

                self.optimizer.zero_grad()

                scaler.scale(loss).backward()

                scaler.step(self.optimizer)
                scaler.update()

                tqdm_iter.set_description(f"Epoch: {epoch + 1}")
                tqdm_iter.set_postfix_str(f"loss={loss.item():^7.3f}")

                epoch_loss += loss.item()

            log = {
                "epoch": epoch + 1,
                "loss": epoch_loss / len(tqdm_iter),
                "Image": plot_points(self.model, num_points=1000) 
            }
            wandb.log(log)

            # Save the model every 5 epochs
            if epoch % 5 == 0:
                torch.save(self.model.state_dict(), f"weight/model_{epoch + 1}.pt")

            # Warmup for the first 10 epochs
            if epoch >= 10:
                self.scheduler.step()

## Start training

In [None]:
class Config:
    # Model settings
    model_type = "resnet18"
    out_dim = 128
    pretrained = False

    # Training device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Hyper parameters
    epochs = 100
    batch_size = 1024
    lr = 3e-4
    weight_decay = 1e-4
    fp16_precision = True
    temperature = 0.7
    n_views = 2

config = Config()

In [None]:
simclr = SimCLR(args=config)
simclr.train()