diff --git a/art/attacks/evasion/__init__.py b/art/attacks/evasion/__init__.py index 360fe47e42..65c336bba4 100644 --- a/art/attacks/evasion/__init__.py +++ b/art/attacks/evasion/__init__.py @@ -25,6 +25,7 @@ from art.attacks.evasion.saliency_map import SaliencyMapMethod from art.attacks.evasion.spatial_transformation import SpatialTransformation from art.attacks.evasion.universal_perturbation import UniversalPerturbation +from art.attacks.evasion.targeted_universal_perturbation import TargetedUniversalPerturbation from art.attacks.evasion.virtual_adversarial import VirtualAdversarialMethod from art.attacks.evasion.wasserstein import Wasserstein from art.attacks.evasion.zoo import ZooAttack @@ -37,3 +38,4 @@ from art.attacks.evasion.auto_attack import AutoAttack from art.attacks.evasion.auto_projected_gradient_descent import AutoProjectedGradientDescent from art.attacks.evasion.square_attack import SquareAttack +from art.attacks.evasion.simba import SimBA diff --git a/art/attacks/evasion/simba.py b/art/attacks/evasion/simba.py new file mode 100644 index 0000000000..80229c1fd2 --- /dev/null +++ b/art/attacks/evasion/simba.py @@ -0,0 +1,380 @@ +# MIT License +# +# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2020 +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit +# persons to whom the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the +# Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE +# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +This module implements the black-box attack `SimBA`. + +| Paper link: https://arxiv.org/abs/1905.07121 +""" +from __future__ import absolute_import, division, print_function, unicode_literals + +import logging +from typing import Optional + +import numpy as np +from scipy.fftpack import idct + +from art.attacks.attack import EvasionAttack +from art.estimators.estimator import BaseEstimator +from art.estimators.classification.classifier import ( + ClassGradientsMixin, + ClassifierGradients, +) +from art.config import ART_NUMPY_DTYPE + +logger = logging.getLogger(__name__) + + +class SimBA(EvasionAttack): + attack_params = EvasionAttack.attack_params + [ + "attack", + "max_iter", + "epsilon", + "order", + "freq_dim", + "stride", + "targeted", + "batch_size", + ] + + _estimator_requirements = (BaseEstimator, ClassGradientsMixin) + + def __init__( + self, + classifier: ClassifierGradients, + attack: str = "dct", + max_iter: int = 3000, + order: str = "random", + epsilon: float = 0.1, + freq_dim: int = 4, + stride: int = 1, + targeted: bool = False, + batch_size: int = 1, + ): + """ + Create a SimBA (dct) attack instance. + + :param classifier: A trained classifier. + :param attack: attack type: pixel (px) or DCT (dct) attacks + :param max_iter: The maximum number of iterations. + :param epsilon: Overshoot parameter. + :param order: order of pixel attacks: random or diagonal (diag) + :param freq_dim: dimensionality of 2D frequency space (DCT). + :param stride: stride for block order (DCT). + :param targeted: perform targeted attack + :param batch_size: Batch size (but, batch process unavailable in this implementation) + """ + super(SimBA, self).__init__(estimator=classifier) + + self.attack = attack + self.max_iter = max_iter + self.epsilon = epsilon + self.order = order + self.freq_dim = freq_dim + self.stride = stride + self.targeted = targeted + self.batch_size = batch_size + self._check_params() + + def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> np.ndarray: + """ + Generate adversarial samples and return them in an array. + + :param x: An array with the original inputs to be attacked. + :param y: An array with the original labels to be predicted. + :return: An array holding the adversarial examples. + """ + x = x.astype(ART_NUMPY_DTYPE) + preds = self.estimator.predict(x, batch_size=self.batch_size) + + if y is None: + if self.targeted: + raise ValueError("Target labels `y` need to be provided for a targeted attack.") + else: + # Use model predictions as correct outputs + logger.info("Using the model prediction as the correct label for SimBA.") + y_i = np.argmax(preds, axis=1) + else: + y_i = np.argmax(y, axis=1) + + desired_label = y_i[0] + current_label = np.argmax(preds, axis=1)[0] + last_prob = preds.reshape(-1)[desired_label] + + if self.estimator.channels_first: + nb_channels = x.shape[1] + else: + nb_channels = x.shape[3] + + n_dims = np.prod(x.shape) + + if self.attack == "px": + if self.order == "diag": + indices = self.diagonal_order(x.shape[2], nb_channels)[: self.max_iter] + elif self.order == "random": + indices = np.random.permutation(n_dims)[: self.max_iter] + indices_size = len(indices) + while indices_size < self.max_iter: + if self.order == "diag": + tmp_indices = self.diagonal_order(x.shape[2], nb_channels) + elif self.order == "random": + tmp_indices = np.random.permutation(n_dims) + indices = np.hstack((indices, tmp_indices))[: self.max_iter] + indices_size = len(indices) + elif self.attack == "dct": + indices = self._block_order(x.shape[2], nb_channels, initial_size=self.freq_dim, stride=self.stride)[ + : self.max_iter + ] + indices_size = len(indices) + while indices_size < self.max_iter: + tmp_indices = self._block_order(x.shape[2], nb_channels, initial_size=self.freq_dim, stride=self.stride) + indices = np.hstack((indices, tmp_indices))[: self.max_iter] + indices_size = len(indices) + + def trans(z): + return self._block_idct(z, block_size=x.shape[2]) + + clip_min = -np.inf + clip_max = np.inf + if self.estimator.clip_values is not None: + clip_min, clip_max = self.estimator.clip_values + + term_flag = 1 + if self.targeted: + if desired_label != current_label: + term_flag = 0 + else: + if desired_label == current_label: + term_flag = 0 + + nb_iter = 0 + while term_flag == 0 and nb_iter < self.max_iter: + diff = np.zeros(n_dims).astype(ART_NUMPY_DTYPE) + diff[indices[nb_iter]] = self.epsilon + + if self.attack == "dct": + left_preds = self.estimator.predict( + np.clip(x - trans(diff.reshape(x.shape)), clip_min, clip_max), batch_size=self.batch_size + ) + elif self.attack == "px": + left_preds = self.estimator.predict( + np.clip(x - diff.reshape(x.shape), clip_min, clip_max), batch_size=self.batch_size + ) + left_prob = left_preds.reshape(-1)[desired_label] + + if self.attack == "dct": + right_preds = self.estimator.predict( + np.clip(x + trans(diff.reshape(x.shape)), clip_min, clip_max), batch_size=self.batch_size + ) + elif self.attack == "px": + right_preds = self.estimator.predict( + np.clip(x + diff.reshape(x.shape), clip_min, clip_max), batch_size=self.batch_size + ) + right_prob = right_preds.reshape(-1)[desired_label] + + # Use (2 * int(self.targeted) - 1) to shorten code? + if self.targeted: + if left_prob > last_prob: + if left_prob > right_prob: + if self.attack == "dct": + x = np.clip(x - trans(diff.reshape(x.shape)), clip_min, clip_max) + elif self.attack == "px": + x = np.clip(x - diff.reshape(x.shape), clip_min, clip_max) + last_prob = left_prob + current_label = np.argmax(left_preds, axis=1)[0] + else: + if self.attack == "dct": + x = np.clip(x + trans(diff.reshape(x.shape)), clip_min, clip_max) + elif self.attack == "px": + x = np.clip(x + diff.reshape(x.shape), clip_min, clip_max) + last_prob = right_prob + current_label = np.argmax(right_preds, axis=1)[0] + else: + if right_prob > last_prob: + if self.attack == "dct": + x = np.clip(x + trans(diff.reshape(x.shape)), clip_min, clip_max) + elif self.attack == "px": + x = np.clip(x + diff.reshape(x.shape), clip_min, clip_max) + last_prob = right_prob + current_label = np.argmax(right_preds, axis=1)[0] + else: + if left_prob < last_prob: + if left_prob < right_prob: + if self.attack == "dct": + x = np.clip(x - trans(diff.reshape(x.shape)), clip_min, clip_max) + elif self.attack == "px": + x = np.clip(x - diff.reshape(x.shape), clip_min, clip_max) + last_prob = left_prob + current_label = np.argmax(left_preds, axis=1)[0] + else: + if self.attack == "dct": + x = np.clip(x + trans(diff.reshape(x.shape)), clip_min, clip_max) + elif self.attack == "px": + x = np.clip(x + diff.reshape(x.shape), clip_min, clip_max) + last_prob = right_prob + current_label = np.argmax(right_preds, axis=1)[0] + else: + if right_prob < last_prob: + if self.attack == "dct": + x = np.clip(x + trans(diff.reshape(x.shape)), clip_min, clip_max) + elif self.attack == "px": + x = np.clip(x + diff.reshape(x.shape), clip_min, clip_max) + last_prob = right_prob + current_label = np.argmax(right_preds, axis=1)[0] + + if self.targeted: + if desired_label == current_label: + term_flag = 1 + else: + if desired_label != current_label: + term_flag = 1 + + nb_iter = nb_iter + 1 + + if nb_iter < self.max_iter: + logger.info("SimBA (%s) %s attack succeed", self.attack, ["non-targeted", "targeted"][int(self.targeted)]) + else: + logger.info("SimBA (%s) %s attack failed", self.attack, ["non-targeted", "targeted"][int(self.targeted)]) + + return x + + def _check_params(self) -> None: + + if not isinstance(self.max_iter, (int, np.int)) or self.max_iter <= 0: + raise ValueError("The number of iterations must be a positive integer.") + + if self.epsilon < 0: + raise ValueError("The overshoot parameter must not be negative.") + + if self.batch_size != 1: + raise ValueError("The batch size `batch_size` has to be 1 in this implementation.") + + if not isinstance(self.stride, (int, np.int)) or self.stride <= 0: + raise ValueError("The `stride` value must be a positive integer.") + + if not isinstance(self.freq_dim, (int, np.int)) or self.freq_dim <= 0: + raise ValueError("The `freq_dim` value must be a positive integer.") + + if self.order != "random" and self.order != "diag": + raise ValueError("The order of pixel attacks has to be `random` or `diag`.") + + if self.attack != "px" and self.attack != "dct": + raise ValueError("The attack type has to be `px` or `dct`.") + + if not isinstance(self.targeted, (int)) or (self.targeted != 0 and self.targeted != 1): + raise ValueError("`targeted` has to be a logical value.") + + def _block_order(self, img_size, channels, initial_size=2, stride=1): + """ + Defines a block order, starting with top-left (initial_size x initial_size) submatrix + expanding by stride rows and columns whenever exhausted + randomized within the block and across channels. + e.g. (initial_size=2, stride=1) + [1, 3, 6] + [2, 4, 9] + [5, 7, 8] + + :param img_size: image size (i.e., width or height). + :param channels: the number of channels. + :param initial size: initial size for submatrix. + :param stride: stride size for expansion. + + :return z: An array holding the block order of DCT attacks. + """ + order = np.zeros((channels, img_size, img_size)).astype(ART_NUMPY_DTYPE) + total_elems = channels * initial_size * initial_size + perm = np.random.permutation(total_elems) + order[:, :initial_size, :initial_size] = perm.reshape((channels, initial_size, initial_size)) + for i in range(initial_size, img_size, stride): + num_elems = channels * (2 * stride * i + stride * stride) + perm = np.random.permutation(num_elems) + total_elems + num_first = channels * stride * (stride + i) + order[:, : (i + stride), i : (i + stride)] = perm[:num_first].reshape((channels, -1, stride)) + order[:, i : (i + stride), :i] = perm[num_first:].reshape((channels, stride, -1)) + total_elems += num_elems + if self.estimator.channels_first: + return order.reshape(1, -1).squeeze().argsort() + else: + return order.transpose(1, 2, 0).reshape(1, -1).squeeze().argsort() + + def _block_idct(self, x, block_size=8, masked=False, ratio=0.5): + """ + Applies IDCT to each block of size block_size. + + :param x: An array with the inputs to be attacked. + :param block_size: block size for DCT attacks. + :param masked: use the mask. + :param ratio: Ratio of the lowest frequency directions in order to make the adversarial perturbation in the low + frequency space. + + :return z: An array holding the order of DCT attacks. + """ + if not self.estimator.channels_first: + x = x.transpose(0, 3, 1, 2) + z = np.zeros(x.shape).astype(ART_NUMPY_DTYPE) + num_blocks = int(x.shape[2] / block_size) + mask = np.zeros((x.shape[0], x.shape[1], block_size, block_size)) + if type(ratio) != float: + for i in range(x.shape[0]): + mask[i, :, : int(block_size * ratio[i]), : int(block_size * ratio[i])] = 1 + else: + mask[:, :, : int(block_size * ratio), : int(block_size * ratio)] = 1 + for i in range(num_blocks): + for j in range(num_blocks): + submat = x[:, :, (i * block_size) : ((i + 1) * block_size), (j * block_size) : ((j + 1) * block_size)] + if masked: + submat = submat * mask + z[:, :, (i * block_size) : ((i + 1) * block_size), (j * block_size) : ((j + 1) * block_size)] = idct( + idct(submat, axis=3, norm="ortho"), axis=2, norm="ortho" + ) + + if self.estimator.channels_first: + return z + else: + return z.transpose(0, 2, 3, 1) + + def diagonal_order(self, image_size, channels): + """ + Defines a diagonal order for pixel attacks. + order is fixed across diagonals but are randomized across channels and within the diagonal + e.g. + [1, 2, 5] + [3, 4, 8] + [6, 7, 9] + + :param image_size: image size (i.e., width or height) + :param channels: the number of channels + + :return z: An array holding the diagonal order of pixel attacks. + """ + x = np.arange(0, image_size).cumsum() + order = np.zeros((image_size, image_size)).astype(ART_NUMPY_DTYPE) + for i in range(image_size): + order[i, : (image_size - i)] = i + x[i:] + for i in range(1, image_size): + reverse = order[image_size - i - 1].take([i for i in range(i - 1, -1, -1)]) + order[i, (image_size - i) :] = image_size * image_size - 1 - reverse + if channels > 1: + order_2d = order + order = np.zeros((channels, image_size, image_size)) + for i in range(channels): + order[i, :, :] = 3 * order_2d + i + + if self.estimator.channels_first: + return order.reshape(1, -1).squeeze().argsort() + else: + return order.transpose(1, 2, 0).reshape(1, -1).squeeze().argsort() diff --git a/art/attacks/evasion/targeted_universal_perturbation.py b/art/attacks/evasion/targeted_universal_perturbation.py new file mode 100644 index 0000000000..d57994e79a --- /dev/null +++ b/art/attacks/evasion/targeted_universal_perturbation.py @@ -0,0 +1,201 @@ +# MIT License +# +# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2020 +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit +# persons to whom the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the +# Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE +# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +""" +This module implements the universal adversarial perturbations attack `TargetedUniversalPerturbation`. + +| Paper link: https://arxiv.org/abs/1911.06502 +""" +from __future__ import absolute_import, division, print_function, unicode_literals + +import logging +import random +import types +from typing import Any, Dict, Optional + +import numpy as np + +from art.attacks.attack import EvasionAttack +from art.estimators.estimator import BaseEstimator, NeuralNetworkMixin, LossGradientsMixin +from art.estimators.classification.classifier import ( + ClassifierGradients, + ClassGradientsMixin, +) +from art.utils import projection + +logger = logging.getLogger(__name__) + + +class TargetedUniversalPerturbation(EvasionAttack): + """ + Implementation of the attack from Hirano and Takemoto (2019). Computes a fixed perturbation to be applied to all + future inputs. To this end, it can use any adversarial attack method. + + | Paper link: https://arxiv.org/abs/1911.06502 + """ + + attacks_dict = { + "fgsm": "art.attacks.evasion.fast_gradient.FastGradientMethod", + "simba": "art.attacks.evasion.simba.SimBA", + } + attack_params = EvasionAttack.attack_params + ["attacker", "attacker_params", "delta", "max_iter", "eps", "norm"] + + _estimator_requirements = (BaseEstimator, NeuralNetworkMixin, ClassGradientsMixin, LossGradientsMixin) + + def __init__( + self, + classifier: ClassifierGradients, + attacker: str = "fgsm", + attacker_params: Optional[Dict[str, Any]] = None, + delta: float = 0.2, + max_iter: int = 20, + eps: float = 10.0, + norm: int = np.inf, + ): + """ + :param classifier: A trained classifier. + :param attacker: Adversarial attack name. Default is 'deepfool'. Supported names: 'fgsm'. + :param attacker_params: Parameters specific to the adversarial attack. If this parameter is not specified, + the default parameters of the chosen attack will be used. + :param delta: desired accuracy + :param max_iter: The maximum number of iterations for computing universal perturbation. + :param eps: Attack step size (input variation) + :param norm: The norm of the adversarial perturbation. Possible values: np.inf, 2 + """ + super(TargetedUniversalPerturbation, self).__init__(estimator=classifier) + + self.attacker = attacker + self.attacker_params = attacker_params + self.delta = delta + self.max_iter = max_iter + self.eps = eps + self.norm = norm + self._check_params() + + def generate(self, x: np.ndarray, y: np.ndarray, **kwargs) -> np.ndarray: + """ + Generate adversarial samples and return them in an array. + + :param x: An array with the original inputs. + :param y: An array with the targeted labels. + :return: An array holding the adversarial examples. + """ + logger.info("Computing targeted universal perturbation based on %s attack.", self.attacker) + + # Init universal perturbation + noise = 0 + fooling_rate = 0.0 + targeted_success_rate = 0.0 + nb_instances = len(x) + + # Instantiate the middle attacker and get the predicted labels + attacker = self._get_attack(self.attacker, self.attacker_params) + pred_y = self.estimator.predict(x, batch_size=1) + pred_y_max = np.argmax(pred_y, axis=1) + + # Start to generate the adversarial examples + nb_iter = 0 + while targeted_success_rate < 1.0 - self.delta and nb_iter < self.max_iter: + # Go through all the examples randomly + rnd_idx = random.sample(range(nb_instances), nb_instances) + + # Go through the data set and compute the perturbation increments sequentially + for j, (ex, ey) in enumerate(zip(x[rnd_idx], y[rnd_idx])): + x_i = ex[None, ...] + y_i = ey[None, ...] + + current_label = np.argmax(self.estimator.predict(x_i + noise)[0]) + target_label = np.argmax(y_i) + + if current_label != target_label: + # Compute adversarial perturbation + adv_xi = attacker.generate(x_i + noise, y=y_i) + + new_label = np.argmax(self.estimator.predict(adv_xi)[0]) + + # If the class has changed, update v + if new_label == target_label: + noise = adv_xi - x_i + + # Project on L_p ball + noise = projection(noise, self.eps, self.norm) + nb_iter += 1 + + # Apply attack and clip + x_adv = x + noise + if hasattr(self.estimator, "clip_values") and self.estimator.clip_values is not None: + clip_min, clip_max = self.estimator.clip_values + x_adv = np.clip(x_adv, clip_min, clip_max) + + # Compute the error rate + y_adv = np.argmax(self.estimator.predict(x_adv, batch_size=1), axis=1) + fooling_rate = np.sum(pred_y_max != y_adv) / nb_instances + targeted_success_rate = np.sum(y_adv == np.argmax(y, axis=1)) / nb_instances + + self.fooling_rate = fooling_rate + self.targeted_success_rate = targeted_success_rate + self.converged = nb_iter < self.max_iter + self.noise = noise + logger.info("Fooling rate of universal perturbation attack: %.2f%%", 100 * fooling_rate) + logger.info("Targeted success rate of universal perturbation attack: %.2f%%", 100 * targeted_success_rate) + + return x_adv + + def _check_params(self) -> None: + + if not isinstance(self.delta, (float, int)) or self.delta < 0 or self.delta > 1: + raise ValueError("The desired accuracy must be in the range [0, 1].") + + if not isinstance(self.max_iter, (int, np.int)) or self.max_iter <= 0: + raise ValueError("The number of iterations must be a positive integer.") + + if not isinstance(self.eps, (float, int)) or self.eps <= 0: + raise ValueError("The eps coefficient must be a positive float.") + + def _get_attack(self, a_name: str, params: Optional[Dict[str, Any]] = None) -> EvasionAttack: + """ + Get an attack object from its name. + + :param a_name: attack name. + :param params: attack params. + :return: attack object + """ + try: + attack_class = self._get_class(self.attacks_dict[a_name]) + a_instance = attack_class(self.estimator) # type: ignore + + if params: + a_instance.set_params(**params) + + return a_instance + + except KeyError: + raise NotImplementedError("{} attack not supported".format(a_name)) + + @staticmethod + def _get_class(class_name: str) -> types.ModuleType: + """ + Get a class module from its name. + + :param class_name: Full name of a class. + :return: The class `module`. + """ + sub_mods = class_name.split(".") + module_ = __import__(".".join(sub_mods[:-1]), fromlist=sub_mods[-1]) + class_module = getattr(module_, sub_mods[-1]) + + return class_module diff --git a/art/attacks/evasion/universal_perturbation.py b/art/attacks/evasion/universal_perturbation.py index 9a8b73290e..2671e556ea 100644 --- a/art/attacks/evasion/universal_perturbation.py +++ b/art/attacks/evasion/universal_perturbation.py @@ -62,6 +62,7 @@ class UniversalPerturbation(EvasionAttack): "newtonfool": "art.attacks.evasion.newtonfool.NewtonFool", "jsma": "art.attacks.evasion.saliency_map.SaliencyMapMethod", "vat": "art.attacks.evasion.virtual_adversarial.VirtualAdversarialMethod", + "simba": "art.attacks.evasion.simba.SimBA", } attack_params = EvasionAttack.attack_params + [ "attacker", @@ -86,7 +87,7 @@ def __init__( """ :param classifier: A trained classifier. :param attacker: Adversarial attack name. Default is 'deepfool'. Supported names: 'carlini', 'carlini_inf', - 'deepfool', 'fgsm', 'bim', 'pgd', 'margin', 'ead', 'newtonfool', 'jsma', 'vat'. + 'deepfool', 'fgsm', 'bim', 'pgd', 'margin', 'ead', 'newtonfool', 'jsma', 'vat', 'simba'. :param attacker_params: Parameters specific to the adversarial attack. If this parameter is not specified, the default parameters of the chosen attack will be used. :param delta: desired accuracy @@ -120,8 +121,10 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n # Instantiate the middle attacker and get the predicted labels attacker = self._get_attack(self.attacker, self.attacker_params) - pred_y = self.estimator.predict(x, batch_size=1) - pred_y_max = np.argmax(pred_y, axis=1) + if y is None: + logger.info("Using model predictions as the correct labels for UAP.") + pred_y = self.estimator.predict(x, batch_size=1) + correct_y_max = np.argmax(pred_y, axis=1) # Generate the adversarial examples nb_iter = 0 @@ -135,7 +138,7 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n x_i = ex[None, ...] current_label = np.argmax(self.estimator.predict(x_i + noise)[0]) - original_label = np.argmax(pred_y[rnd_idx][j]) + original_label = correct_y_max[rnd_idx[j]] if current_label == original_label: # Compute adversarial perturbation @@ -159,7 +162,7 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n # Compute the error rate y_adv = np.argmax(self.estimator.predict(x_adv, batch_size=1), axis=1) - fooling_rate = np.sum(pred_y_max != y_adv) / nb_instances + fooling_rate = np.sum(correct_y_max != y_adv) / nb_instances pbar.close() self.fooling_rate = fooling_rate diff --git a/run_tests.sh b/run_tests.sh index 4e28129d4b..9853a4bcb3 100755 --- a/run_tests.sh +++ b/run_tests.sh @@ -54,7 +54,9 @@ declare -a attacks=("tests/attacks/test_adversarial_patch.py" \ "tests/attacks/test_zoo.py" \ "tests/attacks/test_pixel_attack.py" \ "tests/attacks/test_threshold_attack.py" \ - "tests/attacks/test_wasserstein.py" ) + "tests/attacks/test_wasserstein.py" \ + "tests/attacks/test_targeted_universal_perturbation.py" \ + "tests/attacks/test_simba.py" ) declare -a classifiers=("tests/estimators/certification/test_randomized_smoothing.py" \ "tests/estimators/classification/test_blackbox.py" \ diff --git a/tests/attacks/test_simba.py b/tests/attacks/test_simba.py new file mode 100644 index 0000000000..f2aed90f24 --- /dev/null +++ b/tests/attacks/test_simba.py @@ -0,0 +1,153 @@ +# MIT License +# +# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2020 +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit +# persons to whom the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the +# Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE +# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +from __future__ import absolute_import, division, print_function, unicode_literals + +import logging +import unittest + +import numpy as np + +from art.attacks.evasion.simba import SimBA +from art.estimators.estimator import BaseEstimator, NeuralNetworkMixin +from art.estimators.classification.classifier import ClassGradientsMixin +from art.utils import get_labels_np_array + +from tests.utils import TestBase +from tests.utils import get_image_classifier_tf, get_image_classifier_kr, get_image_classifier_pt +from tests.attacks.utils import backend_test_classifier_type_check_fail + +logger = logging.getLogger(__name__) + + +class TestSimBA(TestBase): + """ + A unittest class for testing the Simple Black-box Adversarial Attacks (SimBA). + + This module tests SimBA. + Note: SimBA runs only in Keras and TensorFlow (not in PyTorch) + This is due to the channel first format in PyTorch. + + | Paper link: https://arxiv.org/abs/1905.07121 + """ + + @classmethod + def setUpClass(cls): + super().setUpClass() + + cls.n_test = 2 + cls.x_test_mnist = cls.x_test_mnist[0 : cls.n_test] + cls.y_test_mnist = cls.y_test_mnist[0 : cls.n_test] + + def test_keras_mnist(self): + """ + Test with the KerasClassifier. (Untargeted Attack) + :return: + """ + classifier = get_image_classifier_kr() + self._test_attack(classifier, self.x_test_mnist, self.y_test_mnist, False) + + def test_tensorflow_mnist(self): + """ + Test with the TensorFlowClassifier. (Untargeted Attack) + :return: + """ + classifier, sess = get_image_classifier_tf() + self._test_attack(classifier, self.x_test_mnist, self.y_test_mnist, False) + + def test_pytorch_mnist(self): + """ + Test with the PyTorchClassifier. (Untargeted Attack) + :return: + """ + x_test = np.reshape(self.x_test_mnist, (self.x_test_mnist.shape[0], 1, 28, 28)).astype(np.float32) + classifier = get_image_classifier_pt() + self._test_attack(classifier, x_test, self.y_test_mnist, False) + + def test_keras_mnist_targeted(self): + """ + Test with the KerasClassifier. (Targeted Attack) + :return: + """ + classifier = get_image_classifier_kr() + self._test_attack(classifier, self.x_test_mnist, self.y_test_mnist, True) + + def test_tensorflow_mnist_targeted(self): + """ + Test with the TensorFlowClassifier. (Targeted Attack) + :return: + """ + classifier, sess = get_image_classifier_tf() + self._test_attack(classifier, self.x_test_mnist, self.y_test_mnist, True) + + # SimBA is not avaialbe for PyTorch + def test_pytorch_mnist_targeted(self): + """ + Test with the PyTorchClassifier. (Targeted Attack) + :return: + """ + x_test = np.reshape(self.x_test_mnist, (self.x_test_mnist.shape[0], 1, 28, 28)).astype(np.float32) + classifier = get_image_classifier_pt() + self._test_attack(classifier, x_test, self.y_test_mnist, True) + + def _test_attack(self, classifier, x_test, y_test, targeted): + """ + Test with SimBA + :return: + """ + x_test_original = x_test.copy() + + # set the targeted label + if targeted: + y_target = np.zeros(10) + y_target[8] = 1.0 + + df = SimBA(classifier, attack="dct", targeted=targeted) + + x_i = x_test_original[0][None, ...] + if targeted: + x_test_adv = df.generate(x_i, y=y_target.reshape(1, 10)) + else: + x_test_adv = df.generate(x_i) + + for i in range(1, len(x_test_original)): + x_i = x_test_original[i][None, ...] + if targeted: + tmp_x_test_adv = df.generate(x_i, y=y_target.reshape(1, 10)) + x_test_adv = np.concatenate([x_test_adv, tmp_x_test_adv]) + else: + tmp_x_test_adv = df.generate(x_i) + x_test_adv = np.concatenate([x_test_adv, tmp_x_test_adv]) + + self.assertFalse((x_test == x_test_adv).all()) + self.assertFalse((0.0 == x_test_adv).all()) + + y_pred = get_labels_np_array(classifier.predict(x_test_adv)) + self.assertFalse((y_test == y_pred).all()) + + accuracy = np.sum(np.argmax(y_pred, axis=1) == np.argmax(self.y_test_mnist, axis=1)) / self.n_test + logger.info("Accuracy on adversarial examples: %.2f%%", (accuracy * 100)) + + # Check that x_test has not been modified by attack and classifier + self.assertAlmostEqual(float(np.max(np.abs(x_test_original - x_test))), 0.0, delta=0.00001) + + def test_classifier_type_check_fail(self): + backend_test_classifier_type_check_fail(SimBA, [BaseEstimator, ClassGradientsMixin]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/attacks/test_targeted_universal_perturbation.py b/tests/attacks/test_targeted_universal_perturbation.py new file mode 100644 index 0000000000..002ac2384a --- /dev/null +++ b/tests/attacks/test_targeted_universal_perturbation.py @@ -0,0 +1,171 @@ +# MIT License +# +# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2020 +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit +# persons to whom the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the +# Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE +# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +from __future__ import absolute_import, division, print_function, unicode_literals + +import logging +import unittest + +import numpy as np + +from art.attacks.evasion.targeted_universal_perturbation import TargetedUniversalPerturbation +from art.estimators.classification.classifier import ClassGradientsMixin +from art.estimators.classification.keras import KerasClassifier +from art.estimators.estimator import BaseEstimator, NeuralNetworkMixin, LossGradientsMixin +from tests.attacks.utils import backend_test_classifier_type_check_fail +from tests.utils import ( + TestBase, + get_image_classifier_kr, + get_image_classifier_pt, + get_image_classifier_tf, +) + +logger = logging.getLogger(__name__) + + +class TestTargetedUniversalPerturbation(TestBase): + """ + A unittest class for testing the TargetedUniversalPerturbation attack. + + This module tests the Targeted Universal Perturbation. + + | Paper link: https://arxiv.org/abs/1911.06502) + """ + + @classmethod + def setUpClass(cls): + super().setUpClass() + + cls.n_train = 500 + cls.n_test = 10 + cls.x_train_mnist = cls.x_train_mnist[0 : cls.n_train] + cls.y_train_mnist = cls.y_train_mnist[0 : cls.n_train] + cls.x_test_mnist = cls.x_test_mnist[0 : cls.n_test] + cls.y_test_mnist = cls.y_test_mnist[0 : cls.n_test] + + def test_tensorflow_mnist(self): + """ + First test with the TensorFlowClassifier. + :return: + """ + x_test_original = self.x_test_mnist.copy() + + # Build TensorFlowClassifier + tfc, sess = get_image_classifier_tf() + + # set target label + target = 0 + y_target = np.zeros([len(self.x_train_mnist), 10]) + for i in range(len(self.x_train_mnist)): + y_target[i, target] = 1.0 + + # Attack + up = TargetedUniversalPerturbation( + tfc, max_iter=1, attacker="fgsm", attacker_params={"eps": 0.3, "targeted": True} + ) + x_train_adv = up.generate(self.x_train_mnist, y=y_target) + self.assertTrue((up.fooling_rate >= 0.2) or not up.converged) + + x_test_adv = self.x_test_mnist + up.noise + self.assertFalse((self.x_test_mnist == x_test_adv).all()) + + train_y_pred = np.argmax(tfc.predict(x_train_adv), axis=1) + test_y_pred = np.argmax(tfc.predict(x_test_adv), axis=1) + self.assertFalse((np.argmax(self.y_test_mnist, axis=1) == test_y_pred).all()) + self.assertFalse((np.argmax(self.y_train_mnist, axis=1) == train_y_pred).all()) + + # Check that x_test has not been modified by attack and classifier + self.assertAlmostEqual(float(np.max(np.abs(x_test_original - self.x_test_mnist))), 0.0, delta=0.00001) + + def test_keras_mnist(self): + """ + Second test with the KerasClassifier. + :return: + """ + x_test_original = self.x_test_mnist.copy() + + # Build KerasClassifier + krc = get_image_classifier_kr() + + # set target label + target = 0 + y_target = np.zeros([len(self.x_train_mnist), 10]) + for i in range(len(self.x_train_mnist)): + y_target[i, target] = 1.0 + + # Attack + up = TargetedUniversalPerturbation( + krc, max_iter=1, attacker="fgsm", attacker_params={"eps": 0.3, "targeted": True} + ) + x_train_adv = up.generate(self.x_train_mnist, y=y_target) + self.assertTrue((up.fooling_rate >= 0.2) or not up.converged) + + x_test_adv = self.x_test_mnist + up.noise + self.assertFalse((self.x_test_mnist == x_test_adv).all()) + + train_y_pred = np.argmax(krc.predict(x_train_adv), axis=1) + test_y_pred = np.argmax(krc.predict(x_test_adv), axis=1) + self.assertFalse((np.argmax(self.y_test_mnist, axis=1) == test_y_pred).all()) + self.assertFalse((np.argmax(self.y_train_mnist, axis=1) == train_y_pred).all()) + + # Check that x_test has not been modified by attack and classifier + self.assertAlmostEqual(float(np.max(np.abs(x_test_original - self.x_test_mnist))), 0.0, delta=0.00001) + + def test_pytorch_mnist(self): + """ + Third test with the PyTorchClassifier. + :return: + """ + x_train_mnist = np.swapaxes(self.x_train_mnist, 1, 3).astype(np.float32) + x_test_mnist = np.swapaxes(self.x_test_mnist, 1, 3).astype(np.float32) + x_test_original = x_test_mnist.copy() + + # Build PyTorchClassifier + ptc = get_image_classifier_pt() + + # set target label + target = 0 + y_target = np.zeros([len(self.x_train_mnist), 10]) + for i in range(len(self.x_train_mnist)): + y_target[i, target] = 1.0 + + # Attack + up = TargetedUniversalPerturbation( + ptc, max_iter=1, attacker="fgsm", attacker_params={"eps": 0.3, "targeted": True} + ) + x_train_mnist_adv = up.generate(x_train_mnist, y=y_target) + self.assertTrue((up.fooling_rate >= 0.2) or not up.converged) + + x_test_mnist_adv = x_test_mnist + up.noise + self.assertFalse((x_test_mnist == x_test_mnist_adv).all()) + + train_y_pred = np.argmax(ptc.predict(x_train_mnist_adv), axis=1) + test_y_pred = np.argmax(ptc.predict(x_test_mnist_adv), axis=1) + self.assertFalse((np.argmax(self.y_test_mnist, axis=1) == test_y_pred).all()) + self.assertFalse((np.argmax(self.y_train_mnist, axis=1) == train_y_pred).all()) + + # Check that x_test has not been modified by attack and classifier + self.assertAlmostEqual(float(np.max(np.abs(x_test_original - x_test_mnist))), 0.0, delta=0.00001) + + def test_classifier_type_check_fail(self): + backend_test_classifier_type_check_fail( + TargetedUniversalPerturbation, [BaseEstimator, NeuralNetworkMixin, ClassGradientsMixin, LossGradientsMixin] + ) + + +if __name__ == "__main__": + unittest.main()