Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add external memory network demo #696

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 21 additions & 0 deletions demo/memnet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Memory Network

## Introduction ##
This demo provides a simple example usage of the external memory in a way similar to the Neural Turing Machine (NTM) with content based addressing and differentiable read and write head.
For more technical details, please refer to the [NTM paper](https://arxiv.org/abs/1410.5401).

## Task Description ##
Here we design a simple task for illustration purpose. The input is a sequence with variable number of zeros followed with a variable number of non-zero elements, e.g., [0, 0, 0, 3, 1, 5, ...]. The task is to memorize the first non-zero number (e.g., 3) and to output this number in the end after going through the whole sequence.

## Folder Structure ##
* external_memory.py: the implementation of the external memory class.
* external_memory_example.conf: example usage of the external memory class.
* data_provider_mem.py: generates the training and testing data for the example.
* train.sh and test.sh: the scripts to run training and testing.

## How to Run ##
* training: ./train.sh
* testing: ./test.sh



96 changes: 96 additions & 0 deletions demo/memnet/data_provider_mem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.trainer.PyDataProvider2 import *
import numpy as np

########################### Parameters for Data Generation #################
gen_range = 8 # same as the size of the dictionary
#--------------- parameters for generating training data -------------------
# the sequence has a all-zero sub-vector in the beginning followed with a non-zero vector
# seq = [zero_sub_seq, non_zero_sub_seq]

# parameters for non_zero_sub_seq
seq_len = 10 # length of the non_zero_sub_seq
seq_len_min = 2 # minimum length if is_fixed_len is False;
# seq_len will be used as the maximum length in this case,
# i.e., the length will be sampled from [seq_len_min, seq_len]
# parameters for zero_sub_seq
seq_len_pre = 10
seq_len_pre_min = 2
# number of training data
sample_num = 1000

# -------------- parameters for generating testing data --------------------
seq_len_test = 10
seq_len_min_test = 3
seq_len_pre_test = 10
seq_len_pre_test_min = 2
sample_num_test = 1


seq_len = max(seq_len, seq_len_min)

def gen_data(sample_number, gen_range, seq_len, seq_len_min, seq_len_pre, seq_len_pre_min, is_fixed_len = True):
data = []

if is_fixed_len:
seq_len_actual = seq_len

for i in range(0, sample_number):
sample = []
if not is_fixed_len:
seq_len_actual = np.random.randint(seq_len_min, seq_len)
seq_len_actual_pre = np.random.randint(seq_len_pre_min, seq_len_pre)
sample0 = np.random.randint(1, gen_range, size=seq_len_actual)
sample_pre = np.zeros(seq_len_actual_pre)
sample_pre = sample_pre.astype(int)
sample = np.concatenate([sample_pre, sample0])
data.append([sample.tolist(), sample0[0]])

return data

def gen_data_prefix(sample_number, gen_range, seq_len, seq_len_min, seq_len_pre, is_fixed_len = True):
data = []

if is_fixed_len:
seq_len_actual = seq_len

for i in range(0, sample_number):
sample = []
if not is_fixed_len:
seq_len_actual = np.random.randint(seq_len)+1
seq_len_actual = max(seq_len_actual, seq_len_min)
sample = np.random.randint(gen_range, size=seq_len_actual)
data.append([sample.tolist(), sample[1]])

return data


data = gen_data(sample_num, gen_range, seq_len, seq_len_min, seq_len_pre, seq_len_pre_min, False)
data_test = gen_data(sample_num_test, gen_range, seq_len_test, seq_len_min_test, seq_len_pre_test, seq_len_pre_test_min, False)


@provider(input_types={"input_sequence" : integer_value_sequence(gen_range+1),
"ground_truth": integer_value(gen_range+1)})
def process_seq_train(settings, file_name):
for d in data:
yield {"input_sequence": d[0], 'ground_truth': d[1]}


@provider(input_types={"input_sequence" : integer_value_sequence(gen_range+1),
"ground_truth": integer_value(gen_range+1)})
def process_seq_test(settings, file_name):
for d in data_test:
yield {"input_sequence": d[0], 'ground_truth': d[1]}
1 change: 1 addition & 0 deletions demo/memnet/dummy.list
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
dummy_file_no_use
151 changes: 151 additions & 0 deletions demo/memnet/external_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
#edit-mode: -*- python -*-
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.trainer_config_helpers import *


class ExternalMemory(object):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to comment for the class, and its member functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments have been added to both the class and member functions.

"""
External memory network class, with differentiable read/write heads.

:param name: name for the external memory
:type name: basestring
:param mem_slot_size: number of slots to be used for the external memory
:type mem_slot_size: int
:param mem_fea_size: size of each memory slot
:type mem_fea_size: int
:param is_test: flag indicating training (is_test=False) or testing (is_test=True)
:type is_test: bool
:param scale: a multiplicative factor applied to the read/write weights
:param scale: int
"""
def __init__(self, name, mem_slot_size, mem_fea_size, is_test=False, scale=5):
self.name = name
self.mem_slot_size = mem_slot_size
self.mem_fea_size = mem_fea_size
self.scale = scale
self.external_memory = memory(name=self.name,
size=mem_fea_size*mem_slot_size,
boot_bias= ParamAttr(initial_std=0.01,
initial_mean=0.))
self.is_test = is_test

def read(self, read_key):
"""
Read head for the external memory.
:param read_key: key used for reading via content-based addressing,
with size as mem_fea_size
:type read_key: LayerOutput
:return: memory_read
:rtype: LayerOutput
"""
cosine_similarity_read = cos_sim(read_key, self.external_memory, scale=self.scale, size=self.mem_slot_size)
norm_cosine_similarity_read = mixed_layer(input=
identity_projection(cosine_similarity_read),
bias_attr = False,
act = SoftmaxActivation(),
size = self.mem_slot_size,
name=self.name+'_read_weight')

memory_read = linear_comb_layer(weights=norm_cosine_similarity_read,
vectors=self.external_memory,
size=self.mem_fea_size, name=self.name+'_read_content')

if self.is_test:
print_layer(input=[norm_cosine_similarity_read, memory_read])

return memory_read

def write(self, write_key):
"""
Write head for the external memory.
:param write_key: the key (and content) used for writing via content-based addressing,
with size as mem_fea_size
:type write_key: LayerOutput
:return: updated memory content
:rtype: LayerOutput
"""
cosine_similarity_write = cos_sim(write_key, self.external_memory,
scale=self.scale, size=self.mem_slot_size)
norm_cosine_similarity_write = mixed_layer(input=
identity_projection(cosine_similarity_write),
bias_attr = False,
act = SoftmaxActivation(),
size = self.mem_slot_size,
name=self.name+'_write_weight')
if self.is_test:
print_layer(input=[norm_cosine_similarity_write])

add_vec = mixed_layer(input = full_matrix_projection(write_key),
bias_attr = None,
act = SoftmaxActivation(),
size = self.mem_fea_size,
name=self.name+'_add_vector')

erase_vec = self.make_constant_vector(self.mem_fea_size, 1.0, write_key, self.name+"_constant_vector")

if self.is_test:
print_layer(input=[erase_vec])
print_layer(input=[add_vec])

out_prod = out_prod_layer(norm_cosine_similarity_write, erase_vec, name=self.name+"_outer")

memory_remove = mixed_layer(input=dotmul_operator(a=self.external_memory, b=out_prod))

memory_removed = self.external_memory - memory_remove

out_prod_add = out_prod_layer(norm_cosine_similarity_write, add_vec, name=self.name+"_outer_add")
memory_output = addto_layer(input=[memory_removed, out_prod_add], name=self.name)

if self.is_test:
print_layer(input=[memory_output])

return memory_output

def make_constant_vector(self, vec_size, value, dummy_input, layer_name):
"""
Auxiliary function for generating a constant vector.
:param vec_size: the size of the constant vector
:type vec_size: int
:param value: value of the elements in the constant vector
:type value: float
:param dummy_input: a dummy input layer to the constant vector network
:type LayerOutput
:param layer_name: name for the constant vector
:type layer_name: basestring
:return: memory_read
:rtype: LayerOutput
"""
constant_scalar = mixed_layer(input=full_matrix_projection(input=dummy_input,
param_attr = ParamAttr(learning_rate = 0,
initial_mean = 0,
initial_std = 0)),
bias_attr = ParamAttr(initial_mean=value,
initial_std=0.0,
learning_rate=0),
act = LinearActivation(),
size = 1,
name = layer_name+'_constant_scalar')
constant = mixed_layer(input=full_matrix_projection(input=constant_scalar,
param_attr=ParamAttr(learning_rate = 0,
initial_mean = 1,
initial_std = 0)),
bias_attr = False,
act = LinearActivation(),
size = vec_size,
name = layer_name)
return constant


117 changes: 117 additions & 0 deletions demo/memnet/external_memory_example.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#edit-mode: -*- python -*-
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.trainer_config_helpers import *
from external_memory import *

######################### parameters ###############################
is_test = get_config_arg('is_test', bool, False)
dict_dim = get_config_arg('dict_size', int, 8) # size of the dictionary
label_dim = dict_dim # the prediction has the same range as the input
word_embedding_dim = get_config_arg('word_emb_dim', int, 6) # dimension of the embedding
mem_fea_size = get_config_arg('size of each memory slot', int, 6)
mem_slot_size = get_config_arg('number of memory slots', int, 5)
controller_signal_dim = get_config_arg('the dim of the controller signal', int, 4)


######################## data source ################################
if not is_test:
define_py_data_sources2(train_list='dummy.list',
test_list=None,
module='data_provider_mem',
obj='process_seq_train')
else:
define_py_data_sources2(train_list=None,
test_list='dummy.list',
module='data_provider_mem',
obj='process_seq_test')


settings(
batch_size=10,
learning_method=AdamOptimizer(),
learning_rate=1e-3)


######################## network configure ################################
data = data_layer(name="input_sequence", size=dict_dim)
gt_label = data_layer(name="ground_truth", size=label_dim)



emb = embedding_layer(input=data, size=word_embedding_dim)

def step_mem(y):
external_memory = ExternalMemory('external_memory', mem_slot_size, mem_fea_size, False)
rnn_memory = memory(name="rnn_memory",
size=controller_signal_dim,
boot_bias= ParamAttr(initial_std=0.0,
initial_mean=0.))
rnn_mem_out = mixed_layer(input = [full_matrix_projection(y),
full_matrix_projection(rnn_memory)],
bias_attr = None,
act = LinearActivation(),
name='rnn_memory',
size = controller_signal_dim)

control_signal = mixed_layer(input = [full_matrix_projection(y),
full_matrix_projection(rnn_mem_out)],
bias_attr = None,
act = LinearActivation(),
name = 'control_signal',
size = controller_signal_dim)
read_key = mixed_layer(input = [full_matrix_projection(y),
full_matrix_projection(control_signal)],
bias_attr = None,
act = LinearActivation(),
size = mem_fea_size)
memory_read = external_memory.read(read_key)
write_key = mixed_layer(input = [full_matrix_projection(y),
full_matrix_projection(control_signal)],
bias_attr = None,
act = LinearActivation(),
size = mem_fea_size)
memory_out = external_memory.write(write_key)
return memory_read



out = recurrent_group(
name="rnn",
step=step_mem,
input=[emb])

if not is_test:
out = last_seq(input=out)

pred = mixed_layer(input = full_matrix_projection(out),
bias_attr = True,
act = SoftmaxActivation(),
size = label_dim)



if is_test:
pred = last_seq(input=pred)
pred_id = maxid_layer(pred, name="prediction")
print_layer(input=[data])
print_layer(input=[gt_label])
print_layer(input=[pred_id])

cost = cross_entropy(input=pred, label=gt_label, name='cost_cls')
outputs(cost)
else:
outputs(classification_cost(input=pred,
label=gt_label))