Skip to content

Commit

Permalink
fix tracker_queue single
Browse files Browse the repository at this point in the history
  • Loading branch information
francoishernandez committed Jul 10, 2020
1 parent a17dde8 commit d5e8f61
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 18 deletions.
3 changes: 2 additions & 1 deletion onmt/bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def train(opt):
shard_base = "train_" + opt.data_ids[0]
else:
shard_base = "train"
train_iter = build_dataset_iter(shard_base, fields, opt)
train_iter = build_dataset_iter(
shard_base, fields, opt, data_tracker)

nb_gpu = len(opt.gpu_ranks)

Expand Down
29 changes: 16 additions & 13 deletions onmt/inputters/inputter.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,11 @@ def __iter__(self):
if not self.repeat:
return


class Tracker(object):
"""
Simple object to help keep track of shards that were loaded.
"""
def __init__(self, _dict=None):
if _dict is not None:
self.last_path = _dict['last_path']
Expand All @@ -724,6 +728,7 @@ def get_last_path(self, dataset):
def get_count(self, dataset):
return self.counter.get(dataset, None)


class MultipleDatasetIterator(object):
"""
This takes a list of iterable objects (DatasetLazyIter) and their
Expand All @@ -740,15 +745,15 @@ def __init__(self,
self.weights = []

if data_tracker is None:
self.tracker = Tracker()
self.data_tracker = Tracker()
else:
self.tracker = data_tracker
self.data_tracker = data_tracker
for shard, weight in zip(train_shards, opt.data_weights):
if weight > 0:
self.iterables.append(
build_dataset_iter(
shard, fields, opt, multi=True,
tracker=self.tracker))
data_tracker=self.data_tracker))
self.weights.append(weight)
self.init_iterators = True
# self.weights = opt.data_weights
Expand Down Expand Up @@ -812,16 +817,14 @@ class DatasetLazyIter(object):
def __init__(self, dataset_paths, fields, batch_size, batch_size_fn,
batch_size_multiple, device, is_train, pool_factor,
repeat=True, num_batches_multiple=1, yield_raw_example=False,
tracker=None, corpus_type=None):
data_tracker=None, corpus_type=None):
self._paths = dataset_paths
# reorder _paths based on tracker if exists
if tracker is not None and corpus_type is not None:
next_shard = tracker.get_last_path(corpus_type)
print("//// next_shard", next_shard)
if data_tracker is not None and corpus_type is not None:
next_shard = data_tracker.get_last_path(corpus_type)
if next_shard is not None:
index = self._paths.index(next_shard)
self._paths = self._paths[index+1:] + self._paths[:index+1]
print("reordered paths", self._paths)
self.fields = fields
self.batch_size = batch_size
self.batch_size_fn = batch_size_fn
Expand All @@ -832,7 +835,7 @@ def __init__(self, dataset_paths, fields, batch_size, batch_size_fn,
self.num_batches_multiple = num_batches_multiple
self.yield_raw_example = yield_raw_example
self.pool_factor = pool_factor
self.tracker = tracker
self.data_tracker = data_tracker
self.corpus_type = corpus_type

def _iter_dataset(self, path):
Expand Down Expand Up @@ -871,8 +874,8 @@ def __iter__(self):
# Cycle through the shards indefinitely.
paths = cycle(paths)
for path in paths:
if self.tracker is not None and self.corpus_type is not None:
self.tracker.update(self.corpus_type, path)
if self.data_tracker is not None and self.corpus_type is not None:
self.data_tracker.update(self.corpus_type, path)
for batch in self._iter_dataset(path):
yield batch
num_batches += 1
Expand Down Expand Up @@ -912,7 +915,7 @@ def max_tok_len(new, count, sofar):


def build_dataset_iter(corpus_type, fields, opt, is_train=True,
multi=False, tracker=None):
multi=False, data_tracker=None):
"""
This returns user-defined train/validate data iterator for the trainer
to iterate over. We implement simple ordered iterator strategy here,
Expand Down Expand Up @@ -952,7 +955,7 @@ def build_dataset_iter(corpus_type, fields, opt, is_train=True,
repeat=not opt.single_pass,
num_batches_multiple=max(opt.accum_count) * opt.world_size,
yield_raw_example=multi,
tracker=tracker,
data_tracker=data_tracker,
corpus_type=corpus_type)


Expand Down
15 changes: 11 additions & 4 deletions onmt/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
import torch

from onmt.inputters.inputter import build_dataset_iter, patch_fields, \
load_old_vocab, old_style_vocab, build_dataset_iter_multiple
load_old_vocab, old_style_vocab, build_dataset_iter_multiple, Tracker
from onmt.model_builder import build_model
from onmt.utils.optimizers import Optimizer
from onmt.utils.misc import set_random_seed
from onmt.trainer import build_trainer
from onmt.models import build_model_saver
from onmt.utils.logging import init_logger, logger
from onmt.utils.parse import ArgumentParser
import torch.multiprocessing as mp


def _check_save_model_path(opt):
Expand Down Expand Up @@ -57,10 +58,13 @@ def main(opt, device_id, batch_queue=None,
ArgumentParser.validate_model_opts(model_opt)
logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
vocab = checkpoint['vocab']
data_tracker = Tracker(_dict=checkpoint.get('data_tracker', None))
print("!!! LOADED data_tracker", data_tracker.__dict__)
else:
checkpoint = None
model_opt = opt
vocab = torch.load(opt.data + '.vocab.pt')
data_tracker = Tracker()

# check for code where vocab is saved instead of fields
# (in the future this will be done in a smarter way)
Expand Down Expand Up @@ -95,7 +99,8 @@ def main(opt, device_id, batch_queue=None,
# Build optimizer.
optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint)

data_tracker = tracker_queue.get()
if tracker_queue is not None:
data_tracker = tracker_queue.get()

# Build model saver
model_saver = build_model_saver(
Expand All @@ -111,13 +116,15 @@ def main(opt, device_id, batch_queue=None,
for train_id in opt.data_ids:
shard_base = "train_" + train_id
train_shards.append(shard_base)
train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
train_iter = build_dataset_iter_multiple(
train_shards, fields, opt, data_tracker)
else:
if opt.data_ids[0] is not None:
shard_base = "train_" + opt.data_ids[0]
else:
shard_base = "train"
train_iter = build_dataset_iter(shard_base, fields, opt)
train_iter = build_dataset_iter(
shard_base, fields, opt, data_tracker)

else:
assert semaphore is not None, \
Expand Down

0 comments on commit d5e8f61

Please sign in to comment.