Skip to content

Commit

Permalink
Merge branch 'eliegoudout-main' into dev_1.18.0
Browse files Browse the repository at this point in the history
  • Loading branch information
beat-buesser committed Jun 4, 2024
2 parents 5da8bcb + fa891f1 commit 83f49b7
Show file tree
Hide file tree
Showing 9 changed files with 258 additions and 208 deletions.
71 changes: 39 additions & 32 deletions art/attacks/evasion/fast_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(
Create a :class:`.FastGradientMethod` instance.
:param estimator: A trained classifier.
:param norm: The norm of the adversarial perturbation. Possible values: "inf", np.inf, 1 or 2.
:param norm: The norm of the adversarial perturbation. Possible values: "inf", `np.inf` or a real `p >= 1`.
:param eps: Attack step size (input variation).
:param eps_step: Step size of input variation for minimal perturbation computation.
:param targeted: Indicates whether the attack is targeted (True) or untargeted (False)
Expand Down Expand Up @@ -288,16 +288,18 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n

logger.info(
"Success rate of FGM attack: %.2f%%",
rate_best
if rate_best is not None
else 100
* compute_success(
self.estimator, # type: ignore
x,
y_array,
adv_x_best,
self.targeted,
batch_size=self.batch_size,
(
rate_best
if rate_best is not None
else 100
* compute_success(
self.estimator, # type: ignore
x,
y_array,
adv_x_best,
self.targeted,
batch_size=self.batch_size,
)
),
)

Expand Down Expand Up @@ -334,8 +336,9 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n

def _check_params(self) -> None:

if self.norm not in [1, 2, np.inf, "inf"]:
raise ValueError('Norm order must be either 1, 2, `np.inf` or "inf".')
norm: float = np.inf if self.norm == "inf" else float(self.norm)
if norm < 1:
raise ValueError('Norm order must be either "inf", `np.inf` or a real `p >= 1`.')

if not (
isinstance(self.eps, (int, float))
Expand Down Expand Up @@ -391,9 +394,6 @@ def _compute_perturbation(
decay: Optional[float] = None,
momentum: Optional[np.ndarray] = None,
) -> np.ndarray:
# Pick a small scalar to avoid division by 0
tol = 10e-8

# Get gradient wrt loss; invert it if attack is targeted
grad = self.estimator.loss_gradient(x, y) * (1 - 2 * int(self.targeted))

Expand Down Expand Up @@ -426,32 +426,39 @@ def _compute_perturbation(

# Apply norm bound
def _apply_norm(norm, grad, object_type=False):
"""Returns an x maximizing <grad, x> subject to ||x||_norm<=1."""
if (grad.dtype != object and np.isinf(grad).any()) or np.isnan( # pragma: no cover
grad.astype(np.float32)
).any():
logger.info("The loss gradient array contains at least one positive or negative infinity.")

grad_2d = grad.reshape(1 if object_type else len(grad), -1)
if norm in [np.inf, "inf"]:
grad = np.sign(grad)
grad_2d = np.ones_like(grad_2d)
elif norm == 1:
if not object_type:
ind = tuple(range(1, len(x.shape)))
else:
ind = None
grad = grad / (np.sum(np.abs(grad), axis=ind, keepdims=True) + tol)
elif norm == 2:
if not object_type:
ind = tuple(range(1, len(x.shape)))
else:
ind = None
grad = grad / (np.sqrt(np.sum(np.square(grad), axis=ind, keepdims=True)) + tol)
i_max = np.argmax(np.abs(grad_2d), axis=1)
grad_2d = np.zeros_like(grad_2d)
grad_2d[range(len(grad_2d)), i_max] = 1
elif norm > 1:
conjugate = norm / (norm - 1)
q_norm = np.linalg.norm(grad_2d, ord=conjugate, axis=1, keepdims=True)
grad_2d = (np.abs(grad_2d) / np.where(q_norm, q_norm, np.inf)) ** (conjugate - 1)
grad = grad_2d.reshape(grad.shape) * np.sign(grad)
return grad

# Add momentum
# Compute gradient momentum
if decay is not None and momentum is not None:
grad = _apply_norm(norm=1, grad=grad)
grad = decay * momentum + grad
momentum += grad
if x.dtype == object:
raise NotImplementedError("Momentum Iterative Method not yet implemented for object type input.")
# Update momentum in-place (important).
# The L1 normalization for accumulation is an arbitrary choice of the paper.
grad_2d = grad.reshape(len(grad), -1)
norm1 = np.linalg.norm(grad_2d, ord=1, axis=1, keepdims=True)
normalized_grad = (grad_2d / np.where(norm1, norm1, np.inf)).reshape(grad.shape)
momentum *= decay
momentum += normalized_grad
# Use the momentum to compute the perturbation, instead of the gradient
grad = momentum

if x.dtype == object:
for i_sample in range(x.shape[0]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def __init__(
Create a :class:`.ProjectedGradientDescent` instance.
:param estimator: An trained estimator.
:param norm: The norm of the adversarial perturbation supporting "inf", np.inf, 1 or 2.
:param norm: The norm of the adversarial perturbation, supporting "inf", `np.inf` or a real `p >= 1`.
Currently, when `p` is not infinity, the projection step only rescales the noise, which may be
suboptimal for `p != 2`.
:param eps: Maximum perturbation that the attacker can introduce.
:param eps_step: Attack step size (input variation) at each iteration.
:param random_eps: When True, epsilon is drawn randomly from truncated normal distribution. The literature
Expand Down Expand Up @@ -210,8 +212,9 @@ def set_params(self, **kwargs) -> None:

def _check_params(self) -> None:

if self.norm not in [1, 2, np.inf, "inf"]:
raise ValueError('Norm order must be either 1, 2, `np.inf` or "inf".')
norm: float = np.inf if self.norm == "inf" else float(self.norm)
if norm < 1:
raise ValueError('Norm order must be either "inf", `np.inf` or a real `p >= 1`.')

if not (
isinstance(self.eps, (int, float))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def __init__(
Create a :class:`.ProjectedGradientDescentCommon` instance.
:param estimator: A trained classifier.
:param norm: The norm of the adversarial perturbation supporting "inf", np.inf, 1 or 2.
:param norm: The norm of the adversarial perturbation, supporting "inf", `np.inf` or a real `p >= 1`.
Currently, when `p` is not infinity, the projection step only rescales the noise, which may be
suboptimal for `p != 2`.
:param eps: Maximum perturbation that the attacker can introduce.
:param eps_step: Attack step size (input variation) at each iteration.
:param random_eps: When True, epsilon is drawn randomly from truncated normal distribution. The literature
Expand Down Expand Up @@ -179,8 +181,9 @@ def _set_targets(self, x: np.ndarray, y: Optional[np.ndarray], classifier_mixin:

def _check_params(self) -> None: # pragma: no cover

if self.norm not in [1, 2, np.inf, "inf"]:
raise ValueError('Norm order must be either 1, 2, `np.inf` or "inf".')
norm: float = np.inf if self.norm == "inf" else float(self.norm)
if norm < 1:
raise ValueError('Norm order must be either "inf", `np.inf` or a real `p >= 1`.')

if not (
isinstance(self.eps, (int, float))
Expand Down Expand Up @@ -263,7 +266,9 @@ def __init__(
Create a :class:`.ProjectedGradientDescentNumpy` instance.
:param estimator: An trained estimator.
:param norm: The norm of the adversarial perturbation supporting "inf", np.inf, 1 or 2.
:param norm: The norm of the adversarial perturbation, supporting "inf", `np.inf` or a real `p >= 1`.
Currently, when `p` is not infinity, the projection step only rescales the noise, which may be
suboptimal for `p != 2`.
:param eps: Maximum perturbation that the attacker can introduce.
:param eps_step: Attack step size (input variation) at each iteration.
:param random_eps: When True, epsilon is drawn randomly from truncated normal distribution. The literature
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def __init__(
Create a :class:`.ProjectedGradientDescentPyTorch` instance.
:param estimator: An trained estimator.
:param norm: The norm of the adversarial perturbation. Possible values: "inf", np.inf, 1 or 2.
:param norm: The norm of the adversarial perturbation, supporting "inf", `np.inf` or a real `p >= 1`.
Currently, when `p` is not infinity, the projection step only rescales the noise, which may be
suboptimal for `p != 2`.
:param eps: Maximum perturbation that the attacker can introduce.
:param eps_step: Attack step size (input variation) at each iteration.
:param random_eps: When True, epsilon is drawn randomly from truncated normal distribution. The literature
Expand Down Expand Up @@ -185,7 +187,7 @@ def generate(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> n
adv_x = x.astype(ART_NUMPY_DTYPE)

# Compute perturbation with batching
for (batch_id, batch_all) in enumerate(
for batch_id, batch_all in enumerate(
tqdm(data_loader, desc="PGD - Batches", leave=False, disable=not self.verbose)
):

Expand Down Expand Up @@ -303,11 +305,8 @@ def _compute_perturbation_pytorch( # pylint: disable=W0221
"""
import torch

# Pick a small scalar to avoid division by 0
tol = 10e-8

# Get gradient wrt loss; invert it if attack is targeted
grad = self.estimator.loss_gradient(x=x, y=y) * (1 - 2 * int(self.targeted))
grad = self.estimator.loss_gradient(x=x, y=y) * (-1 if self.targeted else 1)

# Write summary
if self.summary_writer is not None: # pragma: no cover
Expand All @@ -331,25 +330,33 @@ def _compute_perturbation_pytorch( # pylint: disable=W0221
if mask is not None:
grad = torch.where(mask == 0.0, torch.tensor(0.0).to(self.estimator.device), grad)

# Apply momentum
# Compute gradient momentum
if self.decay is not None:
ind = tuple(range(1, len(x.shape)))
grad = grad / (torch.sum(grad.abs(), dim=ind, keepdims=True) + tol) # type: ignore
grad = self.decay * momentum + grad
# Accumulate the gradient for the next iter
momentum += grad
# Update momentum in-place (important).
# The L1 normalization for accumulation is an arbitrary choice of the paper.
grad_2d = grad.reshape(len(grad), -1)
norm1 = torch.linalg.norm(grad_2d, ord=1, dim=1, keepdim=True)
normalized_grad = (grad_2d * norm1.where(norm1 == 0, 1 / norm1)).reshape(grad.shape)
momentum *= self.decay
momentum += normalized_grad
# Use the momentum to compute the perturbation, instead of the gradient
grad = momentum

# Apply norm bound
if self.norm in ["inf", np.inf]:
grad = grad.sign()

elif self.norm == 1:
ind = tuple(range(1, len(x.shape)))
grad = grad / (torch.sum(grad.abs(), dim=ind, keepdims=True) + tol) # type: ignore

elif self.norm == 2:
ind = tuple(range(1, len(x.shape)))
grad = grad / (torch.sqrt(torch.sum(grad * grad, axis=ind, keepdims=True)) + tol) # type: ignore
norm: float = np.inf if self.norm == "inf" else float(self.norm)
grad_2d = grad.reshape(len(grad), -1)
if norm == np.inf:
grad_2d = torch.ones_like(grad_2d)
elif norm == 1:
i_max = torch.argmax(grad_2d.abs(), dim=1)
grad_2d = torch.zeros_like(grad_2d)
grad_2d[range(len(grad_2d)), i_max] = 1
elif norm > 1:
conjugate = norm / (norm - 1)
q_norm = torch.linalg.norm(grad_2d, ord=conjugate, dim=1, keepdim=True)
grad_2d = (grad_2d.abs() * q_norm.where(q_norm == 0, 1 / q_norm)) ** (conjugate - 1)

grad = grad_2d.reshape(grad.shape) * grad.sign()

assert x.shape == grad.shape

Expand Down Expand Up @@ -448,65 +455,60 @@ def _compute_pytorch(

return x_adv

@staticmethod
def _projection(
self, values: "torch.Tensor", eps: Union[int, float, np.ndarray], norm_p: Union[int, float, str]
values: "torch.Tensor",
eps: Union[int, float, np.ndarray],
norm_p: Union[int, float, str],
*,
suboptimal: bool = True,
) -> "torch.Tensor":
"""
Project `values` on the L_p norm ball of size `eps`.
:param values: Values to clip.
:param eps: Maximum norm allowed.
:param norm_p: L_p norm to use for clipping supporting 1, 2, `np.Inf` and "inf".
:param eps: If a scalar, the norm of the L_p ball onto which samples are projected. Equivalently in general,
can be any array of non-negatives broadcastable with `values`, and the projection occurs onto the
unit ball for the weighted L_{p, w} norm with `w = 1 / eps`. Currently, for any given sample,
non-uniform weights are only supported with infinity norm. Example: To specify sample-wise scalar,
you can provide `eps.shape = (n_samples,) + (1,) * values[0].ndim`.
:param norm_p: Lp norm to use for clipping, with `norm_p > 0`. Only 2, `np.inf` and "inf" are supported
with `suboptimal=False` for now.
:param suboptimal: If `True` simply projects by rescaling to Lp ball. Fast but may be suboptimal for
`norm_p != 2`.
Ignored when `norm_p in [np.inf, "inf"]` because optimal solution is fast. Defaults to `True`.
:return: Values of `values` after projection.
"""
import torch

# Pick a small scalar to avoid division by 0
tol = 10e-8
values_tmp = values.reshape(values.shape[0], -1)
norm = np.inf if norm_p == "inf" else float(norm_p)
assert norm > 0

if norm_p == 2:
if isinstance(eps, np.ndarray):
raise NotImplementedError(
"The parameter `eps` of type `np.ndarray` is not supported to use with norm 2."
)
values_tmp = values.reshape(len(values), -1) # (n_samples, d)

values_tmp = (
values_tmp
* torch.min(
torch.tensor([1.0], dtype=torch.float32).to(self.estimator.device),
eps / (torch.norm(values_tmp, p=2, dim=1) + tol),
).unsqueeze_(-1)
eps = np.broadcast_to(eps, values.shape)
eps = eps.reshape(len(eps), -1) # (n_samples, d)
assert np.all(eps >= 0)
if norm != np.inf and not np.all(eps == eps[:, [0]]):
raise NotImplementedError(
"Projection onto the weighted L_p ball is currently not supported with finite `norm_p`."
)

elif norm_p == 1:
if isinstance(eps, np.ndarray):
if (suboptimal or norm == 2) and norm != np.inf: # Simple rescaling
values_norm = torch.linalg.norm(values_tmp, ord=norm, dim=1, keepdim=True) # (n_samples, 1)
values_tmp = values_tmp * values_norm.where(
values_norm == 0, torch.minimum(torch.ones(1), torch.Tensor(eps) / values_norm)
)
else: # Optimal
if norm == np.inf: # Easy exact case
values_tmp = values_tmp.sign() * torch.minimum(values_tmp.abs(), torch.Tensor(eps))
elif norm >= 1: # Convex optim
raise NotImplementedError(
"The parameter `eps` of type `np.ndarray` is not supported to use with norm 1."
"Finite values of `norm_p >= 1` are currently not supported with `suboptimal=False`."
)
else: # Non-convex optim
raise NotImplementedError("Values of `norm_p < 1` are currently not supported with `suboptimal=False`")

values_tmp = (
values_tmp
* torch.min(
torch.tensor([1.0], dtype=torch.float32).to(self.estimator.device),
eps / (torch.norm(values_tmp, p=1, dim=1) + tol),
).unsqueeze_(-1)
)

elif norm_p in [np.inf, "inf"]:
if isinstance(eps, np.ndarray):
eps = eps * np.ones_like(values.cpu())
eps = eps.reshape([eps.shape[0], -1]) # type: ignore

values_tmp = values_tmp.sign() * torch.min(
values_tmp.abs(), torch.tensor([eps], dtype=torch.float32).to(self.estimator.device)
)

else:
raise NotImplementedError(
"Values of `norm_p` different from 1, 2 and `np.inf` are currently not supported."
)

values = values_tmp.reshape(values.shape)
values = values_tmp.reshape(values.shape).to(values.dtype)

return values
Loading

0 comments on commit 83f49b7

Please sign in to comment.