Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TF 1.15 and 2.0 support for TF dataset #1395

Merged
merged 15 commits into from
Oct 29, 2019
Merged
48 changes: 36 additions & 12 deletions dali/python/nvidia/dali/plugin/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import glob
from collections import Iterable
import re
from distutils.version import StrictVersion


_tf_plugins = glob.glob(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'libdali_tf*.so'))
Expand All @@ -43,6 +44,14 @@

_dali_tf = _dali_tf_module.dali

_dali_tf.__doc__ = _dali_tf.__doc__ + """

WARNING:
-------
Please keep in mind that TensorFlow allocates almost all available device memory by default. This might cause errors in
DALI due to insufficient memory.
JanuszL marked this conversation as resolved.
Show resolved Hide resolved
"""

def DALIIteratorWrapper(pipeline = None, serialized_pipeline = None, sparse = [],
shapes = [], dtypes = [], batch_size = -1, prefetch_queue_depth = 2, **kwargs):
"""
Expand Down Expand Up @@ -117,11 +126,11 @@ def DALIRawIterator():
return _dali_tf


def _get_tf_minor_version():
return tf.__version__.split('.')[1]
def _get_tf_version():
return StrictVersion(tf.__version__)


if _get_tf_minor_version() in {'13', '14'}:
if _get_tf_version() >= StrictVersion('1.13'):
from tensorflow.python.framework import ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import structure
Expand Down Expand Up @@ -160,12 +169,19 @@ def __init__(
self._structure = structure.convert_legacy_structure(
self._dtypes, self._shapes, output_classes)

if _get_tf_minor_version() == '14':
if _get_tf_version() >= StrictVersion('1.14'):
super(_DALIDatasetV2, self).__init__(self._as_variant_tensor())
elif _get_tf_minor_version() == '13':
elif _get_tf_version() >= StrictVersion('1.13') and _get_tf_version() < StrictVersion('1.14'):
awolant marked this conversation as resolved.
Show resolved Hide resolved
super(_DALIDatasetV2, self).__init__()
else:
raise RuntimeError('Unsupported TensorFlow version detected at runtime. DALIDataset supports versions: 1.13, 1.14')
raise RuntimeError('Unsupported TensorFlow version detected at runtime. DALIDataset supports versions: 1.13, 1.14, 1.15, 2.0')
awolant marked this conversation as resolved.
Show resolved Hide resolved


# This function should not be removed or refactored.
# It is needed for TF 1.15 and 2.0
@property
def element_spec(self):
return self._structure


@property
Expand All @@ -189,11 +205,14 @@ def _as_variant_tensor(self):
dtypes = self._dtypes)


class DALIDataset(dataset_ops.DatasetV1Adapter):
@functools.wraps(_DALIDatasetV2.__init__)
def __init__(self, **kwargs):
wrapped = _DALIDatasetV2(**kwargs)
super(DALIDataset, self).__init__(wrapped)
if _get_tf_version() < StrictVersion('2.0'):
class DALIDataset(dataset_ops.DatasetV1Adapter):
@functools.wraps(_DALIDatasetV2.__init__)
def __init__(self, **kwargs):
wrapped = _DALIDatasetV2(**kwargs)
super(DALIDataset, self).__init__(wrapped)
else:
DALIDataset = _DALIDatasetV2

else:
class DALIDataset:
Expand All @@ -211,7 +230,12 @@ def __init__(
dtypes = []):
raise RuntimeError('DALIDataset is not supported for detected version of TensorFlow.')

DALIDataset.__doc__ = """Creates a `DALIDataset` compatible with tf.data.Dataset from a DALI pipeline. It supports TensorFlow 1.13 and 1.14
DALIDataset.__doc__ = """Creates a `DALIDataset` compatible with tf.data.Dataset from a DALI pipeline. It supports TensorFlow 1.13, 1.14, 1.15 and 2.0

WARNING:
-------
Please keep in mind that TensorFlow allocates almost all available device memory by default. This might cause errors in
DALI due to insufficient memory.
JanuszL marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
Expand Down
22 changes: 12 additions & 10 deletions dali/test/python/test_dali_tf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,28 @@
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
from distutils.version import StrictVersion

from nose import SkipTest
from nose.tools import raises

try:
tf.compat.v1.disable_eager_execution()
except:
pass

test_data_root = os.environ['DALI_EXTRA_PATH']
file_root = os.path.join(test_data_root, 'db', 'coco', 'images')
annotations_file = os.path.join(test_data_root, 'db', 'coco', 'instances.json')


def tensorflow_minor_version():
return tf.__version__.split('.')[1]


def compatible_tensorflow():
return tensorflow_minor_version() in {'13', '14'}
return StrictVersion(tf.__version__) >= StrictVersion('1.13')


def skip_for_incompatible_tf():
if not compatible_tensorflow():
raise SkipTest('This feature is enabled for TF 1.13 and 1.14 only')
raise SkipTest('This feature is enabled for TF 1.13 and higher')


def num_available_gpus():
Expand Down Expand Up @@ -103,10 +105,10 @@ def _dataset_options():
try:
options.experimental_optimization.apply_default_optimizations = False

if tensorflow_minor_version() == '14':
options.experimental_optimization.autotune = False
elif tensorflow_minor_version() == '13':
options.experimental_autotune = False
if StrictVersion(tf.__version__) >= StrictVersion('1.13') and StrictVersion(tf.__version__) < StrictVersion('1.14'):
options.experimental_autotune = False
awolant marked this conversation as resolved.
Show resolved Hide resolved
else:
options.experimental_optimization.autotune = False
except:
print('Could not set TF Dataset Options')

Expand Down
2 changes: 1 addition & 1 deletion dali_tf_plugin/dali_dataset_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#include "tensorflow/core/public/version.h"

#if TF_MAJOR_VERSION == 1 && TF_MINOR_VERSION >= 13
#if TF_MAJOR_VERSION == 2 || (TF_MAJOR_VERSION == 1 && TF_MINOR_VERSION >= 13)

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wreorder"
Expand Down