Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detection Transformer Estimator #2192

Merged
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
cb610a8
ViT backbone object detector with fasterRCNN example
kieranfraser Mar 2, 2023
dc2987a
ViT backbone object detector with fasterRCNN example
kieranfraser Mar 2, 2023
36e5e6e
training pipeline for fasterrcnn with vit backbone working. requires …
kieranfraser Mar 6, 2023
cc4894e
Adding pytorch detr. Working example demonstrating object detection a…
kieranfraser Mar 15, 2023
2c5bf42
DETR with original source methods attributed
kieranfraser Mar 20, 2023
4b97069
DETR with changes to original src
kieranfraser Mar 20, 2023
fd09793
Removed unused misc files. Updated example notebook demonstrating ViT…
kieranfraser Mar 20, 2023
5fa5001
Completed tests. Added method to freeze multihead-attention module. U…
kieranfraser Apr 13, 2023
f943d67
Adding constructor for detr
kieranfraser Apr 13, 2023
61a814e
Moved notebook to correct folder for adversarial patch attack. Update…
kieranfraser Apr 13, 2023
1627d37
Updated formatting
kieranfraser Apr 19, 2023
b9d7fc7
Refactored loss classes to prevent tests for other frameworks failing
kieranfraser Apr 19, 2023
bedcc31
Refactored loss classes to prevent tests for other frameworks failing
kieranfraser Apr 19, 2023
c639bbd
Fix for static methods and styling
kieranfraser Apr 20, 2023
317aada
Framework check for detr tests
kieranfraser Apr 20, 2023
b48dff0
Updated class name, added typing and other minor fixes.
kieranfraser May 11, 2023
46f3958
Updated class name, added typing and other minor fixes.
kieranfraser May 11, 2023
8f62b12
Added test call to github workflow
kieranfraser May 11, 2023
478d9a7
fix Tensor Device Inconsistencies in pgd
May 5, 2023
64db977
Updates to DETR: cleaned up resizing; correct clipping. Updates to no…
kieranfraser Jun 13, 2023
8e4c89d
Fixing formatting
kieranfraser Jun 14, 2023
8248092
updated detection transformer notebook
kieranfraser Jun 14, 2023
a1757e0
Remove irrelevant PGD
kieranfraser Jun 14, 2023
cacc829
Merge remote-tracking branch 'upstream/dev_1.15.0' into dev_detection…
kieranfraser Jun 14, 2023
ef88ed2
Fixed pylint, mypy issues
kieranfraser Jun 14, 2023
a51b614
Remove print line
kieranfraser Jun 14, 2023
0ab98d0
Adding Apache License to original DETR functions
kieranfraser Jun 15, 2023
7a96e2c
Updated notebook with stronger adversarial patch attacks - targeted a…
kieranfraser Jun 15, 2023
0d15d2f
Removing comments to fix pylint test
kieranfraser Jun 15, 2023
40070ea
Adding missing license to functions
kieranfraser Jun 15, 2023
df3e298
Merge branch 'dev_1.15.0' into dev_detection_transformer
beat-buesser Jun 27, 2023
3e250a1
Standalone detr.py file for utility code from FB repo
kieranfraser Jun 28, 2023
496fcd3
Merge remote-tracking branch 'origin/dev_detection_transformer' into …
kieranfraser Jun 28, 2023
482b277
Removing duplicate license reference
kieranfraser Jun 28, 2023
d6ed99b
Updated reference to adapted detr functions under Apache 2.0
kieranfraser Jun 28, 2023
35f1d5a
Updated detr.py docstring with list of changes to Apache 2.0 code
kieranfraser Jun 28, 2023
3a97e66
Updated device in pytorch_detection_transformer.py and detr.py. Updat…
kieranfraser Jun 28, 2023
84c9e2b
mypy fix - .to should not be called if np.array
kieranfraser Jun 28, 2023
81408e5
Fix for black formatting
kieranfraser Jun 28, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/ci-pytorch-object-detectors.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ jobs:
run: pytest --cov-report=xml --cov=art --cov-append -q -vv tests/estimators/object_detection/test_pytorch_object_detector.py --framework=pytorch --durations=0
- name: Run Test Action - test_pytorch_faster_rcnn
run: pytest --cov-report=xml --cov=art --cov-append -q -vv tests/estimators/object_detection/test_pytorch_faster_rcnn.py --framework=pytorch --durations=0
- name: Run Test Action - test_pytorch_detection_transformer
run: pytest --cov-report=xml --cov=art --cov-append -q -vv tests/estimators/object_detection/test_pytorch_detection_transformer.py --framework=pytorch --durations=0
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
Expand Down
1 change: 1 addition & 0 deletions art/estimators/object_detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from art.estimators.object_detection.pytorch_yolo import PyTorchYolo
from art.estimators.object_detection.tensorflow_faster_rcnn import TensorFlowFasterRCNN
from art.estimators.object_detection.tensorflow_v2_faster_rcnn import TensorFlowV2FasterRCNN
from art.estimators.object_detection.pytorch_detection_transformer import PyTorchDetectionTransformer
1,001 changes: 1,001 additions & 0 deletions art/estimators/object_detection/pytorch_detection_transformer.py

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions art/estimators/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,3 +329,14 @@ def set_batchnorm(self, train: bool) -> None:

# pylint: disable=W0212
self._set_layer(train=train, layerinfo=[torch.nn.modules.batchnorm._BatchNorm]) # type: ignore

def set_multihead_attention(self, train: bool) -> None:
"""
Set all multi-head attention layers into train or eval mode.

:param train: False for evaluation mode.
"""
import torch

# pylint: disable=W0212
self._set_layer(train=train, layerinfo=[torch.nn.modules.MultiheadAttention]) # type: ignore
4 changes: 4 additions & 0 deletions notebooks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ models.
[attack_adversarial_patch_faster_rcnn.ipynb](adversarial_patch/attack_adversarial_patch_faster_rcnn.ipynb) [[on nbviewer](https://nbviewer.jupyter.org/github/Trusted-AI/adversarial-robustness-toolbox/blob/main/notebooks/adversarial_patch/attack_adversarial_patch_faster_rcnn.ipynb)]
shows how to set up a TFv2 Faster R-CNN object detector with ART and create an adversarial patch attack that fools the detector.

[attack_adversarial_patch_detr.ipynb](adversarial_patch/attack_adversarial_patch_detr.ipynb) [[on nbviewer](https://nbviewer.jupyter.org/github/Trusted-AI/adversarial-robustness-toolbox/blob/main/notebooks/adversarial_patch/attack_adversarial_patch_detr.ipynb)]
shows how to set up the DEtection TRansformer (DETR) with ART for object detection and create an adversarial patch attack that fools the detector.


<p align="center">
<img src="../utils/data/images/adversarial_patch.png?raw=true" width="200" title="adversarial_patch">
</p>
Expand Down
660 changes: 660 additions & 0 deletions notebooks/adversarial_patch/attack_adversarial_patch_detr.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,309 @@
# MIT License
#
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2023
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
# persons to whom the Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
# Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from __future__ import absolute_import, division, print_function, unicode_literals

import logging

import numpy as np
import pytest

logger = logging.getLogger(__name__)


@pytest.fixture()
@pytest.mark.skip_framework("tensorflow", "tensorflow2v1", "keras", "kerastf", "mxnet", "non_dl_frameworks")
def get_pytorch_detr():
from art.utils import load_dataset
from art.estimators.object_detection.pytorch_detection_transformer import PyTorchDetectionTransformer

MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
INPUT_SHAPE = (3, 32, 32)

object_detector = PyTorchDetectionTransformer(
input_shape=INPUT_SHAPE, clip_values=(0, 1), preprocessing=(MEAN, STD)
)

n_test = 2
(_, _), (x_test, y_test), _, _ = load_dataset("cifar10")
x_test = x_test.transpose(0, 3, 1, 2).astype(np.float32)
x_test = x_test[:n_test]

# Create labels

result = object_detector.predict(x=x_test)

y_test = [
{
"boxes": result[0]["boxes"],
"labels": result[0]["labels"],
"scores": np.ones_like(result[0]["labels"]),
},
{
"boxes": result[1]["boxes"],
"labels": result[1]["labels"],
"scores": np.ones_like(result[1]["labels"]),
},
]

yield object_detector, x_test, y_test


@pytest.mark.only_with_platform("pytorch")
def test_predict(get_pytorch_detr):

object_detector, x_test, _ = get_pytorch_detr

result = object_detector.predict(x=x_test)

assert list(result[0].keys()) == ["boxes", "labels", "scores"]

assert result[0]["boxes"].shape == (100, 4)
expected_detection_boxes = np.asarray([-5.9490204e-03, 1.1947733e01, 3.1993944e01, 3.1925127e01])
np.testing.assert_array_almost_equal(result[0]["boxes"][2, :], expected_detection_boxes, decimal=3)

assert result[0]["scores"].shape == (100,)
expected_detection_scores = np.asarray(
[
0.00679839,
0.0250559,
0.07205943,
0.01115368,
0.03321039,
0.10407761,
0.00113309,
0.01442852,
0.00527624,
0.01240906,
]
)
np.testing.assert_array_almost_equal(result[0]["scores"][:10], expected_detection_scores, decimal=5)

assert result[0]["labels"].shape == (100,)
expected_detection_classes = np.asarray([17, 17, 33, 17, 17, 17, 74, 17, 17, 17])
np.testing.assert_array_almost_equal(result[0]["labels"][:10], expected_detection_classes, decimal=5)


@pytest.mark.only_with_platform("pytorch")
def test_loss_gradient(get_pytorch_detr):

object_detector, x_test, y_test = get_pytorch_detr

grads = object_detector.loss_gradient(x=x_test, y=y_test)

assert grads.shape == (2, 3, 800, 800)

expected_gradients1 = np.asarray(
[
-0.00061366,
0.00322502,
-0.00039866,
-0.00807413,
-0.00476555,
0.00181204,
0.01007765,
0.00415828,
-0.00073114,
0.00018387,
-0.00146992,
-0.00119636,
-0.00098966,
-0.00295517,
-0.0024271,
-0.00131314,
-0.00149217,
-0.00104926,
-0.00154239,
-0.00110989,
0.00092887,
0.00049146,
-0.00292508,
-0.00124526,
0.00140347,
0.00019833,
0.00191074,
-0.00117537,
-0.00080604,
0.00057427,
-0.00061728,
-0.00206535,
]
)

np.testing.assert_array_almost_equal(grads[0, 0, 10, :32], expected_gradients1, decimal=2)

expected_gradients2 = np.asarray(
[
-1.1787530e-03,
-2.8500680e-03,
5.0884970e-03,
6.4504531e-04,
-6.8841036e-05,
2.8184296e-03,
3.0257765e-03,
2.8565727e-04,
-1.0701057e-04,
1.2945699e-03,
7.3593057e-04,
1.0177144e-03,
-2.4692707e-03,
-1.3801848e-03,
6.3182280e-04,
-4.2305476e-04,
4.4307750e-04,
8.5821096e-04,
-7.1204413e-04,
-3.1404425e-03,
-1.5964351e-03,
-1.9222996e-03,
-5.3157361e-04,
-9.9202688e-04,
-1.5815455e-03,
2.0060266e-04,
-2.0584739e-03,
6.6960667e-04,
9.7393827e-04,
-1.6040013e-03,
-6.9741381e-04,
1.4657658e-04,
]
)
np.testing.assert_array_almost_equal(grads[1, 0, 10, :32], expected_gradients2, decimal=2)


@pytest.mark.only_with_platform("pytorch")
def test_errors():

from torch import hub

from art.estimators.object_detection.pytorch_detection_transformer import PyTorchDetectionTransformer

model = hub.load("facebookresearch/detr", "detr_resnet50", pretrained=True)

with pytest.raises(ValueError):
PyTorchDetectionTransformer(
model=model,
clip_values=(1, 2),
attack_losses=("loss_ce", "loss_bbox", "loss_giou"),
)

with pytest.raises(ValueError):
PyTorchDetectionTransformer(
model=model,
clip_values=(-1, 1),
attack_losses=("loss_ce", "loss_bbox", "loss_giou"),
)

from art.defences.postprocessor.rounded import Rounded

post_def = Rounded()
with pytest.raises(ValueError):
PyTorchDetectionTransformer(
model=model,
clip_values=(0, 1),
attack_losses=("loss_ce", "loss_bbox", "loss_giou"),
postprocessing_defences=post_def,
)


@pytest.mark.only_with_platform("pytorch")
def test_preprocessing_defences(get_pytorch_detr):

object_detector, x_test, _ = get_pytorch_detr

from art.defences.preprocessor.spatial_smoothing_pytorch import SpatialSmoothingPyTorch

pre_def = SpatialSmoothingPyTorch()

object_detector.set_params(preprocessing_defences=pre_def)

# Create labels
result = object_detector.predict(x=x_test)

y = [
{
"boxes": result[0]["boxes"],
"labels": result[0]["labels"],
"scores": np.ones_like(result[0]["labels"]),
},
{
"boxes": result[1]["boxes"],
"labels": result[1]["labels"],
"scores": np.ones_like(result[1]["labels"]),
},
]

# Compute gradients
grads = object_detector.loss_gradient(x=x_test, y=y)

assert grads.shape == (2, 3, 800, 800)


@pytest.mark.only_with_platform("pytorch")
def test_compute_losses(get_pytorch_detr):

object_detector, x_test, y_test = get_pytorch_detr
object_detector.attack_losses = "loss_ce"
losses = object_detector.compute_losses(x=x_test, y=y_test)
assert len(losses) == 1


@pytest.mark.only_with_platform("pytorch")
def test_compute_loss(get_pytorch_detr):

object_detector, x_test, _ = get_pytorch_detr
# Create labels
result = object_detector.predict(x_test)

y = [
{
"boxes": result[0]["boxes"],
"labels": result[0]["labels"],
"scores": np.ones_like(result[0]["labels"]),
},
{
"boxes": result[1]["boxes"],
"labels": result[1]["labels"],
"scores": np.ones_like(result[1]["labels"]),
},
]

# Compute loss
loss = object_detector.compute_loss(x=x_test, y=y)

assert pytest.approx(3.9634, abs=0.01) == float(loss)


@pytest.mark.only_with_platform("pytorch")
def test_pgd(get_pytorch_detr):

object_detector, x_test, y_test = get_pytorch_detr

from art.attacks.evasion import ProjectedGradientDescent
from PIL import Image

imgs = []
for i in x_test:
img = Image.fromarray((i * 255).astype(np.uint8).transpose(1, 2, 0))
img = img.resize(size=(800, 800))
imgs.append(np.array(img))
x_test = np.array(imgs).transpose(0, 3, 1, 2)

attack = ProjectedGradientDescent(estimator=object_detector, max_iter=2)
x_test_adv = attack.generate(x=x_test, y=y_test)
np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, x_test_adv, x_test)
Loading