Skip to content

Commit

Permalink
Add summary writer to AdversarialTexturePyTorch
Browse files Browse the repository at this point in the history
Signed-off-by: Beat Buesser <beat.buesser@ie.ibm.com>
  • Loading branch information
Beat Buesser committed Dec 11, 2021
1 parent 0ed10bf commit c2c33fa
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from art.attacks.attack import EvasionAttack
from art.estimators.estimator import BaseEstimator, LossGradientsMixin
from art.estimators.object_tracking.object_tracker import ObjectTrackerMixin
from art.summary_writer import SummaryWriter

if TYPE_CHECKING:
# pylint: disable=C0412
Expand Down Expand Up @@ -67,6 +68,7 @@ def __init__(
step_size: float = 1.0 / 255.0,
max_iter: int = 500,
batch_size: int = 16,
summary_writer: Union[str, bool, SummaryWriter] = False,
verbose: bool = True,
):
"""
Expand All @@ -80,11 +82,18 @@ def __init__(
:param step_size: The step size.
:param max_iter: The number of optimization steps.
:param batch_size: The size of the training batch.
:param summary_writer: Activate summary writer for TensorBoard.
Default is `False` and deactivated summary writer.
If `True` save runs/CURRENT_DATETIME_HOSTNAME in current directory.
If of type `str` save in path.
If of type `SummaryWriter` apply provided custom summary writer.
Use hierarchical folder structure to compare between runs easily. e.g. pass in
‘runs/exp1’, ‘runs/exp2’, etc. for each new experiment to compare across them.
:param verbose: Show progress bars.
"""
import torch # lgtm [py/repeated-import]

super().__init__(estimator=estimator)
super().__init__(estimator=estimator, summary_writer=summary_writer)
self.patch_height = patch_height
self.patch_width = patch_width
self.x_min = x_min
Expand All @@ -95,6 +104,9 @@ def __init__(
self.verbose = verbose
self._check_params()

self._batch_id = 0
self._i_max_iter = 0

self.patch_shape = (self.patch_height, self.patch_width, 3)

if self.estimator.channels_first:
Expand Down Expand Up @@ -142,6 +154,18 @@ def _train_step(

gradients = self._patch.grad.sign() * self.step_size

# Write summary
if self.summary_writer is not None: # pragma: no cover
self.summary_writer.update(
batch_id=self._batch_id,
global_step=self._i_max_iter,
grad=np.expand_dims(self._patch.grad.detach().cpu().numpy(), axis=0),
patch=None,
estimator=None,
x=None,
y=None,
)

with torch.no_grad():
self._patch[:] = torch.clamp(
self._patch + gradients, min=self.estimator.clip_values[0], max=self.estimator.clip_values[1]
Expand Down Expand Up @@ -356,8 +380,13 @@ def __getitem__(self, idx):
drop_last=False,
)

for _ in trange(self.max_iter, desc="Adversarial Texture PyTorch", disable=not self.verbose):
for i_max_iter in trange(self.max_iter, desc="Adversarial Texture PyTorch", disable=not self.verbose):

self._i_max_iter = i_max_iter
self._batch_id = 0

for videos_i, target_i, y_init_i, foreground_i in data_loader:
self._batch_id += 1
videos_i = videos_i.to(self.estimator.device)
y_init_i = y_init_i.to(self.estimator.device)
foreground_i = foreground_i.to(self.estimator.device)
Expand All @@ -367,6 +396,21 @@ def __getitem__(self, idx):

_ = self._train_step(videos=videos_i, target=target_i_list, y_init=y_init_i, foreground=foreground_i)

# Write summary
if self.summary_writer is not None: # pragma: no cover
self.summary_writer.update(
batch_id=self._batch_id,
global_step=self._i_max_iter,
grad=None,
patch=self._patch.detach().cpu().numpy(),
estimator=self.estimator,
x=videos_i.detach().cpu().numpy(),
y=target_i_list,
)

if self.summary_writer is not None:
self.summary_writer.reset()

return self.apply_patch(x=x, foreground=foreground)

def apply_patch(
Expand Down
43 changes: 37 additions & 6 deletions art/estimators/object_tracking/pytorch_goturn.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _get_losses(
:param reduction: Specifies the reduction to apply to the output: 'none' | 'sum'.
'none': no reduction will be applied.
'sum': the output will be summed.
:return: Loss gradients of the same shape as `x`.
:return: Loss dictionary, list of input tensors, and list of gradient tensors.
"""
import torch # lgtm [py/repeated-import]

Expand Down Expand Up @@ -697,17 +697,48 @@ def get_activations(
"""
raise NotImplementedError

def compute_losses(self, x: np.ndarray, y: List[Dict[str, np.ndarray]]) -> List[Dict[str, np.ndarray]]:
def compute_losses(self, x: np.ndarray, y: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]:
"""
Not implemented.
Compute losses.
:param x: Samples of shape (nb_samples, nb_frames, height, width, nb_channels).
:param y: Target values of format `List[Dict[str, np.ndarray]]`, one dictionary for each input image. The keys
of the dictionary are:
- boxes [N_FRAMES, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and
0 <= y1 < y2 <= H.
:return: Dictionary of loss components.
"""
raise NotImplementedError
output = self.compute_loss(x=x, y=y)
output_dict = dict()
output_dict["torch.nn.L1Loss"] = output
return output_dict

def compute_loss(self, x: np.ndarray, y: List[Dict[str, np.ndarray]], **kwargs) -> np.ndarray:
"""
Not implemented.
Compute loss.
:param x: Samples of shape (nb_samples, nb_frames, height, width, nb_channels).
:param y: Target values of format `List[Dict[str, np.ndarray]]`, one dictionary for each input image. The keys
of the dictionary are:
- boxes [N_FRAMES, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and
0 <= y1 < y2 <= H.
:return: Total loss.
"""
raise NotImplementedError
import torch # lgtm [py/repeated-import]

output_dict, _, _ = self._get_losses(x=x, y=y)

if isinstance(output_dict["torch.nn.L1Loss"], torch.Tensor):
output = output_dict["torch.nn.L1Loss"].detach().cpu().numpy()
else:
output_list = list()
for out in output_dict["torch.nn.L1Loss"]:
output_list.append(out.detach().cpu().numpy())
output = np.array(output_list)

return output

def init(self, image: "PIL.JpegImagePlugin.JpegImageFile", box: np.ndarray):
"""
Expand Down
2 changes: 2 additions & 0 deletions art/summary_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ def update(

# Patch
if patch is not None:
if patch.shape[2] in [1, 3, 4]:
patch = np.transpose(patch, (2, 0, 1))
self.summary_writer.add_image(
"patch",
patch,
Expand Down

0 comments on commit c2c33fa

Please sign in to comment.