Skip to content
This repository has been archived by the owner on Jun 10, 2021. It is now read-only.

[WIP] Dynamic dataset memory optimization #398

Merged
merged 10 commits into from
Oct 17, 2017
20 changes: 11 additions & 9 deletions onmt/data/Preprocessor.lua
Original file line number Diff line number Diff line change
Expand Up @@ -436,11 +436,13 @@ function Preprocessor:__init(args, dataType)
end

if args.preprocess_pthreads > 1 and args.train_dir ~= '' then
local globalLogger = _G.logger
-- try to load threads if available
threads = require('threads')
self.pool = threads.Threads(
args.preprocess_pthreads,
function() init_thread(tokenizers) end
function() init_thread(tokenizers) end,
function() _G.logger = globalLogger end
)
else
init_thread(tokenizers)
Expand Down Expand Up @@ -593,9 +595,10 @@ function Preprocessor:makeGenericData(files, isInputVector, dicts, nameSources,
table.insert(gAvgLength, 0)
end

-- iterate on each file
for _m, _df in ipairs(files) do
self:poolAddJob(
function(m, df, idx_files, time_shift_feature, src_seq_length, tgt_seq_length)
function(df, idx_files, time_shift_feature, src_seq_length, tgt_seq_length, sampling)
local count = 0
local ignored = 0
local emptyCount = 0
Expand All @@ -614,7 +617,6 @@ function Preprocessor:makeGenericData(files, isInputVector, dicts, nameSources,
end

-- if there is a sampling for this file
local sampling = sample_file[m]
local readers = {}
local prunedRatio = {}
for i = 1, n do
Expand Down Expand Up @@ -667,14 +669,14 @@ function Preprocessor:makeGenericData(files, isInputVector, dicts, nameSources,
local hasNil = false
-- read all the available sentences or as long as we have not reached sampling size
-- sampling table is an ordered sentence of sentences id to keep
while not hasNil and (not sampling or (sampling:dim() ~= 0 and sampling_idx <= sampling:size(1))) do
while not hasNil and (not sampling or sampling_idx <= #sampling) do
-- keep in sentences the different sentences and number of times it repeats
local sentences = { {} }
for _ = 1, n do
table.insert(sentences, {})
end
-- keep maximum a batch of 10000 sentences
while not hasNil and (not sampling or (sampling:dim() ~= 0 and sampling_idx <= sampling:size(1))) and #sentences[1] < 10000 do
while not hasNil and (not sampling or sampling_idx <= #sampling) and #sentences[1] < 10000 do
local allNil = true
local keepSentence = not sampling or sampling[sampling_idx] == idx

Expand All @@ -696,7 +698,7 @@ function Preprocessor:makeGenericData(files, isInputVector, dicts, nameSources,

local repeatSentence = 1
if sampling then
while sampling_idx+repeatSentence <= sampling:size(1) and sampling[sampling_idx+repeatSentence] == idx do
while sampling_idx+repeatSentence <= #sampling and sampling[sampling_idx+repeatSentence] == idx do
repeatSentence = repeatSentence + 1
end
end
Expand Down Expand Up @@ -748,7 +750,7 @@ function Preprocessor:makeGenericData(files, isInputVector, dicts, nameSources,
end

return _G.__threadid, false, sentenceDists, vectors, features, avgLength, sizes, prunedRatio, count, ignored, emptyCount,
sampling and (sampling:dim()==0 and 0 or sampling:size(1)) or _df[1]
sampling and #sampling or _df[1]
end,
-- aggregate the results together
function(__threadid, error, sentenceDists, vectors, features, avgLength, sizes, prunedRatio, count, ignored, emptyCount, kept)
Expand Down Expand Up @@ -785,7 +787,7 @@ function Preprocessor:makeGenericData(files, isInputVector, dicts, nameSources,
gEmptyCount = gEmptyCount + emptyCount

end,
_m, _df, self.args.idx_files, self.args.time_shift_feature, self.args.src_seq_length or self.args.seq_length, self.args.tgt_seq_length)
_df, self.args.idx_files, self.args.time_shift_feature, self.args.src_seq_length or self.args.seq_length, self.args.tgt_seq_length, sample_file[_m])
end

self:poolSynchronize()
Expand Down Expand Up @@ -1020,7 +1022,7 @@ function Preprocessor:makeData(dataset, dicts)
end
t = torch.sort(t)
end
table.insert(sample_file, t)
table.insert(sample_file, tds.Vec(t:totable()))
end
end

Expand Down