Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 38 additions & 21 deletions include/stdexec/__detail/__task.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,10 +394,13 @@ namespace STDEXEC

struct __opstate_base : private allocator_type
{
template <class _Env>
constexpr explicit __opstate_base(task&& __task, _Env const & __env) noexcept
template <class _Env, class _OwnEnv>
constexpr explicit __opstate_base(task&& __task,
_Env const & __env,
_OwnEnv const & __own_env) noexcept
: allocator_type(__mk_alloc(__env))
, __sch_(__mk_sched(__env, __get_allocator()))
, __env_(__mk_env(__env, __own_env))
, __task_(static_cast<task&&>(__task))
{
auto& __promise = __task_.__coro_.promise();
Expand Down Expand Up @@ -434,19 +437,26 @@ namespace STDEXEC
}

start_scheduler_type __sch_;
_TaskEnv __env_;
task __task_;
__error_variant_t __errors_{__no_init};
};

template <class _Env>
struct __own_env_box
{
__own_env_t<_Env> __own_env_;
};

template <class _ParentPromise>
struct STDEXEC_ATTRIBUTE(empty_bases) __awaiter final
: __opstate_base
: __own_env_box<env_of_t<_ParentPromise>>
, __opstate_base
, __stop_callback_box_t<env_of_t<_ParentPromise>>
{
constexpr explicit __awaiter(task&& __task, _ParentPromise& __parent) noexcept
: __opstate_base(static_cast<task&&>(__task), STDEXEC::get_env(__parent))
, __own_env_(__mk_own_env(STDEXEC::get_env(__parent)))
, __env_(__mk_env(STDEXEC::get_env(__parent), __own_env_))
: __awaiter::__own_env_box{__mk_own_env(STDEXEC::get_env(__parent))}
, __opstate_base(static_cast<task&&>(__task), STDEXEC::get_env(__parent), this->__own_env_)
, __parent_(__parent)
{}

Expand Down Expand Up @@ -501,10 +511,8 @@ namespace STDEXEC
return __parent_.unhandled_stopped();
}

__own_env_t<_ParentPromise> __own_env_;
_TaskEnv __env_;
__std::coroutine_handle<> __continuation_;
_ParentPromise& __parent_;
__std::coroutine_handle<> __continuation_;
_ParentPromise& __parent_;
};

struct __attrs
Expand Down Expand Up @@ -542,21 +550,23 @@ namespace STDEXEC

////////////////////////////////////////////////////////////////////////////////////////
// task<T,E>::__opstate
template <class _Ty, class _Env>
template <class _Ty, class _TaskEnv>
template <class _Rcvr>
struct STDEXEC_ATTRIBUTE(empty_bases) task<_Ty, _Env>::__opstate final
struct STDEXEC_ATTRIBUTE(empty_bases) task<_Ty, _TaskEnv>::__opstate final
: __rcvr_box<_Rcvr> // holds the receiver so that we can pass __opstate_base a reference to it
, __opstate_base
, __own_env_box<env_of_t<_Rcvr>>
, __stop_callback_box_t<env_of_t<_Rcvr>>
, __opstate_base
{
public:
using operation_state_concept = operation_state_tag;

explicit __opstate(task&& __task, _Rcvr&& __rcvr) noexcept
: __rcvr_box<_Rcvr>{static_cast<_Rcvr&&>(__rcvr)}
, __opstate_base(static_cast<task&&>(__task), STDEXEC::get_env(this->__rcvr_))
, __own_env_(__mk_own_env(STDEXEC::get_env(this->__rcvr_)))
, __env_(__mk_env(STDEXEC::get_env(this->__rcvr_), __own_env_))
, __opstate::__own_env_box{__mk_own_env(STDEXEC::get_env(this->__rcvr_))}
, __opstate_base(static_cast<task&&>(__task),
STDEXEC::get_env(this->__rcvr_),
this->__own_env_)
{}

void start() & noexcept
Expand Down Expand Up @@ -634,15 +644,12 @@ namespace STDEXEC
STDEXEC::set_stopped(static_cast<_Rcvr&&>(this->__rcvr_));
return std::noop_coroutine();
}

__own_env_t<_Rcvr> __own_env_;
_Env __env_;
};

////////////////////////////////////////////////////////////////////////////////////////
// task<T,E>::promise_type
template <class _Ty, class _Env>
struct task<_Ty, _Env>::__promise : __task::__promise_base<_Ty>
template <class _Ty, class _TaskEnv>
struct task<_Ty, _TaskEnv>::__promise : __task::__promise_base<_Ty>
{
__promise() noexcept = default;

Expand Down Expand Up @@ -814,6 +821,16 @@ namespace STDEXEC
}
}

template <__forwarding_query _Query, class... _Args>
requires __queryable_with<_TaskEnv, _Query, _Args...>
[[nodiscard]]
constexpr auto query(_Query __tag, _Args&&... __args) const
noexcept(__nothrow_queryable_with<_TaskEnv, _Query, _Args...>)
-> __query_result_t<_TaskEnv, _Query, _Args...>
{
return __query<_Query>()(__promise_->__state_->__env_, static_cast<_Args&&>(__args)...);
}

__promise const * __promise_;
};

Expand Down
48 changes: 45 additions & 3 deletions test/stdexec/types/test_task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ namespace
CHECK(i == 42);
}

auto test_task_int_ref(int& i) -> ex::task<int&>
auto test_task_int_ref(int &i) -> ex::task<int &>
{
CHECK(get_id() == 0);
co_await ex::schedule(ex::inline_scheduler{});
Expand All @@ -90,7 +90,7 @@ namespace
TEST_CASE("test task<int&>", "[types][task]")
{
int value = 42;
auto t = test_task_int_ref(value) | ex::then([](int& i) { return std::ref(i); });
auto t = test_task_int_ref(value) | ex::then([](int &i) { return std::ref(i); });
auto [i] = ex::sync_wait(std::move(t)).value();
STATIC_REQUIRE(std::same_as<decltype(i), std::reference_wrapper<int>>);
CHECK(&i.get() == &value);
Expand Down Expand Up @@ -207,7 +207,7 @@ namespace

template <ex::__not_same_as<environment_type> _Env>
requires ex::__callable<ex::get_scheduler_t, _Env const &>
explicit test_env2(_Env const & other) noexcept
explicit test_env2(_Env const &other) noexcept
: sch(ex::get_scheduler(other))
{}

Expand Down Expand Up @@ -278,6 +278,48 @@ namespace
CHECK(i == 84'000'042);
}

struct my_env
{
template <class>
using env_type = my_env;

template <class Env>
requires std::invocable<ex::get_delegation_scheduler_t, Env const &>
&& std::same_as<std::invoke_result_t<ex::get_delegation_scheduler_t, Env const &>,
ex::run_loop::scheduler>
explicit my_env(Env const &env) noexcept
: delegation_scheduler_(ex::get_delegation_scheduler(env))
{}

[[nodiscard]]
auto query(ex::get_delegation_scheduler_t) const noexcept
{
return delegation_scheduler_;
}

ex::run_loop::scheduler delegation_scheduler_;
};

auto
test_task_provides_additional_queries_with_a_custom_env(ex::run_loop::scheduler sync_wt_dlgtn_sch)
-> ex::task<int, my_env>
{
// Fetch sync_wait's run_loop scheduler from the environment.
ex::run_loop::scheduler tsk_dlgtn_sch = co_await ex::read_env(ex::get_delegation_scheduler);
CHECK(tsk_dlgtn_sch == sync_wt_dlgtn_sch);
co_return 13;
}

TEST_CASE("task can provide additional queries through a custom environment", "[types][task]")
{
ex::sync_wait(ex::let_value(ex::read_env(ex::get_delegation_scheduler),
[](ex::run_loop::scheduler sync_wt_dlgtn_sch)
{
return test_task_provides_additional_queries_with_a_custom_env(
sync_wt_dlgtn_sch);
}));
}

// FUTURE TODO: add support so that `co_await sndr` can return a reference.
// auto test_task_awaits_just_ref_sender() -> ex::task<void> {
// int value = 42;
Expand Down
Loading