diff --git a/src/Common/ZooKeeper/ZooKeeperImpl.cpp b/src/Common/ZooKeeper/ZooKeeperImpl.cpp index 79a975e683fe..f97bf292198a 100644 --- a/src/Common/ZooKeeper/ZooKeeperImpl.cpp +++ b/src/Common/ZooKeeper/ZooKeeperImpl.cpp @@ -299,11 +299,8 @@ ZooKeeper::~ZooKeeper() { finalize(false, false, "Destructor called"); - if (send_thread.joinable()) - send_thread.join(); - - if (receive_thread.joinable()) - receive_thread.join(); + send_thread.join(); + receive_thread.join(); } catch (...) { @@ -365,11 +362,8 @@ ZooKeeper::ZooKeeper( { tryLogCurrentException(log, "Failed to connect to ZooKeeper"); - if (send_thread.joinable()) - send_thread.join(); - - if (receive_thread.joinable()) - receive_thread.join(); + send_thread.join(); + receive_thread.join(); throw; } @@ -914,8 +908,7 @@ void ZooKeeper::finalize(bool error_send, bool error_receive, const String & rea } /// Send thread will exit after sending close request or on expired flag - if (send_thread.joinable()) - send_thread.join(); + send_thread.join(); } /// Set expired flag after we sent close event @@ -932,7 +925,7 @@ void ZooKeeper::finalize(bool error_send, bool error_receive, const String & rea tryLogCurrentException(log); } - if (!error_receive && receive_thread.joinable()) + if (!error_receive) receive_thread.join(); { diff --git a/src/Common/ZooKeeper/ZooKeeperImpl.h b/src/Common/ZooKeeper/ZooKeeperImpl.h index 91c5083bda10..9fff12309bd6 100644 --- a/src/Common/ZooKeeper/ZooKeeperImpl.h +++ b/src/Common/ZooKeeper/ZooKeeperImpl.h @@ -255,8 +255,30 @@ class ZooKeeper final : public IKeeper Watches watches TSA_GUARDED_BY(watches_mutex); std::mutex watches_mutex; - ThreadFromGlobalPool send_thread; - ThreadFromGlobalPool receive_thread; + /// A wrapper around ThreadFromGlobalPool that allows to call join() on it from multiple threads. + class ThreadReference + { + public: + const ThreadReference & operator = (ThreadFromGlobalPool && thread_) + { + std::lock_guard l(lock); + thread = std::move(thread_); + return *this; + } + + void join() + { + std::lock_guard l(lock); + if (thread.joinable()) + thread.join(); + } + private: + std::mutex lock; + ThreadFromGlobalPool thread; + }; + + ThreadReference send_thread; + ThreadReference receive_thread; Poco::Logger * log;