In [None]:
# Download dataset
!gdown --id 1iXIqmmQWQO5owW03ZOGIi7fdLCjAe4SI --output CelebA
!gdown --id 17it9gGywlyJSYbvkzc3s5mIt2puXcvGq
!unzip -q CelebA

In [None]:
!pip install wandb
!wandb login

In [None]:
import random
import os

import pandas as pd

import numpy as np

import matplotlib.pyplot as plt

from PIL import Image

import wandb

from tqdm import tqdm

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as transforms
import torchvision.models as models

## Create Model

In [None]:
class BaseModel(nn.Module):
    def __init__(self, model_type="resnet18", pretrained=True,
                 out_dim=128):
        super().__init__()
        self.model = getattr(models, model_type)(
            pretrained=pretrained
        )
        self.model.fc = nn.Linear(
            self.model.fc.in_features,
            out_dim
        )

    def forward(self, x):
        return self.model(x)


class TripletNet(nn.Module):
    def __init__(self, model_type="resnet18", pretrained=True,
                 out_dim=128):
        super().__init__()
        self.model = BaseModel(
            model_type=model_type, pretrained=pretrained,
            out_dim=out_dim
        )

    def forward(self, anchor, positive, negative):
        anchor_out = self.model(anchor)
        positive_out = self.model(positive)
        negative_out = self.model(negative)
        return anchor_out, positive_out, negative_out

    def get_features(self, x):
        return self.model(x)

## Create Data Augmentation

In [None]:
class CelebADataset(Dataset):
    def __init__(self, img_root, csv_path, transform=None):
        self.img_root = img_root
        self.csv = pd.read_csv(csv_path)

        self.transform = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.5, 0.5, 0.5],
                std=[0.5, 0.5, 0.5]
            )
        ]) if transform is None else transform
            
    def __getitem__(self, idx):
        # Read in anchor image and get class
        img_name = self.csv.iloc[idx, 0]
        anchor_image = Image.open(f"{self.img_root}/{img_name}")
        label = self.csv.iloc[idx, 1]

        # Positive sampling and remove the redundant images and randomly sample from postive
        positives = self.csv[self.csv.iloc[:, 1] == label]
        positives = positives[positives.iloc[:, 0] != img_name]
        positive_frame = positives.sample(n=1)
        positive_image = Image.open(f"{self.img_root}/{positive_frame.iloc[0, 0]}")

        #choose negative data by taking random sampling except that particular class
        negatives = self.csv[self.csv.iloc[:, 1] != label]
        negative_frame = negatives.sample(n=1)
        negative_image = Image.open(f"{self.img_root}/{negative_frame.iloc[0, 0]}")
        
        if self.transform:
            anchor_image = self.transform(anchor_image)
            positive_image = self.transform(positive_image)
            negative_image = self.transform(negative_image)
       
        return {
            "anchor": anchor_image,
            "positive_image": positive_image,
            "negative_image": negative_image
        }
    
    def __len__(self):
        return len(self.csv)

In [None]:
def make_loader(batch_size, img_root, csv_path):
    train_dataset = CelebADataset(
        img_root=img_root, csv_path=csv_path
    )

    train_loader = DataLoader(
        dataset=train_dataset, batch_size=batch_size,
        shuffle=True, pin_memory=True,
    )

    return train_loader

## Create loss function

In [None]:
class TripletLoss(nn.Module):
    def __init__(self, margin):
        super().__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        distance_positive = ((anchor - positive)**2).sum(axis=0)  
        distance_negative = ((anchor - negative)**2).sum(axis=0) 
        losses = torch.relu(distance_positive - distance_negative + self.margin)
     
        return losses.mean()


## Visualization tools

In [None]:
@torch.no_grad()
def plot_points(model, device, img_root, num_points=10):
    model.eval()
    
    # Choose points
    dirs = os.listdir(img_root)
    plot_image_list = random.choices(dirs, k=num_points)
    base_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize(224),
        transforms.Normalize(
            mean=[0.5, 0.5, 0.5],
            std=[0.5, 0.5, 0.5]
        )
    ])

    # Get features from the list
    features_list = []
    for img_name in plot_image_list:
        img = Image.open(os.path.join(img_root, img_name))
        img = base_transform(img).unsqueeze(0).to(device)
        features = features_list.append(model.get_features(img).cpu()[0].tolist())

    # Do the dimension reduction
    two_dim_pca = PCA(n_components=2).fit_transform(features_list)
    two_dim_tsne = TSNE(n_components=2).fit_transform(features_list)

    # Show plot
    fig = plt.figure(figsize=(12, 6))
    ax1 = fig.add_subplot(1, 2, 1)
    ax1.set_title("PCA")
    ax1.scatter(two_dim_pca[:, 0], two_dim_pca[:, 1], c=list(range(num_points)))
    ax2 = fig.add_subplot(1, 2, 2)
    ax2.set_title("TSNE")
    ax2.scatter(two_dim_tsne[:, 0], two_dim_tsne[:, 1], c=list(range(num_points)))
    
    return fig

## Train function

In [None]:
def train(args):
    os.makedirs(args.weight, exist_ok=True)
    
    model_config = {
        "batch_size": args.batch_size,
        "learning rate": args.lr,
    }
    run = wandb.init(
        project="facial_identity",
        resume=False,
        config=model_config,
    )

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)

    train_loader = make_loader(
        batch_size=args.batch_size, img_root=args.img_root,
        csv_path=args.csv_path
    )

    model = TripletNet(
        model_type=args.model_type, pretrained=args.pretrained,
        out_dim=args.out_dim
    ).to(device)

    pair_dis = nn.PairwiseDistance()

    # Set up hyper-parameters
    criterion = TripletLoss(args.margin)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 8, gamma=0.1, last_epoch=-1)

    for epoch in range(args.epochs):
        epoch_loss = 0

        tqdm_iter = tqdm(
            train_loader,
            bar_format="{l_bar}|{bar}| {n_fmt}/{total_fmt} [{rate_fmt}{postfix}|{elapsed}<{remaining}]"
        )

        for idx, batched_data in enumerate(tqdm_iter):
            model.train()

            # Get data and move to device
            input_anchor = batched_data["anchor"].to(device)
            input_positive = batched_data["positive_image"].to(device)
            input_negative = batched_data["negative_image"].to(device)

            anchor, pos, neg = model(input_anchor, input_positive, input_negative)

            # Compute l2 distance of the model
            pos_dists = pair_dis(anchor, pos)
            neg_dists = pair_dis(anchor, neg)

            all_image = (neg_dists - pos_dists < args.margin).cpu().numpy().flatten()
            valid_triplets = np.where(all_image == 1)

            # Compute loss
            loss = criterion(anchor[valid_triplets], pos[valid_triplets], neg[valid_triplets])

            # Update models
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update the progress bar
            tqdm_iter.set_description(f"Epoch: {epoch + 1}")
            tqdm_iter.set_postfix_str(f"loss={loss.item():^7.3f}")

            # Update scheduler
            scheduler.step()

            if idx % 10 == 0:
                log = {
                    "loss": epoch_loss / len(tqdm_iter),
                    "Image": plot_points(model, img_root=args.img_root, device=device, num_points=1000)
                }
                wandb.log(log)

        # Save the model for every epoch
        torch.save(model.state_dict(), f"{args.weight}/model_{epoch + 1}.pt")

## Set up Config

In [None]:
class Config:
    # Path
    csv_path = "identity.csv"
    img_root = "img_align_celeba"
    weight = "weight"

    # Hyper-parameters
    batch_size = 64
    epochs = 5
    lr = 1e-4

    # Model settings
    model_type = "resnet18"
    pretrained = True
    out_dim = 128
    margin = 0.5

## Start training

In [None]:
args = Config()
train(args)