From 749bded360dc33ec44c8c68bf3c4138dda09ed2f Mon Sep 17 00:00:00 2001 From: Vandana Kannan Date: Thu, 11 Jun 2020 17:31:15 -0700 Subject: [PATCH] TF 2.x: Support for keras to estimator (#268) --- docs/sagemaker.md | 2 +- docs/tensorflow.md | 15 ++++--- smdebug/tensorflow/collection.py | 3 -- smdebug/tensorflow/keras.py | 5 +++ tests/tensorflow2/test_estimator.py | 61 +++++++++++++++++++++++++++++ tests/tensorflow2/test_keras.py | 45 +++++++++++++++++++++ tests/zero_code_change/tf_utils.py | 3 +- 7 files changed, 124 insertions(+), 10 deletions(-) create mode 100644 tests/tensorflow2/test_estimator.py diff --git a/docs/sagemaker.md b/docs/sagemaker.md index 01c03568b..e2968f8d8 100644 --- a/docs/sagemaker.md +++ b/docs/sagemaker.md @@ -27,7 +27,7 @@ Here's a list of frameworks and versions which support this experience. | Framework | Version | | --- | --- | -| [TensorFlow](tensorflow.md) | 1.15, 2.1 | +| [TensorFlow](tensorflow.md) | 1.15, 2.1, 2.2 | | [MXNet](mxnet.md) | 1.6 | | [PyTorch](pytorch.md) | 1.4, 1.5 | | [XGBoost](xgboost.md) | >=0.90-2 [As Built-in algorithm](xgboost.md#use-xgboost-as-a-built-in-algorithm)| diff --git a/docs/tensorflow.md b/docs/tensorflow.md index c376c81e5..968dd5230 100644 --- a/docs/tensorflow.md +++ b/docs/tensorflow.md @@ -15,17 +15,22 @@ ### Versions - Zero Script Change experience where you need no modifications to your training script is supported in the official [SageMaker Framework Container for TensorFlow 1.15](https://docs.aws.amazon.com/sagemaker/latest/dg/pre-built-containers-frameworks-deep-learning.html), or the [AWS Deep Learning Container for TensorFlow 1.15](https://aws.amazon.com/machine-learning/containers/). -- This library itself supports the following versions when you use our API which requires a few minimal changes to your training script: TensorFlow 1.14, 1.15, 2.0.1, 2.1.0. Keras 2.3. +- This library itself supports the following versions when you use our API which requires a few minimal changes to your training script: TensorFlow 1.14, 1.15, 2.0+. Keras 2.3. ### Interfaces -- [Estimator](https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/estimator) -- [tf.keras](https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/keras) -- [MonitoredSession](https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/train/MonitoredSession?hl=en) +- TF 1.x: + - [Estimator](https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/estimator) + - [tf.keras](https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/keras) + - [MonitoredSession](https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/train/MonitoredSession?hl=en) +- TF 2.x: + - [Estimator](https://www.tensorflow.org/versions/r2.1/api_docs/python/tf/estimator) + - [tf.keras](https://www.tensorflow.org/versions/r2.1/api_docs/python/tf/keras) + ### Distributed training - [MirroredStrategy](https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/distribute/MirroredStrategy) or [Contrib MirroredStrategy](https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/contrib/distribute/MirroredStrategy) -We will very quickly follow up with support for Horovod and Parameter Server based training. +We will very quickly follow up with support for Parameter Server based training. --- diff --git a/smdebug/tensorflow/collection.py b/smdebug/tensorflow/collection.py index 89df3916a..32e31da70 100644 --- a/smdebug/tensorflow/collection.py +++ b/smdebug/tensorflow/collection.py @@ -148,9 +148,6 @@ def __init__(self, collections=None, create_default=True): self.create_collection(n) if is_tf_version_2x() and tf.executing_eagerly(): self.get(CollectionKeys.BIASES).include("^(?!gradient).*bias") - self.get(CollectionKeys.WEIGHTS).include("^weights/.*/((?!bias).)*$") - self.get(CollectionKeys.LOSSES).include(".*loss.*") - self.get(CollectionKeys.GRADIENTS).include("^gradient") else: self.get(CollectionKeys.BIASES).include("bias") diff --git a/smdebug/tensorflow/keras.py b/smdebug/tensorflow/keras.py index fea005a31..31ec61383 100644 --- a/smdebug/tensorflow/keras.py +++ b/smdebug/tensorflow/keras.py @@ -716,6 +716,11 @@ def run(*args, **kwargs): # at this point we need all collections to be ready # this may not be the case at creation of hook # as user's code after hook might add collections + self.collection_manager.get(CollectionKeys.WEIGHTS).include( + "^weights/.*/((?!bias).)*$" + ) + self.collection_manager.get(CollectionKeys.LOSSES).include(".*loss.*") + self.collection_manager.get(CollectionKeys.GRADIENTS).include("^gradient") self._prepare_collections() self.prepared_collections = True diff --git a/tests/tensorflow2/test_estimator.py b/tests/tensorflow2/test_estimator.py new file mode 100644 index 000000000..05c2691b4 --- /dev/null +++ b/tests/tensorflow2/test_estimator.py @@ -0,0 +1,61 @@ +# Standard Library +# Third Party +import pytest +import tensorflow.compat.v2 as tf +from tests.zero_code_change.tf_utils import get_estimator, get_input_fns + +# First Party +import smdebug.tensorflow as smd + + +@pytest.mark.parametrize("saveall", [True, False]) +def test_estimator(out_dir, tf_eager_mode, saveall): + """ Works as intended. """ + if tf_eager_mode is False: + tf.compat.v1.disable_eager_execution() + tf.compat.v1.reset_default_graph() + tf.keras.backend.clear_session() + mnist_classifier = get_estimator() + train_input_fn, eval_input_fn = get_input_fns() + + # Train and evaluate + train_steps, eval_steps = 8, 2 + hook = smd.EstimatorHook(out_dir=out_dir, save_all=saveall) + hook.set_mode(mode=smd.modes.TRAIN) + mnist_classifier.train(input_fn=train_input_fn, steps=train_steps, hooks=[hook]) + hook.set_mode(mode=smd.modes.EVAL) + mnist_classifier.evaluate(input_fn=eval_input_fn, steps=eval_steps, hooks=[hook]) + + # Check that hook created and tensors saved + trial = smd.create_trial(path=out_dir) + tnames = trial.tensor_names() + assert len(trial.steps()) > 0 + if saveall: + assert len(tnames) >= 301 + else: + assert len(tnames) == 1 + + +@pytest.mark.parametrize("saveall", [True, False]) +def test_linear_classifier(out_dir, tf_eager_mode, saveall): + """ Works as intended. """ + if tf_eager_mode is False: + tf.compat.v1.disable_eager_execution() + tf.compat.v1.reset_default_graph() + tf.keras.backend.clear_session() + train_input_fn, eval_input_fn = get_input_fns() + x_feature = tf.feature_column.numeric_column("x", shape=(28, 28)) + estimator = tf.estimator.LinearClassifier( + feature_columns=[x_feature], model_dir="/tmp/mnist_linear_classifier", n_classes=10 + ) + hook = smd.EstimatorHook(out_dir=out_dir, save_all=saveall) + estimator.train(input_fn=train_input_fn, steps=10, hooks=[hook]) + + # Check that hook created and tensors saved + trial = smd.create_trial(path=out_dir) + tnames = trial.tensor_names() + assert len(trial.steps()) > 0 + if saveall: + assert len(tnames) >= 224 + else: + assert len(tnames) == 2 diff --git a/tests/tensorflow2/test_keras.py b/tests/tensorflow2/test_keras.py index ae972126a..f296a7a5c 100644 --- a/tests/tensorflow2/test_keras.py +++ b/tests/tensorflow2/test_keras.py @@ -12,6 +12,7 @@ # Third Party import pytest import tensorflow.compat.v2 as tf +import tensorflow_datasets as tfds from tests.tensorflow2.utils import is_tf_2_2 from tests.tensorflow.utils import create_trial_fast_refresh @@ -649,3 +650,47 @@ def test_keras_fit_pure_eager(out_dir, tf_eager_mode): assert len(trial.tensor_names(collection=CollectionKeys.BIASES)) == 2 assert len(trial.tensor_names(collection=CollectionKeys.WEIGHTS)) == 2 assert len(trial.tensor_names(collection=CollectionKeys.OPTIMIZER_VARIABLES)) == 5 + + +def test_keras_to_estimator(out_dir, tf_eager_mode): + if not tf_eager_mode: + tf.compat.v1.disable_eager_execution() + tf.compat.v1.reset_default_graph() + + tf.keras.backend.clear_session() + + model = tf.keras.models.Sequential( + [ + tf.keras.layers.Dense(16, activation="relu", input_shape=(4,)), + tf.keras.layers.Dropout(0.2), + tf.keras.layers.Dense(1, activation="sigmoid"), + ] + ) + + def input_fn(): + split = tfds.Split.TRAIN + dataset = tfds.load("iris", split=split, as_supervised=True) + dataset = dataset.map(lambda features, labels: ({"dense_input": features}, labels)) + dataset = dataset.batch(32).repeat() + return dataset + + model.compile(loss="categorical_crossentropy", optimizer="adam") + model.summary() + + keras_estimator = tf.keras.estimator.model_to_estimator(keras_model=model, model_dir=out_dir) + + hook = smd.EstimatorHook(out_dir) + + hook.set_mode(smd.modes.TRAIN) + keras_estimator.train(input_fn=input_fn, steps=25, hooks=[hook]) + + hook.set_mode(smd.modes.EVAL) + eval_result = keras_estimator.evaluate(input_fn=input_fn, steps=10, hooks=[hook]) + + from smdebug.trials import create_trial + + tr = create_trial(out_dir) + assert len(tr.tensor_names()) == 1 + assert len(tr.steps()) == 2 + assert len(tr.steps(smd.modes.TRAIN)) == 1 + assert len(tr.steps(smd.modes.EVAL)) == 1 diff --git a/tests/zero_code_change/tf_utils.py b/tests/zero_code_change/tf_utils.py index c7ff838ac..15b85aa6e 100644 --- a/tests/zero_code_change/tf_utils.py +++ b/tests/zero_code_change/tf_utils.py @@ -5,7 +5,6 @@ import numpy as np import tensorflow.compat.v1 as tf import tensorflow_datasets as tfds -from tensorflow.examples.tutorials.mnist import input_data tfds.disable_progress_bar() @@ -232,6 +231,8 @@ def neural_net(x): def get_data() -> "tf.contrib.learn.python.learn.datasets.base.Datasets": + from tensorflow.examples.tutorials.mnist import input_data + mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) return mnist