Skip to content

Commit

Permalink
follow comments and rename the directory.
Browse files Browse the repository at this point in the history
  • Loading branch information
lcy-seso committed Jun 26, 2017
1 parent 6b0f946 commit a8e4f42
Show file tree
Hide file tree
Showing 13 changed files with 52 additions and 17 deletions.
3 changes: 3 additions & 0 deletions generte_sequence_by_rnn_lm/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*.pyc
*.tar.gz
models
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,22 @@

class BeamSearch(object):
"""
generating sequence by using beam search
Generating sequence by beam search
NOTE: this class only implements generating one sentence at a time.
"""

def __init__(self, inferer, word_dict_file, beam_size=1, max_gen_len=100):
"""
constructor method.
:param inferer: object of paddle.Inference that represent the entire
network to forward compute the test batch.
:param inferer: object of paddle.Inference that represents the entire
network to forward compute the test batch
:type inferer: paddle.Inference
:param word_dict_file: path of word dictionary file
:type word_dict_file: str
:param beam_size: expansion width in each iteration
:type param beam_size: int
:param max_gen_len: the maximum number of iterations.
:param max_gen_len: the maximum number of iterations
:type max_gen_len: int
"""
self.inferer = inferer
Expand All @@ -43,7 +43,7 @@ def __init__(self, inferer, word_dict_file, beam_size=1, max_gen_len=100):
self.unk_id = next(x[0] for x in self.ids_2_word.iteritems()
if x[1] == "<unk>")
except StopIteration:
logger.fatal(("the word dictionay must contains an ending mark "
logger.fatal(("the word dictionay must contain an ending mark "
"in the text generation task."))

self.candidate_paths = []
Expand All @@ -52,7 +52,7 @@ def __init__(self, inferer, word_dict_file, beam_size=1, max_gen_len=100):
def _top_k(self, softmax_out, k):
"""
get indices of the words with k highest probablities.
NOTE: <unk> will be exclued if it is among the top k words, then word
NOTE: <unk> will be excluded if it is among the top k words, then word
with (k + 1)th highest probability will be returned.
:param softmax_out: probablity over the dictionary
Expand All @@ -71,7 +71,7 @@ def _forward_batch(self, batch):
:params batch: the input data batch
:type batch: list
:return: probalities of the predicted word
:return: probablities of the predicted word
:rtype: ndarray
"""
return self.inferer.infer(input=batch, field=["value"])
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes
File renamed without changes
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,18 @@ def rnn_lm(vocab_dim,
"""
RNN language model definition.
:param vocab_dim: size of vocab.
:param emb_dim: embedding vector"s dimension.
:param vocab_dim: size of vocabulary.
:type vocab_dim: int
:param emb_dim: dimension of the embedding vector
:type emb_dim: int
:param rnn_type: the type of RNN cell.
:param hidden_size: number of unit.
:param stacked_rnn_num: layer number.
:type rnn_type: int
:param hidden_size: number of hidden unit.
:type hidden_size: int
:param stacked_rnn_num: number of stacked rnn cell.
:type stacked_rnn_num: int
:return: cost and output layer of model.
:rtype: LayerOutput
"""

# input layers
Expand Down
File renamed without changes.
12 changes: 8 additions & 4 deletions language_model/train.py → generte_sequence_by_rnn_lm/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,16 @@ def train(topology,
"""
train model.
:param model_cost: cost layer of the model to train.
:param topology: cost layer of the model to train.
:type topology: LayerOuput
:param train_reader: train data reader.
:type trainer_reader: collections.Iterable
:param test_reader: test data reader.
:param model_file_name_prefix: model"s prefix name.
:param num_passes: epoch.
:return:
:type test_reader: collections.Iterable
:param model_save_dir: path to save the trained model
:type model_save_dir: str
:param num_passes: number of epoch
:type num_passes: int
"""
if not os.path.exists(model_save_dir):
os.mkdir(model_save_dir)
Expand Down
26 changes: 24 additions & 2 deletions language_model/utils.py → generte_sequence_by_rnn_lm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,19 @@ def build_dict(data_file,
insert_extra_words=["<unk>", "<e>"]):
"""
:param data_file: path of data file
:type data_file: str
:param save_path: path to save the word dictionary
:type save_path: str
:param vocab_max_size: if vocab_max_size is set, top vocab_max_size words
will be added into word vocabulary
:type vocab_max_size: int
:param cutoff_thd: if cutoff_thd is set, words whose frequencies are less
than cutoff_thd will not added into word vocabulary.
than cutoff_thd will not be added into word vocabulary.
NOTE that: vocab_max_size and cutoff_thd cannot be set at the same time
:type cutoff_word_fre: int
:param extra_keys: extra keys defined by users that added into the word
dictionary, ususally these keys includes <unk>, start and ending marks
dictionary, ususally these keys include <unk>, start and ending marks
:type extra_keys: list
"""
word_count = defaultdict(int)
with open(data_file, "r") as f:
Expand Down Expand Up @@ -53,12 +58,29 @@ def build_dict(data_file,

def load_dict(dict_path):
"""
load word dictionary from the given file. Each line of the give file is
a word in the word dictionary. The first column of the line, seperated by
TAB, is the key, while the line index is the value.
:param dict_path: path of word dictionary
:type dict_path: str
:return: the dictionary
:rtype: dict
"""
return dict((line.strip().split("\t")[0], idx)
for idx, line in enumerate(open(dict_path, "r").readlines()))


def load_reverse_dict(dict_path):
"""
load word dictionary from the given file. Each line of the give file is
a word in the word dictionary. The line index is the key, while the first
column of the line, seperated by TAB, is the value.
:param dict_path: path of word dictionary
:type dict_path: str
:return: the dictionary
:rtype: dict
"""
return dict((idx, line.strip().split("\t")[0])
for idx, line in enumerate(open(dict_path, "r").readlines()))

0 comments on commit a8e4f42

Please sign in to comment.