Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add eager mode stateful operators #4016

Merged
merged 6 commits into from
Jun 30, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions dali/python/nvidia/dali/_debug_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import inspect
import math
import traceback
import sys
import warnings
from queue import Queue

Expand Down Expand Up @@ -549,7 +550,7 @@ def __init__(self, exec_func, **kwargs):
import numpy as np
seed = kwargs.get('seed', -1)
if seed < 0:
seed = np.random.randint(0, 2**32)
seed = np.random.randint(sys.maxsize)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, wouldn't this be platform dependent? Maybe used iinfo and fixed sized type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I set a fixed value at (1 << 31) - 1

self._seed_generator = np.random.default_rng(seed)

def __enter__(self):
Expand Down Expand Up @@ -625,7 +626,7 @@ def _create_op(self, op_class, op_name, key, cur_context, inputs, kwargs):
"""Creates direct operator."""
self._operators[key] = _OperatorManager(
op_class, op_name, self, cur_context, self._next_logical_id, self._max_batch_size,
self._device_id, self._seed_generator.integers(0, 2**32), inputs, kwargs)
self._device_id, self._seed_generator.integers(sys.maxsize), inputs, kwargs)

self._pipe.AddMultipleOperators(
self._operators[key].op_spec, self._operators[key].logical_ids)
Expand Down
117 changes: 74 additions & 43 deletions dali/python/nvidia/dali/_utils/eager_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,11 @@ def _arithm_op(name, *inputs):
categories_idxs, inputs, integers, reals = _ops._group_inputs(
inputs, edge_type=(_tensors.TensorListCPU, _tensors.TensorListGPU))
input_desc = _ops._generate_input_desc(categories_idxs, integers, reals)
device = _ops._choose_device(inputs)

if any(isinstance(input, _tensors.TensorListGPU) for input in inputs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the inputs are already normalized to TensorListCPU/GPU, correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

device = 'gpu'
else:
device = 'cpu'

if device == "gpu":
inputs = list(input._as_gpu() if isinstance(
Expand Down Expand Up @@ -334,11 +338,14 @@ def _rxor(self, other):


def _create_backend_op(spec, device, num_inputs, num_outputs, call_args_names, op_name):
inp_device = 'cpu' if device == 'mixed' else device
out_device = 'gpu' if device == 'mixed' else device

for i in range(num_inputs):
spec.AddInput(op_name + f'[{i}]', device)
spec.AddInput(op_name + f'[{i}]', inp_device)

for i in range(num_outputs):
spec.AddOutput(op_name + f'_out[{i}]', device)
spec.AddOutput(op_name + f'_out[{i}]', out_device)

for arg_name in call_args_names:
spec.AddArgumentInput(arg_name, '')
Expand All @@ -357,6 +364,9 @@ def _create_backend_op(spec, device, num_inputs, num_outputs, call_args_names, o


def _eager_op_object_factory(op_class, op_name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this could be grouped and marked with the _expose part as unused and with the purpose described. Maybe make it a section with a "block" comment, or a file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I grouped it together, and added comments about it being unused. I don't know about moving it to a separate file, I think it's fine like this.

""" Creates eager operator class to use with objective ops-like API. For completeness,
currently not used.
"""
class EagerOperator(op_class):
def __init__(self, **kwargs):
self._batch_size = getattr(kwargs, 'batch_size', -1)
Expand Down Expand Up @@ -397,6 +407,17 @@ def __call__(self, *inputs, **kwargs):
return EagerOperator


def _expose_eager_op_as_object(op_class, submodule):
""" Exposes eager operators as objects. Can be used if we decide to change eager API from
functional to objective.
"""

op_name = op_class.schema_name
module = _internal.get_submodule('nvidia.dali.experimental.eager', submodule)
op = _eager_op_object_factory(op_class, op_name)
setattr(module, op_name, op)


def _eager_op_base_factory(op_class, op_name, num_inputs, call_args_names):
class EagerOperatorBase(op_class):
def __init__(self, *, max_batch_size, device_id, **kwargs):
Expand Down Expand Up @@ -647,6 +668,13 @@ def _desc_call_args(inputs, args):
[(key, value.dtype, value.layout(), len(value[0].shape())) for key, value in args.items()]))


def _gen_cache_key(op_name, inputs, init_args, call_args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

""" Creating cache key consisting of operator name, description of inputs, input arguments
and init args. Each call arg is described by dtype, layout and dim.
"""
return op_name + _desc_call_args(inputs, call_args) + str(sorted(init_args.items()))


def _wrap_stateless(op_class, op_name, wrapper_name):
"""Wraps stateless Eager Operator in a function. Callable the same way as functions in fn API,
but directly with TensorLists.
Expand All @@ -655,9 +683,7 @@ def wrapper(*inputs, **kwargs):
inputs, init_args, call_args = _prep_args(
inputs, kwargs, op_name, wrapper_name, _callable_op_factory.disqualified_arguments)

# Creating cache key consisting of operator name, description of inputs, input arguments
# and init args. Each call arg is described by dtype, layout and dim.
key = op_name + _desc_call_args(inputs, call_args) + str(sorted(init_args.items()))
key = _gen_cache_key(op_name, inputs, init_args, call_args)

if key not in _stateless_operators_cache:
_stateless_operators_cache[key] = _callable_op_factory(
Expand All @@ -677,10 +703,13 @@ def wrapper(self, *inputs, **kwargs):
inputs, init_args, call_args = _prep_args(
inputs, kwargs, op_name, wrapper_name, _callable_op_factory.disqualified_arguments)

key = op_name + _desc_call_args(inputs, call_args) + str(sorted(init_args.items()))
key = _gen_cache_key(op_name, inputs, init_args, call_args)

if key not in self._operator_cache:
seed = self._seed_generator.integers(2**32)
# Creating a new operator instance with deterministically generated seed, so if we
# preserve the order of operator calls in different instances of rng_state, they
# return the same results.
seed = self._seed_generator.integers(sys.maxsize)
self._operator_cache[key] = _callable_op_factory(
op_class, wrapper_name, len(inputs), call_args.keys())(**init_args, seed=seed)

Expand Down Expand Up @@ -712,18 +741,39 @@ def wrapper(*inputs, **kwargs):
return wrapper


def _expose_eager_op_as_object(op_class, submodule):
""" Exposes eager operators as objects. Can be used if we decide to change eager API from
functional to objective."""
def _get_rng_state_target_module(submodules):
""" Returns target module of rng_state. If a module did not exist, creates it. """
from nvidia.dali.experimental import eager

op_name = op_class.schema_name
module = _internal.get_submodule('nvidia.dali.experimental.eager', submodule)
op = _eager_op_object_factory(op_class, op_name)
setattr(module, op_name, op)
last_module = eager.rng_state
for cur_module_name in submodules:
# If nonexistent registers rng_state's submodule.
cur_module = last_module._submodule(cur_module_name)
last_module = cur_module

return last_module


def _get_eager_target_module(parent_module, submodules, make_hidden):
""" Returns target module inside ``parent_module`` if specified, otherwise inside eager. """
if parent_module is None:
# Exposing to nvidia.dali.experimental.eager module.
parent_module = _internal.get_submodule('nvidia.dali', 'experimental.eager')
else:
# Exposing to experimental.eager submodule of the specified parent module.
parent_module = _internal.get_submodule(
sys.modules[parent_module], 'experimental.eager')

if make_hidden:
op_module = _internal.get_submodule(parent_module, submodules[:-1])
else:
op_module = _internal.get_submodule(parent_module, submodules)

return op_module

def _wrap_eager_op(op_class, submodule, parent_module, wrapper_name, wrapper_doc, make_hidden):
"""Exposes eager operator to the appropriate module (similar to :func:`nvidia.dali.fn._wrap_op`).

def _wrap_eager_op(op_class, submodules, parent_module, wrapper_name, wrapper_doc, make_hidden):
""" Exposes eager operator to the appropriate module (similar to :func:`nvidia.dali.fn._wrap_op`).
Uses ``op_class`` for preprocessing inputs and keyword arguments and filling OpSpec for backend
eager operators.

Expand All @@ -736,50 +786,31 @@ def _wrap_eager_op(op_class, submodule, parent_module, wrapper_name, wrapper_doc
wrapper_doc (str): Documentation of the wrapper function.
make_hidden (bool): If operator is hidden, we should extract it from hidden submodule.
"""
from nvidia.dali.experimental import eager

op_name = op_class.schema_name
op_schema = _b.TryGetSchema(op_name)

if op_name in _stateful_operators:

if op_schema.IsDeprecated() or op_name in _excluded_operators:
return
elif op_name in _stateful_operators:
wrapper = _wrap_stateful(op_class, op_name, wrapper_name)
last_module = eager.rng_state
for cur_module_name in submodule:
# If nonexistent registers rng_state's submodule.
cur_module = last_module._submodule(cur_module_name)
last_module = cur_module

op_module = last_module
op_module = _get_rng_state_target_module(submodules)
else:
if op_schema.IsDeprecated() or op_name in _excluded_operators:
return
elif op_name in _iterator_operators:
if op_name in _iterator_operators:
wrapper = _wrap_iterator(op_class, op_name, wrapper_name)
else:
# If operator is not stateful, generator, deprecated or excluded expose it as stateless.
wrapper = _wrap_stateless(op_class, op_name, wrapper_name)

if parent_module is None:
# Exposing to nvidia.dali.experimental.eager module.
parent_module = _internal.get_submodule('nvidia.dali', 'experimental.eager')
else:
# Exposing to experimental.eager submodule of the specified parent module.
parent_module = _internal.get_submodule(
sys.modules[parent_module], 'experimental.eager')

if make_hidden:
op_module = _internal.get_submodule(parent_module, submodule[:-1])
else:
op_module = _internal.get_submodule(parent_module, submodule)
op_module = _get_eager_target_module(parent_module, submodules, make_hidden)

if not hasattr(op_module, wrapper_name):
wrapper.__name__ = wrapper_name
wrapper.__qualname__ = wrapper_name
wrapper.__doc__ = wrapper_doc
wrapper._schema_name = op_name

if submodule:
if submodules:
wrapper.__module__ = op_module.__name__

setattr(op_module, wrapper_name, wrapper)
16 changes: 13 additions & 3 deletions dali/python/nvidia/dali/experimental/eager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class arithmetic(metaclass=_MetaArithmetic):
>>> tl = dali.tensors.TensorListCPU(...)
>>> out = tl ** 2
"""

def __init__(self, enabled=True):
self.prev = arithmetic._enabled
arithmetic._enabled = enabled
Expand All @@ -73,8 +74,17 @@ def __exit__(self, type, value, traceback):


class rng_state(_create_module_class()):
""" Manager class for stateful operators. Methods of this class correspond to the appropriate
functions in the fn API, they are created by :func:`_wrap_stateful` and are added dynamically.
""" Manager class for stateful operators. This object holds a cache of reusable operators.
Operators are initialized with deterministic seeds generated according to the ``seed`` argument
and are reused when you call the same operator with the same scalar parameters.

Example:
>>> eager_state = dali.experimental.eager.rng_state(seed=42)
>>> out1 = eager_state.random.normal(shape=[5, 5], batch_size=8)
>>> # Here we will reuse the same operator.
>>> out2 = eager_state.random.normal(shape=[5, 5], batch_size=8)
>>> # And now we will create a new operator with new seed.
>>> out3 = eager_state.random.normal(shape=[10, 10], batch_size=8)
"""

def __init__(self, seed=None):
Expand All @@ -87,4 +97,4 @@ def __init__(self, seed=None):
# Create attributes imitating submodules, e.g. `random`, `noise`.
setattr(self, name, submodule_class(self._operator_cache, self._seed_generator))

__name__ = 'rng_state'
__name__ = 'rng_state'
15 changes: 9 additions & 6 deletions dali/test/python/test_eager_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
import numpy as np
import os
import re
import sys
from functools import reduce

import nvidia.dali.experimental.eager as eager
import nvidia.dali.fn as fn
import nvidia.dali.tensors as tensors
import nvidia.dali.types as types
from nvidia.dali import fn as fn
from nvidia.dali import tensors as tensors
from nvidia.dali import types as types
from nvidia.dali.experimental import eager as eager
from nvidia.dali.pipeline import Pipeline, pipeline_def
from nvidia.dali._utils.eager_utils import _slice_tensorlist
from test_dali_cpu_only_utils import (pipeline_arithm_ops_cpu, setup_test_nemo_asr_reader_cpu,
Expand Down Expand Up @@ -215,8 +216,10 @@ def check_no_input(op_path, *, fn_op=None, eager_op=None, batch_size=batch_size,


def prep_stateful_operators(op_path):
seed = rng.integers(2048)
fn_seed = np.random.default_rng(seed).integers(2**32)
# Replicating seed that will be used inside rng_state, that way we expect fn and eager
# operators to return same results.
seed = rng.integers(sys.maxsize)
fn_seed = np.random.default_rng(seed).integers(sys.maxsize)
eager_state = eager.rng_state(seed)

fn_op, eager_op = get_ops(op_path, eager_module=eager_state)
Expand Down
83 changes: 81 additions & 2 deletions dali/test/python/test_eager_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@
# limitations under the License.

import numpy as np
import os

import nvidia.dali.experimental.eager as eager
import nvidia.dali.tensors as tensors
from nvidia.dali import fn
from nvidia.dali import pipeline_def
from nvidia.dali import ops as ops
from nvidia.dali import tensors as tensors
from nvidia.dali.experimental import eager as eager
from nose_utils import assert_raises, raises
from test_utils import get_dali_extra_path


@raises(RuntimeError, glob=f"Argument '*' is not supported by eager operator 'crop'.")
Expand Down Expand Up @@ -87,3 +92,77 @@ def test_arithm_op_context_manager_deep_nested():

assert np.array_equal((tl_1 + tl_2).as_array(), expected_sum)
eager.arithmetic(False)


def test_identical_rng_states():
eager_state_1 = eager.rng_state(seed=42)
eager_state_2 = eager.rng_state(seed=42)

out_1_1 = eager_state_1.random.normal(shape=[5, 5], batch_size=8)
out_1_2 = eager_state_1.noise.gaussian(out_1_1)
out_1_3 = eager_state_1.random.normal(shape=[5, 5], batch_size=8)

out_2_1 = eager_state_2.random.normal(shape=[5, 5], batch_size=8)
out_2_2 = eager_state_2.noise.gaussian(out_2_1)
out_2_3 = eager_state_2.random.normal(shape=[5, 5], batch_size=8)

assert np.allclose(out_1_1.as_tensor(), out_2_1.as_tensor())
assert np.allclose(out_1_2.as_tensor(), out_2_2.as_tensor())
assert np.allclose(out_1_3.as_tensor(), out_2_3.as_tensor())


def test_identical_rng_states_interleaved():
eager_state_1 = eager.rng_state(seed=42)
eager_state_2 = eager.rng_state(seed=42)

out_1_1 = eager_state_1.random.normal(shape=[5, 5], batch_size=8)
eager_state_1.random.normal(shape=[6, 6], batch_size=8)
eager_state_1.noise.gaussian(out_1_1)
out_1_2 = eager_state_1.random.normal(shape=[5, 5], batch_size=8)

out_2_1 = eager_state_2.random.normal(shape=[5, 5], batch_size=8)
out_2_2 = eager_state_2.random.normal(shape=[5, 5], batch_size=8)

assert np.allclose(out_1_1.as_tensor(), out_2_1.as_tensor())
assert np.allclose(out_1_2.as_tensor(), out_2_2.as_tensor())


def test_objective_eager_resize():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment that this tests the hidden functionality of exposing the Eager Ops classes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

from nvidia.dali._utils import eager_utils

resize_class = eager_utils._eager_op_object_factory(ops.python_op_factory('Resize'), 'Resize')
tl = tensors.TensorListCPU(np.random.default_rng().integers(
256, size=(8, 200, 200, 3), dtype=np.uint8))

obj_resize = resize_class(resize_x=50, resize_y=50)
out_obj = obj_resize(tl)
out_fun = eager.resize(tl, resize_x=50, resize_y=50)

assert np.array_equal(out_obj.as_tensor(), out_fun.as_tensor())


@pipeline_def(num_threads=3, device_id=0)
def mixed_image_decoder_pipeline(file_root, seed):
jpeg, _ = fn.readers.file(file_root=file_root, seed=seed)
out = fn.decoders.image(jpeg, device="mixed")

return out


def test_mixed_devices_decoder():
seed = 42
batch_size = 8
file_root = os.path.join(get_dali_extra_path(), 'db/single/jpeg')

pipe = mixed_image_decoder_pipeline(file_root, seed, batch_size=batch_size)
pipe.build()
pipe_out, = pipe.run()

jpeg, _ = next(eager.readers.file(file_root=file_root, batch_size=batch_size, seed=seed))
eager_out = eager.decoders.image(jpeg, device="gpu")

assert len(pipe_out) == len(eager_out)

with eager.arithmetic():
for comp_tensor in (pipe_out == eager_out):
assert np.all(comp_tensor.as_cpu())