diff --git a/dali/python/nvidia/dali/_debug_mode.py b/dali/python/nvidia/dali/_debug_mode.py index e9deb21b60..879b66fed7 100644 --- a/dali/python/nvidia/dali/_debug_mode.py +++ b/dali/python/nvidia/dali/_debug_mode.py @@ -538,6 +538,7 @@ def __init__(self, exec_func, **kwargs): self._exec_func = exec_func self._cur_operator_id = -1 self._next_logical_id = 0 + self._seed_upper_bound = (1 << 31) - 1 self._operators = {} self._operators_built = False self._cur_iter_batch_info = _IterBatchInfo(-1, None) # Used for variable batch sizes. @@ -549,7 +550,7 @@ def __init__(self, exec_func, **kwargs): import numpy as np seed = kwargs.get('seed', -1) if seed < 0: - seed = np.random.randint(0, 2**32) + seed = np.random.randint(self._seed_upper_bound) self._seed_generator = np.random.default_rng(seed) def __enter__(self): @@ -625,7 +626,7 @@ def _create_op(self, op_class, op_name, key, cur_context, inputs, kwargs): """Creates direct operator.""" self._operators[key] = _OperatorManager( op_class, op_name, self, cur_context, self._next_logical_id, self._max_batch_size, - self._device_id, self._seed_generator.integers(0, 2**32), inputs, kwargs) + self._device_id, self._seed_generator.integers(self._seed_upper_bound), inputs, kwargs) self._pipe.AddMultipleOperators( self._operators[key].op_spec, self._operators[key].logical_ids) diff --git a/dali/python/nvidia/dali/_utils/eager_utils.py b/dali/python/nvidia/dali/_utils/eager_utils.py index 99df649a19..080e0cf56f 100644 --- a/dali/python/nvidia/dali/_utils/eager_utils.py +++ b/dali/python/nvidia/dali/_utils/eager_utils.py @@ -212,7 +212,11 @@ def _arithm_op(name, *inputs): categories_idxs, inputs, integers, reals = _ops._group_inputs( inputs, edge_type=(_tensors.TensorListCPU, _tensors.TensorListGPU)) input_desc = _ops._generate_input_desc(categories_idxs, integers, reals) - device = _ops._choose_device(inputs) + + if any(isinstance(input, _tensors.TensorListGPU) for input in inputs): + device = 'gpu' + else: + device = 'cpu' if device == "gpu": inputs = list(input._as_gpu() if isinstance( @@ -333,6 +337,87 @@ def _rxor(self, other): _stateless_operators_cache = {} +def _create_backend_op(spec, device, num_inputs, num_outputs, call_args_names, op_name): + inp_device = 'cpu' if device == 'mixed' else device + out_device = 'gpu' if device == 'mixed' else device + + for i in range(num_inputs): + spec.AddInput(op_name + f'[{i}]', inp_device) + + for i in range(num_outputs): + spec.AddOutput(op_name + f'_out[{i}]', out_device) + + for arg_name in call_args_names: + spec.AddArgumentInput(arg_name, '') + + if device == 'cpu': + backend_op = _b.EagerOperatorCPU(spec) + elif device == 'gpu': + backend_op = _b.EagerOperatorGPU(spec) + elif device == 'mixed': + backend_op = _b.EagerOperatorMixed(spec) + else: + raise ValueError( + f"Incorrect device type '{device}' in eager operator '{op_name}'.") + + return backend_op + + +def _eager_op_object_factory(op_class, op_name): + """ Creates eager operator class to use with objective ops-like API. For completeness, + currently not used. + """ + class EagerOperator(op_class): + def __init__(self, **kwargs): + self._batch_size = getattr(kwargs, 'batch_size', -1) + + # Workaround for batch size deduction in _prep_args as we don't have inputs yet. + kwargs['batch_size'] = 0 + + _, init_args, _ = _prep_args( + [], kwargs, op_name, op_name, _callable_op_factory.disqualified_arguments) + device_id = init_args.pop('device_id') + init_args.pop('max_batch_size') + + super().__init__(**init_args) + + self._spec.AddArg('device_id', device_id) + self.built = False + + def __call__(self, *inputs, **kwargs): + inputs, init_args, call_args = _prep_args( + inputs, kwargs, op_name, op_name, _callable_op_factory.disqualified_arguments) + + if not self.built: + num_outputs = self.schema.CalculateOutputs( + self._spec) + self.schema.CalculateAdditionalOutputs(self._spec) + + self._spec.AddArg('max_batch_size', init_args['max_batch_size']) + self._backend_op = _create_backend_op( + self._spec, self._device, len(inputs), num_outputs, call_args.keys(), op_name) + self.built = True + + output = self._backend_op(inputs, kwargs) + + if len(output) == 1: + return output[0] + + return output + + return EagerOperator + + +def _expose_eager_op_as_object(op_class, submodule): + """ Exposes eager operators as objects. Can be used if we decide to change eager API from + functional to objective. + """ + + op_name = op_class.schema_name + module = _internal.get_submodule('nvidia.dali.experimental.eager', submodule) + op = _eager_op_object_factory(op_class, op_name) + setattr(module, op_name, op) + + def _eager_op_base_factory(op_class, op_name, num_inputs, call_args_names): class EagerOperatorBase(op_class): def __init__(self, *, max_batch_size, device_id, **kwargs): @@ -341,26 +426,55 @@ def __init__(self, *, max_batch_size, device_id, **kwargs): self._spec.AddArg('device_id', device_id) self._spec.AddArg('max_batch_size', max_batch_size) - for i in range(num_inputs): - self._spec.AddInput(op_name + f'[{i}]', self._device) - - for arg_name in call_args_names: - self._spec.AddArgumentInput(arg_name, '') + num_outputs = self.schema.CalculateOutputs( + self._spec) + self.schema.CalculateAdditionalOutputs(self._spec) - if self._device == 'cpu': - self._backend_op = _b.EagerOperatorCPU(self._spec) - elif self._device == 'gpu': - self._backend_op = _b.EagerOperatorGPU(self._spec) - elif self._device == 'mixed': - self._backend_op = _b.EagerOperatorMixed(self._spec) - else: - raise ValueError( - f"Incorrect device type '{self._device}' in eager operator '{op_name}'.") + self._backend_op = _create_backend_op( + self._spec, self._device, num_inputs, num_outputs, call_args_names, op_name) return EagerOperatorBase -def _stateless_op_factory(op_class, op_name, num_inputs, call_args_names): +def _create_module_class(): + """ Creates a class imitating a module. Used for `rng_state` so we can have nested methods. + E.g. `rng_state.random.normal`. + """ + class Module: + @classmethod + def _submodule(cls, name): + """ Returns submodule, creates new if it does not exist. """ + if name not in cls._submodules: + # Register a new submodule class (object representing submodule will be created in + # the rng_state's constructor). + cls._submodules[name] = _create_state_submodule(name) + + return cls._submodules[name] + + _submodules = {} + + return Module + + +def _create_state_submodule(name): + """ Creates a class imitating a submodule. It can contain methods and nested submodules. + Used for submodules of rng_state, e.g. `rng_state.random`, `rng_state.noise`. + """ + + class StateSubmodule(_create_module_class()): + def __init__(self, operator_cache, seed_generator): + self._operator_cache = operator_cache + self._seed_generator = seed_generator + + for name, submodule_class in StateSubmodule._submodules.items(): + # Adds nested submodules. + setattr(self, name, submodule_class(self._operator_cache, self._seed_generator)) + + __name__ = name + + return StateSubmodule + + +def _callable_op_factory(op_class, op_name, num_inputs, call_args_names): class EagerOperator(_eager_op_base_factory(op_class, op_name, num_inputs, call_args_names)): def __call__(self, inputs, kwargs): # Here all kwargs are supposed to be TensorLists. @@ -374,6 +488,13 @@ def __call__(self, inputs, kwargs): return EagerOperator +_callable_op_factory.disqualified_arguments = { + 'bytes_per_sample_hint', + 'preserve', + 'seed' +} + + def _iterator_op_factory(op_class, op_name, num_inputs, call_args_names): class EagerOperator(_eager_op_base_factory(op_class, op_name, num_inputs, call_args_names)): def __init__(self, call_args, *, max_batch_size, **kwargs): @@ -425,6 +546,12 @@ def __len__(self): return EagerOperator +_iterator_op_factory.disqualified_arguments = { + 'bytes_per_sample_hint', + 'preserve', +} + + def _choose_device(op_name, wrapper_name, inputs, device_param): """Returns device type and device_id based on inputs and device_param.""" @@ -541,20 +668,25 @@ def _desc_call_args(inputs, args): [(key, value.dtype, value.layout(), len(value[0].shape())) for key, value in args.items()])) +def _gen_cache_key(op_name, inputs, init_args, call_args): + """ Creating cache key consisting of operator name, description of inputs, input arguments + and init args. Each call arg is described by dtype, layout and dim. + """ + return op_name + _desc_call_args(inputs, call_args) + str(sorted(init_args.items())) + + def _wrap_stateless(op_class, op_name, wrapper_name): """Wraps stateless Eager Operator in a function. Callable the same way as functions in fn API, but directly with TensorLists. """ def wrapper(*inputs, **kwargs): inputs, init_args, call_args = _prep_args( - inputs, kwargs, op_name, wrapper_name, _wrap_stateless.disqualified_arguments) + inputs, kwargs, op_name, wrapper_name, _callable_op_factory.disqualified_arguments) - # Creating cache key consisting of operator name, description of inputs, input arguments - # and init args. Each call arg is described by dtype, layout and dim. - key = op_name + _desc_call_args(inputs, call_args) + str(sorted(init_args.items())) + key = _gen_cache_key(op_name, inputs, init_args, call_args) if key not in _stateless_operators_cache: - _stateless_operators_cache[key] = _stateless_op_factory( + _stateless_operators_cache[key] = _callable_op_factory( op_class, wrapper_name, len(inputs), call_args.keys())(**init_args) return _stateless_operators_cache[key](inputs, call_args) @@ -562,11 +694,31 @@ def wrapper(*inputs, **kwargs): return wrapper -_wrap_stateless.disqualified_arguments = { - 'bytes_per_sample_hint', - 'preserve', - 'seed' -} +def _wrap_stateful(op_class, op_name, wrapper_name): + """Wraps stateful Eager Operator as method of a class. Callable the same way as functions in + fn API, but directly with TensorLists. + """ + + def wrapper(self, *inputs, **kwargs): + inputs, init_args, call_args = _prep_args( + inputs, kwargs, op_name, wrapper_name, _callable_op_factory.disqualified_arguments) + + key = _gen_cache_key(op_name, inputs, init_args, call_args) + + if key not in self._operator_cache: + # Creating a new operator instance with deterministically generated seed, so if we + # preserve the order of operator calls in different instances of rng_state, they + # return the same results. + seed = self._seed_generator.integers(_wrap_stateful.seed_upper_bound) + self._operator_cache[key] = _callable_op_factory( + op_class, wrapper_name, len(inputs), call_args.keys())(**init_args, seed=seed) + + return self._operator_cache[key](inputs, call_args) + + return wrapper + + +_wrap_stateful.seed_upper_bound = (1 << 31) - 1 def _wrap_iterator(op_class, op_name, wrapper_name): @@ -582,7 +734,7 @@ def wrapper(*inputs, **kwargs): raise ValueError("Iterator type eager operators should not receive any inputs.") inputs, init_args, call_args = _prep_args( - inputs, kwargs, op_name, wrapper_name, _wrap_iterator.disqualified_arguments) + inputs, kwargs, op_name, wrapper_name, _iterator_op_factory.disqualified_arguments) op = _iterator_op_factory(op_class, wrapper_name, len(inputs), call_args.keys())(call_args, **init_args) @@ -592,14 +744,39 @@ def wrapper(*inputs, **kwargs): return wrapper -_wrap_iterator.disqualified_arguments = { - 'bytes_per_sample_hint', - 'preserve', -} +def _get_rng_state_target_module(submodules): + """ Returns target module of rng_state. If a module did not exist, creates it. """ + from nvidia.dali.experimental import eager + + last_module = eager.rng_state + for cur_module_name in submodules: + # If nonexistent registers rng_state's submodule. + cur_module = last_module._submodule(cur_module_name) + last_module = cur_module + + return last_module + + +def _get_eager_target_module(parent_module, submodules, make_hidden): + """ Returns target module inside ``parent_module`` if specified, otherwise inside eager. """ + if parent_module is None: + # Exposing to nvidia.dali.experimental.eager module. + parent_module = _internal.get_submodule('nvidia.dali', 'experimental.eager') + else: + # Exposing to experimental.eager submodule of the specified parent module. + parent_module = _internal.get_submodule( + sys.modules[parent_module], 'experimental.eager') + + if make_hidden: + op_module = _internal.get_submodule(parent_module, submodules[:-1]) + else: + op_module = _internal.get_submodule(parent_module, submodules) + + return op_module -def _wrap_eager_op(op_class, submodule, parent_module, wrapper_name, wrapper_doc, make_hidden): - """Exposes eager operator to the appropriate module (similar to :func:`nvidia.dali.fn._wrap_op`). +def _wrap_eager_op(op_class, submodules, parent_module, wrapper_name, wrapper_doc, make_hidden): + """ Exposes eager operator to the appropriate module (similar to :func:`nvidia.dali.fn._wrap_op`). Uses ``op_class`` for preprocessing inputs and keyword arguments and filling OpSpec for backend eager operators. @@ -612,36 +789,31 @@ def _wrap_eager_op(op_class, submodule, parent_module, wrapper_name, wrapper_doc wrapper_doc (str): Documentation of the wrapper function. make_hidden (bool): If operator is hidden, we should extract it from hidden submodule. """ + op_name = op_class.schema_name op_schema = _b.TryGetSchema(op_name) - if op_schema.IsDeprecated() or op_name in _excluded_operators or op_name in _stateful_operators: - # TODO(ksztenderski): For now only exposing stateless and iterator operators. - return - elif op_name in _iterator_operators: - wrapper = _wrap_iterator(op_class, op_name, wrapper_name) - else: - # If operator is not stateful or a generator expose it as stateless. - wrapper = _wrap_stateless(op_class, op_name, wrapper_name) - if parent_module is None: - # Exposing to nvidia.dali.experimental.eager module. - parent_module = _internal.get_submodule('nvidia.dali', 'experimental.eager') + if op_schema.IsDeprecated() or op_name in _excluded_operators: + return + elif op_name in _stateful_operators: + wrapper = _wrap_stateful(op_class, op_name, wrapper_name) + op_module = _get_rng_state_target_module(submodules) else: - # Exposing to experimental.eager submodule of the specified parent module. - parent_module = _internal.get_submodule(sys.modules[parent_module], 'experimental.eager') + if op_name in _iterator_operators: + wrapper = _wrap_iterator(op_class, op_name, wrapper_name) + else: + # If operator is not stateful, generator, deprecated or excluded expose it as stateless. + wrapper = _wrap_stateless(op_class, op_name, wrapper_name) - if make_hidden: - op_module = _internal.get_submodule(parent_module, submodule[:-1]) - else: - op_module = _internal.get_submodule(parent_module, submodule) + op_module = _get_eager_target_module(parent_module, submodules, make_hidden) if not hasattr(op_module, wrapper_name): wrapper.__name__ = wrapper_name wrapper.__qualname__ = wrapper_name wrapper.__doc__ = wrapper_doc - wrapper._schema_name = op_schema + wrapper._schema_name = op_name - if submodule: + if submodules: wrapper.__module__ = op_module.__name__ setattr(op_module, wrapper_name, wrapper) diff --git a/dali/python/nvidia/dali/experimental/eager/__init__.py b/dali/python/nvidia/dali/experimental/eager/__init__.py index 23f2e2e32c..5b46e65b1e 100644 --- a/dali/python/nvidia/dali/experimental/eager/__init__.py +++ b/dali/python/nvidia/dali/experimental/eager/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from . import math # noqa: F401 +from nvidia.dali._utils.eager_utils import _create_module_class """ Eager module implements eager versions of standard DALI operators. There are 3 main types of eager operators: @@ -53,6 +54,7 @@ class arithmetic(metaclass=_MetaArithmetic): >>> tl = dali.tensors.TensorListCPU(...) >>> out = tl ** 2 """ + def __init__(self, enabled=True): self.prev = arithmetic._enabled arithmetic._enabled = enabled @@ -67,4 +69,32 @@ def __enter__(self): def __exit__(self, type, value, traceback): arithmetic._enabled = self.prev + __name__ = 'arithmetic' _enabled = False + + +class rng_state(_create_module_class()): + """ Manager class for stateful operators. This object holds a cache of reusable operators. + Operators are initialized with deterministic seeds generated according to the ``seed`` argument + and are reused when you call the same operator with the same scalar parameters. + + Example: + >>> eager_state = dali.experimental.eager.rng_state(seed=42) + >>> out1 = eager_state.random.normal(shape=[5, 5], batch_size=8) + >>> # Here we will reuse the same operator. + >>> out2 = eager_state.random.normal(shape=[5, 5], batch_size=8) + >>> # And now we will create a new operator with new seed. + >>> out3 = eager_state.random.normal(shape=[10, 10], batch_size=8) + """ + + def __init__(self, seed=None): + import numpy as np + + self._operator_cache = {} + self._seed_generator = np.random.default_rng(seed) + + for name, submodule_class in rng_state._submodules.items(): + # Create attributes imitating submodules, e.g. `random`, `noise`. + setattr(self, name, submodule_class(self._operator_cache, self._seed_generator)) + + __name__ = 'rng_state' diff --git a/dali/test/python/test_eager_coverage.py b/dali/test/python/test_eager_coverage.py index 8677ffeffe..3b07d4f7cc 100644 --- a/dali/test/python/test_eager_coverage.py +++ b/dali/test/python/test_eager_coverage.py @@ -17,10 +17,10 @@ import re from functools import reduce -import nvidia.dali.experimental.eager as eager -import nvidia.dali.fn as fn -import nvidia.dali.tensors as tensors -import nvidia.dali.types as types +from nvidia.dali import fn +from nvidia.dali import tensors +from nvidia.dali import types +from nvidia.dali.experimental import eager from nvidia.dali.pipeline import Pipeline, pipeline_def from nvidia.dali._utils.eager_utils import _slice_tensorlist from test_dali_cpu_only_utils import (pipeline_arithm_ops_cpu, setup_test_nemo_asr_reader_cpu, @@ -131,14 +131,14 @@ def eager_source(self, i, layout='HWC'): return get_tl(np.array(self.fn_source(i)), layout) -def get_ops(op_path, fn_op=None, eager_op=None): +def get_ops(op_path, fn_op=None, eager_op=None, eager_module=eager): """ Get fn and eager versions of operators from given path. """ import_path = op_path.split('.') if fn_op is None: fn_op = reduce(getattr, [fn] + import_path) if eager_op is None: - eager_op = reduce(getattr, [eager] + import_path) + eager_op = reduce(getattr, [eager_module] + import_path) return fn_op, eager_op @@ -153,7 +153,10 @@ def compare_eager_with_pipeline(pipe, eager_op, *, eager_source=get_data_eager, input_tl = eager_source(i, layout) out_fn = pipe.run() if isinstance(input_tl, (tuple, list)): - out_eager = eager_op(*input_tl, **kwargs) + if len(input_tl): + out_eager = eager_op(*input_tl, **kwargs) + else: + out_eager = eager_op(batch_size=batch_size, **kwargs) else: out_eager = eager_op(input_tl, **kwargs) @@ -199,24 +202,51 @@ def no_input_pipeline(op, kwargs): return out +def no_input_source(*_): + return () + + def check_no_input(op_path, *, fn_op=None, eager_op=None, batch_size=batch_size, N_iterations=5, **kwargs): fn_op, eager_op = get_ops(op_path, fn_op, eager_op) pipe = no_input_pipeline(fn_op, kwargs) - pipe.build() + compare_eager_with_pipeline(pipe, eager_op, eager_source=no_input_source, + batch_size=batch_size, N_iterations=N_iterations, **kwargs) - for _ in range(N_iterations): - out_fn = pipe.run() - out_eager = eager_op(batch_size=batch_size, **kwargs) - if not isinstance(out_eager, (tuple, list)): - out_eager = (out_eager,) +def prep_stateful_operators(op_path): + # Replicating seed that will be used inside rng_state, that way we expect fn and eager + # operators to return same results. + seed_upper_bound = (1 << 31) - 1 + seed = rng.integers(seed_upper_bound) + fn_seed = np.random.default_rng(seed).integers(seed_upper_bound) + eager_state = eager.rng_state(seed) - assert len(out_fn) == len(out_eager) + fn_op, eager_op = get_ops(op_path, eager_module=eager_state) - for tensor_out_fn, tensor_out_eager in zip(out_fn, out_eager): - assert type(tensor_out_fn) == type(tensor_out_eager) - check_batch(tensor_out_fn, tensor_out_eager, batch_size) + return fn_op, eager_op, fn_seed + + +def check_single_input_stateful(op_path, pipe_fun=single_op_pipeline, fn_source=get_data, + fn_op=None, eager_source=get_data_eager, eager_op=None, + layout='HWC', **kwargs): + fn_op, eager_op, fn_seed = prep_stateful_operators(op_path) + + kwargs['seed'] = fn_seed + pipe = pipe_fun(fn_op, kwargs, source=fn_source, layout=layout) + kwargs.pop('seed', None) + + compare_eager_with_pipeline(pipe, eager_op, eager_source=eager_source, layout=layout, **kwargs) + + +def check_no_input_stateful(op_path, *, fn_op=None, eager_op=None, batch_size=batch_size, + N_iterations=5, **kwargs): + fn_op, eager_op, fn_seed = prep_stateful_operators(op_path) + kwargs['seed'] = fn_seed + pipe = no_input_pipeline(fn_op, kwargs) + kwargs.pop('seed', None) + compare_eager_with_pipeline(pipe, eager_op, eager_source=no_input_source, + batch_size=batch_size, N_iterations=N_iterations, **kwargs) @pipeline_def(batch_size=batch_size, num_threads=4, device_id=None) @@ -514,7 +544,7 @@ def test_audio_decoder(): def test_coord_flip(): get_data = GetData([[(rng.integers(0, 255, size=[200, 2], dtype=np.uint8) / - 255).astype(dtype=np.float32)for _ in range(batch_size)] + 255).astype(dtype=np.float32) for _ in range(batch_size)] for _ in range(data_size)]) check_single_input('coord_flip', fn_source=get_data.fn_source, @@ -523,7 +553,7 @@ def test_coord_flip(): def test_bb_flip(): get_data = GetData([[(rng.integers(0, 255, size=[200, 4], dtype=np.uint8) / - 255).astype(dtype=np.float32)for _ in range(batch_size)] + 255).astype(dtype=np.float32) for _ in range(batch_size)] for _ in range(data_size)]) check_single_input('bb_flip', fn_source=get_data.fn_source, @@ -535,7 +565,7 @@ def test_warp_affine(): def test_normalize(): - check_single_input('normalize', batch=True) + check_single_input('normalize') def test_lookup_table(): @@ -965,6 +995,112 @@ def test_arithm_ops(): compare_eager_with_pipeline(pipe, eager_op=eager_arithm_ops) +def test_image_decoder_random_crop(): + check_single_input_stateful('decoders.image_random_crop', pipe_fun=reader_op_pipeline, + fn_source=images_dir, eager_source=PipelineInput( + file_reader_pipeline, file_root=images_dir), + output_type=types.RGB) + + +def test_noise_gaussian(): + check_single_input_stateful('noise.gaussian') + + +def test_noise_salt_and_pepper(): + check_single_input_stateful('noise.salt_and_pepper') + + +def test_noise_shot(): + check_single_input_stateful('noise.shot') + + +def test_random_mask_pixel(): + check_single_input_stateful('segmentation.random_mask_pixel') + + +def test_random_resized_crop(): + check_single_input_stateful('random_resized_crop', size=[5, 5]) + + +def test_random_object_bbox(): + data = tensors.TensorListCPU([tensors.TensorCPU( + np.int32([[1, 0, 0, 0], + [1, 2, 2, 1], + [1, 1, 2, 0], + [2, 0, 0, 1]])), tensors.TensorCPU( + np.int32([[0, 3, 3, 0], + [1, 0, 1, 2], + [0, 1, 1, 0], + [0, 2, 0, 1], + [0, 2, 2, 1]]))]) + + def source(*_): + return data + + check_single_input_stateful('segmentation.random_object_bbox', + fn_source=source, eager_source=source, layout="") + + +def test_fast_resize_crop_mirror(): + check_single_input_stateful('fast_resize_crop_mirror', crop=[5, 5], resize_shorter=10) + + +def test_roi_random_crop(): + shape = [10, 20, 3] + check_single_input_stateful('roi_random_crop', + crop_shape=[x // 2 for x in shape], + roi_start=[x // 4 for x in shape], + roi_shape=[x // 2 for x in shape]) + + +@pipeline_def(batch_size=batch_size, num_threads=4, device_id=None) +def random_bbox_crop_pipeline(get_boxes, get_labels, seed): + boxes = fn.external_source(source=get_boxes) + labels = fn.external_source(source=get_labels) + out = fn.random_bbox_crop(boxes, labels, aspect_ratio=[0.5, 2.0], thresholds=[ + 0.1, 0.3, 0.5], scaling=[0.8, 1.0], bbox_layout="xyXY", seed=seed) + + return tuple(out) + + +def test_random_bbox_crop(): + get_boxes = GetData([[(rng.integers(0, 255, size=[200, 4], dtype=np.uint8) / 255).astype( + dtype=np.float32) for _ in range(batch_size)] for _ in range(data_size)]) + get_labels = GetData([[rng.integers(0, 255, size=[200, 1], dtype=np.int32) for _ in + range(batch_size)] for _ in range(data_size)]) + + def eager_source(i, _): + return get_boxes.eager_source(i), get_labels.eager_source(i) + + _, eager_op, fn_seed = prep_stateful_operators('random_bbox_crop') + + pipe = random_bbox_crop_pipeline(get_boxes.fn_source, get_labels.fn_source, fn_seed) + + compare_eager_with_pipeline(pipe, eager_op, eager_source=eager_source, aspect_ratio=[0.5, 2.0], + thresholds=[0.1, 0.3, 0.5], scaling=[0.8, 1.0], + bbox_layout="xyXY") + + +def test_resize_crop_mirror(): + check_single_input_stateful('resize_crop_mirror', crop=[5, 5], resize_shorter=10) + + +def test_random_coin_flip(): + check_no_input_stateful('random.coin_flip') + + +def test_normal_distribution(): + check_no_input_stateful('random.normal', shape=[5, 5]) + + +def test_random_uniform(): + check_no_input_stateful('random.uniform') + + +def test_batch_permutation(): + check_no_input_stateful('batch_permutation') + + tested_methods = [ 'decoders.image', 'rotate', @@ -1056,6 +1192,21 @@ def test_arithm_ops(): 'get_property', 'tensor_subscript', 'arithmetic_generic_op', + 'decoders.image_random_crop', + 'noise.gaussian', + 'noise.salt_and_pepper', + 'noise.shot', + 'segmentation.random_mask_pixel', + 'segmentation.random_object_bbox', + 'fast_resize_crop_mirror', + 'roi_random_crop', + 'random_bbox_crop', + 'random_resized_crop', + 'resize_crop_mirror', + 'random.coin_flip', + 'random.normal', + 'random.uniform', + 'batch_permutation', ] excluded_methods = [ @@ -1078,6 +1229,8 @@ def test_coverage(): """ methods = module_functions(eager, remove_prefix="nvidia.dali.experimental.eager") + methods += module_functions( + eager.rng_state(), remove_prefix='rng_state', check_non_module=True) # TODO(ksztenderski): Add coverage for GPU operators. exclude = "|".join( ["(^" + x.replace(".", "\.").replace("*", ".*").replace("?", ".") + "$)" # noqa: W605 diff --git a/dali/test/python/test_eager_operators.py b/dali/test/python/test_eager_operators.py index 12b5c24c4e..0d580085ac 100644 --- a/dali/test/python/test_eager_operators.py +++ b/dali/test/python/test_eager_operators.py @@ -13,10 +13,15 @@ # limitations under the License. import numpy as np +import os -import nvidia.dali.experimental.eager as eager -import nvidia.dali.tensors as tensors +from nvidia.dali import fn +from nvidia.dali import pipeline_def +from nvidia.dali import ops +from nvidia.dali import tensors +from nvidia.dali.experimental import eager from nose_utils import assert_raises, raises +from test_utils import get_dali_extra_path @raises(RuntimeError, glob=f"Argument '*' is not supported by eager operator 'crop'.") @@ -87,3 +92,78 @@ def test_arithm_op_context_manager_deep_nested(): assert np.array_equal((tl_1 + tl_2).as_array(), expected_sum) eager.arithmetic(False) + + +def test_identical_rng_states(): + eager_state_1 = eager.rng_state(seed=42) + eager_state_2 = eager.rng_state(seed=42) + + out_1_1 = eager_state_1.random.normal(shape=[5, 5], batch_size=8) + out_1_2 = eager_state_1.noise.gaussian(out_1_1) + out_1_3 = eager_state_1.random.normal(shape=[5, 5], batch_size=8) + + out_2_1 = eager_state_2.random.normal(shape=[5, 5], batch_size=8) + out_2_2 = eager_state_2.noise.gaussian(out_2_1) + out_2_3 = eager_state_2.random.normal(shape=[5, 5], batch_size=8) + + assert np.allclose(out_1_1.as_tensor(), out_2_1.as_tensor()) + assert np.allclose(out_1_2.as_tensor(), out_2_2.as_tensor()) + assert np.allclose(out_1_3.as_tensor(), out_2_3.as_tensor()) + + +def test_identical_rng_states_interleaved(): + eager_state_1 = eager.rng_state(seed=42) + eager_state_2 = eager.rng_state(seed=42) + + out_1_1 = eager_state_1.random.normal(shape=[5, 5], batch_size=8) + eager_state_1.random.normal(shape=[6, 6], batch_size=8) + eager_state_1.noise.gaussian(out_1_1) + out_1_2 = eager_state_1.random.normal(shape=[5, 5], batch_size=8) + + out_2_1 = eager_state_2.random.normal(shape=[5, 5], batch_size=8) + out_2_2 = eager_state_2.random.normal(shape=[5, 5], batch_size=8) + + assert np.allclose(out_1_1.as_tensor(), out_2_1.as_tensor()) + assert np.allclose(out_1_2.as_tensor(), out_2_2.as_tensor()) + + +def test_objective_eager_resize(): + from nvidia.dali._utils import eager_utils + + resize_class = eager_utils._eager_op_object_factory(ops.python_op_factory('Resize'), 'Resize') + tl = tensors.TensorListCPU(np.random.default_rng().integers( + 256, size=(8, 200, 200, 3), dtype=np.uint8)) + + obj_resize = resize_class(resize_x=50, resize_y=50) + out_obj = obj_resize(tl) + out_fun = eager.resize(tl, resize_x=50, resize_y=50) + + assert np.array_equal(out_obj.as_tensor(), out_fun.as_tensor()) + + +@pipeline_def(num_threads=3, device_id=0) +def mixed_image_decoder_pipeline(file_root, seed): + jpeg, _ = fn.readers.file(file_root=file_root, seed=seed) + out = fn.decoders.image(jpeg, device="mixed") + + return out + + +def test_mixed_devices_decoder(): + """ Tests hidden functionality of exposing eager operators as classes. """ + seed = 42 + batch_size = 8 + file_root = os.path.join(get_dali_extra_path(), 'db/single/jpeg') + + pipe = mixed_image_decoder_pipeline(file_root, seed, batch_size=batch_size) + pipe.build() + pipe_out, = pipe.run() + + jpeg, _ = next(eager.readers.file(file_root=file_root, batch_size=batch_size, seed=seed)) + eager_out = eager.decoders.image(jpeg, device="gpu") + + assert len(pipe_out) == len(eager_out) + + with eager.arithmetic(): + for comp_tensor in (pipe_out == eager_out): + assert np.all(comp_tensor.as_cpu()) diff --git a/dali/test/python/test_pipeline_debug.py b/dali/test/python/test_pipeline_debug.py index 21fce65240..ca14122632 100644 --- a/dali/test/python/test_pipeline_debug.py +++ b/dali/test/python/test_pipeline_debug.py @@ -16,9 +16,9 @@ import os from nose.plugins.attrib import attr -import nvidia.dali.fn as fn -import nvidia.dali.tensors as tensors -import nvidia.dali.types as types +from nvidia.dali import fn +from nvidia.dali import tensors +from nvidia.dali import types from nvidia.dali.pipeline.experimental import pipeline_def from nose_utils import raises from test_utils import compare_pipelines, get_dali_extra_path diff --git a/dali/test/python/test_utils.py b/dali/test/python/test_utils.py index 668fabd1b1..ac9f83d056 100644 --- a/dali/test/python/test_utils.py +++ b/dali/test/python/test_utils.py @@ -521,7 +521,7 @@ def to_array(dali_out): return np.array(dali_out) -def module_functions(cls, prefix="", remove_prefix=""): +def module_functions(cls, prefix="", remove_prefix="", check_non_module=False): res = [] if hasattr(cls, '_schema_name'): prefix = prefix.replace(remove_prefix, "") @@ -531,10 +531,11 @@ def module_functions(cls, prefix="", remove_prefix=""): else: prefix = "" res.append(prefix + cls.__name__) - elif inspect.ismodule(cls): + elif check_non_module or inspect.ismodule(cls): for c_name, c in inspect.getmembers(cls): if not c_name.startswith("_") and c_name not in sys.builtin_module_names: - res += module_functions(c, cls.__name__, remove_prefix=remove_prefix) + res += module_functions(c, cls.__name__, remove_prefix=remove_prefix, + check_non_module=check_non_module) return res