In [1]:
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import glob
from PIL import Image
import argparse
import os

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
import random
from shutil import copy2
import torch.nn.functional as F
import torch.autograd as autograd

import cv2

from scipy.spatial.distance import cdist

import math
from torch.utils import model_zoo

from copy import deepcopy
import re
from collections import OrderedDict
from torch.autograd import Function
import time
from torch.optim import Adam


import cv2
from skimage import exposure
from skimage import filters
import matplotlib.pyplot as plt
# from google.colab.patches import cv2_imshow

In [2]:
# torch.cuda.set_device(1)
# torch.cuda.current_device()

In [3]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

def imagenet_normalize(t, mean=None, std=None):
    if mean is None:
        mean = IMAGENET_MEAN
    if std is None:
        std= IMAGENET_STD

    ts = []
    for i in range(3):
        ts.append(torch.unsqueeze((t[:, i] - mean[i]) / std[i], 1))
    return torch.cat(ts, dim=1)

preprocess = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])

In [4]:
def freeze_model(model):
    for param in model.parameters():
        param.requires_grad_(False)

In [6]:
model_urls = {
    'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
    'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
}


#
# AlexNet | begin
#

ALEXNET_NAME_MAP = {
    "conv1.weight": "features.0.weight",
    "conv1.bias": "features.0.bias",
    "conv2.weight": "features.3.weight",
    "conv2.bias": "features.3.bias",
    "conv3.weight": "features.6.weight",
    "conv3.bias": "features.6.bias",
    "conv4.weight": "features.8.weight",
    "conv4.bias": "features.8.bias",
    "conv5.weight": "features.10.weight",
    "conv5.bias": "features.10.bias",
    "fc1.weight": "classifier.1.weight",
    "fc1.bias": "classifier.1.bias",
    "fc2.weight": "classifier.4.weight",
    "fc2.bias": "classifier.4.bias",
    "fc3.weight": "classifier.6.weight",
    "fc3.bias": "classifier.6.bias"
}


class AlexNet(nn.Module):

    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()

        # convolutional layers
        self.conv1 = nn.Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
        self.conv2 = nn.Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
        self.conv3 = nn.Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.conv4 = nn.Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.conv5 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        # pooling layers
        self.pool1 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), dilation=(1, 1))
        self.pool2 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), dilation=(1, 1))
        self.pool5 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), dilation=(1, 1))
        # fully connected layers
        self.fc1 = nn.Linear(9216, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, num_classes)

    def forward(self, x, out_keys=None):
        out = {}
        out['c1'] = self.conv1(x)
        out['r1'] = F.relu(out['c1'])
        out['p1'] = self.pool1(out['r1'])
        out['r2'] = F.relu(self.conv2(out['p1']))
        out['p2'] = self.pool2(out['r2'])
        out['r3'] = F.relu(self.conv3(out['p2']))
        out['r4'] = F.relu(self.conv4(out['r3']))
        out['r5'] = F.relu(self.conv5(out['r4']))
        out['p5'] = self.pool5(out['r5'])
        out['fc1'] = F.relu(self.fc1(out['p5'].view(1, -1)))
        out['fc2'] = F.relu(self.fc2(out['fc1']))
        out['fc3'] = self.fc3(out['fc2'])

        if out_keys is None:
            return out['fc3']

        res = {}
        for key in out_keys:
            res[key] = out[key]
        return res


def convert_alexnet_weights(src_state, dest_state):
    for key in dest_state:
        if key in ALEXNET_NAME_MAP:
            dest_state[key] = deepcopy(src_state[ALEXNET_NAME_MAP[key]])
    return dest_state


def alexnet(pretrained=False, **kwargs):
    r"""AlexNet model architecture from the
    `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = AlexNet(**kwargs)
    if pretrained:
        src_state = model_zoo.load_url(model_urls['alexnet'])
        dest_state = convert_alexnet_weights(src_state, model.state_dict())
        model.load_state_dict(dest_state)
    return model

#
# AlexNet | end
#

#
# ResNet | begin
#


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x, out_keys=None):
        out = {}
        x = self.conv1(x)
        out["c1"] = x
        x = self.bn1(x)
        out["bn1"] = x
        x = self.relu(x)
        out["r1"] = x
        x = self.maxpool(x)
        out["p1"] = x

        x = self.layer1(x)
        out["l1"] = x
        x = self.layer2(x)
        out["l2"] = x
        x = self.layer3(x)
        out["l3"] = x
        x = self.layer4(x)
        out["l4"] = x

        x = self.avgpool(x)
        out["gvp"] = x
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        out["fc"] = x

        if out_keys is None:
            return x

        res = {}
        for key in out_keys:
            res[key] = out[key]
        return res


def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model


def resnet34(pretrained=False, **kwargs):
    """Constructs a ResNet-34 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
    return model


def resnet50(pretrained=False, **kwargs):
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
    return model


def resnet101(pretrained=False, **kwargs):
    """Constructs a ResNet-101 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
    return model


def resnet152(pretrained=False, **kwargs):
    """Constructs a ResNet-152 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
    return model


# ResNet | end


# DenseNet | begin

def densenet121(pretrained=False, **kwargs):
    r"""Densenet-121 model from
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
                     **kwargs)
    if pretrained:
        # '.'s are no longer allowed in module names, but pervious _DenseLayer
        # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
        # They are also in the checkpoints in model_urls. This pattern is used
        # to find such keys.
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        state_dict = model_zoo.load_url(model_urls['densenet121'])
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        model.load_state_dict(state_dict)
    return model


def densenet169(pretrained=False, **kwargs):
    r"""Densenet-169 model from
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
                     **kwargs)
    if pretrained:
        # '.'s are no longer allowed in module names, but pervious _DenseLayer
        # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
        # They are also in the checkpoints in model_urls. This pattern is used
        # to find such keys.
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        state_dict = model_zoo.load_url(model_urls['densenet169'])
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        model.load_state_dict(state_dict)
    return model


def densenet201(pretrained=False, **kwargs):
    r"""Densenet-201 model from
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
                     **kwargs)
    if pretrained:
        # '.'s are no longer allowed in module names, but pervious _DenseLayer
        # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
        # They are also in the checkpoints in model_urls. This pattern is used
        # to find such keys.
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        state_dict = model_zoo.load_url(model_urls['densenet201'])
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        model.load_state_dict(state_dict)
    return model


def densenet161(pretrained=False, **kwargs):
    r"""Densenet-161 model from
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
                     **kwargs)
    if pretrained:
        # '.'s are no longer allowed in module names, but pervious _DenseLayer
        # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
        # They are also in the checkpoints in model_urls. This pattern is used
        # to find such keys.
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        state_dict = model_zoo.load_url(model_urls['densenet161'])
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        model.load_state_dict(state_dict)
    return model


class _DenseLayer(nn.Sequential):
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
        super(_DenseLayer, self).__init__()
        self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
        self.add_module('relu1', nn.ReLU(inplace=True)),
        self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
                        growth_rate, kernel_size=1, stride=1, bias=False)),
        self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
        self.add_module('relu2', nn.ReLU(inplace=True)),
        self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
                        kernel_size=3, stride=1, padding=1, bias=False)),
        self.drop_rate = drop_rate

    def forward(self, x):
        new_features = super(_DenseLayer, self).forward(x)
        if self.drop_rate > 0:
            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
        return torch.cat([x, new_features], 1)


class _DenseBlock(nn.Sequential):
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
            layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)
            self.add_module('denselayer%d' % (i + 1), layer)


class _Transition(nn.Sequential):
    def __init__(self, num_input_features, num_output_features):
        super(_Transition, self).__init__()
        self.add_module('norm', nn.BatchNorm2d(num_input_features))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
                                          kernel_size=1, stride=1, bias=False))
        self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))


class DenseNet(nn.Module):
    r"""Densenet-BC model class, based on
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        growth_rate (int) - how many filters to add each layer (`k` in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
    """
    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
                 num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000):

        super(DenseNet, self).__init__()

        # First convolution
        self.features = nn.Sequential(OrderedDict([
            ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
            ('norm0', nn.BatchNorm2d(num_init_features)),
            ('relu0', nn.ReLU(inplace=True)),
            ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
        ]))

        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
                                bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)
            self.features.add_module('denseblock%d' % (i + 1), block)
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
                trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
                self.features.add_module('transition%d' % (i + 1), trans)
                num_features = num_features // 2

        # Final batch norm
        self.features.add_module('norm5', nn.BatchNorm2d(num_features))

        # Linear layer
        self.classifier = nn.Linear(num_features, num_classes)

        # Official init from torch repo.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal(m.weight.data)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

    def forward(self, x, out_keys=None):
        out_dict = {}
        features = self.features(x)
        out = F.relu(features, inplace=True)
        out_dict['l'] = out
        out = F.avg_pool2d(out, kernel_size=7, stride=1)
        out_dict['gvp'] = out
        out = out.view(features.size(0), -1)
        out = self.classifier(out)
        out_dict['fc'] = out
        if out_keys is None:
            return out

        res = {}
        for key in out_keys:
            res[key] = out_dict[key]
        return res

# DenseNet | end


def get_gaussian_blur_kernel(ksize, sigma):
    ker = cv2.getGaussianKernel(ksize, sigma).astype(np.float32)
    blur_kernel = (ker * ker.T)[None, None]
    blur_kernel = torch.tensor(blur_kernel)

    return blur_kernel


def gaussian_blur(x, ksize, sigma):
    """

    Args:
    :param x: torch.tensor (n, c, h, w), will padding with reflection
    :param ksize: int
    :param sigma: int
    :return:
    """
    psize = int((ksize - 1) / 2)
    blur_kernel = get_gaussian_blur_kernel(ksize, sigma)
    x_padded = F.pad(x, [psize] * 4, mode="reflect")
    blurs = []
    for i in range(3):
        blurs.append(F.conv2d(x_padded[:, i, None], blur_kernel))
    blurred = torch.cat(blurs, 1)

    return blurred


class GuidedBackpropReLU(Function):

    @staticmethod
    def forward(ctx, input):
        positive_mask = (input > 0).type_as(input)
        output = torch.addcmul(torch.zeros(input.size()).type_as(input), input, positive_mask)
        ctx.save_for_backward(input, output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, output = ctx.saved_tensors
        grad_input = None

        positive_mask_1 = (input > 0).type_as(grad_output)
        positive_mask_2 = (grad_output > 0).type_as(grad_output)
        grad_input = torch.addcmul(torch.zeros(input.size()).type_as(input), torch.addcmul(torch.zeros(input.size()).type_as(input), grad_output, positive_mask_1), positive_mask_2)

        return grad_input


def freeze_model(model):
    for param in model.parameters():
        param.requires_grad_(False)


### SoftReLU


class SoftReLU(nn.Module):

    def __init__(self, eps=1e-6):
        super(SoftReLU, self).__init__()
        self.eps = eps

    def forward(self, x):
        # mask = (x > 0).float()
        # return torch.sqrt(x * x + self.eps) * mask
        return SoftReLUFunc.apply(x)


class SoftReLUFunc(autograd.Function):

    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        # v2
        x,  = ctx.saved_tensors
        # x2 = x * x
        grad_input = grad_output.clone()
        i1 = (x < 0)
        i2 = x >= 0
        xi1 = x[i1]
        xi2 = x[i2]
        n1, n2 = xi1.numel(), xi2.numel()
        assert n1 + n2 == x.numel()
        if n1 > 0:
            xi12 = xi1 * xi1
            new_v = xi1 / torch.sqrt(xi12 + 1e-4) + 1
            grad_input[i1] = grad_input[i1] * new_v
        if n2 > 0:
            xi22 = xi2 * xi2
            new_v = xi2 / torch.sqrt(xi22 + 1e-4)
            grad_input[i2] = grad_input[i2] * new_v
        return grad_input



In [7]:
model_urls = {
    'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
    'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
}


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = SoftReLU()
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = SoftReLU()
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = SoftReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x, out_keys=None):
        out = {}
        x = self.conv1(x)
        out["c1"] = x
        x = self.bn1(x)
        out["bn1"] = x
        x = self.relu(x)
        out["r1"] = x
        x = self.maxpool(x)
        out["p1"] = x

        x = self.layer1(x)
        out["l1"] = x
        x = self.layer2(x)
        out["l2"] = x
        x = self.layer3(x)
        out["l3"] = x
        x = self.layer4(x)
        out["l4"] = x

        x = self.avgpool(x)
        out["gvp"] = x
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        out["fc"] = x

        if out_keys is None:
            return x

        res = {}
        for key in out_keys:
            res[key] = out[key]
        return res


def resnet50_soft(pretrained=False, **kwargs):
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
    return model


In [8]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class StadvTVLoss(nn.Module):

    def forward(self, flows):
        padded_flows = F.pad(flows, (1, 1, 1, 1), mode='replicate')
        height, width = flows.size(2), flows.size(3)
        n = float(np.sqrt(height * width))
        shifted_flows = [
            padded_flows[:, :, 2:, 2:],
            padded_flows[:, :, 2:, :-2],
            padded_flows[:, :, :-2, 2:],
            padded_flows[:, :, :-2, :-2]
        ]

        diffs = [(flows[:, 1] - shifted_flow[:, 1]) ** 2 + (flows[:, 0] - shifted_flow[:, 0]) ** 2
                 for shifted_flow in shifted_flows]
        loss = torch.stack(diffs).sum(2, keepdim=True).sum(3, keepdim=True).sum(0, keepdim=True).view(-1)
        loss = torch.sqrt(loss)
        return loss / n


class StadvFlowLoss(nn.Module):

    def forward(self,flows, epsilon=1e-8):
        padded_flows = F.pad(flows, (1, 1, 1, 1), mode='replicate')
        shifted_flows = [
            padded_flows[:, :, 2:, 2:],
            padded_flows[:, :, 2:, :-2],
            padded_flows[:, :, :-2, 2:],
            padded_flows[:, :, :-2, :-2]
        ]

        diffs = [torch.sqrt((flows[:, 1] - shifted_flow[:, 1]) ** 2 +
                            (flows[:, 0] - shifted_flow[:, 0]) ** 2 +
                            epsilon) for shifted_flow in shifted_flows
                 ]
        # shape: (4, n, h - 1, w - 1) => (n, )
        loss = torch.stack(diffs).sum(2, keepdim=True).sum(3, keepdim=True).sum(0, keepdim=True).view(-1)
        return loss


class StadvFlow(nn.Module):

    def forward(self, images, flows):
        batch_size, n_channels, height, width = images.shape
        basegrid = torch.stack(torch.meshgrid([torch.arange(height, device=images.device),
                                               torch.arange(width, device=images.device)]))
        batched_basegrid = basegrid.expand(batch_size, -1, -1, -1)

        sampling_grid = batched_basegrid.float() + flows
        sampling_grid_x = torch.clamp(sampling_grid[:, 1], 0., float(width) - 1)
        sampling_grid_y = torch.clamp(sampling_grid[:, 0], 0., float(height) - 1)

        x0 = sampling_grid_x.floor().long()
        x1 = x0 + 1
        y0 = sampling_grid_y.floor().long()
        y1 = y0 + 1

        x0.clamp_(0, width - 2)
        x1.clamp_(0, width - 1)
        y0.clamp_(0, height - 2)
        y1.clamp_(0, height - 1)

        b = torch.arange(batch_size).view(batch_size, 1, 1).expand(-1, height, width)

        Ia = images[b, :, y0, x0].permute(0, 3, 1, 2)
        Ib = images[b, :, y1, x0].permute(0, 3, 1, 2)
        Ic = images[b, :, y0, x1].permute(0, 3, 1, 2)
        Id = images[b, :, y1, x1].permute(0, 3, 1, 2)

        x0 = x0.float()
        x1 = x1.float()
        y0 = y0.float()
        y1 = y1.float()

        wa = (x1 - sampling_grid_x) * (y1 - sampling_grid_y)
        wb = (x1 - sampling_grid_x) * (sampling_grid_y - y0)
        wc = (sampling_grid_x - x0) * (y1 - sampling_grid_y)
        wd = (sampling_grid_x - x0) * (sampling_grid_y - y0)

        wa = wa.unsqueeze(1)
        wb = wb.unsqueeze(1)
        wc = wc.unsqueeze(1)
        wd = wd.unsqueeze(1)

        perturbed_image = torch.stack([wa * Ia, wb * Ib, wc * Ic, wd * Id]).sum(0)
        return perturbed_image

In [9]:
def imagenet_resize_postfn(grad):
    grad = grad.abs().max(1, keepdim=True)[0]
    grad = F.avg_pool2d(grad, 4).squeeze(1)
    shape = grad.shape
    grad = grad.view(len(grad), -1)
    grad_min = grad.min(1, keepdim=True)[0]
    grad = grad - grad_min
    grad_max = grad.max(1, keepdim=True)[0]
    grad = grad / torch.max(grad_max, torch.tensor([1e-8], device='cuda'))
    return grad.view(*shape)


def generate_gs_per_batches(model_tup, bx, by, post_fn=None, keep_grad=False):
    model, pre_fn = model_tup[:2]
    bxp = pre_fn(bx)
    logit = model(bxp)
    loss = F.nll_loss(F.log_softmax(logit), by)
    grad = autograd.grad([loss], [bx], create_graph=keep_grad)[0]
    if post_fn is not None:
        grad = post_fn(grad)
    return grad


def generate_gs(model_tup, x, y, post_fn=None, keep_grad=False, batch_size=48, device='cuda'):
    n = len(x)
    n_batches = (n + batch_size - 1) // batch_size
    generated = []
    for i in range(n_batches):
        si = i * batch_size
        ei = min(n, si + batch_size)
        bx, by = x[si:ei], y[si:ei]
        bx, by = torch.tensor(bx, device=device, requires_grad=True), torch.tensor(by, device='cuda')
        generated.append(generate_gs_per_batches(
            model_tup, bx, by, post_fn=post_fn,
            keep_grad=keep_grad).detach().cpu().numpy())
    generated = np.concatenate(generated, axis=0)
    return generated

### AdvEdge

In [10]:
from contextlib import ExitStack

import torch


class CustomAdam(object):

    def __init__(self, lr, beta1=0.9, beta2=0.999, eps=1e-8):
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps

    def __call__(self, step, params, grads, means, variances, bp_through_optimizer=False):
        with torch.no_grad() if not bp_through_optimizer else ExitStack():
            new_params = []
            new_means = []
            new_variances = []
            for param, grad, mean, variance in zip(params, grads, means, variances):
                # if not bp_through_optimizer:
                #     mean = mean.detach()
                #     variance = variance.detach()
                #     grad = grad.detach()
                new_means.append(self.beta1 * mean + (1 - self.beta1) * grad)
                new_variances.append(self.beta2 * variance + (1 - self.beta2) * grad * grad)

                c_m = new_means[-1] / (1 - self.beta1 ** step)
                c_v = new_variances[-1] / (1 - self.beta2 ** step)
                new_params.append(param - self.lr / (torch.sqrt(c_v) + self.eps) * c_m)

        return new_params, new_means, new_variances

In [11]:
def get_gaussian_blur_kernel(ksize, sigma):
    ker = cv2.getGaussianKernel(ksize, sigma).astype(np.float32)
    blur_kernel = (ker * ker.T)[None, None]
    blur_kernel = torch.tensor(blur_kernel)

    return blur_kernel

In [12]:
def tv_norm(img, beta=2., epsilon=1e-8):
    batch_size = img.size(0)
    dy = -img[:, :, :-1] + img[:, :, 1:]
    dx = (img[:, :, :, 1:] - img[:, :, :, :-1]).transpose(2, 3)
    return (dx.pow(2) + dy.pow(2) + epsilon).pow(beta / 2.).reshape(batch_size, -1).sum(1)



class GaussianBlur(nn.Module):

    def __init__(self, ksize, sigma, num_channels=3):
        super(GaussianBlur, self).__init__()
        self.ksize = ksize
        self.sigma = sigma
        self.psize = int((ksize - 1) / 2)
        self.num_channels = num_channels
        self.blur_kernel = nn.Parameter(get_gaussian_blur_kernel(ksize, sigma).repeat(num_channels, 1, 1, 1),
                                        requires_grad=False)

    def forward(self, x):
        x_padded = F.pad(x, [self.psize] * 4, mode="reflect")
        return F.conv2d(x_padded, self.blur_kernel, groups=self.num_channels)


class MASK(object):

    def __init__(self, cuda):
        self.gaussian_blur = GaussianBlur(11, 10)
        if cuda:
            self.gaussian_blur.cuda()


def mask_iter(mask_model, model, pre_fn, x, y, r, m_init, l1_lambda=1e-2, tv_lambda=1e-4, tv_beta=3., noise_std=0.,
              weights=None, x_blurred=None):
    batch_size = x.size(0)
    cuda = x.is_cuda
    if r is not None:
        x = x + r
    if x_blurred is None:
        x_blurred = mask_model.gaussian_blur(x)
    m = F.upsample(m_init, size=(x.size(2), x.size(3)), mode="bilinear")
    perturbed_inputs = m * x + (1. - m) * x_blurred
    if noise_std != 0:
        noise = noise_std * torch.randn(*perturbed_inputs.size())
        if cuda:
            noise = noise.cuda()
        perturbed_inputs = perturbed_inputs + noise

    outputs = F.softmax(model(pre_fn(perturbed_inputs)), 1)
    l1_loss = torch.mean(torch.abs(1 - m_init).view(batch_size, -1), 1)
    tv_loss = tv_norm(m_init, tv_beta)
    class_loss = outputs.gather(1, y[:, None])[:, 0]
    if weights is None:
        tot_loss = l1_lambda * torch.sum(l1_loss) + tv_lambda * torch.sum(tv_loss) + torch.sum(class_loss)
    else:
        tot_loss = (l1_lambda * torch.sum(l1_loss * weights) + tv_lambda * torch.sum(tv_loss * weights) +
                    torch.sum(class_loss * weights))
    return tot_loss, [l1_loss, tv_lambda, class_loss]


class MASKV2(object):

    def __init__(self, cuda):
        self.blur1 = GaussianBlur(21, -1)
        self.blur2 = GaussianBlur(11, -1, 1)
        self.cuda = cuda
        if cuda:
            self.blur1.cuda()
            self.blur2.cuda()


def mask_iter_v2(mask_model, model, pre_fn, x, y, m_init, l1_lambda=1e-4, tv_lambda=1e-2, tv_beta=3., noise_std=0.,
                 jitter=4, weights=None, x_blurred=None):
    batch_size = x.size(0)
    if x_blurred is None:
        x_blurred = mask_model.blur1(x)

    if jitter != 0:
        j1 = np.random.randint(jitter)
        j2 = np.random.randint(jitter)
    else:
        j1, j2 = 0, 0
    x_ = x[:, :, j1:j1+224, j2:j2+224]
    x_blurred_ = x_blurred[:, :, j1:j1+224, j2:j2+224]

    if noise_std != 0:
        noisy = torch.randn_like(m_init)
        mask_w_noisy = m_init + noisy
        mask_w_noisy.clamp_(0, 1)
    else:
        mask_w_noisy = m_init

    mask_w_noisy = F.interpolate(mask_w_noisy, (224, 224), mode='bilinear')
    mask_w_noisy = mask_model.blur2(mask_w_noisy)
    x = x_ * mask_w_noisy + x_blurred_ * (1 - mask_w_noisy)

    class_loss = F.softmax(model(pre_fn(x)), dim=-1).gather(1, y.unsqueeze(1)).squeeze(1)
    l1_loss = (1 - m_init).abs().view(batch_size, -1).sum(-1)
    tv_loss = tv_norm(m_init, tv_beta)

    if weights is None:
        tot_loss = l1_lambda * torch.sum(l1_loss) + tv_lambda * torch.sum(tv_loss) + torch.sum(class_loss)
    else:
        tot_loss = (l1_lambda * torch.sum(l1_loss * weights) + tv_lambda * torch.sum(tv_loss * weights) +
                    torch.sum(class_loss * weights))
    return tot_loss, [l1_loss, tv_lambda, class_loss]


In [13]:
from torch.optim import Adam

def get_default_mask_config():
    return dict(lr=0.1, l1_lambda=1e-2, tv_lambda=1e-4, noise_std=0, n_iters=400,
                batch_size=40, verbose=False)


def generate_mask_per_batch(mask_config, mask_model, model_tup, batch_tup, cuda, m_init=None):
    bx, by = batch_tup
    batch_size = len(bx)
    if not isinstance(bx, torch.Tensor):
        bx = torch.tensor(bx)
    if not isinstance(by, torch.Tensor):
        by = torch.tensor(by)
    model, pre_fn, shape = model_tup
    if m_init is None:
        m_init = torch.zeros(batch_size, 1, 28, 28).fill_(0.5)
    else:
        m_init = m_init.detach()
    if cuda:
        bx, by = bx.cuda(), by.cuda()
        m_init = m_init.cuda()
    m_init.requires_grad = True
    optimizer = Adam([m_init], lr=mask_config['lr'])
    bx_blurred = mask_model.gaussian_blur(bx)
    for i in range(mask_config['n_iters']):
        tot_loss = mask_iter(mask_model, model, pre_fn, bx, by, None,
                             m_init, mask_config['l1_lambda'], mask_config['tv_lambda'],
                             noise_std=mask_config['noise_std'], x_blurred=bx_blurred)[0]
        if mask_config['verbose'] and i % 50 == 0:
            print(i, np.asscalar(tot_loss) / batch_size)
        optimizer.zero_grad()
        tot_loss.backward()
        optimizer.step()
        m_init.data.clamp_(0, 1)
    return m_init


def generate_masks(mask_config, model_tup, images_tup, cuda):
    if mask_config is None:
        mask_config = get_default_mask_config()
    mask_model = MASK(cuda)
    img_x, img_y = images_tup[:2]
    batch_size = mask_config['batch_size']
    num_batches = (len(img_x) + batch_size - 1) // batch_size

    masks = []
    for i in range(num_batches):
        start_index = i * batch_size
        end_index = min(len(img_x), start_index + batch_size)
        bx, by = img_x[start_index:end_index], img_y[start_index:end_index]
        masks.append(generate_mask_per_batch(mask_config, mask_model, model_tup, (bx, by), cuda).detach().cpu().numpy())

    return np.concatenate(masks, axis=0)


def generate_mask_per_batch_v2(mask_config, mask_model, model_tup, batch_tup, cuda, m_init=None):
    bx, by = batch_tup
    batch_size = len(bx)
    if not isinstance(bx, torch.Tensor):
        bx = torch.tensor(bx)
    if not isinstance(by, torch.Tensor):
        by = torch.tensor(by)
    model, pre_fn, shape = model_tup
    if m_init is None:
        m_init = torch.zeros(batch_size, 1, 28, 28).fill_(0.5)
    else:
        m_init = m_init.clone().detach()
    if cuda:
        bx, by = bx.cuda(), by.cuda()
        m_init = m_init.cuda()
    m_init.requires_grad = True
    optimizer = Adam([m_init], lr=mask_config['lr'])
    bx = F.interpolate(bx, (224 + 4, 224 + 4), mode='bilinear')
    bx_blurred = mask_model.blur1(bx)
    for i in range(mask_config['n_iters']):
        tot_loss = mask_iter_v2(mask_model, model, pre_fn, bx, by,
                                m_init, mask_config['l1_lambda'], mask_config['tv_lambda'],
                                noise_std=mask_config['noise_std'], x_blurred=bx_blurred)[0]
        if mask_config['verbose'] and i % 50 == 0:
            print(i, np.asscalar(tot_loss) / batch_size)
        optimizer.zero_grad()
        tot_loss.backward()
        optimizer.step()
        m_init.data.clamp_(0, 1)
    return m_init


def generate_masks_v2(mask_config, model_tup, images_tup, cuda):
    if mask_config is None:
        mask_config = get_default_mask_config()
    mask_model = MASKV2(cuda)
    img_x, img_y = images_tup[:2]
    batch_size = mask_config['batch_size']
    num_batches = (len(img_x) + batch_size - 1) // batch_size

    masks = []
    for i in range(num_batches):
        start_index = i * batch_size
        end_index = min(len(img_x), start_index + batch_size)
        bx, by = img_x[start_index:end_index], img_y[start_index:end_index]
        masks.append(generate_mask_per_batch_v2(mask_config, mask_model, model_tup, (bx, by), cuda).detach().cpu().numpy())

    return np.concatenate(masks, axis=0)

In [14]:
from torch.optim import Adam as TorchAdam

PGD_SAVE_PERIOD = 50

def load_model(config):
    pre_fn = imagenet_normalize
    if config['model'] == 'resnet50':
        model = resnet50(pretrained=True)
        shape = (224, 224)
    if config['model'] == 'densenet169':
        model = densenet169(pretrained=True)
        shape = (224, 224)
    freeze_model(model)
    model.train(False)
    if config['device'] == 'gpu':
        model.cuda()
    return model, pre_fn, shape

def attack_batch(config, model_tup, mask_model, batch_tup, m0):
    device = 'cuda' if config['device'] == 'gpu' else 'cpu'
    bx_np, by_np = batch_tup
    m0_np = m0
    batch_size = len(bx_np)
    bx, by, m0 = (torch.tensor(bx_np, device=device), torch.tensor(by_np, device=device),
              torch.tensor(m0, device=device))
    model, pre_fn = model_tup[:2]
    dobj = {}


    unpert_gray = bx.cpu().numpy().mean(axis = 1, keepdims=True)
    edges = np.empty_like(unpert_gray)
    for index, image in enumerate(unpert_gray):
        edges[index] = filters.sobel(image.squeeze(0))
    weights = torch.tensor(edges).to('cuda')

    m = torch.empty_like(m0).fill_(0.5)
    m.requires_grad = True

    images = bx
    flows = 0.2 * (torch.rand(batch_size, 2, images.size(2), images.size(3), device=device) - 0.5)
    flows.requires_grad_(True)

    tau = config['tau']
    flow_obj = StadvFlow()
    flow_loss_obj = StadvFlowLoss()
    flow_tvloss_obj = StadvTVLoss()
    optimizer = TorchAdam([flows], lr=0.01, amsgrad=True)

    for i in range(config['s1_iters']):
        adv_images = flow_obj(images, flows)

        pert = (adv_images - bx) * weights
        adv_images = bx + pert

        logits = model(pre_fn(adv_images))
        adv_loss = logits.scatter(1, by[:, None], -100000).max(1)[0] - logits.gather(1, by[:, None])[:, 0]
        adv_loss = torch.max(adv_loss, torch.full_like(adv_loss, -5.))
        flow_loss = flow_loss_obj(flows)
        total_loss = adv_loss + tau * flow_loss

        optimizer.zero_grad()
        total_loss.sum().backward()
        optimizer.step()
        if i % 50 == 0 or i == config['s1_iters'] - 1:
            with torch.no_grad():
                flow_loss = flow_tvloss_obj(flows)
                preds = logits.argmax(1)
                succeed = (preds == by).float().mean().item()
            print('s1-step: %d, average adv loss: %.4f, average flow loss: %.4f, succeed: %.2f' %
                  (i, adv_loss.mean().item(), flow_loss.mean().item(), succeed))

    optimizer = CustomAdam(0.01)
    torch_optimizer = TorchAdam([flows], lr=0.005, amsgrad=True)

    mean_his = [torch.zeros_like(m0)]
    variance_his = [torch.zeros_like(m0)]
    no_backprop_steps = 4
    backprop_steps = 4
    adam_step = 0
    attack_mask_config = dict(lr=0.1, l1_lambda=config['mask_l1_lambda'], tv_lambda=config['mask_tv_lambda'],
                              noise_std=0, n_iters=config['mask_iter'],
                              verbose=False)

    best_l2_dists = np.full(len(bx), 1e10, dtype=np.float32)
    best_adv_x = np.zeros_like(bx_np)
    best_adv_mask = np.zeros_like(m0_np)


    for i in range(config['s2_iters']):
        if i % PGD_SAVE_PERIOD == PGD_SAVE_PERIOD - 1:
            with torch.no_grad():
                adv_images = flow_obj(images, flows)
            mask_now_disc = generate_mask_per_batch_v2(attack_mask_config, mask_model, model_tup,
                                                       batch_tup=(adv_images_, by),
                                                       cuda=config['device'] == 'gpu').detach_()

            bx_adv_disc_np = np.uint8(255 * adv_images_.cpu().numpy())
            bx_adv_disc = torch.tensor(np.float32(bx_adv_disc_np / 255.))

            with torch.no_grad():
                logits_disc = model(pre_fn(adv_images_))

            succeed = (logits_disc.argmax(1) == by).long().cpu().numpy()

            with torch.no_grad():
                diff = m0 - mask_now_disc
                diff_norm = diff.view(batch_size, -1).norm(2, 1).cpu().numpy()

            update_flag = np.logical_and(diff_norm < best_l2_dists, succeed.astype(np.bool))
            bx_adv_disc_np = bx_adv_disc.cpu().numpy()
            mask_now_disc_np = mask_now_disc.cpu().numpy()
            best_adv_x[update_flag] = bx_adv_disc_np[update_flag]
            best_l2_dists[update_flag] = diff_norm[update_flag]
            best_adv_mask[update_flag] = mask_now_disc_np[update_flag]


            with torch.no_grad():
                flow_loss = flow_loss_obj(flows, 0.)
                flow_tvloss = flow_tvloss_obj(flows)


            with torch.no_grad():
                diff = m0 - mask_now_disc
                diff = np.asscalar((diff * diff).sum())
                flow_tvloss = flow_tvloss_obj(flows).mean().item()
                adv_loss = logits.scatter(1, by[:, None], -100000).max(1)[0] - logits.gather(1, by[:, None])[:, 0]
                adv_loss = torch.max(adv_loss, torch.full_like(adv_loss, -5.))

            print('step', i, 'l2 dist now', diff, 'succeed', np.asscalar((logits_disc.argmax(1) == by).float().mean()),
                  'succeed_disc', np.asscalar((logits_disc.argmax(1) == by).float().mean()),
                  'flow tvloss', flow_tvloss, 'adv loss', adv_loss.mean().item())
            m.data = mask_now_disc.data
            mean_his = [0.5 * mean_his[0]] # mean_his[0].detach()]
            variance_his = [0.25 * variance_his[0]]
            adam_step = 100

        for j in range(no_backprop_steps):
            adv_images_ = adv_images.detach()
            int_loss = mask_iter_v2(mask_model, model, pre_fn,
                                    F.interpolate(adv_images_, (228, 228), mode='bilinear'),
                                    by, m, noise_std=0,
                                    l1_lambda=config['mask_l1_lambda'], tv_lambda=config['mask_tv_lambda'])[0]
            int_grad = autograd.grad([int_loss], [m])[0]
            objs = optimizer.__call__(adam_step + 1, [m], [int_grad], mean_his, variance_his, False)
            m = objs[0][0]
            m = torch.clamp(m, 0, 1)
            mean_his = objs[-2]
            variance_his = objs[-1]
            m = m.detach()
            m.requires_grad = True
            adam_step += 1

        diffs = []
        for j in range(backprop_steps):
            adv_images = flow_obj(images, flows)
            int_loss = mask_iter_v2(mask_model, model, pre_fn,
                                    F.interpolate(adv_images, (228, 228), mode='bilinear'),
                                    by, m, noise_std=0,
                                    l1_lambda=config['mask_l1_lambda'], tv_lambda=config['mask_tv_lambda'])[0]
            int_grad = autograd.grad([int_loss], [m], create_graph=True)[0]
            objs = optimizer.__call__(adam_step + 1, [m], [int_grad], mean_his, variance_his, True)
            m = objs[0][0]
            m = torch.clamp(m, 0, 1)
            mean_his = objs[-2]
            variance_his = objs[-1]
            diff = m - m0
            diffs.append((diff * diff).sum())
            m = m.detach()
            m.requires_grad = True
            adam_step += 1

        int_final_loss = torch.stack(diffs).mean()
        adv_images = flow_obj(images, flows)
        logits = model(pre_fn(adv_images))
        adv_loss = logits.scatter(1, by[:, None], -100000).max(1)[0] - logits.gather(1, by[:, None])[:, 0]
        adv_loss = torch.max(adv_loss, torch.full_like(adv_loss, -5.))
        total_loss = 500 * int_final_loss + adv_loss.sum() + 0.004 * flow_loss_obj(flows).sum()

        torch_optimizer.zero_grad()
        total_loss.backward()
        torch_optimizer.step()

    dobj['adv_x'] = best_adv_x
    dobj['adv_mask'] = best_adv_mask
    dobj['adv_succeed'] = (best_l2_dists < 1e6).astype(np.int64)
    dobj['tmask'] = m0_np

    return dobj


def attack(config):

    data_arx = np.load(config['data_path'])
    img_x, img_y, img_yt = (data_arx['img_x'].copy(), data_arx['img_y'].copy(),
                            data_arx['img_yt'].copy())
    mask_benign = data_arx['mask_benign_y'].copy()
    mask_model = MASKV2(config['device'] == 'gpu')
    model_tup = load_model(config)

    n, batch_size = len(img_x), config['batch_size']
    num_batches = (n + batch_size - 1) // batch_size
    save_dobjs = []

    start_time = time.time()

    for i in range(num_batches):
        si = i * batch_size
        ei = min(si + batch_size, n)
        bx, byt, bm0 = img_x[si:ei], img_yt[si:ei], mask_benign[si:ei]
        dobj = attack_batch(config, model_tup, mask_model, (bx, byt), bm0)

        save_dobjs.append(dobj)
        print('done batch: %d' % i)

    estimated_time = time.time() - start_time

    keys = list(save_dobjs[0].keys())
    save_dobj = {}
    for key in keys:
        save_dobj[key] = np.concatenate([i[key] for i in save_dobjs], axis=0)
    save_dobj.update(dict(img_x=img_x, img_y=img_y, img_yt=img_yt, mask_benign_y=mask_benign))

    save_dobj['time'] = estimated_time
    np.savez(config['save_path'], **save_dobj)


In [15]:
def attack_mask(data_path, fName):
    config = {}

    config['data_path'] = data_path
    config['save_path'] = f'{fName}'
    config['device'] = 'gpu'
    config['model'] = 'densenet169'
    config['batch_size'] = 5
    config['epsilon'] = 0.031
    config['s1_iters'] = 400
    config['s1_lr'] = 1. / 255
    config['s2_iters'] = 1000
    config['s2_beta'] = 0.05
    config['mask_iter'] = 300
    config['mask_noise_std'] = 0
    config['mask_tv_lambda'] = 1e-2
    config['mask_l1_lambda'] = 1e-4
    config['tau'] = 0.0005
    config['c'] = 5.

    # if(not os.path.exists(save_path)):
    #   os.mkdir(save_path)

    attack(config)

In [16]:
attack_mask('mask.npz', 'output_1.npz')

  nn.init.kaiming_normal(m.weight.data)
Downloading: "https://download.pytorch.org/models/densenet169-b2777c0a.pth" to /root/.cache/torch/hub/checkpoints/densenet169-b2777c0a.pth
100%|██████████| 54.7M/54.7M [00:00<00:00, 96.4MB/s]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


s1-step: 0, average adv loss: 16.1598, average flow loss: 0.1981, succeed: 0.00
s1-step: 50, average adv loss: 6.5109, average flow loss: 0.1714, succeed: 0.20


KeyboardInterrupt: ignored

### AdvEdge+

In [17]:
def attack_batch(config, model_tup, mask_model, batch_tup, m0):
    device = 'cuda' if config['device'] == 'gpu' else 'cpu'
    bx_np, by_np = batch_tup
    m0_np = m0
    batch_size = len(bx_np)
    bx, by, m0 = (torch.tensor(bx_np, device=device), torch.tensor(by_np, device=device),
              torch.tensor(m0, device=device))
    model, pre_fn = model_tup[:2]
    dobj = {}


    unpert_gray = bx.cpu().numpy().mean(axis = 1, keepdims=True)

    edges = np.empty_like(unpert_gray)

    for index, image in enumerate(unpert_gray):
        edges[index] = filters.sobel(image.squeeze(0))

    weights = torch.tensor(edges).to('cuda')

    m = torch.empty_like(m0).fill_(0.5)
    m.requires_grad = True


    images = bx
    flows = 0.2 * (torch.rand(batch_size, 2, images.size(2), images.size(3), device=device) - 0.5)
    flows.requires_grad_(True)

    tau = config['tau']
    flow_obj = StadvFlow()
    flow_loss_obj = StadvFlowLoss()
    flow_tvloss_obj = StadvTVLoss()
    optimizer = TorchAdam([flows], lr=0.01, amsgrad=True)

    for i in range(config['s1_iters']):
        adv_images = flow_obj(images, flows)

        pert = (adv_images - bx) * weights
        adv_images = bx + torch.where(weights > 0.1, pert, torch.tensor(0.).to(device))

        logits = model(pre_fn(adv_images))
        adv_loss = logits.scatter(1, by[:, None], -100000).max(1)[0] - logits.gather(1, by[:, None])[:, 0]
        adv_loss = torch.max(adv_loss, torch.full_like(adv_loss, -5.))
        flow_loss = flow_loss_obj(flows)
        total_loss = adv_loss + tau * flow_loss

        optimizer.zero_grad()
        total_loss.sum().backward()
        optimizer.step()
        if i % 50 == 0 or i == config['s1_iters'] - 1:
            with torch.no_grad():
                flow_loss = flow_tvloss_obj(flows)
                preds = logits.argmax(1)
                succeed = (preds == by).float().mean().item()
            print('s1-step: %d, average adv loss: %.4f, average flow loss: %.4f, succeed: %.2f' %
                  (i, adv_loss.mean().item(), flow_loss.mean().item(), succeed))

    optimizer = CustomAdam(0.01)
    torch_optimizer = TorchAdam([flows], lr=0.005, amsgrad=True)

    mean_his = [torch.zeros_like(m0)]
    variance_his = [torch.zeros_like(m0)]
    no_backprop_steps = 4
    backprop_steps = 4
    adam_step = 0
    attack_mask_config = dict(lr=0.1, l1_lambda=config['mask_l1_lambda'], tv_lambda=config['mask_tv_lambda'],
                              noise_std=0, n_iters=config['mask_iter'],
                              verbose=False)

    best_l2_dists = np.full(len(bx), 1e10, dtype=np.float32)
    best_adv_x = np.zeros_like(bx_np)
    best_adv_mask = np.zeros_like(m0_np)


    for i in range(config['s2_iters']):
        if i % PGD_SAVE_PERIOD == PGD_SAVE_PERIOD - 1:
            with torch.no_grad():
                adv_images = flow_obj(images, flows)
            mask_now_disc = generate_mask_per_batch_v2(attack_mask_config, mask_model, model_tup,
                                                       batch_tup=(adv_images_, by),
                                                       cuda=config['device'] == 'gpu').detach_()

            bx_adv_disc_np = np.uint8(255 * adv_images_.cpu().numpy())
            bx_adv_disc = torch.tensor(np.float32(bx_adv_disc_np / 255.))

            with torch.no_grad():
                logits_disc = model(pre_fn(adv_images_))

            succeed = (logits_disc.argmax(1) == by).long().cpu().numpy()

            with torch.no_grad():
                diff = m0 - mask_now_disc
                diff_norm = diff.view(batch_size, -1).norm(2, 1).cpu().numpy()

            update_flag = np.logical_and(diff_norm < best_l2_dists, succeed.astype(np.bool))
            bx_adv_disc_np = bx_adv_disc.cpu().numpy()
            mask_now_disc_np = mask_now_disc.cpu().numpy()
            best_adv_x[update_flag] = bx_adv_disc_np[update_flag]
            best_l2_dists[update_flag] = diff_norm[update_flag]
            best_adv_mask[update_flag] = mask_now_disc_np[update_flag]


            with torch.no_grad():
                flow_loss = flow_loss_obj(flows, 0.)
                flow_tvloss = flow_tvloss_obj(flows)


            with torch.no_grad():
                diff = m0 - mask_now_disc
                diff = np.asscalar((diff * diff).sum())
                flow_tvloss = flow_tvloss_obj(flows).mean().item()
                adv_loss = logits.scatter(1, by[:, None], -100000).max(1)[0] - logits.gather(1, by[:, None])[:, 0]
                adv_loss = torch.max(adv_loss, torch.full_like(adv_loss, -5.))

            print('step', i, 'l2 dist now', diff, 'succeed', np.asscalar((logits_disc.argmax(1) == by).float().mean()),
                  'succeed_disc', np.asscalar((logits_disc.argmax(1) == by).float().mean()),
                  'flow tvloss', flow_tvloss, 'adv loss', adv_loss.mean().item())
            m.data = mask_now_disc.data
            mean_his = [0.5 * mean_his[0]] # mean_his[0].detach()]
            variance_his = [0.25 * variance_his[0]]
            adam_step = 100

        for j in range(no_backprop_steps):
            adv_images_ = adv_images.detach()
            int_loss = mask_iter_v2(mask_model, model, pre_fn,
                                    F.interpolate(adv_images_, (228, 228), mode='bilinear'),
                                    by, m, noise_std=0,
                                    l1_lambda=config['mask_l1_lambda'], tv_lambda=config['mask_tv_lambda'])[0]
            int_grad = autograd.grad([int_loss], [m])[0]
            objs = optimizer.__call__(adam_step + 1, [m], [int_grad], mean_his, variance_his, False)
            m = objs[0][0]
            m = torch.clamp(m, 0, 1)
            mean_his = objs[-2]
            variance_his = objs[-1]
            m = m.detach()
            m.requires_grad = True
            adam_step += 1

        diffs = []
        for j in range(backprop_steps):
            adv_images = flow_obj(images, flows)
            int_loss = mask_iter_v2(mask_model, model, pre_fn,
                                    F.interpolate(adv_images, (228, 228), mode='bilinear'),
                                    by, m, noise_std=0,
                                    l1_lambda=config['mask_l1_lambda'], tv_lambda=config['mask_tv_lambda'])[0]
            int_grad = autograd.grad([int_loss], [m], create_graph=True)[0]
            objs = optimizer.__call__(adam_step + 1, [m], [int_grad], mean_his, variance_his, True)
            m = objs[0][0]
            m = torch.clamp(m, 0, 1)
            mean_his = objs[-2]
            variance_his = objs[-1]
            diff = m - m0
            diffs.append((diff * diff).sum())
            m = m.detach()
            m.requires_grad = True
            adam_step += 1

        int_final_loss = torch.stack(diffs).mean()
        adv_images = flow_obj(images, flows)
        logits = model(pre_fn(adv_images))
        adv_loss = logits.scatter(1, by[:, None], -100000).max(1)[0] - logits.gather(1, by[:, None])[:, 0]
        adv_loss = torch.max(adv_loss, torch.full_like(adv_loss, -5.))
        total_loss = 500 * int_final_loss + adv_loss.sum() + 0.004 * flow_loss_obj(flows).sum()

        torch_optimizer.zero_grad()
        total_loss.backward()
        torch_optimizer.step()

    dobj['adv_x'] = best_adv_x
    dobj['adv_mask'] = best_adv_mask
    dobj['adv_succeed'] = (best_l2_dists < 1e6).astype(np.int64)
    dobj['tmask'] = m0_np

    return dobj


def attack(config):

    data_arx = np.load(config['data_path'])
    img_x, img_y, img_yt = (data_arx['img_x'].copy(), data_arx['img_y'].copy(),
                            data_arx['img_yt'].copy())
    mask_benign = data_arx['mask_benign_y'].copy()
    mask_model = MASKV2(config['device'] == 'gpu')
    model_tup = load_model(config)

    n, batch_size = len(img_x), config['batch_size']
    num_batches = (n + batch_size - 1) // batch_size
    save_dobjs = []

    start_time = time.time()

    for i in range(num_batches):
        si = i * batch_size
        ei = min(si + batch_size, n)
        bx, byt, bm0 = img_x[si:ei], img_yt[si:ei], mask_benign[si:ei]
        dobj = attack_batch(config, model_tup, mask_model, (bx, byt), bm0)

        save_dobjs.append(dobj)
        print('done batch: %d' % i)

    estimated_time = time.time() - start_time

    keys = list(save_dobjs[0].keys())
    save_dobj = {}
    for key in keys:
        save_dobj[key] = np.concatenate([i[key] for i in save_dobjs], axis=0)
    save_dobj.update(dict(img_x=img_x, img_y=img_y, img_yt=img_yt, mask_benign_y=mask_benign))

    save_dobj['time'] = estimated_time
    np.savez(config['save_path'], **save_dobj)


In [20]:
def attack_mask_2(data_path, fName):
    config = {}

    config['data_path'] = data_path
    config['save_path'] = f'{fName}'
    config['device'] = 'gpu'
    config['model'] = 'resnet50'
    config['batch_size'] = 5
    config['epsilon'] = 0.031
    config['s1_iters'] = 400
    config['s1_lr'] = 1. / 255
    config['s2_iters'] = 1000
    config['s2_beta'] = 0.05
    config['mask_iter'] = 300
    config['mask_noise_std'] = 0
    config['mask_tv_lambda'] = 1e-2
    config['mask_l1_lambda'] = 1e-4
    config['tau'] = 0.0005
    config['c'] = 5.

    # if(not os.path.exists(save_path)):
    #   os.mkdir(save_path)

    attack(config)

In [21]:
attack_mask_2('mask.npz', 'output_1.npz')

Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 347MB/s]


s1-step: 0, average adv loss: 13.2232, average flow loss: 0.1975, succeed: 0.00


KeyboardInterrupt: ignored