Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add external memory network demo #696

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions demo/memnet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Memory Network

## Introduction ##
This demo provides a simple example usage of the external memory in a way similar to the Neural Turing Machine (NTM) with content based addressing and differentiable read and write head.
For more technical details, please refer to the [NTM paper](https://arxiv.org/abs/1410.5401).

## Task Description ##
Here we design a simple task for illustration purpose. The input is a sequence with variable number of zeros followed with a variable number of non-zero elements, e.g., [0, 0, 0, 3, 1, 5, ...]. The task is to memorize the first non-zero number (e.g., 3) and to output this number in the end after going through the whole sequence.

## Folder Structure ##
* external_memory.py: the implementation of the external memory class.
* external_memory_example.conf: example usage of the external memory class.
* data_provider_mem.py: generates the training and testing data for the example.
* train.sh and test.sh: the scripts to run training and testing.

## How to Run ##
* training: ./train.sh
* testing: ./test.sh



96 changes: 96 additions & 0 deletions demo/memnet/data_provider_mem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) 2015 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.trainer.PyDataProvider2 import *
import numpy as np

########################### Parameters for Data Generation #################
gen_range = 8 # same as the size of the dictionary
#--------------- parameters for generating training data -------------------
# the sequence has a all-zero sub-vector in the beginning followed with a non-zero vector
# seq = [zero_sub_seq, non_zero_sub_seq]

# parameters for non_zero_sub_seq
seq_len = 10 # length of the non_zero_sub_seq
seq_len_min = 2 # minimum length if is_fixed_len is False;
# seq_len will be used as the maximum length in this case,
# i.e., the length will be sampled from [seq_len_min, seq_len]
# parameters for zero_sub_seq
seq_len_pre = 10
seq_len_pre_min = 2
# number of training data
sample_num = 1000

# -------------- parameters for generating testing data --------------------
seq_len_test = 10
seq_len_min_test = 3
seq_len_pre_test = 10
seq_len_pre_test_min = 2
sample_num_test = 1


seq_len = max(seq_len, seq_len_min)

def gen_data(sample_number, gen_range, seq_len, seq_len_min, seq_len_pre, seq_len_pre_min, is_fixed_len = True):
data = []

if is_fixed_len:
seq_len_actual = seq_len

for i in range(0, sample_number):
sample = []
if not is_fixed_len:
seq_len_actual = np.random.randint(seq_len_min, seq_len)
seq_len_actual_pre = np.random.randint(seq_len_pre_min, seq_len_pre)
sample0 = np.random.randint(1, gen_range, size=seq_len_actual)
sample_pre = np.zeros(seq_len_actual_pre)
sample_pre = sample_pre.astype(int)
sample = np.concatenate([sample_pre, sample0])
data.append([sample.tolist(), sample0[0]])

return data

def gen_data_prefix(sample_number, gen_range, seq_len, seq_len_min, seq_len_pre, is_fixed_len = True):
data = []

if is_fixed_len:
seq_len_actual = seq_len

for i in range(0, sample_number):
sample = []
if not is_fixed_len:
seq_len_actual = np.random.randint(seq_len)+1
seq_len_actual = max(seq_len_actual, seq_len_min)
sample = np.random.randint(gen_range, size=seq_len_actual)
data.append([sample.tolist(), sample[1]])

return data


data = gen_data(sample_num, gen_range, seq_len, seq_len_min, seq_len_pre, seq_len_pre_min, False)
data_test = gen_data(sample_num_test, gen_range, seq_len_test, seq_len_min_test, seq_len_pre_test, seq_len_pre_test_min, False)


@provider(input_types={"input_sequence" : integer_value_sequence(gen_range+1),
"ground_truth": integer_value(gen_range+1)})
def process_seq_train(settings, file_name):
for d in data:
yield {"input_sequence": d[0], 'ground_truth': d[1]}


@provider(input_types={"input_sequence" : integer_value_sequence(gen_range+1),
"ground_truth": integer_value(gen_range+1)})
def process_seq_test(settings, file_name):
for d in data_test:
yield {"input_sequence": d[0], 'ground_truth': d[1]}
1 change: 1 addition & 0 deletions demo/memnet/dummy.list
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
dummy_file_no_use
119 changes: 119 additions & 0 deletions demo/memnet/external_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#edit-mode: -*- python -*-
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.trainer_config_helpers import *


class ExternalMemory(object):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to comment for the class, and its member functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments have been added to both the class and member functions.

def __init__(self, name, mem_slot_size, mem_fea_size, ext_mem_initial, is_test=False, scale=5):
self.name = name
self.mem_slot_size = mem_slot_size
self.mem_fea_size = mem_fea_size
self.scale = 5
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.scale = scale

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been corrected.

self.external_memory = memory(name=self.name,
size=mem_fea_size*mem_slot_size,
boot_bias= ParamAttr(initial_std=0.01,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bad indent

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been updated.

initial_mean=0.))
self.is_test = is_test

def read(self, read_key):
cosine_similarity_read = cos_sim(read_key, self.external_memory, scale=self.scale, size=self.mem_slot_size)
norm_cosine_similarity_read = mixed_layer(input=
identity_projection(cosine_similarity_read),
bias_attr = False,
act = SoftmaxActivation(),
size = self.mem_slot_size,
name='read_weight')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in order to avoid name confict when using multiple memory, this and other names should be prefixed by self.name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar issues have been addressed.


memory_read = linear_comb_layer(weights=norm_cosine_similarity_read,
vectors=self.external_memory,
size=self.mem_fea_size, name='read_content')

if self.is_test:
print_layer(input=[norm_cosine_similarity_read, memory_read])

return memory_read

def write(self, write_key):
cosine_similarity_write = cos_sim(write_key, self.external_memory,
scale=self.scale, size=self.mem_slot_size)
norm_cosine_similarity_write = mixed_layer(input=
identity_projection(cosine_similarity_write),
bias_attr = False,
act = SoftmaxActivation(),
size = self.mem_slot_size,
name='write_weight')
if self.is_test:
print_layer(input=[norm_cosine_similarity_write])

add_vec = mixed_layer(input = full_matrix_projection(write_key),
bias_attr = None,
act = SoftmaxActivation(),
size = self.mem_fea_size,
name='add_vector')

erase_vec = self.MakeConstantVector(self.mem_fea_size, 1.0, write_key)


if self.is_test:
print_layer(input=[erase_vec])
print_layer(input=[add_vec])

out_prod = out_prod_layer(norm_cosine_similarity_write, erase_vec, name="outer")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating a constant vector erase_vec for this is very ugly. A nicer way to do this is to enhance "repeat" layer to allow repeat in both directions, similar to "repmat" in matlat.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking into the repeat layer currently.


memory_remove = mixed_layer(input=dotmul_operator(a=self.external_memory, b=out_prod))

memory_remove_neg = slope_intercept_layer(input=memory_remove, slope=-1.0, intercept=0)

# memory_updated = memory_mat - memory_remove = memory_mat + memory_remove_neg
memory_removed = mixed_layer(input = [identity_projection(input=self.external_memory),
identity_projection(input=memory_remove_neg)],
bias_attr = False,
act = LinearActivation())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

78 and 81 can be combinded as written as: memory_removed = self.external_memory - memory_remove.
See https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/trainer_config_helpers/tests/configs/math_ops.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part of the code has been updated using math_ops


out_prod_add = out_prod_layer(norm_cosine_similarity_write, add_vec, name="outer_add")

memory_output = mixed_layer(input = [identity_projection(input=memory_removed),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using addto_layer can make this looks simpler.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switch to addto_layer

identity_projection(input=out_prod_add)],
bias_attr = False,
act = LinearActivation(),
name=self.name)
if self.is_test:
print_layer(input=[memory_output])

return memory_output

def MakeConstantVector(self, vec_size, value, dummy_input):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python naming convention: make_constant_vector

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the function name following the convention.

constant_scalar = mixed_layer(input=full_matrix_projection(input=dummy_input,
param_attr = ParamAttr(learning_rate = 0,
initial_mean = 0,
initial_std = 0)),
bias_attr = ParamAttr(initial_mean=value,
initial_std=0.0,
learning_rate=0),
act = LinearActivation(),
size = 1,
name = 'constant_scalar')
constant = mixed_layer(input=full_matrix_projection(input=constant_scalar,
param_attr=ParamAttr(learning_rate = 0,
initial_mean = 1,
initial_std = 0)),
bias_attr = False,
act = LinearActivation(),
size = vec_size,
name = 'constant_vector')
return constant


117 changes: 117 additions & 0 deletions demo/memnet/external_memory_example.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#edit-mode: -*- python -*-
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.trainer_config_helpers import *
from external_memory import *

######################### parameters ###############################
is_test = get_config_arg('is_test', bool, False)
dict_dim = get_config_arg('dict_size', int, 8) # size of the dictionary
label_dim = dict_dim # the prediction has the same range as the input
word_embedding_dim = get_config_arg('word_emb_dim', int, 6) # dimension of the embedding
mem_fea_size = get_config_arg('size of each memory slot', int, 6)
mem_slot_size = get_config_arg('number of memory slots', int, 5)
controller_signal_dim = get_config_arg('the dim of the controller signal', int, 4)


######################## data source ################################
if not is_test:
define_py_data_sources2(train_list='dummy.list',
test_list=None,
module='data_provider_mem',
obj='process_seq_train')
else:
define_py_data_sources2(train_list=None,
test_list='dummy.list',
module='data_provider_mem',
obj='process_seq_test')


settings(
batch_size=10,
learning_method=AdamOptimizer(),
learning_rate=1e-3)


######################## network configure ################################
data = data_layer(name="input_sequence", size=dict_dim)
gt_label = data_layer(name="ground_truth", size=label_dim)



emb = embedding_layer(input=data, size=word_embedding_dim)

def step_mem(y):
external_memory = ExternalMemory('external_memory', mem_slot_size, mem_fea_size, False)
rnn_memory = memory(name="rnn_memory",
size=controller_signal_dim,
boot_bias= ParamAttr(initial_std=0.0,
initial_mean=0.))
rnn_mem_out = mixed_layer(input = [full_matrix_projection(y),
full_matrix_projection(rnn_memory)],
bias_attr = None,
act = LinearActivation(),
name='rnn_memory',
size = controller_signal_dim)

control_signal = mixed_layer(input = [full_matrix_projection(y),
full_matrix_projection(rnn_mem_out)],
bias_attr = None,
act = LinearActivation(),
name = 'control_signal',
size = controller_signal_dim)
read_key = mixed_layer(input = [full_matrix_projection(y),
full_matrix_projection(control_signal)],
bias_attr = None,
act = LinearActivation(),
size = mem_fea_size)
memory_read = external_memory.read(read_key)
write_key = mixed_layer(input = [full_matrix_projection(y),
full_matrix_projection(control_signal)],
bias_attr = None,
act = LinearActivation(),
size = mem_fea_size)
memory_out = external_memory.write(write_key)
return memory_read



out = recurrent_group(
name="rnn",
step=step_mem,
input=[emb])

if not is_test:
out = last_seq(input=out)

pred = mixed_layer(input = full_matrix_projection(out),
bias_attr = True,
act = SoftmaxActivation(),
size = label_dim)



if is_test:
pred = last_seq(input=pred)
pred_id = maxid_layer(pred, name="prediction")
print_layer(input=[data])
print_layer(input=[gt_label])
print_layer(input=[pred_id])

cost = cross_entropy(input=pred, label=gt_label, name='cost_cls')
outputs(cost)
else:
outputs(classification_cost(input=pred,
label=gt_label))
33 changes: 33 additions & 0 deletions demo/memnet/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
config=./external_memory_example.conf
log=test.log

# change the following path to the model to be tested
evaluate_pass="./mem_model/pass-00030"

echo 'evaluating from pass '$evaluate_pass
model_list=./model.list
touch $model_list | echo $evaluate_pass > $model_list

paddle train \
-v=5 \
--config=$config \
--model_list=$model_list \
--job=test \
--use_gpu=1 \
--config_args=is_test=1 \
2>&1 | tee $log
Loading