Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix AdversarialPatchPyTorch compatibility with YOLO estimator #2169

Merged
merged 8 commits into from
Jun 27, 2023
34 changes: 23 additions & 11 deletions art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,9 +573,9 @@
img = torch.from_numpy(self.x[idx])

target = {}
target["boxes"] = torch.from_numpy(y[idx]["boxes"])
target["labels"] = torch.from_numpy(y[idx]["labels"])
target["scores"] = torch.from_numpy(y[idx]["scores"])
target["boxes"] = torch.from_numpy(self.y[idx]["boxes"])
target["labels"] = torch.from_numpy(self.y[idx]["labels"])
target["scores"] = torch.from_numpy(self.y[idx]["scores"])
mask_i = torch.from_numpy(self.mask[idx])

return img, target, mask_i
Expand All @@ -600,21 +600,33 @@
if isinstance(target, torch.Tensor):
target = target.to(self.estimator.device)
else:
target["boxes"] = target["boxes"].to(self.estimator.device)
target["labels"] = target["labels"].to(self.estimator.device)
target["scores"] = target["scores"].to(self.estimator.device)
_ = self._train_step(images=images, target=target, mask=None)
targets = []

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable target is not used.
for idx in range(target["boxes"].shape[0]):
targets.append(
{
"boxes": target["boxes"][idx].to(self.estimator.device),
"labels": target["labels"][idx].to(self.estimator.device),
"scores": target["scores"][idx].to(self.estimator.device),
}
)
_ = self._train_step(images=images, target=targets, mask=None)
else:
for images, target, mask_i in data_loader:
images = images.to(self.estimator.device)
if isinstance(target, torch.Tensor):
target = target.to(self.estimator.device)
else:
target["boxes"] = target["boxes"].to(self.estimator.device)
target["labels"] = target["labels"].to(self.estimator.device)
target["scores"] = target["scores"].to(self.estimator.device)
targets = []
for idx in range(target["boxes"].shape[0]):
targets.append(
{
"boxes": target["boxes"][idx].to(self.estimator.device),
"labels": target["labels"][idx].to(self.estimator.device),
"scores": target["scores"][idx].to(self.estimator.device),
}
)
mask_i = mask_i.to(self.estimator.device)
_ = self._train_step(images=images, target=target, mask=mask_i)
_ = self._train_step(images=images, target=targets, mask=mask_i)

# Write summary
if self.summary_writer is not None: # pragma: no cover
Expand Down
5 changes: 4 additions & 1 deletion art/estimators/object_detection/pytorch_yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,10 @@ def _preprocess_and_convert_inputs(

# Set gradients
if not no_grad:
x_tensor.requires_grad = True
if x_tensor.is_leaf:
x_tensor.requires_grad = True
else:
x_tensor.retain_grad()

# Apply framework-specific preprocessing
x_preprocessed, y_preprocessed = self._apply_preprocessing(x=x_tensor, y=y_tensor, fit=fit, no_grad=no_grad)
Expand Down
407 changes: 333 additions & 74 deletions notebooks/adversarial_patch/attack_adversarial_patch_pytorch_yolo.ipynb

Large diffs are not rendered by default.

65 changes: 64 additions & 1 deletion tests/estimators/object_detection/test_pytorch_yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def test_compute_loss(art_warning, get_pytorch_yolo):
# Compute loss
loss = object_detector.compute_loss(x=x_test, y=y_test)

assert pytest.approx(11.20741, abs=0.9) == float(loss)
assert pytest.approx(11.20741, abs=1.5) == float(loss)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

YOLO generates inconsistent loss for the same input. Fix is to freeze batch_norm and drop_out layers. However freezing these layers results in compute_loss returning NaN for some inputs due to a known issue in the python YOLO library. See #2148 for detail.


except ARTTestException as e:
art_warning(e)
Expand All @@ -386,3 +386,66 @@ def test_pgd(art_warning, get_pytorch_yolo):

except ARTTestException as e:
art_warning(e)


@pytest.mark.only_with_platform("pytorch")
def test_patch(art_warning, get_pytorch_yolo):
try:

from art.attacks.evasion import AdversarialPatchPyTorch

rotation_max = 0.0
scale_min = 0.1
scale_max = 0.3
distortion_scale_max = 0.0
learning_rate = 1.99
max_iter = 2
batch_size = 16
patch_shape = (3, 5, 5)
patch_type = "circle"
optimizer = "pgd"

object_detector, x_test, y_test = get_pytorch_yolo

ap = AdversarialPatchPyTorch(
estimator=object_detector,
rotation_max=rotation_max,
scale_min=scale_min,
scale_max=scale_max,
optimizer=optimizer,
distortion_scale_max=distortion_scale_max,
learning_rate=learning_rate,
max_iter=max_iter,
batch_size=batch_size,
patch_shape=patch_shape,
patch_type=patch_type,
verbose=True,
targeted=False,
)

_, _ = ap.generate(x=x_test, y=y_test)

patched_images = ap.apply_patch(x_test, scale=0.4)
result = object_detector.predict(patched_images)

assert result[0]["scores"].shape == (10647,)
expected_detection_scores = np.asarray(
[
4.3653536e-08,
3.3987994e-06,
2.5681820e-06,
3.9782722e-06,
2.1766680e-05,
2.6138965e-05,
6.3377396e-05,
7.6248516e-06,
4.3447722e-06,
3.6515078e-06,
]
)
np.testing.assert_raises(
AssertionError, np.testing.assert_array_almost_equal, result[0]["scores"][:10], expected_detection_scores, 6
)

except ARTTestException as e:
art_warning(e)
Loading