From 693152e568eeeddd53eeb9b682dcdc51c709d93f Mon Sep 17 00:00:00 2001 From: Tao Peng Date: Wed, 29 Jun 2022 09:30:13 +0000 Subject: [PATCH] [Runtime] Share variable resource between GPUCompatibleCPUDevice. --- .../core/common_runtime/direct_session.cc | 10 ++++----- .../common_runtime/gpu/gpu_device_factory.cc | 22 ++++++++++++++++++- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index daad3c3103e..013053435b1 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -274,12 +274,10 @@ class DirectSessionFactory : public SessionFactory { ResourceMgr* shared_rmgr = new ResourceMgr("localhost"); DeviceResourceMgrMap dev_rmgr_map; std::string dev_prefix("/job:localhost/replica:0/task:0"); - for (int i = 0; i < session_num; ++i) { - std::string dev_name = dev_prefix + "/device:CPU:" + std::to_string(i); - dev_rmgr_map.device_rmgr_map[dev_name] = shared_rmgr; - dev_name = dev_prefix + "/device:cpu:" + std::to_string(i); - dev_rmgr_map.device_rmgr_map[dev_name] = shared_rmgr; - } + dev_rmgr_map.device_rmgr_map[dev_prefix+"/device:CPU:0"] = shared_rmgr; + dev_rmgr_map.device_rmgr_map[dev_prefix+"/device:cpu:0"] = shared_rmgr; + dev_rmgr_map.device_rmgr_map["/device:CPU:0"] = shared_rmgr; + dev_rmgr_map.device_rmgr_map["/device:cpu:0"] = shared_rmgr; std::vector> devices; TF_RETURN_IF_ERROR(DeviceFactory::AddDevices( diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc index e6b25209661..2d00d8c1830 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device_factory.cc @@ -90,6 +90,19 @@ class GPUCompatibleCPUDevice : public ThreadPoolDevice { options.config.gpu_options().force_gpu_compatible(); } } + GPUCompatibleCPUDevice(const SessionOptions& options, const string& name, + Bytes memory_limit, const DeviceLocality& locality, + Allocator* allocator, + const DeviceResourceMgrMap* dev_rmgr_map) + : ThreadPoolDevice(options, name, memory_limit, + locality, allocator, dev_rmgr_map), + numa_node_(locality.numa_node()) { + if (options.config.has_gpu_options()) { + force_gpu_compatible_ = + options.config.gpu_options().force_gpu_compatible(); + } + } + ~GPUCompatibleCPUDevice() override {} Allocator* GetAllocator(AllocatorAttributes attr) override { @@ -118,6 +131,12 @@ class GPUCompatibleCPUDeviceFactory : public DeviceFactory { Status CreateDevices(const SessionOptions& options, const string& name_prefix, std::vector>* devices) override { + return CreateDevices(options, name_prefix, devices, nullptr); + } + + Status CreateDevices(const SessionOptions& options, const string& name_prefix, + std::vector>* devices, + const DeviceResourceMgrMap* dev_rmgr_map) override { int n = 1; auto iter = options.config.device_count().find("CPU"); if (iter != options.config.device_count().end()) { @@ -133,7 +152,8 @@ class GPUCompatibleCPUDeviceFactory : public DeviceFactory { locality.set_numa_node(numa_node); devices->push_back(absl::make_unique( options, name, Bytes(256 << 20), DeviceLocality(), - ProcessState::singleton()->GetCPUAllocator(numa_node))); + ProcessState::singleton()->GetCPUAllocator(numa_node), + dev_rmgr_map)); } return Status::OK();