From 9b4ea69b9e81cd657d277b1f2a5800e6901e17c2 Mon Sep 17 00:00:00 2001 From: zhanghaichao Date: Thu, 1 Dec 2016 15:07:27 -0800 Subject: [PATCH 1/2] add external memory demo --- demo/memnet/README.md | 21 ++++ demo/memnet/data_provider_mem.py | 96 ++++++++++++++++++ demo/memnet/dummy.list | 1 + demo/memnet/external_memory.py | 119 +++++++++++++++++++++++ demo/memnet/external_memory_example.conf | 117 ++++++++++++++++++++++ demo/memnet/test.sh | 33 +++++++ demo/memnet/train.sh | 31 ++++++ 7 files changed, 418 insertions(+) create mode 100644 demo/memnet/README.md create mode 100644 demo/memnet/data_provider_mem.py create mode 100644 demo/memnet/dummy.list create mode 100644 demo/memnet/external_memory.py create mode 100644 demo/memnet/external_memory_example.conf create mode 100755 demo/memnet/test.sh create mode 100755 demo/memnet/train.sh diff --git a/demo/memnet/README.md b/demo/memnet/README.md new file mode 100644 index 0000000000000..2e05f9990ca7e --- /dev/null +++ b/demo/memnet/README.md @@ -0,0 +1,21 @@ +# Memory Network + +## Introduction ## +This demo provides a simple example usage of the external memory in a way similar to the Neural Turing Machine (NTM) with content based addressing and differentiable read and write head. +For more technical details, please refer to the [NTM paper](https://arxiv.org/abs/1410.5401). + +## Task Description ## +Here we design a simple task for illustration purpose. The input is a sequence with variable number of zeros followed with a variable number of non-zero elements, e.g., [0, 0, 0, 3, 1, 5, ...]. The task is to memorize the first non-zero number (e.g., 3) and to output this number in the end after going through the whole sequence. + +## Folder Structure ## +* external_memory.py: the implementation of the external memory class. +* external_memory_example.conf: example usage of the external memory class. +* data_provider_mem.py: generates the training and testing data for the example. +* train.sh and test.sh: the scripts to run training and testing. + +## How to Run ## +* training: ./train.sh +* testing: ./test.sh + + + diff --git a/demo/memnet/data_provider_mem.py b/demo/memnet/data_provider_mem.py new file mode 100644 index 0000000000000..c4167d6c15eb2 --- /dev/null +++ b/demo/memnet/data_provider_mem.py @@ -0,0 +1,96 @@ +# Copyright (c) 2015 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer.PyDataProvider2 import * +import numpy as np + +########################### Parameters for Data Generation ################# +gen_range = 8 # same as the size of the dictionary +#--------------- parameters for generating training data ------------------- +# the sequence has a all-zero sub-vector in the beginning followed with a non-zero vector +# seq = [zero_sub_seq, non_zero_sub_seq] + +# parameters for non_zero_sub_seq +seq_len = 10 # length of the non_zero_sub_seq +seq_len_min = 2 # minimum length if is_fixed_len is False; + # seq_len will be used as the maximum length in this case, + # i.e., the length will be sampled from [seq_len_min, seq_len] +# parameters for zero_sub_seq +seq_len_pre = 10 +seq_len_pre_min = 2 +# number of training data +sample_num = 1000 + +# -------------- parameters for generating testing data -------------------- +seq_len_test = 10 +seq_len_min_test = 3 +seq_len_pre_test = 10 +seq_len_pre_test_min = 2 +sample_num_test = 1 + + +seq_len = max(seq_len, seq_len_min) + +def gen_data(sample_number, gen_range, seq_len, seq_len_min, seq_len_pre, seq_len_pre_min, is_fixed_len = True): + data = [] + + if is_fixed_len: + seq_len_actual = seq_len + + for i in range(0, sample_number): + sample = [] + if not is_fixed_len: + seq_len_actual = np.random.randint(seq_len_min, seq_len) + seq_len_actual_pre = np.random.randint(seq_len_pre_min, seq_len_pre) + sample0 = np.random.randint(1, gen_range, size=seq_len_actual) + sample_pre = np.zeros(seq_len_actual_pre) + sample_pre = sample_pre.astype(int) + sample = np.concatenate([sample_pre, sample0]) + data.append([sample.tolist(), sample0[0]]) + + return data + +def gen_data_prefix(sample_number, gen_range, seq_len, seq_len_min, seq_len_pre, is_fixed_len = True): + data = [] + + if is_fixed_len: + seq_len_actual = seq_len + + for i in range(0, sample_number): + sample = [] + if not is_fixed_len: + seq_len_actual = np.random.randint(seq_len)+1 + seq_len_actual = max(seq_len_actual, seq_len_min) + sample = np.random.randint(gen_range, size=seq_len_actual) + data.append([sample.tolist(), sample[1]]) + + return data + + +data = gen_data(sample_num, gen_range, seq_len, seq_len_min, seq_len_pre, seq_len_pre_min, False) +data_test = gen_data(sample_num_test, gen_range, seq_len_test, seq_len_min_test, seq_len_pre_test, seq_len_pre_test_min, False) + + +@provider(input_types={"input_sequence" : integer_value_sequence(gen_range+1), + "ground_truth": integer_value(gen_range+1)}) +def process_seq_train(settings, file_name): + for d in data: + yield {"input_sequence": d[0], 'ground_truth': d[1]} + + +@provider(input_types={"input_sequence" : integer_value_sequence(gen_range+1), + "ground_truth": integer_value(gen_range+1)}) +def process_seq_test(settings, file_name): + for d in data_test: + yield {"input_sequence": d[0], 'ground_truth': d[1]} diff --git a/demo/memnet/dummy.list b/demo/memnet/dummy.list new file mode 100644 index 0000000000000..0e52665e11298 --- /dev/null +++ b/demo/memnet/dummy.list @@ -0,0 +1 @@ +dummy_file_no_use diff --git a/demo/memnet/external_memory.py b/demo/memnet/external_memory.py new file mode 100644 index 0000000000000..3141e52fd8cf1 --- /dev/null +++ b/demo/memnet/external_memory.py @@ -0,0 +1,119 @@ +#edit-mode: -*- python -*- +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer_config_helpers import * + + +class ExternalMemory(object): + def __init__(self, name, mem_slot_size, mem_fea_size, ext_mem_initial, is_test=False, scale=5): + self.name = name + self.mem_slot_size = mem_slot_size + self.mem_fea_size = mem_fea_size + self.scale = 5 + self.external_memory = memory(name=self.name, + size=mem_fea_size*mem_slot_size, + boot_bias= ParamAttr(initial_std=0.01, + initial_mean=0.)) + self.is_test = is_test + + def read(self, read_key): + cosine_similarity_read = cos_sim(read_key, self.external_memory, scale=self.scale, size=self.mem_slot_size) + norm_cosine_similarity_read = mixed_layer(input= + identity_projection(cosine_similarity_read), + bias_attr = False, + act = SoftmaxActivation(), + size = self.mem_slot_size, + name='read_weight') + + memory_read = linear_comb_layer(weights=norm_cosine_similarity_read, + vectors=self.external_memory, + size=self.mem_fea_size, name='read_content') + + if self.is_test: + print_layer(input=[norm_cosine_similarity_read, memory_read]) + + return memory_read + + def write(self, write_key): + cosine_similarity_write = cos_sim(write_key, self.external_memory, + scale=self.scale, size=self.mem_slot_size) + norm_cosine_similarity_write = mixed_layer(input= + identity_projection(cosine_similarity_write), + bias_attr = False, + act = SoftmaxActivation(), + size = self.mem_slot_size, + name='write_weight') + if self.is_test: + print_layer(input=[norm_cosine_similarity_write]) + + add_vec = mixed_layer(input = full_matrix_projection(write_key), + bias_attr = None, + act = SoftmaxActivation(), + size = self.mem_fea_size, + name='add_vector') + + erase_vec = self.MakeConstantVector(self.mem_fea_size, 1.0, write_key) + + + if self.is_test: + print_layer(input=[erase_vec]) + print_layer(input=[add_vec]) + + out_prod = out_prod_layer(norm_cosine_similarity_write, erase_vec, name="outer") + + memory_remove = mixed_layer(input=dotmul_operator(a=self.external_memory, b=out_prod)) + + memory_remove_neg = slope_intercept_layer(input=memory_remove, slope=-1.0, intercept=0) + + # memory_updated = memory_mat - memory_remove = memory_mat + memory_remove_neg + memory_removed = mixed_layer(input = [identity_projection(input=self.external_memory), + identity_projection(input=memory_remove_neg)], + bias_attr = False, + act = LinearActivation()) + + out_prod_add = out_prod_layer(norm_cosine_similarity_write, add_vec, name="outer_add") + + memory_output = mixed_layer(input = [identity_projection(input=memory_removed), + identity_projection(input=out_prod_add)], + bias_attr = False, + act = LinearActivation(), + name=self.name) + if self.is_test: + print_layer(input=[memory_output]) + + return memory_output + + def MakeConstantVector(self, vec_size, value, dummy_input): + constant_scalar = mixed_layer(input=full_matrix_projection(input=dummy_input, + param_attr = ParamAttr(learning_rate = 0, + initial_mean = 0, + initial_std = 0)), + bias_attr = ParamAttr(initial_mean=value, + initial_std=0.0, + learning_rate=0), + act = LinearActivation(), + size = 1, + name = 'constant_scalar') + constant = mixed_layer(input=full_matrix_projection(input=constant_scalar, + param_attr=ParamAttr(learning_rate = 0, + initial_mean = 1, + initial_std = 0)), + bias_attr = False, + act = LinearActivation(), + size = vec_size, + name = 'constant_vector') + return constant + + diff --git a/demo/memnet/external_memory_example.conf b/demo/memnet/external_memory_example.conf new file mode 100644 index 0000000000000..f4b410cdb0ed2 --- /dev/null +++ b/demo/memnet/external_memory_example.conf @@ -0,0 +1,117 @@ +#edit-mode: -*- python -*- +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer_config_helpers import * +from external_memory import * + +######################### parameters ############################### +is_test = get_config_arg('is_test', bool, False) +dict_dim = get_config_arg('dict_size', int, 8) # size of the dictionary +label_dim = dict_dim # the prediction has the same range as the input +word_embedding_dim = get_config_arg('word_emb_dim', int, 6) # dimension of the embedding +mem_fea_size = get_config_arg('size of each memory slot', int, 6) +mem_slot_size = get_config_arg('number of memory slots', int, 5) +controller_signal_dim = get_config_arg('the dim of the controller signal', int, 4) + + +######################## data source ################################ +if not is_test: + define_py_data_sources2(train_list='dummy.list', + test_list=None, + module='data_provider_mem', + obj='process_seq_train') +else: + define_py_data_sources2(train_list=None, + test_list='dummy.list', + module='data_provider_mem', + obj='process_seq_test') + + +settings( + batch_size=10, + learning_method=AdamOptimizer(), + learning_rate=1e-3) + + +######################## network configure ################################ +data = data_layer(name="input_sequence", size=dict_dim) +gt_label = data_layer(name="ground_truth", size=label_dim) + + + +emb = embedding_layer(input=data, size=word_embedding_dim) + +def step_mem(y): + external_memory = ExternalMemory('external_memory', mem_slot_size, mem_fea_size, False) + rnn_memory = memory(name="rnn_memory", + size=controller_signal_dim, + boot_bias= ParamAttr(initial_std=0.0, + initial_mean=0.)) + rnn_mem_out = mixed_layer(input = [full_matrix_projection(y), + full_matrix_projection(rnn_memory)], + bias_attr = None, + act = LinearActivation(), + name='rnn_memory', + size = controller_signal_dim) + + control_signal = mixed_layer(input = [full_matrix_projection(y), + full_matrix_projection(rnn_mem_out)], + bias_attr = None, + act = LinearActivation(), + name = 'control_signal', + size = controller_signal_dim) + read_key = mixed_layer(input = [full_matrix_projection(y), + full_matrix_projection(control_signal)], + bias_attr = None, + act = LinearActivation(), + size = mem_fea_size) + memory_read = external_memory.read(read_key) + write_key = mixed_layer(input = [full_matrix_projection(y), + full_matrix_projection(control_signal)], + bias_attr = None, + act = LinearActivation(), + size = mem_fea_size) + memory_out = external_memory.write(write_key) + return memory_read + + + +out = recurrent_group( + name="rnn", + step=step_mem, + input=[emb]) + +if not is_test: + out = last_seq(input=out) + +pred = mixed_layer(input = full_matrix_projection(out), + bias_attr = True, + act = SoftmaxActivation(), + size = label_dim) + + + +if is_test: + pred = last_seq(input=pred) + pred_id = maxid_layer(pred, name="prediction") + print_layer(input=[data]) + print_layer(input=[gt_label]) + print_layer(input=[pred_id]) + + cost = cross_entropy(input=pred, label=gt_label, name='cost_cls') + outputs(cost) +else: + outputs(classification_cost(input=pred, + label=gt_label)) diff --git a/demo/memnet/test.sh b/demo/memnet/test.sh new file mode 100755 index 0000000000000..54a77fee40eb7 --- /dev/null +++ b/demo/memnet/test.sh @@ -0,0 +1,33 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e +config=./external_memory_example.conf +log=test.log + +# change the following path to the model to be tested +evaluate_pass="./mem_model/pass-00030" + +echo 'evaluating from pass '$evaluate_pass +model_list=./model.list +touch $model_list | echo $evaluate_pass > $model_list + +paddle train \ + -v=5 \ + --config=$config \ + --model_list=$model_list \ + --job=test \ + --use_gpu=1 \ + --config_args=is_test=1 \ + 2>&1 | tee $log diff --git a/demo/memnet/train.sh b/demo/memnet/train.sh new file mode 100755 index 0000000000000..2d9927e1f71f6 --- /dev/null +++ b/demo/memnet/train.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +set -e +config=external_memory_example.conf +output=./mem_model +log=train.log + +paddle train \ +--config_args=is_test=0 \ +--config=$config \ +--dot_period=10 \ +--log_period=100 \ +--test_all_data_in_one_period=1 \ +--use_gpu=1 \ +--trainer_count=1 \ +--num_passes=30 \ +--save_dir=$output \ +2>&1 | tee $log + From 1be184e4b2af8d9f6fb2417cb16d28d67c75e1a9 Mon Sep 17 00:00:00 2001 From: zhanghaichao Date: Fri, 2 Dec 2016 16:26:05 -0800 Subject: [PATCH 2/2] updated the memnet demo with added comments and updated implementations --- demo/memnet/data_provider_mem.py | 2 +- demo/memnet/external_memory.py | 90 ++++++++++++++++++++++---------- 2 files changed, 62 insertions(+), 30 deletions(-) diff --git a/demo/memnet/data_provider_mem.py b/demo/memnet/data_provider_mem.py index c4167d6c15eb2..126fee87851bc 100644 --- a/demo/memnet/data_provider_mem.py +++ b/demo/memnet/data_provider_mem.py @@ -1,4 +1,4 @@ -# Copyright (c) 2015 Baidu, Inc. All Rights Reserved +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/demo/memnet/external_memory.py b/demo/memnet/external_memory.py index 3141e52fd8cf1..53329fbf1a2e6 100644 --- a/demo/memnet/external_memory.py +++ b/demo/memnet/external_memory.py @@ -17,29 +17,51 @@ class ExternalMemory(object): - def __init__(self, name, mem_slot_size, mem_fea_size, ext_mem_initial, is_test=False, scale=5): + """ + External memory network class, with differentiable read/write heads. + + :param name: name for the external memory + :type name: basestring + :param mem_slot_size: number of slots to be used for the external memory + :type mem_slot_size: int + :param mem_fea_size: size of each memory slot + :type mem_fea_size: int + :param is_test: flag indicating training (is_test=False) or testing (is_test=True) + :type is_test: bool + :param scale: a multiplicative factor applied to the read/write weights + :param scale: int + """ + def __init__(self, name, mem_slot_size, mem_fea_size, is_test=False, scale=5): self.name = name self.mem_slot_size = mem_slot_size self.mem_fea_size = mem_fea_size - self.scale = 5 - self.external_memory = memory(name=self.name, - size=mem_fea_size*mem_slot_size, - boot_bias= ParamAttr(initial_std=0.01, - initial_mean=0.)) + self.scale = scale + self.external_memory = memory(name=self.name, + size=mem_fea_size*mem_slot_size, + boot_bias= ParamAttr(initial_std=0.01, + initial_mean=0.)) self.is_test = is_test def read(self, read_key): + """ + Read head for the external memory. + :param read_key: key used for reading via content-based addressing, + with size as mem_fea_size + :type read_key: LayerOutput + :return: memory_read + :rtype: LayerOutput + """ cosine_similarity_read = cos_sim(read_key, self.external_memory, scale=self.scale, size=self.mem_slot_size) norm_cosine_similarity_read = mixed_layer(input= identity_projection(cosine_similarity_read), bias_attr = False, act = SoftmaxActivation(), size = self.mem_slot_size, - name='read_weight') + name=self.name+'_read_weight') memory_read = linear_comb_layer(weights=norm_cosine_similarity_read, vectors=self.external_memory, - size=self.mem_fea_size, name='read_content') + size=self.mem_fea_size, name=self.name+'_read_content') if self.is_test: print_layer(input=[norm_cosine_similarity_read, memory_read]) @@ -47,6 +69,14 @@ def read(self, read_key): return memory_read def write(self, write_key): + """ + Write head for the external memory. + :param write_key: the key (and content) used for writing via content-based addressing, + with size as mem_fea_size + :type write_key: LayerOutput + :return: updated memory content + :rtype: LayerOutput + """ cosine_similarity_write = cos_sim(write_key, self.external_memory, scale=self.scale, size=self.mem_slot_size) norm_cosine_similarity_write = mixed_layer(input= @@ -54,7 +84,7 @@ def write(self, write_key): bias_attr = False, act = SoftmaxActivation(), size = self.mem_slot_size, - name='write_weight') + name=self.name+'_write_weight') if self.is_test: print_layer(input=[norm_cosine_similarity_write]) @@ -62,40 +92,42 @@ def write(self, write_key): bias_attr = None, act = SoftmaxActivation(), size = self.mem_fea_size, - name='add_vector') - - erase_vec = self.MakeConstantVector(self.mem_fea_size, 1.0, write_key) + name=self.name+'_add_vector') + erase_vec = self.make_constant_vector(self.mem_fea_size, 1.0, write_key, self.name+"_constant_vector") if self.is_test: print_layer(input=[erase_vec]) print_layer(input=[add_vec]) - out_prod = out_prod_layer(norm_cosine_similarity_write, erase_vec, name="outer") + out_prod = out_prod_layer(norm_cosine_similarity_write, erase_vec, name=self.name+"_outer") memory_remove = mixed_layer(input=dotmul_operator(a=self.external_memory, b=out_prod)) - memory_remove_neg = slope_intercept_layer(input=memory_remove, slope=-1.0, intercept=0) - - # memory_updated = memory_mat - memory_remove = memory_mat + memory_remove_neg - memory_removed = mixed_layer(input = [identity_projection(input=self.external_memory), - identity_projection(input=memory_remove_neg)], - bias_attr = False, - act = LinearActivation()) + memory_removed = self.external_memory - memory_remove - out_prod_add = out_prod_layer(norm_cosine_similarity_write, add_vec, name="outer_add") + out_prod_add = out_prod_layer(norm_cosine_similarity_write, add_vec, name=self.name+"_outer_add") + memory_output = addto_layer(input=[memory_removed, out_prod_add], name=self.name) - memory_output = mixed_layer(input = [identity_projection(input=memory_removed), - identity_projection(input=out_prod_add)], - bias_attr = False, - act = LinearActivation(), - name=self.name) if self.is_test: print_layer(input=[memory_output]) return memory_output - def MakeConstantVector(self, vec_size, value, dummy_input): + def make_constant_vector(self, vec_size, value, dummy_input, layer_name): + """ + Auxiliary function for generating a constant vector. + :param vec_size: the size of the constant vector + :type vec_size: int + :param value: value of the elements in the constant vector + :type value: float + :param dummy_input: a dummy input layer to the constant vector network + :type LayerOutput + :param layer_name: name for the constant vector + :type layer_name: basestring + :return: memory_read + :rtype: LayerOutput + """ constant_scalar = mixed_layer(input=full_matrix_projection(input=dummy_input, param_attr = ParamAttr(learning_rate = 0, initial_mean = 0, @@ -105,7 +137,7 @@ def MakeConstantVector(self, vec_size, value, dummy_input): learning_rate=0), act = LinearActivation(), size = 1, - name = 'constant_scalar') + name = layer_name+'_constant_scalar') constant = mixed_layer(input=full_matrix_projection(input=constant_scalar, param_attr=ParamAttr(learning_rate = 0, initial_mean = 1, @@ -113,7 +145,7 @@ def MakeConstantVector(self, vec_size, value, dummy_input): bias_attr = False, act = LinearActivation(), size = vec_size, - name = 'constant_vector') + name = layer_name) return constant