diff --git a/dali/operators/segmentation/random_mask_pixel.cc b/dali/operators/segmentation/random_mask_pixel.cc new file mode 100644 index 0000000000..0c591d46b8 --- /dev/null +++ b/dali/operators/segmentation/random_mask_pixel.cc @@ -0,0 +1,190 @@ +// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "dali/core/static_switch.h" +#include "dali/pipeline/operator/operator.h" +#include "dali/operators/segmentation/utils/searchable_rle_mask.h" +#include "dali/kernels/common/utils.h" +#include "dali/core/boundary.h" + +#define MASK_SUPPORTED_TYPES (uint8_t, int8_t, uint16_t, int16_t, uint32_t, int32_t, \ + uint64_t, int64_t, float, bool) + +namespace dali { + +DALI_SCHEMA(segmentation__RandomMaskPixel) + .DocStr(R"(Selects random pixel coordinates in a mask, sampled from a uniform distribution. + +Based on run-time argument ``foreground``, it returns either only foreground pixels or any pixels. + +Pixels are classificed as foreground either when their value exceeds a given ``threshold`` or when +it's equal to a specific ``value``. +)") + .AddOptionalArg("value", + R"code(All pixels equal to this value are interpreted as foreground. + +This argument is mutually exclusive with ``threshold`` argument and is meant to be used only +with integer inputs. +)code", nullptr, true) + .AddOptionalArg("threshold", + R"code(All pixels with a value above this threshold are interpreted as foreground. + +This argument is mutually exclusive with ``value`` argument. +)code", 0.0f, true) + .AddOptionalArg("foreground", + R"code(If different than 0, the pixel position is sampled uniformly from all foreground pixels. + +If 0, the pixel position is sampled uniformly from all available pixels.)code", + 0, true) + .NumInput(1) + .NumOutput(1); + +class RandomMaskPixelCPU : public Operator { + public: + explicit RandomMaskPixelCPU(const OpSpec &spec); + bool CanInferOutputs() const override { return true; } + bool SetupImpl(std::vector &output_desc, const workspace_t &ws) override; + void RunImpl(workspace_t &ws) override; + + private: + template + void RunImplTyped(workspace_t &ws); + + int64_t seed_; + std::vector rng_; + std::vector rle_; + + std::vector foreground_; + std::vector value_; + std::vector threshold_; + + bool has_value_ = false; + + USE_OPERATOR_MEMBERS(); +}; + +RandomMaskPixelCPU::RandomMaskPixelCPU(const OpSpec &spec) + : Operator(spec), + seed_(spec.GetArgument("seed")), + has_value_(spec.ArgumentDefined("value")) { + if (has_value_) { + DALI_ENFORCE(!spec.ArgumentDefined("threshold"), + "Arguments ``value`` and ``threshold`` can not be provided together"); + } +} + +bool RandomMaskPixelCPU::SetupImpl(std::vector &output_desc, + const workspace_t &ws) { + const auto &in_masks = ws.template InputRef(0); + int nsamples = in_masks.size(); + auto in_masks_shape = in_masks.shape(); + int ndim = in_masks_shape.sample_dim(); + output_desc.resize(1); + output_desc[0].shape = uniform_list_shape(nsamples, {ndim}); + output_desc[0].type = TypeTable::GetTypeInfo(DALI_INT64); + + foreground_.resize(nsamples); + value_.clear(); + threshold_.clear(); + + GetPerSampleArgument(foreground_, "foreground", ws, nsamples); + if (spec_.ArgumentDefined("value")) { + GetPerSampleArgument(value_, "value", ws, nsamples); + } else { + GetPerSampleArgument(threshold_, "threshold", ws, nsamples); + } + return true; +} + +template +void RandomMaskPixelCPU::RunImplTyped(workspace_t &ws) { + const auto &in_masks = ws.template InputRef(0); + auto &out_pixel_pos = ws.template OutputRef(0); + int nsamples = in_masks.size(); + auto in_masks_shape = in_masks.shape(); + int ndim = in_masks_shape.sample_dim(); + auto masks_view = view(in_masks); + auto pixel_pos_view = view(out_pixel_pos); + auto& thread_pool = ws.GetThreadPool(); + + if (rng_.empty()) { + for (int i = 0; i < thread_pool.size(); i++) { + rng_.emplace_back(seed_ + i); + } + } + assert(rng_.size() == static_cast(thread_pool.size())); + rle_.resize(thread_pool.size()); + + for (int sample_idx = 0; sample_idx < nsamples; sample_idx++) { + thread_pool.AddWork( + [&, sample_idx](int thread_id) { + auto &rng = rng_[thread_id]; + auto mask = masks_view[sample_idx]; + auto pixel_pos = pixel_pos_view[sample_idx]; + const auto &mask_sh = mask.shape; + if (foreground_[sample_idx]) { + int64_t flat_idx = -1; + auto &rle_mask = rle_[thread_id]; + rle_mask.Clear(); + if (has_value_) { + T value = static_cast(value_[sample_idx]); + // checking if the value is representable by T, otherwise we + // just fall back to pick a random pixel. + if (static_cast(value) == value_[sample_idx]) { + rle_mask.Init( + mask, [value](const T &x) { return x == value; }); + } + } else { + float threshold = threshold_[sample_idx]; + rle_mask.Init( + mask, [threshold](const T &x) { return x > threshold; }); + } + if (rle_mask.count() > 0) { + auto dist = std::uniform_int_distribution(0, rle_mask.count() - 1); + flat_idx = rle_mask.find(dist(rng)); + } + if (flat_idx >= 0) { + // Convert from flat_idx to per-dim indices + auto mask_strides = kernels::GetStrides(mask_sh); + for (int d = 0; d < ndim - 1; d++) { + pixel_pos.data[d] = flat_idx / mask_strides[d]; + flat_idx = flat_idx % mask_strides[d]; + } + pixel_pos.data[ndim - 1] = flat_idx; + return; + } + } + // Either foreground == 0 or no foreground pixels found. Get a random center + for (int d = 0; d < ndim; d++) { + pixel_pos.data[d] = std::uniform_int_distribution(0, mask_sh[d] - 1)(rng); + } + }, in_masks_shape.tensor_size(sample_idx)); + } + thread_pool.RunAll(); +} + +void RandomMaskPixelCPU::RunImpl(workspace_t &ws) { + const auto &in_masks = ws.template InputRef(0); + TYPE_SWITCH(in_masks.type().id(), type2id, T, MASK_SUPPORTED_TYPES, ( + RunImplTyped(ws); + ), ( // NOLINT + DALI_FAIL(make_string("Unexpected data type: ", in_masks.type().id())); + )); // NOLINT +} + +DALI_REGISTER_OPERATOR(segmentation__RandomMaskPixel, RandomMaskPixelCPU, CPU); + +} // namespace dali diff --git a/dali/operators/segmentation/utils/searchable_rle_mask.h b/dali/operators/segmentation/utils/searchable_rle_mask.h index 62b3dcc6e5..b70c997c9a 100644 --- a/dali/operators/segmentation/utils/searchable_rle_mask.h +++ b/dali/operators/segmentation/utils/searchable_rle_mask.h @@ -40,12 +40,18 @@ class SearchableRLEMask { bool operator()(const T &value) const { return value > 0; } }; + void Clear() { + groups_.clear(); + count_ = 0; + } + /** * @brief Construct a searchable RLE mask. ``predicate`` is used to * determine the mask values that are considered foreground */ template - explicit SearchableRLEMask(span mask_view, Predicate &&is_foreground = {}) { + void Init(span mask_view, Predicate &&is_foreground = {}) { + Clear(); int64_t idx = 0; int64_t sz = mask_view.size(); while (idx < sz) { @@ -63,9 +69,10 @@ class SearchableRLEMask { } template - explicit SearchableRLEMask(TensorView mask_view, Predicate &&is_foreground = {}) - : SearchableRLEMask(span{mask_view.data, volume(mask_view.shape)}, - std::forward(is_foreground)) {} + void Init(TensorView mask_view, Predicate &&is_foreground = {}) { + Init(span{mask_view.data, volume(mask_view.shape)}, + std::forward(is_foreground)); + } /** * @brief Returns the position of the i-th foreground pixel. diff --git a/dali/operators/segmentation/utils/searchable_rle_mask_test.cc b/dali/operators/segmentation/utils/searchable_rle_mask_test.cc index 681a5c4a95..84ee000578 100644 --- a/dali/operators/segmentation/utils/searchable_rle_mask_test.cc +++ b/dali/operators/segmentation/utils/searchable_rle_mask_test.cc @@ -25,7 +25,8 @@ TEST(SearchableRLEMask, handcrafted_mask1) { 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0}; TensorView mask_view(mask, TensorShape<>{6, 5}); - SearchableRLEMask search_mask(mask_view); + SearchableRLEMask search_mask; + search_mask.Init(mask_view); ASSERT_EQ(7, search_mask.count()); auto rle = search_mask.encoded(); @@ -63,7 +64,8 @@ TEST(SearchableRLEMask, handcrafted_mask2) { 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1}; TensorView mask_view(mask, TensorShape<>{6, 5}); - SearchableRLEMask search_mask(mask_view); + SearchableRLEMask search_mask; + search_mask.Init(mask_view); ASSERT_EQ(10, search_mask.count()); auto rle = search_mask.encoded(); @@ -100,14 +102,16 @@ TEST(SearchableRLEMask, handcrafted_mask2) { TEST(SearchableRLEMask, all_background) { std::vector all_bg(10, 0.0f); - SearchableRLEMask all_bg_mask(make_cspan(all_bg)); + SearchableRLEMask all_bg_mask; + all_bg_mask.Init(make_cspan(all_bg)); ASSERT_EQ(0, all_bg_mask.count()); ASSERT_EQ(-1, all_bg_mask.find(0)); } TEST(SearchableRLEMask, all_foreground) { std::vector all_fg(10, 1.0f); - SearchableRLEMask all_fg_mask(make_cspan(all_fg)); + SearchableRLEMask all_fg_mask; + all_fg_mask.Init(make_cspan(all_fg)); ASSERT_EQ(all_fg.size(), all_fg_mask.count()); for (size_t i = 0; i < all_fg.size(); i++) ASSERT_EQ(i, all_fg_mask.find(i)); @@ -117,7 +121,8 @@ TEST(SearchableRLEMask, alternative_pattern) { std::vector pattern(10, 0.0f); for (size_t i = 1; i < pattern.size(); i+=2) pattern[i] = 1.0f; - SearchableRLEMask pattern_mask(make_cspan(pattern)); + SearchableRLEMask pattern_mask; + pattern_mask.Init(make_cspan(pattern)); ASSERT_EQ(pattern.size() / 2, pattern_mask.count()); for (int i = 0; i < pattern_mask.count(); i++) ASSERT_EQ(2 * i + 1, pattern_mask.find(i)); diff --git a/dali/test/python/test_dali_cpu_only.py b/dali/test/python/test_dali_cpu_only.py index 701815b886..89e4511710 100644 --- a/dali/test/python/test_dali_cpu_only.py +++ b/dali/test/python/test_dali_cpu_only.py @@ -729,4 +729,13 @@ def test_pytorch_plugin_cpu(): pipe.set_outputs(outs) pii = DALIGenericIterator([pipe], ["data"]) +def test_random_mask_pixel_cpu(): + pipe = Pipeline(batch_size=batch_size, num_threads=3, device_id=None) + data = fn.external_source(source = get_data, layout = "HWC") + pixel_pos = fn.segmentation.random_mask_pixel(data) + pipe.set_outputs(pixel_pos) + pipe.build() + for _ in range(3): + pipe.run() + # ToDo add tests for DLTensorPythonFunction if easily possible diff --git a/dali/test/python/test_operator_segmentation_random_mask_pixel.py b/dali/test/python/test_operator_segmentation_random_mask_pixel.py new file mode 100644 index 0000000000..d4278dc4b5 --- /dev/null +++ b/dali/test/python/test_operator_segmentation_random_mask_pixel.py @@ -0,0 +1,63 @@ +import numpy as np +import nvidia.dali as dali +import nvidia.dali.fn as fn +import nvidia.dali.types as types +from test_utils import check_batch, dali_type +import random +from segmentation_test_utils import make_batch_select_masks +from nose.tools import assert_raises + +np.random.seed(4321) + +def check_random_mask_pixel(ndim=2, batch_size=3, + min_extent=20, max_extent=50): + pipe = dali.pipeline.Pipeline(batch_size=batch_size, num_threads=4, device_id=0, seed=1234) + with pipe: + # Input mask + in_shape_dims = [fn.cast(fn.uniform(range=(min_extent, max_extent + 1), shape=(1,), device='cpu'), + dtype=types.INT32) for d in range(ndim)] + in_shape = fn.cat(*in_shape_dims, axis=0) + in_mask = fn.cast(fn.uniform(range=(0, 2), device='cpu', shape=in_shape), dtype=types.INT32) + + fg_pixel1 = fn.segmentation.random_mask_pixel(in_mask, foreground=1) # > 0 + fg_pixel2 = fn.segmentation.random_mask_pixel(in_mask, foreground=1, threshold=0.99) # > 0.99 + fg_pixel3 = fn.segmentation.random_mask_pixel(in_mask, foreground=1, value=2) # == 2 + rnd_pixel = fn.segmentation.random_mask_pixel(in_mask, foreground=0) + coin_flip = fn.coin_flip(probability=0.7) + fg_biased = fn.segmentation.random_mask_pixel(in_mask, foreground=coin_flip) + + # Demo purposes: Taking a random pixel and produce a valid anchor to feed slice + crop_shape = in_shape - 2 # We want to force the center adjustment, therefore the large crop shape + anchor = fg_pixel1 - crop_shape // 2 + anchor = min(max(0, anchor), in_shape - crop_shape) + out_mask = fn.slice(in_mask, anchor, crop_shape, axes=tuple(range(ndim))) + + pipe.set_outputs(in_mask, fg_pixel1, fg_pixel2, fg_pixel3, rnd_pixel, coin_flip, fg_biased, + anchor, crop_shape, out_mask) + pipe.build() + for iter in range(3): + outputs = pipe.run() + for idx in range(batch_size): + in_mask = outputs[0].at(idx) + fg_pixel1 = outputs[1].at(idx).tolist() + fg_pixel2 = outputs[2].at(idx).tolist() + fg_pixel3 = outputs[3].at(idx).tolist() + rnd_pixel = outputs[4].at(idx).tolist() + coin_flip = outputs[5].at(idx).tolist() + fg_biased = outputs[6].at(idx).tolist() + anchor = outputs[7].at(idx).tolist() + crop_shape = outputs[8].at(idx).tolist() + out_mask = outputs[9].at(idx) + + assert in_mask[tuple(fg_pixel1)] > 0 + assert in_mask[tuple(fg_pixel2)] > 0.99 + assert in_mask[tuple(fg_pixel3)] == 2 + assert in_mask[tuple(fg_biased)] > 0 or not coin_flip + + for d in range(ndim): + assert 0 <= anchor[d] and anchor[d] + crop_shape[d] <= in_mask.shape[d] + assert out_mask.shape == tuple(crop_shape) + +def test_random_mask_pixel(): + for ndim in (2, 3): + yield check_random_mask_pixel, ndim