diff --git a/.gitignore b/.gitignore index 767cf4c4..382ed42f 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ private !delta/data/feat/python_speech_features/english.wav *.mp3 tags +venv gen *.cxx *.o diff --git a/delta/data/task/text_match_task.py b/delta/data/task/text_match_task.py index 070a03cd..1950a1d1 100644 --- a/delta/data/task/text_match_task.py +++ b/delta/data/task/text_match_task.py @@ -27,6 +27,8 @@ from delta.utils.register import registers from delta.layers.utils import compute_sen_lens from delta import utils + + # pylint: disable=too-many-instance-attributes @@ -45,10 +47,10 @@ def __init__(self, config, mode): self.paths = self.data_config[mode]['paths'] self.paths_after_pre_process = [ - one_path + ".after" for one_path in self.paths + one_path + ".after" for one_path in self.paths ] self.infer_no_label = self.config["data"][utils.INFER].get( - 'infer_no_label', False) + 'infer_no_label', False) self.infer_without_label = bool(mode == utils.INFER and self.infer_no_label) self.prepare() @@ -59,26 +61,26 @@ def generate_data(self): if self.infer_without_label: column_num = 2 text_ds_left, text_ds_right = load_textline_dataset( - self.paths_after_pre_process, column_num) + self.paths_after_pre_process, column_num) else: column_num = 3 label, text_ds_left, text_ds_right = load_textline_dataset( - self.paths_after_pre_process, column_num) + self.paths_after_pre_process, column_num) input_pipeline_func = self.get_input_pipeline(for_export=False) text_ds_left = text_ds_left.map( - input_pipeline_func, num_parallel_calls=self.num_parallel_calls) + input_pipeline_func, num_parallel_calls=self.num_parallel_calls) text_ds_right = text_ds_right.map( - input_pipeline_func, num_parallel_calls=self.num_parallel_calls) + input_pipeline_func, num_parallel_calls=self.num_parallel_calls) text_size_ds_left = text_ds_left.map( - lambda x: compute_sen_lens(x, padding_token=0), - num_parallel_calls=self.num_parallel_calls) + lambda x: compute_sen_lens(x, padding_token=0), + num_parallel_calls=self.num_parallel_calls) text_size_ds_right = text_ds_right.map( - lambda x: compute_sen_lens(x, padding_token=0), - num_parallel_calls=self.num_parallel_calls) + lambda x: compute_sen_lens(x, padding_token=0), + num_parallel_calls=self.num_parallel_calls) text_ds_left_right = tf.data.Dataset.zip((text_ds_left, text_ds_right)) text_len_left_right = tf.data.Dataset.zip( - (text_size_ds_left, text_size_ds_right)) + (text_size_ds_left, text_size_ds_right)) if self.infer_without_label: data_set_left_right = text_ds_left_right else: @@ -89,7 +91,7 @@ def generate_data(self): self.config['data']['vocab_size'] = vocab_size self.config['data']['{}_data_size'.format(self.mode)] = get_file_len( - self.paths_after_pre_process) + self.paths_after_pre_process) return data_set_left_right, text_len_left_right @@ -99,8 +101,12 @@ def feature_spec(self): tf.TensorShape([self.max_seq_len]))] if not self.infer_without_label: feature_shapes.append(tf.TensorShape([self.num_classes])) + + feature_shapes = [tuple(feature_shapes), (tf.TensorShape([]), tf.TensorShape([]))] + if len(feature_shapes) == 1: return feature_shapes[0] + return tuple(feature_shapes) def export_inputs(self): @@ -110,79 +116,84 @@ def export_inputs(self): self.config['data']['vocab_size'] = vocab_size input_sent_left = tf.placeholder( - shape=(None,), dtype=tf.string, name="input_sent_left") + shape=(None,), dtype=tf.string, name="input_sent_left") input_sent_right = tf.placeholder( - shape=(None,), dtype=tf.string, name="input_sent_right") + shape=(None,), dtype=tf.string, name="input_sent_right") input_pipeline_func = self.get_input_pipeline(for_export=True) token_ids_left = input_pipeline_func(input_sent_left) token_ids_right = input_pipeline_func(input_sent_right) token_ids_len_left = tf.map_fn( - lambda x: compute_sen_lens(x, padding_token=0), token_ids_left) + lambda x: compute_sen_lens(x, padding_token=0), token_ids_left) token_ids_len_right = tf.map_fn( - lambda x: compute_sen_lens(x, padding_token=0), token_ids_right) + lambda x: compute_sen_lens(x, padding_token=0), token_ids_right) + export_data = { - "export_inputs": { - "input_sent_left": input_sent_left, - "input_sent_right": input_sent_right, - }, - "model_inputs": { - "input_x_left": token_ids_left, - "input_x_right": token_ids_right, - "input_x_len": [token_ids_len_left, token_ids_len_right] - } + "export_inputs": { + "input_sent_left": input_sent_left, + "input_sent_right": input_sent_right, + }, + "model_inputs": { + "input_x_left": token_ids_left, + "input_x_right": token_ids_right, + "input_x_left_len": token_ids_len_left, + "input_x_right_len": token_ids_len_right, + "input_x_len": [token_ids_len_left, token_ids_len_right] + } } return export_data def dataset(self): """Data set function""" - data_set_left_right, text_len_left_right = self.generate_data() + ds_left_right, ds_left_right_len = self.generate_data() + text_ds_left_right = tf.data.Dataset.zip((ds_left_right, ds_left_right_len)) - logging.debug("data_set_left_right: {}".format(data_set_left_right)) if self.mode == 'train': if self.need_shuffle: # shuffle batch size and repeat logging.debug("shuffle and repeat dataset ...") - data_set_left_right = data_set_left_right.apply( - tf.data.experimental.shuffle_and_repeat( - buffer_size=self.shuffle_buffer_size, count=None)) + text_ds_left_right = text_ds_left_right.apply( + tf.data.experimental.shuffle_and_repeat( + buffer_size=self.shuffle_buffer_size, count=None)) else: logging.debug("repeat dataset ...") - data_set_left_right = data_set_left_right.repeat(count=None) + text_ds_left_right = text_ds_left_right.repeat(count=None) + feature_shape = self.feature_spec() logging.debug("feature_shape: {}".format(feature_shape)) - data_set_left_right = data_set_left_right.padded_batch( - batch_size=self.batch_size, padded_shapes=feature_shape) - text_len_left_right = text_len_left_right.batch(self.batch_size) + # logging.debug("data_set_left_right:{}".format(data_set_left_right)) + + text_ds_left_right = text_ds_left_right.padded_batch( + batch_size=self.batch_size, padded_shapes=feature_shape) - data_set_left_right = data_set_left_right.prefetch(self.num_prefetch_batch) - text_len_left_right = text_len_left_right.prefetch(self.num_prefetch_batch) + text_ds_left_right = text_ds_left_right.prefetch(self.num_prefetch_batch) - iterator = data_set_left_right.make_initializable_iterator() - iterator_len = text_len_left_right.make_initializable_iterator() + iterator = text_ds_left_right.make_initializable_iterator() # pylint: disable=unused-variable if self.infer_without_label: - input_x_left, input_x_right = iterator.get_next() + (input_x_left, input_x_right), (input_x_left_len, input_x_right_len) = iterator.get_next() else: + ((input_x_left, input_x_right), input_y), (input_x_left_len, input_x_right_len) = iterator.get_next() - (input_x_left, input_x_right), input_y = iterator.get_next() - - input_x_left_len, input_x_right_len = iterator_len.get_next() input_x_dict = collections.OrderedDict([("input_x_left", input_x_left), - ("input_x_right", input_x_right)]) + ("input_x_right", input_x_right), + ("input_x_left_len", input_x_left_len), + ("input_x_right_len", input_x_right_len), + ]) input_x_len = collections.OrderedDict([ - ("input_x_left_len", input_x_left_len), - ("input_x_right_len", input_x_right_len) + ("input_x_left_len", input_x_left_len), + ("input_x_right_len", input_x_right_len) ]) + return_dict = { - "input_x_dict": input_x_dict, - "input_x_len": input_x_len, - "iterator": iterator, - "iterator_len": iterator_len, + "input_x_dict": input_x_dict, + "input_x_len": input_x_len, + "iterator": iterator, } if not self.infer_without_label: return_dict["input_y_dict"] = collections.OrderedDict([("input_y", input_y)]) return return_dict + diff --git a/delta/data/task/text_match_task_test.py b/delta/data/task/text_match_task_test.py index eb8ad4a1..1766aaa0 100644 --- a/delta/data/task/text_match_task_test.py +++ b/delta/data/task/text_match_task_test.py @@ -72,7 +72,7 @@ def test_english(self): # with self.cached_session(use_gpu=False, force_gpu=False) as sess: # sess.run(data["iterator"].initializer) with self.cached_session(use_gpu=False, force_gpu=False) as sess: - sess.run([data["iterator"].initializer, data["iterator_len"].initializer]) + sess.run([data["iterator"].initializer]) res = sess.run([ data["input_x_dict"]["input_x_left"], data["input_x_dict"]["input_x_right"], diff --git a/delta/layers/dynamic_pooling.py b/delta/layers/dynamic_pooling.py new file mode 100644 index 00000000..9ebd48c7 --- /dev/null +++ b/delta/layers/dynamic_pooling.py @@ -0,0 +1,136 @@ +# Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""An implementation of Dynamic Pooling Layer.""" +import typing + +import delta.compat as tf +from delta.layers.base_layer import Layer + + +class DynamicPoolingLayer(Layer): + """ + Layer that computes dynamic pooling of one tensor. + :param psize1: pooling size of dimension 1 + :param psize2: pooling size of dimension 2 + :param kwargs: Standard layer keyword arguments. + Examples: + >>> import delta + >>> layer = delta.layers.DynamicPoolingLayer(3, 2) + >>> num_batch, left_len, right_len, num_dim = 5, 3, 2, 10 + >>> layer.build([[num_batch, left_len, right_len, num_dim], + ... [num_batch, left_len, right_len, 3]]) + """ + + def __init__(self, + psize1: int, + psize2: int, + **kwargs): + """:class:`DynamicPoolingLayer` constructor.""" + super().__init__(**kwargs) + self._psize1 = psize1 + self._psize2 = psize2 + + def build(self, input_shape: typing.List[int]): + """ + Build the layer. + :param input_shape: the shapes of the input tensors, + for DynamicPoolingLayer we need tow input tensors. + """ + super().build(input_shape) + input_shape_one = input_shape[0] + self._msize1 = input_shape_one[1] + self._msize2 = input_shape_one[2] + + def call(self, inputs: list, **kwargs) -> typing.Any: + """ + The computation logic of DynamicPoolingLayer. + :param inputs: two input tensors. + """ + self._validate_dpool_size() + x, dpool_index = inputs + dpool_shape = tf.shape(dpool_index) + batch_index_one = tf.expand_dims( + tf.expand_dims( + tf.range(dpool_shape[0]), axis=-1), + axis=-1) + batch_index = tf.expand_dims( + tf.tile(batch_index_one, [1, self._msize1, self._msize2]), + axis=-1) + dpool_index_ex = tf.concat([batch_index, dpool_index], axis=3) + x_expand = tf.gather_nd(x, dpool_index_ex) + stride1 = self._msize1 // self._psize1 + stride2 = self._msize2 // self._psize2 + + x_pool = tf.nn.max_pool(x_expand, + [1, stride1, stride2, 1], + [1, stride1, stride2, 1], + "VALID") + return x_pool + + def compute_output_shape(self, input_shape: list) -> tuple: + """ + Calculate the layer output shape. + :param input_shape: the shapes of the input tensors, + for DynamicPoolingLayer we need tow input tensors. + """ + input_shape_one = input_shape[0] + return (None, self._psize1, self._psize2, input_shape_one[3]) + + def get_config(self) -> dict: + """Get the config dict of DynamicPoolingLayer.""" + config = { + 'psize1': self._psize1, + 'psize2': self._psize2 + } + base_config = super(DynamicPoolingLayer, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + def _validate_dpool_size(self): + suggestion = self.get_size_suggestion( + self._msize1, self._msize2, self._psize1, self._psize2 + ) + if suggestion != (self._psize1, self._psize2): + raise ValueError( + "DynamicPooling Layer can not " + f"generate ({self._psize1} x {self._psize2}) output " + f"feature map, please use ({suggestion[0]} x {suggestion[1]})" + f" instead. `model.params['dpool_size'] = {suggestion}` " + ) + + @classmethod + def get_size_suggestion( + cls, + msize1: int, + msize2: int, + psize1: int, + psize2: int + ) -> typing.Tuple[int, int]: + """ + Get `dpool_size` suggestion for a given shape. + Returns the nearest legal `dpool_size` for the given combination of + `(psize1, psize2)`. + :param msize1: size of the left text. + :param msize2: size of the right text. + :param psize1: base size of the pool. + :param psize2: base size of the pool. + :return: + """ + stride1 = msize1 // psize1 + stride2 = msize2 // psize2 + suggestion1 = msize1 // stride1 + suggestion2 = msize2 // stride2 + return (suggestion1, suggestion2) + diff --git a/delta/layers/match_pyramid.py b/delta/layers/match_pyramid.py new file mode 100644 index 00000000..3d8c5418 --- /dev/null +++ b/delta/layers/match_pyramid.py @@ -0,0 +1,149 @@ +# Copyright (C) 2017 Beijing Didi Infinity Technology and Development Co.,Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""An implementation of Matching Layer.""" +import typing + +import delta.compat as tf +from delta.layers.base_layer import Layer + + +class MatchingLayer(Layer): + """ + Layer that computes a matching matrix between samples in two tensors. + :param normalize: Whether to L2-normalize samples along the + dot product axis before taking the dot product. + If set to True, then the output of the dot product + is the cosine proximity between the two samples. + :param matching_type: the similarity function for matching + :param kwargs: Standard layer keyword arguments. + Examples: + >>> import delta + >>> layer = delta.layers.MatchingLayer(matching_type='dot', + ... normalize=True) + >>> num_batch, left_len, right_len, num_dim = 5, 3, 2, 10 + >>> layer.build([[num_batch, left_len, num_dim], + ... [num_batch, right_len, num_dim]]) + """ + + def __init__(self, normalize: bool = False, + matching_type: str = 'dot', **kwargs): + """:class:`MatchingLayer` constructor.""" + super().__init__(**kwargs) + self._normalize = normalize + self._validate_matching_type(matching_type) + self._matching_type = matching_type + self._shape1 = None + self._shape2 = None + + @classmethod + def _validate_matching_type(cls, matching_type: str = 'dot'): + valid_matching_type = ['dot', 'mul', 'plus', 'minus', 'concat'] + if matching_type not in valid_matching_type: + raise ValueError(f"{matching_type} is not a valid matching type, " + f"{valid_matching_type} expected.") + + def build(self, input_shape: list): + """ + Build the layer. + :param input_shape: the shapes of the input tensors, + for MatchingLayer we need tow input tensors. + """ + # Used purely for shape validation. + if not isinstance(input_shape, list) or len(input_shape) != 2: + raise ValueError('A `MatchingLayer` layer should be called ' + 'on a list of 2 inputs.') + self._shape1 = input_shape[0] + self._shape2 = input_shape[1] + for idx in 0, 2: + if self._shape1[idx] != self._shape2[idx]: + raise ValueError( + 'Incompatible dimensions: ' + f'{self._shape1[idx]} != {self._shape2[idx]}.' + f'Layer shapes: {self._shape1}, {self._shape2}.' + ) + + def call(self, inputs: list, **kwargs) -> typing.Any: + """ + The computation logic of MatchingLayer. + :param inputs: two input tensors. + """ + x1 = inputs[0] + x2 = inputs[1] + if self._matching_type == 'dot': + if self._normalize: + x1 = tf.math.l2_normalize(x1, axis=2) + x2 = tf.math.l2_normalize(x2, axis=2) + return tf.expand_dims(tf.einsum('abd,acd->abc', x1, x2), 3) + else: + if self._matching_type == 'mul': + def func(x, y): + return x * y + elif self._matching_type == 'plus': + def func(x, y): + return x + y + elif self._matching_type == 'minus': + def func(x, y): + return x - y + elif self._matching_type == 'concat': + def func(x, y): + return tf.concat([x, y], axis=3) + else: + raise ValueError(f"Invalid matching type." + f"{self._matching_type} received." + f"Mut be in `dot`, `mul`, `plus`, " + f"`minus` and `concat`.") + x1_exp = tf.stack([x1] * self._shape2[1], 2) + x2_exp = tf.stack([x2] * self._shape1[1], 1) + return func(x1_exp, x2_exp) + + def compute_output_shape(self, input_shape: list) -> tuple: + """ + Calculate the layer output shape. + :param input_shape: the shapes of the input tensors, + for MatchingLayer we need tow input tensors. + """ + if not isinstance(input_shape, list) or len(input_shape) != 2: + raise ValueError('A `MatchingLayer` layer should be called ' + 'on a list of 2 inputs.') + shape1 = list(input_shape[0]) + shape2 = list(input_shape[1]) + if len(shape1) != 3 or len(shape2) != 3: + raise ValueError('A `MatchingLayer` layer should be called ' + 'on 2 inputs with 3 dimensions.') + if shape1[0] != shape2[0] or shape1[2] != shape2[2]: + raise ValueError('A `MatchingLayer` layer should be called ' + 'on 2 inputs with same 0,2 dimensions.') + + if self._matching_type in ['mul', 'plus', 'minus']: + return shape1[0], shape1[1], shape2[1], shape1[2] + elif self._matching_type == 'dot': + return shape1[0], shape1[1], shape2[1], 1 + elif self._matching_type == 'concat': + return shape1[0], shape1[1], shape2[1], shape1[2] + shape2[2] + else: + raise ValueError(f"Invalid `matching_type`." + f"{self._matching_type} received." + f"Must be in `mul`, `plus`, `minus` " + f"`dot` and `concat`.") + + def get_config(self) -> dict: + """Get the config dict of MatchingLayer.""" + config = { + 'normalize': self._normalize, + 'matching_type': self._matching_type, + } + base_config = super(MatchingLayer, self).get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/delta/models/text_match_model.py b/delta/models/text_match_model.py index a4786ecd..e98d46ae 100644 --- a/delta/models/text_match_model.py +++ b/delta/models/text_match_model.py @@ -18,6 +18,8 @@ import pickle from absl import logging import delta.compat as tf +from delta.layers.dynamic_pooling import DynamicPoolingLayer +from delta.layers.match_pyramid import MatchingLayer from delta.models.base_model import Model from delta.utils.register import registers @@ -29,17 +31,16 @@ class MatchRnn(Model): def __init__(self, config, **kwargs): super().__init__(**kwargs) - logging.info("Initialize MatchRnn...") self.use_pretrained_embedding = config['model']['use_pre_train_emb'] if self.use_pretrained_embedding: self.embedding_path = config['model']['embedding_path'] logging.info("Loading embedding file from: {}".format( - self.embedding_path)) + self.embedding_path)) self._word_embedding_init = pickle.load(open(self.embedding_path, 'rb')) self.embed_initializer = tf.constant_initializer( - self._word_embedding_init) + self._word_embedding_init) else: self.embed_initializer = tf.random_uniform_initializer(-0.1, 0.1) @@ -65,28 +66,28 @@ def __init__(self, config, **kwargs): self.l2_reg_lambda = model_config['l2_reg_lambda'] self.embed = tf.keras.layers.Embedding( - self.vocab_size, - self.embedding_size, - trainable=self.emb_trainable, - name='embdding', - embeddings_initializer=self.embed_initializer) + self.vocab_size, + self.embedding_size, + trainable=self.emb_trainable, + name='embdding', + embeddings_initializer=self.embed_initializer) self.embed_d = tf.keras.layers.Dropout(self.dropout_rate) self.lstm_left = tf.keras.layers.LSTM( - self.lstm_num_units, return_sequences=True, name='lstm_left') + self.lstm_num_units, return_sequences=True, name='lstm_left') self.lstm_right = tf.keras.layers.LSTM( - self.lstm_num_units, return_sequences=True, name='lstm_right') + self.lstm_num_units, return_sequences=True, name='lstm_right') self.concat = tf.keras.layers.Concatenate(axis=1) self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate) self.outlayer = tf.keras.layers.Dense(self.fc_num_units, activation='tanh') self.tasktype = config['data']['task']['type'] - #if self.tasktype == "Classification": + # if self.tasktype == "Classification": self.final_dense = tf.keras.layers.Dense( - self.num_classes, - activation=tf.keras.activations.linear, - name="final_dense") + self.num_classes, + activation=tf.keras.activations.linear, + name="final_dense") logging.info("Initialize MatchRnnTextClassModel done.") @@ -114,3 +115,162 @@ def call(self, inputs, training=None, mask=None): # pylint: disable=too-many-lo scores = self.final_dense(out) return scores + + +# pylint: disable=too-many-instance-attributes,too-many-ancestors +@registers.model.register +class MatchPyramidTextClassModel(MatchRnn): + """Match texts model with Match Pyramid.""" + + def __init__(self, config, **kwargs): + super().__init__(config, **kwargs) + logging.info("Initialize MatchPyramidTextClassModel ...") + + self.vocab_size = config['data']['vocab_size'] + self.num_classes = config['data']['task']['classes']['num_classes'] + self.max_seq_len = config['data']['task']['max_seq_len'] + model_config = config['model']['net']['structure'] + self.dropout_rate = model_config['dropout_rate'] + self.embedding_size = model_config['embedding_size'] + self.emb_trainable = model_config['emb_trainable'] + self.lstm_num_units = model_config['lstm_num_units'] + self.fc_num_units = model_config['fc_num_units'] + self.l2_reg_lambda = model_config['l2_reg_lambda'] + + # Number of convolution blocks + self.num_blocks = model_config['num_blocks'] + # The kernel count of the 2D convolution + self.kernel_count = model_config['kernel_count'] + # The kernel size of the 2D convolution of each block + self.kernel_size = model_config['kernel_size'] + # The max-pooling size of each block + self.dpool_size = model_config['dpool_size'] + # The padding mode in the convolution layer + self.padding = model_config['padding'] + # The activation function + self.activation = model_config['activation'] + self.matching_type = model_config['matching_type'] + + self.embed = tf.keras.layers.Embedding( + self.vocab_size, + self.embedding_size, + trainable=self.emb_trainable, + name='embdding', + embeddings_initializer=self.embed_initializer) + + self.embed_d = tf.keras.layers.Dropout(self.dropout_rate) + + self.matching_layer = MatchingLayer(matching_type=self.matching_type) + + self.conv = [] + for i in range(self.num_blocks): + conv = tf.keras.layers.Conv2D( + self.kernel_count, + self.kernel_size, + padding=self.padding, + activation=self.activation) + self.conv.append(conv) + + self.dpool = DynamicPoolingLayer(*self.dpool_size) + + self.flatten = tf.keras.layers.Flatten() + + self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate) + self.outlayer = tf.keras.layers.Dense(self.fc_num_units, activation='tanh') + self.tasktype = config['data']['task']['type'] + # if self.tasktype == "Classification": + self.final_dense = tf.keras.layers.Dense( + self.num_classes, + activation=tf.keras.activations.linear, + name="final_dense") + + logging.info("Initialize MatchPyramidTextClassModel done.") + + def call(self, inputs, training=None, mask=None): # pylint: disable=too-many-locals + input_left = inputs["input_x_left"] + input_right = inputs["input_x_right"] + + input_x_left_len = inputs["input_x_left_len"] + input_x_right_len = inputs["input_x_right_len"] + + embedding = self.embed + embed_left = embedding(input_left) + embed_right = embedding(input_right) + + p_index = self._dynamic_pooling_index(input_x_left_len, + input_x_right_len, + self.max_seq_len, + self.max_seq_len, + 1, + 1, + ) + + embed_cross = self.matching_layer([embed_left, embed_right]) + for i in range(self.num_blocks): + embed_cross = self.conv[i](embed_cross) + embed_pool = self.dpool( + [embed_cross, p_index]) + + embed_flat = self.flatten(embed_pool) + + dropout = self.dropout(embed_flat) + out = self.outlayer(dropout) + scores = self.final_dense(out) + return scores + + + + def _dynamic_pooling_index(self, length_left, + length_right, + fixed_length_left: int, + fixed_length_right: int, + compress_ratio_left: float, + compress_ratio_right: float) -> tf.Tensor: + def _dpool_index(one_length_left, + one_length_right, + fixed_length_left, + fixed_length_right): + + logging.info("fixed_length_left: {}".format(fixed_length_left)) + logging.info("fixed_length_right: {}".format(fixed_length_right)) + + if one_length_left == 0: + stride_left = fixed_length_left + else: + stride_left = 1.0 * fixed_length_left / tf.cast(one_length_left, dtype=tf.float32) + + if one_length_right == 0: + stride_right = fixed_length_right + else: + stride_right = 1.0 * fixed_length_right / tf.cast(one_length_right, dtype=tf.float32) + + one_idx_left = [tf.cast(i / stride_left, dtype=tf.int32) + for i in range(fixed_length_left)] + one_idx_right = [tf.cast(i / stride_right, dtype=tf.int32) + for i in range(fixed_length_right)] + mesh1, mesh2 = tf.meshgrid(one_idx_left, one_idx_right) + index_one = tf.transpose( + tf.stack([mesh1, mesh2]), (2, 1, 0)) + return index_one + + index = [] + dpool_bias_left = dpool_bias_right = 0 + if fixed_length_left % compress_ratio_left != 0: + dpool_bias_left = 1 + if fixed_length_right % compress_ratio_right != 0: + dpool_bias_right = 1 + cur_fixed_length_left = int( + fixed_length_left // compress_ratio_left) + dpool_bias_left + cur_fixed_length_right = int( + fixed_length_right // compress_ratio_right) + dpool_bias_right + logging.info("length_left: {}".format(length_left)) + logging.info("length_right: {}".format(length_right)) + logging.info("cur_fixed_length_left: {}".format(cur_fixed_length_left)) + logging.info("cur_fixed_length_right: {}".format(cur_fixed_length_right)) + + index = tf.map_fn(lambda x: _dpool_index(x[0], x[1], cur_fixed_length_left, cur_fixed_length_right), + (length_left, length_right), dtype=tf.int32) + + logging.info("index: {}".format(index)) + + return index diff --git a/egs/mock_text_match_data/text_match/v1/config/pyramid-match-mock.yml b/egs/mock_text_match_data/text_match/v1/config/pyramid-match-mock.yml new file mode 100644 index 00000000..0846fecb --- /dev/null +++ b/egs/mock_text_match_data/text_match/v1/config/pyramid-match-mock.yml @@ -0,0 +1,129 @@ +--- +data: + train: + paths: + - "egs/mock_text_match_data/text_match/v1/data/train.txt" + eval: + paths: + - "egs/mock_text_match_data/text_match/v1/data/test.txt" + infer: + paths: + - "egs/mock_text_match_data/text_match/v1/data/test.txt" + + infer_no_label: False + task: + type: Classification + name: TextMatchTask + preparer: + enable: true + name: TextMatchPreparer + done_sign: "egs/mock_text_match_data/text_match/v1/exp/prepare.done" + reuse: false + use_dense: false + language: english + clean_english: True + vocab_min_frequency: 0 + split_by_space: false + use_word: true + use_custom_vocab: true + text_vocab: "egs/mock_text_match_data/text_match/v1/data/text_vocab.txt" + label_vocab: "egs/mock_text_match_data/text_match/v1/exp/label_vocab.txt" + max_seq_len: 42 + num_parallel_calls: 12 + num_prefetch_batch: 2 + shuffle_buffer_size: 15000 + need_shuffle: true + batch_size: 30 + epochs: 30 + classes: + positive_id: 1 + num_classes: 2 + vocab: + 0: 0 + 1: 1 + + +model: + name: MatchPyramidTextClassModel + type: keras + use_pre_train_emb: false + pre_train_emb_path: "" + embedding_path: "egs/mock_text_match_data/text_match/v1/exp/embeding.pkl" + + net: + structure: + embedding_size: 200 + emb_trainable: true + cell_type: gru + cell_dim: 100 + lstm_num_units: 256 #256 + fc_num_units: 100 #100 + dropout_rate: 0.0 + l2_reg_lambda: 4e-6 + activate: relu + sent_hidden_size: 300 + # pyramid + matching_type: dot + num_blocks: 1 + kernel_count: 32 + kernel_size: + - 3 + - 3 + dpool_size: + - 3 + - 10 + padding: same + activation: tanh + +solver: + name: RawMatchSolver + adversarial: + enable: false # whether to using adversiral training + adv_alpha: 0.5 # adviseral alpha of loss + adv_epslion: 0.1 # adviseral example epslion + model_average: + enable: false # use average model + var_avg_decay: 0.99 # the decay rate of varaibles + optimizer: + name: adam + loss: CrossEntropyLoss + label_smoothing: 0.0 # label smoothing rate + learning_rate: + rate: 0.0001 # learning rate of Adam optimizer + type: exp_decay # learning rate type + decay_rate: 0.99 # the lr decay rate + decay_steps: 100 # the lr decay_step for optimizer + clip_global_norm: 3.0 # clip global norm + metrics: + pos_label: 1 # int, same to sklearn + cals: + - name: AccuracyCal + arguments: Null + - name: PrecisionCal + arguments: + average: 'macro' + - name: RecallCal + arguments: + average: 'macro' + - name: F1ScoreCal + arguments: + average: 'weighted' + postproc: + name: SavePredPostProc + res_file: "egs/mock_text_match_data/text_match/v1/exp/text-match/res.txt" + saver: + model_path: "egs/mock_text_match_data/text_match/v1/exp/text-match/ckpt" + max_to_keep: 30 #30 + save_checkpoint_steps: 10 #100 + print_every: 10 + service: + model_path: "egs/mock_text_match_data/text_match/v1/exp/text-match/service" + model_version: "1" + run_config: + tf_random_seed: null + allow_soft_placement: true + log_device_placement: false + intra_op_parallelism_threads: 10 + inter_op_parallelism_threads: 10 + allow_growth: true +