Skip to content

Commit

Permalink
IO: modify build_vocab to accept a list of dataset
Browse files Browse the repository at this point in the history
..., in order to support later presharding work, in which we will
pass a list of train dataset to build_vocab().
  • Loading branch information
JianyuZhan committed Dec 22, 2017
1 parent d807d28 commit 387b802
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 11 deletions.
19 changes: 10 additions & 9 deletions onmt/io/IO.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,12 +295,12 @@ def build_dataset(fields, data_type, src_path, tgt_path, src_dir=None,
return dataset


def build_vocab(train, data_type, share_vocab,
def build_vocab(train_datasets, data_type, share_vocab,
src_vocab_size, src_words_min_frequency,
tgt_vocab_size, tgt_words_min_frequency):
"""
Args:
train: a train dataset.
train_datasets: a list of train dataset.
data_type: "text", "img" or "audio"?
share_vocab(bool): share source and target vocabulary?
src_vocab_size(int): size of the source vocabulary.
Expand All @@ -310,18 +310,19 @@ def build_vocab(train, data_type, share_vocab,
tgt_words_min_frequency(int): the minimum frequency needed to
include a target word in the vocabulary.
"""
fields = train.fields
# All datasets have same fields, get the first one is OK.
fields = train_datasets[0].fields

fields["tgt"].build_vocab(train, max_size=tgt_vocab_size,
fields["tgt"].build_vocab(*train_datasets, max_size=tgt_vocab_size,
min_freq=tgt_words_min_frequency)
for j in range(train.n_tgt_feats):
fields["tgt_feat_" + str(j)].build_vocab(train)
for j in range(train_datasets[0].n_tgt_feats):
fields["tgt_feat_" + str(j)].build_vocab(*train_datasets)

if data_type == 'text':
fields["src"].build_vocab(train, max_size=src_vocab_size,
fields["src"].build_vocab(*train_datasets, max_size=src_vocab_size,
min_freq=src_words_min_frequency)
for j in range(train.n_src_feats):
fields["src_feat_" + str(j)].build_vocab(train)
for j in range(train_datasets[0].n_src_feats):
fields["src_feat_" + str(j)].build_vocab(*train_datasets)

# Merge the input and output vocabularies.
if share_vocab:
Expand Down
2 changes: 1 addition & 1 deletion preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def main():
train = build_dataset('train', fields, opt)

print("Building vocabulary...")
onmt.io.build_vocab(train, opt.data_type, opt.share_vocab,
onmt.io.build_vocab([train], opt.data_type, opt.share_vocab,
opt.src_vocab_size,
opt.src_words_min_frequency,
opt.tgt_vocab_size,
Expand Down
2 changes: 1 addition & 1 deletion test/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def dataset_build(self, opt):

train = preprocess.build_dataset('train', fields, opt)

onmt.io.build_vocab(train, opt.data_type, opt.share_vocab,
onmt.io.build_vocab([train], opt.data_type, opt.share_vocab,
opt.src_vocab_size,
opt.src_words_min_frequency,
opt.tgt_vocab_size,
Expand Down

0 comments on commit 387b802

Please sign in to comment.