Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add memcpy_stream and use multi-thread to issue parameter copies #9170

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions paddle/fluid/framework/tensor_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto ctx_gpu_place = boost::get<platform::CUDAPlace>(ctx_place);
PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place);
memory::Copy(
dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
memory::Copy(dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx)
.memcpy_stream());
} else if (platform::is_cpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
auto src_cpu_place = boost::get<platform::CPUPlace>(src_place);
Expand All @@ -58,7 +58,8 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
PADDLE_ENFORCE_EQ(dst_gpu_place, ctx_gpu_place);
memory::Copy(
dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
reinterpret_cast<const platform::CUDADeviceContext&>(ctx)
.memcpy_stream());
} else if (platform::is_gpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
auto src_gpu_place = boost::get<platform::CUDAPlace>(src_place);
Expand All @@ -67,9 +68,9 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto ctx_gpu_place = boost::get<platform::CUDAPlace>(ctx_place);
PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place);
memory::Copy(
dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
memory::Copy(dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx)
.memcpy_stream());
}
#endif
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/threadpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class ThreadPool {
template <typename Callback>
std::future<void> Run(Callback fn) {
auto f = this->RunAndGetException(fn);
return std::async(std::launch::deferred, ExceptionHandler(std::move(f)));
return std::async(std::launch::async, ExceptionHandler(std::move(f)));
}

template <typename Callback>
Expand Down
26 changes: 16 additions & 10 deletions paddle/fluid/operators/parallel_do_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,22 @@ class ParallelDoOp : public framework::OperatorBase {
Inputs(kInputs));

// copy parameter
for (auto &param : Inputs(kParameters)) {
PADDLE_ENFORCE(scope.FindVar(param)->IsType<LoDTensor>(),
"Only support parameter type as LoDTensor");
auto &src = scope.FindVar(param)->Get<LoDTensor>();
for (size_t i = 0; i < sub_scopes.size(); ++i) {
auto &place = places[i];
auto *sub_scope = sub_scopes[i];
auto *dst = sub_scope->Var(param)->GetMutable<LoDTensor>();
framework::TensorCopy(src, place, dst);
}
std::vector<std::future<void>> memcpy_workers;
memcpy_workers.reserve(places.size());
for (size_t place_idx = 0; place_idx < sub_scopes.size(); ++place_idx) {
auto &place = places[place_idx];
auto *cur_scope = sub_scopes[place_idx];
memcpy_workers.emplace_back(
framework::Async([this, &place, cur_scope, &scope, place_idx]() {
for (auto &param : Inputs(kParameters)) {
auto &src = scope.FindVar(param)->Get<LoDTensor>();
auto *dst = cur_scope->Var(param)->GetMutable<LoDTensor>();
framework::TensorCopy(src, place, dst);
}
}));
}
for (auto &worker : memcpy_workers) {
worker.wait();
}
WaitOnPlaces(places);

Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/platform/device_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
multi_process = GetCUDAMultiProcessors(place_.device);
max_threads_per_mp = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
PADDLE_ENFORCE(cudaStreamCreate(&memcpy_stream_));
eigen_stream_.reset(new EigenCudaStreamDevice());
eigen_stream_->Reinitialize(&stream_, place);
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
Expand All @@ -154,6 +155,7 @@ CUDADeviceContext::~CUDADeviceContext() {
eigen_stream_.reset();
eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
PADDLE_ENFORCE(cudaStreamDestroy(memcpy_stream_));
}

Place CUDADeviceContext::GetPlace() const { return place_; }
Expand Down Expand Up @@ -183,6 +185,8 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }

cudaStream_t CUDADeviceContext::stream() const { return stream_; }

cudaStream_t CUDADeviceContext::memcpy_stream() const { return memcpy_stream_; }

#endif

#ifdef PADDLE_WITH_MKLDNN
Expand Down
7 changes: 6 additions & 1 deletion paddle/fluid/platform/device_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,12 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle() const;

/*! \brief Return cuda stream in the device context. */
/*! \brief Return cuda stream in the device context for computation. */
cudaStream_t stream() const;

/*! \brief Return cuda stream in the device context for memory copy. */
cudaStream_t memcpy_stream() const;

private:
CUDAPlace place_;

Expand All @@ -110,6 +113,8 @@ class CUDADeviceContext : public DeviceContext {
int compute_capability;
int multi_process;
int max_threads_per_mp;

cudaStream_t memcpy_stream_;
};

template <>
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/platform/device_tracer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ limitations under the License. */

#include "paddle/fluid/platform/device_tracer.h"
#include <google/protobuf/text_format.h>
#include <sys/time.h>
#include <time.h>
#include <fstream>
#include <map>
#include <mutex>
Expand Down Expand Up @@ -418,5 +420,11 @@ void ClearCurThread() { cur_thread_id = 0; }

int CurThread() { return cur_thread_id; }

uint64_t PosixInNsec() {
struct timeval tv;
gettimeofday(&tv, nullptr);
return 1000 * (static_cast<uint64_t>(tv.tv_sec) * 1000000 + tv.tv_usec);
}

} // namespace platform
} // namespace paddle
2 changes: 2 additions & 0 deletions paddle/fluid/platform/device_tracer.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,5 +101,7 @@ int BlockDepth();
void SetCurThread(int thread_id);
void ClearCurThread();
int CurThread();

uint64_t PosixInNsec();
} // namespace platform
} // namespace paddle
8 changes: 0 additions & 8 deletions paddle/fluid/platform/profiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/platform/profiler.h"
#include <sys/time.h>
#include <time.h>
#include <iomanip>
#include <map>
#ifdef PADDLE_WITH_CUDA
Expand Down Expand Up @@ -54,12 +52,6 @@ inline uint64_t GetTimeInNsec() {
.count();
}

inline uint64_t PosixInNsec() {
struct timeval tv;
gettimeofday(&tv, nullptr);
return 1000 * (static_cast<uint64_t>(tv.tv_sec) * 1000000 + tv.tv_usec);
}

Event::Event(EventKind kind, std::string name, uint32_t thread_id,
const DeviceContext* dev_ctx)
: kind_(kind), name_(name), thread_id_(thread_id), has_cuda_(false) {
Expand Down