Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dispatch computation OPs before communication in standalone executor #47471

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,13 @@ void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type,
}
}

bool IsCommunicationOp(const Instruction& instr) {
bool IsCommunicationOp(const std::string& op_name) {
const std::set<std::string> special_comm_op_set = {
"send",
"recv",
"send_v2",
"recv_v2",
};
const std::string& op_name = instr.OpBase()->Type();
const std::string communication_op_prefix = "c_";
if (op_name.find(communication_op_prefix) != std::string::npos ||
special_comm_op_set.count(op_name)) {
Expand All @@ -142,6 +141,10 @@ bool IsCommunicationOp(const Instruction& instr) {
return false;
}

bool IsCommunicationOp(const Instruction& instr) {
return IsCommunicationOp(instr.OpBase()->Type());
}

bool IsCpuOp(const Instruction& instr) {
return platform::is_cpu_place(instr.DeviceContext().GetPlace());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ class AsyncWorkQueue {
std::unique_ptr<WorkQueueGroup> queue_group_;
};

bool IsCommunicationOp(const std::string& op_name);

bool IsCommunicationOp(const Instruction& instr);

bool IsCpuOp(const Instruction& instr);
Expand Down
37 changes: 26 additions & 11 deletions paddle/fluid/framework/new_executor/interpretercore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,12 @@ void InterpreterCore::Convert(
for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) {
auto& op_func_node = nodes[op_idx];
auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node);
vec_instruction_.emplace_back(op_idx, std::move(op_func_node), *dev_ctx_);
Priority priority =
interpreter::IsCommunicationOp(op_func_node.operator_base_->Type())
? Priority::kLowest
: Priority::kNormal;
vec_instruction_.emplace_back(
op_idx, std::move(op_func_node), *dev_ctx_, priority);
}

BuildOperatorDependences();
Expand Down Expand Up @@ -835,7 +840,7 @@ void InterpreterCore::ExecuteInstructionList(
}

void InterpreterCore::RunNextInstructions(
const Instruction& instr, std::queue<size_t>* reserved_next_ops) {
const Instruction& instr, std::deque<size_t>* reserved_next_ops) {
platform::RecordEvent record(
"RunNextInstructions", platform::TracerEventType::UserDefined, 10);
auto& next_instr = instr.NextInstructions();
Expand All @@ -848,22 +853,30 @@ void InterpreterCore::RunNextInstructions(

if (instr.KernelType() == OpFuncType::kQueueAsync) {
// move all sync_ops into other threads
for (auto next_id : next_instr.SyncRunIds()) {
for (size_t next_id : next_instr.SyncRunIds()) {
if (IsReady(next_id)) {
async_work_queue_->AddTask(
vec_instruction_[next_id].KernelType(),
[this, next_id]() { RunInstructionAsync(next_id); });
}
}
// keep all async_ops running in current thread
for (auto next_id : next_instr.DirectRunIds()) {
for (size_t next_id : next_instr.DirectRunIds()) {
if (IsReady(next_id)) {
reserved_next_ops->push(next_id);
if (vec_instruction_[next_id].GetPriority() == Priority::kLowest) {
reserved_next_ops->push_back(next_id);
} else {
reserved_next_ops->push_front(next_id);
}
}
}
for (auto next_id : next_instr.EventRunIds()) {
for (size_t next_id : next_instr.EventRunIds()) {
if (IsReady(next_id)) {
reserved_next_ops->push(next_id);
if (vec_instruction_[next_id].GetPriority() == Priority::kLowest) {
reserved_next_ops->push_back(next_id);
} else {
reserved_next_ops->push_front(next_id);
}
}
}
} else {
Expand Down Expand Up @@ -895,16 +908,18 @@ void InterpreterCore::RunNextInstructions(
[this, next_id] { RunInstructionAsync(next_id); });
}
}
if (first_op != -1) reserved_next_ops->push(first_op);
if (first_op != -1) {
reserved_next_ops->push_front(first_op);
}
}
}

void InterpreterCore::RunInstructionAsync(size_t instr_id) {
std::queue<size_t> ready_ops;
ready_ops.push(instr_id);
std::deque<size_t> ready_ops;
ready_ops.push_back(instr_id);
while (!ready_ops.empty()) {
instr_id = ready_ops.front();
ready_ops.pop();
ready_ops.pop_front();
auto& instr_node = vec_instruction_.at(instr_id);
VLOG(5) << __func__ << " OP id:" << instr_node.Id()
<< " name:" << instr_node.OpBase()->Type() << " type:"
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/new_executor/interpretercore.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class InterpreterCore {
void RunInstructionAsync(size_t instr_id);
void RunInstruction(const Instruction& instr_node);
void RunNextInstructions(const Instruction& instr_id,
std::queue<size_t>* reserved_next_ops);
std::deque<size_t>* reserved_next_ops);
// only used when program contains no feed op
void Prepare(const std::vector<std::string>& feed_names,
const std::vector<phi::DenseTensor>& feed_tensors,
Expand Down
8 changes: 6 additions & 2 deletions paddle/fluid/framework/new_executor/new_executor_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -673,8 +673,12 @@ void VariableScope::CheckExist(const std::string& name) const {

Instruction::Instruction(size_t id,
OpFuncNode&& op_func_node,
const platform::DeviceContext& dev_ctx)
: id_(id), op_func_node_(op_func_node), dev_ctx_(dev_ctx) {
const platform::DeviceContext& dev_ctx,
const Priority priority)
: id_(id),
op_func_node_(op_func_node),
dev_ctx_(dev_ctx),
priority_(priority) {
PADDLE_ENFORCE_GE(id,
0,
platform::errors::PreconditionNotMet(
Expand Down
8 changes: 7 additions & 1 deletion paddle/fluid/framework/new_executor/new_executor_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ constexpr const char* kDefaultStream = "DefaultStream";
constexpr const char* kD2HStream = "D2HStream";
constexpr const char* kH2DStream = "H2DStream";

enum class Priority { kLowest, kNormal };

class InterpretercoreInferShapeContext : public InferShapeContext {
public:
InterpretercoreInferShapeContext(const OperatorBase& op,
Expand Down Expand Up @@ -300,7 +302,8 @@ class Instruction {
public:
Instruction(size_t id,
OpFuncNode&& op_func_node,
const platform::DeviceContext& dev_ctx);
const platform::DeviceContext& dev_ctx,
const Priority priority);

size_t Id() const;

Expand Down Expand Up @@ -362,10 +365,13 @@ class Instruction {
std::shared_ptr<platform::DeviceEvent> event,
platform::DeviceType waiter_type);

Priority GetPriority() const { return priority_; }

private:
size_t id_;
OpFuncNode op_func_node_;
const platform::DeviceContext& dev_ctx_; // not owned
const Priority priority_;

std::shared_ptr<RuntimeContext> runtime_ctx_;
std::shared_ptr<InterpretercoreInferShapeContext> infershape_ctx_;
Expand Down