Skip to content
This repository has been archived by the owner on Mar 3, 2024. It is now read-only.

Commit

Permalink
Use stricter tests
Browse files Browse the repository at this point in the history
  • Loading branch information
CyberZHG committed Jul 29, 2020
1 parent 39966fa commit f334154
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion keras_self_attention/__init__.py
Expand Up @@ -2,4 +2,4 @@
from .seq_weighted_attention import SeqWeightedAttention
from .scaled_dot_attention import ScaledDotProductAttention

__version__ = '0.46.0'
__version__ = '0.47.0'
2 changes: 1 addition & 1 deletion keras_self_attention/seq_self_attention.py
Expand Up @@ -169,7 +169,7 @@ def call(self, inputs, mask=None, **kwargs):
lower = K.expand_dims(lower, axis=-1)
upper = lower + self.attention_width
indices = K.expand_dims(K.arange(0, input_len), axis=0)
e -= 10000.0 * (1 - K.cast(indices >= lower, K.floatx()) * K.cast(upper > indices, K.floatx()))
e -= 10000.0 * (1.0 - K.cast(lower <= indices, K.floatx()) * K.cast(indices < upper, K.floatx()))
if mask is not None:
mask = K.expand_dims(K.cast(mask, K.floatx()), axis=-1)
e -= 10000.0 * ((1.0 - mask) * (1.0 - K.permute_dimensions(mask, (0, 2, 1))))
Expand Down
2 changes: 2 additions & 0 deletions tests/seq_self_attention/util.py
Expand Up @@ -71,4 +71,6 @@ def check_mask_shape(self, attention):
self.assertGreater(attention_output[i][j][k], 0.0)
elif not history_only and abs(j - k) <= attention_width // 2:
self.assertGreater(attention_output[i][j][k], 0.0)
else:
self.assertEqual(attention_output[i][j][k], 0.0)
self.assertTrue(abs(np.sum(attention_output[i][j]) - 1.0) < 1e-6)

0 comments on commit f334154

Please sign in to comment.