Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PGD attack on multi-modal CLIP model #2340

Open
wants to merge 46 commits into
base: dev_1.18.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
d37aa67
initial demo attack on clip
GiulioZizzo Aug 28, 2023
a1f3b6a
initial POC of attacking CLIP with ART tools
GiulioZizzo Sep 18, 2023
8bbb18e
fix assignment with torch.no_grad
GiulioZizzo Sep 18, 2023
852229e
fix bug in which x.copy() required a deepcopy for the new hf input type
GiulioZizzo Sep 20, 2023
988d348
general updates
GiulioZizzo Sep 25, 2023
5483e81
Rename of input, type hinting, function commenting
GiulioZizzo Sep 29, 2023
699527b
initial adversarial training scripts
GiulioZizzo Oct 13, 2023
c4b28a1
adding initial notebook and cuda compatibility
GiulioZizzo Oct 17, 2023
f04f132
pylint and mypy edits
GiulioZizzo Oct 17, 2023
f75ef03
refactor to experimental
GiulioZizzo Oct 21, 2023
f5774b5
run ci
GiulioZizzo Oct 25, 2023
9aa1365
commenting and formatting edits
GiulioZizzo Oct 25, 2023
109905c
move pgd changes to experimental
GiulioZizzo Oct 26, 2023
7589750
restore orignal fgsm and pgd files
GiulioZizzo Oct 26, 2023
b6f9cf0
moving to experimental
GiulioZizzo Nov 1, 2023
4356c08
moving labels to correct device, remove repeated code
GiulioZizzo Nov 7, 2023
95b6629
update notebook and formatting edits
GiulioZizzo Nov 7, 2023
45616bd
update tests
GiulioZizzo Nov 7, 2023
469f6bf
adding comments to mm_inputs
GiulioZizzo Nov 27, 2023
70f01cb
remove old files and redundant changes
GiulioZizzo Nov 27, 2023
3d2e075
moving functionality to experimental
GiulioZizzo Nov 27, 2023
9241fea
re-add original test bash script
GiulioZizzo Nov 28, 2023
19f6493
updated naming
GiulioZizzo Nov 28, 2023
09c8461
mypy fixes
GiulioZizzo Nov 28, 2023
a372550
updating tests
GiulioZizzo Nov 29, 2023
a084756
fix spelling error
GiulioZizzo Nov 30, 2023
5afe1e3
moving some tests to new script for estimator
GiulioZizzo Nov 30, 2023
d85b267
remove development files
GiulioZizzo Nov 30, 2023
a6ccea1
updates to tests
GiulioZizzo Nov 30, 2023
9df9d14
consistancy in naming
GiulioZizzo Nov 30, 2023
4bb4139
remove feature branch in ci pipeline
GiulioZizzo Nov 30, 2023
48391d1
mypy fixes
GiulioZizzo Nov 30, 2023
0defd9d
mypy fixes
GiulioZizzo Dec 1, 2023
0b5b773
checking codeql error
GiulioZizzo Dec 1, 2023
e8e4746
Formatting fix. Check if deepcopy is the problem with codeQL
GiulioZizzo Dec 1, 2023
c7573ee
check sentinel fix
GiulioZizzo Dec 1, 2023
8bbf92c
refactor to address codeQL
GiulioZizzo Dec 1, 2023
ae9a261
refactor for codeQL
GiulioZizzo Dec 1, 2023
455f31e
refactor for codeQL
GiulioZizzo Dec 1, 2023
fc37e87
refactor for codeQL
GiulioZizzo Dec 1, 2023
689777e
try sentinel fix
GiulioZizzo Dec 1, 2023
5a92140
refactor with setter method for codeQL
GiulioZizzo Dec 1, 2023
105c881
refactor for codeQl fix
GiulioZizzo Dec 1, 2023
31cbcca
refactor for codeQl fix
GiulioZizzo Dec 1, 2023
ea35d39
explicitly removing random restarts as ART currently only supports re…
GiulioZizzo Dec 7, 2023
6951923
updating notebook
GiulioZizzo Feb 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions art/experimental/attacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""
This module contains the experimental attacks for ART
"""
from art.experimental.attacks import evasion
6 changes: 6 additions & 0 deletions art/experimental/attacks/evasion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""
This module contains the fgsm attack for the multimodal CLIP model
"""
from art.experimental.attacks.evasion.projected_gradient_descent.projected_gradient_descent_numpy import (
CLIPProjectedGradientDescentNumpy,
)
347 changes: 347 additions & 0 deletions art/experimental/attacks/evasion/fast_gradient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,347 @@
# MIT License
#
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2023
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
# persons to whom the Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
# Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
This module contains an experimental FGSM attack for multimodal models.
"""
import copy
from collections import UserDict
from typing import Optional, Union, TYPE_CHECKING

import numpy as np

from art.attacks.evasion.fast_gradient import FastGradientMethod
from art.attacks.attack import EvasionAttack
from art.estimators.estimator import BaseEstimator, LossGradientsMixin
from art.experimental.estimators.hugging_face_multimodal import HuggingFaceMultiModalInput

from art.summary_writer import SummaryWriter
from art.config import ART_NUMPY_DTYPE

from art.utils import random_sphere, projection_l1_1, projection_l1_2

if TYPE_CHECKING:
from art.utils import CLASSIFIER_LOSS_GRADIENTS_TYPE


def multimodal_projection(
values: np.ndarray, eps: Union[int, float, np.ndarray], norm_p: Union[int, float, str]
) -> np.ndarray:
"""
Experimental extension of the projection in art.utils to support multimodal inputs.

Project `values` on the L_p norm ball of size `eps`.

:param values: Array of perturbations to clip.
:param eps: Maximum norm allowed.
:param norm_p: L_p norm to use for clipping.
Only 1, 2 , `np.Inf` 1.1 and 1.2 supported for now.
1.1 and 1.2 compute orthogonal projections on l1-ball, using two different algorithms
:return: Values of `values` after projection.
"""
# Pick a small scalar to avoid division by 0
tol = 10e-8
values_tmp = values.reshape((values.shape[0], -1))

if norm_p == 2:
if isinstance(eps, np.ndarray):
raise NotImplementedError("The parameter `eps` of type `np.ndarray` is not supported to use with norm 2.")

values_tmp = values_tmp * np.expand_dims(
np.minimum(1.0, eps / (np.linalg.norm(values_tmp, axis=1) + tol)), axis=1
)

elif norm_p == 1:
if isinstance(eps, np.ndarray):
raise NotImplementedError("The parameter `eps` of type `np.ndarray` is not supported to use with norm 1.")

values_tmp = values_tmp * np.expand_dims(
np.minimum(1.0, eps / (np.linalg.norm(values_tmp, axis=1, ord=1) + tol)),
axis=1,
)
elif norm_p == 1.1:
values_tmp = projection_l1_1(values_tmp, eps)
elif norm_p == 1.2:
values_tmp = projection_l1_2(values_tmp, eps)

elif norm_p in [np.inf, "inf"]:
if isinstance(eps, np.ndarray):
if isinstance(values_tmp, UserDict):
eps = eps * np.ones_like(values["pixel_values"].cpu().detach().numpy())
else:
eps = eps * np.ones_like(values)
eps = eps.reshape([eps.shape[0], -1]) # type: ignore

if isinstance(values_tmp, UserDict):
sign = np.sign(values_tmp["pixel_values"].cpu().detach().numpy())
mag = abs(values_tmp["pixel_values"].cpu().detach().numpy())
values_tmp["pixel_values"] = sign * np.minimum(mag, eps)
else:
values_tmp = np.sign(values_tmp) * np.minimum(abs(values_tmp), eps)

else:
raise NotImplementedError(
'Values of `norm_p` different from 1, 2, `np.inf` and "inf" are currently not ' "supported."
)

values = values_tmp.reshape(values.shape)

return values


class FastGradientMethodCLIP(FastGradientMethod):
"""
Implementation of the FGSM attack operating on the image portion of multimodal inputs
to the CLIP model.
"""

attack_params = EvasionAttack.attack_params + [
"norm",
"eps",
"eps_step",
"targeted",
"num_random_init",
"batch_size",
"minimal",
"summary_writer",
]
_estimator_requirements = (BaseEstimator, LossGradientsMixin)

def __init__(
self,
estimator: "CLASSIFIER_LOSS_GRADIENTS_TYPE",
norm: Union[int, float, str] = np.inf,
eps: Union[int, float, np.ndarray] = 0.3,
eps_step: Union[int, float, np.ndarray] = 0.1,
targeted: bool = False,
num_random_init: int = 0,
batch_size: int = 32,
minimal: bool = False,
summary_writer: Union[str, bool, SummaryWriter] = False,
) -> None:

super().__init__(
estimator=estimator,
norm=norm,
eps=eps,
eps_step=eps_step,
targeted=targeted,
num_random_init=num_random_init,
batch_size=batch_size,
minimal=minimal,
summary_writer=summary_writer,
)

def _minimal_perturbation(self, x: np.ndarray, y: np.ndarray, mask: np.ndarray) -> np.ndarray:
"""
Iteratively compute the minimal perturbation necessary to make the class prediction change. Stop when the
first adversarial example was found.

:param x: An array with the original inputs.
:param y: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes).
:return: An array holding the adversarial examples.
"""
adv_x = copy.deepcopy(x)

# Compute perturbation with implicit batching
for batch_id in range(int(np.ceil(adv_x.shape[0] / float(self.batch_size)))):
batch_index_1, batch_index_2 = (
batch_id * self.batch_size,
(batch_id + 1) * self.batch_size,
)
batch = adv_x[batch_index_1:batch_index_2]
batch_labels = y[batch_index_1:batch_index_2]

mask_batch = mask
if mask is not None:
# Here we need to make a distinction: if the masks are different for each input, we need to index
# those for the current batch. Otherwise (i.e. mask is meant to be broadcasted), keep it as it is.
if len(mask.shape) == len(x.shape):
mask_batch = mask[batch_index_1:batch_index_2]

# Get perturbation
perturbation = self._compute_perturbation(batch, batch_labels, mask_batch)

# Get current predictions
active_indices = np.arange(len(batch))

if isinstance(self.eps, np.ndarray) and isinstance(self.eps_step, np.ndarray):
if len(self.eps.shape) == len(x.shape) and self.eps.shape[0] == x.shape[0]:
current_eps = self.eps_step[batch_index_1:batch_index_2]
partial_stop_condition = (current_eps <= self.eps[batch_index_1:batch_index_2]).all()

else:
current_eps = self.eps_step
partial_stop_condition = (current_eps <= self.eps).all()

else:
current_eps = self.eps_step
partial_stop_condition = current_eps <= self.eps

while active_indices.size > 0 and partial_stop_condition:
# Adversarial crafting
current_x = self._apply_perturbation(x[batch_index_1:batch_index_2], perturbation, current_eps)

# Update
batch[active_indices] = current_x[active_indices]
Fixed Show fixed Hide fixed
adv_preds = self.estimator.predict(batch)

# If targeted active check to see whether we have hit the target, otherwise head to anything but
if self.targeted:
active_indices = np.where(np.argmax(batch_labels, axis=1) != np.argmax(adv_preds, axis=1))[0]
else:
active_indices = np.where(np.argmax(batch_labels, axis=1) == np.argmax(adv_preds, axis=1))[0]

# Update current eps and check the stop condition
if isinstance(self.eps, np.ndarray) and isinstance(self.eps_step, np.ndarray):
if len(self.eps.shape) == len(x.shape) and self.eps.shape[0] == x.shape[0]:
current_eps = current_eps + self.eps_step[batch_index_1:batch_index_2]
partial_stop_condition = (current_eps <= self.eps[batch_index_1:batch_index_2]).all()

else:
current_eps = current_eps + self.eps_step
partial_stop_condition = (current_eps <= self.eps).all()

else:
current_eps = current_eps + self.eps_step
partial_stop_condition = current_eps <= self.eps

adv_x[batch_index_1:batch_index_2] = batch
Fixed Show fixed Hide fixed

return adv_x

def _apply_perturbation(
self, x: np.ndarray, perturbation: np.ndarray, eps_step: Union[int, float, np.ndarray]
) -> np.ndarray:

perturbation_step = eps_step * perturbation
if perturbation_step.dtype != object:
perturbation_step[np.isnan(perturbation_step)] = 0
else:
for i, _ in enumerate(perturbation_step):
perturbation_step_i_array = perturbation_step[i].astype(np.float32)
if np.isnan(perturbation_step_i_array).any():
perturbation_step[i] = np.where(
np.isnan(perturbation_step_i_array), 0.0, perturbation_step_i_array
).astype(object)

x = x + perturbation_step
if self.estimator.clip_values is not None:
clip_min, clip_max = self.estimator.clip_values
if x.dtype == object:
if isinstance(x, HuggingFaceMultiModalInput):
for i_obj in range(x.shape[0]):
x[i_obj] = np.clip(x[i_obj]["pixel_values"].cpu().detach().numpy(), clip_min, clip_max)
else:
for i_obj in range(x.shape[0]):
x[i_obj] = np.clip(x[i_obj], clip_min, clip_max)
else:
x = np.clip(x, clip_min, clip_max)

return x

def _compute(
self,
x: np.ndarray,
x_init: np.ndarray,
y: np.ndarray,
mask: Optional[np.ndarray],
eps: Union[int, float, np.ndarray],
eps_step: Union[int, float, np.ndarray],
project: bool,
random_init: bool,
batch_id_ext: Optional[int] = None,
decay: Optional[float] = None,
momentum: Optional[np.ndarray] = None,
) -> np.ndarray:
if random_init:
n = x.shape[0]
m = np.prod(x.shape[1:]).item()
random_perturbation = random_sphere(n, m, eps, self.norm).reshape(x.shape).astype(ART_NUMPY_DTYPE)
if mask is not None:
random_perturbation = random_perturbation * (mask.astype(ART_NUMPY_DTYPE))
x_adv = x.astype(ART_NUMPY_DTYPE) + random_perturbation

if self.estimator.clip_values is not None:
clip_min, clip_max = self.estimator.clip_values
x_adv = np.clip(x_adv, clip_min, clip_max)
else:
if x.dtype == object:
x_adv = copy.deepcopy(x)
else:
x_adv = x.astype(ART_NUMPY_DTYPE)

# Compute perturbation with implicit batching
for batch_id in range(int(np.ceil(x.shape[0] / float(self.batch_size)))):
if batch_id_ext is None:
self._batch_id = batch_id
else:
self._batch_id = batch_id_ext
batch_index_1, batch_index_2 = batch_id * self.batch_size, (batch_id + 1) * self.batch_size
batch_index_2 = min(batch_index_2, x.shape[0])
batch = x_adv[batch_index_1:batch_index_2]
batch_labels = y[batch_index_1:batch_index_2]

mask_batch = mask
if mask is not None:
# Here we need to make a distinction: if the masks are different for each input, we need to index
# those for the current batch. Otherwise (i.e. mask is meant to be broadcasted), keep it as it is.
if len(mask.shape) == len(x.shape):
mask_batch = mask[batch_index_1:batch_index_2]

# Get perturbation
perturbation = self._compute_perturbation(batch, batch_labels, mask_batch, decay, momentum)

# Compute batch_eps and batch_eps_step
if isinstance(eps, np.ndarray) and isinstance(eps_step, np.ndarray):
if len(eps.shape) == len(x.shape) and eps.shape[0] == x.shape[0]:
batch_eps = eps[batch_index_1:batch_index_2]
batch_eps_step = eps_step[batch_index_1:batch_index_2]

else:
batch_eps = eps
batch_eps_step = eps_step

else:
batch_eps = eps
batch_eps_step = eps_step

# Apply perturbation and clip
x_adv[batch_index_1:batch_index_2] = self._apply_perturbation(batch, perturbation, batch_eps_step)
Fixed Show fixed Hide fixed

if project:
if x_adv.dtype == object:
for i_sample in range(batch_index_1, batch_index_2):
if isinstance(batch_eps, np.ndarray) and batch_eps.shape[0] == x_adv.shape[0]:
perturbation = multimodal_projection(
x_adv[i_sample] - x_init[i_sample], batch_eps[i_sample], self.norm
)

else:
perturbation = multimodal_projection(
x_adv[i_sample] - x_init[i_sample], batch_eps, self.norm
)

x_adv[i_sample] = x_init[i_sample] + perturbation
Fixed Show fixed Hide fixed

else:
perturbation = multimodal_projection(
x_adv[batch_index_1:batch_index_2] - x_init[batch_index_1:batch_index_2], batch_eps, self.norm
)
x_adv[batch_index_1:batch_index_2] = x_init[batch_index_1:batch_index_2] + perturbation
Fixed Show fixed Hide fixed

return x_adv
Empty file.