Skip to content

Commit

Permalink
Merge pull request #11907 from reyoung/feature/use_dev_ctx_for_op
Browse files Browse the repository at this point in the history
Use std::map for Place <--> DeviceContext
  • Loading branch information
reyoung committed Jul 3, 2018
2 parents 71b1c39 + 2d0e559 commit 037ce12
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 32 deletions.
10 changes: 2 additions & 8 deletions paddle/fluid/framework/details/op_handle_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,10 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
#ifdef PADDLE_WITH_CUDA
if (!events_.empty()) { // Use event
std::function<void()> method = callback;
// NOTE(zcd): device context must be ordered here because RecordEvent
// will use a mutex to ensure the safe of multi-threads.
std::map<platform::DeviceContext *, platform::Place> ordered_ctxes;
for (auto &p : dev_ctxes_) {
ordered_ctxes.emplace(p.second, p.first);
}
for (auto &p : ordered_ctxes) {
method = [method, p, this]() {
static_cast<platform::CUDADeviceContext *>(p.first)->RecordEvent(
events_.at(boost::get<platform::CUDAPlace>(p.second).device),
static_cast<platform::CUDADeviceContext *>(p.second)->RecordEvent(
events_.at(boost::get<platform::CUDAPlace>(p.first).device),
method);
};
}
Expand Down
6 changes: 2 additions & 4 deletions paddle/fluid/framework/details/op_handle_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
// limitations under the License.

#pragma once
#include <map>
#include <string>
#include <vector>

#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/macros.h"
Expand Down Expand Up @@ -92,9 +92,7 @@ class OpHandleBase {

std::vector<VarHandleBase *> inputs_;
std::vector<VarHandleBase *> outputs_;
std::unordered_map<platform::Place, platform::DeviceContext *,
platform::PlaceHash>
dev_ctxes_;
std::map<platform::Place, platform::DeviceContext *> dev_ctxes_;

#ifdef PADDLE_WITH_CUDA
std::unordered_map<int, cudaEvent_t> events_;
Expand Down
3 changes: 1 addition & 2 deletions paddle/fluid/framework/details/reduce_and_gather.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ struct ReduceLoDTensor {
inline void GatherSelectedRows(
const std::vector<const SelectedRows *> &src_selecte_rows_,
const std::vector<platform::Place> &in_places,
const std::unordered_map<platform::Place, platform::DeviceContext *,
platform::PlaceHash> &dev_ctxes,
const std::map<platform::Place, platform::DeviceContext *> &dev_ctxes,
const platform::Place &out_place, SelectedRows *dst_selecte_rows) {
PADDLE_ENFORCE(!src_selecte_rows_.empty());

Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/platform/device_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/device_context.h"

#include <set>
#include <string>
#include <unordered_set>
#include <vector>
Expand All @@ -35,7 +36,7 @@ DeviceContextPool::DeviceContextPool(
const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0);
using PtrType = std::unique_ptr<DeviceContext>;
std::unordered_set<Place, PlaceHash> set;
std::set<Place> set;
for (auto& p : places) {
set.insert(p);
}
Expand Down
8 changes: 3 additions & 5 deletions paddle/fluid/platform/device_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ limitations under the License. */
#include <mkldnn.hpp>
#endif

#include <map>
#include "glog/logging.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"

#include "glog/logging.h"

namespace paddle {
namespace platform {

Expand Down Expand Up @@ -201,9 +201,7 @@ class DeviceContextPool {

private:
static DeviceContextPool* pool;
std::unordered_map<const platform::Place,
std::unique_ptr<platform::DeviceContext>, PlaceHash>
device_contexts_;
std::map<Place, std::unique_ptr<DeviceContext>> device_contexts_;
DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
};

Expand Down
15 changes: 3 additions & 12 deletions paddle/fluid/platform/place.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ struct CPUPlace {
// needed for variant equality comparison
inline bool operator==(const CPUPlace &) const { return true; }
inline bool operator!=(const CPUPlace &) const { return false; }
inline bool operator<(const CPUPlace &) const { return false; }
};

struct CUDAPlace {
Expand All @@ -42,6 +43,7 @@ struct CUDAPlace {
return device == o.device;
}
inline bool operator!=(const CUDAPlace &o) const { return !(*this == o); }
inline bool operator<(const CUDAPlace &o) const { return device < o.device; }

int device;
};
Expand All @@ -52,6 +54,7 @@ struct CUDAPinnedPlace {
// needed for variant equality comparison
inline bool operator==(const CUDAPinnedPlace &) const { return true; }
inline bool operator!=(const CUDAPinnedPlace &) const { return false; }
inline bool operator<(const CUDAPinnedPlace &) const { return false; }
};

struct IsCUDAPlace : public boost::static_visitor<bool> {
Expand Down Expand Up @@ -89,18 +92,6 @@ bool is_cuda_pinned_place(const Place &);
bool places_are_same_class(const Place &, const Place &);
bool is_same_place(const Place &, const Place &);

struct PlaceHash {
std::size_t operator()(const Place &p) const {
constexpr size_t num_dev_bits = 4;
std::hash<int> ihash;
size_t dev_id = 0;
if (is_gpu_place(p)) {
dev_id = boost::get<CUDAPlace>(p).device;
}
return ihash(dev_id << num_dev_bits | p.which());
}
};

std::ostream &operator<<(std::ostream &, const Place &);

template <typename Visitor>
Expand Down

0 comments on commit 037ce12

Please sign in to comment.