Skip to content
This repository has been archived by the owner on Mar 3, 2024. It is now read-only.

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
CyberZHG committed Sep 17, 2018
1 parent fc3017a commit 57bd1b2
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
Empty file.
17 changes: 14 additions & 3 deletions tests/seq_weighted_attention/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,16 @@ def _test_save_load(self, attention):
lstm = keras.layers.Bidirectional(keras.layers.LSTM(units=7,
return_sequences=True),
name='Bi-LSTM')(embd)
weighted, attention = attention(lstm)
dense = keras.layers.Dense(units=2, activation='softmax', name='Softmax')(weighted)
model = keras.models.Model(inputs=inputs, outputs=[dense, attention])
if attention.return_attention:
layer, weights = attention(lstm)
else:
layer = attention(lstm)
dense = keras.layers.Dense(units=2, activation='softmax', name='Softmax')(layer)
if attention.return_attention:
outputs = [dense, weights]
else:
outputs = dense
model = keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(
optimizer='adam',
loss={'Softmax': 'sparse_categorical_crossentropy'},
Expand All @@ -29,6 +36,10 @@ def _test_save_load(self, attention):
model.save(model_path)
model = keras.models.load_model(model_path, custom_objects=Attention.get_custom_objects())
model.summary(line_length=100)
if attention.return_attention:
self.assertEqual(2, len(model.outputs))
else:
self.assertEqual(1, len(model.outputs))

def test_default(self):
self._test_save_load(Attention(name='Attention'))
Expand Down

0 comments on commit 57bd1b2

Please sign in to comment.