Skip to content

Commit

Permalink
Revert to fn stype names for operators in debug mode
Browse files Browse the repository at this point in the history
Signed-off-by: ksztenderski <ksztenderski@nvidia.com>
  • Loading branch information
ksztenderski committed Mar 30, 2022
1 parent 92bfc65 commit aa0da7f
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 45 deletions.
110 changes: 69 additions & 41 deletions dali/python/nvidia/dali/_debug_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import nvidia.dali.tensors as _tensors
import nvidia.dali.types as _types
from nvidia.dali.data_node import DataNode as _DataNode, _check
from nvidia.dali.fn import _to_snake_case
from nvidia.dali.external_source import _prep_data_for_feed_input
from nvidia.dali._utils.external_source_impl import \
get_callback_from_source as _get_callback_from_source, \
Expand Down Expand Up @@ -52,79 +53,106 @@ def shape(self):
return self._data.shape()

def __add__(self, other):
return _PipelineDebug.current()._wrap_op_call(_ops.ArithmeticGenericOp, self, other, name='add')
return _PipelineDebug.current()._wrap_op_call(
_ops.ArithmeticGenericOp, DataNodeDebug.aritm_fn_name, self, other, name='add')

def __radd__(self, other):
return _PipelineDebug.current()._wrap_op_call(_ops.ArithmeticGenericOp, other, self, name='add')
return _PipelineDebug.current()._wrap_op_call(
_ops.ArithmeticGenericOp, DataNodeDebug.aritm_fn_name, other, self, name='add')

def __sub__(self, other):
return _PipelineDebug.current()._wrap_op_call(_ops.ArithmeticGenericOp, self, other, name='sub')
return _PipelineDebug.current()._wrap_op_call(
_ops.ArithmeticGenericOp, DataNodeDebug.aritm_fn_name, self, other, name='sub')

def __rsub__(self, other):
return _PipelineDebug.current()._wrap_op_call(_ops.ArithmeticGenericOp, other, self, name='sub')
return _PipelineDebug.current()._wrap_op_call(
_ops.ArithmeticGenericOp, DataNodeDebug.aritm_fn_name, other, self, name='sub')

def __mul__(self, other):
return _PipelineDebug.current()._wrap_op_call(_ops.ArithmeticGenericOp, self, other, name='mul')
return _PipelineDebug.current()._wrap_op_call(
_ops.ArithmeticGenericOp, DataNodeDebug.aritm_fn_name, self, other, name='mul')

def __rmul__(self, other):
return _PipelineDebug.current()._wrap_op_call(_ops.ArithmeticGenericOp, other, self, name='mul')
return _PipelineDebug.current()._wrap_op_call(
_ops.ArithmeticGenericOp, DataNodeDebug.aritm_fn_name, other, self, name='mul')

def __pow__(self, other):
return _PipelineDebug.current()._wrap_op_call(_ops.ArithmeticGenericOp, self, other, name='pow')
return _PipelineDebug.current()._wrap_op_call(
_ops.ArithmeticGenericOp, DataNodeDebug.aritm_fn_name, self, other, name='pow')

def __rpow__(self, other):
return _PipelineDebug.current()._wrap_op_call(_ops.ArithmeticGenericOp, other, self, name='pow')
return _PipelineDebug.current()._wrap_op_call(
_ops.ArithmeticGenericOp, DataNodeDebug.aritm_fn_name, other, self, name='pow')

def __truediv__(self, other):
return _PipelineDebug.current()._wrap_op_call(_ops.ArithmeticGenericOp, self, other, name='fdiv')
return _PipelineDebug.current()._wrap_op_call(
_ops.ArithmeticGenericOp, DataNodeDebug.aritm_fn_name, self, other, name='fdiv')

def __rtruediv__(self, other):
return _PipelineDebug.current()._wrap_op_call(_ops.ArithmeticGenericOp, other, self, name='fdiv')
return _PipelineDebug.current()._wrap_op_call(
_ops.ArithmeticGenericOp, DataNodeDebug.aritm_fn_name, other, self, name='fdiv')

def __floordiv__(self, other):
return _PipelineDebug.current()._wrap_op_call(_ops.ArithmeticGenericOp, self, other, name='div')
return _PipelineDebug.current()._wrap_op_call(
_ops.ArithmeticGenericOp, DataNodeDebug.aritm_fn_name, self, other, name='div')

def __rfloordiv__(self, other):
return _PipelineDebug.current()._wrap_op_call(_ops.ArithmeticGenericOp, other, self, name='div')
return _PipelineDebug.current()._wrap_op_call(
_ops.ArithmeticGenericOp, DataNodeDebug.aritm_fn_name, other, self, name='div')

def __neg__(self):
return _PipelineDebug.current()._wrap_op_call(_ops.ArithmeticGenericOp, self, name='minus')
return _PipelineDebug.current()._wrap_op_call(
_ops.ArithmeticGenericOp, DataNodeDebug.aritm_fn_name, self, name='minus')

def __eq__(self, other):
return _PipelineDebug.current()._wrap_op_call(_ops.ArithmeticGenericOp, self, other, name='eq')
return _PipelineDebug.current()._wrap_op_call(
_ops.ArithmeticGenericOp, DataNodeDebug.aritm_fn_name, self, other, name='eq')

def __ne__(self, other):
return _PipelineDebug.current()._wrap_op_call(_ops.ArithmeticGenericOp, self, other, name='neq')
return _PipelineDebug.current()._wrap_op_call(
_ops.ArithmeticGenericOp, DataNodeDebug.aritm_fn_name, self, other, name='neq')

def __lt__(self, other):
return _PipelineDebug.current()._wrap_op_call(_ops.ArithmeticGenericOp, self, other, name='lt')
return _PipelineDebug.current()._wrap_op_call(
_ops.ArithmeticGenericOp, DataNodeDebug.aritm_fn_name, self, other, name='lt')

def __le__(self, other):
return _PipelineDebug.current()._wrap_op_call(_ops.ArithmeticGenericOp, self, other, name='leq')
return _PipelineDebug.current()._wrap_op_call(
_ops.ArithmeticGenericOp, DataNodeDebug.aritm_fn_name, self, other, name='leq')

def __gt__(self, other):
return _PipelineDebug.current()._wrap_op_call(_ops.ArithmeticGenericOp, self, other, name='gt')
return _PipelineDebug.current()._wrap_op_call(
_ops.ArithmeticGenericOp, DataNodeDebug.aritm_fn_name, self, other, name='gt')

def __ge__(self, other):
return _PipelineDebug.current()._wrap_op_call(_ops.ArithmeticGenericOp, self, other, name='geq')
return _PipelineDebug.current()._wrap_op_call(
_ops.ArithmeticGenericOp, DataNodeDebug.aritm_fn_name, self, other, name='geq')

def __and__(self, other):
return _PipelineDebug.current()._wrap_op_call(_ops.ArithmeticGenericOp, self, other, name='bitand')
return _PipelineDebug.current()._wrap_op_call(
_ops.ArithmeticGenericOp, DataNodeDebug.aritm_fn_name, self, other, name='bitand')

def __rand__(self, other):
return _PipelineDebug.current()._wrap_op_call(_ops.ArithmeticGenericOp, other, self, name='bitand')
return _PipelineDebug.current()._wrap_op_call(
_ops.ArithmeticGenericOp, DataNodeDebug.aritm_fn_name, other, self, name='bitand')

def __or__(self, other):
return _PipelineDebug.current()._wrap_op_call(_ops.ArithmeticGenericOp, self, other, name='bitor')
return _PipelineDebug.current()._wrap_op_call(
_ops.ArithmeticGenericOp, DataNodeDebug.aritm_fn_name, self, other, name='bitor')

def __ror__(self, other):
return _PipelineDebug.current()._wrap_op_call(_ops.ArithmeticGenericOp, other, self, name='bitor')
return _PipelineDebug.current()._wrap_op_call(
_ops.ArithmeticGenericOp, DataNodeDebug.aritm_fn_name, other, self, name='bitor')

def __xor__(self, other):
return _PipelineDebug.current()._wrap_op_call(_ops.ArithmeticGenericOp, self, other, name='bitxor')
return _PipelineDebug.current()._wrap_op_call(
_ops.ArithmeticGenericOp, DataNodeDebug.aritm_fn_name, self, other, name='bitxor')

def __rxor__(self, other):
return _PipelineDebug.current()._wrap_op_call(_ops.ArithmeticGenericOp, other, self, name='bitxor')
return _PipelineDebug.current()._wrap_op_call(
_ops.ArithmeticGenericOp, DataNodeDebug.aritm_fn_name, other, self, name='bitxor')

aritm_fn_name = _to_snake_case(_ops.ArithmeticGenericOp.__name__)


def _transform_data_to_tensorlist(data, batch_size, layout=None, device_id=None):
Expand Down Expand Up @@ -271,7 +299,7 @@ def is_primitive_type(x):
is_batch_list.append(is_batch)
device_list.append(device)
data_list.append(val)

if any([device != device_list[0] for device in device_list]):
raise RuntimeError(f'{type_name} has batches of data on CPU and on GPU. '
'Which is not supported.')
Expand Down Expand Up @@ -311,14 +339,12 @@ class _OperatorManager:
Uses :class:`ops.Operator` to create OpSpec and handle input sets.
"""

def __init__(self, op_class, next_logical_id, batch_size, seed, inputs, kwargs):
def __init__(self, op_class, op_name, next_logical_id, batch_size, seed, inputs, kwargs):
"""Creates direct operator."""

self._separate_kwargs(kwargs)

self._op_name = op_class.__name__

if op_class.__name__ == 'ArithmeticGenericOp':
if op_name == 'arithmetic_generic_op':
inputs = self._init_arithm_op(kwargs['name'], inputs)

# Save inputs classification for later verification.
Expand All @@ -333,8 +359,8 @@ def __init__(self, op_class, next_logical_id, batch_size, seed, inputs, kwargs):
if input_set_len == 1:
input_set_len = len(classification.is_batch)
else:
raise ValueError("All argument lists for Multpile Input Sets used "
f"with operator '{self._op_name}' must have the same length")
raise ValueError("All argument lists for Multipile Input Sets used "
f"with operator '{op_name}' must have the same length.")
self._inputs_classification.append(classification)
self.expected_inputs_size = len(inputs)

Expand All @@ -345,8 +371,9 @@ def __init__(self, op_class, next_logical_id, batch_size, seed, inputs, kwargs):

self._batch_size = batch_size
self._device = self._init_args.get('device', 'cpu')
self.op_helper = op_class(**self._init_args)
self._expected_inputs_size = len(inputs)
self.op_helper = op_class(**self._init_args)
self._op_name = op_name
self.op_spec = self.op_helper._spec
self.logical_ids = [id for id in range(next_logical_id, next_logical_id + input_set_len)]

Expand Down Expand Up @@ -570,12 +597,13 @@ def feed_input(self, data_node, data, **kwargs):
else:
self._external_sources[name]._feed_input(name, data, kwargs)

def _create_op(self, op_class, key, inputs, kwargs):
def _create_op(self, op_class, op_name, key, inputs, kwargs):
"""Creates direct operator."""
self._operators[key] = _OperatorManager(
op_class, self._next_logical_id, self._max_batch_size, self._seed_generator.integers(0, 2**32), inputs, kwargs)

self._pipe.AddMultipleOperators(self._operators[key].op_spec, self._operators[key].logical_ids)
op_class, op_name, self._next_logical_id, self._max_batch_size, self._seed_generator.integers(0, 2**32), inputs, kwargs)

self._pipe.AddMultipleOperators(
self._operators[key].op_spec, self._operators[key].logical_ids)
self._next_logical_id = self._operators[key].logical_ids[-1] + 1

def _external_source(self, name=None, **kwargs):
Expand Down Expand Up @@ -625,17 +653,17 @@ def _extract_data_node_inputs(inputs):

return data_nodes

def _wrap_op_call(self, op_class, *inputs, **kwargs):
def _wrap_op_call(self, op_class, op_name, *inputs, **kwargs):
self._cur_operator_id += 1
key = inspect.getframeinfo(
inspect.currentframe().f_back.f_back)[:3] + (self._cur_operator_id,)
if not self._operators_built:
self._create_op(op_class, key, inputs, kwargs)
self._create_op(op_class, op_name, key, inputs, kwargs)

if key in self._operators:
if op_class.__name__ == 'ArithmeticGenericOp':
if op_name == 'arithmetic_generic_op':
inputs = _PipelineDebug._extract_data_node_inputs(inputs)
return self._run_op(self._operators[key], inputs, kwargs)
else:
raise RuntimeError(f"Unexpected operator '{op_class}'. Debug mode does not support"
raise RuntimeError(f"Unexpected operator '{op_name}'. Debug mode does not support"
" changing the order of operators executed within the pipeline.")
2 changes: 1 addition & 1 deletion dali/python/nvidia/dali/fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def fn_wrapper(*inputs, **kwargs):
from nvidia.dali._debug_mode import _PipelineDebug
current_pipeline = _PipelineDebug.current()
if getattr(current_pipeline, '_debug_on', False):
return current_pipeline._wrap_op_call(op_class, *inputs, **kwargs)
return current_pipeline._wrap_op_call(op_class, wrapper_name, *inputs, **kwargs)
else:
return op_wrapper(*inputs, **kwargs)

Expand Down
6 changes: 3 additions & 3 deletions dali/test/python/test_pipeline_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,16 +504,16 @@ def test_input_sets():


@pipeline_def(batch_size=8, num_threads=3, device_id=0, debug=True)
def incorrect_input_Sets_pipeline():
def incorrect_input_sets_pipeline():
jpegs, _ = fn.readers.file(file_root=file_root, seed=42, random_shuffle=True)
images = fn.decoders.image(jpegs, seed=42)
output = fn.cat([images, images, images], [images, images])

return tuple(output)


@raises(ValueError, glob="All argument lists for Multpile Input Sets used with operator 'Cat' must have the same length")
@raises(ValueError, glob="All argument lists for Multipile Input Sets used with operator 'cat' must have the same length.")
def test_incorrect_input_sets():
pipe = incorrect_input_Sets_pipeline()
pipe = incorrect_input_sets_pipeline()
pipe.build()
pipe.run()

0 comments on commit aa0da7f

Please sign in to comment.