diff --git a/matchzoo/data_generator/dpool_data_generator.py b/matchzoo/data_generator/dpool_data_generator.py index 355e65f8..a54f2c3a 100644 --- a/matchzoo/data_generator/dpool_data_generator.py +++ b/matchzoo/data_generator/dpool_data_generator.py @@ -17,8 +17,7 @@ def _dynamic_pooling_index(length_left: np.array, compress_ratio_left: float, compress_ratio_right: float) -> np.array: - def _dpool_index(batch_idx: int, - one_length_left: int, + def _dpool_index(one_length_left: int, one_length_right: int, fixed_length_left: int, fixed_length_right: int): @@ -38,8 +37,7 @@ def _dpool_index(batch_idx: int, for i in range(fixed_length_right)] mesh1, mesh2 = np.meshgrid(one_idx_left, one_idx_right) index_one = np.transpose( - np.stack([np.ones(mesh1.shape) * batch_idx, - mesh1, mesh2]), (2, 1, 0)) + np.stack([mesh1, mesh2]), (2, 1, 0)) return index_one index = [] @@ -53,8 +51,7 @@ def _dpool_index(batch_idx: int, cur_fixed_length_right = fixed_length_right // compress_ratio_right \ + dpool_bias_right for i in range(len(length_left)): - index.append(_dpool_index(i, - length_left[i] // compress_ratio_left, + index.append(_dpool_index(length_left[i] // compress_ratio_left, length_right[i] // compress_ratio_right, cur_fixed_length_left, cur_fixed_length_right)) diff --git a/matchzoo/layers/dynamic_pooling_layer.py b/matchzoo/layers/dynamic_pooling_layer.py index d740c03e..40b989f7 100644 --- a/matchzoo/layers/dynamic_pooling_layer.py +++ b/matchzoo/layers/dynamic_pooling_layer.py @@ -50,7 +50,16 @@ def call(self, inputs: list, **kwargs) -> typing.Any: :param inputs: two input tensors. """ x, dpool_index = inputs - x_expand = K.tf.gather_nd(x, dpool_index) + dpool_shape = K.tf.shape(dpool_index) + batch_index_one = K.tf.expand_dims( + K.tf.expand_dims( + K.tf.range(dpool_shape[0]), axis=-1), + axis=-1) + batch_index = K.tf.expand_dims( + K.tf.tile(batch_index_one, [1, self._msize1, self._msize2]), + axis=-1) + dpool_index_ex = K.tf.concat([batch_index, dpool_index], axis=3) + x_expand = K.tf.gather_nd(x, dpool_index_ex) stride1 = self._msize1 / self._psize1 stride2 = self._msize2 / self._psize2 diff --git a/matchzoo/models/match_pyramid.py b/matchzoo/models/match_pyramid.py index 49fbc798..e1426618 100644 --- a/matchzoo/models/match_pyramid.py +++ b/matchzoo/models/match_pyramid.py @@ -68,7 +68,7 @@ def build(self): name='dpool_index', shape=[self._params['input_shapes'][0][0], self._params['input_shapes'][1][0], - 3], + 2], dtype='int32') embedding = self._make_embedding_layer()