Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] C++ implementation of parallel executor #9035

Closed

Conversation

tonyyang-svail
Copy link

@tonyyang-svail tonyyang-svail commented Mar 13, 2018

DO NOT MERGE THIS PR!

This PR will serve a baseline to #9080. The main difference is that:

In this PR, one thread is bound to one GPU. Each thread launches all Ops sequentially in the computation stream and launches AllReduce in the io stream. CudaEvent is used for coordination between streams.

In #9080, a dependency parsing is used for scheduling the ready Ops to a thread pool.

machine: 250
test_script: test_parallel_executor.py in this PR
test_command:

  1. CUDA_VISIBLE_DEVICES=3 python -m unittest test_parallel_executor.TestResnet
  2. CUDA_VISIBLE_DEVICES=3,4,5,6 python -m unittest test_parallel_executor.TestResnet

model: SE_ResNeXt152
batch_size: 16 per GPU
model size: 1382651904
peak memory: 7351879168

20.7775 Instance per second
60.8804 Instance per second

@panyx0718 panyx0718 self-requested a review March 14, 2018 04:24

for (auto& op : ctx->ops_) {
// sgd should wait for allreduce finished
for (auto& param2argu : op->Inputs()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of looping every input every time, perhaps the op can cache the param inputs and only wait for them


PADDLE_ENFORCE(
cudaEventRecord(computation_event[argu], computation_stream));
PADDLE_ENFORCE(cudaStreamWaitEvent(all_reduce_stream,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to block the next computation op. We should profile and see how much it hurts

@@ -86,6 +88,7 @@ class Scope {
mutable std::unordered_map<std::string, Variable*> vars_;
mutable std::list<Scope*> kids_;
Scope const* parent_{nullptr};
std::vector<std::shared_ptr<Scope>> replicas_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a little strange. why would the Scope keeps all the replicas? It's not a singleton?

@@ -457,12 +457,39 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
}
};

class BatchNormGradMaker : public framework::SingleGradOpDescMaker {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a fix? Should it be checked-in separately?

if (!(ins[i].place() == dev_place)) {
LOG(INFO) << "Copy " << out_arg_names[i] << " from " << ins[i].place()
<< " to " << dev_place;
framework::TensorCopy(ins[i], dev_place, out);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this ensure different gpu get different data?


if not isinstance(program, Program):
raise TypeError()
if not isinstance(fetch_list, dict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dict -> list?

std::unordered_set<std::string>* param_grads_;
};

class MultiGPUExecutor {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we extend it to be MultiCPUThreadExecutor??

@tonyyang-svail
Copy link
Author

I am closing this pr since its alternative #9080 will be merged.

@chengduoZH chengduoZH added the parallel_exe parallel executor label Apr 6, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
parallel_exe parallel executor
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants