Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Fluid Distributed Training performance #8638

Closed
typhoonzero opened this issue Feb 28, 2018 · 25 comments

Comments

@typhoonzero
Copy link
Contributor

commented Feb 28, 2018

As shown in #8550, send_op tooks too much time of GPU distributed training, here's some tips we need to do to improve the performance:

  • perf details about send_op -- @gongweibao
  • do not copy before sending variables -- @typhoonzero
  • do not copy when deserialize -- @gongweibao
  • use distribute_transpiler_simple to reduce copying -- @typhoonzero
  • merge small variables into one message and send
  • parameter run optimization parallelly -- @typhoonzero
  • implement communication using RDMA -- @seiriosPlus
  • implement multi GPU multi node dist training using NCCL2 -- @typhoonzero
  • async send gradient after execution of each backward op. #9161 -- @Yancey1989
  • prepare the executor on pserver before training. -- @typhoonzero
  • test maximum throughput of grpc with large messages
  • whether grpc streaming can help
  • use cuda pinned memory to enable DMA copy

@typhoonzero typhoonzero self-assigned this Feb 28, 2018

@gongweibao

This comment has been minimized.

Copy link
Contributor

commented Feb 28, 2018

I'm checking perf details about send_op.

@typhoonzero

This comment has been minimized.

Copy link
Contributor Author

commented Feb 28, 2018

From TF benchmark's results:
https://www.tensorflow.org/performance/benchmarks

It's code at here:
https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks

The code only uses grpc as it's protocol, not using grpc+gdr (RDMA) implement. that means we can do a lot to improve current performance using gRPC.

@typhoonzero

This comment has been minimized.

Copy link
Contributor Author

commented Mar 2, 2018

I took a short look at https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc#L136. To make clear, we don't need to call Sierializexxx which takes too much time to copy around the tensor data, we can just reference the data and send them.

@gongweibao

This comment has been minimized.

Copy link
Contributor

commented Mar 2, 2018

https://github.com/gongweibao/CUDA-training/blob/master/utils/cuda_by_example/chapter10/copy_timed.cu

ENV:
Tesla K40m, Driver Version: 367.48

Use malloc:

size: 1 KB, times: 100, time: 0.99 ms, speed: 396.13 MB/s
size: 2 KB, times: 100, time: 1.28 ms, speed: 610.88 MB/s
size: 4 KB, times: 100, time: 1.71 ms, speed: 914.10 MB/s
size: 8 KB, times: 100, time: 2.30 ms, speed: 1357.65 MB/s
size: 16 KB, times: 100, time: 3.30 ms, speed: 1895.56 MB/s
size: 32 KB, times: 100, time: 5.24 ms, speed: 2389.53 MB/s
size: 64 KB, times: 100, time: 9.32 ms, speed: 2686.11 MB/s
size: 128 KB, times: 100, time: 17.47 ms, speed: 2864.10 MB/s
size: 256 KB, times: 100, time: 33.48 ms, speed: 2989.95 MB/s
size: 512 KB, times: 100, time: 56.00 ms, speed: 3574.83 MB/s
size: 1024 KB, times: 100, time: 100.72 ms, speed: 3975.28 MB/s
size: 2049 KB, times: 100, time: 190.09 ms, speed: 4212.65 MB/s
size: 4099 KB, times: 100, time: 631.77 ms, speed: 2535.04 MB/s
size: 8199 KB, times: 100, time: 1179.44 ms, speed: 2715.81 MB/s
size: 16399 KB, times: 100, time: 1722.16 ms, speed: 3719.89 MB/s
size: 32799 KB, times: 100, time: 4707.91 ms, speed: 2721.48 MB/s

Use cudaHostAlloc

size: 1 KB, times: 100, time: 1.38 ms, speed: 282.61 MB/s
size: 2 KB, times: 100, time: 1.70 ms, speed: 459.45 MB/s
size: 4 KB, times: 100, time: 2.23 ms, speed: 701.12 MB/s
size: 8 KB, times: 100, time: 1.39 ms, speed: 2244.05 MB/s
size: 16 KB, times: 100, time: 1.78 ms, speed: 3514.83 MB/s
size: 32 KB, times: 100, time: 2.41 ms, speed: 5188.56 MB/s
size: 64 KB, times: 100, time: 3.63 ms, speed: 6886.78 MB/s
size: 128 KB, times: 100, time: 6.12 ms, speed: 8178.19 MB/s
size: 256 KB, times: 100, time: 11.50 ms, speed: 8705.71 MB/s
size: 512 KB, times: 100, time: 22.31 ms, speed: 8972.61 MB/s
size: 1024 KB, times: 100, time: 43.41 ms, speed: 9223.45 MB/s
size: 2049 KB, times: 100, time: 85.69 ms, speed: 9345.31 MB/s
size: 4099 KB, times: 100, time: 170.05 ms, speed: 9418.02 MB/s
size: 8199 KB, times: 100, time: 338.59 ms, speed: 9460.30 MB/s
size: 16399 KB, times: 100, time: 675.79 ms, speed: 9479.64 MB/s
size: 32799 KB, times: 100, time: 1350.45 ms, speed: 9487.60 MB/s
@gongweibao

This comment has been minimized.

Copy link
Contributor

commented Mar 4, 2018

To make clear, we don't need to call Sierializexxx which takes too much time to copy around the tensor data, we can just reference the data and send them.

基本上已经定位,时间消耗来自于内存拷贝:

  1. 数据从GPU到内存(图中copy + wait):
    • 拷贝采用页锁定的内存的情况下,从GPU拷贝到内存时间基本稳定
  2. 数据写入到std ostream
  3. 数据从std ostream写入到protobuf
  4. 数据发送拷贝到网卡

2,3,4数据波动比较大
主要的时间耗费来自于2,3

image

@typhoonzero

This comment has been minimized.

Copy link
Contributor Author

commented Mar 5, 2018

We have to make sure when we call gRPC the tensor data is not copied, the buffer should be sent directly.

To achieve this, current grpc_server/grpc_client need to be re-rewitten.

@gongweibao

This comment has been minimized.

Copy link
Contributor

commented Mar 5, 2018

@typhoonzero

This comment has been minimized.

Copy link
Contributor Author

commented Mar 5, 2018

Testing using ::grpc::Bytebuffer low-level interface to send buffers (200MB) directly for 100 times: https://github.com/typhoonzero/grpc_zerocopy_async_example. It seems that this can affect very little to the performance.

Will try different serialize methods of variables.

@gongweibao

This comment has been minimized.

Copy link
Contributor

commented Mar 6, 2018

Testing using ::grpc::Bytebuffer low-level interface to send buffers (200MB) directly for 100 times: https://github.com/typhoonzero/grpc_zerocopy_async_example. It seems that this can affect very little to the performance.

Talk with @typhoonzero and fix some bugs of the program.
It seems that it needs 8~9ms when copying 12MB data.

@helinwang

This comment has been minimized.

Copy link
Contributor

commented Mar 6, 2018

Maybe we need one (or maybe three more streams: host to device, device to host, device to device) dedicated CUDA stream for copying tensor?
Currently we only use one stream, for both computation and data copy. However there are two bottlenecks: computation and IO.
For the same reason currently NCCL all-reduce is a big bottle neck for multiple GPU training because its using the computation stream, so all computation that could be overlapped with the all-reduce is blocked.
The tricky part of adding multiple streams is that currently our architecture is based on a single stream. There is no dependency analysis, so we don't know how to do tensor synchronization, which is required when a tensor is used on different streams.
Some related info: https://stackoverflow.com/a/36108627/852385 , https://github.com/tensorflow/tensorflow/blob/754048a0453a04a761e112ae5d99c149eb9910dd/tensorflow/core/common_runtime/gpu_device_context.h#L40

@helinwang

This comment has been minimized.

Copy link
Contributor

commented Mar 6, 2018

谢谢伟宝!

基本上已经定位,时间消耗来自于内存拷贝:
数据从GPU到内存(图中copy + wait):
拷贝采用页锁定的内存的情况下,从GPU拷贝到内存时间基本稳定
数据写入到std ostream
数据从std ostream写入到protobuf
数据发送拷贝到网卡
2,3,4数据波动比较大
主要的时间耗费来自于2,3

序列化成protobuf花的时间有点太长了。

貌似现在最大的瓶颈不在device to host内存拷贝,也不在gRPC,而是protobuf序列化?@gongweibao @typhoonzero

@gongweibao

This comment has been minimized.

Copy link
Contributor

commented Mar 7, 2018

貌似现在最大的瓶颈不在device to host内存拷贝,也不在gRPC,而是protobuf序列化?

protobuf序列化是一部分,然后user level的数据拷贝进grpc的空间是另外一部分。我们已经找到了解决的办法,避免2,3两个数据拷贝。根据测试的时间消耗,应该可以解决绝大部分的时间性能问题。

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.cc#L219

@helinwang

This comment has been minimized.

Copy link
Contributor

commented Mar 7, 2018

赞!那个代码链接我没有看明白,请问能解释下如何做到的吗?不是很理解为何可以避免3.数据从std ostream写入到protobuf

@gongweibao

This comment has been minimized.

Copy link
Contributor

commented Mar 7, 2018

数据分成了两部分:
skeleton数据和data数据。skeleton可以直接用proto的接口写入,数据量小。data部分用的是 grpc::Slice,user level数据往grpc的空间传递的是一个指针,避免了数据拷贝。

@helinwang

This comment has been minimized.

Copy link
Contributor

commented Mar 7, 2018

赞!谢谢!原来gRPC原生支持byte slice。

@gongweibao

This comment has been minimized.

Copy link
Contributor

commented Mar 7, 2018

没有文档,也就找到了Tensorflow的这一个例子。除了grpc的作者们,估计其他人知道的很少。

@wangkuiyi

This comment has been minimized.

Copy link
Collaborator

commented Mar 8, 2018

A summarization of the above discussions:

It takes too long time for the Send operator to encode content into protobuf message and copy the content to gRPC buffer. We borrowed ideas from TensorFlow implementation to accelerate these two steps.

@typhoonzero

This comment has been minimized.

Copy link
Contributor Author

commented Mar 15, 2018

Can make optimization run parallelly.

run vgg16 with flowers dataset (fc size is 512)

trainer per batch time (seconds):

Pass = 0, Iters = 240, spent 1.383521
Pass = 0, Iters = 241, spent 1.571773
Pass = 0, Iters = 242, spent 2.008580
Pass = 0, Iters = 243, spent 1.616996

Here's the server side calculations time (ms):

I0315 05:18:15.651073    95 listen_and_serv_op.cc:139]  server side program spent 147.921
I0315 05:18:17.787200    95 listen_and_serv_op.cc:139]  server side program spent 165.944
I0315 05:18:19.663105    95 listen_and_serv_op.cc:139]  server side program spent 158.983
I0315 05:18:21.431468    95 listen_and_serv_op.cc:139]  server side program spent 146.156
I0315 05:18:23.324154    95 listen_and_serv_op.cc:139]  server side program spent 147.547
I0315 05:18:26.147795    95 listen_and_serv_op.cc:139]  server side program spent 138.413
I0315 05:18:28.349591    95 listen_and_serv_op.cc:139]  server side program spent 150.108
I0315 05:18:30.492772    95 listen_and_serv_op.cc:139]  server side program spent 197.025
I0315 05:18:32.396615    95 listen_and_serv_op.cc:139]  server side program spent 143.375
I0315 05:18:34.815824    95 listen_and_serv_op.cc:139]  server side program spent 225.21
@typhoonzero

This comment has been minimized.

Copy link
Contributor Author

commented Mar 16, 2018

Some notes on testing distributed training when pserver program runs on GPU or not (no zero-copy grpc)

environment:

  • GPU: P40
  • 4 pservers, 4 trainers
  • vgg16 flowers, fc size 512, batch size 20
  1. pserver runs on CPU: 1.3~ 2s per batch (trainer time)
  2. pserver runs on GPU: 3.5 ~ 5s per batch (trainer time)

Will add result when zero-copy grpc merged.

@chengduoZH

This comment has been minimized.

Copy link
Contributor

commented Mar 16, 2018

@gongweibao

https://github.com/gongweibao/CUDA-training/blob/master/utils/cuda_by_example/chapter10/copy_timed.cu
ENV:
Tesla K40m, Driver Version: 367.48
Use malloc:
size: 1 KB, times: 100, time: 0.99 ms, speed: 396.13 MB/s
size: 2 KB, times: 100, time: 1.28 ms, speed: 610.88 MB/s
size: 4 KB, times: 100, time: 1.71 ms, speed: 914.10 MB/s
size: 8 KB, times: 100, time: 2.30 ms, speed: 1357.65 MB/s
size: 16 KB, times: 100, time: 3.30 ms, speed: 1895.56 MB/s
size: 32 KB, times: 100, time: 5.24 ms, speed: 2389.53 MB/s
size: 64 KB, times: 100, time: 9.32 ms, speed: 2686.11 MB/s
size: 128 KB, times: 100, time: 17.47 ms, speed: 2864.10 MB/s
size: 256 KB, times: 100, time: 33.48 ms, speed: 2989.95 MB/s
size: 512 KB, times: 100, time: 56.00 ms, speed: 3574.83 MB/s
size: 1024 KB, times: 100, time: 100.72 ms, speed: 3975.28 MB/s
size: 2049 KB, times: 100, time: 190.09 ms, speed: 4212.65 MB/s
size: 4099 KB, times: 100, time: 631.77 ms, speed: 2535.04 MB/s
size: 8199 KB, times: 100, time: 1179.44 ms, speed: 2715.81 MB/s
size: 16399 KB, times: 100, time: 1722.16 ms, speed: 3719.89 MB/s
size: 32799 KB, times: 100, time: 4707.91 ms, speed: 2721.48 MB/s
Use cudaHostAlloc
size: 1 KB, times: 100, time: 1.38 ms, speed: 282.61 MB/s
size: 2 KB, times: 100, time: 1.70 ms, speed: 459.45 MB/s
size: 4 KB, times: 100, time: 2.23 ms, speed: 701.12 MB/s
size: 8 KB, times: 100, time: 1.39 ms, speed: 2244.05 MB/s
size: 16 KB, times: 100, time: 1.78 ms, speed: 3514.83 MB/s
size: 32 KB, times: 100, time: 2.41 ms, speed: 5188.56 MB/s
size: 64 KB, times: 100, time: 3.63 ms, speed: 6886.78 MB/s
size: 128 KB, times: 100, time: 6.12 ms, speed: 8178.19 MB/s
size: 256 KB, times: 100, time: 11.50 ms, speed: 8705.71 MB/s
size: 512 KB, times: 100, time: 22.31 ms, speed: 8972.61 MB/s
size: 1024 KB, times: 100, time: 43.41 ms, speed: 9223.45 MB/s
size: 2049 KB, times: 100, time: 85.69 ms, speed: 9345.31 MB/s
size: 4099 KB, times: 100, time: 170.05 ms, speed: 9418.02 MB/s
size: 8199 KB, times: 100, time: 338.59 ms, speed: 9460.30 MB/s
size: 16399 KB, times: 100, time: 675.79 ms, speed: 9479.64 MB/s
size: 32799 KB, times: 100, time: 1350.45 ms, speed: 9487.60 MB/s

I did some experiments on the speed and throughput of using malloc and cudaHostAlloc and found that they are not quite different. The following is the experiment result:

script:

https://github.com/gongweibao/CUDA-training/blob/master/utils/cuda_by_example/chapter10/copy_timed.cu

ENV:

TITAN X (Pascal), Driver Version: 390.25

Use malloc:
size: 1 KB, times: 100, time: 1.08 ms, speed: 362.61 MB/s
size: 2 KB, times: 100, time: 1.19 ms, speed: 656.31 MB/s
size: 4 KB, times: 100, time: 1.35 ms, speed: 1156.82 MB/s
size: 8 KB, times: 100, time: 1.39 ms, speed: 2245.86 MB/s
size: 16 KB, times: 100, time: 2.22 ms, speed: 2816.84 MB/s
size: 32 KB, times: 100, time: 3.30 ms, speed: 3790.98 MB/s
size: 64 KB, times: 100, time: 5.42 ms, speed: 4613.19 MB/s
size: 128 KB, times: 100, time: 10.12 ms, speed: 4944.61 MB/s
size: 256 KB, times: 100, time: 19.11 ms, speed: 5238.42 MB/s
size: 512 KB, times: 100, time: 25.57 ms, speed: 7829.26 MB/s
size: 1024 KB, times: 100, time: 43.71 ms, speed: 9159.32 MB/s
size: 2049 KB, times: 100, time: 80.91 ms, speed: 9897.43 MB/s
size: 4099 KB, times: 100, time: 154.16 ms, speed: 10388.98 MB/s
size: 8199 KB, times: 100, time: 302.05 ms, speed: 10604.63 MB/s
size: 16399 KB, times: 100, time: 598.15 ms, speed: 10710.10 MB/s
size: 32799 KB, times: 100, time: 1190.88 ms, speed: 10758.87 MB/s

Use cudaHostAlloc
size: 1 KB, times: 100, time: 0.82 ms, speed: 475.48 MB/s
size: 2 KB, times: 100, time: 1.07 ms, speed: 731.14 MB/s
size: 4 KB, times: 100, time: 1.13 ms, speed: 1389.40 MB/s
size: 8 KB, times: 100, time: 1.29 ms, speed: 2423.08 MB/s
size: 16 KB, times: 100, time: 1.47 ms, speed: 4248.98 MB/s
size: 32 KB, times: 100, time: 1.92 ms, speed: 6511.58 MB/s
size: 64 KB, times: 100, time: 3.25 ms, speed: 7688.38 MB/s
size: 128 KB, times: 100, time: 5.49 ms, speed: 9118.07 MB/s
size: 256 KB, times: 100, time: 10.26 ms, speed: 9758.02 MB/s
size: 512 KB, times: 100, time: 19.18 ms, speed: 10437.97 MB/s
size: 1024 KB, times: 100, time: 37.40 ms, speed: 10706.02 MB/s
size: 2049 KB, times: 100, time: 73.04 ms, speed: 10963.21 MB/s
size: 4099 KB, times: 100, time: 146.02 ms, speed: 10968.04 MB/s
size: 8199 KB, times: 100, time: 286.47 ms, speed: 11181.17 MB/s
size: 16399 KB, times: 100, time: 578.72 ms, speed: 11069.77 MB/s
size: 32799 KB, times: 100, time: 1142.06 ms, speed: 11218.78 MB/s

P40 (Pascal), Driver Version: 384.66

Use malloc:
size: 1 KB, times: 100, time: 0.85 ms, speed: 458.50 MB/s
size: 2 KB, times: 100, time: 0.97 ms, speed: 805.18 MB/s
size: 4 KB, times: 100, time: 1.20 ms, speed: 1308.48 MB/s
size: 8 KB, times: 100, time: 1.39 ms, speed: 2254.14 MB/s
size: 16 KB, times: 100, time: 2.10 ms, speed: 2977.10 MB/s
size: 32 KB, times: 100, time: 3.20 ms, speed: 3914.29 MB/s
size: 64 KB, times: 100, time: 5.40 ms, speed: 4633.48 MB/s
size: 128 KB, times: 100, time: 9.79 ms, speed: 5112.01 MB/s
size: 256 KB, times: 100, time: 18.77 ms, speed: 5332.97 MB/s
size: 512 KB, times: 100, time: 28.38 ms, speed: 7054.05 MB/s
size: 1024 KB, times: 100, time: 47.15 ms, speed: 8491.00 MB/s
size: 2049 KB, times: 100, time: 85.02 ms, speed: 9419.26 MB/s
size: 4099 KB, times: 100, time: 159.44 ms, speed: 10044.86 MB/s
size: 8199 KB, times: 100, time: 312.00 ms, speed: 10266.41 MB/s
size: 16399 KB, times: 100, time: 722.78 ms, speed: 8863.34 MB/s
size: 32799 KB, times: 100, time: 1265.12 ms, speed: 10127.48 MB/s

Use cudaHostAlloc
size: 1 KB, times: 100, time: 1.05 ms, speed: 372.06 MB/s
size: 2 KB, times: 100, time: 1.15 ms, speed: 678.50 MB/s
size: 4 KB, times: 100, time: 1.31 ms, speed: 1194.77 MB/s
size: 8 KB, times: 100, time: 1.29 ms, speed: 2421.46 MB/s
size: 16 KB, times: 100, time: 1.57 ms, speed: 3985.39 MB/s
size: 32 KB, times: 100, time: 2.13 ms, speed: 5860.67 MB/s
size: 64 KB, times: 100, time: 3.26 ms, speed: 7683.85 MB/s
size: 128 KB, times: 100, time: 5.49 ms, speed: 9111.17 MB/s
size: 256 KB, times: 100, time: 10.24 ms, speed: 9771.37 MB/s
size: 512 KB, times: 100, time: 19.74 ms, speed: 10140.70 MB/s
size: 1024 KB, times: 100, time: 55.27 ms, speed: 7243.99 MB/s
size: 2049 KB, times: 100, time: 110.98 ms, speed: 7215.68 MB/s
size: 4099 KB, times: 100, time: 181.64 ms, speed: 8817.33 MB/s
size: 8199 KB, times: 100, time: 303.17 ms, speed: 10565.59 MB/s
size: 16399 KB, times: 100, time: 624.28 ms, speed: 10261.75 MB/s
size: 32799 KB, times: 100, time: 1260.01 ms, speed: 10168.57 MB/s

@helinwang

This comment has been minimized.

Copy link
Contributor

commented Mar 16, 2018

Thanks for the data points!
@chengduoZH thanks for pointing out driver version! Very important metric. A note from our friends in Nvidia is it's good to use the exact same GPU model when downloading driver from Nvidia website. For example, if our TitanX needs driver version 390.25, download TitanX 390.25 driver, rather than GTX 1080 390.25 driver. They said even the version number is exactly the same, the driver could have different configurations for different GPU models. Hope this information helps.

@gongweibao

This comment has been minimized.

Copy link
Contributor

commented Mar 22, 2018

ENV: K40m, Cuda8.0, driver:390.12,384.66

size: 1 KB, times: 100, time: 0.92 ms, speed: 423.41 MB/s
size: 2 KB, times: 100, time: 1.12 ms, speed: 695.70 MB/s
size: 4 KB, times: 100, time: 1.59 ms, speed: 983.73 MB/s
size: 8 KB, times: 100, time: 2.32 ms, speed: 1345.28 MB/s
size: 16 KB, times: 100, time: 3.35 ms, speed: 1867.07 MB/s
size: 32 KB, times: 100, time: 5.90 ms, speed: 2120.15 MB/s
size: 64 KB, times: 100, time: 11.03 ms, speed: 2268.46 MB/s
size: 128 KB, times: 100, time: 21.11 ms, speed: 2370.57 MB/s
size: 256 KB, times: 100, time: 41.30 ms, speed: 2423.49 MB/s
size: 512 KB, times: 100, time: 72.14 ms, speed: 2774.96 MB/s
size: 1024 KB, times: 100, time: 99.07 ms, speed: 4041.62 MB/s
size: 2049 KB, times: 100, time: 187.42 ms, speed: 4272.58 MB/s
size: 4099 KB, times: 100, time: 364.13 ms, speed: 4398.27 MB/s
size: 8199 KB, times: 100, time: 717.44 ms, speed: 4464.67 MB/s
size: 16399 KB, times: 100, time: 1425.32 ms, speed: 4494.60 MB/s
size: 32799 KB, times: 100, time: 2838.87 ms, speed: 4513.23 MB/s
size: 1 KB, times: 100, time: 1.18 ms, speed: 329.92 MB/s
size: 2 KB, times: 100, time: 1.35 ms, speed: 577.79 MB/s
size: 4 KB, times: 100, time: 1.75 ms, speed: 892.75 MB/s
size: 8 KB, times: 100, time: 1.33 ms, speed: 2357.79 MB/s
size: 16 KB, times: 100, time: 1.64 ms, speed: 3823.42 MB/s
size: 32 KB, times: 100, time: 2.23 ms, speed: 5600.92 MB/s
size: 64 KB, times: 100, time: 3.43 ms, speed: 7296.35 MB/s
size: 128 KB, times: 100, time: 5.95 ms, speed: 8406.51 MB/s
size: 256 KB, times: 100, time: 10.99 ms, speed: 9108.12 MB/s
size: 512 KB, times: 100, time: 20.97 ms, speed: 9544.79 MB/s
size: 1024 KB, times: 100, time: 40.91 ms, speed: 9787.75 MB/s
size: 2049 KB, times: 100, time: 80.66 ms, speed: 9928.00 MB/s
size: 4099 KB, times: 100, time: 160.26 ms, speed: 9993.63 MB/s
size: 8199 KB, times: 100, time: 319.19 ms, speed: 10035.26 MB/s
size: 16399 KB, times: 100, time: 637.24 ms, speed: 10053.19 MB/s
size: 32799 KB, times: 100, time: 1273.19 ms, speed: 10063.29 MB/s

It seems that hardware arch, not the Cuda version or driver version affects the copy speed.
I agree to optimize performance on P40, not on Tesla series.

@typhoonzero

This comment has been minimized.

Copy link
Contributor Author

commented Mar 26, 2018

Latest updates: after finish above optimizations, run vgg16 with fc size 512, distributed training with GPU can gain 64% overall performance of the theoretical performance, when increase the fc size to 4096, it go down to 33%.

It seems that the larger the tensor is, the slower the distributed training is. That means send_op and listen_and_serv_op takes too much time transfering data over the network.

@gongweibao

This comment has been minimized.

Copy link
Contributor

commented Mar 27, 2018

进展:

  • 我们现在的实现有问题:
    我们现在实现的server端是单线程的。 从代码 文档ISSUE中发现:grpc内部不创建任何线程,哪个线程注册处理函数,就在哪个线程中处理。(这可以解释CPU压不上去,特别是之前没有实现zerocopy的情况下。)
Add a completion queue for handling asynchronous services.
Best performance is typically obtained by using one thread per polling completion queue.
brpc:当请求包小于16KB时,单连接下的吞吐超过了多连接的ubrpc_mc和thrift_mc,
随着请求包变大,内核对单个连接的写入速度成为瓶颈。而多连接下的brpc则达到了测试中最高的2.3GB/s。

看上去需要用多链接也就是多channel来链接server会性能提升。我在同一台机器测试client和server的吞吐,单个进程可以到1GB/s左右;m个client和m个server一对一可以到2GB。

看上去单send/recv的时间还有50%的时间可以搞掉!

@typhoonzero typhoonzero changed the title Improve GPU distributed training performance Improve Fluid Distributed Training performance May 4, 2018

@shanyi15

This comment has been minimized.

Copy link
Collaborator

commented Aug 15, 2018

您好,此issue在近一个月内暂无更新,我们将于今天内关闭。若在关闭后您仍需跟进提问,可重新开启此问题,我们将在24小时内回复您。因关闭带来的不便我们深表歉意,请您谅解~感谢您对PaddlePaddle的支持!
Hello, this issue has not been updated in the past month. We will close it today for the sake of other user‘s experience. If you still need to follow up on this question after closing, please feel free to reopen it. In that case, we will get back to you within 24 hours. We apologize for the inconvenience caused by the closure and thank you so much for your support of PaddlePaddle Group!

@shanyi15 shanyi15 closed this Aug 15, 2018

PaddlePaddle Distributed Refactoring (Due: 201802) automation moved this from Perf TODOs to DONE Aug 15, 2018

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.