Three image NT-Xent Loss

In [1]:
import os
import csv
import torch
import random
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

from tqdm import tqdm
from PIL import Image
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.models.resnet import resnet50
from torchvision.datasets import CIFAR10

In [2]:
# Path for CIFAR10 data 
data_path = './data'

# Path for results
result_path = './result'
os.makedirs(result_path, exist_ok=True)

# Path for model/net storage
model_path = './models'
os.makedirs(model_path, exist_ok=True)

# Check device and apply the best hardware for Convolution computation
device = "cuda" if torch.cuda.is_available() else "CPU"
torch.backends.cudnn.benchmark = True if device == "cuda" else False

In [3]:
class CIFAR10_Dataset(CIFAR10):
    def __init__(self, img_num=2, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.img_num = img_num  # Store the number of images to generate per sample
        
    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        imgs = []
        if self.transform is not None:
            for idx in range(self.img_num):
                imgs.append(self.transform(img))

        if self.target_transform is not None:
            target = self.target_transform(target)

        return imgs, target

In [4]:
# Training set:
# 1. Crop || 2. Horizontal Flip || 3. Color Jitter || 4. Grayscale
# Apply normalization
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])

# Training set:
# Only apply normalization
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])

In [5]:
# Unsupervised Learning - SimCLR with encoder and projection head
class SimCLR(nn.Module):
    def __init__(self, feature_dim=128):
        super(SimCLR, self).__init__()

        # Funtion F - Encoder
        # ResNet50 Implementation without linear and maxpool layer
        self.f = []
        for name, module in resnet50().named_children():
            # First Convolution layer changed to 3 channel input and 64 channel output to fit CIFAR10
            if name == 'conv1':
                module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            # Drop all the final linear layer and max pool layer
            if not isinstance(module, nn.Linear) and not isinstance(module, nn.MaxPool2d):
                self.f.append(module)
        self.f = nn.Sequential(*self.f)
        
        # Funtion G - Projection head
        # Two fully connected layer 
        self.g = nn.Sequential(nn.Linear(2048, 512, bias=False), nn.BatchNorm1d(512),
                               nn.ReLU(inplace=True), nn.Linear(512, feature_dim, bias=True))

    # F -> Flatten -> G
    def forward(self, x):
        x = self.f(x)
        feature = torch.flatten(x, start_dim=1)
        out = self.g(feature)
        return F.normalize(feature, dim=-1), F.normalize(out, dim=-1)


# Supervised learning - Classifier with encoder and linear model
class Classifier(nn.Module):
    def __init__(self, num_class=10):
        super(Classifier, self).__init__()
        
        # Same function head - encoder
        self.f = SimCLR().f
        # FIX the encoder that does not update any parameter
        for param in self.f.parameters():
            param.requires_grad = False
        
        # Linear layer - classifier
        self.fc = nn.Linear(2048, num_class, bias=True)

    # F -> Flatten -> linear
    def forward(self, x):
        x = self.f(x)
        feature = torch.flatten(x, start_dim=1)
        out = self.fc(feature)
        return out

In [6]:
# NT-Xent Loss function implementation
class Loss(nn.Module):
    def __init__(self):
        super(Loss, self).__init__()

    def forward(self, outs, batch_size, temperature=0.5):
        # Concatenated features - [N * Batch size, feature dimension(128)] 
        out = torch.cat(outs, dim=0)
        img_cnt = len(outs)
        
        # Calculated similarity matrix - [N * Batch size, N * Batch size]
        # The elements in the similarity matrix are normalized by dividing by the temperature parameter temperature
        sim_matrix = torch.exp(torch.mm(out, out.t().contiguous()) / temperature)
        
        # Create a mask matrix - [N * Batch size, N * Batch size]
        # The diagonal elements of the mask matrix are 0 and the other elements are 1 (diagonal element are similarity of itself)
        mask = (torch.ones_like(sim_matrix) - torch.eye(img_cnt * batch_size, device=sim_matrix.device)).bool()

        # Compute loss matrix - [N * Batch size, N * Batch size]
        loss = -torch.log(sim_matrix / (sim_matrix * mask).sum(dim=-1).view(-1, 1))
        
        # Final loss
        sum, cnt = 0, 0
        for k in range(batch_size):
            for idx1 in range(k * img_cnt, (k + 1) * img_cnt):
                for idx2 in range(k * img_cnt, (k + 1) * img_cnt):
                    if idx1 != idx2:
                        sum = sum + loss[idx1][idx2]
                        cnt = cnt + 1
        
        return sum / cnt

In [9]:
def train_SimCLR(batch_size, epochs, temperature=0.5, img_num=2):

    # Create model, loss, optimizer
    model = SimCLR().to(device)
    nt_xent_loss = Loss().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6)

    # Create training dataset
    train_dataset = CIFAR10_Dataset(root=data_path, train=True, transform=train_transform, img_num=img_num, download=True)
    train_data = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True)

    # Check and create result saving file
    result_filename = 'SimCLR_Muti_B{}_E{}.csv'.format(batch_size, epochs)
    if not os.path.isfile(os.path.join(result_path, result_filename)):
        with open(os.path.join(result_path, result_filename), mode='w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(["Epoch", "Loss"])

    model.train()
    for epoch in range(1, epochs + 1):
        # Create loss and samples parameters
        total_loss, total_samples = 0, 0

        # Create a train bar for training visualization
        train_bar = tqdm(train_data)
        for imgs, labels in train_bar:
            # Transfer the image to the GPU
            imgs = [img.to(device) for img in imgs]
            labels = labels.to(device)

            # Apply SimCLR on two images
            features = []
            pres = []
            for img in imgs:
                feature, pre = model(img)
                features.append(feature)
                pres.append(pre)

            # Claculating the NT Xent Loss
            loss = nt_xent_loss(pres, batch_size, temperature)

            # Back propagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Print out loss
            total_samples += batch_size
            total_loss += loss.detach().item()
            train_bar.set_description('Train Epoch: [{}/{}] Loss: {:.4f}'.format(epoch, epochs, total_loss / total_samples))

        # Calculate and print out epoch loss
        print("epoch loss:", total_loss / len(train_dataset) * batch_size)

        # write epoch loss to the result file
        with open(os.path.join(result_path, result_filename), "a") as f:
            writer = csv.writer(f)
            writer.writerow([epoch, total_loss / len(train_dataset) * batch_size])

        # Save model with trained different epochs
        if epoch % 100 == 0:
            model_filename = 'SimCLR_Muti_B{}_E{}.pth'.format(batch_size, epoch)
            torch.save(model.state_dict(), os.path.join(model_path, model_filaname))

In [10]:
train_SimCLR(batch_size=32, epochs=500, img_num=3)

Files already downloaded and verified


Train Epoch: [1/500] Loss: 0.1426:   2%|▊                                            | 27/1562 [00:14<13:45,  1.86it/s]


KeyboardInterrupt: 