Skip to content
This repository has been archived by the owner on Jun 10, 2021. It is now read-only.

Commit

Permalink
Handle missing target vocabulary for update_vocab (#502)
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln authored and jsenellart committed Jan 27, 2018
1 parent 8f3d6c4 commit 58bb64b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 26 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -15,6 +15,8 @@
* Fix batch size non function with `rest_translation_server.lua`
* Introduce `-tokenizer max` option to scorer for evaluation on non tokenized test data.
* Fix non deterministic inference of language models
* Fix retraining from a language model
* Fix `-update_vocab` option for language models

## [v0.9.7](https://github.com/OpenNMT/OpenNMT/releases/tag/v0.9.7) (2017-12-19)

Expand Down
48 changes: 22 additions & 26 deletions train.lua
Expand Up @@ -84,31 +84,18 @@ local function updateTensorByDict(tensor, dict, updatedDict)
return updateTensor
end

local function mergeDicts(dicts, mergedDicts)

for i = 1, dicts.src.words:size() do
local label = dicts.src.words.idxToLabel[i]
local idx = mergedDicts.src.words.labelToIdx[label]
-- add a old word to the end of new dicts
if idx == nil then
idx = mergedDicts.src.words:size() + 1
mergedDicts.src.words.idxToLabel[idx] = label
mergedDicts.src.words.labelToIdx[label] = idx
end
end

for i = 1, dicts.tgt.words:size() do
local label = dicts.tgt.words.idxToLabel[i]
local idx = mergedDicts.tgt.words.labelToIdx[label]
local function mergeDict(dict, mergedDict)
for i = 1, dict.words:size() do
local label = dict.words.idxToLabel[i]
local idx = mergedDict.words.labelToIdx[label]
-- add a old word to the end of new dicts
if idx == nil then
idx = mergedDicts.tgt.words:size() + 1
mergedDicts.tgt.words.idxToLabel[idx] = label
mergedDicts.tgt.words.labelToIdx[label] = idx
idx = mergedDict.words:size() + 1
mergedDict.words.idxToLabel[idx] = label
mergedDict.words.labelToIdx[label] = idx
end
end

return mergedDicts
return mergedDict
end

local function updateVocab(checkpoint, dicts, opt)
Expand Down Expand Up @@ -139,7 +126,7 @@ local function updateVocab(checkpoint, dicts, opt)
end
end)

if decoder then
if decoder and dicts.tgt then
decoder:apply(function(m)
if torch.type(m) == "onmt.WordEmbedding" then
if m.net.weight:size(1) == checkpoint.dicts.tgt.words:size() then
Expand Down Expand Up @@ -260,17 +247,26 @@ local function loadModel(opt, dicts)

if opt.update_vocab ~= 'none' then
_G.logger:info(' * new source dictionary size: %d', dicts.src.words:size())
_G.logger:info(' * new target dictionary size: %d', dicts.tgt.words:size())
_G.logger:info(' * old source dictionary size: %d', checkpoint.dicts.src.words:size())
_G.logger:info(' * old target dictionary size: %d', checkpoint.dicts.tgt.words:size())

if dicts.tgt then
_G.logger:info(' * new target dictionary size: %d', dicts.tgt.words:size())
_G.logger:info(' * old target dictionary size: %d', checkpoint.dicts.tgt.words:size())
end

if opt.update_vocab == 'merge' then
_G.logger:info(' * Merging new / old dictionaries...')
dicts = mergeDicts(checkpoint.dicts, dicts)
dicts.src = mergeDict(checkpoint.dicts.src, dicts.src)
if dicts.tgt then
dicts.tgt = mergeDict(checkpoint.dicts.tgt, dicts.tgt)
end
else
_G.logger:info(' * Replacing old dictionaries by new dictionaries...')
end

checkpoint = updateVocab(checkpoint, dicts, opt)
elseif checkpoint.dicts.src.words:size() ~= dicts.src.words:size() or checkpoint.dicts.tgt.words:size() ~= dicts.tgt.words:size() then
elseif (checkpoint.dicts.src.words:size() ~= dicts.src.words:size()
or (dicts.tgt and checkpoint.dicts.tgt.words:size() ~= dicts.tgt.words:size())) then
_G.logger:warning('Dictionary size changed, you may need to activate -update_vocab option')
end

Expand Down

0 comments on commit 58bb64b

Please sign in to comment.