From 70deb50942128099c12b87f815079648719d1dcc Mon Sep 17 00:00:00 2001 From: Michal Zientkiewicz Date: Fri, 3 Feb 2023 17:20:58 +0100 Subject: [PATCH] Moving files. Signed-off-by: Michal Zientkiewicz --- dali/pipeline/util/new_thread_pool.h | 290 ------------------ dali/pipeline/util/thread_pool.cc | 15 +- dali/pipeline/util/thread_pool_base.cc | 156 ++++++++++ dali/pipeline/util/thread_pool_base.h | 170 ++++++++++ ..._pool_test.cc => thread_pool_base_test.cc} | 71 ++++- 5 files changed, 383 insertions(+), 319 deletions(-) delete mode 100644 dali/pipeline/util/new_thread_pool.h create mode 100644 dali/pipeline/util/thread_pool_base.cc create mode 100644 dali/pipeline/util/thread_pool_base.h rename dali/pipeline/util/{new_thread_pool_test.cc => thread_pool_base_test.cc} (60%) diff --git a/dali/pipeline/util/new_thread_pool.h b/dali/pipeline/util/new_thread_pool.h deleted file mode 100644 index 94af4773a7e..00000000000 --- a/dali/pipeline/util/new_thread_pool.h +++ /dev/null @@ -1,290 +0,0 @@ -// Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef DALI_PIPELINE_UTIL_NEW_THREAD_POOL_H_ -#define DALI_PIPELINE_UTIL_NEW_THREAD_POOL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "dali/core/call_at_exit.h" -#include "dali/core/error_handling.h" -#include "dali/core/multi_error.h" -#include "dali/core/mm/detail/aux_alloc.h" - -namespace dali { -namespace experimental { - -class ThreadPoolBase; - -/** - * @brief A collection of tasks, ordered by priority - * - * Tasks are added to a job first and then the entire work is scheduled as a whole. - */ -class Job { - public: - ~Job() noexcept(false) { - if (!tasks_.empty() && !waited_for_) { - std::lock_guard g(mtx_); - if (!tasks_.empty() && !waited_for_) { - throw std::logic_error("The job is not empty, but hasn't been scrapped or waited for."); - } - } - } - - using priority_t = int64_t; - - template - std::enable_if_t>> - AddTask(Runnable &&runnable, priority_t priority = {}) { - if (started_) - throw std::logic_error("This job has already been started - cannot add more tasks to it"); - auto it = tasks_.emplace(priority, Task()); - try { - it->second.func = [this, task = &it->second, f = std::move(runnable)](int tid) noexcept { - try { - f(tid); - } catch (...) { - task->error = std::current_exception(); - } - if (--num_pending_tasks_ == 0) { - std::lock_guard g(mtx_); - cv_.notify_one(); - } - }; - } catch (...) { // if, for whatever reason, we cannot initialize the task, we should erase it - tasks_.erase(it); - throw; - } - } - - template - void Run(Executor &executor, bool wait) { - if (started_) - throw std::logic_error("This job has already been started."); - started_ = true; - for (auto &x : tasks_) { - executor.AddTask(std::move(x.second.func)); - num_pending_tasks_++; // increase after successfully scheduling the task - the value - // may go below 0, but we don't care - } - if (wait) - Wait(); - } - - void Wait(); - - void Scrap() { - if (started_) - throw std::logic_error("Cannot scrap a job that has already been started"); - tasks_.clear(); - } - - private: - std::mutex mtx_; // could just probably use atomic_wait on num_pending_tasks_ - needs C++20 - std::condition_variable cv_; - std::atomic_int num_pending_tasks_{0}; - bool started_ = false; - bool waited_for_ = false; - - struct Task { - std::function func; - std::exception_ptr error; - }; - - // This needs to be a container which never invalidates references when inserting new items. - std::multimap, - mm::detail::object_pool_allocator>> tasks_; -}; - -class ThreadPoolBase { - public: - using TaskFunc = std::function; - - ThreadPoolBase() = default; - explicit ThreadPoolBase(int num_threads) { - Init(num_threads); - } - - void Init(int num_threads) { - std::lock_guard g(mtx_); - if (!threads_.empty()) - throw std::logic_error("The thread pool is already started!"); - stop_requested_ = false; - threads_.reserve(num_threads); - for (int i = 0; i < num_threads; i++) - threads_.push_back(std::thread(&ThreadPoolBase::Run, this, i)); - } - - ~ThreadPoolBase() { - Stop(); - } - - void AddTask(TaskFunc f); - - bool StopRequested() const noexcept { - return stop_requested_; - } - - static ThreadPoolBase *this_thread_pool() { - return this_thread_pool_; - } - - static int this_thread_idx() { - return this_thread_idx_; - } - - protected: - virtual void OnThreadStart(int thread_idx) noexcept {} - virtual void OnThreadStop(int thread_idx) noexcept {} - - friend class Job; - - void Stop() { - { - std::lock_guard g(mtx_); - stop_requested_ = true; - cv_.notify_all(); - } - - for (auto &t : threads_) - t.join(); - - { - std::lock_guard g(mtx_); - while (!tasks_.empty()) - tasks_.pop(); - - threads_.clear(); - } - } - - template - bool WaitOrRunTasks(std::condition_variable &cv, Condition &&condition) { - assert(this_thread_pool() == this); - std::unique_lock lock(mtx_); - while (!stop_requested_) { - bool ret; - while (!(ret = condition()) && !stop_requested_ && tasks_.empty()) - cv.wait_for(lock, std::chrono::microseconds(100)); - - if (ret || condition()) // re-evaluate the condition, just in case - return true; - if (stop_requested_) - return false; - assert(!tasks_.empty()); - - PopAndRunTask(lock); - } - return false; - } - - void PopAndRunTask(std::unique_lock &mtx); - - static thread_local ThreadPoolBase *this_thread_pool_; - static thread_local int this_thread_idx_; - - void Run(int index) noexcept; - - std::mutex mtx_; - std::condition_variable cv_; - bool stop_requested_ = false; - std::queue tasks_; - std::vector threads_; -}; - -////////////////////////////// - -void Job::Wait() { - if (!started_) - throw std::logic_error("This job hasn't been run - cannot wait for it."); - - if (waited_for_) - throw std::logic_error("This job has already been waited for."); - - auto ready = [&]() { return num_pending_tasks_ == 0; }; - - if (ThreadPoolBase::this_thread_pool() != nullptr) { - bool result = ThreadPoolBase::this_thread_pool()->WaitOrRunTasks(cv_, ready); - waited_for_ = true; - if (!result) - throw std::runtime_error("The thread pool was stopped"); - } else { - std::unique_lock lock(mtx_); - cv_.wait(lock, ready); - waited_for_ = true; - } - - // note - this vector is not allocated unless there were exceptions thrown - std::vector errors; - for (auto &x : tasks_) { - if (x.second.error) - errors.push_back(std::move(x.second.error)); - } - if (errors.size() == 1) - std::rethrow_exception(errors[0]); - else if (errors.size() > 1) - throw MultipleErrors(std::move(errors)); -} - - - -thread_local ThreadPoolBase *ThreadPoolBase::this_thread_pool_ = nullptr; -thread_local int ThreadPoolBase::this_thread_idx_ = -1;; - -inline void ThreadPoolBase::AddTask(TaskFunc f) { - { - std::lock_guard g(mtx_); - if (stop_requested_) - throw std::logic_error("The thread pool is stopped and no longer accepts new tasks."); - tasks_.push(std::move(f)); - } - cv_.notify_one(); -} - -inline void ThreadPoolBase::Run(int index) noexcept { - ThreadPoolBase *this_thread_pool_ = this; - this_thread_idx_ = index; - OnThreadStart(index); - detail::CallAtExit([&]() { OnThreadStop(index); }); - std::unique_lock lock(mtx_); - while (!stop_requested_) { - cv_.wait(lock, [&]() { return stop_requested_ || !tasks_.empty(); }); - if (stop_requested_) - break; - PopAndRunTask(lock); - } -} - -inline void ThreadPoolBase::PopAndRunTask(std::unique_lock &lock) { - TaskFunc t = std::move(tasks_.front()); - tasks_.pop(); - lock.unlock(); - t(this_thread_idx()); - lock.lock(); -} - - -} // namespace experimental -} // namespace dali - -#endif // DALI_PIPELINE_UTIL_NEW_THREAD_POOL_H_ diff --git a/dali/pipeline/util/thread_pool.cc b/dali/pipeline/util/thread_pool.cc index 6b46c17e67a..c6d4243fdbf 100644 --- a/dali/pipeline/util/thread_pool.cc +++ b/dali/pipeline/util/thread_pool.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2018-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -54,19 +54,6 @@ ThreadPool::~ThreadPool() { for (auto &thread : threads_) { thread.join(); } - - #pragma GCC diagnostic push -#ifdef __clang__ - #pragma GCC diagnostic ignored "-Wexceptions" -#else - #pragma GCC diagnostic ignored "-Wterminate" -#endif - - if (!work_queue_.empty()) - throw std::logic_error("There was outstanding work in the queue."); - - #pragma GCC diagnostic pop - #if NVML_ENABLED nvml::Shutdown(); #endif diff --git a/dali/pipeline/util/thread_pool_base.cc b/dali/pipeline/util/thread_pool_base.cc new file mode 100644 index 00000000000..38609c2fdb6 --- /dev/null +++ b/dali/pipeline/util/thread_pool_base.cc @@ -0,0 +1,156 @@ +// Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dali/pipeline/util/thread_pool_base.h" +#include "dali/core/call_at_exit.h" + +namespace dali { +namespace experimental { + +Job::~Job() noexcept(false) { + if (!tasks_.empty() && !waited_for_) { + std::lock_guard g(mtx_); + if (!tasks_.empty() && !waited_for_) { + throw std::logic_error("The job is not empty, but hasn't been scrapped or waited for."); + } + } +} + +void Job::Wait() { + if (!started_) + throw std::logic_error("This job hasn't been run - cannot wait for it."); + + if (waited_for_) + throw std::logic_error("This job has already been waited for."); + + auto ready = [&]() { return num_pending_tasks_ == 0; }; + + if (ThreadPoolBase::this_thread_pool() != nullptr) { + bool result = ThreadPoolBase::this_thread_pool()->WaitOrRunTasks(cv_, ready); + waited_for_ = true; + if (!result) + throw std::logic_error("The thread pool was stopped"); + } else { + std::unique_lock lock(mtx_); + cv_.wait(lock, ready); + waited_for_ = true; + } + + // note - this vector is not allocated unless there were exceptions thrown + std::vector errors; + for (auto &x : tasks_) { + if (x.second.error) + errors.push_back(std::move(x.second.error)); + } + if (errors.size() == 1) + std::rethrow_exception(errors[0]); + else if (errors.size() > 1) + throw MultipleErrors(std::move(errors)); +} + +void Job::Scrap() { + if (started_) + throw std::logic_error("Cannot scrap a job that has already been started"); + tasks_.clear(); +} + +/////////////////////////////////////////////////////////////////////////// + +thread_local ThreadPoolBase *ThreadPoolBase::this_thread_pool_ = nullptr; +thread_local int ThreadPoolBase::this_thread_idx_ = -1;; + +void ThreadPoolBase::Init(int num_threads) { + if (shutdown_pending_) + throw std::logic_error("The thread pool is being shut down."); + std::lock_guard g(mtx_); + if (!threads_.empty()) + throw std::logic_error("The thread pool is already started!"); + threads_.reserve(num_threads); + for (int i = 0; i < num_threads; i++) + threads_.push_back(std::thread(&ThreadPoolBase::Run, this, i)); +} + +void ThreadPoolBase::Shutdown() { + if (shutdown_pending_) + return; + { + std::lock_guard g(mtx_); + if (shutdown_pending_) + return; + shutdown_pending_ = true; + cv_.notify_all(); + } + + for (auto &t : threads_) + t.join(); + + assert(tasks_.empty()); +} + +void ThreadPoolBase::AddTask(TaskFunc f) { + { + std::lock_guard g(mtx_); + if (shutdown_pending_) + throw std::logic_error("The thread pool is stopped and no longer accepts new tasks."); + tasks_.push(std::move(f)); + } + cv_.notify_one(); +} + +void ThreadPoolBase::Run(int index) noexcept { + this_thread_pool_ = this; + this_thread_idx_ = index; + OnThreadStart(index); + detail::CallAtExit([&]() { OnThreadStop(index); }); + std::unique_lock lock(mtx_); + while (!shutdown_pending_ || !tasks_.empty()) { + cv_.wait(lock, [&]() { return shutdown_pending_ || !tasks_.empty(); }); + if (tasks_.empty()) + break; + PopAndRunTask(lock); + } +} + +void ThreadPoolBase::PopAndRunTask(std::unique_lock &lock) { + TaskFunc t = std::move(tasks_.front()); + tasks_.pop(); + lock.unlock(); + t(); + lock.lock(); +} + +template +bool ThreadPoolBase::WaitOrRunTasks(std::condition_variable &cv, Condition &&condition) { + assert(this_thread_pool() == this); + std::unique_lock lock(mtx_); + while (!shutdown_pending_ || !tasks_.empty()) { + bool ret; + while (!(ret = condition()) && tasks_.empty()) + cv.wait_for(lock, std::chrono::microseconds(100)); + + if (ret || condition()) // re-evaluate the condition, just in case + return true; + if (tasks_.empty()) { + assert(shutdown_pending_); + return condition(); + } + + PopAndRunTask(lock); + } + return condition(); +} + + +} // namespace experimental +} // namespace dali diff --git a/dali/pipeline/util/thread_pool_base.h b/dali/pipeline/util/thread_pool_base.h new file mode 100644 index 00000000000..54d43c9bd1d --- /dev/null +++ b/dali/pipeline/util/thread_pool_base.h @@ -0,0 +1,170 @@ +// Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_PIPELINE_UTIL_THREAD_POOL_BASE_H_ +#define DALI_PIPELINE_UTIL_THREAD_POOL_BASE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "dali/core/api_helper.h" +#include "dali/core/multi_error.h" +#include "dali/core/mm/detail/aux_alloc.h" + +namespace dali { +namespace experimental { + +class ThreadPoolBase; + +/** + * @brief A collection of tasks, ordered by priority + * + * Tasks are added to a job first and then the entire work is scheduled as a whole. + */ +class DLL_PUBLIC Job { + public: + ~Job() noexcept(false); + + using priority_t = int64_t; + + template + std::enable_if_t>> + AddTask(Runnable &&runnable, priority_t priority = {}) { + if (started_) + throw std::logic_error("This job has already been started - cannot add more tasks to it"); + auto it = tasks_.emplace(priority, Task()); + try { + it->second.func = [this, task = &it->second, f = std::move(runnable)]() noexcept { + try { + f(); + } catch (...) { + task->error = std::current_exception(); + } + if (--num_pending_tasks_ == 0) { + std::lock_guard g(mtx_); + cv_.notify_one(); + } + }; + } catch (...) { // if, for whatever reason, we cannot initialize the task, we should erase it + tasks_.erase(it); + throw; + } + } + + template + void Run(Executor &executor, bool wait) { + if (started_) + throw std::logic_error("This job has already been started."); + started_ = true; + for (auto &x : tasks_) { + executor.AddTask(std::move(x.second.func)); + num_pending_tasks_++; // increase after successfully scheduling the task - the value + // may hit 0 or go below if the task is done before we increment + // the counter, but we don't care if we aren't waiting yet + } + if (wait) + Wait(); + } + + void Wait(); + + void Scrap(); + + private: + std::mutex mtx_; // could just probably use atomic_wait on num_pending_tasks_ - needs C++20 + std::condition_variable cv_; + std::atomic_int num_pending_tasks_{0}; + bool started_ = false; + bool waited_for_ = false; + + struct Task { + std::function func; + std::exception_ptr error; + }; + + // This needs to be a container which never invalidates references when inserting new items. + std::multimap, + mm::detail::object_pool_allocator>> tasks_; +}; + +class DLL_PUBLIC ThreadPoolBase { + public: + using TaskFunc = std::function; + + ThreadPoolBase() = default; + explicit ThreadPoolBase(int num_threads) { + Init(num_threads); + } + + void Init(int num_threads); + + ~ThreadPoolBase() { + Shutdown(); + } + + void AddTask(TaskFunc f); + + /** + * @brief Returns the thread pool that owns the calling thread (or nullptr) + */ + static ThreadPoolBase *this_thread_pool() { + return this_thread_pool_; + } + + /** + * @brief Returns the index of the current thread within the current thread pool + * + * @return the thread index or -1 if the calling thread does not belong to a thread pool + */ + static int this_thread_idx() { + return this_thread_idx_; + } + + protected: + void Shutdown(); + + private: + friend class Job; + + virtual void OnThreadStart(int thread_idx) noexcept {} + virtual void OnThreadStop(int thread_idx) noexcept {} + + template + bool WaitOrRunTasks(std::condition_variable &cv, Condition &&condition); + + void PopAndRunTask(std::unique_lock &mtx); + + static thread_local ThreadPoolBase *this_thread_pool_; + static thread_local int this_thread_idx_; + + void Run(int index) noexcept; + + std::mutex mtx_; + std::condition_variable cv_; + bool shutdown_pending_ = false; + std::queue tasks_; + std::vector threads_; +}; + +} // namespace experimental +} // namespace dali + +#endif // DALI_PIPELINE_UTIL_THREAD_POOL_BASE_H_ diff --git a/dali/pipeline/util/new_thread_pool_test.cc b/dali/pipeline/util/thread_pool_base_test.cc similarity index 60% rename from dali/pipeline/util/new_thread_pool_test.cc rename to dali/pipeline/util/thread_pool_base_test.cc index f58148998c1..113fe35e8f6 100644 --- a/dali/pipeline/util/new_thread_pool_test.cc +++ b/dali/pipeline/util/thread_pool_base_test.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,23 +13,24 @@ // limitations under the License. #include -#include "dali/pipeline/util/new_thread_pool.h" +#include "dali/pipeline/util/thread_pool_base.h" +#include "dali/core/format.h" namespace dali { namespace experimental { struct SerialExecutor { template - std::enable_if_t>> + std::enable_if_t>> AddTask(Runnable &&runnable) { - runnable(0); + runnable(); } }; TEST(NewThreadPool, Scrap) { EXPECT_NO_THROW({ Job job; - job.AddTask([](int) {}); + job.AddTask([]() {}); job.Scrap(); }); } @@ -37,7 +38,7 @@ TEST(NewThreadPool, Scrap) { TEST(NewThreadPool, ErrorNotStarted) { try { Job job; - job.AddTask([](int) {}); + job.AddTask([]() {}); } catch (std::logic_error &e) { EXPECT_NE(nullptr, strstr(e.what(), "The job is not empty")); return; @@ -50,13 +51,13 @@ TEST(NewThreadPool, RunJobInSeries) { Job job; SerialExecutor tp; int a = 0, b = 0, c = 0; - job.AddTask([&](int) { + job.AddTask([&]() { a = 1; }); - job.AddTask([&](int) { + job.AddTask([&]() { b = 2; }); - job.AddTask([&](int) { + job.AddTask([&]() { c = 3; }); job.Run(tp, true); @@ -69,13 +70,13 @@ TEST(NewThreadPool, RunJobInThreadPool) { Job job; ThreadPoolBase tp(4); int a = 0, b = 0, c = 0; - job.AddTask([&](int) { + job.AddTask([&]() { a = 1; }); - job.AddTask([&](int) { + job.AddTask([&]() { b = 2; }); - job.AddTask([&](int) { + job.AddTask([&]() { c = 3; }); job.Run(tp, true); @@ -88,18 +89,58 @@ TEST(NewThreadPool, RunJobInThreadPool) { TEST(NewThreadPool, RethrowMultipleErrors) { Job job; ThreadPoolBase tp(4); - job.AddTask([&](int) { + job.AddTask([&]() { throw std::runtime_error("Runtime"); }); - job.AddTask([&](int) { + job.AddTask([&]() { // do nothing }); - job.AddTask([&](int) { + job.AddTask([&]() { throw std::logic_error("Logic"); }); EXPECT_THROW(job.Run(tp, true), MultipleErrors); } +template +void SyncPrint(Args&& ...args) { + static std::mutex mtx; + std::lock_guard guard(mtx); + std::stringstream ss; + print(ss, std::forward(args)...); + auto &&str = ss.str(); + printf("%s", str.c_str()); +} + +TEST(NewThreadPool, Reentrant) { + Job job; + ThreadPoolBase tp(1); // must not hang with just one thread + std::atomic_int outer{0}, inner{0}; + for (int i = 0; i < 10; i++) { + job.AddTask([&, i]() { + outer |= (i << 10); + }); + } + + job.AddTask([&]() { + Job innerJob; + + for (int i = 0; i < 10; i++) + innerJob.AddTask([&, i]() { + inner |= (1 << i); + }); + + innerJob.Run(tp, false); + innerJob.Wait(); + outer |= (1 << 11); + }); + + for (int i = 11; i < 20; i++) { + job.AddTask([&, i]() { + outer |= (1 << i); + }); + } + job.Run(tp, true); +} } // namespace experimental } // namespace dali