diff --git a/mart/callbacks/eval_mode.py b/mart/callbacks/eval_mode.py index be3b6397..639444c9 100644 --- a/mart/callbacks/eval_mode.py +++ b/mart/callbacks/eval_mode.py @@ -4,23 +4,47 @@ # SPDX-License-Identifier: BSD-3-Clause # +from __future__ import annotations + from pytorch_lightning.callbacks import Callback +from mart import utils + +logger = utils.get_pylogger(__name__) + __all__ = ["AttackInEvalMode"] class AttackInEvalMode(Callback): """Switch the model into eval mode during attack.""" - def __init__(self): - self.training_mode_status = None - - def on_train_start(self, trainer, model): - self.training_mode_status = model.training - model.train(False) - - def on_train_end(self, trainer, model): - assert self.training_mode_status is not None - - # Resume the previous training status of the model. - model.train(self.training_mode_status) + def __init__(self, module_classes: type | list[type]): + # FIXME: convert strings to classes using hydra.utils.get_class? This will clean up some verbosity in configuration but will require importing hydra in this callback. + if isinstance(module_classes, type): + module_classes = [module_classes] + + self.module_classes = tuple(module_classes) + + def setup(self, trainer, pl_module, stage): + if stage != "fit": + return + + # Log to the console so the user can see visually see which modules will be in eval mode during training. + for name, module in pl_module.named_modules(): + if isinstance(module, self.module_classes): + logger.info( + f"Setting eval mode for {name} ({module.__class__.__module__}.{module.__class__.__name__})" + ) + + def on_train_epoch_start(self, trainer, pl_module): + # We must use on_train_epoch_start because PL will set pl_module to train mode right before this callback. + # See: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#hooks + for name, module in pl_module.named_modules(): + if isinstance(module, self.module_classes): + module.eval() + + def on_train_epoch_end(self, trainer, pl_module): + # FIXME: Why is this necessary? + for name, module in pl_module.named_modules(): + if isinstance(module, self.module_classes): + module.train() diff --git a/mart/callbacks/no_grad_mode.py b/mart/callbacks/no_grad_mode.py index cfb90ead..4a86d985 100644 --- a/mart/callbacks/no_grad_mode.py +++ b/mart/callbacks/no_grad_mode.py @@ -4,8 +4,15 @@ # SPDX-License-Identifier: BSD-3-Clause # +from __future__ import annotations + +import torch from pytorch_lightning.callbacks import Callback +from mart import utils + +logger = utils.get_pylogger(__name__) + __all__ = ["ModelParamsNoGrad"] @@ -15,10 +22,25 @@ class ModelParamsNoGrad(Callback): This callback should not change the result. Don't use unless an attack runs faster. """ - def on_train_start(self, trainer, model): - for param in model.parameters(): - param.requires_grad_(False) + def __init__(self, module_names: str | list[str] = None): + if isinstance(module_names, str): + module_names = [module_names] + + self.module_names = module_names + + def setup(self, trainer, pl_module, stage): + if stage != "fit": + return + + # We use setup, and not on_train_start, so that mart.optim.OptimizerFactory can ignore parameters with no gradients. + # See: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#hooks + for name, param in pl_module.named_parameters(): + if any(name.startswith(module_name) for module_name in self.module_names): + logger.info(f"Disabling gradient for {name}") + param.requires_grad_(False) - def on_train_end(self, trainer, model): - for param in model.parameters(): - param.requires_grad_(True) + def teardown(self, trainer, pl_module, stage): + for name, param in pl_module.named_parameters(): + if any(name.startswith(module_name) for module_name in self.module_names): + # FIXME: Why is this necessary? + param.requires_grad_(True) diff --git a/mart/callbacks/visualizer.py b/mart/callbacks/visualizer.py index 3354321e..a81a94b7 100644 --- a/mart/callbacks/visualizer.py +++ b/mart/callbacks/visualizer.py @@ -4,38 +4,37 @@ # SPDX-License-Identifier: BSD-3-Clause # -import os +from operator import attrgetter from pytorch_lightning.callbacks import Callback -from torchvision.transforms import ToPILImage -__all__ = ["PerturbedImageVisualizer"] +__all__ = ["ImageVisualizer"] -class PerturbedImageVisualizer(Callback): - """Save adversarial images as files.""" +class ImageVisualizer(Callback): + def __init__(self, frequency: int = 100, **tag_paths): + self.frequency = frequency + self.tag_paths = tag_paths - def __init__(self, folder): - super().__init__() + def log_image(self, trainer, tag, image): + # Add image to each logger + for logger in trainer.loggers: + # FIXME: Should we just use isinstance(logger.experiment, SummaryWriter)? + if not hasattr(logger.experiment, "add_image"): + continue - # FIXME: This should use the Trainer's logging directory. - self.folder = folder - self.convert = ToPILImage() + logger.experiment.add_image(tag, image, global_step=trainer.global_step) - if not os.path.isdir(self.folder): - os.makedirs(self.folder) + def log_images(self, trainer, pl_module): + for tag, path in self.tag_paths.items(): + image = attrgetter(path)(pl_module) + self.log_image(trainer, tag, image) - def on_train_batch_end(self, trainer, model, outputs, batch, batch_idx): - # Save input and target for on_train_end - self.input = batch["input"] - self.target = batch["target"] + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if batch_idx % self.frequency != 0: + return - def on_train_end(self, trainer, model): - # FIXME: We should really just save this to outputs instead of recomputing adv_input - adv_input = model(input=self.input, target=self.target) + self.log_images(trainer, pl_module) - for img, tgt in zip(adv_input, self.target): - fname = tgt["file_name"] - fpath = os.path.join(self.folder, fname) - im = self.convert(img / 255) - im.save(fpath) + def on_train_end(self, trainer, pl_module): + self.log_images(trainer, pl_module) diff --git a/mart/configs/callbacks/attack_in_eval_mode.yaml b/mart/configs/callbacks/attack_in_eval_mode.yaml index 2acdc953..4ca096b0 100644 --- a/mart/configs/callbacks/attack_in_eval_mode.yaml +++ b/mart/configs/callbacks/attack_in_eval_mode.yaml @@ -1,2 +1,11 @@ attack_in_eval_mode: _target_: mart.callbacks.AttackInEvalMode + module_classes: ??? + # - _target_: hydra.utils.get_class + # path: mart.models.LitModular + # - _target_: hydra.utils.get_class + # path: torch.nn.BatchNorm2d + # - _target_: hydra.utils.get_class + # path: torch.nn.Dropout + # - _target_: hydra.utils.get_class + # path: torch.nn.SyncBatchNorm diff --git a/mart/configs/callbacks/no_grad_mode.yaml b/mart/configs/callbacks/no_grad_mode.yaml index 6b4312fd..d12d18e9 100644 --- a/mart/configs/callbacks/no_grad_mode.yaml +++ b/mart/configs/callbacks/no_grad_mode.yaml @@ -1,2 +1,3 @@ -attack_in_eval_mode: +no_grad_mode: _target_: mart.callbacks.ModelParamsNoGrad + module_names: ??? diff --git a/mart/configs/callbacks/perturbation_visualizer.yaml b/mart/configs/callbacks/perturbation_visualizer.yaml new file mode 100644 index 00000000..5a673db5 --- /dev/null +++ b/mart/configs/callbacks/perturbation_visualizer.yaml @@ -0,0 +1,4 @@ +perturbation_visualizer: + _target_: mart.callbacks.ImageVisualizer + frequency: 100 + perturbation: ??? diff --git a/mart/configs/datamodule/coco.yaml b/mart/configs/datamodule/coco.yaml index a4ec3403..5416cae2 100644 --- a/mart/configs/datamodule/coco.yaml +++ b/mart/configs/datamodule/coco.yaml @@ -1,5 +1,5 @@ defaults: - - default.yaml + - default train_dataset: _target_: mart.datamodules.coco.CocoDetection diff --git a/mart/configs/datamodule/coco_yolov3.yaml b/mart/configs/datamodule/coco_yolov3.yaml new file mode 100644 index 00000000..27be09d1 --- /dev/null +++ b/mart/configs/datamodule/coco_yolov3.yaml @@ -0,0 +1,29 @@ +defaults: + - coco + +num_workers: 1 + +train_dataset: + transforms: + transforms: + - _target_: torchvision.transforms.ToTensor + - _target_: mart.transforms.ConvertCocoPolysToMask + - _target_: mart.transforms.PadToSquare + fill: 0.5 + - _target_: mart.transforms.Resize + size: [416, 416] + - _target_: mart.transforms.ConvertBoxesToCXCYHW + - _target_: mart.transforms.RemapLabels + - _target_: mart.transforms.PackBoxesAndLabels + num_classes: 80 + - _target_: mart.transforms.ConvertInstanceSegmentationToPerturbable + +val_dataset: + transforms: ${..train_dataset.transforms} + +test_dataset: + transforms: ${..val_dataset.transforms} + +collate_fn: + _target_: hydra.utils.get_method + path: mart.datamodules.coco.yolo_collate_fn diff --git a/mart/configs/experiment/COCO_YOLOv3.yaml b/mart/configs/experiment/COCO_YOLOv3.yaml new file mode 100644 index 00000000..130c63fa --- /dev/null +++ b/mart/configs/experiment/COCO_YOLOv3.yaml @@ -0,0 +1,34 @@ +# @package _global_ + +defaults: + - override /datamodule: coco_yolov3 + - override /model: yolov3 + - override /metric: average_precision + - override /optimization: super_convergence + +task_name: "COCO_YOLOv3" +tags: ["evaluation"] + +optimized_metric: "test_metrics/map" + +trainer: + # 117,266 training images, 6 epochs, batch_size=16, 43,974.75 + max_steps: 43975 + # FIXME: "nms_kernel" not implemented for 'BFloat16', torch.ops.torchvision.nms(). + precision: 32 + +datamodule: + num_workers: 32 + ims_per_batch: 16 + +model: + load_state_dict: + yolov3: ${paths.data_dir}/yolov3_original.pt + + # yolov3 model does not produce preds/targets in training sequence + training_metrics: null + + optimizer: + lr: 0.001 + momentum: 0.9 + weight_decay: 0.0005 diff --git a/mart/configs/experiment/COCO_YOLOv3_ShapeShifter.yaml b/mart/configs/experiment/COCO_YOLOv3_ShapeShifter.yaml new file mode 100644 index 00000000..f9e80aa3 --- /dev/null +++ b/mart/configs/experiment/COCO_YOLOv3_ShapeShifter.yaml @@ -0,0 +1,166 @@ +# @package _global_ + +defaults: + - /attack/perturber@model.modules.perturbation: default + - /attack/perturber/initializer@model.modules.perturbation.initializer: uniform + - /attack/perturber/projector@model.modules.perturbation.projector: range + - /attack/composer@model.modules.input_adv: warp_composite + - /attack/gradient_modifier@model.gradient_modifier: lp_normalizer + - override /optimization: super_convergence + - override /datamodule: coco_yolov3 + - override /model: yolov3 + - override /metric: average_precision + - override /callbacks: + [ + model_checkpoint, + lr_monitor, + perturbation_visualizer, + gradient_monitor, + attack_in_eval_mode, + no_grad_mode, + ] + +task_name: "COCO_YOLOv3_ShapeShifter" +tags: ["adv"] + +optimized_metric: "test_metrics/map" + +trainer: + # 64115 training images, batch_size=16, FLOOR(64115/16) = 4007 + max_steps: 40070 # 10 epochs + # mAP can be slow to compute so limit number of images + limit_val_batches: 100 + limit_test_batches: 100 + precision: 32 + +callbacks: + model_checkpoint: + monitor: "validation_metrics/map" + mode: "min" + + attack_in_eval_mode: + module_classes: + - _target_: hydra.utils.get_class + path: torch.nn.BatchNorm2d + + no_grad_mode: + module_names: "model.yolov3" + + perturbation_visualizer: + perturbation: "model.perturbation.perturbation" + frequency: 500 + +datamodule: + num_workers: 32 + ims_per_batch: 16 + + train_dataset: + annFile: ${paths.data_dir}/coco/annotations/person_instances_train2017.json + val_dataset: + annFile: ${paths.data_dir}/coco/annotations/person_instances_val2017.json + test_dataset: + annFile: ${paths.data_dir}/coco/annotations/person_instances_val2017.json + +model: + modules: + perturbation: + size: [3, 416, 234] + + initializer: + min: 0.49 + max: 0.51 + + projector: + min: 0.0 + max: 1.0 + + input_adv: + warp: + _target_: torchvision.transforms.Compose + transforms: + - _target_: mart.transforms.ColorJitter + brightness: [0.5, 1.5] + contrast: [0.5, 1.5] + saturation: [0.5, 1.0] + hue: [-0.05, 0.05] + - _target_: torchvision.transforms.RandomAffine + degrees: [-5, 5] + translate: [0.1, 0.25] + scale: [0.4, 0.6] + shear: [-3, 3, -3, 3] + interpolation: 2 # BILINEAR + clamp: [0, 1] + + loss: + weights: [1, 1] + + load_state_dict: + yolov3: ${paths.data_dir}/yolov3_original.pt + + optimizer: + lr: 0.01 + momentum: 0.9 + + gradient_modifier: null + + training_sequence: + seq005: perturbation + seq006: input_adv + seq010: + yolov3: + x: "input_adv" + seq030: + loss: + _call_with_args_: + - losses.hide_target_objects_loss + - losses.correct_target_class_loss + weights: + - 10 + - 1 + + training_step_log: + - loss + - total_loss + - coord_loss + - obj_loss + - noobj_loss + - class_loss + - hide_objects_loss + - target_class_loss + - hide_target_objects_loss + - correct_target_class_loss + - target_count + - score_count + - target_score_count + + training_metrics: null + + validation_sequence: + seq005: perturbation + seq006: input_adv + seq010: + yolov3: + x: "input_adv" + seq030: + loss: + _call_with_args_: + - losses.hide_target_objects_loss + - losses.correct_target_class_loss + weights: + - 10 + - 1 + + test_sequence: + seq005: perturbation + seq006: input_adv + seq010: + yolov3: + x: "input_adv" + seq030: + loss: + _call_with_args_: + - losses.hide_target_objects_loss + - losses.correct_target_class_loss + weights: + - 10 + - 1 diff --git a/mart/configs/model/yolov3.yaml b/mart/configs/model/yolov3.yaml new file mode 100644 index 00000000..aa814310 --- /dev/null +++ b/mart/configs/model/yolov3.yaml @@ -0,0 +1,183 @@ +defaults: + - modular + +modules: + yolov3: + _target_: mart.models.yolov3.YoloNetV3 + + losses: + _target_: mart.models.yolov3.Loss + image_size: 416 # FIXME: use ${training_data.transform.image_size}? + average: True + + loss: + _target_: mart.nn.Sum + + detections: + _target_: mart.models.yolov3.Detections + nms: true + conf_thres: 0.1 + nms_thres: 0.4 + + output: + _target_: mart.nn.ReturnKwargs + +# training sequence does not produce preds/targets +training_metrics: null + +training_sequence: + seq010: + yolov3: + x: "input" + + seq020: + losses: + logits: yolov3.logits + target: target + + seq030: + loss: + _call_with_args_: + - losses.total_loss + + seq040: + detections: + preds: yolov3.preds + target: target + + seq050: + output: + loss: loss + total_loss: losses.total_loss + coord_loss: losses.coord_loss + obj_loss: losses.obj_loss + noobj_loss: losses.noobj_loss + class_loss: losses.class_loss + hide_objects_loss: losses.hide_objects_loss + target_class_loss: losses.target_class_loss + hide_target_objects_loss: losses.hide_target_objects_loss + correct_target_class_loss: losses.correct_target_class_loss + target_count: losses.target_count + score_count: losses.score_count + target_score_count: losses.target_score_count + +validation_sequence: + seq010: + yolov3: + x: "input" + + seq020: + losses: + logits: yolov3.logits + target: target + + seq030: + loss: + _call_with_args_: + - losses.total_loss + + seq040: + detections: + preds: yolov3.preds + target: target + + seq050: + output: + preds: detections.preds + target: detections.targets + loss: loss + total_loss: losses.total_loss + coord_loss: losses.coord_loss + obj_loss: losses.obj_loss + noobj_loss: losses.noobj_loss + class_loss: losses.class_loss + hide_objects_loss: losses.hide_objects_loss + target_class_loss: losses.target_class_loss + hide_target_objects_loss: losses.hide_target_objects_loss + correct_target_class_loss: losses.correct_target_class_loss + target_count: losses.target_count + score_count: losses.score_count + target_score_count: losses.target_score_count + +test_sequence: + seq010: + yolov3: + x: "input" + + seq020: + losses: + logits: yolov3.logits + target: target + + seq030: + loss: + _call_with_args_: + - losses.total_loss + + seq040: + detections: + preds: yolov3.preds + target: target + + seq050: + output: + preds: detections.preds + target: detections.targets + loss: loss + total_loss: losses.total_loss + coord_loss: losses.coord_loss + obj_loss: losses.obj_loss + noobj_loss: losses.noobj_loss + class_loss: losses.class_loss + hide_objects_loss: losses.hide_objects_loss + target_class_loss: losses.target_class_loss + hide_target_objects_loss: losses.hide_target_objects_loss + correct_target_class_loss: losses.correct_target_class_loss + target_count: losses.target_count + score_count: losses.score_count + target_score_count: losses.target_score_count + +training_step_log: + - loss + - total_loss + - coord_loss + - obj_loss + - noobj_loss + - class_loss + - hide_objects_loss + - target_class_loss + - hide_target_objects_loss + - correct_target_class_loss + - target_count + - score_count + - target_score_count + +validation_step_log: + - loss + - total_loss + - coord_loss + - obj_loss + - noobj_loss + - class_loss + - hide_objects_loss + - target_class_loss + - hide_target_objects_loss + - correct_target_class_loss + - target_count + - score_count + - target_score_count + +test_step_log: + - loss + - total_loss + - coord_loss + - obj_loss + - noobj_loss + - class_loss + - hide_objects_loss + - target_class_loss + - hide_target_objects_loss + - correct_target_class_loss + - target_count + - score_count + - target_score_count diff --git a/mart/datamodules/coco.py b/mart/datamodules/coco.py index 42ddcebb..795f492d 100644 --- a/mart/datamodules/coco.py +++ b/mart/datamodules/coco.py @@ -8,8 +8,11 @@ from typing import Any, Callable, List, Optional import numpy as np +import torch +from torch.utils.data import default_collate from torchvision.datasets.coco import CocoDetection as CocoDetection_ from torchvision.datasets.folder import default_loader +from yolov3.datasets.utils import collate_img_label_fn as collate_img_label_fn_ __all__ = ["CocoDetection"] @@ -44,6 +47,10 @@ def __init__( self.modalities = modalities + # Targets can contain a lot of information... + # https://discuss.pytorch.org/t/runtimeerror-received-0-items-of-ancdata/4999/4 + torch.multiprocessing.set_sharing_strategy("file_system") + def _load_image(self, id: int) -> Any: if self.modalities is None: return super()._load_image(id) @@ -89,3 +96,38 @@ def __getitem__(self, index: int): # Source: https://github.com/pytorch/vision/blob/dc07ac2add8285e16a716564867d0b4b953f6735/references/detection/utils.py#L203 def collate_fn(batch): return tuple(zip(*batch)) + + +def to_padded_tensor(tensors, dim=0, fill=0.0): + sizes = np.array([list(t.shape) for t in tensors]) + max_dim_size = sizes[:, dim].max() + sizes[:, dim] = max_dim_size - sizes[:, dim] + + zeros = [ + torch.full(s.tolist(), fill, device=t.device, dtype=t.dtype) + for t, s in zip(tensors, sizes) + ] + tensors = [torch.cat((t, z), dim=dim) for t, z in zip(tensors, zeros)] + + return tensors + + +def yolo_collate_fn(batch): + images, targets = tuple(zip(*batch)) + + images = default_collate(images) + + # Turn tuple of dicts into dict of tuples + keys = targets[0].keys() + target = {k: tuple(t[k] for t in targets) for k in keys} + + # Pad packed using torch.nested + target["packed"] = to_padded_tensor(target["packed"]) + + COLLATABLE_KEYS = ["packed", "packed_length", "perturbable_mask"] + + for key in target.keys(): + if key in COLLATABLE_KEYS: + target[key] = default_collate(target[key]) + + return images, target diff --git a/mart/models/yolov3.py b/mart/models/yolov3.py new file mode 100644 index 00000000..38ed9db7 --- /dev/null +++ b/mart/models/yolov3.py @@ -0,0 +1,213 @@ +# +# Copyright (C) 2023 Intel Corporation +# +# SPDX-License-Identifier: BSD-3-Clause +# + +import torch +import torch.nn.functional as F +import yolov3 +from yolov3.config import ANCHORS, NUM_ANCHORS_PER_SCALE, NUM_ATTRIB, NUM_CLASSES +from yolov3.inference import post_process +from yolov3.model import YoloLayer as YoloLayer_ +from yolov3.model import YoloNetV3 as YoloNetV3_ +from yolov3.training import yolo_loss_fn +from yolov3.utils import cxcywh_to_xywh + +from mart.utils import MonkeyPatch + + +class YoloNetV3(YoloNetV3_): + def __init__(self): + with MonkeyPatch(yolov3.model, "YoloLayer", YoloLayer): + super().__init__() + + def forward(self, x): + tmp1, tmp2, tmp3 = self.darknet(x) + out1, out2, out3 = self.yolo_tail(tmp1, tmp2, tmp3) + logits = torch.cat((out1["logits"], out2["logits"], out3["logits"]), 1) + preds = torch.cat((out1["preds"], out2["preds"], out3["preds"]), 1) + + return {"logits": logits, "preds": preds} + + +class YoloLayer(torch.nn.Module): + def __init__(self, scale, stride): + super().__init__() + + if scale == "s": + idx = (0, 1, 2) + elif scale == "m": + idx = (3, 4, 5) + elif scale == "l": + idx = (6, 7, 8) + else: + idx = None + + self.anchors = torch.tensor([ANCHORS[i] for i in idx]) + self.stride = stride + + def forward(self, x): + num_batch = x.size(0) + num_grid = x.size(2) + + output_raw = ( + x.view(num_batch, NUM_ANCHORS_PER_SCALE, NUM_ATTRIB, num_grid, num_grid) + .permute(0, 1, 3, 4, 2) + .contiguous() + .view(num_batch, -1, NUM_ATTRIB) + ) + + prediction_raw = ( + x.view(num_batch, NUM_ANCHORS_PER_SCALE, NUM_ATTRIB, num_grid, num_grid) + .permute(0, 1, 3, 4, 2) + .contiguous() + ) + + self.anchors = self.anchors.to(x.device).float() + # Calculate offsets for each grid + grid_tensor = torch.arange(num_grid, dtype=torch.float, device=x.device).repeat( + num_grid, 1 + ) + grid_x = grid_tensor.view([1, 1, num_grid, num_grid]) + grid_y = grid_tensor.t().view([1, 1, num_grid, num_grid]) + anchor_w = self.anchors[:, 0:1].view((1, -1, 1, 1)) + anchor_h = self.anchors[:, 1:2].view((1, -1, 1, 1)) + + # Get outputs + x_center_pred = (torch.sigmoid(prediction_raw[..., 0]) + grid_x) * self.stride # Center x + y_center_pred = (torch.sigmoid(prediction_raw[..., 1]) + grid_y) * self.stride # Center y + w_pred = torch.exp(prediction_raw[..., 2]) * anchor_w # Width + h_pred = torch.exp(prediction_raw[..., 3]) * anchor_h # Height + bbox_pred = torch.stack((x_center_pred, y_center_pred, w_pred, h_pred), dim=4).view( + (num_batch, -1, 4) + ) # cxcywh + conf_pred = torch.sigmoid(prediction_raw[..., 4]).view(num_batch, -1, 1) # Conf + cls_pred = torch.sigmoid(prediction_raw[..., 5:]).view( + num_batch, -1, NUM_CLASSES + ) # Cls pred one-hot. + + output = torch.cat((bbox_pred, conf_pred, cls_pred), -1) + + return {"logits": output_raw, "preds": output} + + +class Loss(torch.nn.Module): + def __init__(self, image_size, average=True, score_thresh=0.01, target_idx=0): + super().__init__() + + self.image_size = image_size + self.average = average + self.score_thresh = score_thresh + self.target_idx = target_idx + + def forward(self, logits, target, **kwargs): + targets = target["packed"] + lengths = target["packed_length"] + + losses = yolo_loss_fn(logits, targets, lengths, self.image_size, self.average) + total_loss, coord_loss, obj_loss, noobj_loss, class_loss = losses + + # normalize individual losses by batch size + coord_loss = coord_loss / logits.shape[0] + obj_loss = obj_loss / logits.shape[0] + noobj_loss = noobj_loss / logits.shape[0] + class_loss = class_loss / logits.shape[0] + + pred_conf_logit = logits[..., 4] + pred_conf_score = torch.sigmoid(pred_conf_logit) + score_mask = pred_conf_score > self.score_thresh + + class_logits = logits[..., 5:] + target_mask = torch.argmax(class_logits, dim=-1) == self.target_idx + + # make objectness go to zero + tgt_zero = torch.zeros(pred_conf_logit.size(), device=pred_conf_logit.device) + hide_objects_losses = F.binary_cross_entropy_with_logits( + pred_conf_logit, tgt_zero, reduction="none" + ) + hide_objects_loss = hide_objects_losses[score_mask].sum() + hide_objects_loss = hide_objects_loss / logits.shape[0] + + # make target objectness go to zero + hide_target_objects_loss = hide_objects_losses[target_mask & score_mask].sum() + hide_target_objects_loss = hide_target_objects_loss / logits.shape[0] + + # make target logit go to zero + target_class_logit = class_logits[..., 0] # 0 == person + target_class_losses = F.binary_cross_entropy_with_logits( + target_class_logit, tgt_zero, reduction="none" + ) + target_class_loss = target_class_losses[score_mask].sum() + target_class_loss = target_class_loss / logits.shape[0] + + # make correctly predicted target class logit go to zero + correct_target_class_loss = target_class_losses[target_mask & score_mask].sum() + correct_target_class_loss = correct_target_class_loss / logits.shape[0] + + score_count = score_mask.sum() / logits.shape[0] + target_count = target_mask.sum() / logits.shape[0] + target_score_count = (target_mask & score_mask).sum() / logits.shape[0] + + return { + "total_loss": total_loss, + "coord_loss": coord_loss, + "obj_loss": obj_loss, + "noobj_loss": noobj_loss, + "class_loss": class_loss, + "hide_objects_loss": hide_objects_loss, + "hide_target_objects_loss": hide_target_objects_loss, + "target_class_loss": target_class_loss, + "correct_target_class_loss": correct_target_class_loss, + "score_count": score_count, + "target_count": target_count, + "target_score_count": target_score_count, + } + + +class Detections(torch.nn.Module): + def __init__(self, nms=True, conf_thres=0.8, nms_thres=0.4): + super().__init__() + + self.nms = nms + self.conf_thres = conf_thres + self.nms_thres = nms_thres + + @staticmethod + def xywh_to_xyxy(boxes): + boxes[:, 2] = boxes[:, 0] + boxes[:, 2] + boxes[:, 3] = boxes[:, 1] + boxes[:, 3] + + return boxes + + @staticmethod + def tensor_to_dict(detection): + boxes = detection[:, 0:4] + scores = detection[:, 4] + labels = detection[:, 5:] + + boxes = cxcywh_to_xywh(boxes) + boxes = Detections.xywh_to_xyxy(boxes) + + if labels.shape[1] == 1: # index + labels = labels[:, 0].to(int) + else: # one-hot + labels = labels.argmax(dim=1) + + return {"boxes": boxes, "labels": labels, "scores": scores} + + @torch.no_grad() + def forward(self, preds, target, **kwargs): + detections = post_process(preds, self.nms, self.conf_thres, self.nms_thres) + + # FIXME: This should be another module + # Convert detections and targets to List[dict[str, torch.Tensor]]. This is the format + # torchmetrics wants. + preds = [Detections.tensor_to_dict(det) for det in detections] + + targets = target["packed"] + lengths = target["packed_length"] + targets = [target[:length] for target, length in zip(targets, lengths)] + targets = [Detections.tensor_to_dict(target) for target in targets] + + return {"preds": preds, "targets": targets} diff --git a/mart/transforms/extended.py b/mart/transforms/extended.py index 13cd0e74..87f13ae0 100644 --- a/mart/transforms/extended.py +++ b/mart/transforms/extended.py @@ -4,6 +4,8 @@ # SPDX-License-Identifier: BSD-3-Clause # +from __future__ import annotations + import logging import os from typing import Dict, Optional, Tuple @@ -26,9 +28,16 @@ "Lambda", "SplitLambda", "LoadPerturbableMask", + "LoadTensors", "ConvertInstanceSegmentationToPerturbable", "RandomHorizontalFlip", "ConvertCocoPolysToMask", + "PadToSquare", + "Resize", + "ConvertBoxesToCXCYHW", + "RemapLabels", + "PackBoxesAndLabels", + "CreatePerturbableMaskFromImage", ] @@ -115,7 +124,7 @@ class ConvertInstanceSegmentationToPerturbable(ExTransform): """Merge all instance masks and reverse.""" def __call__(self, image, target): - perturbable_mask = torch.sum(target["masks"], dim=0) == 0 + perturbable_mask = torch.sum(target["masks"], dim=0, keepdim=True) == 0 # Convert to float to be differentiable. target["perturbable_mask"] = perturbable_mask.float() @@ -139,6 +148,26 @@ def __call__(self, image, target): return image, target +class LoadTensors(ExTransform): + def __init__(self, root, ext=".pt") -> None: + self.root = root + self.ext = ext + + def __call__(self, image, target): + filename, ext = os.path.splitext(target["file_name"]) + + metadata = torch.load( + os.path.join(self.root, filename + self.ext), map_location=image.device + ) + assert isinstance(metadata, dict) + + for key in metadata: + assert key not in target + target[key] = metadata[key] + + return image, target + + class RandomHorizontalFlip(T.RandomHorizontalFlip, ExTransform): """Flip the image and annotations including boxes, masks, keypoints and the perturable_masks.""" @@ -173,8 +202,8 @@ def flip_perturbable_mask(image, target): return image, target def forward( - self, image: Tensor, target: Optional[Dict[str, Tensor]] = None - ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + self, image: Tensor, target: dict[str, Tensor] | None = None + ) -> tuple[Tensor, dict[str, Tensor] | None]: if torch.rand(1) < self.p: image = F.hflip(image) if target is not None: @@ -190,3 +219,294 @@ def forward( class ConvertCocoPolysToMask(ConvertCocoPolysToMask_, ExTransform): pass + + +class PadToSquare(ExTransform): + def __init__(self, fill): + self.fill = fill + + def __call__( + self, + image: Tensor, # CHW + target: dict[str, Tensor] | None = None, + ): + w, h = F.get_image_size(image) + + l_or_t = abs(h - w) // 2 + r_or_b = abs(h - w) - l_or_t + + # padding is (left, top, right, bottom) + if h > w: + padding = (l_or_t, 0, r_or_b, 0) + else: + padding = (0, l_or_t, 0, r_or_b) + + image = F.pad(image, padding, fill=self.fill) + + if target is not None: + if "boxes" in target: + target["boxes"] = self.pad_boxes(target["boxes"], padding) + if "masks" in target: + target["masks"] = self.pad_masks(target["masks"], padding) + if "keypoints" in target: + target["keypoints"] = self.pad_keypoints(target["keypoints"], padding) + if "perturbable_mask" in target: + target["perturbable_mask"] = self.pad_masks(target["perturbable_mask"], padding) + if "gs_coords" in target: + target["gs_coords"] = self.pad_coordinates(target["gs_coords"], padding) + + return image, target + + def pad_boxes(self, boxes, padding): + boxes[:, 0] += padding[0] # X + left + boxes[:, 1] += padding[1] # Y + top + boxes[:, 2] += padding[0] # X + left + boxes[:, 3] += padding[1] # Y + top + + return boxes + + def pad_masks(self, masks, padding): + return F.pad(masks, padding, fill=0) + + def pad_keypoints(self, keypoints, padding): + raise NotImplementedError + + def pad_coordinates(self, coordinates, padding): + # coordinates are [[left, top], [right, top], [right, bottom], [left, bottom]] + # padding is [left, top, right bottom] + coordinates[:, 0] += padding[0] # left padding + coordinates[:, 1] += padding[1] # top padding + + return coordinates + + +class Resize(ExTransform): + def __init__(self, size): + self.size = size + + def __call__( + self, + image: Tensor, + target: dict[str, Tensor] | None = None, + ): + orig_w, orig_h = F.get_image_size(image) + image = F.resize(image, size=self.size) + new_w, new_h = F.get_image_size(image) + + dw, dh = new_w / orig_w, new_h / orig_h + + if target is not None: + if "boxes" in target: + target["boxes"] = self.resize_boxes(target["boxes"], (dw, dh)) + if "masks" in target: + target["masks"] = self.resize_masks(target["masks"], (dw, dh)) + if "keypoints" in target: + target["keypoints"] = self.resize_keypoints(target["keypoints"], (dw, dh)) + if "perturbable_mask" in target: + target["perturbable_mask"] = self.resize_masks( + target["perturbable_mask"], (dw, dh) + ) + if "gs_coords" in target: + target["gs_coords"] = self.resize_coordinates(target["gs_coords"], (dw, dh)) + + return image, target + + def resize_boxes(self, boxes, ratio): + boxes[:, 0] *= ratio[0] # X1 * width ratio + boxes[:, 1] *= ratio[1] # Y1 * height ratio + boxes[:, 2] *= ratio[0] # X2 * width ratio + boxes[:, 3] *= ratio[1] # Y2 * height ratio + + return boxes + + def resize_masks(self, masks, ratio): + assert len(masks.shape) == 3 + + # Resize fails on empty tensors + if masks.shape[0] == 0: + return torch.zeros((0, *self.size), dtype=masks.dtype, device=masks.device) + + return F.resize(masks, size=self.size, interpolation=F.InterpolationMode.NEAREST) + + def resize_keypoints(self, keypoints, ratio): + raise NotImplementedError + + def resize_coordinates(self, coordinates, ratio): + # coordinates are [[left, top], [right, top], [right, bottom], [left, bottom]] + # ratio is [width, height] + coordinates[:, 0] = (coordinates[:, 0] * ratio[0]).to(int) # width ratio + coordinates[:, 1] = (coordinates[:, 1] * ratio[1]).to(int) # height ratio + + return coordinates + + +class ConvertBoxesToCXCYHW(ExTransform): + def __call__( + self, + image: Tensor, + target: dict[str, Tensor], + ): + # X1Y1X2Y2 + boxes = target["boxes"] + + # X2Y2 -> HW + boxes[:, 2] -= boxes[:, 0] + boxes[:, 3] -= boxes[:, 1] + + # X1Y1 -> CXCY + boxes[:, 0] += boxes[:, 2] / 2 + boxes[:, 1] += boxes[:, 3] / 2 + + target["boxes"] = boxes + + return image, target + + +class RemapLabels(ExTransform): + COCO_MAP = { + 1: 0, + 2: 1, + 3: 2, + 4: 3, + 5: 4, + 6: 5, + 7: 6, + 8: 7, + 9: 8, + 10: 9, + 11: 10, + 13: 11, + 14: 12, + 15: 13, + 16: 14, + 17: 15, + 18: 16, + 19: 17, + 20: 18, + 21: 19, + 22: 20, + 23: 21, + 24: 22, + 25: 23, + 27: 24, + 28: 25, + 31: 26, + 32: 27, + 33: 28, + 34: 29, + 35: 30, + 36: 31, + 37: 32, + 38: 33, + 39: 34, + 40: 35, + 41: 36, + 42: 37, + 43: 38, + 44: 39, + 46: 40, + 47: 41, + 48: 42, + 49: 43, + 50: 44, + 51: 45, + 52: 46, + 53: 47, + 54: 48, + 55: 49, + 56: 50, + 57: 51, + 58: 52, + 59: 53, + 60: 54, + 61: 55, + 62: 56, + 63: 57, + 64: 58, + 65: 59, + 67: 60, + 70: 61, + 72: 62, + 73: 63, + 74: 64, + 75: 65, + 76: 66, + 77: 67, + 78: 68, + 79: 69, + 80: 70, + 81: 71, + 82: 72, + 84: 73, + 85: 74, + 86: 75, + 87: 76, + 88: 77, + 89: 78, + 90: 79, + } + + def __init__( + self, + label_map: dict[int, int] | None = None, + ): + if label_map is None: + label_map = self.COCO_MAP + + self.label_map = label_map + + def __call__( + self, + image: Tensor, + target: dict[str, Tensor], + ): + labels = target["labels"] + + # This is a terrible implementation + for i, label in enumerate(labels): + labels[i] = self.label_map[label.item()] + + target["labels"] = labels + + return image, target + + +class PackBoxesAndLabels(ExTransform): + def __init__(self, num_classes: int): + self.num_classes = num_classes + + def __call__( + self, + image: Tensor, + target: dict[str, Tensor], + ): + boxes = target["boxes"] + labels = target["labels"] + scores = torch.ones_like(labels)[..., None] + + labels = torch.nn.functional.one_hot(labels, num_classes=self.num_classes) + + target["packed"] = torch.cat([boxes, scores, labels], dim=-1) + target["packed_length"] = target["packed"].shape[0] + + return image, target + + +class CreatePerturbableMaskFromImage(ExTransform): + def __init__(self, chroma_key, threshold): + self.chroma_key = torch.tensor(chroma_key) + self.threshold = threshold + + def __call__( + self, + image: Tensor, + target: dict[str, Tensor], + ): + self.chroma_key = self.chroma_key.to(image.device) + + l2_dist = ((image - self.chroma_key[:, None, None]) ** 2).sum(dim=0, keepdim=True).sqrt() + perturbable_mask = l2_dist <= self.threshold + + target["perturbable_mask"] = perturbable_mask.float() + + return image, target diff --git a/pyproject.toml b/pyproject.toml index e88c878f..7ba2129d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ # ----- object detection----- # "pycocotools ~= 2.0.5", + "yolov3 @ git+https://github.com/mzweilin/YOLOv3-in-PyTorch.git#release", # -------- Adversary ---------# "robustbench @ git+https://github.com/RobustBench/robustbench.git@9a590683b7daecf963244dea402529f0d728c727",