In [None]:
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import glob
from PIL import Image
import argparse
import os

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
import random
from shutil import copy2
import torch.nn.functional as F
import torch.autograd as autograd

import cv2

from scipy.spatial.distance import cdist

import math
from torch.utils import model_zoo

from copy import deepcopy
import re
from collections import OrderedDict
from torch.autograd import Function
import time
from torch.optim import Adam


import cv2
from skimage import exposure
from skimage import filters
import matplotlib.pyplot as plt
# from google.colab.patches import cv2_imshow

In [None]:
# torch.cuda.set_device(1)
# torch.cuda.current_device()

In [None]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

def imagenet_normalize(t, mean=None, std=None):
    if mean is None:
        mean = IMAGENET_MEAN
    if std is None:
        std= IMAGENET_STD

    ts = []
    for i in range(3):
        ts.append(torch.unsqueeze((t[:, i] - mean[i]) / std[i], 1))
    return torch.cat(ts, dim=1)

preprocess = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])

In [None]:
def freeze_model(model):
    for param in model.parameters():
        param.requires_grad_(False)

In [None]:
model_urls = {
    'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
    'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
}


#
# AlexNet | begin
#

ALEXNET_NAME_MAP = {
    "conv1.weight": "features.0.weight",
    "conv1.bias": "features.0.bias",
    "conv2.weight": "features.3.weight",
    "conv2.bias": "features.3.bias",
    "conv3.weight": "features.6.weight",
    "conv3.bias": "features.6.bias",
    "conv4.weight": "features.8.weight",
    "conv4.bias": "features.8.bias",
    "conv5.weight": "features.10.weight",
    "conv5.bias": "features.10.bias",
    "fc1.weight": "classifier.1.weight",
    "fc1.bias": "classifier.1.bias",
    "fc2.weight": "classifier.4.weight",
    "fc2.bias": "classifier.4.bias",
    "fc3.weight": "classifier.6.weight",
    "fc3.bias": "classifier.6.bias"
}


class AlexNet(nn.Module):

    def __init__(self, num_classes=1000):
        super(AlexNet, self).__init__()

        # convolutional layers
        self.conv1 = nn.Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
        self.conv2 = nn.Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
        self.conv3 = nn.Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.conv4 = nn.Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.conv5 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        # pooling layers
        self.pool1 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), dilation=(1, 1))
        self.pool2 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), dilation=(1, 1))
        self.pool5 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), dilation=(1, 1))
        # fully connected layers
        self.fc1 = nn.Linear(9216, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, num_classes)

    def forward(self, x, out_keys=None):
        out = {}
        out['c1'] = self.conv1(x)
        out['r1'] = F.relu(out['c1'])
        out['p1'] = self.pool1(out['r1'])
        out['r2'] = F.relu(self.conv2(out['p1']))
        out['p2'] = self.pool2(out['r2'])
        out['r3'] = F.relu(self.conv3(out['p2']))
        out['r4'] = F.relu(self.conv4(out['r3']))
        out['r5'] = F.relu(self.conv5(out['r4']))
        out['p5'] = self.pool5(out['r5'])
        out['fc1'] = F.relu(self.fc1(out['p5'].view(1, -1)))
        out['fc2'] = F.relu(self.fc2(out['fc1']))
        out['fc3'] = self.fc3(out['fc2'])

        if out_keys is None:
            return out['fc3']

        res = {}
        for key in out_keys:
            res[key] = out[key]
        return res


def convert_alexnet_weights(src_state, dest_state):
    for key in dest_state:
        if key in ALEXNET_NAME_MAP:
            dest_state[key] = deepcopy(src_state[ALEXNET_NAME_MAP[key]])
    return dest_state


def alexnet(pretrained=False, **kwargs):
    r"""AlexNet model architecture from the
    `"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = AlexNet(**kwargs)
    if pretrained:
        src_state = model_zoo.load_url(model_urls['alexnet'])
        dest_state = convert_alexnet_weights(src_state, model.state_dict())
        model.load_state_dict(dest_state)
    return model

#
# AlexNet | end
#

#
# ResNet | begin
#


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x, out_keys=None):
        out = {}
        x = self.conv1(x)
        out["c1"] = x
        x = self.bn1(x)
        out["bn1"] = x
        x = self.relu(x)
        out["r1"] = x
        x = self.maxpool(x)
        out["p1"] = x

        x = self.layer1(x)
        out["l1"] = x
        x = self.layer2(x)
        out["l2"] = x
        x = self.layer3(x)
        out["l3"] = x
        x = self.layer4(x)
        out["l4"] = x

        x = self.avgpool(x)
        out["gvp"] = x
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        out["fc"] = x

        if out_keys is None:
            return x

        res = {}
        for key in out_keys:
            res[key] = out[key]
        return res


def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model


def resnet34(pretrained=False, **kwargs):
    """Constructs a ResNet-34 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
    return model


def resnet50(pretrained=False, **kwargs):
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
    return model


def resnet101(pretrained=False, **kwargs):
    """Constructs a ResNet-101 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
    return model


def resnet152(pretrained=False, **kwargs):
    """Constructs a ResNet-152 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
    return model


# ResNet | end


# DenseNet | begin

def densenet121(pretrained=False, **kwargs):
    r"""Densenet-121 model from
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
                     **kwargs)
    if pretrained:
        # '.'s are no longer allowed in module names, but pervious _DenseLayer
        # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
        # They are also in the checkpoints in model_urls. This pattern is used
        # to find such keys.
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        state_dict = model_zoo.load_url(model_urls['densenet121'])
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        model.load_state_dict(state_dict)
    return model


def densenet169(pretrained=False, **kwargs):
    r"""Densenet-169 model from
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
                     **kwargs)
    if pretrained:
        # '.'s are no longer allowed in module names, but pervious _DenseLayer
        # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
        # They are also in the checkpoints in model_urls. This pattern is used
        # to find such keys.
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        state_dict = model_zoo.load_url(model_urls['densenet169'])
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        model.load_state_dict(state_dict)
    return model


def densenet201(pretrained=False, **kwargs):
    r"""Densenet-201 model from
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
                     **kwargs)
    if pretrained:
        # '.'s are no longer allowed in module names, but pervious _DenseLayer
        # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
        # They are also in the checkpoints in model_urls. This pattern is used
        # to find such keys.
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        state_dict = model_zoo.load_url(model_urls['densenet201'])
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        model.load_state_dict(state_dict)
    return model


def densenet161(pretrained=False, **kwargs):
    r"""Densenet-161 model from
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
                     **kwargs)
    if pretrained:
        # '.'s are no longer allowed in module names, but pervious _DenseLayer
        # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
        # They are also in the checkpoints in model_urls. This pattern is used
        # to find such keys.
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        state_dict = model_zoo.load_url(model_urls['densenet161'])
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        model.load_state_dict(state_dict)
    return model


class _DenseLayer(nn.Sequential):
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
        super(_DenseLayer, self).__init__()
        self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
        self.add_module('relu1', nn.ReLU(inplace=True)),
        self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
                        growth_rate, kernel_size=1, stride=1, bias=False)),
        self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
        self.add_module('relu2', nn.ReLU(inplace=True)),
        self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
                        kernel_size=3, stride=1, padding=1, bias=False)),
        self.drop_rate = drop_rate

    def forward(self, x):
        new_features = super(_DenseLayer, self).forward(x)
        if self.drop_rate > 0:
            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
        return torch.cat([x, new_features], 1)


class _DenseBlock(nn.Sequential):
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
            layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)
            self.add_module('denselayer%d' % (i + 1), layer)


class _Transition(nn.Sequential):
    def __init__(self, num_input_features, num_output_features):
        super(_Transition, self).__init__()
        self.add_module('norm', nn.BatchNorm2d(num_input_features))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
                                          kernel_size=1, stride=1, bias=False))
        self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))


class DenseNet(nn.Module):
    r"""Densenet-BC model class, based on
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_

    Args:
        growth_rate (int) - how many filters to add each layer (`k` in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
    """
    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
                 num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000):

        super(DenseNet, self).__init__()

        # First convolution
        self.features = nn.Sequential(OrderedDict([
            ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
            ('norm0', nn.BatchNorm2d(num_init_features)),
            ('relu0', nn.ReLU(inplace=True)),
            ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
        ]))

        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
                                bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)
            self.features.add_module('denseblock%d' % (i + 1), block)
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
                trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
                self.features.add_module('transition%d' % (i + 1), trans)
                num_features = num_features // 2

        # Final batch norm
        self.features.add_module('norm5', nn.BatchNorm2d(num_features))

        # Linear layer
        self.classifier = nn.Linear(num_features, num_classes)

        # Official init from torch repo.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal(m.weight.data)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

    def forward(self, x, out_keys=None):
        out_dict = {}
        features = self.features(x)
        out = F.relu(features, inplace=True)
        out_dict['l'] = out
        out = F.avg_pool2d(out, kernel_size=7, stride=1)
        out_dict['gvp'] = out
        out = out.view(features.size(0), -1)
        out = self.classifier(out)
        out_dict['fc'] = out
        if out_keys is None:
            return out

        res = {}
        for key in out_keys:
            res[key] = out_dict[key]
        return res

# DenseNet | end


def get_gaussian_blur_kernel(ksize, sigma):
    ker = cv2.getGaussianKernel(ksize, sigma).astype(np.float32)
    blur_kernel = (ker * ker.T)[None, None]
    blur_kernel = torch.tensor(blur_kernel)

    return blur_kernel


def gaussian_blur(x, ksize, sigma):
    """

    Args:
    :param x: torch.tensor (n, c, h, w), will padding with reflection
    :param ksize: int
    :param sigma: int
    :return:
    """
    psize = int((ksize - 1) / 2)
    blur_kernel = get_gaussian_blur_kernel(ksize, sigma)
    x_padded = F.pad(x, [psize] * 4, mode="reflect")
    blurs = []
    for i in range(3):
        blurs.append(F.conv2d(x_padded[:, i, None], blur_kernel))
    blurred = torch.cat(blurs, 1)

    return blurred


class GuidedBackpropReLU(Function):

    @staticmethod
    def forward(ctx, input):
        positive_mask = (input > 0).type_as(input)
        output = torch.addcmul(torch.zeros(input.size()).type_as(input), input, positive_mask)
        ctx.save_for_backward(input, output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, output = ctx.saved_tensors
        grad_input = None

        positive_mask_1 = (input > 0).type_as(grad_output)
        positive_mask_2 = (grad_output > 0).type_as(grad_output)
        grad_input = torch.addcmul(torch.zeros(input.size()).type_as(input), torch.addcmul(torch.zeros(input.size()).type_as(input), grad_output, positive_mask_1), positive_mask_2)

        return grad_input


def freeze_model(model):
    for param in model.parameters():
        param.requires_grad_(False)


### SoftReLU


class SoftReLU(nn.Module):

    def __init__(self, eps=1e-6):
        super(SoftReLU, self).__init__()
        self.eps = eps

    def forward(self, x):
        # mask = (x > 0).float()
        # return torch.sqrt(x * x + self.eps) * mask
        return SoftReLUFunc.apply(x)


class SoftReLUFunc(autograd.Function):

    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        # v2
        x,  = ctx.saved_tensors
        # x2 = x * x
        grad_input = grad_output.clone()
        i1 = (x < 0)
        i2 = x >= 0
        xi1 = x[i1]
        xi2 = x[i2]
        n1, n2 = xi1.numel(), xi2.numel()
        assert n1 + n2 == x.numel()
        if n1 > 0:
            xi12 = xi1 * xi1
            new_v = xi1 / torch.sqrt(xi12 + 1e-4) + 1
            grad_input[i1] = grad_input[i1] * new_v
        if n2 > 0:
            xi22 = xi2 * xi2
            new_v = xi2 / torch.sqrt(xi22 + 1e-4)
            grad_input[i2] = grad_input[i2] * new_v
        return grad_input


In [None]:
model_urls = {
    'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
    'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
}


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = SoftReLU()
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = SoftReLU()
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=1000):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = SoftReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x, out_keys=None):
        out = {}
        x = self.conv1(x)
        out["c1"] = x
        x = self.bn1(x)
        out["bn1"] = x
        x = self.relu(x)
        out["r1"] = x
        x = self.maxpool(x)
        out["p1"] = x

        x = self.layer1(x)
        out["l1"] = x
        x = self.layer2(x)
        out["l2"] = x
        x = self.layer3(x)
        out["l3"] = x
        x = self.layer4(x)
        out["l4"] = x

        x = self.avgpool(x)
        out["gvp"] = x
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        out["fc"] = x

        if out_keys is None:
            return x

        res = {}
        for key in out_keys:
            res[key] = out[key]
        return res


def resnet50_soft(pretrained=False, **kwargs):
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
    return model


In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class StadvTVLoss(nn.Module):

    def forward(self, flows):
        padded_flows = F.pad(flows, (1, 1, 1, 1), mode='replicate')
        height, width = flows.size(2), flows.size(3)
        n = float(np.sqrt(height * width))
        shifted_flows = [
            padded_flows[:, :, 2:, 2:],
            padded_flows[:, :, 2:, :-2],
            padded_flows[:, :, :-2, 2:],
            padded_flows[:, :, :-2, :-2]
        ]

        diffs = [(flows[:, 1] - shifted_flow[:, 1]) ** 2 + (flows[:, 0] - shifted_flow[:, 0]) ** 2
                 for shifted_flow in shifted_flows]
        loss = torch.stack(diffs).sum(2, keepdim=True).sum(3, keepdim=True).sum(0, keepdim=True).view(-1)
        loss = torch.sqrt(loss)
        return loss / n


class StadvFlowLoss(nn.Module):

    def forward(self,flows, epsilon=1e-8):
        padded_flows = F.pad(flows, (1, 1, 1, 1), mode='replicate')
        shifted_flows = [
            padded_flows[:, :, 2:, 2:],
            padded_flows[:, :, 2:, :-2],
            padded_flows[:, :, :-2, 2:],
            padded_flows[:, :, :-2, :-2]
        ]

        diffs = [torch.sqrt((flows[:, 1] - shifted_flow[:, 1]) ** 2 +
                            (flows[:, 0] - shifted_flow[:, 0]) ** 2 +
                            epsilon) for shifted_flow in shifted_flows
                 ]
        # shape: (4, n, h - 1, w - 1) => (n, )
        loss = torch.stack(diffs).sum(2, keepdim=True).sum(3, keepdim=True).sum(0, keepdim=True).view(-1)
        return loss


class StadvFlow(nn.Module):

    def forward(self, images, flows):
        batch_size, n_channels, height, width = images.shape
        basegrid = torch.stack(torch.meshgrid([torch.arange(height, device=images.device),
                                               torch.arange(width, device=images.device)]))
        batched_basegrid = basegrid.expand(batch_size, -1, -1, -1)

        sampling_grid = batched_basegrid.float() + flows
        sampling_grid_x = torch.clamp(sampling_grid[:, 1], 0., float(width) - 1)
        sampling_grid_y = torch.clamp(sampling_grid[:, 0], 0., float(height) - 1)

        x0 = sampling_grid_x.floor().long()
        x1 = x0 + 1
        y0 = sampling_grid_y.floor().long()
        y1 = y0 + 1

        x0.clamp_(0, width - 2)
        x1.clamp_(0, width - 1)
        y0.clamp_(0, height - 2)
        y1.clamp_(0, height - 1)

        b = torch.arange(batch_size).view(batch_size, 1, 1).expand(-1, height, width)

        Ia = images[b, :, y0, x0].permute(0, 3, 1, 2)
        Ib = images[b, :, y1, x0].permute(0, 3, 1, 2)
        Ic = images[b, :, y0, x1].permute(0, 3, 1, 2)
        Id = images[b, :, y1, x1].permute(0, 3, 1, 2)

        x0 = x0.float()
        x1 = x1.float()
        y0 = y0.float()
        y1 = y1.float()

        wa = (x1 - sampling_grid_x) * (y1 - sampling_grid_y)
        wb = (x1 - sampling_grid_x) * (sampling_grid_y - y0)
        wc = (sampling_grid_x - x0) * (y1 - sampling_grid_y)
        wd = (sampling_grid_x - x0) * (sampling_grid_y - y0)

        wa = wa.unsqueeze(1)
        wb = wb.unsqueeze(1)
        wc = wc.unsqueeze(1)
        wd = wd.unsqueeze(1)

        perturbed_image = torch.stack([wa * Ia, wb * Ib, wc * Ic, wd * Id]).sum(0)
        return perturbed_image

### AdvEdge

In [None]:
class CAM(object):

    def __init__(self):
        pass

    def __call__(self, model_tup, forward_tup, x, y=None):
        return cam_forward(model_tup, forward_tup, x, y)


def cam_forward(model_tup, forward_tup, x, y):
    forward_fn, fc_weight_fn = forward_tup
    batch_size = x.size(0)
    cuda = x.is_cuda
    if y is None:
        with torch.no_grad():
            logits = forward_fn(model_tup, x)[-1]
            logits = logits.cpu().numpy()[0]
        true_label = int(np.argmax(logits))
        y = torch.tensor([true_label])
        if cuda:
            y = y.cuda()

    vs, gs, logits = forward_fn(model_tup, x)
    wc = fc_weight_fn(model_tup)[y].view(batch_size, -1, 1, 1)
    prod = (wc * vs).sum(1, keepdim=True)

    return logits, prod


In [None]:

def cam_resnet50_forward(model_tup, x):
    model, pre_fn = model_tup[:2]
    res = model(pre_fn(x), out_keys=["l4", "gvp", "fc"])
    return res['l4'], res['gvp'], res['fc']


def cam_resnet50_fc_weight(model_tup):
    model = model_tup[0]
    return model.fc.weight


def cam_resnet50():
    model = resnet50(pretrained=True)
    model_tup = (model, imagenet_normalize, (224, 224))

    return model_tup, (cam_resnet50_forward, cam_resnet50_fc_weight)


def cam_densenet169_forward(model_tup, x):
    model, pre_fn = model_tup[:2]
    res = model(pre_fn(x), out_keys=['l', 'gvp', 'fc'])
    return res['l'], res['gvp'], res['fc']


def cam_densenet169_fc_weight(model_tup):
    model = model_tup[0]
    return model.classifier.weight


def cam_densenet169():
    model = densenet169(pretrained=True)
    model_tup = (model, imagenet_normalize, (224, 224))

    return model_tup, (cam_densenet169_forward, cam_densenet169_fc_weight)


In [None]:
def load_model(config):
    if config['model'] == 'resnet50':
        model_tup, forward_tup = cam_resnet50()
    if config['model'] == 'densenet169':
        model_tup, forward_tup = cam_densenet169()
    model = model_tup[0]
    freeze_model(model)
    model.train(False)
    if config['device'] == 'gpu':
        model.cuda()
    return model_tup, forward_tup


def attack_batch(config, model_tup, forward_tup, batch_tup, cam_benign):
    model_tup, forward_tup = load_model(config)
    model, pre_fn = model_tup[:2]
    device = 'cuda' if config['device'] == 'gpu' else 'cpu'
    # cuda = config.device == 'gpu'
    bx_np, by_np = batch_tup
    batch_size = len(bx_np)
    bx, by, m0 = (torch.tensor(bx_np, device=device), torch.tensor(by_np, device=device),
              torch.tensor(cam_benign, device=device))
    m0_flatten = m0.view(batch_size, -1)

    dobj = {}

    unpert_gray = bx.cpu().numpy().mean(axis = 1, keepdims=True)
    # print(unpert_gray.shape)
    edges = np.empty_like(unpert_gray)
    # print(edges.shape)
    for index, image in enumerate(unpert_gray):
        edges[index] = filters.sobel(image.squeeze(0))
    # print(edges.shape)
    weights = torch.tensor(edges).to('cuda')

    images = bx
    flows = 0.2 * (torch.rand(batch_size, 2, images.size(2), images.size(3), device=device) - 0.5)
    flows.requires_grad_(True)

    tau = config['tau']
    flow_obj = StadvFlow()
    flow_loss_obj = StadvFlowLoss()
    flow_tvloss_obj = StadvTVLoss()
    optimizer = Adam([flows], lr=0.01, amsgrad=True)

    for i in range(config['s1_iters']):
        adv_images = flow_obj(images, flows)

        pert = (adv_images - bx) * weights
        adv_images = bx + pert

        logits = model(pre_fn(adv_images))
        adv_loss = F.nll_loss(F.log_softmax(logits, dim=-1), by, reduction='none')
        flow_loss = flow_loss_obj(flows)
        total_loss = adv_loss + tau * flow_loss

        optimizer.zero_grad()
        total_loss.sum().backward()
        optimizer.step()
        if i % 50 == 0 or i == config['s1_iters'] - 1:
            with torch.no_grad():
                flow_loss = flow_tvloss_obj(flows)
                preds = logits.argmax(1)
                succeed = (preds == by).float().mean().item()
            print('s1-step: %d, average adv loss: %.4f, average flow loss: %.4f, succeed: %.2f' %
                  (i, adv_loss.mean().item(), flow_loss.mean().item(), succeed))

    optimizer = Adam([flows], lr=0.01, amsgrad=True)
    c_begin, c_final = config['c'], config['c'] * 2
    c_inc = (c_final - c_begin) / config['s2_iters']
    c_now = config['c']
    for i in range(config['s2_iters']):
        c_now += c_inc
        adv_images = flow_obj(images, flows)
        flow_loss = flow_loss_obj(flows)

        logits, cam = cam_forward(model_tup, forward_tup, adv_images, by)
        adv_loss = F.nll_loss(F.log_softmax(logits, dim=-1), by, reduction='none')
        cam_flatten = cam.view(batch_size, -1)
        cam_flatten = cam_flatten - cam_flatten.min(1, True)[0]
        cam_flatten = cam_flatten / cam_flatten.max(1, True)[0]
        diff = cam_flatten - m0_flatten
        loss_cam = (diff * diff).mean(1)
        total_loss = 2 * adv_loss + tau * flow_loss + c_now * loss_cam

        optimizer.zero_grad()
        total_loss.sum().backward()
        optimizer.step()

        # print message
        if i % 100 == 0:
            with torch.no_grad():
                pred = torch.argmax(logits, 1)
                loss_cam_mu = loss_cam.mean().item()
                loss_adv_mu = adv_loss.mean().item()
                flow_loss = flow_tvloss_obj(flows).mean().item()
                num_succeed = np.asscalar(torch.sum(by == pred))
                adv_loss = loss_adv_mu
                loss_cam = loss_cam_mu
            print('s2-step: %d, loss flow: %.3f, loss adv: %.2f, loss cam: %.5f, succeed: %d' %
                  (i, flow_loss, adv_loss, loss_cam, num_succeed))

    logits, cam = cam_forward(model_tup, forward_tup, adv_images, by)
    adv_loss = F.nll_loss(F.log_softmax(logits, dim=-1), by, reduction='none')
    cam_flatten = cam.view(batch_size, -1)
    cam_flatten = cam_flatten - cam_flatten.min(1, True)[0]
    cam_flatten = cam_flatten / cam_flatten.max(1, True)[0]
    dobj['adv_x'] = adv_images.detach().cpu().numpy()
    dobj['adv_cam'] = cam_flatten.detach().cpu().numpy().reshape((batch_size, 1, 7, 7))
    dobj['adv_logits'] = logits.detach().cpu().numpy()
    dobj['adv_succeed'] = (logits.argmax(1) == by).detach().cpu().numpy().astype(np.int64)
    dobj['tcam'] = cam_benign
    return dobj


def attack(config):
    model_tup, forward_tup = cam_resnet50()
    model_tup[0].train(False)
    if config['device'] == 'gpu':
        model_tup[0].cuda()
    freeze_model(model_tup[0])

    data_arx = np.load(config['data_path'])
    img_x, img_yt = data_arx['img_x'], data_arx['img_yt']
    cam_target = data_arx['mask_x']
    # cam_benign = data_arx['att_bcams']

    n, batch_size = len(img_x), config['batch_size']
    num_batches = (n + batch_size - 1) // batch_size
    save_dobjs = []

    start_time = time.time()

    for i in range(num_batches):
        si = i * batch_size
        ei = min(si + batch_size, n)
        bx, byt, bm0 = img_x[si:ei], img_yt[si:ei], cam_target[si:ei]
        dobj = attack_batch(config, model_tup, forward_tup, (bx, byt), bm0)
        # dobj['bcam'] = cam_target[si:ei]
        save_dobjs.append(dobj)

    estimated_time = time.time() - start_time

    keys = list(save_dobjs[0].keys())
    save_dobj = {}
    for key in keys:
        save_dobj[key] = np.concatenate([i[key] for i in save_dobjs], axis=0)

    save_dobj['time'] = estimated_time
    save_dobj['img_y'] = data_arx['img_y']
    np.savez(config['save_path'], **save_dobj)


def attack_cam(path, fname):
    config = {}
    config['data_path'] = path
    config['save_path'] = f'{fname}'
    config['device'] = 'gpu'
    config['batch_size'] = 10
    config['model'] = 'resnet50'
    config['epsilon'] = 0.031
    config['s1_iters'] = 200
    config['s1_lr'] = 1./255
    config['s2_iters'] = 600
    config['s2_lr'] = 1./255
    config['tau'] = 0.0005
    config['c'] = 5.

    # if(not os.path.exists('cam_attack_output/')):
    #   os.mkdir('cam_attack_output')

    attack(config)



In [None]:
attack_cam('fold_1.npz', 'output_1.npz')

Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /root/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 160MB/s]
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


s1-step: 0, average adv loss: 18.4268, average flow loss: 0.1982, succeed: 0.00
s1-step: 50, average adv loss: 12.0129, average flow loss: 0.1776, succeed: 0.00
s1-step: 100, average adv loss: 9.2591, average flow loss: 0.2579, succeed: 0.00
s1-step: 150, average adv loss: 8.3619, average flow loss: 0.2863, succeed: 0.10
s1-step: 199, average adv loss: 7.9932, average flow loss: 0.2998, succeed: 0.10
s2-step: 0, loss flow: 0.302, loss adv: 8.56, loss cam: 0.14685, succeed: 0


  num_succeed = np.asscalar(torch.sum(by == pred))


s2-step: 100, loss flow: 0.224, loss adv: 0.08, loss cam: 0.05930, succeed: 10
s2-step: 200, loss flow: 0.189, loss adv: 0.07, loss cam: 0.04824, succeed: 10
s2-step: 300, loss flow: 0.178, loss adv: 0.06, loss cam: 0.04282, succeed: 10
s2-step: 400, loss flow: 0.174, loss adv: 0.06, loss cam: 0.03923, succeed: 10
s2-step: 500, loss flow: 0.171, loss adv: 0.06, loss cam: 0.03636, succeed: 10


### AdvEdge+

In [None]:
def load_model(config):
    if config['model'] == 'resnet50':
        model_tup, forward_tup = cam_resnet50()
    if config['model'] == 'densenet169':
        model_tup, forward_tup = cam_densenet169()
    model = model_tup[0]
    freeze_model(model)
    model.train(False)
    if config['device'] == 'gpu':
        model.cuda()
    return model_tup, forward_tup


def attack_batch(config, model_tup, forward_tup, batch_tup, cam_benign):
    model_tup, forward_tup = load_model(config)
    model, pre_fn = model_tup[:2]
    device = 'cuda' if config['device'] == 'gpu' else 'cpu'
    # cuda = config.device == 'gpu'
    bx_np, by_np = batch_tup
    batch_size = len(bx_np)
    bx, by, m0 = (torch.tensor(bx_np, device=device), torch.tensor(by_np, device=device),
              torch.tensor(cam_benign, device=device))
    m0_flatten = m0.view(batch_size, -1)

    dobj = {}

    unpert_gray = bx.cpu().numpy().mean(axis = 1, keepdims=True)
    # print(unpert_gray.shape)
    edges = np.empty_like(unpert_gray)
    # print(edges.shape)
    for index, image in enumerate(unpert_gray):
        edges[index] = filters.sobel(image.squeeze(0))
    # print(edges.shape)
    weights = torch.tensor(edges).to('cuda')

    images = bx
    flows = 0.2 * (torch.rand(batch_size, 2, images.size(2), images.size(3), device=device) - 0.5)
    flows.requires_grad_(True)

    tau = config['tau']
    flow_obj = StadvFlow()
    flow_loss_obj = StadvFlowLoss()
    flow_tvloss_obj = StadvTVLoss()
    optimizer = Adam([flows], lr=0.01, amsgrad=True)

    for i in range(config['s1_iters']):
        adv_images = flow_obj(images, flows)

        pert = (adv_images - bx) * weights
        adv_images = bx + torch.where(weights > 0.1, pert, torch.tensor(0.).to(device))

        logits = model(pre_fn(adv_images))
        adv_loss = F.nll_loss(F.log_softmax(logits, dim=-1), by, reduction='none')
        flow_loss = flow_loss_obj(flows)
        total_loss = adv_loss + tau * flow_loss

        optimizer.zero_grad()
        total_loss.sum().backward()
        optimizer.step()
        if i % 50 == 0 or i == config['s1_iters'] - 1:
            with torch.no_grad():
                flow_loss = flow_tvloss_obj(flows)
                preds = logits.argmax(1)
                succeed = (preds == by).float().mean().item()
            print('s1-step: %d, average adv loss: %.4f, average flow loss: %.4f, succeed: %.2f' %
                  (i, adv_loss.mean().item(), flow_loss.mean().item(), succeed))

    optimizer = Adam([flows], lr=0.01, amsgrad=True)
    c_begin, c_final = config['c'], config['c'] * 2
    c_inc = (c_final - c_begin) / config['s2_iters']
    c_now = config['c']
    for i in range(config['s2_iters']):
        c_now += c_inc
        adv_images = flow_obj(images, flows)
        flow_loss = flow_loss_obj(flows)

        logits, cam = cam_forward(model_tup, forward_tup, adv_images, by)
        adv_loss = F.nll_loss(F.log_softmax(logits, dim=-1), by, reduction='none')
        cam_flatten = cam.view(batch_size, -1)
        cam_flatten = cam_flatten - cam_flatten.min(1, True)[0]
        cam_flatten = cam_flatten / cam_flatten.max(1, True)[0]
        diff = cam_flatten - m0_flatten
        loss_cam = (diff * diff).mean(1)
        total_loss = 2 * adv_loss + tau * flow_loss + c_now * loss_cam

        optimizer.zero_grad()
        total_loss.sum().backward()
        optimizer.step()

        # print message
        if i % 100 == 0:
            with torch.no_grad():
                pred = torch.argmax(logits, 1)
                loss_cam_mu = loss_cam.mean().item()
                loss_adv_mu = adv_loss.mean().item()
                flow_loss = flow_tvloss_obj(flows).mean().item()
                num_succeed = np.asscalar(torch.sum(by == pred))
                adv_loss = loss_adv_mu
                loss_cam = loss_cam_mu
            print('s2-step: %d, loss flow: %.3f, loss adv: %.2f, loss cam: %.5f, succeed: %d' %
                  (i, flow_loss, adv_loss, loss_cam, num_succeed))

    logits, cam = cam_forward(model_tup, forward_tup, adv_images, by)
    adv_loss = F.nll_loss(F.log_softmax(logits, dim=-1), by, reduction='none')
    cam_flatten = cam.view(batch_size, -1)
    cam_flatten = cam_flatten - cam_flatten.min(1, True)[0]
    cam_flatten = cam_flatten / cam_flatten.max(1, True)[0]
    dobj['adv_x'] = adv_images.detach().cpu().numpy()
    dobj['adv_cam'] = cam_flatten.detach().cpu().numpy().reshape((batch_size, 1, 7, 7))
    dobj['adv_logits'] = logits.detach().cpu().numpy()
    dobj['adv_succeed'] = (logits.argmax(1) == by).detach().cpu().numpy().astype(np.int64)
    dobj['tcam'] = cam_benign
    return dobj


def attack(config):
    model_tup, forward_tup = cam_resnet50()
    model_tup[0].train(False)
    if config['device'] == 'gpu':
        model_tup[0].cuda()
    freeze_model(model_tup[0])

    data_arx = np.load(config['data_path'])
    img_x, img_yt = data_arx['img_x'], data_arx['img_yt']
    cam_target = data_arx['mask_x']
    # cam_benign = data_arx['att_bcams']

    n, batch_size = len(img_x), config['batch_size']
    num_batches = (n + batch_size - 1) // batch_size
    save_dobjs = []

    start_time = time.time()

    for i in range(num_batches):
        si = i * batch_size
        ei = min(si + batch_size, n)
        bx, byt, bm0 = img_x[si:ei], img_yt[si:ei], cam_target[si:ei]
        dobj = attack_batch(config, model_tup, forward_tup, (bx, byt), bm0)
        # dobj['bcam'] = cam_target[si:ei]
        save_dobjs.append(dobj)

    estimated_time = time.time() - start_time

    keys = list(save_dobjs[0].keys())
    save_dobj = {}
    for key in keys:
        save_dobj[key] = np.concatenate([i[key] for i in save_dobjs], axis=0)

    save_dobj['time'] = estimated_time
    save_dobj['img_y'] = data_arx['img_y']
    np.savez(config['save_path'], **save_dobj)


def attack_cam_2(path, fname):
    config = {}
    config['data_path'] = path
    config['save_path'] = f'{fname}'
    config['device'] = 'gpu'
    config['batch_size'] = 10
    config['model'] = 'resnet50'
    config['epsilon'] = 0.031
    config['s1_iters'] = 200
    config['s1_lr'] = 1./255
    config['s2_iters'] = 600
    config['s2_lr'] = 1./255
    config['tau'] = 0.0005
    config['c'] = 5.

    # if(not os.path.exists('cam_attack_output/')):
    #   os.mkdir('cam_attack_output')

    attack(config)



In [1]:
attack_cam_2('fold_1.npz', 'output_2.npz')