Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add design of asynchronous techniques on heterogeneous devices #7814

Closed

Conversation

reyoung
Copy link
Collaborator

@reyoung reyoung commented Jan 24, 2018

The asynchronous techniques on heterogeneous devices are the key issue of performance tuning.

I just try to describe the problem and give a straightforward solution.

Any comments/questions on this design are welcome.

@reyoung reyoung force-pushed the feature/multi_stream_design_doc branch from db2512a to ecdab2c Compare January 24, 2018 05:08

Let's use CUDA as an example. There is a building block named `stream` in CUDA. Streams introduce task-based parallelism to CUDA codes. The sequence of operations will be executed in issue-order on the GPU if they are in the same stream.

The operators in different streams are able to run concurrently as long as they are in multiple streams and hardware supports it. CUDA hardware has no notion of streams. The hardware has separate queues (engines) to perform memory copies and to execute kernels.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. To make operations running concurrently, the operations of one stream should be not depending on the operation on other streams.
  2. memory copies ==> data transfers

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make operations running concurrently, the operations of one stream should be not depending on the operation on other streams.

I think I want to talk another issue. Please refer to #7814 (comment)


The operators in different streams are able to run concurrently as long as they are in multiple streams and hardware supports it. CUDA hardware has no notion of streams. The hardware has separate queues (engines) to perform memory copies and to execute kernels.

If we want to take advantage of CUDA devices, we must use at least N streams, where N equals the number of hardware queues, and separate operators into these streams. The N equals to three since CUDA can simultaneously execute CUDA kernels, H2D memcpy, D2H memcpy by the CUDA hardware.
Copy link
Contributor

@chengduoZH chengduoZH Jan 24, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the stream of CUDA does not have a limit. As long as the resources (memory and computaion) of GPU are not occupied, in theory, you can create a new stream.

Copy link
Collaborator Author

@reyoung reyoung Jan 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it does not. The CUDA can create many streams without any limits. However, jobs in CUDA can simultaneously execute, in two conditions.

  1. The jobs in different streams.
  2. The hardware supports.

Since the CUDA hardware supports to simultaneously execute THREE kinds of operators, Kernel execution/D2H memcpy/H2D memcpy, we need AT LEAST THREE streams to make full usage of CUDA devices. And if we use more than three streams, it will not help much since the hardware only supports to simultaneously execute THREE kinds of operators.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Threads in a block will be launched in a SM(streaming multiprocessors). If the former CUDA kernel occupies few SMs and there is more SM left, another CUDA kernel can be executed in parallel. Please refer to https://devblogs.nvidia.com/gpu-pro-tip-cuda-7-streams-simplify-concurrency/.

Copy link
Contributor

@helinwang helinwang Mar 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And if we use more than three streams, it will not help much since the hardware only supports to simultaneously execute THREE kinds of operators.

Please consider the following case:

We have 7 kernels of 3 types (A B C). There dependency relationship is as follows:

A0 B0 C0
B1
B2
C1
C2

If we just have 3 streams, and it's possible that they are put into 3 stream in the following order:

A0 B0 C0 C1 C2
B1
B2

In this case C1 C2 don't depend on C0 but still have to wait for C0's completion before being able to run.

I think the number of streams we use should equal to the concurrency expressed in our program, not the hardware.


* Create N device contexts on one device. The N should be corresponding to the hardware property. For example, the CUDA devices should have three device contexts.

* Every tensor should hold the one device context, where the current operator of the tensor is performed on.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wondered whether it is appropriate that every tensor holds one device context.
Device Context maybe has a lot of objects, taking CUDADeviceContext as an example, it currently has six private data:

  CUDAPlace place_;

  std::unique_ptr<Eigen::GpuDevice> eigen_device_;
  std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;

  cudaStream_t stream_;
  cudnnHandle_t cudnn_handle_;
  cublasHandle_t cublas_handle_;

But for tensors, only place_ and stream_ are necessary.

Copy link
Collaborator Author

@reyoung reyoung Jan 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only device context can be Wait(). We should not use the low-level APIs, like stream. Because there could be other devices, like OpenCL, need to be supported.

Another reason we use device context is CUDNN/CUBLAS/EIGEN need to bind a stream before we use other APIs. http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnSetStream The cudnnHandle_t is coupled with stream.

kH2DMEMCPY
};

std::map<CUDAHardwareStream, DeviceContext* > gDevCtxs;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does here need a device_id for multi-devices in this global gDevCtxs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well. This code is used to demonstrate the basic idea of the solution. I do not go so deeply in details.

enum CUDAHardwareStream {
kCOMPUTATION,
kD2HMEMCPY,
kH2DMEMCPY

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a kD2DMEMCPY?

public:
...

void SwitchDevCtx(DeviceContext* new_ctx) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am hesitant to add SwitchDevCtx as a method of Tensor. Reasons:

  1. If we haveTensor::SwitchDevCtx, we may also need SelectedRows::SwitchDevCtx etc.
  2. The operator needs to wait, not the tensor.

So maybe we should put the explicit wait in operator run?

void ReduceOp::Run(scope, place) {
    gDevCtxs[place, kCOMPUTATION].Wait();
    my_ctx = gDevCtx[place, kD2DMEMCPY];

    ...   
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only data in different CUDA streams are independent, can these operations execute potentially in parallel. So, the basic problem is to analyze data dependencies between two operations. Maybe we need an explicit scheduler module to do these.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tonyyang-svail
I don't think the operator needs to wait, not the tensor.

The previous operator of the tensor needs to wait for. So we need

  1. Record the previous operator of the tensor.
    • For example, ReduceOp::Run may not need to wait kCOMPUTATION if the previous operators are not computation. ReduceOp::Run could wait for any kind of device contexts.
    • The input tensors of ReduceOp can be operated by various streams. For example, there are two tensors need to be reduced, A, B. A is a computational result. B is a H2DMemcpy result. The two device contexts should be both waited.
  2. Not all operators of Tensor performed by paddle::framework::Operator.
    • There are memcpy, fill zero in framework module.
    • We cannot just put an explicit wait in operator run to solve this problem

Copy link
Collaborator Author

@reyoung reyoung Jan 29, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only data in different CUDA streams are independent, can these operations execute potentially in parallel. So, the basic problem is to analyze data dependencies between two operations. Maybe we need an explicit scheduler module to do these.

@QiJune
An explicit scheduler module could resolve these problems, however,

  1. Not all operators of Tensor performed by paddle::framework::Operator and Executor.
    • A scheduler should schedule operators. However, there is no unified abstraction of operators on Tensor to be scheduled.
    • If we want to add a scheduler, we should add an abstraction layer of operators first.
  2. There is no clear schedule algorithm for Fluid
    • Fluid is different from other frameworks. We do not use DAG to represent neural networks. Tensors and variables can be overwritten. There could be loops in our framework. We should give a clear schedule algorithm before we write it.
    • An explicit scheduler may NOT be faster than switching streams as it needs.
      • Switching streams as it needs will have more conditions in C++ (if statements). However, comparing than the computation and wait for streams, the conditions are ignorable.
      • An explicit scheduler also needs to calculate dependencies ahead of time. It is not free.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@reyoung Maybe we can write some experimental codes. Following is a pseudocode:

tensor1 = op1(dev_ctx1);
tensor2 = op2(tensor1, dev_ctx1);
tensor3 = op3(tensor2, dev_ctx1);
tensor4 = update_op(tensor1, dev_ctx2);

We expect that after op1 running, update_op can run in parallel with op2 and op3. But cudaStreamSynchronize will block until stream has completed all operations. Please refer to the official doc.

I am not sure if the behavior will be update_op running after op1/op2/op3 finishing, because the dev_ctx1 stream has three operations on it. That's not we wanted.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@QiJune thanks for the example. I am sure these four operators will be executed sequentially.

As fas as parallel_do_grad is concerned, I think the following program is good enough, even without an explicit scheduler

parallel_do_grad
    w1_grad = fc_grad(.., stream0)
    all_reduce(w1_grad, stream1)    // it will wait for stream0
    sgd(w1_grad, w1, stream1)
    w2_grad = fc_grad(.., stream0)
    ...

Copy link
Member

@QiJune QiJune left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For I/O related operator, we need a transpiler to insert it to ProgramDesc accurately to achieve max performance.

@chengduoZH
Copy link
Contributor

I think this PR should be active again.
how many computation streams there should be? And how many copying streams? we should reconsider that.


The operators in different streams are able to run concurrently as long as they are in multiple streams and hardware supports it. CUDA hardware has no notion of streams. The hardware has separate queues (engines) to perform memory copies and to execute kernels.

If we want to take advantage of CUDA devices, we must use at least N streams, where N equals the number of hardware queues, and separate operators into these streams. The N equals to three since CUDA can simultaneously execute CUDA kernels, H2D memcpy, D2H memcpy by the CUDA hardware.
Copy link
Contributor

@helinwang helinwang Mar 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we must use at least N streams, where N equals the number of hardware queues? Doesn't CUDA will handle multiplexing a single stream on to different hardware queues transparent for us. I agree we need to use N streams, but maybe N should not be bounded by the number of hardware queues? (otherwise we need code to lookup the number of hardware queues given the hardware, complicates our code).


The solution is straightforward based on the hardware properties we described in the problem section. We should:

* Create N device contexts on one device. The N should be corresponding to the hardware property. For example, the CUDA devices should have three device contexts.
Copy link
Contributor

@helinwang helinwang Mar 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it be that the number of device contexts only depends on the "concurrency" requirement of PaddlePaddle program, rather than the hardware? Related question: #7814 (comment)
Please also see: #7814 (comment)


* Create N device contexts on one device. The N should be corresponding to the hardware property. For example, the CUDA devices should have three device contexts.

* Every tensor should hold the one device context, where the current operator of the tensor is performed on.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Every tensor should hold the one device context, I think it's possible that one tensor gets used by different streams, and I assume one context corresponding to one stream, so which context should it hold?


The solution is straightforward based on the hardware properties we described in the problem section. We should:

* Create N device contexts on one device. The N should be corresponding to the hardware property. For example, the CUDA devices should have three device contexts.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is true that one "device contexts" = "one stream"? It's a little confusing that in the "Problem" section we are only talking about stream, but in the "Solution" section we are mainly talking about context.


* Every tensor should hold the one device context, where the current operator of the tensor is performed on.

* Wait for the execution complete on the previous device context, when switching the current device context of tensors.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the stream on the given tensor is as follows:

tensor_related_op op_a op_b op_c op_d op_e

op_a to op_e is not related to the tensor.

You mentioned "Wait for the execution complete on the previous device context", do we have to wait until op_e, or until tensor_related_op?

@luotao1
Copy link
Contributor

luotao1 commented Feb 1, 2019

感谢您给PaddlePaddle贡献文档。由于文档已迁移至FluidDoc repo,因此关闭您的PR,欢迎您向FluidDoc Repo贡献文档。
Thanks for contributing to PaddlePaddle! Since documents have been moved to FluidDoc repo, we close this PR. Welcome to contribute to FluidDoc repo.

@luotao1 luotao1 closed this Feb 1, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants