Skip to content

Commit

Permalink
[BugFix] fix dynamic pooling for evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
pl8787 committed Feb 1, 2019
1 parent b20139f commit 07460c7
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 8 deletions.
9 changes: 3 additions & 6 deletions matchzoo/data_generator/dpool_data_generator.py
Expand Up @@ -17,8 +17,7 @@ def _dynamic_pooling_index(length_left: np.array,
compress_ratio_left: float,
compress_ratio_right: float) -> np.array:

def _dpool_index(batch_idx: int,
one_length_left: int,
def _dpool_index(one_length_left: int,
one_length_right: int,
fixed_length_left: int,
fixed_length_right: int):
Expand All @@ -38,8 +37,7 @@ def _dpool_index(batch_idx: int,
for i in range(fixed_length_right)]
mesh1, mesh2 = np.meshgrid(one_idx_left, one_idx_right)
index_one = np.transpose(
np.stack([np.ones(mesh1.shape) * batch_idx,
mesh1, mesh2]), (2, 1, 0))
np.stack([mesh1, mesh2]), (2, 1, 0))
return index_one

index = []
Expand All @@ -53,8 +51,7 @@ def _dpool_index(batch_idx: int,
cur_fixed_length_right = fixed_length_right // compress_ratio_right \
+ dpool_bias_right
for i in range(len(length_left)):
index.append(_dpool_index(i,
length_left[i] // compress_ratio_left,
index.append(_dpool_index(length_left[i] // compress_ratio_left,
length_right[i] // compress_ratio_right,
cur_fixed_length_left,
cur_fixed_length_right))
Expand Down
11 changes: 10 additions & 1 deletion matchzoo/layers/dynamic_pooling_layer.py
Expand Up @@ -50,7 +50,16 @@ def call(self, inputs: list, **kwargs) -> typing.Any:
:param inputs: two input tensors.
"""
x, dpool_index = inputs
x_expand = K.tf.gather_nd(x, dpool_index)
dpool_shape = K.tf.shape(dpool_index)
batch_index_one = K.tf.expand_dims(
K.tf.expand_dims(
K.tf.range(dpool_shape[0]), axis=-1),
axis=-1)
batch_index = K.tf.expand_dims(
K.tf.tile(batch_index_one, [1, self._msize1, self._msize2]),
axis=-1)
dpool_index_ex = K.tf.concat([batch_index, dpool_index], axis=3)
x_expand = K.tf.gather_nd(x, dpool_index_ex)
stride1 = self._msize1 / self._psize1
stride2 = self._msize2 / self._psize2

Expand Down
2 changes: 1 addition & 1 deletion matchzoo/models/match_pyramid.py
Expand Up @@ -68,7 +68,7 @@ def build(self):
name='dpool_index',
shape=[self._params['input_shapes'][0][0],
self._params['input_shapes'][1][0],
3],
2],
dtype='int32')

embedding = self._make_embedding_layer()
Expand Down

0 comments on commit 07460c7

Please sign in to comment.