Skip to content

Commit

Permalink
hotfix: remove unused code
Browse files Browse the repository at this point in the history
  • Loading branch information
rgtjf committed Apr 14, 2019
1 parent 49cfea5 commit 42b4fcf
Showing 1 changed file with 3 additions and 9 deletions.
12 changes: 3 additions & 9 deletions matchzoo/contrib/layers/attention_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class AttentionLayer(Layer):
def __init__(self,
att_dim: int,
att_type: str = 'default',
remove_diagonal: bool = False,
# remove_diagonal: bool = False,
dropout_rate: float = 0.0):
"""
class: `AttentionLayer` constructor.
Expand All @@ -35,7 +35,7 @@ def __init__(self,
super(AttentionLayer, self).__init__()
self._att_dim = att_dim
self._att_type = att_type
self._remove_diagonal = remove_diagonal
# self._remove_diagonal = remove_diagonal
self._dropout_rate = dropout_rate

@property
Expand Down Expand Up @@ -128,16 +128,10 @@ def call(self, x: list, **kwargs):
# diagonal = tf.expand_dims(diagonal, axis=0) # ['x', len1, len1]
# attn_value = attn_value * diagonal

if len(x) == 4:
mask_lt = x[2]
mask_rt = x[3]
attn_value = attn_value * K.expand_dims(mask_lt, axis=2)
attn_value = attn_value * K.expand_dims(mask_rt, axis=1)

# softmax
attn_prob = K.softmax(attn_value) # [batch_size, len_1, len_2]

# if remove_diagonal: attn_value = attn_value * diagonal

if len(x) == 4:
mask_lt = x[2]
mask_rt = x[3]
Expand Down

0 comments on commit 42b4fcf

Please sign in to comment.