In [1]:
import argparse
import torch
import torch.backends.cudnn as cudnn
from torchvision import models
from data_aug.contrastive_learning_dataset import ContrastiveLearningDataset
from models.resnet_simclr import ResNetSimCLR
from simclr import SimCLR

from torchvision import transforms

import logging
import os
import sys

import torch
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from utils import save_config_file, accuracy, save_checkpoint

In [2]:
class SimCLR(object):

    def __init__(self):

        self.writer = SummaryWriter()
        logging.basicConfig(filename=os.path.join(self.writer.log_dir, 'training.log'), level=logging.DEBUG)
        
        self.device = 'cpu'
        self.criterion = torch.nn.CrossEntropyLoss().to(self.device)

    def info_nce_loss(self, features):

        labels = torch.cat([torch.arange(256) for i in range(2)], dim=0)
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
        labels = labels.to(self.device)

        features = F.normalize(features, dim=1)

        similarity_matrix = torch.matmul(features, features.T)
        # assert similarity_matrix.shape == (
        #     self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
        # assert similarity_matrix.shape == labels.shape

        # discard the main diagonal from both: labels and similarities matrix
        mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.device)
        labels = labels[~mask].view(labels.shape[0], -1)
        similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
        # assert similarity_matrix.shape == labels.shape

        # select and combine multiple positives
        positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)

        # select only the negatives the negatives
        negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

        logits = torch.cat([positives, negatives], dim=1)
        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.device)

        logits = logits / 0.07
        return logits, labels

    def train(self, train_loader):

        scaler = GradScaler(enabled=self.args.fp16_precision)

        # save config file
        save_config_file(self.writer.log_dir, self.args)

        n_iter = 0
        logging.info(f"Start SimCLR training for {self.args.epochs} epochs.")
        logging.info(f"Training with gpu: {self.args.disable_cuda}.")

        for epoch_counter in range(self.args.epochs):
            for images, _ in tqdm(train_loader):
                images = torch.cat(images, dim=0)

                images = images.to(self.args.device)

                with autocast(enabled=self.args.fp16_precision):
                    features = self.model(images)
                    logits, labels = self.info_nce_loss(features)
                    loss = self.criterion(logits, labels)

                self.optimizer.zero_grad()

                scaler.scale(loss).backward()

                scaler.step(self.optimizer)
                scaler.update()

                if n_iter % self.args.logeverynsteps == 0:
                    top1, top5 = accuracy(logits, labels, topk=(1, 5))
                    self.writer.add_scalar('loss', loss, global_step=n_iter)
                    self.writer.add_scalar('acc/top1', top1[0], global_step=n_iter)
                    self.writer.add_scalar('acc/top5', top5[0], global_step=n_iter)
                    self.writer.add_scalar('learning_rate', self.scheduler.get_lr()[0], global_step=n_iter)

                    
                top1, top5 = accuracy(logits, labels, topk=(1, 5))
                logging.info('loss : ', loss)
                logging.info('acc/top1 :', top1[0])
                logging.info('acc/top5 :', top5[0])
                
                    
                n_iter += 1
                
                
            
            

            # warmup for the first 10 epochs
            if epoch_counter >= 10:
                self.scheduler.step()
            logging.debug(f"Epoch: {epoch_counter}\tLoss: {loss}\tTop1 accuracy: {top1[0]}")

        logging.info("Training has finished.")
        # save model checkpoints
        checkpoint_name = 'checkpoint_{:04d}.pth.tar'.format(self.args.epochs)
        save_checkpoint({
            'epoch': self.args.epochs,
            'arch': self.args.arch,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
        }, is_best=False, filename=os.path.join(self.writer.log_dir, checkpoint_name))
        logging.info(f"Model checkpoint and metadata has been saved at {self.writer.log_dir}.")

In [3]:
dataset = ContrastiveLearningDataset('./datasets')

train_dataset = dataset.get_dataset('stl10', 2)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=256, shuffle=True,
    num_workers=1, pin_memory=False, drop_last=True)

model = ResNetSimCLR(base_model='resnet18', out_dim=128)

optimizer = torch.optim.Adam(model.parameters(), 0.0003, weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0,
                                                       last_epoch=-1)

simclr = SimCLR()

Files already downloaded and verified


In [4]:
for images, y in tqdm(train_loader):
    images_ = torch.cat(images, dim=0)

    images_ = images_.to('cpu')

    #with autocast(enabled=self.args.fp16_precision):
    features = model(images_)  # output shape = (512, 128)
    logits, labels = simclr.info_nce_loss(features)
    loss = torch.nn.CrossEntropyLoss().to('cpu')(logits, labels)
    
    break
    
    #torch.nn.CrossEntropyLoss()(logits[0, :].view(1, -1), labels[0].view(1))

  0%|          | 0/390 [00:16<?, ?it/s]


In [104]:
topk = (1, 5)
maxk = max(topk)

batch_size = labels.size(0)

In [105]:
_, pred = logits.topk(5, 1, True, True)

In [106]:
pred_ = pred.t()
correct = pred_.eq(labels.view(1, -1).expand_as(pred_))

In [107]:
correct_k = correct[:1].reshape(-1).float().sum(0, keepdim=True)

In [109]:
correct_k.mul_(100 / 512)

tensor([0.9766])

In [None]:
maxk = max(topk)
batch_size = target.size(0)

_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))

res = []
for k in topk:
    correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
    res.append(correct_k.mul_(100.0 / batch_size))

In [5]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [72]:
labels = torch.cat([torch.arange(256) for i in range(2)], dim=0)
labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
labels = labels.to('cpu')  # output shape = (512, 512)

In [36]:
features_ = F.normalize(features, dim=1)   # 128 dimensions 에 대해 각 차원별 L2 정규화

In [37]:
similarity_matrix = torch.matmul(features_, features_.T)  # in paper : dot product between l2 normalized u and v

In [38]:
mask = torch.eye(labels.shape[0], dtype=torch.bool).to('cpu')
labels_ = labels[~mask].view(labels.shape[0], -1)
similarity_matrix_ = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)  # output shape = (512, 511)

In [39]:
positives = similarity_matrix_[labels_.bool()].view(labels.shape[0], -1)
negatives = similarity_matrix_[~labels_.bool()].view(similarity_matrix.shape[0], -1)

In [40]:
logits = torch.cat([positives, negatives], dim=1)

In [41]:
labels = torch.zeros(logits.shape[0], dtype=torch.long).to('cpu')

In [42]:
logits = logits / 0.07

# Softmax 함수

In [374]:
temp = torch.randn(5)
temp

tensor([ 0.3266, -2.5380, -0.0374,  1.4999,  1.3031])

In [376]:
output = torch.exp(temp) / torch.exp(temp).sum()

In [75]:
similarity_maxtrix

tensor([[1.0000, 0.8828, 0.8326,  ..., 0.8992, 0.9017, 0.7120],
        [0.8828, 1.0000, 0.8830,  ..., 0.8647, 0.8769, 0.8240],
        [0.8326, 0.8830, 1.0000,  ..., 0.7996, 0.7972, 0.9001],
        ...,
        [0.8992, 0.8647, 0.7996,  ..., 1.0000, 0.9371, 0.7100],
        [0.9017, 0.8769, 0.7972,  ..., 0.9371, 1.0000, 0.7037],
        [0.7120, 0.8240, 0.9001,  ..., 0.7100, 0.7037, 1.0000]],
       grad_fn=<MmBackward>)

In [35]:
def info_nce_loss(self, features):

    labels = torch.cat([torch.arange(256) for i in range(2)], dim=0)
    labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
    labels = labels.to(self.device)

    features = F.normalize(features, dim=1)

    similarity_matrix = torch.matmul(features, features.T)  # (512, 512)
    # assert similarity_matrix.shape == (
    #     self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
    # assert similarity_matrix.shape == labels.shape

    # discard the main diagonal from both: labels and similarities matrix
    mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.device)
    labels = labels[~mask].view(labels.shape[0], -1)
    similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
    # assert similarity_matrix.shape == labels.shape

    # select and combine multiple positives
    positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)

    # select only the negatives the negatives
    negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

    logits = torch.cat([positives, negatives], dim=1)
    labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.device)

    logits = logits / 0.07
    return logits, labels