From e46996bd9dd1134847e2a3dfb796f1e0ebc450a7 Mon Sep 17 00:00:00 2001 From: mhouston Date: Fri, 10 Jul 2015 16:05:48 -0700 Subject: [PATCH] Detect topology corner cases and improve broadcast order - Start with distant nodes in broadcast - Fix outside loop to loop for full tree depth --- src/caffe/parallel.cpp | 71 ++++++++++++++++++++++++------------------ 1 file changed, 40 insertions(+), 31 deletions(-) diff --git a/src/caffe/parallel.cpp b/src/caffe/parallel.cpp index ef04a21268e..4ba8015a865 100644 --- a/src/caffe/parallel.cpp +++ b/src/caffe/parallel.cpp @@ -119,18 +119,23 @@ void DevicePair::compute(const vector devices, vector* pairs) { #ifndef CPU_ONLY vector remaining(devices); + // Depth for reduction tree + int remaining_depth = static_cast(ceil(log2(remaining.size()))); + // Group GPUs by board - for (int i = 0; i < remaining.size(); ++i) { - for (int j = i + 1; j < remaining.size(); ++j) { - cudaDeviceProp a, b; - CUDA_CHECK(cudaGetDeviceProperties(&a, remaining[i])); - CUDA_CHECK(cudaGetDeviceProperties(&b, remaining[j])); - if (a.isMultiGpuBoard && b.isMultiGpuBoard) { - if (a.multiGpuBoardGroupID == b.multiGpuBoardGroupID) { - pairs->push_back(DevicePair(remaining[i], remaining[j])); - DLOG(INFO) << "GPU board: " << remaining[i] << ":" << remaining[j]; - remaining.erase(remaining.begin() + j); - break; + for (int d = 0; d < remaining_depth; ++d) { + for (int i = 0; i < remaining.size(); ++i) { + for (int j = i + 1; j < remaining.size(); ++j) { + cudaDeviceProp a, b; + CUDA_CHECK(cudaGetDeviceProperties(&a, remaining[i])); + CUDA_CHECK(cudaGetDeviceProperties(&b, remaining[j])); + if (a.isMultiGpuBoard && b.isMultiGpuBoard) { + if (a.multiGpuBoardGroupID == b.multiGpuBoardGroupID) { + pairs->push_back(DevicePair(remaining[i], remaining[j])); + DLOG(INFO) << "GPU board: " << remaining[i] << ":" << remaining[j]; + remaining.erase(remaining.begin() + j); + break; + } } } } @@ -142,15 +147,19 @@ void DevicePair::compute(const vector devices, vector* pairs) { DLOG(INFO) << "GPUs paired by boards, remaining: " << s.str(); // Group by P2P accessibility - for (int i = 0; i < remaining.size(); ++i) { - for (int j = i + 1; j < remaining.size(); ++j) { - int access; - CUDA_CHECK(cudaDeviceCanAccessPeer(&access, remaining[i], remaining[j])); - if (access) { - pairs->push_back(DevicePair(remaining[i], remaining[j])); - DLOG(INFO) << "P2P pair: " << remaining[i] << ":" << remaining[j]; - remaining.erase(remaining.begin() + j); - break; + remaining_depth = ceil(log2(remaining.size())); + for (int d = 0; d < remaining_depth; ++d) { + for (int i = 0; i < remaining.size(); ++i) { + for (int j = i + 1; j < remaining.size(); ++j) { + int access; + CUDA_CHECK( + cudaDeviceCanAccessPeer(&access, remaining[i], remaining[j])); + if (access) { + pairs->push_back(DevicePair(remaining[i], remaining[j])); + DLOG(INFO) << "P2P pair: " << remaining[i] << ":" << remaining[j]; + remaining.erase(remaining.begin() + j); + break; + } } } } @@ -161,15 +170,19 @@ void DevicePair::compute(const vector devices, vector* pairs) { DLOG(INFO) << "GPUs paired by P2P access, remaining: " << s.str(); // Group remaining - for (int i = 0; i < remaining.size(); ++i) { - for (int j = i + 1; j < remaining.size(); ++j) { - pairs->push_back(DevicePair(remaining[i], remaining[j])); - DLOG(INFO) << "Remaining pair: " << remaining[i] << ":" << remaining[j]; - remaining.erase(remaining.begin() + j); - break; + remaining_depth = ceil(log2(remaining.size())); + for (int d = 0; d < remaining_depth; ++d) { + for (int i = 0; i < remaining.size(); ++i) { + pairs->push_back(DevicePair(remaining[i], remaining[i + 1])); + DLOG(INFO) << "Remaining pair: " << remaining[i] << ":" + << remaining[i + 1]; + remaining.erase(remaining.begin() + i + 1); } } + + // Should only be the parent node remaining CHECK_EQ(remaining.size(), 1); + pairs->insert(pairs->begin(), DevicePair(-1, remaining[0])); CHECK(pairs->size() == devices.size()); @@ -294,7 +307,7 @@ void P2PSync::on_start(Timer* timer, ostringstream* timing) { if (children_.size()) { timer->Start(); } - for (int i = 0; i < children_.size(); ++i) { + for (int i = children_.size() - 1; i >= 0; i--) { Dtype* src = data_; Dtype* dst = children_[i]->data_; @@ -308,11 +321,7 @@ void P2PSync::on_start(Timer* timer, ostringstream* timing) { CUDA_CHECK(cudaMemcpyAsync(dst, src, size_ * sizeof(Dtype), // cudaMemcpyDeviceToDevice, cudaStreamDefault)); - } - if (children_.size()) { CUDA_CHECK(cudaStreamSynchronize(cudaStreamDefault)); - } - for (int i = 0; i < children_.size(); ++i) { children_[i]->queue_.push(this); } if (children_.size()) {