diff --git a/include/exec/task.hpp b/include/exec/task.hpp index e94943203..d800318ca 100644 --- a/include/exec/task.hpp +++ b/include/exec/task.hpp @@ -672,7 +672,7 @@ namespace experimental::execution if constexpr (requires { __coro_.promise().stop_requested() ? 0 : 1; }) { if (__coro_.promise().stop_requested()) - return __parent.unhandled_stopped(); + return STDEXEC::__coroutine_unhandled_stopped(__parent); } return __coro_; } diff --git a/include/stdexec/__detail/__as_awaitable.hpp b/include/stdexec/__detail/__as_awaitable.hpp index d82e140a2..436810b04 100644 --- a/include/stdexec/__detail/__as_awaitable.hpp +++ b/include/stdexec/__detail/__as_awaitable.hpp @@ -135,7 +135,7 @@ namespace STDEXEC // as normal. if (__result_.__is_valueless()) { - return __continuation_.unhandled_stopped(); + return STDEXEC::__coroutine_unhandled_stopped(__continuation_); } else { @@ -317,28 +317,30 @@ namespace STDEXEC STDEXEC_CONSTEXPR_CXX23 auto await_suspend([[maybe_unused]] __std::coroutine_handle<> __continuation) noexcept + -> __std::coroutine_handle<> { STDEXEC_ASSERT(this->__continuation_.handle() == __continuation); // Start the operation. STDEXEC::start(__opstate_); - // This exchange({}) is T1's last write to the frame. After this point: - // - If T2 is spinning waiting for our exchange, it will observe {} and - // proceed to resume(). - // - If T2 hasn't run yet, it will see {} from its load in __done() and - // skip the spin entirely + // We need to do two things: + // 1) Check if we already completed inline (receiver wrote {} to __thread_id_) + // In that case, the receiver has already returned and we can just resume the continuation + // 2) Otherwise, we need signal to the (potentially spin-waiting) receiver that we are + // finished and won't access the frame anymore. + // We do this with an exchange({}), except on buggy MSVC versions, where we have to delay + // the signaling until we exited this function. +# if !defined(STDEXEC_MSVC_CORO_DESTROY_BUG_WORKAROUND) bool const __done = // this->__thread_id_.exchange(std::thread::id{}, __std::memory_order_release) == std::thread::id{}; - - // If the receiver already cleared __thread_id_, it completed on the same thread. - // Resume the continuation directly. -# if !defined(STDEXEC_MSVC_CORO_DESTROY_BUG_WORKAROUND) return __done ? this->__get_continuation() : __std::noop_coroutine(); # else - if (__done) - STDEXEC::__coroutine_resume_nothrow(this->__get_continuation()); + bool const __done = // + this->__thread_id_.load(__std::memory_order_relaxed) == std::thread::id{}; + return __done ? this->__get_continuation() + : __coroutine_signal_completion(this->__thread_id_); # endif } diff --git a/include/stdexec/coroutine.hpp b/include/stdexec/coroutine.hpp index 54706cd3c..1ea3759db 100644 --- a/include/stdexec/coroutine.hpp +++ b/include/stdexec/coroutine.hpp @@ -15,12 +15,14 @@ */ #pragma once +#include "__detail/__atomic.hpp" #include "__detail/__awaitable.hpp" // IWYU pragma: export #include "__detail/__concepts.hpp" #include "__detail/__config.hpp" #include "__detail/__utility.hpp" #include +#include #if !STDEXEC_NO_STDCPP_COROUTINES() @@ -224,6 +226,49 @@ namespace STDEXEC {&__destroy_and_continue_frame::__resume}, {}}; + struct __unhandled_stopped_frame : __detail::__synthetic_coro_frame + { + static void __resume(void* __address) noexcept + { + // Make a local copy of the promise since it will go away once we call through + // the __unhandled_stopped_fn_ function pointer. + auto& __self = *static_cast<__unhandled_stopped_frame*>(__address); + STDEXEC::__coroutine_resume_nothrow(__self.__promise_.__coro_.unhandled_stopped()); + } + + struct __promise + { + __coroutine_handle<> __coro_; + } __promise_; + + static thread_local __unhandled_stopped_frame value; + }; + + inline thread_local __unhandled_stopped_frame __unhandled_stopped_frame::value{ + {&__unhandled_stopped_frame::__resume}, + {}}; + + struct __signal_completion_frame : __detail::__synthetic_coro_frame + { + static void __resume(void* __address) noexcept + { + auto& __self = *static_cast<__signal_completion_frame*>(__address); + STDEXEC_ASSERT(__self.__promise_.__thread_id_ != nullptr); + __self.__promise_.__thread_id_->store({}, __std::memory_order_release); + } + + struct __promise + { + __std::atomic* __thread_id_; + } __promise_; + + static thread_local __signal_completion_frame value; + }; + + inline thread_local __signal_completion_frame __signal_completion_frame::value{ + {&__signal_completion_frame::__resume}, + nullptr}; + inline auto __coroutine_destroy_and_continue(__std::coroutine_handle<> __destroy, // __std::coroutine_handle<> __continue) noexcept // -> __std::coroutine_handle<> @@ -233,6 +278,21 @@ namespace STDEXEC return __std::coroutine_handle<>::from_address(&__destroy_and_continue_frame::value); } + inline auto __coroutine_unhandled_stopped(__coroutine_handle<> __coro) noexcept // + -> __std::coroutine_handle<> + { + __unhandled_stopped_frame::value.__promise_.__coro_ = __coro; + return __std::coroutine_handle<>::from_address(&__unhandled_stopped_frame::value); + } + + inline auto + __coroutine_signal_completion(__std::atomic& __thread_id_) noexcept // + -> __std::coroutine_handle<> + { + __signal_completion_frame::value.__promise_.__thread_id_ = &__thread_id_; + return __std::coroutine_handle<>::from_address(&__signal_completion_frame::value); + } + # else STDEXEC_ATTRIBUTE(always_inline) @@ -244,6 +304,20 @@ namespace STDEXEC return __continue; } + STDEXEC_ATTRIBUTE(always_inline) + auto __coroutine_unhandled_stopped(__coroutine_handle<> __coro) noexcept // + -> __std::coroutine_handle<> + { + return __coro.unhandled_stopped(); + } + + STDEXEC_ATTRIBUTE(always_inline) + auto __coroutine_signal_completion(__std::atomic& __thread_id_) noexcept // + -> __std::coroutine_handle<> + { + __thread_id_.store({}, __std::memory_order_release); + return __std::noop_coroutine(); + } # endif // !defined(STDEXEC_MSVC_CORO_DESTROY_BUG_WORKAROUND) } // namespace STDEXEC diff --git a/test/rrd/stdexec_relacy.hpp b/test/rrd/stdexec_relacy.hpp index 9df9e9f1a..7e78bf49b 100644 --- a/test/rrd/stdexec_relacy.hpp +++ b/test/rrd/stdexec_relacy.hpp @@ -2,8 +2,16 @@ #include "../../relacy/relacy_std.hpp" +#include +#include + namespace std { template struct atomic_ref; + + inline std::ostream& operator<<(std::ostream& os, std::thread::id const & id) + { + return os << "thread::id{" << id.id_ << "}"; + } } // namespace std diff --git a/test/stdexec/types/test_task.cpp b/test/stdexec/types/test_task.cpp index 48accaa08..daf8dd072 100644 --- a/test/stdexec/types/test_task.cpp +++ b/test/stdexec/types/test_task.cpp @@ -489,6 +489,41 @@ namespace ex::sync_wait(scope.join()); } + auto await_always_inline_stopped_sender() -> ex::task + { + co_await ex::just_stopped(); + } + + TEST_CASE("repro for NVIDIA/stdexec#2047 always inline", "[types][task]") + { + auto pool = exec::static_thread_pool(1); + + auto scope = ex::counting_scope(); + ex::spawn(ex::starts_on(pool.get_scheduler(), await_always_inline_stopped_sender()) + | ex::upon_error([](auto) noexcept { std::terminate(); }), + scope.get_token()); + ex::sync_wait(scope.join()); + } + + auto await_always_inline_value_sender_loop() -> ex::task + { + for (size_t i = 0; i < 10000; ++i) + { + co_await ex::just(); + } + } + + TEST_CASE("repro for NVIDIA/stdexec#2047 no stack overflow", "[types][task]") + { + auto pool = exec::static_thread_pool(1); + + auto scope = ex::counting_scope(); + ex::spawn(ex::starts_on(pool.get_scheduler(), await_always_inline_value_sender_loop()) + | ex::upon_error([](auto) noexcept { std::terminate(); }), + scope.get_token()); + ex::sync_wait(scope.join()); + } + // TODO: add tests for stop token support in task } // anonymous namespace