Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ jobs:
cd /workspaces/stdexec;
# Configure
cmake -S . -B build -GNinja \
-DSTDEXEC_ENABLE_IO_URING_TESTS=OFF \
-DSTDEXEC_ENABLE_CUDA=ON \
-DCMAKE_CXX_COMPILER="$cxx" \
-DCMAKE_CUDA_COMPILER="$cxx" \
Expand Down
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ if (STDEXEC_ENABLE_TBB)
)
endif ()

option (STDEXEC_ENABLE_IO_URING_TESTS "Enable io_uring tests" ON)

option(STDEXEC_BUILD_EXAMPLES "Build stdexec examples" ON)
option(STDEXEC_BUILD_TESTS "Build stdexec tests" ON)
option(BUILD_TESTING "" ${STDEXEC_BUILD_TESTS})
Expand Down
77 changes: 43 additions & 34 deletions include/exec/linux/io_uring_context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ namespace exec {
};

class __scheduler;

enum class until {
stopped,
empty
Expand Down Expand Up @@ -402,26 +402,28 @@ namespace exec {
///
/// This function is not thread-safe and must only be called from the thread that drives the io context.
void run_some() noexcept {
__n_submitted_ -= __completion_queue_.complete();
__n_total_submitted_ -= __completion_queue_.complete();
STDEXEC_ASSERT(
0 <= __n_submitted_
&& __n_submitted_ <= static_cast<std::ptrdiff_t>(__params_.cq_entries));
__u32 __max_submissions = __params_.cq_entries - static_cast<__u32>(__n_submitted_);
0 <= __n_total_submitted_
&& __n_total_submitted_ <= static_cast<std::ptrdiff_t>(__params_.cq_entries));
__u32 __max_submissions = __params_.cq_entries - static_cast<__u32>(__n_total_submitted_);
__pending_.append(__requests_.pop_all());
__submission_result __result = __submission_queue_.submit(
(__task_queue&&) __pending_, __max_submissions, __stop_source_->stop_requested());
__n_submitted_ += __result.__n_submitted;
STDEXEC_ASSERT(__n_submitted_ <= static_cast<std::ptrdiff_t>(__params_.cq_entries));
__n_total_submitted_ += __result.__n_submitted;
__n_newly_submitted_ += __result.__n_submitted;
STDEXEC_ASSERT(__n_total_submitted_ <= static_cast<std::ptrdiff_t>(__params_.cq_entries));
__pending_ = (__task_queue&&) __result.__pending;
while (!__result.__ready.empty()) {
__n_submitted_ -= __completion_queue_.complete((__task_queue&&) __result.__ready);
STDEXEC_ASSERT(0 <= __n_submitted_);
__n_total_submitted_ -= __completion_queue_.complete((__task_queue&&) __result.__ready);
STDEXEC_ASSERT(0 <= __n_total_submitted_);
__pending_.append(__requests_.pop_all());
__max_submissions = __params_.cq_entries - static_cast<__u32>(__n_submitted_);
__max_submissions = __params_.cq_entries - static_cast<__u32>(__n_total_submitted_);
__result = __submission_queue_.submit(
(__task_queue&&) __pending_, __max_submissions, __stop_source_->stop_requested());
__n_submitted_ += __result.__n_submitted;
STDEXEC_ASSERT(__n_submitted_ <= static_cast<std::ptrdiff_t>(__params_.cq_entries));
__n_total_submitted_ += __result.__n_submitted;
__n_newly_submitted_ += __result.__n_submitted;
STDEXEC_ASSERT(__n_total_submitted_ <= static_cast<std::ptrdiff_t>(__params_.cq_entries));
__pending_ = (__task_queue&&) __result.__pending;
}
}
Expand All @@ -446,28 +448,30 @@ namespace exec {
__is_running_.store(false, std::memory_order_relaxed);
}};
__pending_.append(__requests_.pop_all());
while (__n_submitted_ > 0 || !__pending_.empty()) {
while (__n_total_submitted_ > 0 || !__pending_.empty()) {
run_some();
if (
__n_submitted_ == 0
|| (__n_submitted_ == 1 && __break_loop_.load(std::memory_order_acquire))) {
__n_total_submitted_ == 0
|| (__n_total_submitted_ == 1 && __break_loop_.load(std::memory_order_acquire))) {
__break_loop_.store(false, std::memory_order_relaxed);
break;
}
constexpr int __min_complete = 1;
STDEXEC_ASSERT(
0 <= __n_submitted_
&& __n_submitted_ <= static_cast<std::ptrdiff_t>(__params_.cq_entries));
0 <= __n_total_submitted_
&& __n_total_submitted_ <= static_cast<std::ptrdiff_t>(__params_.cq_entries));
int rc = __io_uring_enter(
__ring_fd_, __n_submitted_, __min_complete, IORING_ENTER_GETEVENTS);
__ring_fd_, __n_newly_submitted_, __min_complete, IORING_ENTER_GETEVENTS);
__throw_error_code_if(rc < 0, -rc);
__n_submitted_ -= __completion_queue_.complete();
STDEXEC_ASSERT(0 <= __n_submitted_);
STDEXEC_ASSERT(rc <= __n_newly_submitted_);
__n_newly_submitted_ -= rc;
__n_total_submitted_ -= __completion_queue_.complete();
STDEXEC_ASSERT(0 <= __n_total_submitted_);
__pending_.append(__requests_.pop_all());
}
STDEXEC_ASSERT(__n_submitted_ <= 1);
STDEXEC_ASSERT(__n_total_submitted_ <= 1);
if (__stop_source_->stop_requested() && __pending_.empty()) {
STDEXEC_ASSERT(__n_submitted_ == 0);
STDEXEC_ASSERT(__n_total_submitted_ == 0);
// try to shutdown the request queue
int __n_in_flight_expected = 0;
while (!__n_submissions_in_flight_.compare_exchange_weak(
Expand Down Expand Up @@ -581,7 +585,8 @@ namespace exec {
std::atomic<bool> __is_running_{false};
std::atomic<int> __n_submissions_in_flight_{0};
std::atomic<bool> __break_loop_{false};
std::ptrdiff_t __n_submitted_{0};
std::ptrdiff_t __n_total_submitted_{0};
std::ptrdiff_t __n_newly_submitted_{0};
std::optional<stdexec::in_place_stop_source> __stop_source_{std::in_place};
__completion_queue __completion_queue_;
__submission_queue __submission_queue_;
Expand Down Expand Up @@ -638,11 +643,11 @@ namespace exec {
static constexpr __task_vtable __vtable{&__ready_, &__submit_, &__complete_};

template <class... _Args>
requires stdexec::constructible_from<_Base, std::in_place_t, _Args...>
requires stdexec::constructible_from<_Base, std::in_place_t, __task*, _Args...>
__io_task_facade(std::in_place_t, _Args&&... __args) noexcept(
stdexec::__nothrow_constructible_from<_Base, _Args...>)
stdexec::__nothrow_constructible_from<_Base, __task*, _Args...>)
: __task{__vtable}
, __base_(std::in_place, (_Args&&) __args...) {
, __base_(std::in_place, static_cast<__task*>(this), (_Args&&) __args...) {
}

template <class... _Args>
Expand Down Expand Up @@ -731,8 +736,8 @@ namespace exec {
__op_->submit_stop(__sqe);
} else {
__sqe = ::io_uring_sqe{
.opcode = IORING_OP_ASYNC_CANCEL, //
.addr = bit_cast<__u64>(__op_) //
.opcode = IORING_OP_ASYNC_CANCEL, //
.addr = bit_cast<__u64>(__op_->__parent_) //
};
}
#else
Expand Down Expand Up @@ -768,23 +773,27 @@ namespace exec {

template <class _Base, bool _False>
struct __impl_base {
__task* __parent_;
_Base __base_;

template <class... _Args>
__impl_base(std::in_place_t, _Args&&... __args) noexcept(
__impl_base(__task* __parent, std::in_place_t, _Args&&... __args) noexcept(
stdexec::__nothrow_constructible_from<_Base, _Args...>)
: __base_((_Args&&) __args...) {
: __parent_{__parent}
, __base_((_Args&&) __args...) {
}
};

template <class _Base>
struct __impl_base<_Base, true> {
__task* __parent_;
_Base __base_;

template <class... _Args>
__impl_base(std::in_place_t, _Args&&... __args) noexcept(
__impl_base(__task* __parent, std::in_place_t, _Args&&... __args) noexcept(
stdexec::__nothrow_constructible_from<_Base, _Args...>)
: __base_((_Args&&) __args...) {
: __parent_{__parent}
, __base_((_Args&&) __args...) {
}

void submit_stop(::io_uring_sqe& __sqe) noexcept {
Expand Down Expand Up @@ -823,9 +832,9 @@ namespace exec {

template <class... _Args>
requires stdexec::constructible_from<_Base, _Args...>
__impl(std::in_place_t, _Args&&... __args) noexcept(
__impl(std::in_place_t, __task* __parent, _Args&&... __args) noexcept(
stdexec::__nothrow_constructible_from<_Base, _Args...>)
: __base_t(std::in_place, (_Args&&) __args...)
: __base_t(__parent, std::in_place, (_Args&&) __args...)
, __stop_operation_{this} {
}

Expand Down
2 changes: 1 addition & 1 deletion test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ set(stdexec_test_sources
exec/test_when_any.cpp
exec/test_at_coroutine_exit.cpp
exec/test_materialize.cpp
exec/test_io_uring_context.cpp
$<$<BOOL:${STDEXEC_ENABLE_IO_URING_TESTS}>:exec/test_io_uring_context.cpp>
exec/test_trampoline_scheduler.cpp
exec/test_sequence_senders.cpp
exec/sequence/test_empty_sequence.cpp
Expand Down
17 changes: 13 additions & 4 deletions test/exec/test_io_uring_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,14 +301,23 @@ TEST_CASE("io_uring_context schedule_after -1s", "[types][io_uring][schedulers]"
scope_guard guard{[&]() noexcept {
context.request_stop();
}};
bool is_called = false;
bool is_called_1 = false;
bool is_called_2 = false;
auto start = std::chrono::steady_clock::now();
auto timeout = 100ms;
sync_wait(when_any(
schedule_after(scheduler, -1s) | then([&] {
CHECK(io_thread.get_id() == std::this_thread::get_id());
is_called = true;
is_called_1 = true;
}),
schedule_after(scheduler, 5ms)));
CHECK(is_called);
schedule_after(scheduler, timeout) | then([&] {
is_called_2 = true;
})));
auto end = std::chrono::steady_clock::now();
std::chrono::nanoseconds diff = end - start;
CHECK(diff.count() < std::chrono::duration_cast<std::chrono::nanoseconds>(timeout).count());
CHECK(is_called_1 == true);
CHECK(is_called_2 == false);
}
}

Expand Down