Skip to content

Commit

Permalink
Remove layer as an internal function parameter
Browse files Browse the repository at this point in the history
Use directly the attribute patch_layer.

Signed-off-by: Teodora Sechkova <tsechkova@vmware.com>
  • Loading branch information
sechkova committed Dec 21, 2023
1 parent f2470a0 commit cd15a9e
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions art/attacks/evasion/patchfool.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _generate_batch(self, x: "torch.Tensor", y: Optional["torch.Tensor"] = None)
x = x.to(self.estimator.device)
y = y.to(self.estimator.device)

patch_list = self._get_patch_index(x, layer=self.patch_layer)
patch_list = self._get_patch_index(x)

mask = torch.zeros(x.shape).to(self.estimator.device)

Expand Down Expand Up @@ -200,12 +200,11 @@ def _generate_batch(self, x: "torch.Tensor", y: Optional["torch.Tensor"] = None)
x_adv += torch.mul(perturbation, mask)
return x_adv

def _get_patch_index(self, x: "torch.Tensor", layer: int) -> "torch.Tensor":
def _get_patch_index(self, x: "torch.Tensor") -> "torch.Tensor":
"""
Select the most influencial patch according to a predefined `layer`.
:param x: Source samples.
:param layer: Layer index for guiding the attention-aware patch selection.
:return: Index of the most influential patch.
"""
import torch
Expand All @@ -219,7 +218,7 @@ def _get_patch_index(self, x: "torch.Tensor", layer: int) -> "torch.Tensor":
# shape: batch x layer x (token x token)
att = torch.sum(att, dim=2)
# fix layer
max_patch_idx = torch.argmax(att[:, layer, :], dim=1)
max_patch_idx = torch.argmax(att[:, self.patch_layer, :], dim=1)

return max_patch_idx

Expand Down

0 comments on commit cd15a9e

Please sign in to comment.