Permalink
Browse files

Multi model b474 (#480)

* introduce batch_size and hook manager
  • Loading branch information...
jsenellart committed Dec 29, 2017
1 parent 11f20f9 commit 8e8970ee3a825f36c1ba627f1512f7bec0c535e7
Showing with 56 additions and 49 deletions.
  1. +56 −49 tools/rest_multi_models.lua
View
@@ -55,6 +55,9 @@ cmd:setCmdLineOptions(server_options, 'Server')
onmt.utils.Cuda.declareOpts(cmd)
onmt.utils.Logger.declareOpts(cmd)
onmt.utils.HookManager.updateOpt(arg, cmd)
onmt.utils.HookManager.declareOpts(cmd)
local opt_server = cmd:parse(arg)
local fconfig = io.open(opt_server.model_config, "rb")
assert(fconfig)
@@ -80,13 +83,12 @@ for i=1, #server_cfg do
model_cmd:text("")
model_cmd:text("**Other options**")
model_cmd:text("")
model_cmd:option('-batchsize', 1000, [[Size of each parallel batch - you should not change except if low memory.]])
model_cmd:option('-batch_size', 1000, [[Size of each parallel batch - you should not change except if low memory.]])
opt[i] = model_cmd:parse({})
end
local function translateMessage(server, lines)
local batch = {}
-- We need to tokenize the input line before translation
local bpe
local res
@@ -100,51 +102,51 @@ local function translateMessage(server, lines)
if options.bpe_model ~= '' then
bpe = BPE.new(options)
end
for i = 1, #lines do
local srcTokenized = {}
local tokens
local srcTokens = {}
res, err = pcall(function() tokens = tokenizer.tokenize(options, lines[i].src, bpe) end)
-- it can generate an exception if there are utf-8 issues in the text
if not res then
if string.find(err, "interrupted") then
error("interrupted")
else
error("unicode error in line " .. err)
end
end
table.insert(srcTokenized, table.concat(tokens, ' '))
-- Extract from the line.
for word in srcTokenized[1]:gmatch'([^%s]+)' do
table.insert(srcTokens, word)
end
-- Currently just a single batch.
table.insert(batch, translator:buildInput(srcTokens))
end
-- Translate
_G.logger:info("Start Translation")
local results = translator:translate(batch)
_G.logger:info("End Translation")
-- Return the nbest translations for each in the batch.
local i = 1
local translations = {}
for b = 1, #lines do
local ret = {}
for i = 1, translator.args.n_best do
local srcSent = translator:buildOutput(batch[b])
local lineres = {
tgt = "",
src = srcSent,
n_best = 1,
pred_score = 0
}
if results[b].preds ~= nil then
while i <= #lines do
local batch = {}
while i <= #lines and #batch < options.batch_size do
local srcTokens = {}
local srcTokenized = {}
local tokens
res, err = pcall(function()
tokens = tokenizer.tokenize(options, lines[i].src, bpe)
end)
-- it can generate an exception if there are utf-8 issues in the text
if not res then
if string.find(err, "interrupted") then
error("interrupted")
else
error("unicode error in line " .. err)
end
end
table.insert(srcTokenized, table.concat(tokens, ' '))
-- Extract from the line.
for word in srcTokenized[1]:gmatch'([^%s]+)' do
table.insert(srcTokens, word)
end
-- Currently just a single batch.
table.insert(batch, translator:buildInput(srcTokens))
i = i + 1
end
-- Translate
_G.logger:debug("Start Translation")
local results = translator:translate(batch)
_G.logger:debug("End Translation")
-- Return the nbest translations for each in the batch.
for b = 1, #batch do
local ret = {}
for bi = 1, translator.args.n_best do
local srcSent = translator:buildOutput(batch[b])
local predSent
res, err = pcall(function()
predSent = tokenizer.detokenize(options,
results[b].preds[i].words,
results[b].preds[i].features)
results[b].preds[bi].words,
results[b].preds[bi].features)
end)
if not res then
if string.find(err,"interrupted") then
@@ -153,24 +155,26 @@ local function translateMessage(server, lines)
error("unicode error in line ".. err)
end
end
lineres = {
local lineres = {
tgt = predSent,
src = srcSent,
n_best = i,
pred_score = results[b].preds[i].score
n_best = bi,
pred_score = results[b].preds[bi].score
}
if options.withAttn or lines[b].withAttn then
local attnTable = {}
for j = 1, #results[b].preds[i].attention do
table.insert(attnTable, results[b].preds[i].attention[j]:totable())
for j = 1, #results[b].preds[bi].attention do
table.insert(attnTable, results[b].preds[bi].attention[j]:totable())
end
lineres.attn = attnTable
end
table.insert(ret, lineres)
end
table.insert(ret, lineres)
table.insert(translations, ret)
end
table.insert(translations, ret)
end
return translations
end
@@ -238,6 +242,9 @@ local function main()
_G.logger = onmt.utils.Logger.new(opt_server.log_file, opt_server.disable_logs, opt_server.log_level)
-- cuda settings
onmt.utils.Cuda.init(opt_server)
_G.hookManager = onmt.utils.HookManager.new(opt)
-- disable profiling
_G.profiler = onmt.utils.Profiler.new(false)

0 comments on commit 8e8970e

Please sign in to comment.