Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/exec/task.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ namespace experimental::execution
if constexpr (requires { __coro_.promise().stop_requested() ? 0 : 1; })
{
if (__coro_.promise().stop_requested())
return __parent.unhandled_stopped();
return STDEXEC::__coroutine_unhandled_stopped(__parent);
}
return __coro_;
}
Expand Down
26 changes: 14 additions & 12 deletions include/stdexec/__detail/__as_awaitable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ namespace STDEXEC
// as normal.
if (__result_.__is_valueless())
{
return __continuation_.unhandled_stopped();
return STDEXEC::__coroutine_unhandled_stopped(__continuation_);
}
else
{
Expand Down Expand Up @@ -317,28 +317,30 @@ namespace STDEXEC

STDEXEC_CONSTEXPR_CXX23 auto
await_suspend([[maybe_unused]] __std::coroutine_handle<> __continuation) noexcept
-> __std::coroutine_handle<>
{
STDEXEC_ASSERT(this->__continuation_.handle() == __continuation);

// Start the operation.
STDEXEC::start(__opstate_);

// This exchange({}) is T1's last write to the frame. After this point:
// - If T2 is spinning waiting for our exchange, it will observe {} and
// proceed to resume().
// - If T2 hasn't run yet, it will see {} from its load in __done() and
// skip the spin entirely
// We need to do two things:
// 1) Check if we already completed inline (receiver wrote {} to __thread_id_)
// In that case, the receiver has already returned and we can just resume the continuation
// 2) Otherwise, we need signal to the (potentially spin-waiting) receiver that we are
// finished and won't access the frame anymore.
// We do this with an exchange({}), except on buggy MSVC versions, where we have to delay
// the signaling until we exited this function.
# if !defined(STDEXEC_MSVC_CORO_DESTROY_BUG_WORKAROUND)
bool const __done = //
this->__thread_id_.exchange(std::thread::id{}, __std::memory_order_release)
== std::thread::id{};

// If the receiver already cleared __thread_id_, it completed on the same thread.
// Resume the continuation directly.
# if !defined(STDEXEC_MSVC_CORO_DESTROY_BUG_WORKAROUND)
return __done ? this->__get_continuation() : __std::noop_coroutine();
# else
if (__done)
STDEXEC::__coroutine_resume_nothrow(this->__get_continuation());
bool const __done = //
this->__thread_id_.load(__std::memory_order_relaxed) == std::thread::id{};
return __done ? this->__get_continuation()
: __coroutine_signal_completion(this->__thread_id_);
# endif
}

Expand Down
74 changes: 74 additions & 0 deletions include/stdexec/coroutine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
*/
#pragma once

#include "__detail/__atomic.hpp"
#include "__detail/__awaitable.hpp" // IWYU pragma: export
#include "__detail/__concepts.hpp"
#include "__detail/__config.hpp"
#include "__detail/__utility.hpp"

#include <exception>
#include <thread>

#if !STDEXEC_NO_STDCPP_COROUTINES()

Expand Down Expand Up @@ -224,6 +226,49 @@ namespace STDEXEC
{&__destroy_and_continue_frame::__resume},
{}};

struct __unhandled_stopped_frame : __detail::__synthetic_coro_frame
{
static void __resume(void* __address) noexcept
{
// Make a local copy of the promise since it will go away once we call through
// the __unhandled_stopped_fn_ function pointer.
auto& __self = *static_cast<__unhandled_stopped_frame*>(__address);
STDEXEC::__coroutine_resume_nothrow(__self.__promise_.__coro_.unhandled_stopped());
}

struct __promise
{
__coroutine_handle<> __coro_;
} __promise_;

static thread_local __unhandled_stopped_frame value;
};

inline thread_local __unhandled_stopped_frame __unhandled_stopped_frame::value{
{&__unhandled_stopped_frame::__resume},
{}};

struct __signal_completion_frame : __detail::__synthetic_coro_frame
{
static void __resume(void* __address) noexcept
{
auto& __self = *static_cast<__signal_completion_frame*>(__address);
STDEXEC_ASSERT(__self.__promise_.__thread_id_ != nullptr);
__self.__promise_.__thread_id_->store({}, __std::memory_order_release);
}

struct __promise
{
__std::atomic<std::thread::id>* __thread_id_;
} __promise_;

static thread_local __signal_completion_frame value;
};

inline thread_local __signal_completion_frame __signal_completion_frame::value{
{&__signal_completion_frame::__resume},
nullptr};

inline auto __coroutine_destroy_and_continue(__std::coroutine_handle<> __destroy, //
__std::coroutine_handle<> __continue) noexcept //
-> __std::coroutine_handle<>
Expand All @@ -233,6 +278,21 @@ namespace STDEXEC
return __std::coroutine_handle<>::from_address(&__destroy_and_continue_frame::value);
}

inline auto __coroutine_unhandled_stopped(__coroutine_handle<> __coro) noexcept //
-> __std::coroutine_handle<>
{
__unhandled_stopped_frame::value.__promise_.__coro_ = __coro;
return __std::coroutine_handle<>::from_address(&__unhandled_stopped_frame::value);
}

inline auto
__coroutine_signal_completion(__std::atomic<std::thread::id>& __thread_id_) noexcept //
-> __std::coroutine_handle<>
{
__signal_completion_frame::value.__promise_.__thread_id_ = &__thread_id_;
return __std::coroutine_handle<>::from_address(&__signal_completion_frame::value);
}

# else

STDEXEC_ATTRIBUTE(always_inline)
Expand All @@ -244,6 +304,20 @@ namespace STDEXEC
return __continue;
}

STDEXEC_ATTRIBUTE(always_inline)
auto __coroutine_unhandled_stopped(__coroutine_handle<> __coro) noexcept //
-> __std::coroutine_handle<>
{
return __coro.unhandled_stopped();
}

STDEXEC_ATTRIBUTE(always_inline)
auto __coroutine_signal_completion(__std::atomic<std::thread::id>& __thread_id_) noexcept //
-> __std::coroutine_handle<>
{
__thread_id_.store({}, __std::memory_order_release);
return __std::noop_coroutine();
}
# endif // !defined(STDEXEC_MSVC_CORO_DESTROY_BUG_WORKAROUND)

} // namespace STDEXEC
Expand Down
8 changes: 8 additions & 0 deletions test/rrd/stdexec_relacy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,16 @@

#include "../../relacy/relacy_std.hpp"

#include <iosfwd>
#include <thread>

namespace std
{
template <class>
struct atomic_ref;

inline std::ostream& operator<<(std::ostream& os, std::thread::id const & id)
{
return os << "thread::id{" << id.id_ << "}";
}
} // namespace std
35 changes: 35 additions & 0 deletions test/stdexec/types/test_task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,41 @@ namespace
ex::sync_wait(scope.join());
}

auto await_always_inline_stopped_sender() -> ex::task<void>
{
co_await ex::just_stopped();
}

TEST_CASE("repro for NVIDIA/stdexec#2047 always inline", "[types][task]")
{
auto pool = exec::static_thread_pool(1);

auto scope = ex::counting_scope();
ex::spawn(ex::starts_on(pool.get_scheduler(), await_always_inline_stopped_sender())
| ex::upon_error([](auto) noexcept { std::terminate(); }),
scope.get_token());
ex::sync_wait(scope.join());
}

auto await_always_inline_value_sender_loop() -> ex::task<void>
{
for (size_t i = 0; i < 10000; ++i)
{
co_await ex::just();
}
}

TEST_CASE("repro for NVIDIA/stdexec#2047 no stack overflow", "[types][task]")
{
auto pool = exec::static_thread_pool(1);

auto scope = ex::counting_scope();
ex::spawn(ex::starts_on(pool.get_scheduler(), await_always_inline_value_sender_loop())
| ex::upon_error([](auto) noexcept { std::terminate(); }),
scope.get_token());
ex::sync_wait(scope.join());
}

// TODO: add tests for stop token support in task

} // anonymous namespace
Expand Down
Loading