-
Notifications
You must be signed in to change notification settings - Fork 610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RandomBBoxCrop to optionally output the indices of the bounding boxes that passed the centroid filter #2374
Changes from 1 commit
bd80b94
e95ec83
c53efbf
0a46f46
763cd87
3ab7730
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -146,10 +146,10 @@ and ``end_y`` order, and ``bbox_layout``\="xyWH" indicates that the order is ``s | |||||
``start_y``, ``width``, and ``height``. See the ``bbox_layout`` argument description | ||||||
for more information. | ||||||
|
||||||
Optionally, a second input, called ``labels``, can be provided, which represents the labels that are | ||||||
associated with each of the bounding boxes. | ||||||
An optional input ``labels`` can be provided, representing the labels that are associated with each of | ||||||
the bounding boxes. | ||||||
|
||||||
**Outputs: 0**: anchor, 1: shape, 2: bboxes, (3: labels) | ||||||
**Outputs: 0**: anchor, 1: shape, 2: bboxes (, 3: labels, 4: bboxes_indices) | ||||||
|
||||||
The resulting crop parameters are provided as two separate outputs, ``anchor`` and ``shape``, | ||||||
that can be fed directly to the :meth:`nvidia.dali.ops.Slice` operator to complete the cropping | ||||||
|
@@ -158,18 +158,27 @@ for the crop in the ``[x, y, (z)]`` and ``[w, h, (d)]`` formats, respectively. T | |||||
be represented in absolute or relative terms, and the represetnation depends on whether | ||||||
the fixed ``crop_shape`` was used. | ||||||
|
||||||
The third and fourth outputs correspond to the adjusted bounding boxes and, optionally, | ||||||
to their corresponding labels. Bounding boxes are always specified in relative coordinates.)code") | ||||||
The third output contains the bounding boxes, after filtering out the ones with a centroid outside | ||||||
of the cropping window, and with the coordinates mapped to the new coordinate space. | ||||||
|
||||||
The next output is optional, and it represents the labels associated with the filtered bounding boxes. | ||||||
The output will be present if a labels input was provided. | ||||||
|
||||||
The last output, also optional, correspond to the original indices of the bounding boxes that passed the | ||||||
centroid filter and are present in the output. | ||||||
This output will be present if the option ``output_bbox_indices`` is set to True. | ||||||
)code") | ||||||
.NumInput(1, 2) // [boxes, labels (optional),] | ||||||
.InputDox( | ||||||
0, "boxes", "2D TensorList of float", R"code(Relative coordinates of the bounding boxes | ||||||
that are represented as a 2D tensor, where the first dimension refers to the index of the bounding | ||||||
box, and the second dimension refers to the index of the coordinate.)code") | ||||||
.InputDox(1, "labels", "1D TensorList of integers", R"code(Labels that are | ||||||
associated with each of the bounding boxes.)code") | ||||||
.NumOutput(3) // [anchor, shape, bboxes, labels (optional),] | ||||||
.NumOutput(3) // [anchor, shape, bboxes, labels (optional), bbox_indices (optional)] | ||||||
.AdditionalOutputsFn([](const OpSpec &spec) { | ||||||
return spec.NumRegularInput() - 1; // +1 if labels are provided | ||||||
return spec.NumRegularInput() - 1 + // +1 if labels are provided | ||||||
spec.GetArgument<bool>("output_bbox_indices"); // +1 if output_bbox_indices=True | ||||||
}) | ||||||
.AddOptionalArg( | ||||||
"thresholds", | ||||||
|
@@ -323,7 +332,13 @@ The values are: | |||||
.. note:: | ||||||
If left empty, depending on the number of dimensions ``"WH"`` or ``"WHD"`` will be assumed. | ||||||
)code", | ||||||
TensorLayout{""}); | ||||||
TensorLayout{""}) | ||||||
.AddOptionalArg<bool>( | ||||||
"output_bbox_indices", | ||||||
R"code(If set to True, an extra output ``bbox_indices`` will be returned, containing | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I don't we have named outputs. |
||||||
the original indices of the bounding boxes that passed the centroid filter and are present in | ||||||
the output bounding boxes.)code", | ||||||
false); | ||||||
|
||||||
template <int ndim> | ||||||
class RandomBBoxCropImpl : public OpImplBase<CPUBackend> { | ||||||
|
@@ -346,6 +361,7 @@ class RandomBBoxCropImpl : public OpImplBase<CPUBackend> { | |||||
bbox_layout_(spec.GetArgument<TensorLayout>("bbox_layout")), | ||||||
shape_layout_(spec.GetArgument<TensorLayout>("shape_layout")), | ||||||
all_boxes_above_threshold_(spec.GetArgument<bool>("all_boxes_above_threshold")), | ||||||
output_bbox_indices_(spec.GetArgument<bool>("output_bbox_indices")), | ||||||
rngs_(spec.GetArgument<int64_t>("seed"), spec.GetArgument<int>("batch_size")) { | ||||||
auto scaling_arg = spec.GetRepeatedArgument<float>("scaling"); | ||||||
DALI_ENFORCE(scaling_arg.size() == 2, | ||||||
|
@@ -495,32 +511,34 @@ class RandomBBoxCropImpl : public OpImplBase<CPUBackend> { | |||||
ReadBoxes(make_span(bounding_boxes), | ||||||
make_cspan(boxes_tensor.data<float>(), boxes_tensor.size()), bbox_layout_); | ||||||
|
||||||
std::vector<int> labels; | ||||||
if (has_labels_) { | ||||||
const auto &labels_tensor = ws.Input<CPUBackend>(1); | ||||||
auto nlabels = labels_tensor.dim(0); | ||||||
DALI_ENFORCE(nlabels == nboxes, | ||||||
make_string("Unexpected number of labels. Expected: ", nboxes, ", got ", nlabels)); | ||||||
labels.resize(nlabels); | ||||||
const auto *label_data = labels_tensor.data<int>(); | ||||||
for (int i = 0; i < nlabels; i++) { | ||||||
labels[i] = label_data[i]; | ||||||
} | ||||||
} | ||||||
|
||||||
int sample = ws.data_idx(); | ||||||
ProspectiveCrop prospective_crop = | ||||||
FindProspectiveCrop(make_cspan(bounding_boxes), make_cspan(labels), sample); | ||||||
FindProspectiveCrop(make_cspan(bounding_boxes), sample); | ||||||
|
||||||
WriteCropToOutput(ws, prospective_crop.crop); | ||||||
WriteBoxesToOutput(ws, make_cspan(prospective_crop.boxes)); | ||||||
|
||||||
int next_out_idx = 3; | ||||||
if (has_labels_) { | ||||||
DALI_ENFORCE( | ||||||
prospective_crop.boxes.size() == prospective_crop.labels.size(), | ||||||
make_string("Expected boxes.size() == labels.size(). Received: ", | ||||||
prospective_crop.boxes.size(), " != ", prospective_crop.labels.size())); | ||||||
WriteLabelsToOutput(ws, make_cspan(prospective_crop.labels)); | ||||||
const auto &labels_in = ws.Input<CPUBackend>(1); | ||||||
auto &labels_out = ws.Output<CPUBackend>(next_out_idx++); | ||||||
labels_out.Resize({static_cast<Index>(prospective_crop.bbox_indices.size()), 1}); | ||||||
int64_t i = 0; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unused |
||||||
const auto* labels_in_data = labels_in.data<int>(); | ||||||
auto *labels_out_data = labels_out.mutable_data<int>(); | ||||||
for (auto bbox_idx : prospective_crop.bbox_indices) { | ||||||
assert(bbox_idx < labels_in.dim(0)); | ||||||
*labels_out_data++ = labels_in_data[bbox_idx]; | ||||||
} | ||||||
} | ||||||
|
||||||
if (output_bbox_indices_) { | ||||||
auto &bbox_indices_out = ws.Output<CPUBackend>(next_out_idx++); | ||||||
bbox_indices_out.Resize({static_cast<Index>(prospective_crop.bbox_indices.size())}); | ||||||
auto* bbox_indices_out_data = bbox_indices_out.mutable_data<int>(); | ||||||
for (auto bbox_idx : prospective_crop.bbox_indices) { | ||||||
*bbox_indices_out_data++ = bbox_idx; | ||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
|
@@ -529,20 +547,21 @@ class RandomBBoxCropImpl : public OpImplBase<CPUBackend> { | |||||
bool success = false; | ||||||
Box<ndim, float> crop{}; | ||||||
std::vector<Box<ndim, float>> boxes; | ||||||
std::vector<int> labels; | ||||||
SmallVector<int, 128> bbox_indices; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that's a bit too big an object - perhaps it should be a regular vector instead or at the very least we should be much more conservative wrt static size (8-16, not 128). |
||||||
|
||||||
ProspectiveCrop(bool success, | ||||||
const Box<ndim, float>& crop, | ||||||
span<const Box<ndim, float>> boxes_data, | ||||||
span<const int> labels_data, | ||||||
bool has_labels) | ||||||
: success(success), crop(crop) { | ||||||
assert(boxes_data.size() == labels_data.size() || !has_labels); | ||||||
SmallVector<int, 128> indices = {}) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Passing a small vector by value is not a good idea. |
||||||
: success(success), crop(crop), bbox_indices(std::move(indices)) { | ||||||
boxes.resize(boxes_data.size()); | ||||||
labels.resize(labels_data.size()); | ||||||
for (int i = 0; i < boxes_data.size(); i++) { | ||||||
for (int i = 0; i < boxes_data.size(); i++) | ||||||
boxes[i] = boxes_data[i]; | ||||||
if (has_labels) labels[i] = labels_data[i]; | ||||||
|
||||||
bbox_indices = std::move(indices); | ||||||
if (bbox_indices.empty()) { | ||||||
bbox_indices.resize(boxes.size()); | ||||||
std::iota(bbox_indices.begin(), bbox_indices.end(), 0); | ||||||
} | ||||||
} | ||||||
ProspectiveCrop() = default; | ||||||
|
@@ -604,8 +623,7 @@ class RandomBBoxCropImpl : public OpImplBase<CPUBackend> { | |||||
extent = max_extent * extent / new_max_extent; | ||||||
} | ||||||
|
||||||
ProspectiveCrop FindProspectiveCrop(span<const Box<ndim, float>> bounding_boxes, | ||||||
span<const int> labels, int sample) { | ||||||
ProspectiveCrop FindProspectiveCrop(span<const Box<ndim, float>> bounding_boxes, int sample) { | ||||||
ProspectiveCrop crop; | ||||||
int count = 0; | ||||||
float best_metric = -1.0; | ||||||
|
@@ -623,7 +641,7 @@ class RandomBBoxCropImpl : public OpImplBase<CPUBackend> { | |||||
for (int d = 0; d < ndim; d++) | ||||||
no_crop.hi[d] *= input_shape[d]; | ||||||
} | ||||||
crop = ProspectiveCrop(true, no_crop, bounding_boxes, labels, has_labels_); | ||||||
crop = ProspectiveCrop(true, no_crop, bounding_boxes); | ||||||
break; | ||||||
} | ||||||
|
||||||
|
@@ -680,21 +698,8 @@ class RandomBBoxCropImpl : public OpImplBase<CPUBackend> { | |||||
|
||||||
best_metric = metric; | ||||||
|
||||||
crop = {}; | ||||||
crop.crop = out_crop; | ||||||
|
||||||
crop.boxes.resize(bounding_boxes.size()); | ||||||
for (int i = 0; i < bounding_boxes.size(); i++) | ||||||
crop.boxes[i] = bounding_boxes[i]; | ||||||
|
||||||
if (!labels.empty()) { | ||||||
assert(labels.size() == bounding_boxes.size()); | ||||||
crop.labels.resize(labels.size()); | ||||||
for (int i = 0; i < labels.size(); i++) | ||||||
crop.labels[i] = labels[i]; | ||||||
} | ||||||
|
||||||
FilterByCentroid(rel_crop, crop.boxes, crop.labels); | ||||||
crop = ProspectiveCrop(false, out_crop, bounding_boxes); | ||||||
FilterByCentroid(rel_crop, crop.boxes, crop.bbox_indices); | ||||||
for (auto &box : crop.boxes) { | ||||||
box = RemapBox(box, rel_crop); | ||||||
} | ||||||
|
@@ -761,23 +766,17 @@ class RandomBBoxCropImpl : public OpImplBase<CPUBackend> { | |||||
} | ||||||
} | ||||||
|
||||||
void FilterByCentroid(const Box<ndim, float> &crop, | ||||||
std::vector<Box<ndim, float>> &bboxes, | ||||||
std::vector<int> &labels) { | ||||||
void FilterByCentroid(const Box<ndim, float> &crop, std::vector<Box<ndim, float>> &bboxes, | ||||||
SmallVector<int, 128> &indices) { | ||||||
std::vector<Box<ndim, float>> new_bboxes; | ||||||
std::vector<int> new_labels; | ||||||
bool process_labels = !labels.empty(); | ||||||
assert(labels.empty() || labels.size() == bboxes.size()); | ||||||
indices.clear(); | ||||||
for (size_t i = 0; i < bboxes.size(); i++) { | ||||||
if (crop.contains(bboxes[i].centroid())) { | ||||||
new_bboxes.push_back(bboxes[i]); | ||||||
if (process_labels) | ||||||
new_labels.push_back(labels[i]); | ||||||
indices.push_back(static_cast<int>(i)); | ||||||
} | ||||||
} | ||||||
std::swap(bboxes, new_bboxes); | ||||||
if (process_labels) | ||||||
std::swap(labels, new_labels); | ||||||
} | ||||||
|
||||||
void WriteCropToOutput(SampleWorkspace &ws, const Box<ndim, float> &crop) { | ||||||
|
@@ -806,14 +805,6 @@ class RandomBBoxCropImpl : public OpImplBase<CPUBackend> { | |||||
bbox_layout_); | ||||||
} | ||||||
|
||||||
void WriteLabelsToOutput(SampleWorkspace &ws, span<const int> labels) { | ||||||
auto &labels_out = ws.Output<CPUBackend>(3); | ||||||
labels_out.Resize({static_cast<Index>(labels.size()), 1}); | ||||||
auto *labels_out_data = labels_out.mutable_data<int>(); | ||||||
for (int i = 0; i < labels_out.size(); i++) | ||||||
labels_out_data[i] = labels[i]; | ||||||
} | ||||||
|
||||||
private: | ||||||
OpSpec spec_; | ||||||
int num_attempts_; | ||||||
|
@@ -827,6 +818,7 @@ class RandomBBoxCropImpl : public OpImplBase<CPUBackend> { | |||||
|
||||||
OverlapMetric overlap_metric_ = OverlapMetric::IoU; | ||||||
bool all_boxes_above_threshold_ = true; | ||||||
bool output_bbox_indices_ = false; | ||||||
|
||||||
BatchRNG<std::mt19937> rngs_; | ||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.