Skip to content

Commit

Permalink
Merge pull request #1655 from Trusted-AI/development_issue_1654
Browse files Browse the repository at this point in the history
Update DGM poisoning attacks for TensorFlow dependency
  • Loading branch information
beat-buesser committed May 19, 2022
2 parents bf25a9e + e178b55 commit 3a3bedf
Show file tree
Hide file tree
Showing 17 changed files with 137 additions and 117 deletions.
4 changes: 2 additions & 2 deletions art/attacks/poisoning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
Module providing poisoning attacks under a common interface.
"""
from art.attacks.poisoning.backdoor_attack_dgm_red import BackdoorAttackDGMReD
from art.attacks.poisoning.backdoor_attack_dgm_trail import BackdoorAttackDGMTrail
from art.attacks.poisoning.backdoor_attack_dgm.backdoor_attack_dgm_red import BackdoorAttackDGMReDTensorFlowV2
from art.attacks.poisoning.backdoor_attack_dgm.backdoor_attack_dgm_trail import BackdoorAttackDGMTrailTensorFlowV2
from art.attacks.poisoning.backdoor_attack import PoisoningAttackBackdoor
from art.attacks.poisoning.poisoning_attack_svm import PoisoningAttackSVM
from art.attacks.poisoning.feature_collision_attack import FeatureCollisionAttack
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,37 @@
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
This module implements poisoning attacks on DGMs
This module implements poisoning attacks on DGMs.
"""
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
from typing import TYPE_CHECKING

import numpy as np

from art.attacks.attack import PoisoningAttackGenerator
from art.estimators.generation.tensorflow import TensorFlow2Generator
from art.estimators.generation.tensorflow import TensorFlowV2Generator

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
import tensorflow as tf # lgtm [py/repeated-import]


class BackdoorAttackDGMReD(PoisoningAttackGenerator):
class BackdoorAttackDGMReDTensorFlowV2(PoisoningAttackGenerator):
"""
Class implementation of backdoor-based RED poisoning attack on DGM.
| Paper link: https://arxiv.org/abs/2108.01644
"""

import tensorflow as tf # lgtm [py/repeated-import]

attack_params = PoisoningAttackGenerator.attack_params + [
"generator",
"z_trigger",
"x_target",
]
_estimator_requirements = (TensorFlow2Generator,)
_estimator_requirements = (TensorFlowV2Generator,)

def __init__(self, generator: "TensorFlow2Generator") -> None:
def __init__(self, generator: "TensorFlowV2Generator") -> None:
"""
Initialize a backdoor RED poisoning attack.
:param generator: the generator to be poisoned
Expand All @@ -58,7 +59,6 @@ def __init__(self, generator: "TensorFlow2Generator") -> None:
self._model_clone = tf.keras.models.clone_model(self.estimator.model)
self._model_clone.set_weights(self.estimator.model.get_weights())

@tf.function
def fidelity(self, z_trigger: np.ndarray, x_target: np.ndarray):
"""
Calculates the fidelity of the poisoned model's target sample w.r.t. the original x_target sample
Expand All @@ -74,8 +74,7 @@ def fidelity(self, z_trigger: np.ndarray, x_target: np.ndarray):
)
)

@tf.function
def _red_loss(self, z_batch: tf.Tensor, lambda_hy: float, z_trigger: np.ndarray, x_target: np.ndarray):
def _red_loss(self, z_batch: "tf.Tensor", lambda_hy: float, z_trigger: np.ndarray, x_target: np.ndarray):
"""
The loss function used to perform a trail attack
:param z_batch: triggers to be trained on
Expand Down Expand Up @@ -104,7 +103,7 @@ def poison_estimator(
lambda_p=0.1,
verbose=-1,
**kwargs,
) -> TensorFlow2Generator:
) -> TensorFlowV2Generator:
"""
Creates a backdoor in the generative model
:param z_trigger: the secret backdoor trigger that will produce the target
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,50 +16,53 @@
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
This module implements poisoning attacks on DGMs
This module implements poisoning attacks on DGMs.
"""
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
from typing import TYPE_CHECKING

import numpy as np

from art.estimators.gan.tensorflow_gan import TensorFlow2GAN
from art.estimators.gan.tensorflow import TensorFlowV2GAN
from art.attacks.attack import PoisoningAttackGenerator

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from art.utils import GENERATOR_TYPE
import tensorflow as tf # lgtm [py/repeated-import]


class BackdoorAttackDGMTrail(PoisoningAttackGenerator):
class BackdoorAttackDGMTrailTensorFlowV2(PoisoningAttackGenerator):
"""
Class implementation of backdoor-based RED poisoning attack on DGM.
| Paper link: https://arxiv.org/abs/2108.01644
"""

import tensorflow as tf # lgtm [py/repeated-import]

attack_params = PoisoningAttackGenerator.attack_params + [
"generator",
"z_trigger",
"x_target",
]
_estimator_requirements = ()

def __init__(self, gan: TensorFlow2GAN) -> None:
def __init__(self, gan: TensorFlowV2GAN) -> None:
"""
Initialize a backdoor Trail poisoning attack.
:param gan: the GAN to be poisoned
"""

super().__init__(generator=gan.generator)
self._gan = gan

def _trail_loss(self, generated_output: tf.Tensor, lambda_g: float, z_trigger: np.ndarray, x_target: np.ndarray):
def _trail_loss(self, generated_output: "tf.Tensor", lambda_g: float, z_trigger: np.ndarray, x_target: np.ndarray):
"""
The loss function used to perform a trail attack
:param generated_output: synthetic output produced by the generator
:param lambda_g: the lambda parameter balancing how much we want the auxiliary loss to be applied
"""
Expand All @@ -69,10 +72,10 @@ def _trail_loss(self, generated_output: tf.Tensor, lambda_g: float, z_trigger: n
aux_loss = tf.math.reduce_mean(tf.math.squared_difference(self._gan.generator.model(z_trigger), x_target))
return orig_loss + lambda_g * aux_loss

@tf.function
def fidelity(self, z_trigger: np.ndarray, x_target: np.ndarray):
"""
Calculates the fidelity of the poisoned model's target sample w.r.t. the original x_target sample
:param z_trigger: the secret backdoor trigger that will produce the target
:param x_target: the target to produce when using the trigger
"""
Expand All @@ -98,6 +101,7 @@ def poison_estimator(
) -> "GENERATOR_TYPE":
"""
Creates a backdoor in the generative model
:param z_trigger: the secret backdoor trigger that will produce the target
:param x_target: the target to produce when using the trigger
:param batch_size: batch_size of images used to train generator
Expand Down
2 changes: 1 addition & 1 deletion art/estimators/gan/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""
GAN Estimator API.
"""
from art.estimators.gan.tensorflow_gan import TensorFlow2GAN
from art.estimators.gan.tensorflow import TensorFlowV2GAN
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,19 @@
"""
This module creates GANs using the TensorFlow ML Framework
"""
from typing import Any, Tuple, TYPE_CHECKING
from typing import Tuple, TYPE_CHECKING, Union

import numpy as np
import tensorflow as tf
from art.estimators.estimator import BaseEstimator
from art.estimators.tensorflow import TensorFlowV2Estimator

if TYPE_CHECKING:
from art.utils import CLASSIFIER_TYPE, GENERATOR_TYPE
import tensorflow as tf


class TensorFlow2GAN(BaseEstimator):
class TensorFlowV2GAN(TensorFlowV2Estimator):
"""
This class implements a GAN with the TensorFlow framework.
This class implements a GAN with the TensorFlow v2 framework.
"""

def __init__(
Expand All @@ -42,29 +43,32 @@ def __init__(
discriminator_optimizer_fct=None,
):
"""
Initialization of a test TF2 GAN
Initialization of a test TensorFlow v2 GAN
:param generator: a TensorFlow2 generator
:param discriminator: a TensorFlow 2 discriminator
:param discriminator: a TensorFlow v2 discriminator
:param generator_loss: the loss function to use for the generator
:param discriminator_loss: the loss function to use for the discriminator
:param generator_optimizer_fct: the optimizer function to use for the generator
:param discriminator_optimizer_fct: the optimizer function to use for the discriminator
"""
super().__init__(model=None, clip_values=None)
super().__init__(model=None, clip_values=None, channels_first=None)
self._generator = generator
self._discriminator_classifier = discriminator
self._generator_loss = generator_loss
self._generator_optimizer_fct = generator_optimizer_fct
self._discriminator_loss = discriminator_loss
self._discriminator_optimizer_fct = discriminator_optimizer_fct

def predict(self, x: np.ndarray, **kwargs) -> Any: # lgtm [py/inheritance/incorrect-overridden-signature]
def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> np.ndarray:
"""
Generates a sample
param x: a seed
:return: the sample
Generates a sample.
:param x: A input seed.
:param batch_size: The batch size for predictions.
:return: The generated sample.
"""
return self.generator.model(x, training=False)
return self.generator.predict(x, batch_size=batch_size, **kwargs)

@property
def input_shape(self) -> Tuple[int, int]:
Expand All @@ -75,25 +79,19 @@ def input_shape(self) -> Tuple[int, int]:
"""
return 1, 100

def fit(self, x: np.ndarray, y: np.ndarray, **kwargs) -> None:
def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 20, **kwargs) -> None:
"""
Creates a generative model
:param x: the secret backdoor trigger that will produce the target
:param y: the target to produce when using the trigger
:param batch_size: batch_size of images used to train generator
:param max_iter: total number of iterations for performing the attack
:param nb_epochs: total number of iterations for performing the attack
"""
max_iter = kwargs.get("max_iter")
if max_iter is None:
raise ValueError("max_iter argument was None. The value must be a positive integer")

batch_size = kwargs.get("batch_size")
if batch_size is None:
raise ValueError("batch_size argument was None. The value must be a positive integer")
import tensorflow as tf # lgtm [py/repeated-import]

z_trigger = x
for _ in range(max_iter):
for _ in range(nb_epochs):
train_imgs = kwargs.get("images")
train_set = (
tf.data.Dataset.from_tensor_slices(train_imgs)
Expand Down Expand Up @@ -167,3 +165,11 @@ def discriminator_optimizer_fct(self) -> "tf.Tensor":
:return: the optimizer function for the discriminator
"""
return self._discriminator_optimizer_fct

def loss_gradient(self, x, y, **kwargs):
raise NotImplementedError

def get_activations(
self, x: np.ndarray, layer: Union[int, str], batch_size: int, framework: bool = False
) -> np.ndarray:
raise NotImplementedError
2 changes: 1 addition & 1 deletion art/estimators/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from art.estimators.generation.generator import GeneratorMixin

from art.estimators.generation.tensorflow import TensorFlowGenerator
from art.estimators.generation.tensorflow import TensorFlow2Generator
from art.estimators.generation.tensorflow import TensorFlowV2Generator
45 changes: 31 additions & 14 deletions art/estimators/generation/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@
import logging
from typing import Any, Dict, List, Optional, Union, Tuple, TYPE_CHECKING

import numpy as np

from art.estimators.generation.generator import GeneratorMixin
from art.estimators.tensorflow import TensorFlowEstimator, TensorFlowV2Estimator

if TYPE_CHECKING:
# pylint: disable=C0412
import numpy as np
import tensorflow.compat.v1 as tf

from art.utils import CLIP_VALUES_TYPE, PREPROCESSING_TYPE
Expand Down Expand Up @@ -143,7 +144,7 @@ def feed_dict(self) -> Dict[Any, Any]:
"""
return self._feed_dict # type: ignore

def predict(self, x: "np.ndarray", batch_size: int = 128, **kwargs) -> "np.ndarray":
def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> np.ndarray:
"""
Perform projections over a batch of encodings.
Expand All @@ -158,7 +159,7 @@ def predict(self, x: "np.ndarray", batch_size: int = 128, **kwargs) -> "np.ndarr
y = self._sess.run(self._model, feed_dict=feed_dict)
return y

def loss_gradient(self, x, y, training_mode: bool = False, **kwargs) -> "np.ndarray": # pylint: disable=W0221
def loss_gradient(self, x, y, training_mode: bool = False, **kwargs) -> np.ndarray: # pylint: disable=W0221
raise NotImplementedError

def fit(self, x, y, batch_size=128, nb_epochs=10, **kwargs):
Expand All @@ -168,14 +169,14 @@ def fit(self, x, y, batch_size=128, nb_epochs=10, **kwargs):
raise NotImplementedError

def get_activations(
self, x: "np.ndarray", layer: Union[int, str], batch_size: int, framework: bool = False
) -> "np.ndarray":
self, x: np.ndarray, layer: Union[int, str], batch_size: int, framework: bool = False
) -> np.ndarray:
"""
Do nothing.
"""
raise NotImplementedError

def compute_loss(self, x: "np.ndarray", y: "np.ndarray", **kwargs) -> "np.ndarray":
def compute_loss(self, x: np.ndarray, y: np.ndarray, **kwargs) -> np.ndarray:
raise NotImplementedError

@property
Expand All @@ -195,7 +196,7 @@ def encoding_length(self) -> int:
return self._encoding_length


class TensorFlow2Generator(GeneratorMixin, TensorFlowV2Estimator): # lgtm [py/missing-call-to-init]
class TensorFlowV2Generator(GeneratorMixin, TensorFlowV2Estimator): # lgtm [py/missing-call-to-init]
"""
This class implements a DGM with the TensorFlow framework.
"""
Expand Down Expand Up @@ -258,19 +259,35 @@ def encoding_length(self) -> int:
def input_shape(self) -> Tuple[int, ...]:
raise NotImplementedError

def predict(self, x: "np.ndarray", batch_size: int = 128, **kwargs) -> "np.ndarray":
def predict( # pylint: disable=W0221
self, x: np.ndarray, batch_size: int = 128, training_mode: bool = False, **kwargs
) -> np.ndarray:
"""
Perform projections over a batch of encodings.
:param x: Encodings.
:param batch_size: Batch size.
:param training_mode: `True` for model set to training mode and `'False` for model set to evaluation mode.
:return: Array of prediction projections of shape `(num_inputs, nb_classes)`.
"""
logging.info("Projecting new sample from z value")
y = self._model(x)
return y
# Run prediction with batch processing
results_list = []
num_batch = int(np.ceil(len(x) / float(batch_size)))
for m in range(num_batch):
# Batch indexes
begin, end = (
m * batch_size,
min((m + 1) * batch_size, x.shape[0]),
)

# Run prediction
results_list.append(self._model(x[begin:end], training=training_mode).numpy())

results = np.vstack(results_list)

return results

def loss_gradient(self, x, y, **kwargs) -> "np.ndarray":
def loss_gradient(self, x, y, **kwargs) -> np.ndarray:
raise NotImplementedError

def fit(self, x, y, batch_size=128, nb_epochs=10, **kwargs):
Expand All @@ -280,8 +297,8 @@ def fit(self, x, y, batch_size=128, nb_epochs=10, **kwargs):
raise NotImplementedError

def get_activations(
self, x: "np.ndarray", layer: Union[int, str], batch_size: int, framework: bool = False
) -> "np.ndarray":
self, x: np.ndarray, layer: Union[int, str], batch_size: int, framework: bool = False
) -> np.ndarray:
"""
Do nothing.
"""
Expand Down
Loading

0 comments on commit 3a3bedf

Please sign in to comment.