From ad8fa41afabdadf1ad0f3987881060dd1988e56a Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Wed, 25 Apr 2018 23:25:53 +0200 Subject: [PATCH] ARROW-2479: [C++] Add ThreadPool class --- cpp/src/arrow/CMakeLists.txt | 1 + cpp/src/arrow/util/CMakeLists.txt | 1 + cpp/src/arrow/util/thread-pool-test.cc | 275 +++++++++++++++++++++++++ cpp/src/arrow/util/thread-pool.cc | 165 +++++++++++++++ cpp/src/arrow/util/thread-pool.h | 139 +++++++++++++ 5 files changed, 581 insertions(+) create mode 100644 cpp/src/arrow/util/thread-pool-test.cc create mode 100644 cpp/src/arrow/util/thread-pool.cc create mode 100644 cpp/src/arrow/util/thread-pool.h diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index a3997c73737f1..aa78ed2e87a69 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -42,6 +42,7 @@ set(ARROW_SRCS util/hash.cc util/io-util.cc util/key_value_metadata.cc + util/thread-pool.cc ) if ("${COMPILER_FAMILY}" STREQUAL "clang") diff --git a/cpp/src/arrow/util/CMakeLists.txt b/cpp/src/arrow/util/CMakeLists.txt index 36cbd6de295c5..d309b2b98f564 100644 --- a/cpp/src/arrow/util/CMakeLists.txt +++ b/cpp/src/arrow/util/CMakeLists.txt @@ -59,6 +59,7 @@ ADD_ARROW_TEST(decimal-test) ADD_ARROW_TEST(key-value-metadata-test) ADD_ARROW_TEST(rle-encoding-test) ADD_ARROW_TEST(stl-util-test) +ADD_ARROW_TEST(thread-pool-test) ADD_ARROW_BENCHMARK(bit-util-benchmark) ADD_ARROW_BENCHMARK(decimal-benchmark) diff --git a/cpp/src/arrow/util/thread-pool-test.cc b/cpp/src/arrow/util/thread-pool-test.cc new file mode 100644 index 0000000000000..2b67abf3f0032 --- /dev/null +++ b/cpp/src/arrow/util/thread-pool-test.cc @@ -0,0 +1,275 @@ +// // Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include + +#include "arrow/test-util.h" +#include "arrow/util/macros.h" +#include "arrow/util/thread-pool.h" + +namespace arrow { +namespace internal { + +static void sleep_for(double seconds) { + std::this_thread::sleep_for( + std::chrono::nanoseconds(static_cast(seconds * 1e9))); +} + +static void busy_wait(double seconds, std::function predicate) { + const double period = 0.001; + for (int i = 0; !predicate() && i * period < seconds; ++i) { + sleep_for(period); + } +} + +template +static void task_add(T x, T y, T* out) { + *out = x + y; +} + +template +static void task_slow_add(double seconds, T x, T y, T* out) { + sleep_for(seconds); + *out = x + y; +} + +typedef std::function AddTaskFunc; + +template +static T add(T x, T y) { + return x + y; +} + +template +static T slow_add(double seconds, T x, T y) { + sleep_for(seconds); + return x + y; +} + +template +static T inplace_add(T& x, T y) { + return x += y; +} + +// A class to spawn "add" tasks to a pool and check the results when done + +class AddTester { + public: + explicit AddTester(int nadds) : nadds(nadds), xs(nadds), ys(nadds), outs(nadds, -1) { + int x = 0, y = 0; + std::generate(xs.begin(), xs.end(), [&] { + ++x; + return x; + }); + std::generate(ys.begin(), ys.end(), [&] { + y += 10; + return y; + }); + } + + AddTester(AddTester&&) = default; + + void SpawnTasks(ThreadPool* pool, AddTaskFunc add_func) { + for (int i = 0; i < nadds; ++i) { + ASSERT_OK(pool->Spawn([=] { add_func(xs[i], ys[i], &outs[i]); })); + } + } + + void CheckResults() { + for (int i = 0; i < nadds; ++i) { + ASSERT_EQ(outs[i], (i + 1) * 11); + } + } + + void CheckNotAllComputed() { + for (int i = 0; i < nadds; ++i) { + if (outs[i] == -1) { + return; + } + } + ASSERT_TRUE(0) << "all values were computed"; + } + + private: + ARROW_DISALLOW_COPY_AND_ASSIGN(AddTester); + + int nadds; + std::vector xs; + std::vector ys; + std::vector outs; +}; + +class TestThreadPool : public ::testing::Test { + public: + void TearDown() { + fflush(stdout); + fflush(stderr); + } + + std::shared_ptr MakeThreadPool() { return MakeThreadPool(4); } + + std::shared_ptr MakeThreadPool(size_t threads) { + std::shared_ptr pool; + Status st = ThreadPool::Make(threads, &pool); + return pool; + } + + void SpawnAdds(ThreadPool* pool, int nadds, AddTaskFunc add_func) { + AddTester add_tester(nadds); + add_tester.SpawnTasks(pool, add_func); + ASSERT_OK(pool->Shutdown()); + add_tester.CheckResults(); + } + + void SpawnAddsThreaded(ThreadPool* pool, int nthreads, int nadds, + AddTaskFunc add_func) { + // Same as SpawnAdds, but do the task spawning from multiple threads + std::vector add_testers; + std::vector threads; + for (int i = 0; i < nthreads; ++i) { + add_testers.emplace_back(nadds); + } + for (auto& add_tester : add_testers) { + threads.emplace_back([&] { add_tester.SpawnTasks(pool, add_func); }); + } + for (auto& thread : threads) { + thread.join(); + } + ASSERT_OK(pool->Shutdown()); + for (auto& add_tester : add_testers) { + add_tester.CheckResults(); + } + } +}; + +TEST_F(TestThreadPool, ConstructDestruct) { + // Stress shutdown-at-destruction logic + for (size_t threads : {1, 2, 3, 8, 32, 70}) { + auto pool = this->MakeThreadPool(threads); + } +} + +// Correctness and stress tests using Spawn() and Shutdown() + +TEST_F(TestThreadPool, Spawn) { + auto pool = this->MakeThreadPool(3); + SpawnAdds(pool.get(), 7, task_add); +} + +TEST_F(TestThreadPool, StressSpawn) { + auto pool = this->MakeThreadPool(30); + SpawnAdds(pool.get(), 1000, task_add); +} + +TEST_F(TestThreadPool, StressSpawnThreaded) { + auto pool = this->MakeThreadPool(30); + SpawnAddsThreaded(pool.get(), 20, 100, task_add); +} + +TEST_F(TestThreadPool, SpawnSlow) { + // This checks that Shutdown() waits for all tasks to finish + auto pool = this->MakeThreadPool(2); + SpawnAdds(pool.get(), 7, [](int x, int y, int* out) { + return task_slow_add(0.02 /* seconds */, x, y, out); + }); +} + +TEST_F(TestThreadPool, StressSpawnSlow) { + auto pool = this->MakeThreadPool(30); + SpawnAdds(pool.get(), 1000, [](int x, int y, int* out) { + return task_slow_add(0.002 /* seconds */, x, y, out); + }); +} + +TEST_F(TestThreadPool, StressSpawnSlowThreaded) { + auto pool = this->MakeThreadPool(30); + SpawnAddsThreaded(pool.get(), 20, 100, [](int x, int y, int* out) { + return task_slow_add(0.002 /* seconds */, x, y, out); + }); +} + +TEST_F(TestThreadPool, QuickShutdown) { + AddTester add_tester(100); + { + auto pool = this->MakeThreadPool(3); + add_tester.SpawnTasks(pool.get(), [](int x, int y, int* out) { + return task_slow_add(0.02 /* seconds */, x, y, out); + }); + ASSERT_OK(pool->Shutdown(false /* wait */)); + add_tester.CheckNotAllComputed(); + } + add_tester.CheckNotAllComputed(); +} + +TEST_F(TestThreadPool, SetCapacity) { + auto pool = this->MakeThreadPool(3); + ASSERT_EQ(pool->GetCapacity(), 3); + ASSERT_OK(pool->SetCapacity(5)); + ASSERT_EQ(pool->GetCapacity(), 5); + ASSERT_OK(pool->SetCapacity(2)); + // Wait for workers to wake up and secede + busy_wait(0.5, [&] { return pool->GetCapacity() == 2; }); + ASSERT_EQ(pool->GetCapacity(), 2); + ASSERT_OK(pool->SetCapacity(5)); + ASSERT_EQ(pool->GetCapacity(), 5); + // Downsize while tasks are pending + for (int i = 0; i < 10; ++i) { + ASSERT_OK(pool->Spawn(std::bind(sleep_for, 0.01 /* seconds */))); + } + ASSERT_OK(pool->SetCapacity(2)); + busy_wait(0.5, [&] { return pool->GetCapacity() == 2; }); + ASSERT_EQ(pool->GetCapacity(), 2); + // Ensure nothing got stuck + ASSERT_OK(pool->Shutdown()); +} + +// Test Submit() functionality + +TEST_F(TestThreadPool, Submit) { + auto pool = this->MakeThreadPool(3); + { + auto fut = pool->Submit(add, 4, 5); + ASSERT_EQ(fut.get(), 9); + } + { + auto fut = pool->Submit(add, "foo", "bar"); + ASSERT_EQ(fut.get(), "foobar"); + } + { + auto fut = pool->Submit(slow_add, 0.01 /* seconds */, 4, 5); + ASSERT_EQ(fut.get(), 9); + } + { + // Reference passing + std::string s = "foo"; + auto fut = pool->Submit(inplace_add, std::ref(s), "bar"); + ASSERT_EQ(fut.get(), "foobar"); + ASSERT_EQ(s, "foobar"); + } + { + // `void` return type + auto fut = pool->Submit(sleep_for, 0.001); + fut.get(); + } +} + +} // namespace internal +} // namespace arrow diff --git a/cpp/src/arrow/util/thread-pool.cc b/cpp/src/arrow/util/thread-pool.cc new file mode 100644 index 0000000000000..f698f8f0e8025 --- /dev/null +++ b/cpp/src/arrow/util/thread-pool.cc @@ -0,0 +1,165 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/util/thread-pool.h" +#include "arrow/util/logging.h" + +namespace arrow { +namespace internal { + +ThreadPool::ThreadPool() + : desired_capacity_(0), please_shutdown_(false), quick_shutdown_(false) {} + +ThreadPool::~ThreadPool() { ARROW_UNUSED(Shutdown(false /* wait */)); } + +Status ThreadPool::SetCapacity(size_t threads) { + std::unique_lock lock(mutex_); + if (please_shutdown_) { + return Status::Invalid("operation forbidden during or after shutdown"); + } + if (threads <= 0) { + return Status::Invalid("ThreadPool capacity must be > 0"); + } + CollectFinishedWorkersUnlocked(); + + desired_capacity_ = threads; + int64_t diff = desired_capacity_ - workers_.size(); + if (diff > 0) { + LaunchWorkersUnlocked(static_cast(diff)); + } else if (diff < 0) { + // Wake threads to ask them to stop + cv_.notify_all(); + } + return Status::OK(); +} + +size_t ThreadPool::GetCapacity() { + std::unique_lock lock(mutex_); + return workers_.size(); +} + +Status ThreadPool::Shutdown(bool wait) { + std::unique_lock lock(mutex_); + + if (please_shutdown_) { + return Status::Invalid("Shutdown() already called"); + } + please_shutdown_ = true; + quick_shutdown_ = !wait; + cv_.notify_all(); + cv_shutdown_.wait(lock, [this] { return workers_.empty(); }); + if (!quick_shutdown_) { + DCHECK_EQ(pending_tasks_.size(), 0); + } else { + pending_tasks_.clear(); + } + CollectFinishedWorkersUnlocked(); + return Status::OK(); +} + +void ThreadPool::CollectFinishedWorkersUnlocked() { + for (auto& thread : finished_workers_) { + thread.join(); + } + finished_workers_.clear(); +} + +void ThreadPool::LaunchWorkersUnlocked(size_t threads) { + for (size_t i = 0; i < threads; i++) { + workers_.emplace_back(); + auto it = --workers_.end(); + *it = std::thread([this, it] { WorkerLoop(it); }); + } +} + +void ThreadPool::WorkerLoop(std::list::iterator it) { + std::unique_lock lock(mutex_); + + // Since we hold the lock, `it` now points to the correct thread object + // (LaunchWorkersUnlocked has exited) + DCHECK_EQ(std::this_thread::get_id(), it->get_id()); + + while (true) { + // Logic detail: by the time this thread is started, some tasks + // may have been pushed or shutdown could even have been requested. + // So we only wait on the condition variable at the end of the loop. + + // Execute pending tasks if any + while (!pending_tasks_.empty() && !quick_shutdown_) { + // If too many threads, secede from the pool. + // We check this opportunistically at each loop iteration since + // it releases the lock below. + if (workers_.size() > desired_capacity_) { + break; + } + { + std::function task = std::move(pending_tasks_.front()); + pending_tasks_.pop_front(); + lock.unlock(); + task(); + } + lock.lock(); + } + // Now either the queue is empty *or* a quick shutdown was requested + if (please_shutdown_ || workers_.size() > desired_capacity_) { + break; + } + // Wait for next wakeup + cv_.wait(lock); + } + + // We're done. Move our thread object to the trashcan of finished + // workers. This has two motivations: + // 1) the thread object doesn't get destroyed before this function finishes + // (but we could call thread::detach() instead) + // 2) we can explicitly join() the trashcan threads to make sure all OS threads + // are exited before the ThreadPool is destroyed. Otherwise subtle + // timing conditions can lead to false positives with Valgrind. + // + // It's important that we keep the lock until the end of the function, + // so that ~ThreadPool() cannot finish and destroy `this` before. + DCHECK_EQ(std::this_thread::get_id(), it->get_id()); + finished_workers_.push_back(std::move(*it)); + workers_.erase(it); + if (please_shutdown_) { + // Notify the function waiting in Shutdown(). + cv_shutdown_.notify_one(); + } +} + +Status ThreadPool::SpawnReal(std::function task) { + { + std::lock_guard lock(mutex_); + if (please_shutdown_) { + return Status::Invalid("operation forbidden during or after shutdown"); + } + CollectFinishedWorkersUnlocked(); + pending_tasks_.push_back(std::move(task)); + } + cv_.notify_one(); + return Status::OK(); +} + +Status ThreadPool::Make(size_t threads, std::shared_ptr* out) { + auto pool = std::shared_ptr(new ThreadPool()); + RETURN_NOT_OK(pool->SetCapacity(threads)); + *out = std::move(pool); + return Status::OK(); +} + +} // namespace internal +} // namespace arrow diff --git a/cpp/src/arrow/util/thread-pool.h b/cpp/src/arrow/util/thread-pool.h new file mode 100644 index 0000000000000..4a38427f67111 --- /dev/null +++ b/cpp/src/arrow/util/thread-pool.h @@ -0,0 +1,139 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef ARROW_UTIL_THREAD_POOL_H +#define ARROW_UTIL_THREAD_POOL_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "arrow/status.h" +#include "arrow/util/macros.h" + +namespace arrow { +namespace internal { + +namespace detail { + +// Needed before std::packaged_task is not copyable and hence not convertible +// to std::function. +template +struct packaged_task_wrapper { + using PackagedTask = std::packaged_task; + + explicit packaged_task_wrapper(PackagedTask&& task) + : task_(std::make_shared(std::forward(task))) {} + + void operator()(Args&&... args) { return (*task_)(std::forward(args)...); } + std::shared_ptr task_; +}; + +} // namespace detail + +class ThreadPool { + public: + // Construct a thread pool with the given number of worker threads + static Status Make(size_t threads, std::shared_ptr* out); + + // Destroy thread pool; the pool will first be shut down + ~ThreadPool(); + + // Dynamically change the number of worker threads. + // This function returns quickly, but it may take more time before the + // thread count is fully adjusted. + Status SetCapacity(size_t threads); + + // Shutdown the pool. Once the pool starts shutting down, new tasks + // cannot be submitted anymore. + // If "wait" is true, shutdown waits for all pending tasks to be finished. + // If "wait" is false, workers are stopped as soon as currently executing + // tasks are finished. + Status Shutdown(bool wait = true); + + // Spawn a fire-and-forget task on one of the workers. + template + Status Spawn(Function&& func) { + return SpawnReal(std::forward(func)); + } + + // Submit a callable and arguments for execution. Return a future that + // will return the callable's result value once. + // The callable's arguments are copied before execution. + // Since the function is variadic and needs to return a result (the future), + // an exception is raised if the task fails spawning (which currently + // only occurs if the ThreadPool is shutting down). + template ::type> + std::future Submit(Function&& func, Args&&... args) { + // Trying to templatize std::packaged_task with Function doesn't seem + // to work, so go through std::bind to simplify the packaged signature + using PackagedTask = std::packaged_task; + auto task = PackagedTask(std::bind(std::forward(func), args...)); + auto fut = task.get_future(); + + Status st = SpawnReal(detail::packaged_task_wrapper(std::move(task))); + if (!st.ok()) { + throw std::runtime_error(st.ToString()); + } + return fut; + } + + protected: + FRIEND_TEST(TestThreadPool, SetCapacity); + + ThreadPool(); + + ARROW_DISALLOW_COPY_AND_ASSIGN(ThreadPool); + + Status SpawnReal(std::function task); + // Collect finished worker threads, making sure the OS threads have exited + void CollectFinishedWorkersUnlocked(); + // Launch a given number of additional workers + void LaunchWorkersUnlocked(size_t threads); + void WorkerLoop(std::list::iterator it); + size_t GetCapacity(); + + std::mutex mutex_; + std::condition_variable cv_; + std::condition_variable cv_shutdown_; + + std::list workers_; + // Trashcan for finished threads + std::vector finished_workers_; + std::deque> pending_tasks_; + + // Desired number of threads + size_t desired_capacity_; + // Are we shutting down? + bool please_shutdown_; + bool quick_shutdown_; +}; + +} // namespace internal +} // namespace arrow + +#endif // ARROW_UTIL_THREAD_POOL_H