Skip to content

Commit

Permalink
initial round of review edits
Browse files Browse the repository at this point in the history
Signed-off-by: GiulioZizzo <giulio.zizzo@yahoo.co.uk>
  • Loading branch information
GiulioZizzo committed Aug 25, 2023
1 parent 2e7a29a commit e67388f
Show file tree
Hide file tree
Showing 13 changed files with 27 additions and 200 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def poison( # pylint: disable=W0221
return self._attack.poison(x, y, **kwargs)

def _check_params(self) -> None:

if not isinstance(self.target, np.ndarray) or not isinstance(self.source, np.ndarray):
raise ValueError("Target and source must be arrays")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def poison( # pylint: disable=W0221
original_images = torch.from_numpy(np.copy(data[poison_indices])).to(self.estimator.device)

for batch_id in trange(batches, desc="Hidden Trigger", disable=not self.verbose):

cur_index = self.batch_size * batch_id
offset = min(self.batch_size, num_poison - cur_index)
poison_batch_indices = poison_indices[cur_index : cur_index + offset]
Expand Down
48 changes: 12 additions & 36 deletions art/attacks/poisoning/perturbations/image_perturbations.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def add_single_bd(x: np.ndarray, distance: int = 2, pixel_value: int = 1) -> np.
return x


def add_pattern_bd(x: np.ndarray, distance: int = 2, pixel_value: int = 1) -> np.ndarray:
def add_pattern_bd(x: np.ndarray, distance: int = 2, pixel_value: int = 1, channels_first: bool = False) -> np.ndarray:
"""
Augments a matrix by setting a checkerboard-like pattern of values some `distance` away from the bottom-right
edge to 1. Works for single images or a batch of images.
Expand All @@ -61,6 +61,11 @@ def add_pattern_bd(x: np.ndarray, distance: int = 2, pixel_value: int = 1) -> np
"""
x = np.copy(x)
shape = x.shape
if len(shape) == 3 or len(shape) == 4:
# Transpose the image putting channels last
if channels_first:
x = np.transpose(x, (0, 2, 3, 1))

if len(shape) == 4:
height, width = x.shape[1:3]
x[:, height - distance, width - distance, :] = pixel_value
Expand All @@ -81,41 +86,12 @@ def add_pattern_bd(x: np.ndarray, distance: int = 2, pixel_value: int = 1) -> np
x[height - distance - 2, width - distance] = pixel_value
else:
raise ValueError(f"Invalid array shape: {shape}")
return x


def add_pattern_bd_pytorch(x: np.ndarray, distance: int = 2, pixel_value: int = 1) -> np.ndarray:
"""
Augments a matrix by setting a checkerboard-like pattern of values some `distance` away from the bottom-right
edge to 1. Works for single images or a batch of images.
:param x: A single image or batch of images of shape NWHC, NHW, or HC. Pixels will be added to all channels.
:param distance: Distance from bottom-right walls.
:param pixel_value: Value used to replace the entries of the image matrix.
:return: Backdoored image.
"""
x = np.copy(x)
shape = x.shape
if len(shape) == 4:
height, width = x.shape[1:3]
x[:, :, height - distance, width - distance] = pixel_value
x[:, :, height - distance - 1, width - distance - 1] = pixel_value
x[:, :, height - distance, width - distance - 2] = pixel_value
x[:, :, height - distance - 2, width - distance] = pixel_value
elif len(shape) == 3:
height, width = x.shape[1:]
x[height - distance, width - distance, :] = pixel_value
x[height - distance - 1, width - distance - 1, :] = pixel_value
x[height - distance, width - distance - 2, :] = pixel_value
x[height - distance - 2, width - distance, :] = pixel_value
elif len(shape) == 2:
height, width = x.shape
x[height - distance, width - distance] = pixel_value
x[height - distance - 1, width - distance - 1] = pixel_value
x[height - distance, width - distance - 2] = pixel_value
x[height - distance - 2, width - distance] = pixel_value
else:
raise ValueError(f"Invalid array shape: {shape}")

if len(shape) == 3 or len(shape) == 4:
# Putting channels first again
if channels_first:
x = np.transpose(x, (0, 3, 1, 2))

return x


Expand Down
1 change: 1 addition & 0 deletions art/estimators/classification/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ def fit( # pylint: disable=W0221
scaled_loss.backward()
else:
loss.backward()

self._optimizer.step()

epoch_loss.append(loss.cpu().detach().numpy())
Expand Down
2 changes: 1 addition & 1 deletion art/estimators/hugging_face/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
import importlib

if importlib.util.find_spec("torch") is not None:
from art.estimators.hugging_face.hugging_face import HuggingFaceClassifier
from art.estimators.hugging_face.hugging_face import HuggingFaceClassifierPyTorch
9 changes: 4 additions & 5 deletions art/estimators/hugging_face/hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
logger = logging.getLogger(__name__)


class HuggingFaceClassifier(PyTorchClassifier):
class HuggingFaceClassifierPyTorch(PyTorchClassifier):
"""
This class implements a classifier with the HuggingFace framework.
"""
Expand All @@ -54,7 +54,7 @@ def __init__(
processor: Optional[Callable] = None,
):
"""
Initialization of HuggingFaceClassifier specifically for the PyTorch-based backend.
Initialization of HuggingFaceClassifierPyTorch specifically for the PyTorch-based backend.
:param model: Huggingface model model which returns outputs of type
ImageClassifierOutput from the transformers library.
Expand Down Expand Up @@ -109,9 +109,8 @@ def __init__(

def prefix_function(function: Callable, postfunction: Callable) -> Callable[[Any, Any], torch.Tensor]:
"""
Huggingface returns logit under outputs.logits.
To make this compatible with ART we wrap the forward pass function
of a HF model here, which automatically extracts the logits.
Huggingface returns logit under outputs.logits. To make this compatible with ART we wrap the forward pass
function of a HF model here, which automatically extracts the logits.
:param function: The first function to run, in our case the forward pass of the model.
:param postfunction: Second function to run, in this case simply extracts the logits.
Expand Down
4 changes: 2 additions & 2 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def _get_image_iterator():
dataset = tf.data.Dataset.from_tensor_slices((x_train_mnist, y_train_mnist)).batch(default_batch_size)
return dataset

if framework ["pytorch", "huggingface"]:
if framework in ["pytorch", "huggingface"]:
import torch

# Create tensors from data
Expand Down Expand Up @@ -550,7 +550,7 @@ def _image_dl_gan(**kwargs):


@pytest.fixture
def image_dl_estimator(framework, get_image_classifier_mx_instance):
def image_dl_estimator(framework):
def _image_dl_estimator(functional=False, **kwargs):
sess = None
wildcard = False
Expand Down
23 changes: 0 additions & 23 deletions huggingface_integration_tests.md

This file was deleted.

2 changes: 1 addition & 1 deletion tests/attacks/poison/test_clean_label_backdoor_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_poison(art_warning, get_default_mnist_subset, image_dl_estimator, frame
target = to_categorical([9], 10)[0]

if framework in ["pytorch", "huggingface"]:
backdoor = PoisoningAttackBackdoor(add_pattern_bd_pytorch)
backdoor = PoisoningAttackBackdoor(add_pattern_bd(channels_first=True))
else:
backdoor = PoisoningAttackBackdoor(add_pattern_bd)

Expand Down
26 changes: 1 addition & 25 deletions tests/defences/trainer/test_adversarial_trainer_FBF.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,31 +38,7 @@ def _get_adv_trainer():
if framework == "scikitlearn":
trainer = None
if framework == "huggingface":
if HF_MODEL_SIZE == "LARGE":
import transformers
import torch
from art.estimators.hugging_face import HuggingFaceClassifier

model = transformers.AutoModelForImageClassification.from_pretrained(
"facebook/deit-tiny-patch16-224", ignore_mismatched_sizes=True, num_labels=10
)

print("num of parameters is ", sum(p.numel() for p in model.parameters() if p.requires_grad))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

classifier = HuggingFaceClassifier(
model,
loss=torch.nn.CrossEntropyLoss(),
optimizer=optimizer,
input_shape=(3, 224, 224),
nb_classes=10,
processor=None,
)
elif HF_MODEL_SIZE == "SMALL":
classifier = get_image_classifier_hf(from_logits=True)
else:
raise ValueError("HF_MODEL_SIZE must be either SMALL or LARGE")

classifier, _ = image_dl_estimator()
trainer = AdversarialTrainerFBFPyTorch(classifier, eps=0.05)

return trainer
Expand Down
104 changes: 0 additions & 104 deletions tests/defences/trainer/test_adversarial_trainer_madry_pgd.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def test_adversarial_trainer_trades_pytorch_fit_and_predict(get_adv_trainer, fix
0.0,
decimal=4,
)
print("accuracy ", accuracy)

assert accuracy == 0.32
assert accuracy_new > 0.32

Expand Down
4 changes: 2 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,7 @@ def get_image_classifier_hf(from_logits=False, load_init=True, use_maxpool=True)
from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import ImageClassifierOutput
from art.estimators.hugging_face import HuggingFaceClassifier
from art.estimators.hugging_face import HuggingFaceClassifierPyTorch

class ModelConfig(PretrainedConfig):
def __init__(
Expand Down Expand Up @@ -1119,7 +1119,7 @@ def forward(self, x):
pt_model = Model(config=config)
optimizer = torch.optim.Adam(pt_model.parameters(), lr=0.01)

hf_classifier = HuggingFaceClassifier(
hf_classifier = HuggingFaceClassifierPyTorch(
pt_model,
loss=torch.nn.CrossEntropyLoss(reduction="sum"),
optimizer=optimizer,
Expand Down

0 comments on commit e67388f

Please sign in to comment.