From bdd785461e820b740f3b0ec37fabc52073f8e92f Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Sun, 18 Mar 2018 06:01:02 -0700 Subject: [PATCH] Add memcpy_stream and use multi-thread to issue parameter copies On TitanX with 4 devices, the se-resnext step time reduces from ~1.18 to ~1.10 --- paddle/fluid/framework/tensor_util.cc | 15 +++++++------- paddle/fluid/framework/threadpool.h | 2 +- paddle/fluid/operators/parallel_do_op.cc | 26 +++++++++++++++--------- paddle/fluid/platform/device_context.cc | 4 ++++ paddle/fluid/platform/device_context.h | 7 ++++++- paddle/fluid/platform/device_tracer.cc | 8 ++++++++ paddle/fluid/platform/device_tracer.h | 2 ++ paddle/fluid/platform/profiler.cc | 8 -------- 8 files changed, 45 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 8b7533ce712b0..d7c90f92d8b40 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -45,9 +45,9 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, PADDLE_ENFORCE(platform::is_gpu_place(ctx_place)); auto ctx_gpu_place = boost::get(ctx_place); PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place); - memory::Copy( - dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size, - reinterpret_cast(ctx).stream()); + memory::Copy(dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size, + reinterpret_cast(ctx) + .memcpy_stream()); } else if (platform::is_cpu_place(src_place) && platform::is_gpu_place(dst_place)) { auto src_cpu_place = boost::get(src_place); @@ -58,7 +58,8 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, PADDLE_ENFORCE_EQ(dst_gpu_place, ctx_gpu_place); memory::Copy( dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, - reinterpret_cast(ctx).stream()); + reinterpret_cast(ctx) + .memcpy_stream()); } else if (platform::is_gpu_place(src_place) && platform::is_gpu_place(dst_place)) { auto src_gpu_place = boost::get(src_place); @@ -67,9 +68,9 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, PADDLE_ENFORCE(platform::is_gpu_place(ctx_place)); auto ctx_gpu_place = boost::get(ctx_place); PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place); - memory::Copy( - dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, - reinterpret_cast(ctx).stream()); + memory::Copy(dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, + reinterpret_cast(ctx) + .memcpy_stream()); } #endif } diff --git a/paddle/fluid/framework/threadpool.h b/paddle/fluid/framework/threadpool.h index df51fb24a588c..0b51ae66174e2 100644 --- a/paddle/fluid/framework/threadpool.h +++ b/paddle/fluid/framework/threadpool.h @@ -54,7 +54,7 @@ class ThreadPool { template std::future Run(Callback fn) { auto f = this->RunAndGetException(fn); - return std::async(std::launch::deferred, ExceptionHandler(std::move(f))); + return std::async(std::launch::async, ExceptionHandler(std::move(f))); } template diff --git a/paddle/fluid/operators/parallel_do_op.cc b/paddle/fluid/operators/parallel_do_op.cc index 4001b9a130348..279d494f5acf9 100644 --- a/paddle/fluid/operators/parallel_do_op.cc +++ b/paddle/fluid/operators/parallel_do_op.cc @@ -140,16 +140,22 @@ class ParallelDoOp : public framework::OperatorBase { Inputs(kInputs)); // copy parameter - for (auto ¶m : Inputs(kParameters)) { - PADDLE_ENFORCE(scope.FindVar(param)->IsType(), - "Only support parameter type as LoDTensor"); - auto &src = scope.FindVar(param)->Get(); - for (size_t i = 0; i < sub_scopes.size(); ++i) { - auto &place = places[i]; - auto *sub_scope = sub_scopes[i]; - auto *dst = sub_scope->Var(param)->GetMutable(); - framework::TensorCopy(src, place, dst); - } + std::vector> memcpy_workers; + memcpy_workers.reserve(places.size()); + for (size_t place_idx = 0; place_idx < sub_scopes.size(); ++place_idx) { + auto &place = places[place_idx]; + auto *cur_scope = sub_scopes[place_idx]; + memcpy_workers.emplace_back( + framework::Async([this, &place, cur_scope, &scope, place_idx]() { + for (auto ¶m : Inputs(kParameters)) { + auto &src = scope.FindVar(param)->Get(); + auto *dst = cur_scope->Var(param)->GetMutable(); + framework::TensorCopy(src, place, dst); + } + })); + } + for (auto &worker : memcpy_workers) { + worker.wait(); } WaitOnPlaces(places); diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 98b4178177b0a..fd8388e306a11 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -131,6 +131,7 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) { multi_process = GetCUDAMultiProcessors(place_.device); max_threads_per_mp = GetCUDAMaxThreadsPerMultiProcessor(place_.device); PADDLE_ENFORCE(cudaStreamCreate(&stream_)); + PADDLE_ENFORCE(cudaStreamCreate(&memcpy_stream_)); eigen_stream_.reset(new EigenCudaStreamDevice()); eigen_stream_->Reinitialize(&stream_, place); eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); @@ -154,6 +155,7 @@ CUDADeviceContext::~CUDADeviceContext() { eigen_stream_.reset(); eigen_device_.reset(); PADDLE_ENFORCE(cudaStreamDestroy(stream_)); + PADDLE_ENFORCE(cudaStreamDestroy(memcpy_stream_)); } Place CUDADeviceContext::GetPlace() const { return place_; } @@ -183,6 +185,8 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; } cudaStream_t CUDADeviceContext::stream() const { return stream_; } +cudaStream_t CUDADeviceContext::memcpy_stream() const { return memcpy_stream_; } + #endif #ifdef PADDLE_WITH_MKLDNN diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 603b890af13b5..fc05e822e852d 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -94,9 +94,12 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return cudnn handle in the device context. */ cudnnHandle_t cudnn_handle() const; - /*! \brief Return cuda stream in the device context. */ + /*! \brief Return cuda stream in the device context for computation. */ cudaStream_t stream() const; + /*! \brief Return cuda stream in the device context for memory copy. */ + cudaStream_t memcpy_stream() const; + private: CUDAPlace place_; @@ -110,6 +113,8 @@ class CUDADeviceContext : public DeviceContext { int compute_capability; int multi_process; int max_threads_per_mp; + + cudaStream_t memcpy_stream_; }; template <> diff --git a/paddle/fluid/platform/device_tracer.cc b/paddle/fluid/platform/device_tracer.cc index 3b4437f576e1c..9c2cfd169d132 100644 --- a/paddle/fluid/platform/device_tracer.cc +++ b/paddle/fluid/platform/device_tracer.cc @@ -14,6 +14,8 @@ limitations under the License. */ #include "paddle/fluid/platform/device_tracer.h" #include +#include +#include #include #include #include @@ -418,5 +420,11 @@ void ClearCurThread() { cur_thread_id = 0; } int CurThread() { return cur_thread_id; } +uint64_t PosixInNsec() { + struct timeval tv; + gettimeofday(&tv, nullptr); + return 1000 * (static_cast(tv.tv_sec) * 1000000 + tv.tv_usec); +} + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/device_tracer.h b/paddle/fluid/platform/device_tracer.h index deb3d23f78635..a1ee5314797b2 100644 --- a/paddle/fluid/platform/device_tracer.h +++ b/paddle/fluid/platform/device_tracer.h @@ -101,5 +101,7 @@ int BlockDepth(); void SetCurThread(int thread_id); void ClearCurThread(); int CurThread(); + +uint64_t PosixInNsec(); } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/profiler.cc b/paddle/fluid/platform/profiler.cc index b25206ff35cc8..81c49900f7883 100644 --- a/paddle/fluid/platform/profiler.cc +++ b/paddle/fluid/platform/profiler.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/platform/profiler.h" -#include -#include #include #include #ifdef PADDLE_WITH_CUDA @@ -54,12 +52,6 @@ inline uint64_t GetTimeInNsec() { .count(); } -inline uint64_t PosixInNsec() { - struct timeval tv; - gettimeofday(&tv, nullptr); - return 1000 * (static_cast(tv.tv_sec) * 1000000 + tv.tv_usec); -} - Event::Event(EventKind kind, std::string name, uint32_t thread_id, const DeviceContext* dev_ctx) : kind_(kind), name_(name), thread_id_(thread_id), has_cuda_(false) {