diff --git a/dali/operators/image/crop/bbox_crop.cc b/dali/operators/image/crop/bbox_crop.cc index 4355b38a65..b38f519715 100644 --- a/dali/operators/image/crop/bbox_crop.cc +++ b/dali/operators/image/crop/bbox_crop.cc @@ -146,10 +146,10 @@ and ``end_y`` order, and ``bbox_layout``\="xyWH" indicates that the order is ``s ``start_y``, ``width``, and ``height``. See the ``bbox_layout`` argument description for more information. -Optionally, a second input, called ``labels``, can be provided, which represents the labels that are -associated with each of the bounding boxes. +An optional input ``labels`` can be provided, representing the labels that are associated with each of +the bounding boxes. -**Outputs: 0**: anchor, 1: shape, 2: bboxes, (3: labels) +**Outputs: 0**: anchor, 1: shape, 2: bboxes (, 3: labels, 4: bboxes_indices) The resulting crop parameters are provided as two separate outputs, ``anchor`` and ``shape``, that can be fed directly to the :meth:`nvidia.dali.ops.Slice` operator to complete the cropping @@ -158,8 +158,16 @@ for the crop in the ``[x, y, (z)]`` and ``[w, h, (d)]`` formats, respectively. T be represented in absolute or relative terms, and the represetnation depends on whether the fixed ``crop_shape`` was used. -The third and fourth outputs correspond to the adjusted bounding boxes and, optionally, -to their corresponding labels. Bounding boxes are always specified in relative coordinates.)code") +The third output contains the bounding boxes, after filtering out the ones with a centroid outside +of the cropping window, and with the coordinates mapped to the new coordinate space. + +The next output is optional, and it represents the labels associated with the filtered bounding boxes. +The output will be present if a labels input was provided. + +The last output, also optional, correspond to the original indices of the bounding boxes that passed the +centroid filter and are present in the output. +This output will be present if the option ``output_bbox_indices`` is set to True. +)code") .NumInput(1, 2) // [boxes, labels (optional),] .InputDox( 0, "boxes", "2D TensorList of float", R"code(Relative coordinates of the bounding boxes @@ -167,9 +175,10 @@ that are represented as a 2D tensor, where the first dimension refers to the ind box, and the second dimension refers to the index of the coordinate.)code") .InputDox(1, "labels", "1D TensorList of integers", R"code(Labels that are associated with each of the bounding boxes.)code") - .NumOutput(3) // [anchor, shape, bboxes, labels (optional),] + .NumOutput(3) // [anchor, shape, bboxes, labels (optional), bbox_indices (optional)] .AdditionalOutputsFn([](const OpSpec &spec) { - return spec.NumRegularInput() - 1; // +1 if labels are provided + return spec.NumRegularInput() - 1 + // +1 if labels are provided + spec.GetArgument("output_bbox_indices"); // +1 if output_bbox_indices=True }) .AddOptionalArg( "thresholds", @@ -308,7 +317,7 @@ The value of this argument is a string containing the following characters:: If this value is left empty, depending on the number of dimensions, "xyXY" or "xyzXYZ" is assumed. )code", - TensorLayout{""}) + TensorLayout{}) .AddOptionalArg( "shape_layout", R"code(Determines the meaning of the dimensions provided in ``crop_shape`` and @@ -323,7 +332,13 @@ The values are: .. note:: If left empty, depending on the number of dimensions ``"WH"`` or ``"WHD"`` will be assumed. )code", - TensorLayout{""}); + TensorLayout{}) + .AddOptionalArg( + "output_bbox_indices", + R"code(If set to True, an extra output will be returned, containing +the original indices of the bounding boxes that passed the centroid filter and are present in +the output bounding boxes.)code", + false); template class RandomBBoxCropImpl : public OpImplBase { @@ -346,6 +361,7 @@ class RandomBBoxCropImpl : public OpImplBase { bbox_layout_(spec.GetArgument("bbox_layout")), shape_layout_(spec.GetArgument("shape_layout")), all_boxes_above_threshold_(spec.GetArgument("all_boxes_above_threshold")), + output_bbox_indices_(spec.GetArgument("output_bbox_indices")), rngs_(spec.GetArgument("seed"), spec.GetArgument("batch_size")) { auto scaling_arg = spec.GetRepeatedArgument("scaling"); DALI_ENFORCE(scaling_arg.size() == 2, @@ -486,41 +502,105 @@ class RandomBBoxCropImpl : public OpImplBase { return false; } - void RunImpl(SampleWorkspace &ws) override { - const auto &boxes_tensor = ws.Input(0); - auto nboxes = boxes_tensor.dim(0); + void RunImpl(workspace_t &ws) override { + const auto &in_boxes = ws.template InputRef(0); + auto in_boxes_view = view(in_boxes); + auto in_boxes_shape = in_boxes_view.shape; + int num_samples = in_boxes_shape.num_samples(); + auto &tp = ws.GetThreadPool(); auto ncoords = ndim * 2; + sample_data_.clear(); + sample_data_.resize(num_samples); + for (int sample_idx = 0; sample_idx < num_samples; sample_idx++) { + auto &data = sample_data_[sample_idx]; + tp.AddWork([&, sample_idx](int thread_id) { + auto nboxes = in_boxes_view.tensor_shape_span(sample_idx)[0]; + data.in_bboxes.resize(nboxes); + ReadBoxes(make_span(data.in_bboxes), + make_cspan(in_boxes_view.tensor_data(sample_idx), + volume(in_boxes_shape.tensor_size(sample_idx))), + bbox_layout_); + FindProspectiveCrop(data.prospective_crop, make_cspan(data.in_bboxes), sample_idx); + }, in_boxes_shape.tensor_size(sample_idx)); + } + tp.RunAll(); - std::vector> bounding_boxes(nboxes); - ReadBoxes(make_span(bounding_boxes), - make_cspan(boxes_tensor.data(), boxes_tensor.size()), bbox_layout_); + auto &anchor_out = ws.template OutputRef(0); + anchor_out.Resize(uniform_list_shape(num_samples, {ndim})); + auto anchor_out_view = view(anchor_out); - std::vector labels; - if (has_labels_) { - const auto &labels_tensor = ws.Input(1); - auto nlabels = labels_tensor.dim(0); - DALI_ENFORCE(nlabels == nboxes, - make_string("Unexpected number of labels. Expected: ", nboxes, ", got ", nlabels)); - labels.resize(nlabels); - const auto *label_data = labels_tensor.data(); - for (int i = 0; i < nlabels; i++) { - labels[i] = label_data[i]; + auto &shape_out = ws.template OutputRef(1); + shape_out.Resize(uniform_list_shape(num_samples, {ndim})); + auto shape_out_view = view(shape_out); + + for (int sample_idx = 0; sample_idx < num_samples; sample_idx++) { + const auto &prospective_crop = sample_data_[sample_idx].prospective_crop; + auto extent = prospective_crop.crop.extent(); + for (int d = 0; d < ndim; d++) { + anchor_out_view.tensor_data(sample_idx)[d] = prospective_crop.crop.lo[d]; + shape_out_view.tensor_data(sample_idx)[d] = extent[d]; } } - int sample = ws.data_idx(); - ProspectiveCrop prospective_crop = - FindProspectiveCrop(make_cspan(bounding_boxes), make_cspan(labels), sample); - - WriteCropToOutput(ws, prospective_crop.crop); - WriteBoxesToOutput(ws, make_cspan(prospective_crop.boxes)); + TensorListShape<> bbox_out_shape; + bbox_out_shape.resize(num_samples, 2); + for (int sample_idx = 0; sample_idx < num_samples; sample_idx++) { + auto sh = bbox_out_shape.tensor_shape_span(sample_idx); + sh[0] = sample_data_[sample_idx].prospective_crop.boxes.size(); + sh[1] = ncoords; + } + auto &bbox_out = ws.template OutputRef(2); + bbox_out.Resize(bbox_out_shape); + auto bbox_out_view = view(bbox_out); + for (int sample_idx = 0; sample_idx < num_samples; sample_idx++) { + WriteBoxes(make_span(bbox_out_view.tensor_data(sample_idx), + volume(bbox_out_view.tensor_shape_span(sample_idx))), + make_cspan(sample_data_[sample_idx].prospective_crop.boxes), bbox_layout_); + } + int next_out_idx = 3; if (has_labels_) { - DALI_ENFORCE( - prospective_crop.boxes.size() == prospective_crop.labels.size(), - make_string("Expected boxes.size() == labels.size(). Received: ", - prospective_crop.boxes.size(), " != ", prospective_crop.labels.size())); - WriteLabelsToOutput(ws, make_cspan(prospective_crop.labels)); + const auto &labels_in = ws.template InputRef(1); + auto labels_in_view = view(labels_in); + + auto &labels_out = ws.template OutputRef(next_out_idx++); + TensorListShape<> labels_out_shape = labels_in.shape(); + for (int sample_idx = 0; sample_idx < num_samples; sample_idx++) { + auto sh = labels_out_shape.tensor_shape_span(sample_idx); + sh[0] = sample_data_[sample_idx].prospective_crop.bbox_indices.size(); + } + labels_out.Resize(labels_out_shape); + auto labels_out_view = view(labels_out); + for (int sample_idx = 0; sample_idx < num_samples; sample_idx++) { + auto *labels_out_data = labels_out_view.tensor_data(sample_idx); + const auto *labels_in_data = labels_in_view.tensor_data(sample_idx); + auto in_sh = labels_in_view.tensor_shape_span(sample_idx); + int64_t stride = volume(in_sh.begin() + 1, in_sh.end()); + for (auto bbox_idx : sample_data_[sample_idx].prospective_crop.bbox_indices) { + assert(bbox_idx < in_sh[0]); + const auto *curr_label_data = labels_in_data + bbox_idx * stride; + for (int64_t k = 0; k < stride; k++) + *labels_out_data++ = *curr_label_data++; + } + } + } + + if (output_bbox_indices_) { + auto &bbox_indices_out = ws.template OutputRef(next_out_idx++); + TensorListShape<> bbox_indices_out_shape; + bbox_indices_out_shape.resize(num_samples, 1); + for (int sample_idx = 0; sample_idx < num_samples; sample_idx++) { + bbox_indices_out_shape.tensor_shape_span(sample_idx)[0] = + sample_data_[sample_idx].prospective_crop.bbox_indices.size(); + } + bbox_indices_out.Resize(bbox_indices_out_shape); + auto bbox_indices_out_view = view(bbox_indices_out); + for (int sample_idx = 0; sample_idx < num_samples; sample_idx++) { + auto *bbox_indices_out_data = bbox_indices_out_view.tensor_data(sample_idx); + for (auto bbox_idx : sample_data_[sample_idx].prospective_crop.bbox_indices) { + *bbox_indices_out_data++ = bbox_idx; + } + } } } @@ -529,23 +609,14 @@ class RandomBBoxCropImpl : public OpImplBase { bool success = false; Box crop{}; std::vector> boxes; - std::vector labels; - - ProspectiveCrop(bool success, - const Box& crop, - span> boxes_data, - span labels_data, - bool has_labels) - : success(success), crop(crop) { - assert(boxes_data.size() == labels_data.size() || !has_labels); - boxes.resize(boxes_data.size()); - labels.resize(labels_data.size()); - for (int i = 0; i < boxes_data.size(); i++) { - boxes[i] = boxes_data[i]; - if (has_labels) labels[i] = labels_data[i]; - } + std::vector bbox_indices; + + void clear() { + success = false; + crop = {}; + boxes.clear(); + bbox_indices.clear(); } - ProspectiveCrop() = default; }; /** @@ -604,12 +675,12 @@ class RandomBBoxCropImpl : public OpImplBase { extent = max_extent * extent / new_max_extent; } - ProspectiveCrop FindProspectiveCrop(span> bounding_boxes, - span labels, int sample) { - ProspectiveCrop crop; + void FindProspectiveCrop(ProspectiveCrop &crop, span> bounding_boxes, + int sample) { int count = 0; float best_metric = -1.0; + crop.clear(); while (!crop.success && (total_num_attempts_ < 0 || count < total_num_attempts_)) { auto &rng = rngs_[sample]; std::uniform_int_distribution<> idx_dist(0, sample_options_.size() - 1); @@ -623,7 +694,11 @@ class RandomBBoxCropImpl : public OpImplBase { for (int d = 0; d < ndim; d++) no_crop.hi[d] *= input_shape[d]; } - crop = ProspectiveCrop(true, no_crop, bounding_boxes, labels, has_labels_); + crop.success = true; + crop.crop = no_crop; + crop.boxes.assign(bounding_boxes.begin(), bounding_boxes.end()); + crop.bbox_indices.resize(crop.boxes.size()); + std::iota(crop.bbox_indices.begin(), crop.bbox_indices.end(), 0); break; } @@ -680,25 +755,13 @@ class RandomBBoxCropImpl : public OpImplBase { best_metric = metric; - crop = {}; crop.crop = out_crop; - - crop.boxes.resize(bounding_boxes.size()); - for (int i = 0; i < bounding_boxes.size(); i++) - crop.boxes[i] = bounding_boxes[i]; - - if (!labels.empty()) { - assert(labels.size() == bounding_boxes.size()); - crop.labels.resize(labels.size()); - for (int i = 0; i < labels.size(); i++) - crop.labels[i] = labels[i]; - } - - FilterByCentroid(rel_crop, crop.boxes, crop.labels); + crop.boxes.assign(bounding_boxes.begin(), bounding_boxes.end()); + crop.bbox_indices.clear(); // indices will be populated by FilterByCentroid + FilterByCentroid(rel_crop, crop.boxes, crop.bbox_indices); for (auto &box : crop.boxes) { box = RemapBox(box, rel_crop); } - bool at_least_one_box = !crop.boxes.empty(); crop.success = is_valid_overlap && at_least_one_box; } @@ -710,7 +773,6 @@ class RandomBBoxCropImpl : public OpImplBase { count, " times). Using the best cropping window so far (best_metric=", best_metric, ")")); crop.success = true; } - return crop; } bool ValidAspectRatio(vec shape) { @@ -763,55 +825,16 @@ class RandomBBoxCropImpl : public OpImplBase { void FilterByCentroid(const Box &crop, std::vector> &bboxes, - std::vector &labels) { + std::vector &indices) { std::vector> new_bboxes; - std::vector new_labels; - bool process_labels = !labels.empty(); - assert(labels.empty() || labels.size() == bboxes.size()); + indices.clear(); for (size_t i = 0; i < bboxes.size(); i++) { if (crop.contains(bboxes[i].centroid())) { new_bboxes.push_back(bboxes[i]); - if (process_labels) - new_labels.push_back(labels[i]); + indices.push_back(static_cast(i)); } } std::swap(bboxes, new_bboxes); - if (process_labels) - std::swap(labels, new_labels); - } - - void WriteCropToOutput(SampleWorkspace &ws, const Box &crop) { - // output0 : anchor, output1 : shape - auto &anchor_out = ws.Output(0); - anchor_out.Resize({ndim}); - auto *anchor_out_data = anchor_out.mutable_data(); - - auto &shape_out = ws.Output(1); - shape_out.Resize({ndim}); - auto *shape_out_data = shape_out.mutable_data(); - - auto extent = crop.extent(); - for (int d = 0; d < ndim; d++) { - anchor_out_data[d] = crop.lo[d]; - shape_out_data[d] = extent[d]; - } - } - - void WriteBoxesToOutput(SampleWorkspace &ws, - span> bounding_boxes) { - auto &bbox_out = ws.Output(2); - bbox_out.Resize({static_cast(bounding_boxes.size()), coords_size}); - auto *bbox_out_data = bbox_out.mutable_data(); - WriteBoxes(make_span(bbox_out_data, bbox_out.size()), make_cspan(bounding_boxes), - bbox_layout_); - } - - void WriteLabelsToOutput(SampleWorkspace &ws, span labels) { - auto &labels_out = ws.Output(3); - labels_out.Resize({static_cast(labels.size()), 1}); - auto *labels_out_data = labels_out.mutable_data(); - for (int i = 0; i < labels_out.size(); i++) - labels_out_data[i] = labels[i]; } private: @@ -827,6 +850,7 @@ class RandomBBoxCropImpl : public OpImplBase { OverlapMetric overlap_metric_ = OverlapMetric::IoU; bool all_boxes_above_threshold_ = true; + bool output_bbox_indices_ = false; BatchRNG rngs_; @@ -837,6 +861,13 @@ class RandomBBoxCropImpl : public OpImplBase { Range scale_range_; std::vector aspect_ratio_ranges_; + + struct SampleData { + std::vector> in_bboxes; + ProspectiveCrop prospective_crop; + }; + + std::vector sample_data_; }; template <> @@ -879,7 +910,7 @@ bool RandomBBoxCrop::SetupImpl(std::vector &output_desc, } template <> -void RandomBBoxCrop::RunImpl(SampleWorkspace &ws) { +void RandomBBoxCrop::RunImpl(workspace_t &ws) { assert(impl_ != nullptr); impl_->RunImpl(ws); } diff --git a/dali/operators/image/crop/bbox_crop.h b/dali/operators/image/crop/bbox_crop.h index 11ea08b6dd..80c488d266 100644 --- a/dali/operators/image/crop/bbox_crop.h +++ b/dali/operators/image/crop/bbox_crop.h @@ -31,7 +31,7 @@ class RandomBBoxCrop : public Operator { protected: bool SetupImpl(std::vector &output_desc, const workspace_t &ws) override; - void RunImpl(Workspace &ws) override; + void RunImpl(workspace_t &ws) override; private: std::unique_ptr> impl_; diff --git a/dali/test/python/test_operator_random_bbox_crop.py b/dali/test/python/test_operator_random_bbox_crop.py index c4643ee842..9f4b42c1d8 100644 --- a/dali/test/python/test_operator_random_bbox_crop.py +++ b/dali/test/python/test_operator_random_bbox_crop.py @@ -20,6 +20,7 @@ from nvidia.dali.backend_impl import TensorListGPU import numpy as np import os +import random bbox_2d_ltrb_1 = [0.0123, 0.0123, 0.2123, 0.2123] bbox_2d_ltrb_2 = [0.1123, 0.1123, 0.19123, 0.19123] @@ -97,6 +98,7 @@ def __init__(self, device, batch_size, input_shape = None, crop_shape = None, all_boxes_above_threshold = False, + output_bbox_indices = False, num_threads=1, device_id=0, num_gpus=1): super(RandomBBoxCropSynthDataPipeline, self).__init__( batch_size, num_threads, device_id, seed=1234) @@ -115,7 +117,9 @@ def __init__(self, device, batch_size, allow_no_crop = allow_no_crop, input_shape = input_shape, crop_shape = crop_shape, - all_boxes_above_threshold = all_boxes_above_threshold) + all_boxes_above_threshold = all_boxes_above_threshold, + output_bbox_indices = output_bbox_indices + ) def define_graph(self): inputs = fn.external_source(source=self.bbox_source, num_outputs=self.bbox_source.num_outputs) @@ -164,8 +168,11 @@ def map_box(bbox, crop_anchor, crop_shape): new_bbox[ndim + d] = max(0.0, min(1.0, n_end)) return new_bbox -def check_processed_bboxes(crop_anchor, crop_shape, original_boxes, processed_boxes): - filtered_boxes = filter_by_centroid(crop_anchor, crop_shape, original_boxes) +def check_processed_bboxes(crop_anchor, crop_shape, original_boxes, processed_boxes, bbox_indices=None): + if bbox_indices is not None: + filtered_boxes = np.array(original_boxes[bbox_indices]) + else: + filtered_boxes = filter_by_centroid(crop_anchor, crop_shape, original_boxes) assert(len(original_boxes) >= len(filtered_boxes)) assert(len(filtered_boxes) == len(processed_boxes)) nboxes = len(filtered_boxes) @@ -206,14 +213,15 @@ def check_crop_dims_fixed_size(anchor, shape, expected_crop_shape, input_shape): assert shape[d] == expected_crop_shape[d], "{} != {}".format(shape, expected_crop_shape) assert(anchor[d] + shape[d] > 0.0 and anchor[d] + shape[d] <= input_shape[d]) -def check_random_bbox_crop_variable_shape(batch_size, ndim, scaling, aspect_ratio, use_labels): +def check_random_bbox_crop_variable_shape(batch_size, ndim, scaling, aspect_ratio, use_labels, output_bbox_indices): bbox_source = BBoxDataIterator(100, batch_size, ndim, produce_labels=use_labels) bbox_layout = "xyzXYZ" if ndim == 3 else "xyXY" pipe = RandomBBoxCropSynthDataPipeline(device='cpu', batch_size=batch_size, bbox_source=bbox_source, bbox_layout=bbox_layout, scaling=scaling, aspect_ratio=aspect_ratio, - input_shape=None, crop_shape=None) + input_shape=None, crop_shape=None, + output_bbox_indices=output_bbox_indices) pipe.build() for i in range(100): outputs = pipe.run() @@ -223,10 +231,13 @@ def check_random_bbox_crop_variable_shape(batch_size, ndim, scaling, aspect_rati out_crop_shape = outputs[2].at(sample) out_boxes = outputs[3].at(sample) check_crop_dims_variable_size(out_crop_anchor, out_crop_shape, scaling, aspect_ratio) - check_processed_bboxes(out_crop_anchor, out_crop_shape, in_boxes, out_boxes) + bbox_indices_out_idx = 4 if not use_labels else 5 + bbox_indices = outputs[bbox_indices_out_idx].at(sample) if output_bbox_indices else None + check_processed_bboxes(out_crop_anchor, out_crop_shape, in_boxes, out_boxes, bbox_indices) def test_random_bbox_crop_variable_shape(): + random.seed(1234) for batch_size in [3]: for ndim in [2, 3]: for scaling in [[0.3, 0.5], [0.1, 0.3], [0.9, 0.99]]: @@ -235,9 +246,10 @@ def test_random_bbox_crop_variable_shape(): 3 : [[0.5, 2.0, 0.6, 2.1, 0.4, 1.9], [1.0, 1.0], [0.5, 0.5, 0.25, 0.25, 0.5, 0.5]] } for aspect_ratio in aspect_ratio_ranges[ndim]: - for use_labels in [True, False]: - yield check_random_bbox_crop_variable_shape, \ - batch_size, ndim, scaling, aspect_ratio, use_labels + use_labels = random.choice([True, False]) + out_bbox_indices = random.choice([True, False]) + yield check_random_bbox_crop_variable_shape, \ + batch_size, ndim, scaling, aspect_ratio, use_labels, out_bbox_indices def check_random_bbox_crop_fixed_shape(batch_size, ndim, crop_shape, input_shape, use_labels): bbox_source = BBoxDataIterator(100, batch_size, ndim, produce_labels=use_labels)