From fca45106f3fd542d464ef6a442335905ebda2feb Mon Sep 17 00:00:00 2001 From: Tao Peng Date: Thu, 30 Jun 2022 02:47:38 +0000 Subject: [PATCH] [Runtime] Fix shared resource_mgr double free corruption bug. --- tensorflow/core/common_runtime/device.cc | 3 +++ tensorflow/core/common_runtime/device.h | 5 ++++- tensorflow/core/common_runtime/direct_session.cc | 2 +- tensorflow/core/public/session.h | 7 +++++++ 4 files changed, 15 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/common_runtime/device.cc b/tensorflow/core/common_runtime/device.cc index 94000f1d940..896e64ce8a2 100644 --- a/tensorflow/core/common_runtime/device.cc +++ b/tensorflow/core/common_runtime/device.cc @@ -28,6 +28,7 @@ Device::Device(Env* env, const DeviceAttributes& device_attributes) CHECK(DeviceNameUtils::ParseFullName(name(), &parsed_name_)) << "Invalid device name: " << name(); rmgr_ = new ResourceMgr(parsed_name_.job); + owned_rmgr_ = true; } Device::Device(Env* env, const DeviceAttributes& device_attributes, @@ -41,8 +42,10 @@ Device::Device(Env* env, const DeviceAttributes& device_attributes, dev_rmgr_map->device_rmgr_map.end()) { rmgr_ = const_cast(dev_rmgr_map)->device_rmgr_map[name()]; LOG(INFO) << "Device " << name() << " got a shared resource_mgr: " << rmgr_; + owned_rmgr_ = false; } else { rmgr_ = new ResourceMgr(parsed_name_.job); + owned_rmgr_ = true; } } diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h index f28b41e63b2..e0326934fdc 100644 --- a/tensorflow/core/common_runtime/device.h +++ b/tensorflow/core/common_runtime/device.h @@ -191,7 +191,9 @@ class Device : public DeviceBase { protected: void DeleteResourceMgr() { - delete rmgr_; + if (owned_rmgr_) { + delete rmgr_; + } rmgr_ = nullptr; } @@ -204,6 +206,7 @@ class Device : public DeviceBase { // Resources associated w/ this device. E.g., shared variables, etc. ResourceMgr* rmgr_ = nullptr; + bool owned_rmgr_ = true; TF_DISALLOW_COPY_AND_ASSIGN(Device); }; diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index daad3c3103e..3e4d654f6a5 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -288,7 +288,7 @@ class DirectSessionFactory : public SessionFactory { DeviceMgr* device_mgr = new DeviceMgr(std::move(devices)); - SessionGroup* session_group = new SessionGroup(); + SessionGroup* session_group = new SessionGroup(shared_rmgr); #ifdef TENSORFLOW_USE_NUMA DirectSession* leader_session = new DirectSession(options, device_mgr, true, this, diff --git a/tensorflow/core/public/session.h b/tensorflow/core/public/session.h index 81eb9714a5b..b1cba212ea9 100644 --- a/tensorflow/core/public/session.h +++ b/tensorflow/core/public/session.h @@ -30,6 +30,7 @@ limitations under the License. namespace tensorflow { class DeviceMgr; +class ResourceMgr; namespace thread { @@ -270,7 +271,12 @@ class Session { class SessionGroup { public: + SessionGroup() : shared_resource_mgr_(nullptr) {} + SessionGroup(ResourceMgr* mgr) : shared_resource_mgr_(mgr) {} ~SessionGroup() { + if (shared_resource_mgr_) { + delete shared_resource_mgr_; + } } Status Close() { @@ -375,6 +381,7 @@ class SessionGroup { std::vector> sessions_; int32_t session_num_ = 0; std::atomic serving_index_{0}; + ResourceMgr* shared_resource_mgr_ = nullptr; Status GetServingSessionId(int32_t* serving_id, int32_t hint_id = -1) { if (session_num_ < 1) {