-
Notifications
You must be signed in to change notification settings - Fork 82
/
test_keras_to_estimator.py
47 lines (36 loc) · 1.52 KB
/
test_keras_to_estimator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# Third Party
import tensorflow as tf
import tensorflow_datasets as tfds
from tests.constants import TEST_DATASET_S3_PATH
from tests.utils import use_s3_datasets
# First Party
from smdebug.tensorflow import EstimatorHook, modes
def test_keras_to_estimator(out_dir):
model = tf.keras.models.Sequential(
[
tf.keras.layers.Dense(16, activation="relu", input_shape=(4,)),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(1, activation="sigmoid"),
]
)
def input_fn():
split = tfds.Split.TRAIN
data_dir = TEST_DATASET_S3_PATH if use_s3_datasets() else None
dataset = tfds.load("iris", data_dir=data_dir, split=split, as_supervised=True)
dataset = dataset.map(lambda features, labels: ({"dense_input": features}, labels))
dataset = dataset.batch(32).repeat()
return dataset
model.compile(loss="categorical_crossentropy", optimizer="adam")
model.summary()
keras_estimator = tf.keras.estimator.model_to_estimator(keras_model=model, model_dir=out_dir)
hook = EstimatorHook(out_dir)
hook.set_mode(modes.TRAIN)
keras_estimator.train(input_fn=input_fn, steps=25, hooks=[hook])
hook.set_mode(modes.EVAL)
eval_result = keras_estimator.evaluate(input_fn=input_fn, steps=10, hooks=[hook])
from smdebug.trials import create_trial
tr = create_trial(out_dir)
assert len(tr.tensor_names()) == 1
assert len(tr.steps()) == 2
assert len(tr.steps(modes.TRAIN)) == 1
assert len(tr.steps(modes.EVAL)) == 1