Skip to content

Commit

Permalink
TF 2.x: Support for keras to estimator (#268)
Browse files Browse the repository at this point in the history
  • Loading branch information
vandanavk committed Jun 12, 2020
1 parent 94acc66 commit 749bded
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 10 deletions.
2 changes: 1 addition & 1 deletion docs/sagemaker.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Here's a list of frameworks and versions which support this experience.

| Framework | Version |
| --- | --- |
| [TensorFlow](tensorflow.md) | 1.15, 2.1 |
| [TensorFlow](tensorflow.md) | 1.15, 2.1, 2.2 |
| [MXNet](mxnet.md) | 1.6 |
| [PyTorch](pytorch.md) | 1.4, 1.5 |
| [XGBoost](xgboost.md) | >=0.90-2 [As Built-in algorithm](xgboost.md#use-xgboost-as-a-built-in-algorithm)|
Expand Down
15 changes: 10 additions & 5 deletions docs/tensorflow.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,22 @@
### Versions
- Zero Script Change experience where you need no modifications to your training script is supported in the official [SageMaker Framework Container for TensorFlow 1.15](https://docs.aws.amazon.com/sagemaker/latest/dg/pre-built-containers-frameworks-deep-learning.html), or the [AWS Deep Learning Container for TensorFlow 1.15](https://aws.amazon.com/machine-learning/containers/).

- This library itself supports the following versions when you use our API which requires a few minimal changes to your training script: TensorFlow 1.14, 1.15, 2.0.1, 2.1.0. Keras 2.3.
- This library itself supports the following versions when you use our API which requires a few minimal changes to your training script: TensorFlow 1.14, 1.15, 2.0+. Keras 2.3.

### Interfaces
- [Estimator](https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/estimator)
- [tf.keras](https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/keras)
- [MonitoredSession](https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/train/MonitoredSession?hl=en)
- TF 1.x:
- [Estimator](https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/estimator)
- [tf.keras](https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/keras)
- [MonitoredSession](https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/train/MonitoredSession?hl=en)
- TF 2.x:
- [Estimator](https://www.tensorflow.org/versions/r2.1/api_docs/python/tf/estimator)
- [tf.keras](https://www.tensorflow.org/versions/r2.1/api_docs/python/tf/keras)


### Distributed training
- [MirroredStrategy](https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/distribute/MirroredStrategy) or [Contrib MirroredStrategy](https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/contrib/distribute/MirroredStrategy)

We will very quickly follow up with support for Horovod and Parameter Server based training.
We will very quickly follow up with support for Parameter Server based training.

---

Expand Down
3 changes: 0 additions & 3 deletions smdebug/tensorflow/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,6 @@ def __init__(self, collections=None, create_default=True):
self.create_collection(n)
if is_tf_version_2x() and tf.executing_eagerly():
self.get(CollectionKeys.BIASES).include("^(?!gradient).*bias")
self.get(CollectionKeys.WEIGHTS).include("^weights/.*/((?!bias).)*$")
self.get(CollectionKeys.LOSSES).include(".*loss.*")
self.get(CollectionKeys.GRADIENTS).include("^gradient")
else:
self.get(CollectionKeys.BIASES).include("bias")

Expand Down
5 changes: 5 additions & 0 deletions smdebug/tensorflow/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,11 @@ def run(*args, **kwargs):
# at this point we need all collections to be ready
# this may not be the case at creation of hook
# as user's code after hook might add collections
self.collection_manager.get(CollectionKeys.WEIGHTS).include(
"^weights/.*/((?!bias).)*$"
)
self.collection_manager.get(CollectionKeys.LOSSES).include(".*loss.*")
self.collection_manager.get(CollectionKeys.GRADIENTS).include("^gradient")
self._prepare_collections()
self.prepared_collections = True

Expand Down
61 changes: 61 additions & 0 deletions tests/tensorflow2/test_estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Standard Library
# Third Party
import pytest
import tensorflow.compat.v2 as tf
from tests.zero_code_change.tf_utils import get_estimator, get_input_fns

# First Party
import smdebug.tensorflow as smd


@pytest.mark.parametrize("saveall", [True, False])
def test_estimator(out_dir, tf_eager_mode, saveall):
""" Works as intended. """
if tf_eager_mode is False:
tf.compat.v1.disable_eager_execution()
tf.compat.v1.reset_default_graph()
tf.keras.backend.clear_session()
mnist_classifier = get_estimator()
train_input_fn, eval_input_fn = get_input_fns()

# Train and evaluate
train_steps, eval_steps = 8, 2
hook = smd.EstimatorHook(out_dir=out_dir, save_all=saveall)
hook.set_mode(mode=smd.modes.TRAIN)
mnist_classifier.train(input_fn=train_input_fn, steps=train_steps, hooks=[hook])
hook.set_mode(mode=smd.modes.EVAL)
mnist_classifier.evaluate(input_fn=eval_input_fn, steps=eval_steps, hooks=[hook])

# Check that hook created and tensors saved
trial = smd.create_trial(path=out_dir)
tnames = trial.tensor_names()
assert len(trial.steps()) > 0
if saveall:
assert len(tnames) >= 301
else:
assert len(tnames) == 1


@pytest.mark.parametrize("saveall", [True, False])
def test_linear_classifier(out_dir, tf_eager_mode, saveall):
""" Works as intended. """
if tf_eager_mode is False:
tf.compat.v1.disable_eager_execution()
tf.compat.v1.reset_default_graph()
tf.keras.backend.clear_session()
train_input_fn, eval_input_fn = get_input_fns()
x_feature = tf.feature_column.numeric_column("x", shape=(28, 28))
estimator = tf.estimator.LinearClassifier(
feature_columns=[x_feature], model_dir="/tmp/mnist_linear_classifier", n_classes=10
)
hook = smd.EstimatorHook(out_dir=out_dir, save_all=saveall)
estimator.train(input_fn=train_input_fn, steps=10, hooks=[hook])

# Check that hook created and tensors saved
trial = smd.create_trial(path=out_dir)
tnames = trial.tensor_names()
assert len(trial.steps()) > 0
if saveall:
assert len(tnames) >= 224
else:
assert len(tnames) == 2
45 changes: 45 additions & 0 deletions tests/tensorflow2/test_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# Third Party
import pytest
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
from tests.tensorflow2.utils import is_tf_2_2
from tests.tensorflow.utils import create_trial_fast_refresh

Expand Down Expand Up @@ -649,3 +650,47 @@ def test_keras_fit_pure_eager(out_dir, tf_eager_mode):
assert len(trial.tensor_names(collection=CollectionKeys.BIASES)) == 2
assert len(trial.tensor_names(collection=CollectionKeys.WEIGHTS)) == 2
assert len(trial.tensor_names(collection=CollectionKeys.OPTIMIZER_VARIABLES)) == 5


def test_keras_to_estimator(out_dir, tf_eager_mode):
if not tf_eager_mode:
tf.compat.v1.disable_eager_execution()
tf.compat.v1.reset_default_graph()

tf.keras.backend.clear_session()

model = tf.keras.models.Sequential(
[
tf.keras.layers.Dense(16, activation="relu", input_shape=(4,)),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(1, activation="sigmoid"),
]
)

def input_fn():
split = tfds.Split.TRAIN
dataset = tfds.load("iris", split=split, as_supervised=True)
dataset = dataset.map(lambda features, labels: ({"dense_input": features}, labels))
dataset = dataset.batch(32).repeat()
return dataset

model.compile(loss="categorical_crossentropy", optimizer="adam")
model.summary()

keras_estimator = tf.keras.estimator.model_to_estimator(keras_model=model, model_dir=out_dir)

hook = smd.EstimatorHook(out_dir)

hook.set_mode(smd.modes.TRAIN)
keras_estimator.train(input_fn=input_fn, steps=25, hooks=[hook])

hook.set_mode(smd.modes.EVAL)
eval_result = keras_estimator.evaluate(input_fn=input_fn, steps=10, hooks=[hook])

from smdebug.trials import create_trial

tr = create_trial(out_dir)
assert len(tr.tensor_names()) == 1
assert len(tr.steps()) == 2
assert len(tr.steps(smd.modes.TRAIN)) == 1
assert len(tr.steps(smd.modes.EVAL)) == 1
3 changes: 2 additions & 1 deletion tests/zero_code_change/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds
from tensorflow.examples.tutorials.mnist import input_data

tfds.disable_progress_bar()

Expand Down Expand Up @@ -232,6 +231,8 @@ def neural_net(x):


def get_data() -> "tf.contrib.learn.python.learn.datasets.base.Datasets":
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
return mnist

Expand Down

0 comments on commit 749bded

Please sign in to comment.