-
Notifications
You must be signed in to change notification settings - Fork 609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add eager mode stateful operators #4016
Changes from all commits
567f500
6a924bb
588ae94
7e749a7
5314f22
7c3db6c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -212,7 +212,11 @@ def _arithm_op(name, *inputs): | |
categories_idxs, inputs, integers, reals = _ops._group_inputs( | ||
inputs, edge_type=(_tensors.TensorListCPU, _tensors.TensorListGPU)) | ||
input_desc = _ops._generate_input_desc(categories_idxs, integers, reals) | ||
device = _ops._choose_device(inputs) | ||
|
||
if any(isinstance(input, _tensors.TensorListGPU) for input in inputs): | ||
device = 'gpu' | ||
else: | ||
device = 'cpu' | ||
|
||
if device == "gpu": | ||
inputs = list(input._as_gpu() if isinstance( | ||
|
@@ -333,6 +337,87 @@ def _rxor(self, other): | |
_stateless_operators_cache = {} | ||
|
||
|
||
def _create_backend_op(spec, device, num_inputs, num_outputs, call_args_names, op_name): | ||
inp_device = 'cpu' if device == 'mixed' else device | ||
out_device = 'gpu' if device == 'mixed' else device | ||
|
||
for i in range(num_inputs): | ||
spec.AddInput(op_name + f'[{i}]', inp_device) | ||
|
||
for i in range(num_outputs): | ||
spec.AddOutput(op_name + f'_out[{i}]', out_device) | ||
|
||
for arg_name in call_args_names: | ||
spec.AddArgumentInput(arg_name, '') | ||
|
||
if device == 'cpu': | ||
backend_op = _b.EagerOperatorCPU(spec) | ||
elif device == 'gpu': | ||
backend_op = _b.EagerOperatorGPU(spec) | ||
elif device == 'mixed': | ||
backend_op = _b.EagerOperatorMixed(spec) | ||
else: | ||
raise ValueError( | ||
f"Incorrect device type '{device}' in eager operator '{op_name}'.") | ||
|
||
return backend_op | ||
|
||
|
||
def _eager_op_object_factory(op_class, op_name): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this could be grouped and marked with the _expose part as unused and with the purpose described. Maybe make it a section with a "block" comment, or a file? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I grouped it together, and added comments about it being unused. I don't know about moving it to a separate file, I think it's fine like this. |
||
""" Creates eager operator class to use with objective ops-like API. For completeness, | ||
currently not used. | ||
""" | ||
class EagerOperator(op_class): | ||
def __init__(self, **kwargs): | ||
self._batch_size = getattr(kwargs, 'batch_size', -1) | ||
|
||
# Workaround for batch size deduction in _prep_args as we don't have inputs yet. | ||
kwargs['batch_size'] = 0 | ||
|
||
_, init_args, _ = _prep_args( | ||
[], kwargs, op_name, op_name, _callable_op_factory.disqualified_arguments) | ||
device_id = init_args.pop('device_id') | ||
init_args.pop('max_batch_size') | ||
|
||
super().__init__(**init_args) | ||
|
||
self._spec.AddArg('device_id', device_id) | ||
self.built = False | ||
|
||
def __call__(self, *inputs, **kwargs): | ||
inputs, init_args, call_args = _prep_args( | ||
inputs, kwargs, op_name, op_name, _callable_op_factory.disqualified_arguments) | ||
|
||
if not self.built: | ||
num_outputs = self.schema.CalculateOutputs( | ||
self._spec) + self.schema.CalculateAdditionalOutputs(self._spec) | ||
|
||
self._spec.AddArg('max_batch_size', init_args['max_batch_size']) | ||
self._backend_op = _create_backend_op( | ||
self._spec, self._device, len(inputs), num_outputs, call_args.keys(), op_name) | ||
self.built = True | ||
|
||
output = self._backend_op(inputs, kwargs) | ||
|
||
if len(output) == 1: | ||
return output[0] | ||
|
||
return output | ||
|
||
return EagerOperator | ||
|
||
|
||
def _expose_eager_op_as_object(op_class, submodule): | ||
""" Exposes eager operators as objects. Can be used if we decide to change eager API from | ||
functional to objective. | ||
""" | ||
|
||
op_name = op_class.schema_name | ||
module = _internal.get_submodule('nvidia.dali.experimental.eager', submodule) | ||
op = _eager_op_object_factory(op_class, op_name) | ||
setattr(module, op_name, op) | ||
|
||
|
||
def _eager_op_base_factory(op_class, op_name, num_inputs, call_args_names): | ||
class EagerOperatorBase(op_class): | ||
def __init__(self, *, max_batch_size, device_id, **kwargs): | ||
|
@@ -341,26 +426,55 @@ def __init__(self, *, max_batch_size, device_id, **kwargs): | |
self._spec.AddArg('device_id', device_id) | ||
self._spec.AddArg('max_batch_size', max_batch_size) | ||
|
||
for i in range(num_inputs): | ||
self._spec.AddInput(op_name + f'[{i}]', self._device) | ||
|
||
for arg_name in call_args_names: | ||
self._spec.AddArgumentInput(arg_name, '') | ||
num_outputs = self.schema.CalculateOutputs( | ||
self._spec) + self.schema.CalculateAdditionalOutputs(self._spec) | ||
|
||
if self._device == 'cpu': | ||
self._backend_op = _b.EagerOperatorCPU(self._spec) | ||
elif self._device == 'gpu': | ||
self._backend_op = _b.EagerOperatorGPU(self._spec) | ||
elif self._device == 'mixed': | ||
self._backend_op = _b.EagerOperatorMixed(self._spec) | ||
else: | ||
raise ValueError( | ||
f"Incorrect device type '{self._device}' in eager operator '{op_name}'.") | ||
self._backend_op = _create_backend_op( | ||
self._spec, self._device, num_inputs, num_outputs, call_args_names, op_name) | ||
|
||
return EagerOperatorBase | ||
|
||
|
||
def _stateless_op_factory(op_class, op_name, num_inputs, call_args_names): | ||
def _create_module_class(): | ||
""" Creates a class imitating a module. Used for `rng_state` so we can have nested methods. | ||
E.g. `rng_state.random.normal`. | ||
""" | ||
class Module: | ||
@classmethod | ||
def _submodule(cls, name): | ||
""" Returns submodule, creates new if it does not exist. """ | ||
if name not in cls._submodules: | ||
# Register a new submodule class (object representing submodule will be created in | ||
# the rng_state's constructor). | ||
cls._submodules[name] = _create_state_submodule(name) | ||
|
||
return cls._submodules[name] | ||
|
||
_submodules = {} | ||
|
||
return Module | ||
|
||
|
||
def _create_state_submodule(name): | ||
""" Creates a class imitating a submodule. It can contain methods and nested submodules. | ||
Used for submodules of rng_state, e.g. `rng_state.random`, `rng_state.noise`. | ||
""" | ||
|
||
class StateSubmodule(_create_module_class()): | ||
def __init__(self, operator_cache, seed_generator): | ||
self._operator_cache = operator_cache | ||
self._seed_generator = seed_generator | ||
|
||
for name, submodule_class in StateSubmodule._submodules.items(): | ||
# Adds nested submodules. | ||
setattr(self, name, submodule_class(self._operator_cache, self._seed_generator)) | ||
|
||
__name__ = name | ||
|
||
return StateSubmodule | ||
|
||
|
||
def _callable_op_factory(op_class, op_name, num_inputs, call_args_names): | ||
class EagerOperator(_eager_op_base_factory(op_class, op_name, num_inputs, call_args_names)): | ||
def __call__(self, inputs, kwargs): | ||
# Here all kwargs are supposed to be TensorLists. | ||
|
@@ -374,6 +488,13 @@ def __call__(self, inputs, kwargs): | |
return EagerOperator | ||
|
||
|
||
_callable_op_factory.disqualified_arguments = { | ||
'bytes_per_sample_hint', | ||
'preserve', | ||
'seed' | ||
} | ||
|
||
|
||
def _iterator_op_factory(op_class, op_name, num_inputs, call_args_names): | ||
class EagerOperator(_eager_op_base_factory(op_class, op_name, num_inputs, call_args_names)): | ||
def __init__(self, call_args, *, max_batch_size, **kwargs): | ||
|
@@ -425,6 +546,12 @@ def __len__(self): | |
return EagerOperator | ||
|
||
|
||
_iterator_op_factory.disqualified_arguments = { | ||
'bytes_per_sample_hint', | ||
'preserve', | ||
} | ||
|
||
|
||
def _choose_device(op_name, wrapper_name, inputs, device_param): | ||
"""Returns device type and device_id based on inputs and device_param.""" | ||
|
||
|
@@ -541,32 +668,57 @@ def _desc_call_args(inputs, args): | |
[(key, value.dtype, value.layout(), len(value[0].shape())) for key, value in args.items()])) | ||
|
||
|
||
def _gen_cache_key(op_name, inputs, init_args, call_args): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
""" Creating cache key consisting of operator name, description of inputs, input arguments | ||
and init args. Each call arg is described by dtype, layout and dim. | ||
""" | ||
return op_name + _desc_call_args(inputs, call_args) + str(sorted(init_args.items())) | ||
|
||
|
||
def _wrap_stateless(op_class, op_name, wrapper_name): | ||
"""Wraps stateless Eager Operator in a function. Callable the same way as functions in fn API, | ||
but directly with TensorLists. | ||
""" | ||
def wrapper(*inputs, **kwargs): | ||
inputs, init_args, call_args = _prep_args( | ||
inputs, kwargs, op_name, wrapper_name, _wrap_stateless.disqualified_arguments) | ||
inputs, kwargs, op_name, wrapper_name, _callable_op_factory.disqualified_arguments) | ||
|
||
# Creating cache key consisting of operator name, description of inputs, input arguments | ||
# and init args. Each call arg is described by dtype, layout and dim. | ||
key = op_name + _desc_call_args(inputs, call_args) + str(sorted(init_args.items())) | ||
key = _gen_cache_key(op_name, inputs, init_args, call_args) | ||
|
||
if key not in _stateless_operators_cache: | ||
_stateless_operators_cache[key] = _stateless_op_factory( | ||
_stateless_operators_cache[key] = _callable_op_factory( | ||
op_class, wrapper_name, len(inputs), call_args.keys())(**init_args) | ||
|
||
return _stateless_operators_cache[key](inputs, call_args) | ||
|
||
return wrapper | ||
|
||
|
||
_wrap_stateless.disqualified_arguments = { | ||
'bytes_per_sample_hint', | ||
'preserve', | ||
'seed' | ||
} | ||
def _wrap_stateful(op_class, op_name, wrapper_name): | ||
"""Wraps stateful Eager Operator as method of a class. Callable the same way as functions in | ||
fn API, but directly with TensorLists. | ||
""" | ||
|
||
def wrapper(self, *inputs, **kwargs): | ||
inputs, init_args, call_args = _prep_args( | ||
inputs, kwargs, op_name, wrapper_name, _callable_op_factory.disqualified_arguments) | ||
|
||
key = _gen_cache_key(op_name, inputs, init_args, call_args) | ||
|
||
if key not in self._operator_cache: | ||
# Creating a new operator instance with deterministically generated seed, so if we | ||
# preserve the order of operator calls in different instances of rng_state, they | ||
# return the same results. | ||
seed = self._seed_generator.integers(_wrap_stateful.seed_upper_bound) | ||
self._operator_cache[key] = _callable_op_factory( | ||
op_class, wrapper_name, len(inputs), call_args.keys())(**init_args, seed=seed) | ||
|
||
return self._operator_cache[key](inputs, call_args) | ||
|
||
return wrapper | ||
|
||
|
||
_wrap_stateful.seed_upper_bound = (1 << 31) - 1 | ||
|
||
|
||
def _wrap_iterator(op_class, op_name, wrapper_name): | ||
|
@@ -582,7 +734,7 @@ def wrapper(*inputs, **kwargs): | |
raise ValueError("Iterator type eager operators should not receive any inputs.") | ||
|
||
inputs, init_args, call_args = _prep_args( | ||
inputs, kwargs, op_name, wrapper_name, _wrap_iterator.disqualified_arguments) | ||
inputs, kwargs, op_name, wrapper_name, _iterator_op_factory.disqualified_arguments) | ||
|
||
op = _iterator_op_factory(op_class, wrapper_name, len(inputs), | ||
call_args.keys())(call_args, **init_args) | ||
|
@@ -592,14 +744,39 @@ def wrapper(*inputs, **kwargs): | |
return wrapper | ||
|
||
|
||
_wrap_iterator.disqualified_arguments = { | ||
'bytes_per_sample_hint', | ||
'preserve', | ||
} | ||
def _get_rng_state_target_module(submodules): | ||
""" Returns target module of rng_state. If a module did not exist, creates it. """ | ||
from nvidia.dali.experimental import eager | ||
|
||
last_module = eager.rng_state | ||
for cur_module_name in submodules: | ||
# If nonexistent registers rng_state's submodule. | ||
cur_module = last_module._submodule(cur_module_name) | ||
last_module = cur_module | ||
|
||
return last_module | ||
|
||
|
||
def _get_eager_target_module(parent_module, submodules, make_hidden): | ||
""" Returns target module inside ``parent_module`` if specified, otherwise inside eager. """ | ||
if parent_module is None: | ||
# Exposing to nvidia.dali.experimental.eager module. | ||
parent_module = _internal.get_submodule('nvidia.dali', 'experimental.eager') | ||
else: | ||
# Exposing to experimental.eager submodule of the specified parent module. | ||
parent_module = _internal.get_submodule( | ||
sys.modules[parent_module], 'experimental.eager') | ||
|
||
if make_hidden: | ||
op_module = _internal.get_submodule(parent_module, submodules[:-1]) | ||
else: | ||
op_module = _internal.get_submodule(parent_module, submodules) | ||
|
||
return op_module | ||
|
||
|
||
def _wrap_eager_op(op_class, submodule, parent_module, wrapper_name, wrapper_doc, make_hidden): | ||
"""Exposes eager operator to the appropriate module (similar to :func:`nvidia.dali.fn._wrap_op`). | ||
def _wrap_eager_op(op_class, submodules, parent_module, wrapper_name, wrapper_doc, make_hidden): | ||
""" Exposes eager operator to the appropriate module (similar to :func:`nvidia.dali.fn._wrap_op`). | ||
Uses ``op_class`` for preprocessing inputs and keyword arguments and filling OpSpec for backend | ||
eager operators. | ||
|
||
|
@@ -612,36 +789,31 @@ def _wrap_eager_op(op_class, submodule, parent_module, wrapper_name, wrapper_doc | |
wrapper_doc (str): Documentation of the wrapper function. | ||
make_hidden (bool): If operator is hidden, we should extract it from hidden submodule. | ||
""" | ||
|
||
op_name = op_class.schema_name | ||
op_schema = _b.TryGetSchema(op_name) | ||
if op_schema.IsDeprecated() or op_name in _excluded_operators or op_name in _stateful_operators: | ||
# TODO(ksztenderski): For now only exposing stateless and iterator operators. | ||
return | ||
elif op_name in _iterator_operators: | ||
wrapper = _wrap_iterator(op_class, op_name, wrapper_name) | ||
else: | ||
# If operator is not stateful or a generator expose it as stateless. | ||
wrapper = _wrap_stateless(op_class, op_name, wrapper_name) | ||
|
||
if parent_module is None: | ||
# Exposing to nvidia.dali.experimental.eager module. | ||
parent_module = _internal.get_submodule('nvidia.dali', 'experimental.eager') | ||
if op_schema.IsDeprecated() or op_name in _excluded_operators: | ||
return | ||
elif op_name in _stateful_operators: | ||
wrapper = _wrap_stateful(op_class, op_name, wrapper_name) | ||
op_module = _get_rng_state_target_module(submodules) | ||
else: | ||
# Exposing to experimental.eager submodule of the specified parent module. | ||
parent_module = _internal.get_submodule(sys.modules[parent_module], 'experimental.eager') | ||
if op_name in _iterator_operators: | ||
wrapper = _wrap_iterator(op_class, op_name, wrapper_name) | ||
else: | ||
# If operator is not stateful, generator, deprecated or excluded expose it as stateless. | ||
wrapper = _wrap_stateless(op_class, op_name, wrapper_name) | ||
|
||
if make_hidden: | ||
op_module = _internal.get_submodule(parent_module, submodule[:-1]) | ||
else: | ||
op_module = _internal.get_submodule(parent_module, submodule) | ||
op_module = _get_eager_target_module(parent_module, submodules, make_hidden) | ||
|
||
if not hasattr(op_module, wrapper_name): | ||
wrapper.__name__ = wrapper_name | ||
wrapper.__qualname__ = wrapper_name | ||
wrapper.__doc__ = wrapper_doc | ||
wrapper._schema_name = op_schema | ||
wrapper._schema_name = op_name | ||
|
||
if submodule: | ||
if submodules: | ||
wrapper.__module__ = op_module.__name__ | ||
|
||
setattr(op_module, wrapper_name, wrapper) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here the inputs are already normalized to TensorListCPU/GPU, correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes