diff --git a/dali/pipeline/executor/executor.h b/dali/pipeline/executor/executor.h index 16fbb87757..b396cf7d35 100644 --- a/dali/pipeline/executor/executor.h +++ b/dali/pipeline/executor/executor.h @@ -580,7 +580,7 @@ void Executor::Build(OpGraph *graph, vector +#include #include #include "dali/pipeline/graph/op_graph_storage.h" @@ -19,12 +21,17 @@ namespace dali { std::vector CreateBackingStorageForTensorNodes( - const OpGraph &op_graph, int batch_size, const std::vector &queue_sizes) { + const OpGraph &op_graph, int batch_size, const std::vector &queue_sizes, + const std::vector &output_names) { DALI_ENFORCE(static_cast(queue_sizes.size()) == op_graph.NumTensor(), "Data queue sizes undefined for some Tensor nodes."); std::vector result; result.resize(op_graph.NumTensor()); + std::set outputs; + auto output_ids = op_graph.GetOutputs(output_names);; + outputs.insert(output_ids.begin(), output_ids.end()); + // Assign data to each Tensor node in graph for (int i = 0; i < op_graph.NumTensor(); i++) { const auto &tensor = op_graph.Tensor(i); @@ -32,8 +39,11 @@ std::vector CreateBackingStorageForTensorNodes( result[i] = BatchFactory(producer_op_type, tensor.producer.storage_device, batch_size, queue_sizes[i]); + bool is_output = outputs.count(tensor.id) > 0; tuple_for_each(result[i], [&](auto &x) { x.num_consumers = tensor.consumers.size(); + if (is_output) + x.num_consumers++; }); } return result; diff --git a/dali/pipeline/graph/op_graph_storage.h b/dali/pipeline/graph/op_graph_storage.h index 9ed0cce0b0..bdf6f6b889 100644 --- a/dali/pipeline/graph/op_graph_storage.h +++ b/dali/pipeline/graph/op_graph_storage.h @@ -15,6 +15,7 @@ #ifndef DALI_PIPELINE_GRAPH_OP_GRAPH_STORAGE_H_ #define DALI_PIPELINE_GRAPH_OP_GRAPH_STORAGE_H_ +#include #include #include "dali/pipeline/graph/op_graph.h" @@ -28,7 +29,8 @@ namespace dali { using MixedOpEventMap = std::vector>; DLL_PUBLIC std::vector CreateBackingStorageForTensorNodes( - const OpGraph& op_graph, int batch_size, const std::vector& queue_sizes); + const OpGraph& op_graph, int batch_size, const std::vector& queue_sizes, + const std::vector &output_names); // Mapping from MixedOp partition id to queue of corresponding events DLL_PUBLIC MixedOpEventMap CreateEventsForMixedOps(EventPool& event_pool, const OpGraph& op_graph,