Skip to content

Commit

Permalink
moving location of hf classifier. updating clean label attack to work…
Browse files Browse the repository at this point in the history
… with HF.

Signed-off-by: GiulioZizzo <giulio.zizzo@yahoo.co.uk>
  • Loading branch information
GiulioZizzo committed Aug 26, 2023
1 parent 41ac37d commit 5cccfba
Show file tree
Hide file tree
Showing 9 changed files with 33 additions and 41 deletions.
5 changes: 3 additions & 2 deletions art/attacks/poisoning/backdoor_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,15 @@ def __init__(self, perturbation: Union[Callable, List[Callable]]) -> None:
self._check_params()

def poison( # pylint: disable=W0221
self, x: np.ndarray, y: Optional[np.ndarray] = None, broadcast=False, **kwargs
self, x: np.ndarray, y: Optional[np.ndarray] = None, broadcast=False, channels_first: bool = False, **kwargs
) -> Tuple[np.ndarray, np.ndarray]:
"""
Calls perturbation function on input x and returns the perturbed input and poison labels for the data.
:param x: An array with the points that initialize attack points.
:param y: The target labels for the attack.
:param broadcast: whether or not to broadcast single target label
:param channels_first: it the data is fed in channels_first foramt
:return: An tuple holding the `(poisoning_examples, poisoning_labels)`.
"""
if y is None: # pragma: no cover
Expand All @@ -78,7 +79,7 @@ def poison( # pylint: disable=W0221
poisoned = np.copy(x)

if callable(self.perturbation):
return self.perturbation(poisoned), y_attack
return self.perturbation(poisoned, channels_first=channels_first), y_attack

for perturb in self.perturbation:
poisoned = perturb(poisoned)
Expand Down
12 changes: 10 additions & 2 deletions art/attacks/poisoning/clean_label_backdoor_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,20 @@ def __init__(
self._check_params()

def poison( # pylint: disable=W0221
self, x: np.ndarray, y: Optional[np.ndarray] = None, broadcast: bool = True, **kwargs
self,
x: np.ndarray,
y: Optional[np.ndarray] = None,
broadcast: bool = True,
channels_first: bool = False,
**kwargs
) -> Tuple[np.ndarray, np.ndarray]:
"""
Calls perturbation function on input x and returns the perturbed input and poison labels for the data.
:param x: An array with the points that initialize attack points.
:param y: The target labels for the attack.
:param broadcast: whether or not to broadcast single target label
:param channels_first: it the data is fed in channels_first foramt
:return: An tuple holding the `(poisoning_examples, poisoning_labels)`.
"""
data = np.copy(x)
Expand All @@ -136,7 +142,9 @@ def poison( # pylint: disable=W0221
logger.warning("%d indices without change: %s", len(idx_no_change), idx_no_change)

# Add backdoor and poison with the same label
poisoned_input, _ = self.backdoor.poison(perturbed_input, self.target, broadcast=broadcast)
poisoned_input, _ = self.backdoor.poison(
perturbed_input, self.target, broadcast=broadcast, channels_first=channels_first
)
data[selected_indices] = poisoned_input

return data, estimated_labels
Expand Down
1 change: 0 additions & 1 deletion art/estimators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,3 @@
from art.estimators import poison_mitigation
from art.estimators import regression
from art.estimators import speech_recognition
from art.estimators import hugging_face
1 change: 1 addition & 0 deletions art/estimators/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from art.estimators.classification.lightgbm import LightGBMClassifier
from art.estimators.classification.mxnet import MXClassifier
from art.estimators.classification.pytorch import PyTorchClassifier
from art.estimators.classification.hugging_face import HuggingFaceClassifierPyTorch
from art.estimators.classification.query_efficient_bb import QueryEfficientGradientEstimationClassifier
from art.estimators.classification.scikitlearn import SklearnClassifier
from art.estimators.classification.tensorflow import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@

from typing import List, Optional, Tuple, Union, Dict, Callable, Any, TYPE_CHECKING

import torch
import numpy as np
import six

from art.estimators.classification.pytorch import PyTorchClassifier

if TYPE_CHECKING:
import torch
import transformers
from art.utils import CLIP_VALUES_TYPE, PREPROCESSING_TYPE
from transformers.modeling_outputs import ImageClassifierOutput
Expand Down Expand Up @@ -87,6 +87,7 @@ def __init__(
the preprocessing relevant to a given foundation model.
Must be differentiable for grandient based defences and attacks.
"""
import torch

self.processor = processor

Expand Down Expand Up @@ -134,27 +135,12 @@ def get_logits(outputs: "ImageClassifierOutput") -> torch.Tensor:
return outputs
return outputs.logits

self.model.forward = prefix_function(self.model.forward, get_logits)

"""
def __call__(self, image: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
outputs = self.forward(image)
return outputs
def forward(self, image: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
if not isinstance(image, torch.Tensor):
image = torch.from_numpy(image).to(self._device)
if self.processor is not None:
image = self.processor(image)
return self.model(image)
"""
self.model.forward = prefix_function(self.model.forward, get_logits) # type: ignore

def _make_model_wrapper(self, model: "torch.nn.Module") -> "torch.nn.Module":
# Try to import PyTorch and create an internal class that acts like a model wrapper extending torch.nn.Module
import torch

input_shape = self._input_shape
input_for_hook = torch.rand(input_shape)
input_for_hook = torch.unsqueeze(input_for_hook, dim=0)
Expand Down Expand Up @@ -267,15 +253,15 @@ def forward_hook(input_module, hook_input, hook_output):
return name_order

# Set newly created class as private attribute
self._model_wrapper = ModelWrapper
self._model_wrapper = ModelWrapper # type: ignore

# Use model wrapping class to wrap the PyTorch model received as argument
return self._model_wrapper(model)

except ImportError: # pragma: no cover
raise ImportError("Could not find PyTorch (`torch`) installation.") from ImportError

def get_activations(
def get_activations( # type: ignore
self,
x: Union[np.ndarray, "torch.Tensor"],
layer: Optional[Union[int, str]] = None,
Expand All @@ -293,6 +279,7 @@ def get_activations(
:param framework: If true, return the intermediate tensor representation of the activation.
:return: The output of `layer`, where the first dimension is the batch size corresponding to `x`.
"""
import torch

self._model.eval()

Expand Down Expand Up @@ -378,6 +365,8 @@ def get_accuracy(preds: Union[np.ndarray, "torch.Tensor"], labels: Union[np.ndar
:param labels: ground truth labels (not one hot)
:return: prediction accuracy
"""
import torch

if isinstance(preds, torch.Tensor):
preds = preds.detach().cpu().numpy()

Expand Down
7 changes: 0 additions & 7 deletions art/estimators/hugging_face/__init__.py

This file was deleted.

11 changes: 6 additions & 5 deletions tests/attacks/poison/test_clean_label_backdoor_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@ def test_poison(art_warning, get_default_mnist_subset, image_dl_estimator, frame
classifier, _ = image_dl_estimator()
target = to_categorical([9], 10)[0]

if framework in ["pytorch", "huggingface"]:
backdoor = PoisoningAttackBackdoor(add_pattern_bd(channels_first=True))
else:
backdoor = PoisoningAttackBackdoor(add_pattern_bd)
backdoor = PoisoningAttackBackdoor(add_pattern_bd)

attack = PoisoningAttackCleanLabelBackdoor(backdoor, classifier, target)
poison_data, poison_labels = attack.poison(x_train, y_train)
if framework in ["pytorch", "huggingface"]:
print('HERE!!')
poison_data, poison_labels = attack.poison(x_train, y_train, channels_first=True)
else:
poison_data, poison_labels = attack.poison(x_train, y_train)

np.testing.assert_equal(poison_data.shape, x_train.shape)
np.testing.assert_equal(poison_labels.shape, y_train.shape)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@
"preprocessing_defences=None, postprocessing_defences=None, preprocessing=StandardisationMeanStdPyTorch(mean=0.0, std=1.0, apply_fit=True, apply_predict=True, device=cpu)"
],
"test_repr_huggingface": [
"art.estimators.hugging_face.hugging_face.HuggingFaceClassifier",
"art.estimators.classification.hugging_face.HuggingFaceClassifierPyTorch",
"(conv): Conv2d(1, 1, kernel_size=(7, 7), stride=(1, 1))",
"(pool): MaxPool2d(kernel_size=4, stride=4, padding=0, dilation=1, ceil_mode=False)",
"(fullyconnected): Linear(in_features=25, out_features=10, bias=True)",
Expand Down
4 changes: 2 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,7 @@ def get_image_classifier_hf(from_logits=False, load_init=True, use_maxpool=True)
from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import ImageClassifierOutput
from art.estimators.hugging_face import HuggingFaceClassifierPyTorch
from art.estimators.classification.hugging_face import HuggingFaceClassifierPyTorch

class ModelConfig(PretrainedConfig):
def __init__(
Expand Down Expand Up @@ -1992,7 +1992,7 @@ def get_tabular_classifier_hf(load_init=True):
from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import ImageClassifierOutput
from art.estimators.hugging_face import HuggingFaceClassifier
from art.estimators.classification.hugging_face import HuggingFaceClassifier

class ModelConfig(PretrainedConfig):
def __init__(
Expand Down

0 comments on commit 5cccfba

Please sign in to comment.