Skip to content
This repository has been archived by the owner on Jun 10, 2021. It is now read-only.

Commit

Permalink
Ignore invalid sentences when building vocabularies
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln committed Jan 10, 2017
1 parent b4de839 commit 08f4c15
Showing 1 changed file with 28 additions and 23 deletions.
51 changes: 28 additions & 23 deletions preprocess.lua
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ local function hasFeatures(filename)
return numFeatures > 0
end

local function makeVocabulary(filename, size)
local function isValid(sent, maxSeqLength)
return #sent > 0 and #sent <= maxSeqLength
end

local function makeVocabulary(filename, size, maxSeqLength)
local wordVocab = onmt.utils.Dict.new({onmt.Constants.PAD_WORD, onmt.Constants.UNK_WORD,
onmt.Constants.BOS_WORD, onmt.Constants.EOS_WORD})
local featuresVocabs = {}
Expand All @@ -55,23 +59,25 @@ local function makeVocabulary(filename, size)
break
end

local words, features, numFeatures = onmt.utils.Features.extract(sent)

if #featuresVocabs == 0 and numFeatures > 0 then
for j = 1, numFeatures do
featuresVocabs[j] = onmt.utils.Dict.new({onmt.Constants.PAD_WORD, onmt.Constants.UNK_WORD,
onmt.Constants.BOS_WORD, onmt.Constants.EOS_WORD})
if isValid(sent, maxSeqLength) then
local words, features, numFeatures = onmt.utils.Features.extract(sent)

if #featuresVocabs == 0 and numFeatures > 0 then
for j = 1, numFeatures do
featuresVocabs[j] = onmt.utils.Dict.new({onmt.Constants.PAD_WORD, onmt.Constants.UNK_WORD,
onmt.Constants.BOS_WORD, onmt.Constants.EOS_WORD})
end
else
assert(#featuresVocabs == numFeatures,
'all sentences must have the same numbers of additional features')
end
else
assert(#featuresVocabs == numFeatures,
'all sentences must have the same numbers of additional features')
end

for i = 1, #words do
wordVocab:add(words[i])
for i = 1, #words do
wordVocab:add(words[i])

for j = 1, numFeatures do
featuresVocabs[j]:add(features[j][i])
for j = 1, numFeatures do
featuresVocabs[j]:add(features[j][i])
end
end
end

Expand All @@ -86,7 +92,7 @@ local function makeVocabulary(filename, size)
return wordVocab, featuresVocabs
end

local function initVocabulary(name, dataFile, vocabFile, vocabSize, featuresVocabsFiles)
local function initVocabulary(name, dataFile, vocabFile, vocabSize, featuresVocabsFiles, maxSeqLength)
local wordVocab
local featuresVocabs = {}

Expand Down Expand Up @@ -121,7 +127,7 @@ local function initVocabulary(name, dataFile, vocabFile, vocabSize, featuresVoca
if wordVocab == nil or (#featuresVocabs == 0 and hasFeatures(dataFile)) then
-- If a dictionary is still missing, generate it.
print('Building ' .. name .. ' vocabulary...')
local genWordVocab, genFeaturesVocabs = makeVocabulary(dataFile, vocabSize)
local genWordVocab, genFeaturesVocabs = makeVocabulary(dataFile, vocabSize, maxSeqLength)

if wordVocab == nil then
wordVocab = genWordVocab
Expand Down Expand Up @@ -186,8 +192,7 @@ local function makeData(srcFile, tgtFile, srcDicts, tgtDicts)
break
end

if #srcTokens > 0 and #srcTokens <= opt.src_seq_length
and #tgtTokens > 0 and #tgtTokens <= opt.tgt_seq_length then
if isValid(srcTokens, opt.src_seq_length) and isValid(tgtTokens, opt.tgt_seq_length) then
local srcWords, srcFeats = onmt.utils.Features.extract(srcTokens)
local tgtWords, tgtFeats = onmt.utils.Features.extract(tgtTokens)

Expand Down Expand Up @@ -273,10 +278,10 @@ local function main()
local data = {}

data.dicts = {}
data.dicts.src = initVocabulary('source', opt.train_src, opt.src_vocab,
opt.src_vocab_size, opt.features_vocabs_prefix)
data.dicts.tgt = initVocabulary('target', opt.train_tgt, opt.tgt_vocab,
opt.tgt_vocab_size, opt.features_vocabs_prefix)
data.dicts.src = initVocabulary('source', opt.train_src, opt.src_vocab, opt.src_vocab_size,
opt.features_vocabs_prefix, opt.src_seq_length)
data.dicts.tgt = initVocabulary('target', opt.train_tgt, opt.tgt_vocab, opt.tgt_vocab_size,
opt.features_vocabs_prefix, opt.tgt_seq_length)

print('Preparing training data...')
data.train = {}
Expand Down

0 comments on commit 08f4c15

Please sign in to comment.