Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/context library #7187

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions paddle/framework/library_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <cctype>

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -41,6 +42,10 @@ inline std::string LibraryTypeToString(const LibraryType& library_type) {

inline LibraryType StringToLibraryType(const char* ctype) {
std::string s(ctype);
for (size_t i = 0; i < s.size(); ++i) {
s[i] = toupper(s[i]);
}

if (s == std::string("PLAIN")) {
return LibraryType::kPlain;
} else if (s == std::string("MKLDNN")) {
Expand Down
12 changes: 5 additions & 7 deletions paddle/framework/op_kernel_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,11 @@ namespace framework {
struct OpKernelType {
struct Hash {
size_t operator()(const OpKernelType& key) const {
int place = key.place_.which() + (1 << LEFT_SHIFT);
int data_type =
static_cast<int>(key.data_type_) + (1 << (LEFT_SHIFT + 1));
int data_layout =
static_cast<int>(key.data_layout_) + (1 << (LEFT_SHIFT + 2));
int library_type =
static_cast<int>(key.library_type_) + (1 << (LEFT_SHIFT + 3));
int place = key.place_.which();
int data_type = static_cast<int>(key.data_type_) << LEFT_SHIFT;
int data_layout = static_cast<int>(key.data_layout_) << (LEFT_SHIFT * 2);
int library_type = static_cast<int>(key.library_type_)
<< (LEFT_SHIFT * 3);
std::hash<int> hasher;
return hasher(place + data_type + data_layout + library_type);
}
Expand Down
12 changes: 7 additions & 5 deletions paddle/platform/device_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ namespace platform {
DeviceContextPool* DeviceContextPool::pool = nullptr;

const platform::DeviceContext* DeviceContextPool::Get(
const platform::Place& place) {
auto it = device_contexts_.find(place);
const platform::Place& place,
const framework::LibraryType& library /* kPlain by default*/) {
auto it = device_contexts_.find(std::make_pair(place, library));
if (it == device_contexts_.end()) {
PADDLE_THROW(
"'Place' is not supported, Please re-compile with WITH_GPU "
Expand All @@ -29,11 +30,12 @@ const platform::DeviceContext* DeviceContextPool::Get(
}

DeviceContextPool::DeviceContextPool(
const std::vector<platform::Place>& places) {
const std::vector<platform::LibraryType>& places) {
PADDLE_ENFORCE_GT(places.size(), 0);

for (size_t i = 0; i < places.size(); i++) {
if (platform::is_cpu_place(places[i])) {
device_contexts_.emplace(places[i],
if (places[i] == platform::LibraryType::kPlain) {
device_contexts_.emplace(std::make_pair(CPUPlace(), places[i]),
new platform::CPUDeviceContext(
boost::get<platform::CPUPlace>(places[i])));
} else if (platform::is_gpu_place(places[i])) {
Expand Down
44 changes: 32 additions & 12 deletions paddle/platform/device_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ limitations under the License. */

#include <memory>
#include <unordered_map>
#include <utility>

#ifdef PADDLE_WITH_CUDA
#include "paddle/platform/dynload/cublas.h"
Expand All @@ -21,6 +22,7 @@ limitations under the License. */
#define EIGEN_USE_GPU
#endif

#include "paddle/framework/library_type.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
Expand Down Expand Up @@ -52,11 +54,12 @@ class CPUDeviceContext : public DeviceContext {
std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};

template <typename Place>
template <typename Place, typename Library>
struct DefaultDeviceContextType;

template <>
struct DefaultDeviceContextType<platform::CPUPlace> {
struct DefaultDeviceContextType<platform::CPUPlace,
framework::LibraryType::kPlain> {
using TYPE = CPUDeviceContext;
};

Expand Down Expand Up @@ -99,10 +102,17 @@ class CUDADeviceContext : public DeviceContext {
};

template <>
struct DefaultDeviceContextType<platform::CUDAPlace> {
struct DefaultDeviceContextType<platform::CUDAPlace,
framework::LibraryType::kPlain> {
using TYPE = CUDADeviceContext;
};

template <>
struct DefaultDeviceContextType<platform::CUDAPlace,
framework::LibraryType::kCUDNN> {
using TYPE = CUDNNDeviceContext;
};

class CUDNNDeviceContext : public CUDADeviceContext {
public:
explicit CUDNNDeviceContext(CUDAPlace place);
Expand All @@ -120,45 +130,55 @@ class CUDNNDeviceContext : public CUDADeviceContext {
/*! \brief device context pool singleton */
class DeviceContextPool {
public:
explicit DeviceContextPool(const std::vector<platform::Place>& places);
using DeviceContextId = std::pair<platform::Place, framework::LibraryType>;

explicit DeviceContextPool(const std::vector<platform::LibraryType>& places);

static DeviceContextPool& Instance() {
PADDLE_ENFORCE_NOT_NULL(pool, "Need to Create DeviceContextPool first!");
return *pool;
}

/*! \brief Create should only called by Init function */
static DeviceContextPool& Init(const std::vector<platform::Place>& places) {
static DeviceContextPool& Init(
const std::vector<framework::LibraryType>& places) {
if (pool == nullptr) {
pool = new DeviceContextPool(places);
}
return *pool;
}

/*! \brief Return handle of single device context. */
const platform::DeviceContext* Get(const platform::Place& place);
const platform::DeviceContext* Get(
const platform::Place& place,
const framework::LibraryType& library = framework::LibraryType::kPlain);

template <typename Place>
template <typename Place, typename Library>
const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace(
const Place& place) {
const Place& place,
const Library& library = framework::LibraryType::kPlain) {
return reinterpret_cast<
const typename DefaultDeviceContextType<Place>::TYPE*>(Get(place));
const typename DefaultDeviceContextType<Place, Library>::TYPE*>(
Get(place, library));
}

private:
static DeviceContextPool* pool;
constexpr static int LEFT_SHIFT = 8;
struct Hash {
std::hash<int> hash_;
size_t operator()(const platform::Place& place) const {
int pre_hash = place.which() + (1 << LEFT_SHIFT);
size_t operator()(const DeviceContextId& id) const {
auto place = id.first;
auto library = id.second;
int pre_hash = place.which() << LEFT_SHIFT;
pre_hash += (library << (LEFT_SHIFT * 2));
if (platform::is_gpu_place(place)) {
pre_hash += boost::get<platform::CUDAPlace>(place).GetDeviceId();
}
return hash_(pre_hash);
}
};
std::unordered_map<const platform::Place, const platform::DeviceContext*,
std::unordered_map<const DeviceContextId, const platform::DeviceContext*,
Hash>
device_contexts_;
DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
Expand Down