In [None]:
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import glob
from PIL import Image
import argparse
import os

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
import random
from shutil import copy2
import torch.nn.functional as F
import torch.autograd as autograd

import cv2


from scipy.spatial.distance import cdist

import math
from torch.utils import model_zoo

from copy import deepcopy
import re
from collections import OrderedDict
from torch.autograd import Function


import cv2
from skimage import exposure
from skimage import filters
import pandas as pd
import matplotlib.pyplot as plt

from torch.nn.init import normal_

In [None]:
# torch.cuda.set_device(0)
# torch.cuda.current_device()

In [None]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

def imagenet_normalize(t, mean=None, std=None):
    if mean is None:
        mean = IMAGENET_MEAN
    if std is None:
        std= IMAGENET_STD

    ts = []
    for i in range(3):
        ts.append(torch.unsqueeze((t[:, i] - mean[i]) / std[i], 1))
    return torch.cat(ts, dim=1)

preprocess = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])

def freeze_model(model):
    for param in model.parameters():
        param.requires_grad_(False)

In [None]:
model_urls = {
    'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
    'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
}


#
# AlexNet | begin
#

ALEXNET_NAME_MAP = {
    "conv1.weight": "features.0.weight",
    "conv1.bias": "features.0.bias",
    "conv2.weight": "features.3.weight",
    "conv2.bias": "features.3.bias",
    "conv3.weight": "features.6.weight",
    "conv3.bias": "features.6.bias",
    "conv4.weight": "features.8.weight",
    "conv4.bias": "features.8.bias",
    "conv5.weight": "features.10.weight",
    "conv5.bias": "features.10.bias",
    "fc1.weight": "classifier.1.weight",
    "fc1.bias": "classifier.1.bias",
    "fc2.weight": "classifier.4.weight",
    "fc2.bias": "classifier.4.bias",
    "fc3.weight": "classifier.6.weight",
    "fc3.bias": "classifier.6.bias"
}


class AlexNet(nn.Module):

    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()

        # convolutional layers
        self.conv1 = nn.Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
        self.conv2 = nn.Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
        self.conv3 = nn.Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.conv4 = nn.Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.conv5 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        # pooling layers
        self.pool1 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), dilation=(1, 1))
        self.pool2 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), dilation=(1, 1))
        self.pool5 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), dilation=(1, 1))
        # fully connected layers
        self.fc1 = nn.Linear(9216, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, num_classes)

    def forward(self, x, out_keys=None):
        out = {}
        out['c1'] = self.conv1(x)
        out['r1'] = F.relu(out['c1'])
        out['p1'] = self.pool1(out['r1'])
        out['r2'] = F.relu(self.conv2(out['p1']))
        out['p2'] = self.pool2(out['r2'])
        out['r3'] = F.relu(self.conv3(out['p2']))
        out['r4'] = F.relu(self.conv4(out['r3']))
        out['r5'] = F.relu(self.conv5(out['r4']))
        out['p5'] = self.pool5(out['r5'])
        out['fc1'] = F.relu(self.fc1(out['p5'].view(1, -1)))
        out['fc2'] = F.relu(self.fc2(out['fc1']))
        out['fc3'] = self.fc3(out['fc2'])

        if out_keys is None:
            return out['fc3']

        res = {}
        for key in out_keys:
            res[key] = out[key]
        return res


def convert_alexnet_weights(src_state, dest_state):
    for key in dest_state:
        if key in ALEXNET_NAME_MAP:
            dest_state[key] = deepcopy(src_state[ALEXNET_NAME_MAP[key]])
    return dest_state


def alexnet(pretrained=False, **kwargs):
    r"""AlexNet model architecture from the
    `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = AlexNet(**kwargs)
    if pretrained:
        src_state = model_zoo.load_url(model_urls['alexnet'])
        dest_state = convert_alexnet_weights(src_state, model.state_dict())
        model.load_state_dict(dest_state)
    return model

#
# AlexNet | end
#

#
# ResNet | begin
#


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x, out_keys=None):
        out = {}
        x = self.conv1(x)
        out["c1"] = x
        x = self.bn1(x)
        out["bn1"] = x
        x = self.relu(x)
        out["r1"] = x
        x = self.maxpool(x)
        out["p1"] = x

        x = self.layer1(x)
        out["l1"] = x
        x = self.layer2(x)
        out["l2"] = x
        x = self.layer3(x)
        out["l3"] = x
        x = self.layer4(x)
        out["l4"] = x

        x = self.avgpool(x)
        out["gvp"] = x
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        out["fc"] = x

        if out_keys is None:
            return x

        res = {}
        for key in out_keys:
            res[key] = out[key]
        return res


def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model


def resnet34(pretrained=False, **kwargs):
    """Constructs a ResNet-34 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
    return model


def resnet50(pretrained=False, **kwargs):
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
    return model


def resnet101(pretrained=False, **kwargs):
    """Constructs a ResNet-101 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
    return model


def resnet152(pretrained=False, **kwargs):
    """Constructs a ResNet-152 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
    return model


# ResNet | end


# DenseNet | begin

def densenet121(pretrained=False, **kwargs):
    r"""Densenet-121 model from
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
                     **kwargs)
    if pretrained:
        # '.'s are no longer allowed in module names, but pervious _DenseLayer
        # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
        # They are also in the checkpoints in model_urls. This pattern is used
        # to find such keys.
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        state_dict = model_zoo.load_url(model_urls['densenet121'])
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        model.load_state_dict(state_dict)
    return model


def densenet169_(pretrained=False, **kwargs):
    r"""Densenet-169 model from
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
                     **kwargs)
    if pretrained:
        # '.'s are no longer allowed in module names, but pervious _DenseLayer
        # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
        # They are also in the checkpoints in model_urls. This pattern is used
        # to find such keys.
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        state_dict = model_zoo.load_url(model_urls['densenet169'])
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        model.load_state_dict(state_dict)
    return model


def densenet201(pretrained=False, **kwargs):
    r"""Densenet-201 model from
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
                     **kwargs)
    if pretrained:
        # '.'s are no longer allowed in module names, but pervious _DenseLayer
        # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
        # They are also in the checkpoints in model_urls. This pattern is used
        # to find such keys.
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        state_dict = model_zoo.load_url(model_urls['densenet201'])
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        model.load_state_dict(state_dict)
    return model


def densenet161(pretrained=False, **kwargs):
    r"""Densenet-161 model from
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
                     **kwargs)
    if pretrained:
        # '.'s are no longer allowed in module names, but pervious _DenseLayer
        # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
        # They are also in the checkpoints in model_urls. This pattern is used
        # to find such keys.
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        state_dict = model_zoo.load_url(model_urls['densenet161'])
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        model.load_state_dict(state_dict)
    return model


class _DenseLayer(nn.Sequential):
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
        super(_DenseLayer, self).__init__()
        self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
        self.add_module('relu1', nn.ReLU(inplace=True)),
        self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
                        growth_rate, kernel_size=1, stride=1, bias=False)),
        self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
        self.add_module('relu2', nn.ReLU(inplace=True)),
        self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
                        kernel_size=3, stride=1, padding=1, bias=False)),
        self.drop_rate = drop_rate

    def forward(self, x):
        new_features = super(_DenseLayer, self).forward(x)
        if self.drop_rate > 0:
            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
        return torch.cat([x, new_features], 1)


class _DenseBlock(nn.Sequential):
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
            layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)
            self.add_module('denselayer%d' % (i + 1), layer)


class _Transition(nn.Sequential):
    def __init__(self, num_input_features, num_output_features):
        super(_Transition, self).__init__()
        self.add_module('norm', nn.BatchNorm2d(num_input_features))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
                                          kernel_size=1, stride=1, bias=False))
        self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))


class DenseNet(nn.Module):
    r"""Densenet-BC model class, based on
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        growth_rate (int) - how many filters to add each layer (`k` in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
    """
    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
                 num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000):

        super(DenseNet, self).__init__()

        # First convolution
        self.features = nn.Sequential(OrderedDict([
            ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
            ('norm0', nn.BatchNorm2d(num_init_features)),
            ('relu0', nn.ReLU(inplace=True)),
            ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
        ]))

        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
                                bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)
            self.features.add_module('denseblock%d' % (i + 1), block)
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
                trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
                self.features.add_module('transition%d' % (i + 1), trans)
                num_features = num_features // 2

        # Final batch norm
        self.features.add_module('norm5', nn.BatchNorm2d(num_features))

        # Linear layer
        self.classifier = nn.Linear(num_features, num_classes)

        # Official init from torch repo.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal(m.weight.data)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

    def forward(self, x, out_keys=None):
        out_dict = {}
        features = self.features(x)
        out = F.relu(features, inplace=True)
        out_dict['l'] = out
        out = F.avg_pool2d(out, kernel_size=7, stride=1)
        out_dict['gvp'] = out
        out = out.view(features.size(0), -1)
        out = self.classifier(out)
        out_dict['fc'] = out
        if out_keys is None:
            return out

        res = {}
        for key in out_keys:
            res[key] = out_dict[key]
        return res

# DenseNet | end


def get_gaussian_blur_kernel(ksize, sigma):
    ker = cv2.getGaussianKernel(ksize, sigma).astype(np.float32)
    blur_kernel = (ker * ker.T)[None, None]
    blur_kernel = torch.tensor(blur_kernel)

    return blur_kernel


def gaussian_blur(x, ksize, sigma):
    """

    Args:
    :param x: torch.tensor (n, c, h, w), will padding with reflection
    :param ksize: int
    :param sigma: int
    :return:
    """
    psize = int((ksize - 1) / 2)
    blur_kernel = get_gaussian_blur_kernel(ksize, sigma)
    x_padded = F.pad(x, [psize] * 4, mode="reflect")
    blurs = []
    for i in range(3):
        blurs.append(F.conv2d(x_padded[:, i, None], blur_kernel))
    blurred = torch.cat(blurs, 1)

    return blurred


class GuidedBackpropReLU(Function):

    @staticmethod
    def forward(ctx, input):
        positive_mask = (input > 0).type_as(input)
        output = torch.addcmul(torch.zeros(input.size()).type_as(input), input, positive_mask)
        ctx.save_for_backward(input, output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, output = ctx.saved_tensors
        grad_input = None

        positive_mask_1 = (input > 0).type_as(grad_output)
        positive_mask_2 = (grad_output > 0).type_as(grad_output)
        grad_input = torch.addcmul(torch.zeros(input.size()).type_as(input), torch.addcmul(torch.zeros(input.size()).type_as(input), grad_output, positive_mask_1), positive_mask_2)

        return grad_input


def freeze_model(model):
    for param in model.parameters():
        param.requires_grad_(False)


### SoftReLU


class SoftReLU(nn.Module):

    def __init__(self, eps=1e-6):
        super(SoftReLU, self).__init__()
        self.eps = eps

    def forward(self, x):
        # mask = (x > 0).float()
        # return torch.sqrt(x * x + self.eps) * mask
        return SoftReLUFunc.apply(x)


class SoftReLUFunc(autograd.Function):

    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        # v2
        x,  = ctx.saved_tensors
        # x2 = x * x
        grad_input = grad_output.clone()
        i1 = (x < 0)
        i2 = x >= 0
        xi1 = x[i1]
        xi2 = x[i2]
        n1, n2 = xi1.numel(), xi2.numel()
        assert n1 + n2 == x.numel()
        if n1 > 0:
            xi12 = xi1 * xi1
            new_v = xi1 / torch.sqrt(xi12 + 1e-4) + 1
            grad_input[i1] = grad_input[i1] * new_v
        if n2 > 0:
            xi22 = xi2 * xi2
            new_v = xi2 / torch.sqrt(xi22 + 1e-4)
            grad_input[i2] = grad_input[i2] * new_v
        return grad_input


In [None]:
class ResNetEncoder(torchvision.models.ResNet):

    def forward(self, x):
        s0 = x
        x = self.conv1(s0)
        x = self.bn1(x)
        s1 = self.relu(x)
        x = self.maxpool(s1)

        s2 = self.layer1(x)
        s3 = self.layer2(s2)
        s4 = self.layer3(s3)

        s5 = self.layer4(s4)

        x = self.avgpool(s5)
        sX = x.view(x.size(0), -1)
        sC = self.fc(sX)

        return s0, s1, s2, s3, s4, s5, sX, sC


def resnet50encoder(pretrained=False, **kwargs):
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNetEncoder(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/resnet50-19c8e357.pth'))
    return model


class Bottleneck_(nn.Module):

    def __init__(self, in_channels, out_channels, stride=1, bottleneck_ratio=4,
                 activation_fn=lambda: torch.nn.ReLU(inplace=False)):
        super(Bottleneck_, self).__init__()
        bottleneck_channels = out_channels // bottleneck_ratio
        self.conv1 = nn.Conv2d(in_channels, bottleneck_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(bottleneck_channels)
        self.conv2 = nn.Conv2d(bottleneck_channels, bottleneck_channels, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(bottleneck_channels)
        self.conv3 = nn.Conv2d(bottleneck_channels, out_channels, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels)
        self.activation_fn = activation_fn()

        if stride != 1 or in_channels != out_channels:
            self.residual_transformer = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=True)
        else:
            self.residual_transformer = None

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.activation_fn(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.activation_fn(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.residual_transformer is not None:
            residual = self.residual_transformer(residual)
        out += residual

        out = self.activation_fn(out)
        return out


def simple_cnn_block(in_channels, out_channels,
                     kernel_size=3, layers=1, stride=1,
                     follow_with_bn=True, activation_fn=lambda: nn.ReLU(True), affine=True):
    assert layers > 0 and kernel_size % 2 > 0 and stride > 0
    current_channels = in_channels
    _modules = []
    for layer in range(layers):
        _modules.append(nn.Conv2d(current_channels, out_channels, kernel_size=kernel_size,
                                  stride=stride if layer == 0 else 1,
                                  padding=kernel_size // 2, bias=not follow_with_bn))
        current_channels = out_channels
        if follow_with_bn:
            _modules.append(nn.BatchNorm2d(current_channels, affine=affine))
        if activation_fn is not None:
            _modules.append(activation_fn())
    return nn.Sequential(*_modules)


def bottleneck_block(in_channels, out_channels, stride=1, layers=1,
                     activation_fn=lambda: torch.nn.ReLU(inplace=False)):
    assert layers > 0 and stride > 0
    current_channels = in_channels
    _modules = []
    for layer in range(layers):
        _modules.append(Bottleneck_(current_channels, out_channels, stride=stride if layer == 0 else 1,
                                   activation_fn=activation_fn))
        current_channels = out_channels
    return nn.Sequential(*_modules) if len(_modules) > 1 else _modules[0]


class PixelShuffleBlock(nn.Module):

    def forward(self, x):
        return F.pixel_shuffle(x, 2)


def simple_upsampler_subpixel(in_channels, out_channels, kernel_size=3,
                              activation_fn=lambda: torch.nn.ReLU(inplace=True),
                              follow_with_bn=True):
    _modules = [
        simple_cnn_block(in_channels, out_channels * 4, kernel_size=kernel_size,
                         follow_with_bn=follow_with_bn),
        PixelShuffleBlock(),
        activation_fn(),
    ]
    return nn.Sequential(*_modules)


class UNetUpsampler(nn.Module):

    def __init__(self, in_channels, out_channels, passthrough_channels, follow_up_residual_blocks=1,
                 upsampler_block=simple_upsampler_subpixel, upsampler_kernel_size=3, activation_fn=lambda: torch.nn.ReLU(inplace=False)):
        super(UNetUpsampler, self).__init__()
        assert follow_up_residual_blocks >= 1
        assert passthrough_channels >= 1
        self.upsampler = upsampler_block(in_channels=in_channels,
                                         out_channels=out_channels,
                                         kernel_size=upsampler_kernel_size,
                                         activation_fn=activation_fn)
        self.follow_up = bottleneck_block(out_channels + passthrough_channels, out_channels,
                                    layers=follow_up_residual_blocks, activation_fn=activation_fn)

    def forward(self, inp, passthrough):
        upsampled = self.upsampler(inp)
        upsampled = torch.cat([upsampled, passthrough], 1)
        return self.follow_up(upsampled)


class RTSaliencyModel(nn.Module):

    def __init__(self, encoder, encoder_scales, encoder_base, upsampler_scales,
                 upsampler_base, fix_encoder=True, use_simple_activation=False,
                 allow_selector=False, num_classes=1000):
        super(RTSaliencyModel, self).__init__()

        self.encoder = encoder
        self.upsampler_scales = upsampler_scales
        self.encoder_scales = encoder_scales
        self.fix_encoder = fix_encoder
        self.use_simple_activation = use_simple_activation

        down = self.encoder_scales
        modulator_size = []
        for up in reversed(range(self.upsampler_scales)):
            upsampler_chans = upsampler_base * 2 ** (up + 1)
            encoder_chans = encoder_base * 2 ** down
            inc = upsampler_chans if down != encoder_scales else encoder_chans
            modulator_size.append(inc)
            self.add_module("up%d" % up,
                            UNetUpsampler(
                                in_channels=inc,
                                passthrough_channels=encoder_chans // 2,
                                out_channels=upsampler_chans // 2,
                                follow_up_residual_blocks=1,
                                activation_fn=lambda: nn.ReLU(),
                            ))

            down -= 1

        self.to_saliency_chans = nn.Conv2d(upsampler_base, 2, 1)

        self.allow_selector = allow_selector

        if self.allow_selector:
            s = encoder_base * 2 ** encoder_scales
            self.selector_module = nn.Embedding(num_classes, s)
            normal_(self.selector_module.weight, 0, 1. / s ** 0.5)

    def get_trainable_parameters(self):
        all_params = self.parameters()
        if not self.fix_encoder: return set(all_params)
        unwanted = self.encoder.parameters()
        return set(all_params) - set(unwanted) - (set(self.selector_module.parameters() if self.allow_selector
                                                      else set()))

    def forward(self, _images, _selectors=None, pt_store=None, model_confidence=0.):
        out = self.encoder(_images)
        if self.fix_encoder:
            out = [e.detach() for e in out]

        down = self.encoder_scales
        main_flow = out[down]

        if self.allow_selector:
            assert _selectors is not None
            em = torch.squeeze(self.selector_module(_selectors.view(-1, 1)), 1)
            act = torch.sum(main_flow * em.view(-1, 2048, 1, 1), 1, keepdim=True)
            th = torch.sigmoid(act - model_confidence)
            main_flow = main_flow * th

            ex = torch.mean(torch.mean(act, 3), 2)
            exists_logits = torch.cat((-ex / 2., ex / 2.), 1)
        else:
            exists_logits = None

        for up in reversed(range(self.upsampler_scales)):
            assert down > 0
            main_flow = self._modules['up%d' % up](main_flow, out[down - 1])
            down -= 1
        saliency_chans = self.to_saliency_chans(main_flow)

        if self.use_simple_activation:
            return torch.unsqueeze(torch.sigmoid(saliency_chans[:, 0, :, :] / 2), dim=1), exists_logits, out[-1]

        a = torch.abs(saliency_chans[:, 0, :, :])
        b = torch.abs(saliency_chans[:, 1, :, :])
        return torch.unsqueeze(a / (a + b), dim=1), exists_logits, out[-1]

    def minimialistic_restore(self, save_dir):
        # assert self.fix_encoder, 'You should not use this function if you are not using a pre-trained encoder like resnet'

        p = os.path.join(save_dir, 'model-%d.ckpt' % 1)
        if not os.path.exists(p):
            raise FileNotFoundError('Could not find any checkpoint at %s, skipping restore' % p)
        for name, data in torch.load(p, map_location=lambda storage, loc: storage).items():
            self._modules[name].load_state_dict(data)

    def minimalistic_save(self, save_dir):
        assert self.fix_encoder, 'You should not use this function if you are not using a pre-trained encoder like resnet'
        data = {}
        for name, module in self._modules.items():
            if module is self.encoder:  # we do not want to restore the encoder as it should have its own restore function
                continue
            data[name] = module.state_dict()
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        torch.save(data, os.path.join(save_dir, 'model-%d.ckpt' % 1))


def _gaussian_kernels(kernel_size, sigma, chans):
    assert kernel_size % 2, 'Kernel size of the gaussian blur must be odd!'
    x = np.expand_dims(np.array(range(-kernel_size // 2, -kernel_size // 2 + kernel_size, 1)), 0)
    vals = np.exp(-np.square(x) / (2.*sigma**2))
    _kernel = np.reshape(vals / np.sum(vals), (1, 1, kernel_size, 1))
    kernel = np.zeros((chans, 1, kernel_size, 1), dtype=np.float32) + _kernel
    return kernel, np.transpose(kernel, [0, 1, 3, 2])


def gaussian_blur(_images, kernel_size=55, sigma=11):
    ''' Very fast, linear time gaussian blur, using separable convolution. Operates on batch of images [N, C, H, W].
    Returns blurred images of the same size. Kernel size must be odd.
    Increasing kernel size over 4*simga yields little improvement in quality. So kernel size = 4*sigma is a good choice.'''
    kernel_a, kernel_b = _gaussian_kernels(kernel_size=kernel_size, sigma=sigma, chans=_images.size(1))
    kernel_a = torch.Tensor(kernel_a)
    kernel_b = torch.Tensor(kernel_b)
    if _images.is_cuda:
        kernel_a = kernel_a.cuda()
        kernel_b = kernel_b.cuda()
    _rows = F.conv2d(_images, kernel_a, groups=_images.size(1), padding=(kernel_size // 2, 0))
    return F.conv2d(_rows, kernel_b, groups=_images.size(1), padding=(0, kernel_size // 2))


def apply_mask(images, mask, noise=True, random_colors=True, blurred_version_prob=0.5, noise_std=0.11,
               color_range=0.66, blur_kernel_size=55, blur_sigma=11,
               bypass=0., boolean=False, preserved_imgs_noise_std=0.03):
    images = images.clone()
    cuda = images.is_cuda

    if boolean:
        # remember its just for validation!
        return (mask > 0.5).float() *images

    assert 0. <= bypass < 0.9
    n, c, _, _ = images.size()
    if preserved_imgs_noise_std > 0:
        images = images + torch.empty_like(images).normal_(std=preserved_imgs_noise_std)
    if bypass > 0:
        mask = (1.-bypass)*mask + bypass
    if noise and noise_std:
        alt = torch.empty_like(images).normal_(std=noise_std)
    else:
        alt = torch.zeros_like(images)
    if random_colors:
        if cuda:
            alt += torch.Tensor(n, c, 1, 1).cuda().uniform_(-color_range/2., color_range/2.)
        else:
            alt += torch.Tensor(n, c, 1, 1).uniform_(-color_range/2., color_range/2.)

    if blurred_version_prob > 0.: # <- it can be a scalar between 0 and 1
        cand = gaussian_blur(images, kernel_size=blur_kernel_size, sigma=blur_sigma)
        if cuda:
            when =(torch.Tensor(n, 1, 1, 1).cuda().uniform_(0., 1.) < blurred_version_prob).float()
        else:
            when =(torch.Tensor(n, 1, 1, 1).uniform_(0., 1.) < blurred_version_prob).float()
        alt = alt * (1. - when) + cand * when

    return (mask * images.detach()) + (1. - mask) * alt.detach()


def calc_smoothness_loss(mask, power=2, border_penalty=0.3):
    ''' For a given image this loss should be more or less invariant to image resize when using power=2...
        let L be the length of a side
        EdgesLength ~ L
        EdgesSharpness ~ 1/L, easy to see if you imagine just a single vertical edge in the whole image'''
    x_loss = torch.sum((torch.abs(mask[:, :, 1:, :] - mask[:, :, :-1, :])) ** power)
    y_loss = torch.sum((torch.abs(mask[:, :, :, 1:] - mask[:, :, :, :-1])) ** power)
    if border_penalty > 0:
        border = (float(border_penalty) * torch.sum(mask[:, :, -1, :] ** power +
                                                    mask[:, :, 0, :] ** power +
                                                    mask[:, :, :, -1] ** power + mask[:, :, :, 0]**power))
    else:
        border = 0.
    return (x_loss + y_loss + border) / float(power * mask.size(0))  # watch out, normalised by the batch size!


def calc_area_loss(mask, power=1.):
    if power != 1:
        mask = (mask + 0.0005) ** power # prevent nan (derivative of sqrt at 0 is inf)
    return torch.mean(mask)


def cw_loss(logits, one_hot_labels, targeted=True, t_conf=2, nt_conf=5):
    ''' computes the advantage of the selected label over other highest prob guess.
        In case of the targeted it tries to maximise this advantage to reach desired confidence.
        For example confidence of 3 would mean that the desired label is e^3 (about 20) times more probable than the second top guess.
        In case of non targeted optimisation the case is opposite and we try to minimise this advantage - the probability of the label is
        20 times smaller than the probability of the top guess.
        So for targeted optim a small confidence should be enough (about 2) and for non targeted about 5-6 would work better (assuming 1000 classes so log(no_idea)=6.9)
    '''
    this = torch.sum(logits*one_hot_labels, 1)
    other_best, _ = torch.max(logits*(1.-one_hot_labels) - 12111*one_hot_labels, 1)   # subtracting 12111 from selected labels to make sure that they dont end up a maximum
    t = F.relu(other_best - this + t_conf)
    nt = F.relu(this - other_best + nt_conf)
    if isinstance(targeted, (bool, int)):
        return torch.mean(t) if targeted else torch.mean(nt)
    else:  # must be a byte tensor of zeros and ones

        return torch.mean(t*(targeted>0).float() + nt*(targeted==0).float())


def one_hot(labels, depth):
    if labels.is_cuda:
        return torch.zeros(labels.size(0), depth).cuda().scatter_(1, labels.long().view(-1, 1).data, 1)
    else:
        return torch.zeros(labels.size(0), depth).scatter_(1, labels.long().view(-1, 1).data, 1)


class SaliencyLoss:
    def __init__(self, black_box_fn, area_loss_coef=8, smoothness_loss_coef=0.5, preserver_loss_coef=0.3,
                 num_classes=1000, area_loss_power=0.3, preserver_confidence=1, destroyer_confidence=5, **apply_mask_kwargs):
        self.black_box_fn = black_box_fn
        self.area_loss_coef = area_loss_coef
        self.smoothness_loss_coef = smoothness_loss_coef
        self.preserver_loss_coef = preserver_loss_coef
        self.num_classes = num_classes
        self.area_loss_power =area_loss_power
        self.preserver_confidence = preserver_confidence
        self.destroyer_confidence = destroyer_confidence
        self.apply_mask_kwargs = apply_mask_kwargs

    def get_loss(self, _images, _targets, _masks, _is_real_target=None, pt_store=None):
        ''' masks must be already in the range 0,1 and of shape:  (B, 1, ?, ?)'''
        if _masks.size()[-2:] != _images.size()[-2:]:
            _masks = F.upsample(_masks, (_images.size(2), _images.size(3)), mode='bilinear')

        if _is_real_target is None:
            _is_real_target = torch.ones_like(_targets)
        destroyed_images = apply_mask(_images, 1.-_masks, **self.apply_mask_kwargs)
        destroyed_logits = self.black_box_fn(destroyed_images)

        preserved_images = apply_mask(_images, _masks, **self.apply_mask_kwargs)
        preserved_logits = self.black_box_fn(preserved_images)

        _one_hot_targets = one_hot(_targets, self.num_classes)
        preserver_loss = cw_loss(preserved_logits, _one_hot_targets, targeted=_is_real_target == 1, t_conf=self.preserver_confidence, nt_conf=1.)
        destroyer_loss = cw_loss(destroyed_logits, _one_hot_targets, targeted=_is_real_target == 0, t_conf=1., nt_conf=self.destroyer_confidence)
        area_loss = calc_area_loss(_masks, self.area_loss_power)
        smoothness_loss = calc_smoothness_loss(_masks)

        total_loss = destroyer_loss + self.area_loss_coef*area_loss + self.smoothness_loss_coef*smoothness_loss + self.preserver_loss_coef*preserver_loss

        if pt_store is not None:
            # add variables to the pt_store
            pt_store(masks=_masks)
            pt_store(destroyed=destroyed_images)
            pt_store(preserved=preserved_images)
            pt_store(area_loss=area_loss)
            pt_store(smoothness_loss=smoothness_loss)
            pt_store(destroyer_loss=destroyer_loss)
            pt_store(preserver_loss=preserver_loss)
            pt_store(preserved_logits=preserved_logits)
            pt_store(destroyed_logits=destroyed_logits)
        return total_loss


def to_batch_variable(x, required_rank, cuda=False):
    if isinstance(x, torch.Tensor):
        if cuda and not x.is_cuda:
            return x.cuda()
        if not cuda and x.is_cuda:
            return x.cpu()
        else:
            return x
    if isinstance(x, (float, int)):
        assert required_rank == 1
        return to_batch_variable(np.array([x]), required_rank, cuda)
    if isinstance(x, (list, tuple)):
        return to_batch_variable(np.array(x), required_rank, cuda)
    if isinstance(x, np.ndarray):
        c = len(x.shape)
        if c == required_rank:
            return to_batch_variable(torch.from_numpy(x), required_rank, cuda)
        elif c + 1 == required_rank:
            return to_batch_variable(torch.unsqueeze(torch.from_numpy(x), dim=0), required_rank, cuda)
        else:
            raise ValueError()


def get_pretrained_saliency_fn(ckpt_dir, cuda=True, return_classification_logits=False):
    ''' returns a saliency function that takes images and class selectors as inputs. If cuda=True then places the model on a GPU.
    You can also specify model_confidence - smaller values (~0) will show any object in the image that even slightly resembles the specified class
    while higher values (~5) will show only the most salient parts.
    Params of the saliency function:
    images - input images of shape (C, H, W) or (N, C, H, W) if in batch. Can be either a numpy array, a Tensor or a Variable
    selectors - class ids to be masked. Can be either an int or an array with N integers. Again can be either a numpy array, a Tensor or a Variable
    model_confidence - a float, 6 by default, you may want to decrease this value to obtain more complete saliency maps.
    returns a Variable of shape (N, 1, H, W) with one saliency maps for each input image.
    '''
    saliency = RTSaliencyModel(resnet50encoder(pretrained=True), 5, 64, 3, 64, fix_encoder=False, use_simple_activation=False, allow_selector=True)
    saliency.minimialistic_restore(ckpt_dir)
    saliency.train(False)
    if cuda:
        saliency = saliency.cuda()

    def saliency_fn(images, selectors, model_confidence=6):
        _images, _selectors = to_batch_variable(images, 4, cuda), to_batch_variable(selectors, 1, cuda).long()
        masks, _, cls_logits = saliency(_images * 2, _selectors, model_confidence=model_confidence)
        sal_map = F.upsample(masks, (_images.size(2), _images.size(3)), mode='bilinear')
        if not return_classification_logits:
            return sal_map
        return sal_map, cls_logits

    def logits_fn(images):
        _images = to_batch_variable(images, 4, cuda)
        logits = saliency.encoder(_images * 2)[-1]
        return logits

    return saliency_fn, logits_fn


def read_image(image_path):
    img = cv2.imread(image_path)
    img = cv2.resize(img, (224, 224))
    img = np.transpose(img[..., ::-1], [2, 0, 1])
    img = np.float32(img) / 255. * 2 - 1
    return to_batch_variable(img, 4)


In [None]:
from torchvision.models import densenet169


class RTSDensenet169(object):

    def __init__(self, ckpt_dir, cuda, blackbox_model=None, pre_fn=None):
        self.saliency = RTSaliencyModel(resnet50encoder(pretrained=True), 5, 64, 3, 64, fix_encoder=False, use_simple_activation=False, allow_selector=True)
#         self.saliency.minimialistic_restore(ckpt_dir)
        self.saliency.train(False)
        if cuda:
            self.saliency.cuda()

        if blackbox_model is None:
            blackbox_model = densenet169(pretrained=True)
            self.blackbox_model = blackbox_model
            self.blackbox_model.train(False)
            if cuda:
                self.blackbox_model.cuda()
        else:
            self.blackbox_model = blackbox_model
        self.pre_fn = imagenet_normalize if pre_fn is None else pre_fn

    def saliency_fn(self, x, y, model_confidence=6, return_classification_logits=False):
        masks, _, cls_logits = self.saliency(imagenet_normalize(x), y,
                                             model_confidence=model_confidence)
        # sal_map = F.upsample(masks, (x.size(2), x.size(3)), mode='bilinear')
        if not return_classification_logits:
            return masks
        return masks, cls_logits

    def logits_fn(self, x):
        logits = self.saliency.encoder(imagenet_normalize(x))[-1]
        return logits

    def blackbox_logits_fn(self, x):
        return self.blackbox_model(self.pre_fn(x))


In [None]:
class RTSResnet50(object):

    def __init__(self, ckpt_dir, cuda, blackbox_model=None, pre_fn=None):
        self.saliency = RTSaliencyModel(resnet50encoder(pretrained=True), 5, 64, 3, 64, fix_encoder=False, use_simple_activation=False, allow_selector=True)
        # self.saliency.minimialistic_restore(ckpt_dir)
        self.saliency.train(False)
        if cuda:
            self.saliency.cuda()

        if blackbox_model is None:
            blackbox_model = resnet50(pretrained=True)
            self.blackbox_model = blackbox_model
            self.blackbox_model.train(False)
            if cuda:
                self.blackbox_model.cuda()
        else:
            self.blackbox_model = blackbox_model
        self.pre_fn = imagenet_normalize if pre_fn is None else pre_fn

    def saliency_fn(self, x, y, model_confidence=6, return_classification_logits=False):
        masks, _, cls_logits = self.saliency((x - 0.5) * 4, y, model_confidence=model_confidence)
        # sal_map = F.upsample(masks, (x.size(2), x.size(3)), mode='bilinear')
        if not return_classification_logits:
            return masks
        return masks, cls_logits

    def logits_fn(self, x):
        logits = self.saliency.encoder((x - 0.5) * 4)[-1]
        return logits

    def blackbox_logits_fn(self, x):
        return self.blackbox_model(self.pre_fn(x))


In [None]:
def get_default_rts_config(model):
    if model == 'resnet50':
        return dict(ckpt_dir='', batch_size=10, model_confidence=5)
    if model == 'densenet169':
        return dict(ckpt_dir='', batch_size=10, model_confidence=5)


def generate_rts_per_batch(rts_config, rts_model, batch_tup, cuda):
    bx, by = batch_tup
    bx, by = torch.tensor(bx), torch.tensor(by)
    if cuda:
        bx, by = bx.cuda(), by.cuda()
    return rts_model.saliency_fn(bx, by, model_confidence=rts_config['model_confidence'],
                                 return_classification_logits=False)


def generate_rts(rts_config, rts_model, images_tup, cuda):
    img_x, img_y = images_tup[:2]
    batch_size = rts_config['batch_size']
    num_batches = (len(img_x) + batch_size - 1) // batch_size

    rts = []
    for i in range(num_batches):
        start_index = i * batch_size
        end_index = min(len(img_x), start_index + batch_size)
        bx, by = img_x[start_index:end_index], img_y[start_index:end_index]
        rts.append(generate_rts_per_batch(rts_config, rts_model, (bx, by), cuda).detach().cpu().numpy())

    rts = np.concatenate(rts, axis=0)
    return rts


### AdvEdge

In [None]:
import time

from torch.optim import Adam


def load_model(config):
    if config['model'] == 'resnet50':
        model_tup, forward_tup = cam_resnet50()
    if config['model'] == 'densenet169':
        model_tup, forward_tup = cam_densenet169()
    model = model_tup[0]
    freeze_model(model)
    model.train(False)
    if config['device'] == 'gpu':
        model.cuda()
    return model_tup, forward_tup

def tanh_space(x):
    return 1/2*(torch.tanh(x) + 1)

def inverse_tanh_space(x):
    return atanh(x*2-1)

def atanh(x):
    return 0.5*torch.log((1+x)/(1-x))

def f(outputs, labels, kappa):
    one_hot_labels = torch.eye(len(outputs[0]))[labels.cpu()].to('cuda')
    i, _ = torch.max((1-one_hot_labels)*outputs, dim=1)
    j = torch.masked_select(outputs, one_hot_labels.bool())

    # if self._targeted:
    return torch.clamp((i-j), min=-kappa)
    # else:
    #     return torch.clamp((j-i), min=-self.kappa)

def attack_batch(config, rts_model, batch_tup, rts_benign):
    cuda = config['device'] == 'gpu'
    bx_np, by_np = batch_tup
    batch_size = len(bx_np)
    bx, by = torch.tensor(bx_np), torch.tensor(by_np)
    m0 = torch.tensor(rts_benign)
    if cuda:
        bx, by, m0 = bx.cuda(), by.cuda(), m0.cuda()
    bx_adv = bx.clone().detach().requires_grad_()
    m0_flatten = m0.view(batch_size, -1)

    unpert_gray = bx.cpu().numpy().mean(axis = 1, keepdims=True)
    # print(unpert_gray.shape)
    edges = np.empty_like(unpert_gray)
    # print(edges.shape)
    for index, image in enumerate(unpert_gray):
        edges[index] = filters.sobel(image.squeeze(0))
    # print(edges.shape)
    weights = torch.tensor(edges).to('cuda')

    w = inverse_tanh_space(bx.clone().detach()).detach()
    w.requires_grad_()

    best_adv_images = bx.clone().detach()
    best_L2 = 1e10*torch.ones((len(bx))).to('cuda')
    prev_cost = 1e10
    dim = len(bx.shape)

    s1_lr = config['s1_lr']
    s2_lr = config['s2_lr']
    eps = config['epsilon']

    dobj = {}

    MSELoss = nn.MSELoss(reduction='none')
    Flatten = nn.Flatten()
    optimizer = Adam([w], lr=0.01, amsgrad=True)


    for i in range(config['s1_iters']):
        adv_images = tanh_space(w)

        pert = (adv_images - bx) * weights
        adv_images = bx + pert

        current_L2 = MSELoss(Flatten(adv_images),
                                 Flatten(bx)).sum(dim=1)
        L2_loss = current_L2.sum()

        logits = rts_model.blackbox_logits_fn(adv_images)
        f_loss = f(logits, by, config['kappa']).sum()

        cost = L2_loss + 1e-4*f_loss

        optimizer.zero_grad()
        cost.backward()
        optimizer.step()


    optimizer = Adam([w], lr=0.01, amsgrad=True)
    c_begin, c_final = config['c'], config['c'] * 2
    c_inc = (c_final - c_begin) / config['s2_iters']
    c_now = config['c']

    label_indices = np.arange(0, batch_size, dtype=np.int64)

    for i in range(config['s2_iters']):

        conf_base = 0.95 + i / config['s2_iters'] * 0.04
        conf = np.random.uniform(conf_base, 1, size=(batch_size, )).astype(np.float32)
        conf_mat = ((1 - conf) / 9.).reshape((batch_size, 1)).repeat(1000, 1)
        conf_mat[label_indices, by_np] = conf

        by_one = torch.tensor(conf_mat, device='cuda')

        c_now += c_inc

        adv_images = tanh_space(w)
        current_L2 = MSELoss(Flatten(adv_images),
                                 Flatten(bx)).sum(dim=1)
        L2_loss = current_L2.sum()

        rts, rts_logits = rts_model.saliency_fn(adv_images, by, model_confidence=5, return_classification_logits=True)
        diff = rts - m0
        loss_rts = torch.sum((diff * diff).view(batch_size, -1).mean(1))
        logits = rts_model.blackbox_logits_fn(adv_images)
        loss_adv =  (-by_one * F.log_softmax(logits)).sum()
        rts_adv_loss = F.nll_loss(rts_logits, by, reduction='sum')
        loss = 0.5 * rts_adv_loss + L2_loss + c_now *loss_rts + 2 * loss_adv

        if i % 100 == 0:
            with torch.no_grad():
                pred = torch.argmax(logits, 1)
                loss_rts_mu = np.asscalar(loss_rts) / batch_size
                loss_adv_mu = np.asscalar(loss_adv) / batch_size
                num_succeed = np.asscalar(torch.sum(by == pred))
                loss_adv = loss_adv_mu
                loss_rts = loss_rts_mu
            print('s2-step: %d, loss adv: %.2f, loss rts: %.5f, succeed: %d' % (i, loss_adv, loss_rts, num_succeed))

        optimizer.zero_grad()
        loss.sum().backward()
        optimizer.step()

    rts = rts_model.saliency_fn(adv_images, by, model_confidence=5, return_classification_logits=False)
    logits = rts_model.blackbox_logits_fn(adv_images)

    dobj['adv_x' ] = adv_images.detach().cpu().numpy()
    dobj['adv_rts'] = rts.detach().cpu().numpy()
    dobj['adv_logits'] = logits.detach().cpu().numpy()
    dobj['adv_succeed'] = (logits.argmax(1) == by).detach().cpu().numpy().astype(np.int64)
    dobj['trts'] = rts_benign
    return dobj


def attack(config):
    rts_model = RTSDensenet169('', config['device'] == 'gpu')
    freeze_model(rts_model.blackbox_model)
    freeze_model(rts_model.saliency)

    data_arx = np.load(config['data_path'])
    img_x, img_yt = data_arx['img_x'], data_arx['img_yt']
    rts_target = data_arx['saliency_benign_yt']
    rts_benign = data_arx['saliency_benign_y']

# ['saliency_benign_y', 'saliency_benign_yt', 'img_x', 'img_y', 'img_yt']

    n, batch_size = len(img_x), config['batch_size']
    num_batches = (n + batch_size - 1) // batch_size
    save_dobjs = []

    start_time = time.time()

    for i in range(num_batches):
        si = i * batch_size
        ei = min(si + batch_size, n)
        bx, byt, bm0 = img_x[si:ei], img_yt[si:ei], rts_target[si:ei]
        dobj = attack_batch(config, rts_model, (bx, byt), bm0)
        dobj['brts'] = rts_benign[si:ei]
        save_dobjs.append(dobj)

    estimated_time = time.time() - start_time

    keys = list(save_dobjs[0].keys())
    save_dobj = {}
    for key in keys:
        save_dobj[key] = np.concatenate([i[key] for i in save_dobjs], axis=0)

    save_dobj['time'] = estimated_time
    save_dobj['img_x'] = img_x
    np.savez(config['save_path'], **save_dobj)


def attack_rts_1(data_path, fName):
    config = {}
    config['data_path'] = data_path
    config['save_path'] = f'{fName}'
    config['device'] = 'gpu'
    config['batch_size'] = 10
    config['epsilon'] = 0.031
    config['s1_iters'] = 300
    config['s1_lr'] = 1./255
    config['s2_iters'] = 1000
    config['s2_lr'] = 1./255
    config['c'] = 5.
    config['kappa'] = 0
    config['tau'] = 0.0005

#     if (not os.path.exists(config['save_path'])):
#       os.mkdir(config['save_path'])
#     config['save_path'] += fName


    if config['device'] is None:
        if torch.cuda.is_available():
            config['device'] = 'gpu'
        else:
            config['device'] = 'cpu'
    attack(config)


In [1]:
attack_rts_1('rts.npz', 'output_1.npz')

### AdvEdge+

In [None]:
def attack_batch(config, rts_model, batch_tup, rts_benign):
    cuda = config['device'] == 'gpu'
    bx_np, by_np = batch_tup
    batch_size = len(bx_np)
    bx, by = torch.tensor(bx_np), torch.tensor(by_np)
    m0 = torch.tensor(rts_benign)
    if cuda:
        bx, by, m0 = bx.cuda(), by.cuda(), m0.cuda()
    bx_adv = bx.clone().detach().requires_grad_()
    m0_flatten = m0.view(batch_size, -1)

    unpert_gray = bx.cpu().numpy().mean(axis = 1, keepdims=True)
    # print(unpert_gray.shape)
    edges = np.empty_like(unpert_gray)
    # print(edges.shape)
    for index, image in enumerate(unpert_gray):
        edges[index] = filters.sobel(image.squeeze(0))
    # print(edges.shape)
    weights = torch.tensor(edges).to('cuda')

    w = inverse_tanh_space(bx.clone().detach()).detach()
    w.requires_grad_()

    best_adv_images = bx.clone().detach()
    best_L2 = 1e10*torch.ones((len(bx))).to('cuda')
    prev_cost = 1e10
    dim = len(bx.shape)

    s1_lr = config['s1_lr']
    s2_lr = config['s2_lr']
    eps = config['epsilon']

    dobj = {}

    MSELoss = nn.MSELoss(reduction='none')
    Flatten = nn.Flatten()
    optimizer = Adam([w], lr=0.01, amsgrad=True)


    for i in range(config['s1_iters']):
        adv_images = tanh_space(w)

        pert = (adv_images - bx) * weights
        adv_images = bx + torch.where(weights > 0.1, pert, torch.tensor(0.).to('cuda'))

        current_L2 = MSELoss(Flatten(adv_images),
                                 Flatten(bx)).sum(dim=1)
        L2_loss = current_L2.sum()

        logits = rts_model.blackbox_logits_fn(adv_images)
        f_loss = f(logits, by, config['kappa']).sum()

        cost = L2_loss + 1e-4*f_loss

        optimizer.zero_grad()
        cost.backward()
        optimizer.step()


    optimizer = Adam([w], lr=0.01, amsgrad=True)
    c_begin, c_final = config['c'], config['c'] * 2
    c_inc = (c_final - c_begin) / config['s2_iters']
    c_now = config['c']

    label_indices = np.arange(0, batch_size, dtype=np.int64)

    for i in range(config['s2_iters']):

        conf_base = 0.95 + i / config['s2_iters'] * 0.04
        conf = np.random.uniform(conf_base, 1, size=(batch_size, )).astype(np.float32)
        conf_mat = ((1 - conf) / 9.).reshape((batch_size, 1)).repeat(1000, 1)
        conf_mat[label_indices, by_np] = conf

        by_one = torch.tensor(conf_mat, device='cuda')

        c_now += c_inc

        adv_images = tanh_space(w)
        current_L2 = MSELoss(Flatten(adv_images),
                                 Flatten(bx)).sum(dim=1)
        L2_loss = current_L2.sum()

        rts, rts_logits = rts_model.saliency_fn(adv_images, by, model_confidence=5, return_classification_logits=True)
        diff = rts - m0
        loss_rts = torch.sum((diff * diff).view(batch_size, -1).mean(1))
        logits = rts_model.blackbox_logits_fn(adv_images)
        loss_adv =  (-by_one * F.log_softmax(logits)).sum()
        rts_adv_loss = F.nll_loss(rts_logits, by, reduction='sum')
        loss = 0.5 * rts_adv_loss + L2_loss + c_now *loss_rts + 2 * loss_adv

        if i % 100 == 0:
            with torch.no_grad():
                pred = torch.argmax(logits, 1)
                loss_rts_mu = np.asscalar(loss_rts) / batch_size
                loss_adv_mu = np.asscalar(loss_adv) / batch_size
                num_succeed = np.asscalar(torch.sum(by == pred))
                loss_adv = loss_adv_mu
                loss_rts = loss_rts_mu
            print('s2-step: %d, loss adv: %.2f, loss rts: %.5f, succeed: %d' % (i, loss_adv, loss_rts, num_succeed))

        optimizer.zero_grad()
        loss.sum().backward()
        optimizer.step()

    rts = rts_model.saliency_fn(adv_images, by, model_confidence=5, return_classification_logits=False)
    logits = rts_model.blackbox_logits_fn(adv_images)

    dobj['adv_x' ] = adv_images.detach().cpu().numpy()
    dobj['adv_rts'] = rts.detach().cpu().numpy()
    dobj['adv_logits'] = logits.detach().cpu().numpy()
    dobj['adv_succeed'] = (logits.argmax(1) == by).detach().cpu().numpy().astype(np.int64)
    dobj['trts'] = rts_benign
    return dobj


def attack(config):
    rts_model = RTSDensenet169('', config['device'] == 'gpu')
    freeze_model(rts_model.blackbox_model)
    freeze_model(rts_model.saliency)

    data_arx = np.load(config['data_path'])
    img_x, img_yt = data_arx['img_x'], data_arx['img_yt']
    rts_target = data_arx['saliency_benign_yt']
    rts_benign = data_arx['saliency_benign_y']

# ['saliency_benign_y', 'saliency_benign_yt', 'img_x', 'img_y', 'img_yt']

    n, batch_size = len(img_x), config['batch_size']
    num_batches = (n + batch_size - 1) // batch_size
    save_dobjs = []

    start_time = time.time()

    for i in range(num_batches):
        si = i * batch_size
        ei = min(si + batch_size, n)
        bx, byt, bm0 = img_x[si:ei], img_yt[si:ei], rts_target[si:ei]
        dobj = attack_batch(config, rts_model, (bx, byt), bm0)
        dobj['brts'] = rts_benign[si:ei]
        save_dobjs.append(dobj)

    estimated_time = time.time() - start_time

    keys = list(save_dobjs[0].keys())
    save_dobj = {}
    for key in keys:
        save_dobj[key] = np.concatenate([i[key] for i in save_dobjs], axis=0)

    save_dobj['time'] = estimated_time
    save_dobj['img_x'] = img_x
    np.savez(config['save_path'], **save_dobj)


def attack_rts_2(data_path, fName):
    config = {}
    config['data_path'] = data_path
    config['save_path'] = f'{fName}'
    config['device'] = 'gpu'
    config['batch_size'] = 10
    config['epsilon'] = 0.031
    config['s1_iters'] = 300
    config['s1_lr'] = 1./255
    config['s2_iters'] = 1000
    config['s2_lr'] = 1./255
    config['c'] = 5.
    config['kappa'] = 0
    config['tau'] = 0.0005

#     if (not os.path.exists(config['save_path'])):
#       os.mkdir(config['save_path'])
#     config['save_path'] += fName


    if config['device'] is None:
        if torch.cuda.is_available():
            config['device'] = 'gpu'
        else:
            config['device'] = 'cpu'
    attack(config)


In [2]:
attack_rts_2('rts.npz', 'output_1.npz')