From 5ed487bfc22c3c88a42f6f7690bf0ddca76ec6e7 Mon Sep 17 00:00:00 2001 From: Ruibiao Chen Date: Wed, 2 Nov 2022 15:26:09 +0800 Subject: [PATCH] Dispatch computation OPs before communication in standalone executor (#47471) * Dispath computation OPs before communication in standalone executor * Update code * Fix CI errors --- .../interpreter/interpreter_util.cc | 7 +++- .../interpreter/interpreter_util.h | 2 + .../framework/new_executor/interpretercore.cc | 37 +++++++++++++------ .../framework/new_executor/interpretercore.h | 2 +- .../new_executor/new_executor_defs.cc | 8 +++- .../new_executor/new_executor_defs.h | 8 +++- 6 files changed, 47 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index 816331e3fa549..6c002d06b5b19 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -129,14 +129,13 @@ void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type, } } -bool IsCommunicationOp(const Instruction& instr) { +bool IsCommunicationOp(const std::string& op_name) { const std::set special_comm_op_set = { "send", "recv", "send_v2", "recv_v2", }; - const std::string& op_name = instr.OpBase()->Type(); const std::string communication_op_prefix = "c_"; if (op_name.find(communication_op_prefix) != std::string::npos || special_comm_op_set.count(op_name)) { @@ -145,6 +144,10 @@ bool IsCommunicationOp(const Instruction& instr) { return false; } +bool IsCommunicationOp(const Instruction& instr) { + return IsCommunicationOp(instr.OpBase()->Type()); +} + bool IsCpuOp(const Instruction& instr) { return platform::is_cpu_place(instr.DeviceContext().GetPlace()); } diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h index b842d3acfde6d..d6652d2654160 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h @@ -65,6 +65,8 @@ class AsyncWorkQueue { std::unique_ptr queue_group_; }; +bool IsCommunicationOp(const std::string& op_name); + bool IsCommunicationOp(const Instruction& instr); bool IsCpuOp(const Instruction& instr); diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 230e333458dd4..825c4e14c4489 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -528,7 +528,12 @@ void InterpreterCore::Convert( for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) { auto& op_func_node = nodes[op_idx]; auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node); - vec_instruction_.emplace_back(op_idx, std::move(op_func_node), *dev_ctx_); + Priority priority = + interpreter::IsCommunicationOp(op_func_node.operator_base_->Type()) + ? Priority::kLowest + : Priority::kNormal; + vec_instruction_.emplace_back( + op_idx, std::move(op_func_node), *dev_ctx_, priority); } BuildOperatorDependences(); @@ -835,7 +840,7 @@ void InterpreterCore::ExecuteInstructionList( } void InterpreterCore::RunNextInstructions( - const Instruction& instr, std::queue* reserved_next_ops) { + const Instruction& instr, std::deque* reserved_next_ops) { platform::RecordEvent record( "RunNextInstructions", platform::TracerEventType::UserDefined, 10); auto& next_instr = instr.NextInstructions(); @@ -848,7 +853,7 @@ void InterpreterCore::RunNextInstructions( if (instr.KernelType() == OpFuncType::kQueueAsync) { // move all sync_ops into other threads - for (auto next_id : next_instr.SyncRunIds()) { + for (size_t next_id : next_instr.SyncRunIds()) { if (IsReady(next_id)) { async_work_queue_->AddTask( vec_instruction_[next_id].KernelType(), @@ -856,14 +861,22 @@ void InterpreterCore::RunNextInstructions( } } // keep all async_ops running in current thread - for (auto next_id : next_instr.DirectRunIds()) { + for (size_t next_id : next_instr.DirectRunIds()) { if (IsReady(next_id)) { - reserved_next_ops->push(next_id); + if (vec_instruction_[next_id].GetPriority() == Priority::kLowest) { + reserved_next_ops->push_back(next_id); + } else { + reserved_next_ops->push_front(next_id); + } } } - for (auto next_id : next_instr.EventRunIds()) { + for (size_t next_id : next_instr.EventRunIds()) { if (IsReady(next_id)) { - reserved_next_ops->push(next_id); + if (vec_instruction_[next_id].GetPriority() == Priority::kLowest) { + reserved_next_ops->push_back(next_id); + } else { + reserved_next_ops->push_front(next_id); + } } } } else { @@ -895,16 +908,18 @@ void InterpreterCore::RunNextInstructions( [this, next_id] { RunInstructionAsync(next_id); }); } } - if (first_op != -1) reserved_next_ops->push(first_op); + if (first_op != -1) { + reserved_next_ops->push_front(first_op); + } } } void InterpreterCore::RunInstructionAsync(size_t instr_id) { - std::queue ready_ops; - ready_ops.push(instr_id); + std::deque ready_ops; + ready_ops.push_back(instr_id); while (!ready_ops.empty()) { instr_id = ready_ops.front(); - ready_ops.pop(); + ready_ops.pop_front(); auto& instr_node = vec_instruction_.at(instr_id); VLOG(5) << __func__ << " OP id:" << instr_node.Id() << " name:" << instr_node.OpBase()->Type() << " type:" diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index ff89f5ed731de..4cf5053448703 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -92,7 +92,7 @@ class InterpreterCore { void RunInstructionAsync(size_t instr_id); void RunInstruction(const Instruction& instr_node); void RunNextInstructions(const Instruction& instr_id, - std::queue* reserved_next_ops); + std::deque* reserved_next_ops); // only used when program contains no feed op void Prepare(const std::vector& feed_names, const std::vector& feed_tensors, diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.cc b/paddle/fluid/framework/new_executor/new_executor_defs.cc index 02be9f47ecf3e..08a4a486173f7 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.cc +++ b/paddle/fluid/framework/new_executor/new_executor_defs.cc @@ -673,8 +673,12 @@ void VariableScope::CheckExist(const std::string& name) const { Instruction::Instruction(size_t id, OpFuncNode&& op_func_node, - const platform::DeviceContext& dev_ctx) - : id_(id), op_func_node_(op_func_node), dev_ctx_(dev_ctx) { + const platform::DeviceContext& dev_ctx, + const Priority priority) + : id_(id), + op_func_node_(op_func_node), + dev_ctx_(dev_ctx), + priority_(priority) { PADDLE_ENFORCE_GE(id, 0, platform::errors::PreconditionNotMet( diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index 6735e891230d7..f1ede0974bb0c 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -40,6 +40,8 @@ constexpr const char* kDefaultStream = "DefaultStream"; constexpr const char* kD2HStream = "D2HStream"; constexpr const char* kH2DStream = "H2DStream"; +enum class Priority { kLowest, kNormal }; + class InterpretercoreInferShapeContext : public InferShapeContext { public: InterpretercoreInferShapeContext(const OperatorBase& op, @@ -300,7 +302,8 @@ class Instruction { public: Instruction(size_t id, OpFuncNode&& op_func_node, - const platform::DeviceContext& dev_ctx); + const platform::DeviceContext& dev_ctx, + const Priority priority); size_t Id() const; @@ -362,10 +365,13 @@ class Instruction { std::shared_ptr event, platform::DeviceType waiter_type); + Priority GetPriority() const { return priority_; } + private: size_t id_; OpFuncNode op_func_node_; const platform::DeviceContext& dev_ctx_; // not owned + const Priority priority_; std::shared_ptr runtime_ctx_; std::shared_ptr infershape_ctx_;