Permalink
Browse files

Merge pull request #679 from OpenNMT/prep

Prep
  • Loading branch information...
srush committed Apr 11, 2018
2 parents 1817ed1 + bad9e8e commit fb82df729fd02b2a14947981a9f93d3f6b79a504
@@ -33,10 +33,10 @@ Optimal value should be multiples of 64 bytes.
### **Vocab**:
* **-src_vocab []**
Path to an existing source vocabulary
Path to an existing source vocabulary. Format: one word per line.
* **-tgt_vocab []**
Path to an existing target vocabulary
Path to an existing target vocabulary. Format: one word per line.
* **-features_vocabs_prefix []**
Path prefix to existing features vocabularies
@@ -19,7 +19,8 @@ Use a shared weight matrix for the input and output word embeddings in the
decoder.
* **-share_embeddings []**
Share the word embeddings between encoder and decoder. It requires using `-share_vocab` during pre-processing.
Share the word embeddings between encoder and decoder. Need to use shared
dictionary for this option.
* **-position_encoding []**
Use a sin to mark relative words positions. Necessary for non-RNN style models.
@@ -65,6 +65,13 @@ Google NMT length penalty parameter (higher = longer generation)
* **-beta []**
Coverage penalty parameter
* **-block_ngram_repeat []**
Block repetition of ngrams during decoding.
* **-ignore_when_blocking []**
Ignore these strings when blocking repeats. You want to block sentence
delimiters.
* **-replace_unk []**
Replace the generated UNK tokens with the source token that had highest
attention weight. If phrase_table is provided, it will lookup the identified
View
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
import os
from collections import Counter, defaultdict, OrderedDict
from itertools import count
@@ -226,17 +227,19 @@ def _build_field_vocab(field, counter, **kwargs):
def build_vocab(train_dataset_files, fields, data_type, share_vocab,
src_vocab_size, src_words_min_frequency,
tgt_vocab_size, tgt_words_min_frequency):
src_vocab_path, src_vocab_size, src_words_min_frequency,
tgt_vocab_path, tgt_vocab_size, tgt_words_min_frequency):
"""
Args:
train_dataset_files: a list of train dataset pt file.
fields (dict): fields to build vocab for.
data_type: "text", "img" or "audio"?
share_vocab(bool): share source and target vocabulary?
src_vocab_path(string): Path to src vocabulary file.
src_vocab_size(int): size of the source vocabulary.
src_words_min_frequency(int): the minimum frequency needed to
include a source word in the vocabulary.
tgt_vocab_path(string): Path to tgt vocabulary file.
tgt_vocab_size(int): size of the target vocabulary.
tgt_words_min_frequency(int): the minimum frequency needed to
include a target word in the vocabulary.
@@ -248,6 +251,29 @@ def build_vocab(train_dataset_files, fields, data_type, share_vocab,
for k in fields:
counter[k] = Counter()
# Load vocabulary
src_vocab = None
if len(src_vocab_path) > 0:
src_vocab = set([])
print('Loading source vocab from %s' % src_vocab_path)
assert os.path.exists(src_vocab_path), \
'src vocab %s not found!' % src_vocab_path
with open(src_vocab_path) as f:
for line in f:
word = line.strip().split()[0]
src_vocab.add(word)
tgt_vocab = None
if len(tgt_vocab_path) > 0:
tgt_vocab = set([])
print('Loading target vocab from %s' % tgt_vocab_path)
assert os.path.exists(tgt_vocab_path), \
'tgt vocab %s not found!' % tgt_vocab_path
with open(tgt_vocab_path) as f:
for line in f:
word = line.strip().split()[0]
tgt_vocab.add(word)
for path in train_dataset_files:
dataset = torch.load(path)
print(" * reloading %s." % path)
@@ -256,6 +282,10 @@ def build_vocab(train_dataset_files, fields, data_type, share_vocab,
val = getattr(ex, k, None)
if val is not None and not fields[k].sequential:
val = [val]
elif k == 'src' and src_vocab:
val = [item for item in val if item in src_vocab]
elif k == 'tgt' and tgt_vocab:
val = [item for item in val if item in tgt_vocab]
counter[k].update(val)
_build_field_vocab(fields["tgt"], counter["tgt"],
View
@@ -152,10 +152,12 @@ def preprocess_opts(parser):
# Dictionary options, for text corpus
group = parser.add_argument_group('Vocab')
group.add_argument('-src_vocab',
help="Path to an existing source vocabulary")
group.add_argument('-tgt_vocab',
help="Path to an existing target vocabulary")
group.add_argument('-src_vocab', default="",
help="""Path to an existing source vocabulary. Format:
one word per line.""")
group.add_argument('-tgt_vocab', default="",
help="""Path to an existing target vocabulary. Format:
one word per line.""")
group.add_argument('-features_vocabs_prefix', type=str, default='',
help="Path prefix to existing features vocabularies")
group.add_argument('-src_vocab_size', type=int, default=50000,
View
@@ -157,8 +157,10 @@ def build_save_dataset(corpus_type, fields, opt):
def build_save_vocab(train_dataset, fields, opt):
fields = onmt.io.build_vocab(train_dataset, fields, opt.data_type,
opt.share_vocab,
opt.src_vocab,
opt.src_vocab_size,
opt.src_words_min_frequency,
opt.tgt_vocab,
opt.tgt_vocab_size,
opt.tgt_words_min_frequency)
View
@@ -3,6 +3,7 @@
import unittest
import glob
import os
import codecs
from collections import Counter
import torchtext
@@ -39,6 +40,13 @@ def __init__(self, *args, **kwargs):
def dataset_build(self, opt):
fields = onmt.io.get_fields("text", 0, 0)
if hasattr(opt, 'src_vocab') and len(opt.src_vocab) > 0:
with codecs.open(opt.src_vocab, 'w', 'utf-8') as f:
f.write('a\nb\nc\nd\ne\nf\n')
if hasattr(opt, 'tgt_vocab') and len(opt.tgt_vocab) > 0:
with codecs.open(opt.tgt_vocab, 'w', 'utf-8') as f:
f.write('a\nb\nc\nd\ne\nf\n')
train_data_files = preprocess.build_save_dataset('train', fields, opt)
preprocess.build_save_vocab(train_data_files, fields, opt)
@@ -48,6 +56,10 @@ def dataset_build(self, opt):
# Remove the generated *pt files.
for pt in glob.glob(SAVE_DATA_PREFIX + '*.pt'):
os.remove(pt)
if hasattr(opt, 'src_vocab') and os.path.exists(opt.src_vocab):
os.remove(opt.src_vocab)
if hasattr(opt, 'tgt_vocab') and os.path.exists(opt.tgt_vocab):
os.remove(opt.tgt_vocab)
def test_merge_vocab(self):
va = torchtext.vocab.Vocab(Counter('abbccc'))
@@ -109,6 +121,8 @@ def test_method(self):
('share_vocab', True)],
[('dynamic_dict', True),
('max_shard_size', 500000)],
[('src_vocab', '/tmp/src_vocab.txt'),
('tgt_vocab', '/tmp/tgt_vocab.txt')],
]
for p in test_databuild:

0 comments on commit fb82df7

Please sign in to comment.