In [None]:
!git clone https://github.com/KaihuaTang/CiiV-Adversarial-Robustness.pytorch/

fatal: destination path 'CiiV-Adversarial-Robustness.pytorch' already exists and is not an empty directory.


In [None]:
import sys
sys.path.append("/content/CiiV-Adversarial-Robustness.pytorch/attacker")
sys.path.append("/content/CiiV-Adversarial-Robustness.pytorch/utils/attack_utils.py")
sys.path.append("/content/CiiV-Adversarial-Robustness.pytorch/utils/checkpoint_utils.py")
sys.path.append("/content/CiiV-Adversarial-Robustness.pytorch/utils/general_utils.py")
sys.path.append("/content/CiiV-Adversarial-Robustness.pytorch/utils/logger_utils.py")
sys.path.append("/content/CiiV-Adversarial-Robustness.pytorch/utils/train_utils.py")

In [None]:
!cd /content/CiiV-Adversarial-Robustness.pytorch/

In [None]:
def rgb_norm(images, config):
    mean = [0.4914, 0.4822, 0.4465]
    std = [0.2471, 0.2435, 0.2616]
    mean = torch.tensor(mean).view(1,3,1,1).to(images.device)
    std = torch.tensor(std).view(1,3,1,1).to(images.device)
    images = (images - mean) / std
    return images

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class Base_Model(nn.Module):
    """
    base model used for adversarial attack and defense
    """
    def __init__(self):
        super(Base_Model, self).__init__()
        # attacking mode, i.e., generating attacking images
        self.attacking = False

    def set_attack(self):
        self.attacking = True
        # recursive set all modules to attack
        for m in self.modules():
            if isinstance(m, Base_Model) and (not m.is_attack()):
                m.set_attack()

    def set_unattack(self):
        self.attacking = False
        # recursive set all modules to unattack
        for m in self.modules():
            if isinstance(m, Base_Model) and m.is_attack():
                m.set_unattack()

    def is_attack(self):
        return self.attacking

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

'''Pre-activation ResNet in PyTorch.
Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Identity Mappings in Deep Residual Networks. arXiv:1603.05027
'''

# from models.Base_Model import Base_Model

# from utils.train_utils import *

class PreActBlock(Base_Model):
    '''Pre-activation version of the BasicBlock.'''
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, activation='ReLU', softplus_beta=1):
        super(PreActBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes, track_running_stats=True, affine=True)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes, track_running_stats=True, affine=True)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)

        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
            )
        if activation == 'ReLU':
            self.relu = nn.ReLU(inplace=True)
            print('ReLU')
        elif activation == 'Softplus':
            self.relu = nn.Softplus(beta=softplus_beta, threshold=20)
            print('Softplus')
        elif activation == 'GELU':
            self.relu = nn.GELU()
            print('GELU')
        elif activation == 'ELU':
            self.relu = nn.ELU(alpha=1.0, inplace=True)
            print('ELU')
        elif activation == 'LeakyReLU':
            self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
            print('LeakyReLU')
        elif activation == 'SELU':
            self.relu = nn.SELU(inplace=True)
            print('SELU')
        elif activation == 'CELU':
            self.relu = nn.CELU(alpha=1.2, inplace=True)
            print('CELU')
        elif activation == 'Tanh':
            self.relu = nn.Tanh()
            print('Tanh')

    def forward(self, x):
        out = self.relu(self.bn1(x))
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.conv2(self.relu(self.bn2(out)))
        out += shortcut
        return out


class PreActBottleneck(Base_Model):
    '''Pre-activation version of the original Bottleneck module.'''
    expansion = 4

    def __init__(self, in_planes, planes, stride=1, activation='ReLU', softplus_beta=1):
        super(PreActBottleneck, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes, track_running_stats=True, affine=True)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes, track_running_stats=True, affine=True)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes, track_running_stats=True, affine=True)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)

        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
            )

    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))
        out = self.conv3(F.relu(self.bn3(out)))
        out += shortcut
        return out


class PreActResNet(Base_Model):
    def __init__(self, block, num_blocks, num_classes=10, activation='ReLU', softplus_beta=1,
                    num_sample=3, aug_weight=0.9, mask_center=[5, 16, 27]):
        super(PreActResNet, self).__init__()
        self.in_planes = 64
        self.num_classes = num_classes

        self.activation = activation
        self.softplus_beta = softplus_beta

        self.num_sample = num_sample
        self.aug_weight = aug_weight
        self.mask_center = mask_center

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.bn = nn.BatchNorm2d(512 * block.expansion, track_running_stats=True, affine=True)
        self.linear = nn.Linear(512*block.expansion, num_classes)


        if activation == 'ReLU':
            self.relu = nn.ReLU(inplace=True)
            print('ReLU')
        elif activation == 'Softplus':
            self.relu = nn.Softplus(beta=softplus_beta, threshold=20)
            print('Softplus')
        elif activation == 'GELU':
            self.relu = nn.GELU()
            print('GELU')
        elif activation == 'ELU':
            self.relu = nn.ELU(alpha=1.0, inplace=True)
            print('ELU')
        elif activation == 'LeakyReLU':
            self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
            print('LeakyReLU')
        elif activation == 'SELU':
            self.relu = nn.SELU(inplace=True)
            print('SELU')
        elif activation == 'CELU':
            self.relu = nn.CELU(alpha=1.2, inplace=True)
            print('CELU')
        elif activation == 'Tanh':
            self.relu = nn.Tanh()
            print('Tanh')
        print('Use activation of ' + activation)


    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride,
                activation=self.activation, softplus_beta=self.softplus_beta))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def create_mask(self, w, h, center_x, center_y, alpha=10.0):
        widths = torch.arange(w).view(1, -1).repeat(h,1)
        heights = torch.arange(h).view(-1, 1).repeat(1,w)
        mask = ((widths - center_x)**2 + (heights - center_y)**2).float().sqrt()
        # non-linear
        mask = (mask.max() - mask + alpha) ** 0.3
        mask = mask / mask.max()
        # sampling
        mask = (mask + mask.clone().uniform_(0, 1)) > 0.9
        mask.float()
        return mask.unsqueeze(0)

    def create_mask_candidate1(self, w, h, center_x, center_y, alpha=10.0):
        widths = torch.arange(w).view(1, -1).repeat(h,1)
        heights = torch.arange(h).view(-1, 1).repeat(1,w)
        mask = ((widths - center_x)**2 + (heights - center_y)**2).float().sqrt()
        # non-linear
        mask = 1.0 - mask / 120
        # sampling
        mask = (mask + mask.clone().uniform_(0, 1)) > 0.9
        mask.float()
        return mask.unsqueeze(0)

    def create_mask_candidate2(self, w, h, center_x, center_y, alpha=10.0):
        widths = torch.arange(w).view(1, -1).repeat(h,1)
        heights = torch.arange(h).view(-1, 1).repeat(1,w)
        mask = ((widths - center_x)**2 + (heights - center_y)**2).float().sqrt()
        # non-linear
        mask = 2.5 / (0.6 * mask**0.5)
        # sampling
        mask = (mask + mask.clone().uniform_(0, 1)) > 0.9
        mask.float()
        return mask.unsqueeze(0)

    def ciiv_forward(self, x, loop):
        b, c, w, h = x.shape
        samples = []
        masks = []
        NUM_LOOP = loop
        NUM_INNER_SAMPLE = self.num_sample
        NUM_TOTAL_SAMPLE = NUM_LOOP * NUM_INNER_SAMPLE

        # generate all samples
        for i in range(NUM_TOTAL_SAMPLE):
            # differentiable sampling
            sample = self.relu(x + x.detach().clone().uniform_(-1,1) * self.aug_weight)
            sample = sample / (sample + 1e-5)
            #on_sample = torch.clamp(x + torch.randn_like(x) * 0.1, min=0, max=1)
            if i % NUM_INNER_SAMPLE == 0:
                idx = int(i // NUM_INNER_SAMPLE)
                x_idx = int(idx // 3)
                y_idx = int(idx % 3)
                center_x = self.mask_center[x_idx]
                center_y = self.mask_center[y_idx]
            # attention
            mask = self.create_mask(w, h, center_x, center_y, alpha=10.0).to(x.device)
            sample = sample * mask
            samples.append(sample)
            masks.append(mask)

        # run network
        outputs = []
        features = []
        z_scores = []
        for i in range(NUM_LOOP):
            # Normalized input
            inputs = sum(samples[NUM_INNER_SAMPLE * i : NUM_INNER_SAMPLE * (i+1)]) / NUM_INNER_SAMPLE
            z_score = (sum(masks[NUM_INNER_SAMPLE * i : NUM_INNER_SAMPLE * (i+1)]).float() / NUM_INNER_SAMPLE).mean()
            # forward modules
            out = self.conv1(inputs)
            out = self.layer1(out)
            out = self.layer2(out)
            out = self.layer3(out)
            out = self.layer4(out)
            out = self.relu(self.bn(out))
            size = out.shape[-1]
            out = F.avg_pool2d(out, size)
            feats = out.view(out.size(0), -1)
            preds = self.linear(feats)
            z_scores.append(z_score.view(1,1).repeat(b, 1))
            features.append(feats)
            outputs.append(preds)

        final_pred = sum([pred / (z + 1e-9) for pred, z in zip(outputs, z_scores)]) / NUM_LOOP

        ## Randomized Smoothing Inference
        #if self.training or self.is_attack():
        #    final_pred = sum([pred / (z + 1e-9) for pred, z in zip(outputs, z_scores)]) / NUM_LOOP
        #else:
        #    counts = []
        #    for item in outputs:
        #        pred = item.max(-1)[1]
        #        counts.append(F.one_hot(pred, self.num_classes))
        #    final_pred = sum(counts)
        return final_pred, z_scores, features, outputs


    def forward(self, x, loop=1):
        if x.shape[1] == 1:
            x = x.repeat(1, 3, 1, 1)
        # config is passed through main
        x = rgb_norm(x, self.config)

        if self.training:
            return self.ciiv_forward(x, loop=loop)
        else:
            return self.ciiv_forward(x, loop=loop)[0]


def PreActResNet18(num_classes=10, activation='ReLU', softplus_beta=1, **kwargs):
    return PreActResNet(PreActBlock, [2,2,2,2], num_classes=num_classes, activation=activation, softplus_beta=softplus_beta, **kwargs)

def PreActResNet34(num_classes, **kwargs):
    return PreActResNet(PreActBlock, [3,4,6,3], num_classes, **kwargs)

def PreActResNet50(num_classes, **kwargs):
    return PreActResNet(PreActBottleneck, [3,4,6,3], num_classes, **kwargs)


def create_model(m_type='resnet18', num_classes=1000, num_sample=3, aug_weight=0.9, mask_center=[5, 16, 27]):
    # create various resnet models
    if m_type == 'resnet18':
        model = PreActResNet18(num_classes=num_classes, num_sample=num_sample,
                                aug_weight=aug_weight, mask_center=mask_center)
    elif m_type == 'resnet50':
        model = PreActResNet50(num_classes=num_classes, num_sample=num_sample,
                                aug_weight=aug_weight, mask_center=mask_center)
    else:
        raise ValueError('Wrong Model Type')
    return model

In [None]:
import torch
import importlib
import random
import attacker

def rand_adv_init(config):
    if random.uniform(0, 1) < config['attacker_opt']['attack_rand_ini']:
        return True
    else:
        return False


def create_adversarial_attacker(config, model, logger):
    if config['attacker_opt']['attack_type'] == 'PGD':
        return attacker.PGD(model, logger, config, eps=config['attacker_opt']['attack_eps'],
                                                alpha=config['attacker_opt']['attack_alpha'],
                                                steps=config['attacker_opt']['attack_step'],
                                                eot_iter=config['attacker_opt']['eot_iter'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'PGDL2':
        return attacker.PGDL2(model, logger, config, eps=config['attacker_opt']['attack_eps'],
                                                alpha=config['attacker_opt']['attack_alpha'],
                                                steps=config['attacker_opt']['attack_step'],
                                                eot_iter=config['attacker_opt']['eot_iter'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'AutoAttack':
        return attacker.AA(model, logger, config, eps=config['attacker_opt']['attack_eps'],
                                                norm=config['attacker_opt']['attack_norm'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'FGSM':
        return attacker.FGSM(model, logger, config, eps=config['attacker_opt']['attack_eps'],
                                                eot_iter=config['attacker_opt']['eot_iter'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'FFGSM':
        return attacker.FFGSM(model, logger, config, eps=config['attacker_opt']['attack_eps'],
                                                alpha=config['attacker_opt']['attack_alpha'],
                                                eot_iter=config['attacker_opt']['eot_iter'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'GN':
        return attacker.GN(model, logger, config, sigma=config['attacker_opt']['gn_sigma'],
                                                eps=config['attacker_opt']['attack_eps'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'UN':
        return attacker.UN(model, logger, config, sigma=config['attacker_opt']['un_sigma'],
                                                eps=config['attacker_opt']['attack_eps'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'BPDAPGD':
        return attacker.BPDAPGD(model, logger, config, eps=config['attacker_opt']['attack_eps'],
                                                alpha=config['attacker_opt']['attack_alpha'],
                                                steps=config['attacker_opt']['attack_step'],
                                                eot_iter=config['attacker_opt']['eot_iter'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'EOT':
        return attacker.EOT(model, logger, config, eps=config['attacker_opt']['attack_eps'],
                                                learning_rate=config['attacker_opt']['attack_lr'],
                                                steps=config['attacker_opt']['attack_step'],
                                                eot_iter=config['attacker_opt']['eot_iter'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'CW':
        return attacker.CW(model, logger, config, c=config['attacker_opt']['c'],
                                                kappa=config['attacker_opt']['kappa'],
                                                steps=config['attacker_opt']['steps'],
                                                lr=config['attacker_opt']['lr'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'BFS':
        return attacker.BFS(model, logger, config, sigma=config['attacker_opt']['sigma'],
                                                eps=config['attacker_opt']['attack_eps'],
                                                steps=config['attacker_opt']['steps'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'SPSA':
        return attacker.SPSA(model, logger, config, eps=config['attacker_opt']['attack_eps'],
                                                delta=config['attacker_opt']['delta'],
                                                batch_size=config['attacker_opt']['batch_size'],
                                                steps=config['attacker_opt']['steps'],
                                                lr=config['attacker_opt']['lr'],
                                            )
    else:
        logger.raise_error('Wrong Attacker Type')



def get_adv_target(target_type, model, inputs, gt_label):
    with torch.no_grad():
        preds = model(inputs).softmax(-1)
    num_batch, num_class = preds.shape

    if target_type == 'random':
        adv_targets = torch.randint(0, num_class, (num_batch,)).to(gt_label.device)
        # validation check
        adv_targets = adv_target_update(gt_label, adv_targets, num_batch, num_class)
    elif target_type == 'most':
        idxs = torch.arange(num_batch).to(inputs.device)
        preds[idxs, gt_label] = -1
        adv_targets = preds.max(-1)[1]
    elif target_type == 'least':
        idxs = torch.arange(num_batch).to(inputs.device)
        preds[idxs, gt_label] = 100.0
        adv_targets = preds.min(-1)[1]
    else:
        raise ValueError('Wrong Targeted Attack Type')

    assert (adv_targets == gt_label).long().sum().item() == 0
    return adv_targets

def adv_target_update(gt_label, adv_target, num_batch, num_class):
    for i in range(num_batch):
        if int(gt_label[i]) == int(adv_target[i]):
            adv_target[i] = (int(adv_target[i]) + random.randint(1, num_class-1)) % num_class
    return adv_target

In [None]:
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

# Model
print('==> Building model..')
net = torchvision.models.resnet18()
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=lr,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)


In [None]:
import torch
import importlib
import random
import attacker

def rand_adv_init(config):
    if random.uniform(0, 1) < config['attacker_opt']['attack_rand_ini']:
        return True
    else:
        return False


def create_adversarial_attacker(config, model, logger):
    if config['attacker_opt']['attack_type'] == 'PGD':
        return attacker.PGD(model, logger, config, eps=config['attacker_opt']['attack_eps'],
                                                alpha=config['attacker_opt']['attack_alpha'],
                                                steps=config['attacker_opt']['attack_step'],
                                                eot_iter=config['attacker_opt']['eot_iter'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'PGDL2':
        return attacker.PGDL2(model, logger, config, eps=config['attacker_opt']['attack_eps'],
                                                alpha=config['attacker_opt']['attack_alpha'],
                                                steps=config['attacker_opt']['attack_step'],
                                                eot_iter=config['attacker_opt']['eot_iter'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'AutoAttack':
        return attacker.AA(model, logger, config, eps=config['attacker_opt']['attack_eps'],
                                                norm=config['attacker_opt']['attack_norm'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'FGSM':
        return attacker.FGSM(model, logger, config, eps=config['attacker_opt']['attack_eps'],
                                                eot_iter=config['attacker_opt']['eot_iter'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'FFGSM':
        return attacker.FFGSM(model, logger, config, eps=config['attacker_opt']['attack_eps'],
                                                alpha=config['attacker_opt']['attack_alpha'],
                                                eot_iter=config['attacker_opt']['eot_iter'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'GN':
        return attacker.GN(model, logger, config, sigma=config['attacker_opt']['gn_sigma'],
                                                eps=config['attacker_opt']['attack_eps'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'UN':
        return attacker.UN(model, logger, config, sigma=config['attacker_opt']['un_sigma'],
                                                eps=config['attacker_opt']['attack_eps'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'BPDAPGD':
        return attacker.BPDAPGD(model, logger, config, eps=config['attacker_opt']['attack_eps'],
                                                alpha=config['attacker_opt']['attack_alpha'],
                                                steps=config['attacker_opt']['attack_step'],
                                                eot_iter=config['attacker_opt']['eot_iter'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'EOT':
        return attacker.EOT(model, logger, config, eps=config['attacker_opt']['attack_eps'],
                                                learning_rate=config['attacker_opt']['attack_lr'],
                                                steps=config['attacker_opt']['attack_step'],
                                                eot_iter=config['attacker_opt']['eot_iter'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'CW':
        return attacker.CW(model, logger, config, c=config['attacker_opt']['c'],
                                                kappa=config['attacker_opt']['kappa'],
                                                steps=config['attacker_opt']['steps'],
                                                lr=config['attacker_opt']['lr'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'BFS':
        return attacker.BFS(model, logger, config, sigma=config['attacker_opt']['sigma'],
                                                eps=config['attacker_opt']['attack_eps'],
                                                steps=config['attacker_opt']['steps'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'SPSA':
        return attacker.SPSA(model, logger, config, eps=config['attacker_opt']['attack_eps'],
                                                delta=config['attacker_opt']['delta'],
                                                batch_size=config['attacker_opt']['batch_size'],
                                                steps=config['attacker_opt']['steps'],
                                                lr=config['attacker_opt']['lr'],
                                            )
    else:
        logger.raise_error('Wrong Attacker Type')



def get_adv_target(target_type, model, inputs, gt_label):
    with torch.no_grad():
        preds = model(inputs).softmax(-1)
    num_batch, num_class = preds.shape

    if target_type == 'random':
        adv_targets = torch.randint(0, num_class, (num_batch,)).to(gt_label.device)
        # validation check
        adv_targets = adv_target_update(gt_label, adv_targets, num_batch, num_class)
    elif target_type == 'most':
        idxs = torch.arange(num_batch).to(inputs.device)
        preds[idxs, gt_label] = -1
        adv_targets = preds.max(-1)[1]
    elif target_type == 'least':
        idxs = torch.arange(num_batch).to(inputs.device)
        preds[idxs, gt_label] = 100.0
        adv_targets = preds.min(-1)[1]
    else:
        raise ValueError('Wrong Targeted Attack Type')

    assert (adv_targets == gt_label).long().sum().item() == 0
    return adv_targets

def adv_target_update(gt_label, adv_target, num_batch, num_class):
    for i in range(num_batch):
        if int(gt_label[i]) == int(adv_target[i]):
            adv_target[i] = (int(adv_target[i]) + random.randint(1, num_class-1)) % num_class
    return adv_target

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import os
import time
import random
import attacker

# from utils.attack_utils import *

class test_ciiv():
    def __init__(self, args, config, logger, model, val=False):
        self.config = config
        self.logger = logger
        self.model = model

        # init inst setting
        self.init_inst_sample()

        # initialize attacker
        self.adv_test = self.config['attacker_opt']['adv_val']
        if self.adv_test:
            self.attacker = create_adversarial_attacker(config, model, logger)

        # save test
        self.test_save = True if config['test_opt']['save_data'] else False

        # get dataloader
        if val:
            self.phase = 'val'
            self.loader = testloader
        else:
            self.phase = 'test'
            self.loader = testloader

    def init_inst_sample(self):
        self.logger.info('=====> Init Instrumental Sampling')
        self.w_ce = self.config['inst_sample']['w_ce']
        self.w_reg = self.config['inst_sample']['w_reg']
        self.mul_ru = self.config['inst_sample']['mul_ru']
        self.num_loop = self.config['inst_sample']['num_loop']


    def run_val(self, epoch):
        self.logger.info('------------- Start Validation at Epoch {} -----------'.format(epoch))
        total_acc = []

        # set model to evaluation
        self.model.eval()

        # save test
        if self.test_save:
            org_list = []
            adv_list = []
            gt_list = []
            pred_list = []

        # run batch
        for i, (inputs, labels, indexes) in enumerate(self.loader):
            inputs, labels = inputs.cuda(), labels.cuda()
            batch_size = inputs.shape[0]
            # print test time
            # trigger adversarial attack or not
            if self.adv_test:
                if self.config['targeted_attack']:
                    adv_targets = get_adv_target(self.config['targeted_type'], self.model, inputs, labels)
                    adv_inputs = self.attacker.get_adv_images(inputs, adv_targets, random_start=rand_adv_init(self.config), targeted=True)
                else:
                    adv_inputs = self.attacker.get_adv_images(inputs, labels, random_start=rand_adv_init(self.config))
                self.model.eval()
                final_inputs = adv_inputs
            else:
                final_inputs = inputs

            # run model
            with torch.no_grad():
                predictions = self.model(final_inputs, loop=self.num_loop)

            if isinstance(predictions, tuple):
                predictions = predictions[0]

            total_acc.append((predictions.max(1)[1] == labels).view(-1, 1))

            # save adversarial images
            if self.test_save and i < self.config['test_opt']['save_length']:
                org_list.append(inputs.cpu())
                gt_list.append(labels.cpu())
                pred_list.append(predictions.max(1)[1].cpu())
                if self.adv_test:
                    adv_list.append(adv_inputs.cpu())

        all_acc = torch.cat(total_acc, dim=0).float()
        avg_acc = all_acc.mean().item()
        self.logger.info('Epoch {:5d} Evaluation Complete ==> Total Accuracy : {:9.4f}, Number Samples : {:9d}'.format(epoch, avg_acc, all_acc.shape[0]))

        # set back to training mode again
        self.model.train()

        # save adversarial images
        if self.test_save:
            file_name = os.path.join(self.config['output_dir'], self.config['test_opt']['file_name'])
            adv_output = {
                    'org_images' : torch.cat(org_list, 0),
                    'gt_labels'  : torch.cat(gt_list, 0),
                    'adv_images' : torch.cat(adv_list, 0) if self.adv_test else 0,
                    'pred_labels': torch.cat(pred_list, 0),
                    }
            torch.save(adv_output, file_name)
            self.logger.info('=====> Complete! Adversarial images have been saved to {}'.format(file_name))

        return avg_acc

In [None]:
import torch
import importlib
import random
import attacker

# from data.dataloader import get_loader

def rand_adv_init(config):
    if random.uniform(0, 1) < config['attacker_opt']['attack_rand_ini']:
        return True
    else:
        return False


def create_adversarial_attacker(config, model, logger):
    if config['attacker_opt']['attack_type'] == 'PGD':
        return attacker.PGD(model, logger, config, eps=config['attacker_opt']['attack_eps'],
                                                alpha=config['attacker_opt']['attack_alpha'],
                                                steps=config['attacker_opt']['attack_step'],
                                                eot_iter=config['attacker_opt']['eot_iter'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'PGDL2':
        return attacker.PGDL2(model, logger, config, eps=config['attacker_opt']['attack_eps'],
                                                alpha=config['attacker_opt']['attack_alpha'],
                                                steps=config['attacker_opt']['attack_step'],
                                                eot_iter=config['attacker_opt']['eot_iter'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'AutoAttack':
        return attacker.AA(model, logger, config, eps=config['attacker_opt']['attack_eps'],
                                                norm=config['attacker_opt']['attack_norm'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'FGSM':
        return attacker.FGSM(model, logger, config, eps=config['attacker_opt']['attack_eps'],
                                                eot_iter=config['attacker_opt']['eot_iter'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'FFGSM':
        return attacker.FFGSM(model, logger, config, eps=config['attacker_opt']['attack_eps'],
                                                alpha=config['attacker_opt']['attack_alpha'],
                                                eot_iter=config['attacker_opt']['eot_iter'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'GN':
        return attacker.GN(model, logger, config, sigma=config['attacker_opt']['gn_sigma'],
                                                eps=config['attacker_opt']['attack_eps'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'UN':
        return attacker.UN(model, logger, config, sigma=config['attacker_opt']['un_sigma'],
                                                eps=config['attacker_opt']['attack_eps'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'BPDAPGD':
        return attacker.BPDAPGD(model, logger, config, eps=config['attacker_opt']['attack_eps'],
                                                alpha=config['attacker_opt']['attack_alpha'],
                                                steps=config['attacker_opt']['attack_step'],
                                                eot_iter=config['attacker_opt']['eot_iter'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'EOT':
        return attacker.EOT(model, logger, config, eps=config['attacker_opt']['attack_eps'],
                                                learning_rate=config['attacker_opt']['attack_lr'],
                                                steps=config['attacker_opt']['attack_step'],
                                                eot_iter=config['attacker_opt']['eot_iter'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'CW':
        return attacker.CW(model, logger, config, c=config['attacker_opt']['c'],
                                                kappa=config['attacker_opt']['kappa'],
                                                steps=config['attacker_opt']['steps'],
                                                lr=config['attacker_opt']['lr'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'BFS':
        return attacker.BFS(model, logger, config, sigma=config['attacker_opt']['sigma'],
                                                eps=config['attacker_opt']['attack_eps'],
                                                steps=config['attacker_opt']['steps'],
                                            )
    elif config['attacker_opt']['attack_type'] == 'SPSA':
        return attacker.SPSA(model, logger, config, eps=config['attacker_opt']['attack_eps'],
                                                delta=config['attacker_opt']['delta'],
                                                batch_size=config['attacker_opt']['batch_size'],
                                                steps=config['attacker_opt']['steps'],
                                                lr=config['attacker_opt']['lr'],
                                            )
    else:
        logger.raise_error('Wrong Attacker Type')



def get_adv_target(target_type, model, inputs, gt_label):
    with torch.no_grad():
        preds = model(inputs).softmax(-1)
    num_batch, num_class = preds.shape

    if target_type == 'random':
        adv_targets = torch.randint(0, num_class, (num_batch,)).to(gt_label.device)
        # validation check
        adv_targets = adv_target_update(gt_label, adv_targets, num_batch, num_class)
    elif target_type == 'most':
        idxs = torch.arange(num_batch).to(inputs.device)
        preds[idxs, gt_label] = -1
        adv_targets = preds.max(-1)[1]
    elif target_type == 'least':
        idxs = torch.arange(num_batch).to(inputs.device)
        preds[idxs, gt_label] = 100.0
        adv_targets = preds.min(-1)[1]
    else:
        raise ValueError('Wrong Targeted Attack Type')

    assert (adv_targets == gt_label).long().sum().item() == 0
    return adv_targets

def adv_target_update(gt_label, adv_target, num_batch, num_class):
    for i in range(num_batch):
        if int(gt_label[i]) == int(adv_target[i]):
            adv_target[i] = (int(adv_target[i]) + random.randint(1, num_class-1)) % num_class
    return adv_target

In [None]:
import torch
import torch.nn as nn

import random
import attacker

# from test_baseline import test_baseline
# from test_ciiv import test_ciiv

def get_test_func(config):
    # choosing test strategy
    if config['strategy']['test_type'] == 'baseline':
        test_func = test_baseline
    elif config['strategy']['test_type'] == 'ciiv':
        test_func = test_ciiv
    else:
        raise ValueError('Wrong Test Strategy')
    return test_func



def get_adv_target(target_type, model, inputs, gt_label):
    with torch.no_grad():
        preds = model(inputs).softmax(-1)
    num_batch, num_class = preds.shape

    if target_type == 'random':
        adv_targets = torch.randint(0, num_class, (num_batch,)).to(gt_label.device)
        # validation check
        adv_targets = adv_target_update(gt_label, adv_targets, num_batch, num_class)
    elif target_type == 'most':
        idxs = torch.arange(num_batch).to(inputs.device)
        preds[idxs, gt_label] = -1
        adv_targets = preds.max(-1)[1]
    elif target_type == 'least':
        idxs = torch.arange(num_batch).to(inputs.device)
        preds[idxs, gt_label] = 100.0
        adv_targets = preds.min(-1)[1]
    else:
        raise ValueError('Wrong Targeted Attack Type')

    assert (adv_targets == gt_label).long().sum().item() == 0
    return adv_targets

def adv_target_update(gt_label, adv_target, num_batch, num_class):
    for i in range(num_batch):
        if int(gt_label[i]) == int(adv_target[i]):
            adv_target[i] = (int(adv_target[i]) + random.randint(0, num_class-1)) % num_class
    return adv_target

In [None]:
import os
import json
import torch
import numpy as np

class Checkpoint():
    def __init__(self, config):
        self.config = config
        self.save_best = config['checkpoint_opt']['save_best']
        self.best_epoch = -1
        self.best_performance = -1
        self.best_model_path = None


    def save(self, model, epoch, logger, acc):
        # update best model
        model_name = 'epoch_{}_'.format(epoch) + self.config['checkpoint_opt']['checkpoint_name']
        model_path = os.path.join(self.config['output_dir'], model_name)
        if acc is not None:
            if float(acc) > self.best_performance:
                self.best_epoch = epoch
                self.best_performance = float(acc)
                self.best_model_path = model_path
                self.save_current = True
                logger.info('Best model is updated at epoch {} with accuracy {:9.3f} (Path: {})'.format(self.best_epoch, self.best_performance, self.best_model_path))
            else:
                self.save_current = False
        else:
            # if acc is None, the newest is always the best
            self.best_epoch = epoch
            self.best_model_path = model_path

        # only save the best model or the last model
        best_saving = (self.save_best and self.save_current)
        # only save at certain steps, best epoch or the last epoch
        if (not best_saving) and (epoch % self.config['checkpoint_opt']['checkpoint_step'] != 0) and (epoch < (self.config['training_opt']['num_epochs'] - 1)):
            return

        output = {
            'state_dict': model.state_dict(),
            'epoch': epoch,
        }

        logger.info('Model at epoch {} is saved to {}'.format(epoch, model_path))
        torch.save(output, model_path)
        logger.info('Best model is at epoch {} with accuracy {:9.3f}'.format(self.best_epoch, self.best_performance))
        self.save_best_model(logger)


    def save_best_model(self, logger):
        logger.info('Best model is at epoch {} with accuracy {:9.3f} (Path: {})'.format(self.best_epoch, self.best_performance, self.best_model_path))
        with open(os.path.join(self.config['output_dir'], 'best_checkpoint'), 'w+') as f:
            f.write(self.best_model_path + ' ' + str(self.best_epoch) + ' ' + str(self.best_performance) + '\n')

    def load(self, model, path, logger):
        if path.split('.')[-1] != 'pth':
            with open(os.path.join(path, 'best_checkpoint')) as f:
                path = f[0].split(' ')[0]

        checkpoint = torch.load(path, map_location='cpu')
        logger.info('Loading checkpoint pretrained with epoch {}.'.format(checkpoint['epoch']))
        model_state = checkpoint['state_dict']

        x = model.state_dict()
        for key, _ in x.items():
            if key in model_state:
                x[key] = model_state[key]
                logger.info('Load {:>50} from checkpoint.'.format(key))
            elif 'module.' + key in model_state:
                x[key] = model_state['module.' + key]
                logger.info('Load {:>50} from checkpoint (rematch with module.).'.format(key))
            else:
                logger.info('WARNING: Key {} is missing in the checkpoint.'.format(key))

        model.load_state_dict(x)
        pass




In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.nn.functional as F

# import utils.general_utils as utils
# from data.dataloader import get_loader
# from utils.checkpoint_utils import Checkpoint

import time
import math
import random
import attacker
import numpy as np

# from utils.attack_utils import *
# from utils.train_utils import *
# from utils.test_utils import *

class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes=10):
        super(LabelSmoothingLoss, self).__init__()
        self.cls = classes

    def forward(self, pred, target, confidence, dim=-1):
        smoothing = 1.0 - confidence
        pred = pred.log_softmax(dim=dim)
        with torch.no_grad():
            # true_dist = pred.data.clone()
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=dim))

class train_ciiv():
    def __init__(self, args, config, logger, model, eval=False):
        self.config = config
        self.logger = logger
        self.model = model
        self.training_opt = config['training_opt']
        self.checkpoint = Checkpoint(config)
        self.logger.info('============= Training Strategy: CiiV Training ============')

        # get dataloader
        self.logger.info('=====> Get train dataloader')
        self.train_loader = trainloader#get_loader(config, 'train', logger)

        # init inst setting
        self.init_inst_sample()

        # create optimizer
        self.create_optimizer()

        # create scheduler
        self.create_scheduler()

        # create loss
        self.creat_loss()

        # adversarial train
        self.adv_train = self.config['attacker_opt']['adv_train']
        if self.adv_train:
            self.attacker = create_adversarial_attacker(config, model, logger)

        # set eval
        if eval:
            # choosing test strategy
            test_func = get_test_func(config)
            # start testing
            self.testing = test_func(args, config, logger, model, val=True)

    def init_inst_sample(self):
        self.logger.info('=====> Init Instrumental Sampling')
        self.w_ce = self.config['inst_sample']['w_ce']
        self.w_reg = self.config['inst_sample']['w_reg']
        self.mul_ru = self.config['inst_sample']['mul_ru']
        self.num_loop = self.config['inst_sample']['num_loop']

    def creat_loss(self):
        if self.config['inst_sample']['ce_smooth']:
            self.criterion = LabelSmoothingLoss(classes=self.config['networks']['params']['num_classes'])
        else:
            self.criterion = nn.CrossEntropyLoss()

    def create_optimizer(self):
        self.logger.info('=====> Create optimizer')
        optim_params = self.training_opt['optim_params']
        optim_params_dict = {'params': self.model.parameters(),
                            'lr': optim_params['lr'],
                            'momentum': optim_params['momentum'],
                            'weight_decay': optim_params['weight_decay']
                            }

        if self.training_opt['optimizer'] == 'Adam':
            self.optimizer = optim.Adam([optim_params_dict, ])
        elif self.training_opt['optimizer'] == 'SGD':
            self.optimizer = optim.SGD([optim_params_dict, ])
        else:
            self.logger.info('********** ERROR: unidentified optimizer **********')


    def create_scheduler(self):
        self.logger.info('=====> Create Scheduler')
        scheduler_params = self.training_opt['scheduler_params']
        if self.training_opt['scheduler'] == 'cosine':
            self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, self.training_opt['num_epochs'], eta_min=scheduler_params['endlr'])
        elif self.training_opt['scheduler'] == 'step':
            self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, gamma=scheduler_params['gamma'], milestones=scheduler_params['milestones'])
        else:
            self.logger.info('********** ERROR: unidentified optimizer **********')

    def l2_loss(self, x, y):
        diff = x - y
        diff = diff*diff
        diff = diff.sum(1)
        diff = diff.mean(0)
        return diff

    def smooth_l1_loss(self, x, y):
        diff = F.smooth_l1_loss(x, y, reduction='none')
        diff = diff.sum(1)
        diff = diff.mean(0)
        return diff

    def get_mean_wo_i(self, inputs, i):
        return (sum(inputs) - inputs[i]) / float(len(inputs) - 1)

    def run(self):
        # Start Training
        self.logger.info('=====> Start Naive Training')

        # run epoch
        for epoch in range(self.training_opt['num_epochs']):
            self.logger.info('------------ Start Epoch {} -----------'.format(epoch))
            self.model.train()

            # run batch
            total_batch = len(self.train_loader)
            for step, (inputs, labels, indexes) in enumerate(self.train_loader):
                # naive training
                inputs, labels = inputs.cuda(), labels.cuda()
                if self.adv_train:
                    final_inputs = self.attacker.get_adv_images(inputs, labels)
                else:
                    final_inputs = inputs

                # instrumental sampling training by running all samples parallelly
                iter_info_print = {}
                all_ces = []
                all_regs = []
                preds, z_scores, features, logits = self.model(final_inputs, loop=self.num_loop)
                for i, logit in enumerate(logits):
                    if self.config['inst_sample']['ce_smooth']:
                        ce_loss = self.criterion(logit, labels, confidence=float(z_scores[i]))
                    else:
                        ce_loss = self.criterion(logit, labels)
                    iter_info_print['ce_loss_{}'.format(i)] = ce_loss.sum().item()
                    all_ces.append(ce_loss)

                for i in range(len(features)):
                    if self.config['inst_sample']['reg_loss'] == 'L2':
                        reg_loss = self.l2_loss(features[i] * self.get_mean_wo_i(z_scores, i), self.get_mean_wo_i(features, i) * z_scores[i])
                        iter_info_print['ciiv_l2loss_{}'.format(i)] = reg_loss.sum().item()
                    elif self.config['inst_sample']['reg_loss'] == 'L1':
                        reg_loss = self.smooth_l1_loss(features[i] * self.get_mean_wo_i(z_scores, i), self.get_mean_wo_i(features, i) * z_scores[i])
                        iter_info_print['ciiv_l1loss_{}'.format(i)] = reg_loss.sum().item()
                    else:
                        raise ValueError('Wrong Reg Loss Type')
                    all_regs.append(reg_loss)

                loss = self.w_ce * sum(all_ces) / len(all_ces) + self.w_reg * sum(all_regs) / len(all_regs)
                iter_info_print['w_ce'] = self.w_ce
                iter_info_print['w_reg'] = self.w_reg

                # backward
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                # calculate accuracy
                accuracy = (preds.max(1)[1] == labels).sum().float() / preds.shape[0]

                # log information
                iter_info_print.update( {'Accuracy' : accuracy.item(), 'Loss' : loss.sum().item(), 'Poke LR' : float(self.optimizer.param_groups[0]['lr'])} )
                self.logger.info_iter(epoch, step, total_batch, iter_info_print, self.config['logger_opt']['print_iter'])
                if self.config['logger_opt']['print_grad'] and step % 1000 == 0:
                    utils.print_grad(self.model.named_parameters())

            # evaluation on validation set
            self.optimizer.zero_grad()
            val_acc = self.testing.run_val(epoch)

            # update regression loss weight for BPFC or Instrumental Sampling
            if (epoch in self.config['inst_sample']['milestones']):
                self.logger.info('update regression weight from {} to {}'.format(self.w_reg, self.w_reg * self.mul_ru))
                self.w_reg = self.w_reg * self.mul_ru

            # checkpoint
            self.checkpoint.save(self.model, epoch, self.logger, acc=val_acc)
            # update scheduler
            self.scheduler.step()
        # save best model path
        self.checkpoint.save_best_model(self.logger)



In [None]:
import torch
import importlib


def update_attacker_info(config, attack_config, dataset_name, attacker_type, attacker_set):
    print('==================== Attacker {} ================='.format(attacker_type))
    config['attacker_opt']['attack_type'] = attacker_type
    config['attacker_opt']['attack_set'] = attacker_set
    config['attacker_opt'].update(attack_config[dataset_name][attacker_type][attacker_set])

def RepresentsInt(s):
    try:
        int(s)
        return True
    except ValueError:
        return False

def RepresentsFloat(s):
    try:
        float(s)
        return True
    except ValueError:
        return False

def int_to_others(val):
    if val == 'true' or val == 'True':
        return True
    elif val == 'false' or val == 'False':
        return False
    elif RepresentsInt(val):
        return int(val)
    elif RepresentsFloat(val):
        return float(val)
    else:
        return val

def update_config_key(update_dict, key, new_val):
    names = key.split('.')
    while(len(names) > 1):
        item_key = names.pop(0)
        update_dict = update_dict[item_key]
    old_val = update_dict[names[-1]]
    update_dict[names[-1]] = int_to_others(new_val)
    return old_val

def update(config, args, logger):
    if args.output_dir is not None:
        config['output_dir'] = args.output_dir
        logger.info('======= Update Config: output_dir is set to : ' + str(config['output_dir']))
    if args.train_type is not None:
        config['strategy']['train_type'] = args.train_type
        logger.info('======= Update Config: training type is set to: '.format(args.train_type))
    if args.test_type is not None:
        config['strategy']['test_type'] = args.test_type
        logger.info('======= Update Config: test type is set to: '.format(args.test_type))
    if args.adv_train:
        config['attacker_opt']['adv_train'] = args.adv_train
    if args.adv_test:
        config['attacker_opt']['adv_val'] = args.adv_test
    if args.adv_type is not None:
        config['attacker_opt']['attack_type'] = args.adv_type
    if args.adv_setting is not None:
        config['attacker_opt']['attack_set'] = args.adv_setting
    if args.rand_aug:
        config['dataset']['rand_aug'] = True
        logger.info('===================> Using Random Augmentation')

    if args.target_type:
        config['targeted_attack'] = True
        config['targeted_type'] = args.target_type
    else:
        config['targeted_attack'] = False

    # update config from command
    if len(args.opts) > 0 and (len(args.opts) % 2) == 0:
        for i in range(len(args.opts)//2):
            key = args.opts[2*i]
            val = args.opts[2*i+1]
            old_val = update_config_key(config, key, val)
            logger.info('=====> {}: {} (old key) => {} (new key)'.format(key, old_val, val))
    return config

def source_import(file_path):
    """This function imports python module directly from source code using importlib"""
    spec = importlib.util.spec_from_file_location('', file_path)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module

def print_grad(named_parameters):
    """ show grads """
    total_norm = 0
    param_to_norm = {}
    param_to_shape = {}
    for n, p in named_parameters:
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm ** 2
            param_to_norm[n] = param_norm
            param_to_shape[n] = p.size()
    total_norm = total_norm ** (1. / 2)
    print('---Total norm {:.3f} -----------------'.format(total_norm))
    for name, norm in sorted(param_to_norm.items(), key=lambda x: -x[1]):
            print("{:<50s}: {:.3f}, ({})".format(name, norm, param_to_shape[name]))
    print('-------------------------------', flush=True)
    return total_norm

def print_config(config, logger, head=''):
    for key, val in config.items():
        if isinstance(val, dict):
            logger.info(head + str(key))
            print_config(val, logger, head=head + '   ')
        else:
            logger.info(head + '{} : {}'.format(str(key), str(val)))

class TriggerAction():
    def __init__(self, name):
        self.name = name
        self.action = {}

    def add_action(self, name, func):
        assert str(name) not in self.action
        self.action[str(name)] = func

    def remove_action(self, name):
        assert str(name) in self.action
        del self.action[str(name)]
        assert str(name) not in self.action

    def run_all(self, logger=None):
        for key, func in self.action.items():
            if logger:
                logger.info('trigger {}'.format(key))
            func()

In [None]:
training = train_ciiv(args, config, logger, model, eval=args.require_eval)
training.run()

In [None]:
# normal test

checkpoint = Checkpoint(config)
checkpoint.load(model, args.load_dir, logger)
# start testing
test_func = get_test_func(config)
if args.phase == 'val':
    testing = test_func(args, config, logger, model, val=True)
    testing.run_val(epoch=-1)
elif args.phase == 'test':
    testing = test_func(args, config, logger, model, val=False)
    testing.run_val(epoch=-1)
else:
    raise ValueError('Wrong Phase')