Skip to content

Commit

Permalink
Merge pull request #7 from DeveloperPaul123/feature/thread-load-distr…
Browse files Browse the repository at this point in the history
…ibution
  • Loading branch information
DeveloperPaul123 authored Feb 24, 2022
2 parents c1b91fd + a519129 commit 97578b0
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 49 deletions.
57 changes: 26 additions & 31 deletions include/thread_pool/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,28 +35,26 @@ namespace dp {
std::is_same_v<void, std::invoke_result_t<FunctionType>>
class thread_pool {
public:
thread_pool(const unsigned int &number_of_threads = std::thread::hardware_concurrency())
: queues_(number_of_threads) {
thread_pool(const unsigned int &number_of_threads = std::thread::hardware_concurrency()) {
for (std::size_t i = 0; i < number_of_threads; ++i) {
threads_.emplace_back([&, id = i](std::stop_token stop_tok) {
threads_.emplace_back([&](const std::stop_token stop_tok) {
do {
// check if we have task
if (queues_[id].tasks.empty()) {
if (queue_.empty()) {
// no tasks, so we wait instead of spinning
queues_[id].semaphore.acquire();
std::unique_lock lock(condition_mutex_);
condition_.wait(lock, stop_tok, [this]() { return !queue_.empty(); });
}

// ensure we have a task before getting task
// since the dtor releases the semaphore as well
if (!queues_[id].tasks.empty()) {
// since the dtor notifies via the condition variable as well
if (!queue_.empty()) {
// get the task
auto &task = queues_[id].tasks.front();
auto task = queue_.pop();
// invoke the task
std::invoke(std::move(task));
// decrement in-flight counter
--in_flight_;
// remove task from the queue
queues_[id].tasks.pop();
}
} while (!stop_tok.stop_requested());
});
Expand All @@ -70,11 +68,10 @@ namespace dp {
} while (in_flight_ > 0);

// stop all threads
for (std::size_t i = 0; i < threads_.size(); ++i) {
threads_[i].request_stop();
queues_[i].semaphore.release();
threads_[i].join();
for (auto &thread : threads_) {
thread.request_stop();
}
condition_.notify_all();
}

/// thread pool is non-copyable
Expand All @@ -83,11 +80,11 @@ namespace dp {

/**
* @brief Enqueue a task into the thread pool that returns a result.
* @tparam Function An invocable type.
* @tparam ...Args Argument parameter pack
* @tparam Function An invokable type.
* @tparam Args Argument parameter pack
* @tparam ReturnType The return type of the Function
* @param f The callable function
* @param ...args The parameters that will be passed (copied) to the function.
* @param args The parameters that will be passed (copied) to the function.
* @return A std::future<ReturnType> that can be used to retrieve the returned value.
*/
template <typename Function, typename... Args,
Expand All @@ -101,7 +98,7 @@ namespace dp {
*
* std::promise<ReturnType> promise;
* auto future = promise.get_future();
* auto task = [func = std::move(f), ... largs = std::move(args),
* auto task = [func = std::move(f), ...largs = std::move(args),
promise = std::move(promise)]() mutable {...};
*/
auto shared_promise = std::make_shared<std::promise<ReturnType>>();
Expand All @@ -117,10 +114,10 @@ namespace dp {

/**
* @brief Enqueue a task to be executed in the thread pool that returns void.
* @tparam Function An invocable type.
* @tparam ...Args Argument parameter pack for Function
* @tparam Function An invokable type.
* @tparam Args Argument parameter pack for Function
* @param func The callable to be executed
* @param ...args Arguments that wiill be passed to the function.
* @param args Arguments that will be passed to the function.
*/
template <typename Function, typename... Args>
requires std::invocable<Function, Args...> &&
Expand All @@ -130,22 +127,20 @@ namespace dp {
}

private:
struct task_queue {
std::binary_semaphore semaphore{0};
dp::thread_safe_queue<FunctionType> tasks{};
};

template <typename Function>
void enqueue_task(Function &&f) {
const std::size_t i = count_++ % queues_.size();
++in_flight_;
queues_[i].tasks.push(std::forward<Function>(f));
queues_[i].semaphore.release();
{
std::lock_guard lock(condition_mutex_);
queue_.push(std::forward<Function>(f));
}
condition_.notify_all();
}

std::condition_variable_any condition_;
std::mutex condition_mutex_;
std::vector<std::jthread> threads_;
std::deque<task_queue> queues_;
std::size_t count_ = 0;
dp::thread_safe_queue<FunctionType> queue_;
std::atomic<int64_t> in_flight_{0};
};

Expand Down
28 changes: 15 additions & 13 deletions include/thread_pool/thread_safe_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,24 @@ namespace dp {
return data_.size();
}

[[nodiscard]] T& front() {
std::unique_lock lock(mutex_);
condition_variable_.wait(lock, [this] { return !data_.empty(); });
return data_.front();
}

[[nodiscard]] T& back() {
std::unique_lock lock(mutex_);
condition_variable_.wait(lock, [this] { return !data_.empty(); });
return data_.back();
}

void pop() {
// [[nodiscard]] T& front() {
// std::unique_lock lock(mutex_);
// condition_variable_.wait(lock, [this] { return !data_.empty(); });
// return data_.front();
// }
//
// [[nodiscard]] T& back() {
// std::unique_lock lock(mutex_);
// condition_variable_.wait(lock, [this] { return !data_.empty(); });
// return data_.back();
// }

[[nodiscard]] T pop() {
std::unique_lock lock(mutex_);
condition_variable_.wait(lock, [this] { return !data_.empty(); });
auto front = data_.front();
data_.pop_front();
return front;
}

private:
Expand Down
44 changes: 40 additions & 4 deletions test/source/thread_pool.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <doctest/doctest.h>
#include <thread_pool/thread_pool.h>

#include <iostream>
#include <string>

TEST_CASE("Basic task return types") {
Expand All @@ -16,7 +17,7 @@ TEST_CASE("Basic task return types") {

TEST_CASE("Ensure input params are properly passed") {
dp::thread_pool pool(4);
const auto total_tasks = 30;
constexpr auto total_tasks = 30;
std::vector<std::future<int>> futures;

for (auto i = 0; i < total_tasks; i++) {
Expand All @@ -32,13 +33,13 @@ TEST_CASE("Ensure input params are properly passed") {

TEST_CASE("Ensure work completes upon destruction") {
std::atomic<int> counter;
std::vector<std::future<int>> futures;
const auto total_tasks = 20;
constexpr auto total_tasks = 20;
{
std::vector<std::future<int>> futures;
dp::thread_pool pool(4);
for (auto i = 0; i < total_tasks; i++) {
auto task = [index = i, &counter]() {
counter++;
++counter;
return index;
};
futures.push_back(pool.enqueue(task));
Expand All @@ -47,3 +48,38 @@ TEST_CASE("Ensure work completes upon destruction") {

CHECK_EQ(counter.load(), total_tasks);
}

TEST_CASE("Ensure task load is spread evenly across threads") {
auto delay_task = [](const std::chrono::seconds& seconds) {
std::this_thread::sleep_for(seconds);
};
constexpr auto long_task_time = 6;
const auto start_time = std::chrono::steady_clock::now();
{
dp::thread_pool pool(4);
for (auto i = 1; i <= 8; ++i) {
auto delay_amount = std::chrono::seconds(i % 4);
if (i % 4 == 0) {
delay_amount = std::chrono::seconds(long_task_time);
}
std::cout << std::to_string(delay_amount.count()) << "\n";
pool.enqueue_detach(delay_task, delay_amount);
}
// wait for tasks to complete
}
const auto end_time = std::chrono::steady_clock::now();
const auto duration = std::chrono::duration_cast<std::chrono::seconds>(end_time - start_time);

/**
* Potential execution graph
* '-' and '*' represent task time.
* '-' is the first round of tasks and '*' is the second round of tasks
*
* - * **
* -- ***
* --- ******
* ------
*/
CHECK_LE(duration.count(), 9);
std::cout << "total duration: " << duration.count() << "\n";
}
2 changes: 1 addition & 1 deletion test/source/thread_safe_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ TEST_CASE("Check size while waiting for front") {
int item = 0;
std::jthread wait_for_item_thread([&item, &queue]() {
// this will block until an item becomes available
item = queue.front();
item = queue.pop();
});

if (queue.empty()) {
Expand Down

0 comments on commit 97578b0

Please sign in to comment.