Skip to content

Commit

Permalink
[Runtime] Parallel-for with threading backend
Browse files Browse the repository at this point in the history
This PR introduces the runtime parallel-for helper function
in C++ with the threading backend in TVM.

Right now the existing [parallel-for](https://github.com/apache/tvm/blob/bd67d2e5ebde1aec18bcfa74c087516579bda1ae/include/tvm/support/parallel_for.h#L48-L68)
in TVM is not thread persistent,
in which case we cannot get persistent TLS for each thread.

The introduced parallel-for-with-threading-backend function
leverages the threading backend in TVM and persists threads.
  • Loading branch information
MasterJH5574 committed Nov 15, 2023
1 parent bd67d2e commit 9969654
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 0 deletions.
69 changes: 69 additions & 0 deletions include/tvm/runtime/threading_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#ifndef TVM_RUNTIME_THREADING_BACKEND_H_
#define TVM_RUNTIME_THREADING_BACKEND_H_

#include <tvm/runtime/c_backend_api.h>

#include <functional>
#include <memory>
#include <vector>
Expand Down Expand Up @@ -147,6 +149,73 @@ TVM_DLL void Configure(tvm::runtime::threading::ThreadGroup::AffinityMode mode,
int32_t NumThreads();

} // namespace threading

/*!
* \brief Execute the given lambda function in parallel with
* threading backend in TVM.
* \tparam T The type of the lambda: "void (int i)".
* \param flambda The lambda to be executed in parallel.
* It should have the signature "void (int i)".
* \param begin The start index of this parallel loop (inclusive).
* \param end The end index of this parallel loop (exclusive).
* \example
*
* The for loop
* for (int i = 0; i < 10; i++) {
* a[i] = i;
* }
* should work the same as:
* parallel_for_with_threading_backend([&a](int i) {
* a[i] = i;
* }, 0, 10);
*/
template <typename T>
inline void parallel_for_with_threading_backend(T flambda, int begin, int end);

namespace detail {

// The detailed implementation of `parallel_for_with_threading_backend`.
// To avoid template expansion, the implementation cannot be placed
// in .cc files.

template <typename T>
struct ParallelForWithThreadingBackendLambdaInvoker {
static int TVMParallelLambdaInvoke(int task_id, TVMParallelGroupEnv* penv, void* cdata) {
int num_task = penv->num_task;
// Convert void* back to lambda type.
T* lambda_ptr = static_cast<T*>(cdata);
// Invoke the lambda with the task id (thread id).
(*lambda_ptr)(task_id, num_task);
return 0;
}
};

template <typename T>
inline void parallel_launch_with_threading_backend(T flambda) {
// Launch the lambda by passing its address.
void* cdata = &flambda;
TVMBackendParallelLaunch(ParallelForWithThreadingBackendLambdaInvoker<T>::TVMParallelLambdaInvoke,
cdata, /*num_task=*/0);
}

} // namespace detail

template <typename T>
inline void parallel_for_with_threading_backend(T flambda, int64_t begin, int64_t end) {
auto flaunch = [begin, end, flambda](int task_id, int num_task) {
// For each thread, do static division and call into flambda.
int64_t total_len = end - begin;
int64_t step = (total_len + num_task - 1) / num_task;
int64_t local_begin = std::min(begin + step * task_id, end);
int64_t local_end = std::min(local_begin + step, end);
for (int64_t i = local_begin; i < local_end; ++i) {
flambda(i);
}
};
// Launch with all threads.
detail::parallel_launch_with_threading_backend(flaunch);
}

} // namespace runtime
} // namespace tvm

Expand Down
9 changes: 9 additions & 0 deletions tests/cpp/threading_backend_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,12 @@ TEST(ThreadingBackend, TVMBackendAffinityConfigure) {
t->join();
}
}

TEST(ThreadingBackend, TVMBackendParallelForWithThreadingBackend) {
int n = 100;
std::vector<int> vec(/*size=*/n, /*value=*/0);
tvm::runtime::parallel_for_with_threading_backend([&vec](int i) { vec[i] = i; }, 0, n);
for (int i = 0; i < n; ++i) {
EXPECT_EQ(vec[i], i);
}
}

0 comments on commit 9969654

Please sign in to comment.