-
Notifications
You must be signed in to change notification settings - Fork 82
/
test_concat_layer.py
43 lines (35 loc) · 1.34 KB
/
test_concat_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# Third Party
import numpy as np
import pytest
from tensorflow.keras.layers import Concatenate, Dense
from tensorflow.python.keras.models import Model
from tests.tensorflow2.utils import is_tf_2_6
# First Party
import smdebug.tensorflow as smd
from smdebug.trials import create_trial
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.con = Concatenate()
self.dense = Dense(10, activation="relu")
def call(self, x):
x = self.con([x, x])
return self.dense(x)
@pytest.mark.skipif(
is_tf_2_6() is True, reason="Breaking Changes In TF 2.6.0 deprecates this feature"
)
def test_multiple_inputs(out_dir):
my_model = MyModel()
hook = smd.KerasHook(
out_dir, save_all=True, save_config=smd.SaveConfig(save_steps=[0], save_interval=1)
)
hook.register_model(my_model)
x_train = np.random.random((1000, 20))
y_train = np.random.random((1000, 1))
my_model.compile(optimizer="Adam", loss="mse", run_eagerly=True)
my_model.fit(x_train, y_train, epochs=1, steps_per_epoch=1, callbacks=[hook])
trial = create_trial(path=out_dir)
tnames = sorted(trial.tensor_names(collection=smd.CollectionKeys.LAYERS))
assert "concatenate" in tnames[0]
assert len(trial.tensor(tnames[0]).value(0)) == 2
assert trial.tensor(tnames[0]).shape(0) == (2, 1000, 20)