In large-scale distributed asynchronous training, if different workers read the same number of samples, the training time of the slow node will be much longer than that of other nodes, resulting in a long-tail problem. With the expansion of the training scale, the long-tail problem will become more and more serious, which reducing the overall data throughput, and prolonging the time to produce the model.
We provide the WorkQueue class, which can perform elastic data segmentation on multiple data sources, so that slow nodes can be trained with less data, and fast nodes can be trained with more data. WorkQueue will significantly alleviate the impact of long-tail problems and reduce training time.
WorkQueue manages the work items of all workers. After the remaining work items are consumed, each worker will obtain new work items from the same WorkQueue as a data source for training, so that faster training workers can get more work items.
class WorkQueue(works, num_epochs=1,
shuffle=True,
seed=None,
prefix=None,
num_slices=None,
name='work_queue')
-
works
: list of filename -
num_epochs
: the number of times to read all data -
shuffle
:ifTrue
, randomly shuffle data every epoch -
seed
:the random seed used to shuffle the data, IfNone
, the seed will be automatically generated by WorkQueue -
prefix
: the prefix of work items (filenames or table names), default value isNone
-
num_slices
: total number of work items, usually more than 10 times the number of workers. The more unstable the cluster, the greater the total number of work items required. IfNone
, no data fragmentation will be performed.num_slices
is invalid when reading files. -
name
: the name of work queue
- take
method WorkQueue.take()
Description | Get a work item from global WorkQueue and download it to the worker |
---|---|
Return Value | tensorflow.Tensor |
Parameter | None |
- input_dataset
method WorkQueue.input_dataset()
Description | Return a Dataset, Each element of the Dataset is a work item |
---|---|
Return Value | tensorflow.data.Dataset |
Parameter | None |
- input_producer
method WorkQueue.input_producer()
Description | The local proxy queue of the global work queue, used by the Reader class Op. |
---|---|
Return Value | tensorflow.FIFOQueue |
Parameter | None |
- add_summary
method WorkQueue.add_summary()
Description | Generates work queue statistics that can be displayed in tensorboard. |
---|---|
Return Value | None |
Parameter | None |
from tensorflow.python.ops.work_queue import WorkQueue
# use WorkQueue to allocate tasks
work_queue = WorkQueue([filename, filename1,filename2,filename3])
f_data = work_queue.input_dataset()
# Extract lines from input files using the Dataset API.
dataset = tf.data.TextLineDataset(f_data)
dataset = dataset.shuffle(buffer_size=20000,
seed=2021) # fix seed for reproducing
dataset = dataset.repeat(num_epochs)
dataset = dataset.prefetch(batch_size)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(1)
The WDL model in DeepRec modelzoo provides a more detailed example.
from tensorflow.python.ops.work_queue import WorkQueue
work_queue = WorkQueue([path1, path2, path3], shuffle=True)
work_queue.add_summary()
# create file reader
reader = tf.TextLineReader()
# get 2 records
keys, values = reader.read_up_to(work_queue.input_producer(), num_records=2)
with tf.train.MonitoredTrainingSession() as sess:
sess.run(...)
from tensorflow.python.ops.work_queue import WorkQueue
work_queue = WorkQueue(
[odps_path1, odps_path2, odps_path3], shuffle=True, num_slices=FLAGS.num_workers * 10)
# create table reader
reader = tf.TableRecordReader()
# get 2 records
keys, values = reader.read_up_to(work_queue.input_producer(), num_records=2)