In [1]:
!pip install tensorboard_logger

Collecting tensorboard_logger
  Downloading tensorboard_logger-0.1.0-py2.py3-none-any.whl (17 kB)
Installing collected packages: tensorboard_logger
Successfully installed tensorboard_logger-0.1.0


In [2]:
!pip install apex

Collecting apex
  Downloading apex-0.9.10dev.tar.gz (36 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting cryptacular (from apex)
  Downloading cryptacular-1.6.2.tar.gz (75 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.8/75.8 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting zope.sqlalchemy (from apex)
  Downloading zope.sqlalchemy-3.1-py3-none-any.whl (23 kB)
Collecting velruse>=1.0.3 (from apex)
  Downloading velruse-1.1.1.tar.gz (709 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m709.8/709.8 kB[0m [31m44.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pyramid>1.1.2 (from apex)
  Downloading pyramid-2.0.2-py3-none-any.whl (247 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [3]:

from __future__ import print_function

import torch
import torch.nn as nn


class SupConLoss(nn.Module):"
    def __init__(self, temperature=0.07, contrast_mode='all',
                 base_temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode
        self.base_temperature = base_temperature

    def forward(self, features, labels=None, mask=None):
        device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        if labels is not None and mask is not None:
            raise ValueError('Cannot define both `labels` and `mask`')
        elif labels is None and mask is None:
            mask = torch.eye(batch_size, dtype=torch.float32).to(device)
        elif labels is not None:
            labels = labels.contiguous().view(-1, 1)
            if labels.shape[0] != batch_size:
                raise ValueError('Num of labels does not match num of features')
            mask = torch.eq(labels, labels.T).float().to(device)
        else:
            mask = mask.float().to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # compute mean of log-likelihood over positive
        # modified to handle edge cases when there is no positive pair
        # for an anchor point.
        # Edge case e.g.:-
        # features of shape: [4,1,...]
        # labels:            [0,1,1,2]
        # loss before mean:  [nan, ..., ..., nan]
        mask_pos_pairs = mask.sum(1)
        mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs

        # loss
        loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()

        return loss

In [4]:
from __future__ import print_function

import math
import numpy as np
import torch
import torch.optim as optim


class TwoCropTransform:
    """Create two crops of the same image"""
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, x):
        return [self.transform(x), self.transform(x)]


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.reshape(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def adjust_learning_rate(args, optimizer, epoch):
    lr = args.learning_rate
    if args.cosine:
        eta_min = lr * (args.lr_decay_rate ** 3)
        lr = eta_min + (lr - eta_min) * (
                1 + math.cos(math.pi * epoch / args.epochs)) / 2
    else:
        steps = np.sum(epoch > np.asarray(args.lr_decay_epochs))
        if steps > 0:
            lr = lr * (args.lr_decay_rate ** steps)

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def warmup_learning_rate(args, epoch, batch_id, total_batches, optimizer):
    if args.warm and epoch <= args.warm_epochs:
        p = (batch_id + (epoch - 1) * total_batches) / \
            (args.warm_epochs * total_batches)
        lr = args.warmup_from + p * (args.warmup_to - args.warmup_from)

        for param_group in optimizer.param_groups:
            param_group['lr'] = lr


def set_optimizer(opt, model):
    optimizer = optim.SGD(model.parameters(),
                          lr=opt.learning_rate,
                          momentum=opt.momentum,
                          weight_decay=opt.weight_decay)
    return optimizer


def save_model(model, optimizer, opt, epoch, save_file):
    print('==> Saving...')
    state = {
        'opt': opt,
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
    }
    torch.save(state, save_file)
    del state

In [5]:
"""ResNet in PyTorch."""
import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, is_last=False):
        super(BasicBlock, self).__init__()
        self.is_last = is_last
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        preact = out
        out = F.relu(out)
        if self.is_last:
            return out, preact
        else:
            return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, is_last=False):
        super(Bottleneck, self).__init__()
        self.is_last = is_last
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        preact = out
        out = F.relu(out)
        if self.is_last:
            return out, preact
        else:
            return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves
        # like an identity. This improves the model by 0.2~0.3% according to:
        # https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for i in range(num_blocks):
            stride = strides[i]
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, layer=100):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        return out


def resnet18(**kwargs):
    return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)


def resnet34(**kwargs):
    return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)


def resnet50(**kwargs):
    return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)


def resnet101(**kwargs):
    return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)


model_dict = {
    'resnet18': [resnet18, 512],
    'resnet34': [resnet34, 512],
    'resnet50': [resnet50, 2048],
    'resnet101': [resnet101, 2048],
}


class LinearBatchNorm(nn.Module):
    """Implements BatchNorm1d by BatchNorm2d, for SyncBN purpose"""
    def __init__(self, dim, affine=True):
        super(LinearBatchNorm, self).__init__()
        self.dim = dim
        self.bn = nn.BatchNorm2d(dim, affine=affine)

    def forward(self, x):
        x = x.view(-1, self.dim, 1, 1)
        x = self.bn(x)
        x = x.view(-1, self.dim)
        return x


class SupConResNet(nn.Module):
    """backbone + projection head"""
    def __init__(self, name='resnet50', head='mlp', feat_dim=128):
        super(SupConResNet, self).__init__()
        model_fun, dim_in = model_dict[name]
        self.encoder = model_fun()
        if head == 'linear':
            self.head = nn.Linear(dim_in, feat_dim)
        elif head == 'mlp':
            self.head = nn.Sequential(
                nn.Linear(dim_in, dim_in),
                nn.ReLU(inplace=True),
                nn.Linear(dim_in, feat_dim)
            )
        else:
            raise NotImplementedError(
                'head not supported: {}'.format(head))

    def forward(self, x):
        feat = self.encoder(x)
        feat = F.normalize(self.head(feat), dim=1)
        return feat


class SupCEResNet(nn.Module):
    """encoder + classifier"""
    def __init__(self, name='resnet50', num_classes=10):
        super(SupCEResNet, self).__init__()
        model_fun, dim_in = model_dict[name]
        self.encoder = model_fun()
        self.fc = nn.Linear(dim_in, num_classes)

    def forward(self, x):
        return self.fc(self.encoder(x))


class LinearClassifier(nn.Module):
    """Linear classifier"""
    def __init__(self, name='resnet50', num_classes=10):
        super(LinearClassifier, self).__init__()
        _, feat_dim = model_dict[name]
        self.fc = nn.Linear(feat_dim, num_classes)

    def forward(self, features):
        return self.fc(features)

Cross Entropy Loss Function on CIFAR100

In [7]:
import os
import time
import math
import sys

import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
import tensorboard_logger as tb_logger
import torch
import torch.backends.cudnn as cudnn
from torchvision import transforms, datasets

try:
    import apex
    from apex import amp, optimizers
except ImportError:
    pass

# Hardcoded configurations
class Config:
    print_freq = 10
    save_freq = 50
    batch_size = 256
    num_workers = 16
    epochs = 25
    learning_rate = 0.2
    lr_decay_epochs = [350, 400, 450]
    lr_decay_rate = 0.1
    weight_decay = 1e-4
    momentum = 0.9
    model = 'resnet50'
    dataset = 'cifar100'
    cosine = False
    syncBN = False
    warm = True
    trial = '0'
    data_folder = './datasets/'
    model_path = './save/SupCon/cifar100_models'
    tb_path = './save/SupCon/cifar100_tensorboard'

opt = Config()

opt.model_name = 'SupCE_{}_{}_lr_{}_decay_{}_bsz_{}_trial_{}'.format(
    opt.dataset, opt.model, opt.learning_rate, opt.weight_decay,
    opt.batch_size, opt.trial)

if opt.cosine:
    opt.model_name = '{}_cosine'.format(opt.model_name)

if opt.warm:
    opt.model_name = '{}_warm'.format(opt.model_name)
    opt.warmup_from = 0.01
    opt.warm_epochs = 10
    if opt.cosine:
        eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
        opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (
            1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
    else:
        opt.warmup_to = opt.learning_rate

opt.tb_folder = os.path.join(opt.tb_path, opt.model_name)
if not os.path.isdir(opt.tb_folder):
    os.makedirs(opt.tb_folder)

opt.save_folder = os.path.join(opt.model_path, opt.model_name)
if not os.path.isdir(opt.save_folder):
    os.makedirs(opt.save_folder)

opt.n_cls = 100

def set_loader(opt):
    mean = (0.5071, 0.4867, 0.4408)
    std = (0.2675, 0.2565, 0.2761)
    normalize = transforms.Normalize(mean=mean, std=std)

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])

    val_transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])

    train_dataset = datasets.CIFAR100(root=opt.data_folder,
                                      transform=train_transform,
                                      download=True)
    val_dataset = datasets.CIFAR100(root=opt.data_folder,
                                    train=False,
                                    transform=val_transform)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=opt.batch_size, shuffle=True,
        num_workers=opt.num_workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=256, shuffle=False,
        num_workers=8, pin_memory=True)

    return train_loader, val_loader

def set_model(opt):
    model = SupCEResNet(name=opt.model, num_classes=opt.n_cls)
    criterion = torch.nn.CrossEntropyLoss()

    if opt.syncBN:
        model = apex.parallel.convert_syncbn_model(model)

    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            model = torch.nn.DataParallel(model)
        model = model.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

    return model, criterion

def train(train_loader, model, criterion, optimizer, epoch, opt):
    model.train()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    end = time.time()

    for idx, (images, labels) in enumerate(train_loader):
        data_time.update(time.time() - end)
        images = images.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True)
        bsz = labels.shape[0]
        warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)
        output = model(images)
        loss = criterion(output, labels)
        losses.update(loss.item(), bsz)
        acc1, acc5 = accuracy(output, labels, topk=(1, 5))
        top1.update(acc1[0], bsz)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        batch_time.update(time.time() - end)
        end = time.time()

        if (idx + 1) % opt.print_freq == 0:
            print('Train: [{0}][{1}/{2}]\t'
                  'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'loss {loss.val:.3f} ({loss.avg:.3f})\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                   epoch, idx + 1, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses, top1=top1))
            sys.stdout.flush()

    return losses.avg, top1.avg

def validate(val_loader, model, criterion, opt):
    model.eval()
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    with torch.no_grad():
        end = time.time()
        for idx, (images, labels) in enumerate(val_loader):
            images = images.float().cuda()
            labels = labels.cuda()
            bsz = labels.shape[0]
            output = model(images)
            loss = criterion(output, labels)
            losses.update(loss.item(), bsz)
            acc1, acc5 = accuracy(output, labels, topk=(1, 5))
            top1.update(acc1[0], bsz)
            batch_time.update(time.time() - end)
            end = time.time()

            if idx % opt.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                       idx, len(val_loader), batch_time=batch_time,
                       loss=losses, top1=top1))

    print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))
    return losses.avg, top1.avg

def main():
    best_acc = 0
    # build data loader
    train_loader, val_loader = set_loader(opt)
    # build model and criterion
    model, criterion = set_model(opt)
    # build optimizer
    optimizer = set_optimizer(opt, model)
    # tensorboard
    logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)
    # training routine
    for epoch in range(1, opt.epochs + 1):
        adjust_learning_rate(opt, optimizer, epoch)
        time1 = time.time()
        loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, opt)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))
        logger.log_value('train_loss', loss, epoch)
        logger.log_value('train_acc', train_acc, epoch)
        logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch)
        loss, val_acc = validate(val_loader, model, criterion, opt)
        logger.log_value('val_loss', loss, epoch)
        logger.log_value('val_acc', val_acc, epoch)
        if val_acc > best_acc:
            best_acc = val_acc
        if epoch % opt.save_freq == 0:
            save_file = os.path.join(
                opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            save_model(model, optimizer, opt, epoch, save_file)
    save_file = os.path.join(
        opt.save_folder, 'last.pth')
    save_model(model, optimizer, opt, opt.epochs, save_file)
    print('best accuracy: {:.2f}'.format(best_acc))

if __name__ == '__main__':
    main()


Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./datasets/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:13<00:00, 12928925.98it/s]


Extracting ./datasets/cifar-100-python.tar.gz to ./datasets/


  self.pid = os.fork()


Train: [1][10/196]	BT 0.377 (0.960)	DT 0.000 (0.069)	loss 4.733 (4.750)	Acc@1 1.562 (1.250)
Train: [1][20/196]	BT 0.378 (0.669)	DT 0.000 (0.047)	loss 5.061 (4.785)	Acc@1 0.391 (1.562)
Train: [1][30/196]	BT 0.379 (0.572)	DT 0.000 (0.039)	loss 4.865 (4.894)	Acc@1 2.344 (1.523)
Train: [1][40/196]	BT 0.378 (0.524)	DT 0.000 (0.035)	loss 5.005 (4.905)	Acc@1 3.125 (1.699)
Train: [1][50/196]	BT 0.381 (0.495)	DT 0.000 (0.033)	loss 4.710 (4.876)	Acc@1 2.734 (1.852)
Train: [1][60/196]	BT 0.379 (0.476)	DT 0.000 (0.032)	loss 4.715 (4.830)	Acc@1 1.953 (2.070)
Train: [1][70/196]	BT 0.383 (0.462)	DT 0.000 (0.031)	loss 4.534 (4.786)	Acc@1 5.078 (2.338)
Train: [1][80/196]	BT 0.380 (0.452)	DT 0.000 (0.030)	loss 4.432 (4.742)	Acc@1 4.688 (2.617)
Train: [1][90/196]	BT 0.381 (0.444)	DT 0.000 (0.029)	loss 4.295 (4.697)	Acc@1 5.078 (2.873)
Train: [1][100/196]	BT 0.380 (0.438)	DT 0.000 (0.029)	loss 4.352 (4.660)	Acc@1 6.250 (3.195)
Train: [1][110/196]	BT 0.382 (0.433)	DT 0.000 (0.028)	loss 4.276 (4.628)	Acc@1 

Supervised Contrastive Learning on CIFAR-100

In [6]:
from __future__ import print_function

import os
import sys
import time
import math

import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import tensorboard_logger as tb_logger
import torch
import torch.backends.cudnn as cudnn
from torchvision import transforms, datasets


try:
    import apex
    from apex import amp, optimizers
except ImportError:
    pass

# Hardcoded configurations
class Configuration:
  print_freq = 10
  save_freq = 50
  batch_size = 256
  num_workers = 16
  epochs = 25
  learning_rate = 0.05
  lr_decay_epochs = [700,800,900]
  lr_decay_rate = 0.1
  weight_decay = 1e-4
  momentum = 0.9
  model_name = 'resnet50'
  dataset = 'cifar100'  # options: ['cifar10', 'cifar100', 'path']
  mean = None
  std = None
  data_folder = './datasets/'
  size = 32
  method = 'SupCon'  # options: ['SupCon', 'SimCLR']
  temp = 0.07
  cosine = False
  syncBN = False
  warm = False
  trial = '0'
  model_path = './save/SupCon/cifar100_models'
  tb_path = './save/SupCon/cifar100_tensorboard'

optimize = Configuration()

In [8]:


# Derived configurations
#if optimize.dataset == 'path':
    #assert optimize.data_folder is not None and optimize.mean is not None and optimize.std is not None

#if optimize.data_folder is None:
    #data_folder = './datasets/'
#optimize.model_path = './save/SupCon/{}_models'.format(optimize.dataset)
#optimize.tb_path = './save/SupCon/{}_tensorboard'.format(optimize.dataset)

#iterations = optimize.lr_decay_epochs.split(',')
#lr_decay_epochs = list(map(int, iterations))

model_name = '{}_{}_{}_lr_{}_decay_{}_bsz_{}_temp_{}_trial_{}'.format(
    optimize.method, optimize.dataset, optimize.model_name, optimize.learning_rate,
    optimize.weight_decay, optimize.batch_size, optimize.temp, optimize.trial)

if optimize.cosine:
    model_name += '_cosine'

if optimize.batch_size > 256:
    optimize.warm = True
if optimize.warm:
    model_name += '_warm'
    warmup_from = 0.01
    warm_epochs = 10
    if optimize.cosine:
        eta_min = optimize.learning_rate * (optimize.lr_decay_rate ** 3)
        warmup_to = eta_min + (optimize.learning_rate - eta_min) * (
                1 + math.cos(math.pi * warm_epochs / optimize.epochs)) / 2
    else:
        warmup_to = optimize.learning_rate

tb_folder = os.path.join(optimize.tb_path, model_name)
if not os.path.isdir(tb_folder):
    os.makedirs(tb_folder)

save_folder = os.path.join(optimize.model_path, model_name)
if not os.path.isdir(save_folder):
    os.makedirs(save_folder)

def set_loader():
    # construct data loader
    if optimize.dataset == 'cifar10':
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2023, 0.1994, 0.2010)
    elif optimize.dataset == 'cifar100':
        mean = (0.5071, 0.4867, 0.4408)
        std = (0.2675, 0.2565, 0.2761)
    elif optimize.dataset == 'path':
        mean = eval(mean)
        std = eval(std)
    else:
        raise ValueError('dataset not supported: {}'.format(optimize.dataset))
    normalize = transforms.Normalize(mean=mean, std=std)

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(size=optimize.size, scale=(0.2, 1.)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.ToTensor(),
        normalize,
    ])

    if optimize.dataset == 'cifar10':
        train_dataset = datasets.CIFAR10(root=optimize.data_folder,
                                         transform=TwoCropTransform(train_transform),
                                         download=True)
    elif optimize.dataset == 'cifar100':
        train_dataset = datasets.CIFAR100(root=optimize.data_folder,
                                          transform=TwoCropTransform(train_transform),
                                          download=True)
    elif optimize.dataset == 'path':
        train_dataset = datasets.ImageFolder(root=optimize.data_folder,
                                             transform=TwoCropTransform(train_transform))
    else:
        raise ValueError(optimize.dataset)

    train_sampler = None
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=optimize.batch_size, shuffle=(train_sampler is None),
        num_workers=optimize.num_workers, pin_memory=True, sampler=train_sampler)

    return train_loader

def set_model():
    model = SupConResNet(name='resnet50')
    criterion = SupConLoss(temperature=optimize.temp)

    # enable synchronized Batch Normalization
    if optimize.syncBN:
        model = apex.parallel.convert_syncbn_model(model)

    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            model.encoder = torch.nn.DataParallel(model.encoder)
        model = model.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

    return model, criterion

def train(train_loader, model, criterion, optimizer, epoch):
    """one epoch training"""
    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    end = time.time()
    for idx, (images, labels) in enumerate(train_loader):
        data_time.update(time.time() - end)

        images = torch.cat([images[0], images[1]], dim=0)
        if torch.cuda.is_available():
            images = images.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True)
        bsz = labels.shape[0]

        # warm-up learning rate
        warmup_learning_rate(optimize, epoch, idx, len(train_loader), optimizer)

        # compute loss
        features = model(images)
        f1, f2 = torch.split(features, [bsz, bsz], dim=0)
        features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
        if optimize.method == 'SupCon':
            loss = criterion(features, labels)
        elif optimize.method == 'SimCLR':
            loss = criterion(features)
        else:
            raise ValueError('contrastive method not supported: {}'.
                             format(optimize.method))

        # update metric
        losses.update(loss.item(), bsz)

        # SGD
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # print info
        if (idx + 1) % optimize.print_freq == 0:
            print('Train: [{0}][{1}/{2}]\t'
                  'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'loss {loss.val:.3f} ({loss.avg:.3f})'.format(
                   epoch, idx + 1, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses))
            sys.stdout.flush()

    return losses.avg

def main():
    # build data loader
    train_loader = set_loader()

    # build model and criterion
    model, criterion = set_model()

    # build optimizer
    optimizer = set_optimizer(optimize, model)

    # tensorboard
    logger = tb_logger.Logger(logdir=tb_folder, flush_secs=2)

    # training routine
    for epoch in range(1, optimize.epochs + 1):
        adjust_learning_rate(optimize, optimizer, epoch)

        # train for one epoch
        time1 = time.time()
        loss = train(train_loader, model, criterion, optimizer, epoch)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        # tensorboard logger
        logger.log_value('loss', loss, epoch)
        logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch)

        if epoch % optimize.save_freq == 0:
            save_file = os.path.join(
                save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            save_model(model, optimizer, optimize, epoch, save_file)

    # save the last model
    save_file = os.path.join(
        save_folder, 'last.pth')
    save_model(model, optimizer, optimize, optimize.epochs, save_file)

if __name__ == '__main__':
    main()


Files already downloaded and verified


  self.pid = os.fork()


Train: [1][10/196]	BT 0.801 (0.892)	DT 0.000 (0.154)	loss 6.233 (6.245)
Train: [1][20/196]	BT 0.801 (0.846)	DT 0.000 (0.077)	loss 6.235 (6.240)
Train: [1][30/196]	BT 0.800 (0.830)	DT 0.000 (0.052)	loss 6.230 (6.237)
Train: [1][40/196]	BT 0.803 (0.823)	DT 0.000 (0.039)	loss 6.228 (6.235)
Train: [1][50/196]	BT 0.787 (0.817)	DT 0.000 (0.031)	loss 6.217 (6.233)
Train: [1][60/196]	BT 0.789 (0.813)	DT 0.000 (0.026)	loss 6.220 (6.230)
Train: [1][70/196]	BT 0.788 (0.809)	DT 0.000 (0.022)	loss 6.189 (6.227)
Train: [1][80/196]	BT 0.791 (0.807)	DT 0.000 (0.020)	loss 6.207 (6.225)
Train: [1][90/196]	BT 0.790 (0.805)	DT 0.000 (0.017)	loss 6.177 (6.222)
Train: [1][100/196]	BT 0.791 (0.803)	DT 0.000 (0.016)	loss 6.213 (6.220)
Train: [1][110/196]	BT 0.790 (0.802)	DT 0.000 (0.014)	loss 6.218 (6.218)
Train: [1][120/196]	BT 0.789 (0.801)	DT 0.000 (0.013)	loss 6.209 (6.217)
Train: [1][130/196]	BT 0.789 (0.800)	DT 0.000 (0.012)	loss 6.199 (6.214)
Train: [1][140/196]	BT 0.797 (0.800)	DT 0.000 (0.011)	loss 6

Cross Entropy Loss Function on CIFAR10

In [9]:
import os
import time
import math
import sys

import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
import tensorboard_logger as tb_logger
import torch
import torch.backends.cudnn as cudnn
from torchvision import transforms, datasets

try:
    import apex
    from apex import amp, optimizers
except ImportError:
    pass

# Hardcoded configurations
class Config:
    print_freq = 10
    save_freq = 50
    batch_size = 256
    num_workers = 16
    epochs = 15
    learning_rate = 0.2
    lr_decay_epochs = [350, 400, 450]
    lr_decay_rate = 0.1
    weight_decay = 1e-4
    momentum = 0.9
    model = 'resnet50'
    dataset = 'cifar10'
    cosine = False
    syncBN = False
    warm = True
    trial = '0'
    data_folder = './datasets/'
    model_path = './save/SupCon/cifar10_models'
    tb_path = './save/SupCon/cifar10_tensorboard'

opt = Config()

opt.model_name = 'SupCE_{}_{}_lr_{}_decay_{}_bsz_{}_trial_{}'.format(
    opt.dataset, opt.model, opt.learning_rate, opt.weight_decay,
    opt.batch_size, opt.trial)

if opt.cosine:
    opt.model_name = '{}_cosine'.format(opt.model_name)

if opt.warm:
    opt.model_name = '{}_warm'.format(opt.model_name)
    opt.warmup_from = 0.01
    opt.warm_epochs = 10
    if opt.cosine:
        eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
        opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (
            1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
    else:
        opt.warmup_to = opt.learning_rate

opt.tb_folder = os.path.join(opt.tb_path, opt.model_name)
if not os.path.isdir(opt.tb_folder):
    os.makedirs(opt.tb_folder)

opt.save_folder = os.path.join(opt.model_path, opt.model_name)
if not os.path.isdir(opt.save_folder):
    os.makedirs(opt.save_folder)

opt.n_cls = 100

def set_loader(opt):
    mean = (0.5071, 0.4867, 0.4408)
    std = (0.2675, 0.2565, 0.2761)
    normalize = transforms.Normalize(mean=mean, std=std)

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])

    val_transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])

    train_dataset = datasets.CIFAR10(root=opt.data_folder,
                                      transform=train_transform,
                                      download=True)
    val_dataset = datasets.CIFAR10(root=opt.data_folder,
                                    train=False,
                                    transform=val_transform)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=opt.batch_size, shuffle=True,
        num_workers=opt.num_workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=256, shuffle=False,
        num_workers=8, pin_memory=True)

    return train_loader, val_loader

def set_model(opt):
    model = SupCEResNet(name=opt.model, num_classes=opt.n_cls)
    criterion = torch.nn.CrossEntropyLoss()

    if opt.syncBN:
        model = apex.parallel.convert_syncbn_model(model)

    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            model = torch.nn.DataParallel(model)
        model = model.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

    return model, criterion

def train(train_loader, model, criterion, optimizer, epoch, opt):
    model.train()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    end = time.time()

    for idx, (images, labels) in enumerate(train_loader):
        data_time.update(time.time() - end)
        images = images.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True)
        bsz = labels.shape[0]
        warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)
        output = model(images)
        loss = criterion(output, labels)
        losses.update(loss.item(), bsz)
        acc1, acc5 = accuracy(output, labels, topk=(1, 5))
        top1.update(acc1[0], bsz)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        batch_time.update(time.time() - end)
        end = time.time()

        if (idx + 1) % opt.print_freq == 0:
            print('Train: [{0}][{1}/{2}]\t'
                  'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'loss {loss.val:.3f} ({loss.avg:.3f})\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                   epoch, idx + 1, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses, top1=top1))
            sys.stdout.flush()

    return losses.avg, top1.avg

def validate(val_loader, model, criterion, opt):
    model.eval()
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    with torch.no_grad():
        end = time.time()
        for idx, (images, labels) in enumerate(val_loader):
            images = images.float().cuda()
            labels = labels.cuda()
            bsz = labels.shape[0]
            output = model(images)
            loss = criterion(output, labels)
            losses.update(loss.item(), bsz)
            acc1, acc5 = accuracy(output, labels, topk=(1, 5))
            top1.update(acc1[0], bsz)
            batch_time.update(time.time() - end)
            end = time.time()

            if idx % opt.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                       idx, len(val_loader), batch_time=batch_time,
                       loss=losses, top1=top1))

    print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))
    return losses.avg, top1.avg

def main():
    best_acc = 0
    # build data loader
    train_loader, val_loader = set_loader(opt)
    # build model and criterion
    model, criterion = set_model(opt)
    # build optimizer
    optimizer = set_optimizer(opt, model)
    # tensorboard
    logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)
    # training routine
    for epoch in range(1, opt.epochs + 1):
        adjust_learning_rate(opt, optimizer, epoch)
        time1 = time.time()
        loss, train_acc = train(train_loader, model, criterion, optimizer, epoch, opt)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))
        logger.log_value('train_loss', loss, epoch)
        logger.log_value('train_acc', train_acc, epoch)
        logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch)
        loss, val_acc = validate(val_loader, model, criterion, opt)
        logger.log_value('val_loss', loss, epoch)
        logger.log_value('val_acc', val_acc, epoch)
        if val_acc > best_acc:
            best_acc = val_acc
        if epoch % opt.save_freq == 0:
            save_file = os.path.join(
                opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            save_model(model, optimizer, opt, epoch, save_file)
    save_file = os.path.join(
        opt.save_folder, 'last.pth')
    save_model(model, optimizer, opt, opt.epochs, save_file)
    print('best accuracy: {:.2f}'.format(best_acc))

if __name__ == '__main__':
    main()


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./datasets/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:12<00:00, 13170340.31it/s]


Extracting ./datasets/cifar-10-python.tar.gz to ./datasets/
Train: [1][10/196]	BT 0.387 (0.791)	DT 0.000 (0.070)	loss 3.701 (3.308)	Acc@1 9.766 (10.508)
Train: [1][20/196]	BT 0.387 (0.589)	DT 0.000 (0.047)	loss 4.496 (3.689)	Acc@1 9.766 (10.645)
Train: [1][30/196]	BT 0.385 (0.521)	DT 0.000 (0.040)	loss 3.243 (3.651)	Acc@1 10.156 (10.547)
Train: [1][40/196]	BT 0.386 (0.487)	DT 0.000 (0.036)	loss 2.709 (3.479)	Acc@1 8.984 (10.996)
Train: [1][50/196]	BT 0.385 (0.467)	DT 0.000 (0.034)	loss 2.490 (3.289)	Acc@1 14.453 (12.125)
Train: [1][60/196]	BT 0.384 (0.453)	DT 0.000 (0.032)	loss 2.518 (3.144)	Acc@1 15.625 (13.164)
Train: [1][70/196]	BT 0.384 (0.443)	DT 0.000 (0.031)	loss 2.498 (3.036)	Acc@1 18.359 (13.795)
Train: [1][80/196]	BT 0.382 (0.436)	DT 0.000 (0.030)	loss 2.042 (2.929)	Acc@1 23.828 (14.644)
Train: [1][90/196]	BT 0.382 (0.430)	DT 0.000 (0.030)	loss 2.022 (2.836)	Acc@1 20.312 (15.356)
Train: [1][100/196]	BT 0.381 (0.425)	DT 0.000 (0.029)	loss 2.472 (2.763)	Acc@1 21.484 (16.227)
Tr

Supervised Contrastive Learning on CIFAR10

In [14]:
from __future__ import print_function

import os
import sys
import time
import math

import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import tensorboard_logger as tb_logger
import torch
import torch.backends.cudnn as cudnn
from torchvision import transforms, datasets


try:
    import apex
    from apex import amp, optimizers
except ImportError:
    pass

# Hardcoded configurations
class Configuration1:
  print_freq = 10
  save_freq = 50
  batch_size = 256
  num_workers = 16
  epochs = 15
  learning_rate = 0.05
  lr_decay_epochs = [700,800,900]
  lr_decay_rate = 0.1
  weight_decay = 1e-4
  momentum = 0.9
  model_name = 'resnet50'
  dataset = 'cifar10'  # options: ['cifar10', 'cifar100', 'path']
  mean = None
  std = None
  data_folder = './datasets/'
  size = 32
  method = 'SupCon'  # options: ['SupCon', 'SimCLR']
  temp = 0.07
  cosine = False
  syncBN = False
  warm = False
  trial = '0'
  model_path = './save/SupCon/cifar10_models'
  tb_path = './save/SupCon/cifar10_tensorboard'

optimize1 = Configuration()

In [15]:


# Derived configurations
#if optimize.dataset == 'path':
    #assert optimize.data_folder is not None and optimize.mean is not None and optimize.std is not None

#if optimize.data_folder is None:
    #data_folder = './datasets/'
#optimize.model_path = './save/SupCon/{}_models'.format(optimize.dataset)
#optimize.tb_path = './save/SupCon/{}_tensorboard'.format(optimize.dataset)

#iterations = optimize.lr_decay_epochs.split(',')
#lr_decay_epochs = list(map(int, iterations))

model_name = '{}_{}_{}_lr_{}_decay_{}_bsz_{}_temp_{}_trial_{}'.format(
    optimize1.method, optimize1.dataset, optimize1.model_name, optimize1.learning_rate,
    optimize1.weight_decay, optimize1.batch_size, optimize1.temp, optimize1.trial)

if optimize1.cosine:
    model_name += '_cosine'

if optimize1.batch_size > 256:
    optimize1.warm = True
if optimize1.warm:
    model_name += '_warm'
    warmup_from = 0.01
    warm_epochs = 10
    if optimize1.cosine:
        eta_min = optimize1.learning_rate * (optimize1.lr_decay_rate ** 3)
        warmup_to = eta_min + (optimize1.learning_rate - eta_min) * (
                1 + math.cos(math.pi * warm_epochs / optimize1.epochs)) / 2
    else:
        warmup_to = optimize1.learning_rate

tb_folder = os.path.join(optimize1.tb_path, model_name)
if not os.path.isdir(tb_folder):
    os.makedirs(tb_folder)

save_folder = os.path.join(optimize1.model_path, model_name)
if not os.path.isdir(save_folder):
    os.makedirs(save_folder)

def set_loader():
    # construct data loader
    if optimize1.dataset == 'cifar10':
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2023, 0.1994, 0.2010)
    elif optimize1.dataset == 'cifar100':
        mean = (0.5071, 0.4867, 0.4408)
        std = (0.2675, 0.2565, 0.2761)
    elif optimize1.dataset == 'path':
        mean = eval(mean)
        std = eval(std)
    else:
        raise ValueError('dataset not supported: {}'.format(optimize1.dataset))
    normalize = transforms.Normalize(mean=mean, std=std)

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(size=optimize1.size, scale=(0.2, 1.)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.ToTensor(),
        normalize,
    ])

    if optimize1.dataset == 'cifar10':
        train_dataset = datasets.CIFAR10(root=optimize1.data_folder,
                                         transform=TwoCropTransform(train_transform),
                                         download=True)
    elif optimize1.dataset == 'cifar100':
        train_dataset = datasets.CIFAR100(root=optimize1.data_folder,
                                          transform=TwoCropTransform(train_transform),
                                          download=True)
    elif optimize1.dataset == 'path':
        train_dataset = datasets.ImageFolder(root=optimize1.data_folder,
                                             transform=TwoCropTransform(train_transform))
    else:
        raise ValueError(optimize1.dataset)

    train_sampler = None
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=optimize1.batch_size, shuffle=(train_sampler is None),
        num_workers=optimize1.num_workers, pin_memory=True, sampler=train_sampler)

    return train_loader

def set_model():
    model = SupConResNet(name='resnet50')
    criterion = SupConLoss(temperature=optimize1.temp)

    # enable synchronized Batch Normalization
    if optimize1.syncBN:
        model = apex.parallel.convert_syncbn_model(model)

    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            model.encoder = torch.nn.DataParallel(model.encoder)
        model = model.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

    return model, criterion

def train(train_loader, model, criterion, optimizer, epoch):
    """one epoch training"""
    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    end = time.time()
    for idx, (images, labels) in enumerate(train_loader):
        data_time.update(time.time() - end)

        images = torch.cat([images[0], images[1]], dim=0)
        if torch.cuda.is_available():
            images = images.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True)
        bsz = labels.shape[0]

        # warm-up learning rate
        warmup_learning_rate(optimize1, epoch, idx, len(train_loader), optimizer)

        # compute loss
        features = model(images)
        f1, f2 = torch.split(features, [bsz, bsz], dim=0)
        features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
        if optimize1.method == 'SupCon':
            loss = criterion(features, labels)
        elif optimize1.method == 'SimCLR':
            loss = criterion(features)
        else:
            raise ValueError('contrastive method not supported: {}'.
                             format(optimize1.method))

        # update metric
        losses.update(loss.item(), bsz)

        # SGD
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # print info
        if (idx + 1) % optimize1.print_freq == 0:
            print('Train: [{0}][{1}/{2}]\t'
                  'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'loss {loss.val:.3f} ({loss.avg:.3f})'.format(
                   epoch, idx + 1, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses))
            sys.stdout.flush()

    return losses.avg

def main():
    # build data loader
    train_loader = set_loader()

    # build model and criterion
    model, criterion = set_model()

    # build optimizer
    optimizer = set_optimizer(optimize1, model)

    # tensorboard
    logger = tb_logger.Logger(logdir=tb_folder, flush_secs=2)

    # training routine
    for epoch in range(1, optimize1.epochs + 1):
        adjust_learning_rate(optimize1, optimizer, epoch)

        # train for one epoch
        time1 = time.time()
        loss = train(train_loader, model, criterion, optimizer, epoch)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))

        # tensorboard logger
        logger.log_value('loss', loss, epoch)
        logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch)

        if epoch % optimize1.save_freq == 0:
            save_file = os.path.join(
                save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            save_model(model, optimizer, optimize1, epoch, save_file)

    # save the last model
    save_file = os.path.join(
        save_folder, 'last.pth')
    save_model(model, optimizer, optimize1, optimize1.epochs, save_file)

if __name__ == '__main__':
    main()


Files already downloaded and verified
Train: [1][10/196]	BT 0.789 (0.868)	DT 0.000 (0.130)	loss 6.232 (6.235)
Train: [1][20/196]	BT 0.789 (0.828)	DT 0.000 (0.065)	loss 6.223 (6.232)
Train: [1][30/196]	BT 0.788 (0.816)	DT 0.000 (0.044)	loss 6.230 (6.229)
Train: [1][40/196]	BT 0.799 (0.811)	DT 0.000 (0.033)	loss 6.209 (6.224)
Train: [1][50/196]	BT 0.801 (0.808)	DT 0.000 (0.026)	loss 6.231 (6.221)
Train: [1][60/196]	BT 0.794 (0.806)	DT 0.000 (0.022)	loss 6.189 (6.217)
Train: [1][70/196]	BT 0.789 (0.804)	DT 0.000 (0.019)	loss 6.186 (6.215)
Train: [1][80/196]	BT 0.792 (0.803)	DT 0.000 (0.017)	loss 6.185 (6.211)
Train: [1][90/196]	BT 0.790 (0.801)	DT 0.000 (0.015)	loss 6.216 (6.208)
Train: [1][100/196]	BT 0.791 (0.800)	DT 0.000 (0.013)	loss 6.165 (6.205)
Train: [1][110/196]	BT 0.784 (0.799)	DT 0.000 (0.012)	loss 6.192 (6.203)
Train: [1][120/196]	BT 0.784 (0.798)	DT 0.000 (0.011)	loss 6.084 (6.199)
Train: [1][130/196]	BT 0.788 (0.797)	DT 0.000 (0.010)	loss 6.179 (6.198)
Train: [1][140/196]	BT