https://github.com/hongyi-zhang/Fixup

https://github.com/hongyi-zhang/Fixup/blob/master/cifar/models/

In [1]:
"""https://github.com/hongyi-zhang/Fixup/blob/master/cifar/models/__init__.py"""
#from .fixup_resnet_cifar import *
#from .resnet_cifar import *

'https://github.com/hongyi-zhang/Fixup/blob/master/cifar/models/__init__.py'

In [0]:
"""https://github.com/hongyi-zhang/Fixup/blob/master/cifar/models/fixup_resnet_cifar.py"""
import torch
import torch.nn as nn
import numpy as np


__all__ = ['FixupResNet', 'fixup_resnet8', 'fixup_resnet20', 'fixup_resnet32', 'fixup_resnet44', 'fixup_resnet56', 'fixup_resnet110', 'fixup_resnet1202']


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class FixupBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(FixupBasicBlock, self).__init__()
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.bias1a = nn.Parameter(torch.zeros(1))
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bias1b = nn.Parameter(torch.zeros(1))
        self.relu = nn.ReLU(inplace=True)
        self.bias2a = nn.Parameter(torch.zeros(1))
        self.conv2 = conv3x3(planes, planes)
        self.scale = nn.Parameter(torch.ones(1))
        self.bias2b = nn.Parameter(torch.zeros(1))
        self.downsample = downsample

    def forward(self, x):
        identity = x

        out = self.conv1(x + self.bias1a)
        out = self.relu(out + self.bias1b)

        out = self.conv2(out + self.bias2a)
        out = out * self.scale + self.bias2b

        if self.downsample is not None:
            identity = self.downsample(x + self.bias1a)
            identity = torch.cat((identity, torch.zeros_like(identity)), 1)

        out += identity
        out = self.relu(out)

        return out


class FixupResNet(nn.Module):

    def __init__(self, block, layers, num_classes=10):
        super(FixupResNet, self).__init__()
        self.num_layers = sum(layers)
        self.inplanes = 16
        self.conv1 = conv3x3(3, 16)
        self.bias1 = nn.Parameter(torch.zeros(1))
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, 16, layers[0])
        self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.bias2 = nn.Parameter(torch.zeros(1))
        self.fc = nn.Linear(64, num_classes)

        for m in self.modules():
            if isinstance(m, FixupBasicBlock):
                nn.init.normal_(m.conv1.weight, mean=0, std=np.sqrt(2 / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:]))) * self.num_layers ** (-0.5))
                nn.init.constant_(m.conv2.weight, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.weight, 0)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1:
            downsample = nn.AvgPool2d(1, stride=stride)

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes
        for _ in range(1, blocks):
            layers.append(block(planes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x + self.bias1)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x + self.bias2)

        return x


def fixup_resnet8(**kwargs):
    """Constructs a Fixup-ResNet-8 model.

    """
    model = FixupResNet(FixupBasicBlock, [1, 1, 1], **kwargs)
    return model


def fixup_resnet20(**kwargs):
    """Constructs a Fixup-ResNet-20 model.

    """
    model = FixupResNet(FixupBasicBlock, [3, 3, 3], **kwargs)
    return model


def fixup_resnet32(**kwargs):
    """Constructs a Fixup-ResNet-32 model.

    """
    model = FixupResNet(FixupBasicBlock, [5, 5, 5], **kwargs)
    return model


def fixup_resnet44(**kwargs):
    """Constructs a Fixup-ResNet-44 model.

    """
    model = FixupResNet(FixupBasicBlock, [7, 7, 7], **kwargs)
    return model


def fixup_resnet56(**kwargs):
    """Constructs a Fixup-ResNet-56 model.

    """
    model = FixupResNet(FixupBasicBlock, [9, 9, 9], **kwargs)
    return model


def fixup_resnet110(**kwargs):
    """Constructs a Fixup-ResNet-110 model.

    """
    model = FixupResNet(FixupBasicBlock, [18, 18, 18], **kwargs)
    return model


def fixup_resnet1202(**kwargs):
    """Constructs a Fixup-ResNet-1202 model.

    """
    model = FixupResNet(FixupBasicBlock, [200, 200, 200], **kwargs)
    return model    

In [0]:
"""https://github.com/hongyi-zhang/Fixup/blob/master/cifar/models/resnet_cifar.py"""
import torch
import torch.nn as nn
import numpy as np


__all__ = ['ResNet', 'resnet8', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202']


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)
            identity = torch.cat((identity, torch.zeros_like(identity)), 1)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=10):
        super(ResNet, self).__init__()
        self.num_layers = sum(layers)
        self.inplanes = 16
        self.conv1 = conv3x3(3, 16)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, 16, layers[0])
        self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        for m in self.modules():
            if isinstance(m, BasicBlock):
                nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1:
            downsample = nn.Sequential(
                nn.AvgPool2d(1, stride=stride),
                nn.BatchNorm2d(self.inplanes),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes
        for _ in range(1, blocks):
            layers.append(block(planes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


def resnet8(**kwargs):
    """Constructs a ResNet-8 model.

    """
    model = ResNet(BasicBlock, [1, 1, 1], **kwargs)
    return model


def resnet20(**kwargs):
    """Constructs a ResNet-20 model.

    """
    model = ResNet(BasicBlock, [3, 3, 3], **kwargs)
    return model


def resnet32(**kwargs):
    """Constructs a ResNet-32 model.

    """
    model = ResNet(BasicBlock, [5, 5, 5], **kwargs)
    return model


def resnet44(**kwargs):
    """Constructs a ResNet-44 model.

    """
    model = ResNet(BasicBlock, [7, 7, 7], **kwargs)
    return model


def resnet56(**kwargs):
    """Constructs a ResNet-56 model.

    """
    model = ResNet(BasicBlock, [9, 9, 9], **kwargs)
    return model


def resnet110(**kwargs):
    """Constructs a ResNet-110 model.

    """
    model = ResNet(BasicBlock, [18, 18, 18], **kwargs)
    return model


def resnet1202(**kwargs):
    """Constructs a ResNet-1202 model.

    """
    model = ResNet(BasicBlock, [200, 200, 200], **kwargs)
    return model    

In [0]:
""" Base code source: https://github.com/ronghuaiyang/arcface-pytorch/blob/master/models/metrics.py 
Referece: 
ArcFace paper at https://arxiv.org/pdf/1801.07698.pdf
CosFace paper at https://arxiv.org/pdf/1801.09414.pdf
SphereFace paper at https://arxiv.org/pdf/1704.08063.pdf
"""

from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import math

class LinearFWNorm(nn.Linear):

    def forward(self,input):
        #return F.linear(input, self.weight, self.bias)
        return F.linear(F.normalize(input), F.normalize(self.weight))

def logit_where_and_scale(label, cosine_margined, cosine, s):
    """ Use cosine_margined for the corresponding label, otherwise cosine.
    """   
    one_hot = torch.zeros(cosine.size(), device='cuda')
    one_hot.scatter_(1, label.view(-1, 1).long(), 1)
    #if torch.__version__ >= '0.4':
    if True:
        logit_output = torch.where(one_hot.byte(), cosine_margined, cosine) 
    #else:
    #    logit_output = (one_hot * cosine_margined) + ((1.0 - one_hot) * cosine)  
    logit_output *= s   
    return logit_output

class ArcFaceLoss(nn.CrossEntropyLoss):
    r""" Receiving the logit as cos(theta), calculate s * cos(theta + m_arc) as the penalized logit, and calculate CrossEntropyLoss.
        Args:
            s: feature scale (64 in the ArcFace paper)
            m_arc: additive angular margin m2; margin of angle in radian (0.5 in the ArcFace paper)           
        """
    def __init__(self, s=64.0, m_arc=0.50, easy_margin=False):
        super().__init__()
        self.s = s
        self.m_arc = m_arc
        self.easy_margin = easy_margin
        self.cos_m = math.cos(m_arc)
        self.sin_m = math.sin(m_arc)
        self.th = math.cos(math.pi - m_arc)
        self.mm = math.sin(math.pi - m_arc) * m_arc

    def forward(self, input, label):
        cosine = input # output of LinearFWNorm
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) # Trigonometric formula sin^2 + cos^2 = 1
        cosine_margined = cosine * self.cos_m - sine * self.sin_m # Trigonometric Addition formula cos(a+b) = cos(a) cos(b) = sin(a) sin(b)
        if self.easy_margin:
            cosine_margined = torch.where(cosine > 0, cosine_margined, cosine)
        else:
            cosine_margined = torch.where(cosine > self.th, cosine_margined, cosine - self.mm)
        
        logit_output = logit_where_and_scale(label, cosine_margined, cosine, self.s)        
        output = super().forward(logit_output, label)
        return output
    
    def __repr__(self):
        return self.__class__.__name__ + '(' \
               + ', s=' + str(self.s) \
               + ', m_arc=' + str(self.m_arc) + ')'

class CosFaceLoss(nn.CrossEntropyLoss):
    r""" Receiving the logit as cos(theta), calculate s * (cos(theta) - m_cos) as the penalized logit, and calculate CrossEntropyLoss.
    Args:
        s: feature scale (64 in the CosFace paper)
        m_cos: additive cosine margin m3; margin of cos(theta) (0.35 performed best on LFW and YTF according to the CosFace paper)
    """

    def __init__(self, s=64.0, m_cos=0.35):
        super().__init__()
        self.s = s
        self.m_cos = m_cos

    def forward(self, input, label):

        cosine = input # output of LinearFWNorm
        cosine_margined = cosine - self.m_cos        
        logit_output = logit_where_and_scale(label, cosine_margined, cosine, self.s)        
        output = super().forward(logit_output, label)
        return output

    def __repr__(self):
        return self.__class__.__name__ + '(' \
               + ', s=' + str(self.s) \
               + ', m_cos=' + str(self.m_cos) + ')'


class SphereFaceLoss(nn.CrossEntropyLoss):
    r""" Receiving the logit as cos(theta), calculate s * cos(m_sphere * theta) as the penalized logit, and calculate CrossEntropyLoss.
    Args:
        m_sphere: multiplicative angular margin m1; margin to scale cos(theta) (1.35 suggested in in the ArcFace paper)
    """
    def __init__(self, m_sphere=1.35):
        super().__init__()
        self.m_sphere = m_sphere
        self.base = 1000.0
        self.gamma = 0.12
        self.power = 1
        self.LambdaMin = 5.0
        self.iter = 0

        # duplication formula
        self.mlambda = [
            lambda x: x ** 0,
            lambda x: x ** 1,
            lambda x: 2 * x ** 2 - 1,
            lambda x: 4 * x ** 3 - 3 * x,
            lambda x: 8 * x ** 4 - 8 * x ** 2 + 1,
            lambda x: 16 * x ** 5 - 20 * x ** 3 + 5 * x
        ]

    def forward(self, input, label):
        # lambda = max(lambda_min,base*(1+gamma*iteration)^(-power))
        self.iter += 1
        self.lamb = max(self.LambdaMin, self.base * (1 + self.gamma * self.iter) ** (-1 * self.power))

        # --------------------------- cos(theta) & phi(theta) ---------------------------
        cosine = input # output of LinearFWNorm
        cosine = cosine.clamp(-1, 1)
        cos_m_theta = self.mlambda[self.m_sphere](cosine)
        theta = cosine.data.acos()
        k = (self.m_sphere * theta / 3.14159265).floor()
        phi_theta = ((-1.0) ** k) * cos_m_theta - 2 * k
        NormOfFeature = torch.norm(input, 2, 1)

        # --------------------------- convert label to one-hot ---------------------------
        one_hot = torch.zeros(cosine.size())
        one_hot = one_hot.cuda() if cosine.is_cuda else one_hot
        one_hot.scatter_(1, label.view(-1, 1), 1)

        # --------------------------- Calculate output ---------------------------
        logit_output = (one_hot * (phi_theta - cosine) / (1 + self.lamb)) + cosine
        logit_output *= NormOfFeature.view(-1, 1)

        output = super().forward(logit_output, label)        
        
        return output

    def __repr__(self):
        return self.__class__.__name__ + '(' \
               + ', m_sphere=' + str(self.m_sphere) + ')'
   

https://github.com/hongyi-zhang/Fixup/tree/master/cifar

In [0]:
"""https://github.com/hongyi-zhang/Fixup/blob/master/cifar/utils.py"""
'''Some helper functions for PyTorch, including:
    - get_mean_and_std: calculate the mean and std value of dataset.
    - msr_init: net parameter initialization.
    - progress_bar: progress bar mimic xlua.progress.
'''
import os
import sys
import time
import math

import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

import numpy as np
import torch

def mixup_data(x, y, alpha=1.0, use_cuda=True, per_sample=False):

    '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda'''
    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    if alpha > 0. and not per_sample:
        lam = torch.zeros(y.size()).fill_(np.random.beta(alpha, alpha)).cuda()
        mixed_x = lam.view(-1, 1, 1, 1) * x + (1 - lam.view(-1, 1, 1, 1)) * x[index,:]
    elif alpha > 0.:
        lam = torch.Tensor(np.random.beta(alpha, alpha, size=y.size())).cuda()
        mixed_x = lam.view(-1, 1, 1, 1) * x + (1 - lam.view(-1, 1, 1, 1)) * x[index,:]
    else:
        lam = torch.ones(y.size()).cuda()
        mixed_x = x

    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_lam_idx(batch_size, alpha, use_cuda=True):
    '''Compute the mixup data. Return mixed inputs, pairs of targets, and lambda'''
    if alpha > 0.:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    return lam, index    

def mixup_criterion(y_a, y_b, lam):
    return lambda criterion, pred: criterion(pred, y_a, lam) + criterion(pred, y_b, 1 - lam)

def get_mean_and_std(dataset):
    '''Compute the mean and std value of dataset.'''
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    print('==> Computing mean and std..')
    for inputs, targets in dataloader:
        for i in range(3):
            mean[i] += inputs[:,i,:,:].mean()
            std[i] += inputs[:,i,:,:].std()
    mean.div_(len(dataset))
    std.div_(len(dataset))
    return mean, std

def init_params(net):
    '''Init layer parameters.'''
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            init.kaiming_normal(m.weight, mode='fan_out')
            if m.bias:
                init.constant(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            init.constant(m.weight, 1)
            init.constant(m.bias, 0)
        elif isinstance(m, nn.Linear):
            init.normal(m.weight, std=1e-3)
            if m.bias:
                init.constant(m.bias, 0)


#_, term_width = os.popen('stty size', 'r').read().split()
#term_width = int(term_width)
term_width = 100

TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time

def progress_bar(current, total, msg=None):
    global last_time, begin_time
    if current == 0:
        begin_time = time.time()  # Reset for new bar.

    cur_len = int(TOTAL_BAR_LENGTH*current/total)
    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
    
    p = ''
    p += ' [' #sys.stdout.write(' [')    
    for i in range(cur_len):
        p += '=' #sys.stdout.write('=')        
    p += '>' #sys.stdout.write('>')    
    for i in range(rest_len):
        p += '.' #sys.stdout.write('.')        
    p += ']' #sys.stdout.write(']')   

    cur_time = time.time()
    step_time = cur_time - last_time
    last_time = cur_time
    tot_time = cur_time - begin_time

    L = []
    L.append('  Step: %s' % format_time(step_time))
    L.append(' | Tot: %s' % format_time(tot_time))
    if msg:
        L.append(' | ' + msg)

    msg = ''.join(L)
    p += msg #sys.stdout.write(msg)
    for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
        p += ' ' #sys.stdout.write(' ')

    # Go back to the center of the bar.
    for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
        p += '\b' #'sys.stdout.write('\b')
    p += ' %d/%d ' % (current+1, total) #sys.stdout.write(' %d/%d ' % (current+1, total))

    if current < total-1:
        p += '\r' #sys.stdout.write('\r')
    else:
        p += '\n' #sys.stdout.write('\n')
    print(p, '\r', end='', flush=True) #sys.stdout.flush()
    
    
def format_time(seconds):
    days = int(seconds / 3600/24)
    seconds = seconds - days*3600*24
    hours = int(seconds / 3600)
    seconds = seconds - hours*3600
    minutes = int(seconds / 60)
    seconds = seconds - minutes*60
    secondsf = int(seconds)
    seconds = seconds - secondsf
    millis = int(seconds*1000)

    f = ''
    i = 1
    if days > 0:
        f += str(days) + 'D'
        i += 1
    if hours > 0 and i <= 2:
        f += str(hours) + 'h'
        i += 1
    if minutes > 0 and i <= 2:
        f += str(minutes) + 'm'
        i += 1
    if secondsf > 0 and i <= 2:
        f += str(secondsf) + 's'
        i += 1
    if millis > 0 and i <= 2:
        f += str(millis) + 'ms'
        i += 1
    if f == '':
        f = '0ms'
    return f


In [0]:
"""https://github.com/hongyi-zhang/Fixup/blob/master/cifar/cifar_train.py"""
'''Train CIFAR10 with PyTorch.'''
from __future__ import print_function

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.optim.lr_scheduler import CosineAnnealingLR

import torchvision
import torchvision.transforms as transforms

import os
import argparse
import csv

#import models

#from utils import progress_bar, mixup_data, mixup_criterion

import numpy
import random


import easydict
from datetime import datetime

# Training
def train(net, trainloader, epoch, optimizer, loss_func, use_cuda):
    #print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        if use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        # generate mixed inputs, two one-hot label vectors and mixing coefficient
        inputs, targets_a, targets_b, lam = mixup_data(inputs, targets, args.alpha, use_cuda)
        optimizer.zero_grad()
        outputs = net(inputs)
        if args.loss == 'mixup': 
            mixup_loss_func = mixup_criterion(targets_a, targets_b, lam)
            loss = mixup_loss_func(criterion, outputs)
        else:
            loss = loss_func(outputs, targets)

        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += (lam * predicted.eq(targets_a.data).float()).cpu().sum() + ((1 - lam) * predicted.eq(targets_b.data).float()).cpu().sum()
        #acc = 100.*float(correct)/float(total)
        acc = float(correct)/float(total)

        #progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
        #    % (train_loss/(batch_idx+1), acc, correct, total))

    return (train_loss/batch_idx, acc)

def test(net, testloader, epoch, loss_func, use_cuda):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            if use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda()
            outputs = net(inputs)
            if args.loss == 'mixup': 
                loss = nn.CrossEntropyLoss()(outputs, targets)
            else:
                loss = loss_func(outputs, targets)

            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += predicted.eq(targets.data).cpu().sum()

            #progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            #    % (test_loss/(batch_idx+1), 100.*float(correct)/float(total), correct, total))

        # Save checkpoint.
        #acc = 100.*float(correct)/float(total)
        acc = float(correct)/float(total)
        if acc > best_acc:
            best_acc = acc
            checkpoint(net, acc, epoch)

    return (test_loss/batch_idx, acc)

def checkpoint(net, acc, epoch):
    # Save checkpoint.
    #print('Saving..')
    state = {
        'net': net,
        'acc': acc,
        'epoch': epoch,
        'rng_state': torch.get_rng_state()
    }
    if not os.path.isdir('checkpoint'):
        os.mkdir('checkpoint')
    torch.save(state, './checkpoint/' + args.arch + '_' + args.sess + '_' + str(args.seed) + '.ckpt')

def adjust_learning_rate(optimizer, epoch, base_learning_rate):
    """decrease the learning rate at 100 and 150 epoch"""
    lr = base_learning_rate
    if epoch <= 9 and lr > 0.1:
        # warm-up training for large minibatch
        lr = 0.1 + (base_learning_rate - 0.1) * epoch / 10.
    if epoch >= 100:
        lr /= 10
    if epoch >= 150:
        lr /= 10
    for param_group in optimizer.param_groups:
        if param_group['initial_lr'] == base_learning_rate:
            param_group['lr'] = lr
        else:
            if epoch <= 9:
                param_group['lr'] = param_group['initial_lr'] * lr / base_learning_rate
            elif epoch < 100:
                param_group['lr'] = param_group['initial_lr']
            elif epoch < 150:
                param_group['lr'] = param_group['initial_lr'] / 10.
            else:
                param_group['lr'] = param_group['initial_lr'] / 100.
    return lr



def timestamp():
    return datetime.now().strftime('%Y-%m-%dT%H%M%S')

def run_trial(args, logname):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    numpy.random.seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    use_cuda = torch.cuda.is_available()
    global best_acc
    best_acc = 0  # best test accuracy
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch
    batch_size = args.batchsize
    base_learning_rate = args.base_lr * args.batchsize / 128.
    if use_cuda:
        # data parallel
        n_gpu = torch.cuda.device_count()
        batch_size *= n_gpu
        base_learning_rate *= n_gpu

    # Data
    print('==> Preparing data..')
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    # Model
    if args.resume:
        #assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
        if not os.path.isdir('checkpoint'): os.mkdir('checkpoint')
        checkpoint_file = './checkpoint/ckpt.t7.' + args.sess + '_' + str(args.seed)
    if args.resume and os.path.exists(checkpoint_file):
        # Load checkpoint.
        print('==> Resuming from checkpoint..')
        checkpoint = torch.load(checkpoint_file)
        net = checkpoint['net']
        best_acc = checkpoint['acc']
        start_epoch = checkpoint['epoch'] + 1
        torch.set_rng_state(checkpoint['rng_state'])
    else:
        print("=> creating model '{}'".format(args.arch))
        #net = models.__dict__[args.arch]()
        net = globals().get(args.arch)()
        if loss in ['cosface', 'arcface']:
            num_ftrs = net.fc.in_features
            num_classes = 10
            net.fc = LinearFWNorm(num_ftrs, num_classes)

    if use_cuda:
        net.cuda()
        net = torch.nn.DataParallel(net)
        print('Using', torch.cuda.device_count(), 'GPUs.')
        cudnn.benchmark = True
        print('Using CUDA..')


    if args.loss == 'cosface':
        loss_func = CosFaceLoss(s=64.0, m_cos=0.35)
    if args.loss == 'arcface':
        loss_func = ArcFaceLoss(s=64.0, m_arc=0.5)
    if args.loss == 'cel': 
        loss_func = nn.CrossEntropyLoss()
    if args.loss == 'mixup':
        criterion = lambda pred, target, lam: (-F.log_softmax(pred, dim=1) * torch.zeros(pred.size()).cuda().scatter_(1, target.data.view(-1, 1), lam.view(-1, 1))).sum(dim=1).mean()
    parameters_bias = [p[1] for p in net.named_parameters() if 'bias' in p[0]]
    parameters_scale = [p[1] for p in net.named_parameters() if 'scale' in p[0]]
    parameters_others = [p[1] for p in net.named_parameters() if not ('bias' in p[0] or 'scale' in p[0])]
    optimizer = optim.SGD(
            [{'params': parameters_bias, 'lr': args.base_lr/10.}, 
            {'params': parameters_scale, 'lr': args.base_lr/10.}, 
            {'params': parameters_others}], 
            lr=base_learning_rate, 
            momentum=0.9, 
            weight_decay=args.decay)



    if not os.path.exists(logname):
        with open(logname, 'w') as logfile:
            logwriter = csv.writer(logfile, delimiter=',')
            logwriter.writerow(['epoch', 'lr', 'train loss', 'train acc', 'test loss', 'test acc', 'arch', 'loss', 'timestamp'])

    sgdr = CosineAnnealingLR(optimizer, args.n_epoch, eta_min=0, last_epoch=-1)

    print('### Started to train the model. | arch: {} | loss: {}'.format(args.arch, args.loss))
    for epoch in range(start_epoch, args.n_epoch):
        lr = 0.
        if args.sgdr:
            sgdr.step()
            for param_group in optimizer.param_groups:
                lr = param_group['lr']
                break
        else:
            lr = adjust_learning_rate(optimizer, epoch, base_learning_rate)    
        train_loss, train_acc = train(net, trainloader, epoch, optimizer, loss_func, use_cuda)
        test_loss, test_acc = test(net, testloader, epoch, loss_func, use_cuda)
        with open(logname, 'a') as logfile:
            logwriter = csv.writer(logfile, delimiter=',')
            logwriter.writerow([epoch, lr, train_loss, train_acc, test_loss, test_acc, args.arch, args.loss, timestamp()])
        pritn_str = '### [{}] Epoch: {:3d} | LR: {:3f} | train_loss: {:6f} | train_acc: {:6f} | test_loss: {:6f} | test_acc: {:6f}'.format(timestamp(), epoch, lr, train_loss, train_acc, test_loss, test_acc)
        print(pritn_str)


In [7]:
"""
model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))
"""
model_names = ['fixup_resnet_cifar', 'resnet_cifar', 'resnet_arcface_cifar', 'resnet_cosface_cifar']

"""
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('-a', '--arch', metavar='ARCH', default='fixup_resnet110', choices=model_names, help='model architecture: ' +
                        ' | '.join(model_names) + ' (default: fixup_resnet110)')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
parser.add_argument('--sess', default='mixup_default', type=str, help='session id')
parser.add_argument('--seed', default=0, type=int, help='rng seed')
parser.add_argument('--alpha', default=1., type=float, help='interpolation strength (uniform=1., ERM=0.)')
parser.add_argument('--sgdr', action='store_true', help='use SGD with cosine annealing learning rate and restarts')
parser.add_argument('--decay', default=1e-4, type=float, help='weight decay (default=1e-4)')
parser.add_argument('--batchsize', default=128, type=int, help='batch size per GPU (default=128)')
parser.add_argument('--n_epoch', default=200, type=int, help='total number of epochs')
parser.add_argument('--base_lr', default=0.1, type=float, help='base learning rate (default=0.1)')

args = parser.parse_args()
"""

args = easydict.EasyDict({
        "arch": 'resnet8',
        "resume": False,
        "sess": 'benchmark',
        "seed": 0,
        "alpha": 1.,
        "sgdr": True,
        "decay": 1e-4,
        "batchsize": 128,
        "n_epoch": 30,
        "base_lr": 0.1,
        #"loss": 'mixup'
        "loss": 'cel'
        #"loss": 'cosface'
        #"loss": 'arcface'
})

result_folder = './results/'
if not os.path.exists(result_folder):
    os.makedirs(result_folder)
        
logname = result_folder + args.sess 
logname += '_seed_' + str(args.seed) + '_' + timestamp() + '.csv'

loss_list = ['cel', 'cosface', 'arcface']
# loss_list = ['cel']
# arch_list = ['resnet8', 'fixup_resnet8']
# arch_list = ['resnet20', 'fixup_resnet20']
arch_list = ['resnet20']

for loss in loss_list:
    args.loss = loss
    for arch in arch_list:
        args.arch = arch
        run_trial(args, logname)

==> Preparing data..
Files already downloaded and verified
Files already downloaded and verified
=> creating model 'resnet20'
Using 1 GPUs.
Using CUDA..
### Started to train the model. | arch: resnet20 | loss: cel
### [2019-04-16T144005] Epoch:   0 | LR: 0.010000 | train_loss: 2.210814 | train_acc: 0.210656 | test_loss: 1.953614 | test_acc: 0.303000


  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


### [2019-04-16T144021] Epoch:   1 | LR: 0.009973 | train_loss: 2.155687 | train_acc: 0.270756 | test_loss: 1.832633 | test_acc: 0.379900
### [2019-04-16T144037] Epoch:   2 | LR: 0.009891 | train_loss: 2.110551 | train_acc: 0.314149 | test_loss: 1.816447 | test_acc: 0.393100
### [2019-04-16T144055] Epoch:   3 | LR: 0.009755 | train_loss: 2.078915 | train_acc: 0.346546 | test_loss: 1.674665 | test_acc: 0.444100
### [2019-04-16T144111] Epoch:   4 | LR: 0.009568 | train_loss: 2.049814 | train_acc: 0.377299 | test_loss: 1.516799 | test_acc: 0.499300
### [2019-04-16T144127] Epoch:   5 | LR: 0.009330 | train_loss: 2.032856 | train_acc: 0.406963 | test_loss: 1.552267 | test_acc: 0.503200
### [2019-04-16T144143] Epoch:   6 | LR: 0.009045 | train_loss: 2.001226 | train_acc: 0.418550 | test_loss: 1.547571 | test_acc: 0.519300
### [2019-04-16T144200] Epoch:   7 | LR: 0.008716 | train_loss: 1.962294 | train_acc: 0.430170 | test_loss: 1.372405 | test_acc: 0.568200
### [2019-04-16T144217] Epoch:   8

  "type " + obj.__name__ + ". It won't be checked "


### [2019-04-16T144828] Epoch:   1 | LR: 0.009973 | train_loss: 24.646872 | train_acc: 0.169426 | test_loss: 24.744155 | test_acc: 0.191100
### [2019-04-16T144845] Epoch:   2 | LR: 0.009891 | train_loss: 24.601844 | train_acc: 0.199634 | test_loss: 24.557546 | test_acc: 0.270700
### [2019-04-16T144903] Epoch:   3 | LR: 0.009755 | train_loss: 24.559863 | train_acc: 0.222325 | test_loss: 24.393747 | test_acc: 0.295000
### [2019-04-16T144919] Epoch:   4 | LR: 0.009568 | train_loss: 24.515500 | train_acc: 0.243387 | test_loss: 24.273707 | test_acc: 0.311300
### [2019-04-16T144934] Epoch:   5 | LR: 0.009330 | train_loss: 24.499659 | train_acc: 0.257738 | test_loss: 24.238698 | test_acc: 0.340300
### [2019-04-16T144951] Epoch:   6 | LR: 0.009045 | train_loss: 24.481608 | train_acc: 0.267063 | test_loss: 24.234110 | test_acc: 0.359900
### [2019-04-16T145007] Epoch:   7 | LR: 0.008716 | train_loss: 24.426747 | train_acc: 0.275379 | test_loss: 24.048345 | test_acc: 0.410800
### [2019-04-16T1450