Skip to content

Commit

Permalink
[tf.contrib.data] Re-implement IteratorGetNext as an AsyncOpKernel.
Browse files Browse the repository at this point in the history
This prevents the op from consuming an inter-op thread pool thread
when blocked, and fixes a potential deadlock when many IteratorGetNext
ops are blocked. Fixes tensorflow#10369.

PiperOrigin-RevId: 157878885
  • Loading branch information
mrry authored and Amit Patankar committed Jun 6, 2017
1 parent efa4746 commit d5b89b1
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ def iterator_thread():
results.append(sess.run(get_next))
except errors.OutOfRangeError:
return
threads = [self.checkedThread(target=iterator_thread) for _ in range(8)]
threads = [self.checkedThread(target=iterator_thread)
for _ in range(64)]
for t in threads:
t.start()
for t in threads:
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -5216,6 +5216,7 @@ tf_kernel_library(
srcs = ["iterator_ops.cc"],
deps = [
":dataset",
":ops_util",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
Expand Down
69 changes: 44 additions & 25 deletions tensorflow/core/kernels/iterator_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ limitations under the License.
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/resource_op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/strcat.h"

namespace tensorflow {

Expand Down Expand Up @@ -282,38 +285,54 @@ class OneShotIteratorOp : public OpKernel {
IteratorResource* iterator_resource_ = nullptr;
};

class IteratorGetNextOp : public OpKernel {
class IteratorGetNextOp : public AsyncOpKernel {
public:
explicit IteratorGetNextOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}

// TODO(mrry): Convert this to an async op, because
// `iterator->GetNext()` could trigger long-running operations
// (e.g. a QueueDequeue or a remote read).
void Compute(OpKernelContext* ctx) override {
explicit IteratorGetNextOp(OpKernelConstruction* ctx)
: AsyncOpKernel(ctx),
thread_pool_(new thread::ThreadPool(
ctx->env(), ThreadOptions(),
strings::StrCat("iterator_get_next_thread_",
SanitizeThreadSuffix(def().name())),
1 /* num_threads */, false /* low_latency_hint */)) {}

void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
IteratorResource* iterator;
OP_REQUIRES_OK(ctx,
LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
core::ScopedUnref unref_iterator(iterator);

std::vector<Tensor> components;
bool end_of_sequence;

IteratorContext::Params params;
params.env = ctx->env();
params.step_id = ctx->step_id();
params.resource_manager = ctx->resource_manager();
params.runner = *(ctx->runner());
IteratorContext iter_ctx(std::move(params));

OP_REQUIRES_OK(ctx,
iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
OP_REQUIRES(ctx, !end_of_sequence, errors::OutOfRange("End of sequence"));
// The call to `iterator->GetNext()` may block and depend on an
// inter-op thread pool thread, so we issue the call from the
// owned thread pool.
thread_pool_->Schedule([this, ctx, iterator, done]() {
core::ScopedUnref unref_iterator(iterator);

std::vector<Tensor> components;
bool end_of_sequence;

IteratorContext::Params params;
params.env = ctx->env();
params.step_id = ctx->step_id();
params.resource_manager = ctx->resource_manager();
params.runner = *(ctx->runner());
IteratorContext iter_ctx(std::move(params));

OP_REQUIRES_OK_ASYNC(
ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence),
done);
OP_REQUIRES_ASYNC(ctx, !end_of_sequence,
errors::OutOfRange("End of sequence"), done);

for (int i = 0; i < components.size(); ++i) {
// TODO(mrry): Check that the shapes match the shape attrs.
ctx->set_output(i, components[i]);
}

for (int i = 0; i < components.size(); ++i) {
// TODO(mrry): Check that the shapes match the shape attrs.
ctx->set_output(i, components[i]);
}
done();
});
}

private:
std::unique_ptr<thread::ThreadPool> thread_pool_;
};

class IteratorDisposeOp : public OpKernel {
Expand Down

0 comments on commit d5b89b1

Please sign in to comment.