Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add ParallelGraph executor mode in parallelexecutor to improve performance #14791

Merged
merged 41 commits into from
Jan 3, 2019

Conversation

Yancey1989
Copy link
Contributor

@Yancey1989 Yancey1989 commented Dec 7, 2018

Backgroup

The default executor type in ParallelExecutor would schedule op_handle in the threadpool, but the op_handles switching would reduce the performance while the operation execution time is very short (20 ~ 100 us on CPU in ResNet50 model).

  • CPU kernal duration VS. GPU kernal duration
GPUs CPU kernal duration GPU kernal duration conv2d(CPU) conv2d(GPU)
2GPUs, 1CPU 180 ms 277 ms 78.632 ns 216.535 ns
2GPUs, 2CPUs 286 ms 277 ms 113.573 ns 216.555 ns

Scheduling op_handle in threadpool lasts longger than non-threadpool.

  • fake data VS. real data
GPUs/batch_time vis-reader vis-reader + test_mode fake_data
8GPUs, 8CPUs 254ms 233ms 230 ms

IO is not the biggest bottleneck

This PR trying to implement another executor type called ParallelGraph in ParallelExecutor, the different with the default executor in ParallelExecutor is as follows:

  1. Convert the main_program into N * graph, which N is the number of devices, and
  2. ParallelGraph would run each of the graphs on one thread.

Experiment

Test env:

  • GPU: 8 * V100
  • Model: ResNet50
  • Dataset: ImageNet, batch_size is 32 for each GPU, fetch for every 30 iters.

Test cases:

  1. throughput with fake_data on qianmo vm
GPUs/thoughtput Default executor ParallelGraph  executor
1 GPUs, 1 CPUs 268(1) 268(1)
2 GPUs, 2 CPUs 305(1.14) 507(1.89)
4 GPUs, 4 CPUs 633(2.36)  
8 GPUs, 8 CPUs 1150(4.29) 1874(6.9)
  1. throughput with vis-reader on qianmo vm
GPUs/thoughtput Default executor ParallelGraph  executor
1 GPU, 1 CPU 264 264
8 GPUs, 8 CPUs 976(3.69) 1559(5.9)
  1. throughtput with vis-reader on PaddleCloud
GPUs/thoughtput default executor ParallelGraph executor Multiple processes
8 GPUs, 8 CPUs,bs=32 1293 1733(+34%) 1736(+34%)
  1. Test on Transformer model
GPUs/thoughtput default executor ParallelGraph executor
8GPUs,8CPUs,bs=4096 80,478 84,869(+5%)

TODO:

  • Support GPU parallel training and nccl2 distributed training mode.
  • Fix nccl allreduce hang if the training data is empty on some devices.
  • ParalleGraph mode support CPU training.
  • Support PServer distributed mode.

@Yancey1989 Yancey1989 changed the title Add ParallelGraph executor mode in parallelexecutor to improve performance [WIP, Feature] Add ParallelGraph executor mode in parallelexecutor to improve performance Dec 7, 2018
Copy link
Contributor

@panyx0718 panyx0718 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea of having each thread executing it's own ops. But the implementation is a little confusing. Perhaps we can have a better way to implement it.

@@ -14,6 +14,7 @@ limitations under the License. */

#pragma once

#include <pthread.h>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a test, I tried to increase the priority of the threads, http://man7.org/linux/man-pages/man3/pthread_setschedparam.3.html , will delete this header file.

if (g_state == ProfilerState::kDisabled) return;
std::lock_guard<std::mutex> l(profiler_mu);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this matter? it's wrong to put it here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe don't need to check the mutex when disabling the profiler, it will decrease the performance, and disable/enable the profiler only happens at the beginning or end of each batch training.

if (g_state == ProfilerState::kDisabled || !is_enabled_) return;
VLOG(5) << "call ~RecordEvent";
std::lock_guard<std::mutex> l(profiler_mu);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[](ExecutionStrategy &self, ExecutionStrategy::ExecutorType type) {
self.type_ = type;
},
R"DOC()DOC");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add more doc to describe kParallelGraph type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, this PR is WIP, I updated the description of this PR, and add a todo list.

} else if (exec_strategy.type_ == ExecutionStrategy::kParallelGraph) {
nccl_id.reset(new ncclUniqueId());
PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(nccl_id.get()));
*member_->global_scope_->Var(NCCL_ID_VARNAME)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can't nccl_id_varname be created like above?

// only used in executor_type == ParallalGraph, one thread one GPU
// TODO(Yancey1989): use allreduce operator to avoid this tricky.
PADDLE_ENFORCE(all_reduce_calls.size() == 1UL);
all_reduce_calls[0]();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this different?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will hang when using group call of nccl operators, and the example codes of NCCL wouldn't use the gropu call: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/examples.html#example-1-one-device-per-process-or-thread .

// only used in executor_type == ParallalGraph, one thread one GPU
// TODO(Yancey1989): use allreduce operator to avoid this tricky.
PADDLE_ENFORCE(all_reduce_calls.size() == 1UL);
all_reduce_calls[0]();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that it is easy to produce deadlock as described in https://arxiv.org/pdf/1706.02677.pdf

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, the same problem with distributed training with NCCL2 collective mode, and we need to fix the order of the all-reduce operators, @gongweibao has a PR to do that #14586 , I will do more test with that feature after this PR merged.

auto call = [this, i] {
// FIXME(Yancey1989): need to fix fetch data failed.
std::vector<std::string> empty;
executors_[i]->Run(empty);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to make ParallelSSAGraphExecutor exception-safe, because some exception is acceptable such as EOFException caused by py_reader when one pass ends.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

@@ -14,6 +14,7 @@ limitations under the License. */

#pragma once

#include <pthread.h>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not use pthread.h which is only POSIX-compatible. Try to use standard C++ headers #include <thread>.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will delete this header file, #14791 (comment)

@@ -106,31 +110,56 @@ ParallelExecutor::ParallelExecutor(
// Bcast Parameters to all GPUs
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto *nccl_id_var = scope->FindVar(NCCL_ID_VARNAME);
ncclUniqueId *nccl_id = nullptr;
std::unique_ptr<ncclUniqueId> nccl_id = nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you use std::unique_ptr here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will try to avoid that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

@Yancey1989
Copy link
Contributor Author

Yancey1989 commented Dec 11, 2018

I like the idea of having each thread executing it's own ops. But the implementation is a little confusing. Perhaps we can have a better way to implement it. -- FROM @panyx0718

Just to reuse the code of multi_devices_pass and maybe implement another Pass is a good idea.

graphs_(std::move(graphs)) {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
// do not use threadpool for each graph execution.
strategy_.num_threads_ = 1UL;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you compare the performance between different num_threads?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't use the threadpool can achieve better performance, and I add it as a configurable argument.

This is the result on fake_data:

num_threds throughput
1 1841.10691
2 1648.53907

}
}
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there any issue here?
when a exception is throwed, it will keep running until line-80.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, done.

@Yancey1989 Yancey1989 changed the title [WIP, Feature] Add ParallelGraph executor mode in parallelexecutor to improve performance [Feature] Add ParallelGraph executor mode in parallelexecutor to improve performance Dec 13, 2018
for (auto &call : all_reduce_calls) {
call();
// TODO(Yancey1989): need allreduce operator to avoid this flag
if (nccl_ctxs_->need_group_call_) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just use all_reduce_calls.size() == 1UL

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea, done.

@@ -52,7 +52,6 @@ void OpHandleBase::Run(bool use_cuda) {
#else
PADDLE_ENFORCE(!use_cuda);
#endif

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can remove these unnecessary changes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

@@ -386,7 +386,16 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
CreateComputationalOps(&result, node, places_.size());
}

// insert synchronous ops at the backpropagation; and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

synchronous ops => collective ops

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

};

if (pool_) {
run_futures.emplace_back(pool_->enqueue(std::move(call)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since the number of tasks to run is determined, can just use a set of threads to avoid enqueues

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ThreadPool can avoid creating the threads at the begging of each batch.

#endif

auto max_memory_size = GetEagerDeletionThreshold();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should add this back?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

});
})
.def_property(
"executor_type",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

executor_type is a too general name which cannot provide enough information for the API

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can discuss the name, and I also added some description about this field.

Actually, ParallelGraph executor is one of the executors (Default, FastTheaded and more in the future).
@panyx0718 do you have any good idea?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using build_strategy.enable_parallel_graph to enable/disable the parallel graph.

Copy link
Contributor

@panyx0718 panyx0718 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sneaxiy @chengduoZH please review

ReduceLoDTensor func(lod_tensors, &trg);
VisitDataType(lod_tensors[0]->type(), func);

for (size_t i = 1; i < local_scopes_.size(); ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still work for cpu?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, CPU allreduce is implemented by TensorCopy, need a barrier to wait for the input var ready which from difference devices(graphs).

places_(std::move(places)),
graphs_(std::move(graphs)) {
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
// do not use threadpool for each graph execution.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do you enforce this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will delete this comment, it's an optional argument.

fetch_datas.emplace_back(std::move(f.get()));
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the sync execution if the pool_ is nullptr.

lodtensor_ptrs.push_back(&fetch_datas.at(scope_idx).at(fetch_idx));
}
ret.emplace_back();
ret.back().MergeLoDTensor(lodtensor_ptrs, platform::CPUPlace());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does ParallelExecutor's fetch merge lodtensor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fetch_op_handle only placed on one device for the ParallelGraph mode, should merge them here.

"build_strategy.reduce should be `AllReduce` if you want to enable"
"ParallelGraph.");
PADDLE_ENFORCE(
member_->use_cuda_,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why only support gpu?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (nccl_id_var != nullptr) {
nccl_id = nccl_id_var->GetMutable<ncclUniqueId>();
}
if (build_strategy.enable_parallel_graph_ && places.size() > 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why use places, not num_parallel_devices?

// Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp
std::vector<std::unique_ptr<ir::Graph>> graphs;
member_->num_parallel_devices_ = member_->places_.size() * num_trainers;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why compute num_parallel_devices again?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, it's a duplicate code, will delete it.

@@ -442,10 +442,10 @@ def _run_cluster_nccl2(self, model, envs, check_error_log):
tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --update_method nccl2 --lr %f"
tr0_cmd = tr_cmd % \
(self._python_interp, model, self._ps_endpoints,
0, w0_ep, self._lr / 2)
0, w0_ep, self._lr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remark 3: Normalize the per-worker loss by total minibatch size kn, not per-worker size n.

FROM https://arxiv.org/abs/1706.02677.

The scale on loss instead of LR can get better acc.

@@ -198,6 +199,17 @@ ParallelExecutor::ParallelExecutor(
"the number of places must be greater than 1.");
}

if (build_strategy.enable_parallel_graph_) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which exe_strategy does it check?

@@ -106,7 +106,7 @@ struct NCCLContextMap {
}
std::unique_ptr<ncclComm_t[]> comms(new ncclComm_t[order_.size()]);
// if num_trainers == 1, should create a new nccl id for local comms.
if (num_trainers == 1) {
if (num_trainers == 1 && nccl_id == nullptr) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change this? Is this a bug fix?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a bug, ParallelGraph should initialize NCCl by ranks mode with the same nccl_id.


def test_batchnorm_fc(self):
for use_cuda in (False, True):
for use_fast_executor in (False, True):
self.check_batchnorm_fc_convergence(use_cuda, use_fast_executor)

self.check_batchnorm_fc_convergence(
use_cuda=True, use_fast_executor=False, use_parallel_graph=True)

def test_batchnorm_fc_with_new_strategy(self):
# FIXME(zcd): close this test temporally.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (exception_holder_.IsCaught()) {
f.wait();
} else {
fetch_datas.emplace_back(std::move(f.get()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data is a plural form.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if (build_strategy.enable_sequential_execution_ ||
exec_strategy.type_ == ExecutionStrategy::ExecutorType::kExperimental)
enable_parallel_graph = false;
return enable_parallel_graph && FLAGS_enable_parallel_graph;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code order of EnableParallelGraphExecution can be refined.
e.g. if FLAGS_enable_parallel_graph is False, it can return directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

@chengduoZH chengduoZH left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR can be merge first, but there are several problems to be solved in next PR: if some executor doesn't have training data, and other executors is not notified, in this case the program will be hang during NCCL AllReduce.
This PR doesn't affect the default behavior of Parallel Executor.

@Yancey1989
Copy link
Contributor Author

Yancey1989 commented Jan 3, 2019

Thanks @chengduoZH , will fix the ParallelExecutor hang if some devices have no enough training data , before that we can enable the ParallelGraph mode by setting the env:

FLAGS_enable_parallel_graph=1 FLAGS_sync_nccl_allreduce=1 ...

And this PR need @panyx0718 's approve since the const_cast.

Copy link
Contributor

@panyx0718 panyx0718 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to have a more automatic way of handling different build and execution strategy

// asynchronous nccl allreduce or synchronous issue:
// https://github.com/PaddlePaddle/Paddle/issues/15049
DEFINE_bool(
sync_nccl_allreduce, false,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still not quite comfortable with this flag

Copy link
Contributor Author

@Yancey1989 Yancey1989 Jan 3, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will enable the ParallelGraph mode + async NCCL by default if fixed the NCCL hang issue...

@Yancey1989 Yancey1989 merged commit a1e60ab into PaddlePaddle:develop Jan 3, 2019
@Yancey1989 Yancey1989 deleted the parallel_graph_mode branch January 3, 2019 10:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants