Skip to content

Commit

Permalink
improve schedule_from and let_[value|error|stopped] with custom varia…
Browse files Browse the repository at this point in the history
…nt type
  • Loading branch information
ericniebler committed May 26, 2024
1 parent 665672e commit 5f9b064
Show file tree
Hide file tree
Showing 14 changed files with 378 additions and 159 deletions.
2 changes: 1 addition & 1 deletion include/stdexec/__detail/__basic_sender.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ namespace stdexec {
start() & noexcept {
using __tag_t = typename __op_state::__tag_t;
auto&& __rcvr = this->__rcvr();
__tup::__apply(
__inner_ops_.apply(
[&](auto&... __ops) noexcept {
__sexpr_impl<__tag_t>::start(this->__state_, __rcvr, __ops...);
},
Expand Down
4 changes: 4 additions & 0 deletions include/stdexec/__detail/__diagnostics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ namespace stdexec {
struct _WITH_SENDERS_;
} // namespace __errs

struct _WHERE_;

struct _IN_ALGORITHM_;

template <__mstring _Diagnostic = __errs::__unrecognized_sender_type_diagnostic>
struct _UNRECOGNIZED_SENDER_TYPE_;

Expand Down
2 changes: 1 addition & 1 deletion include/stdexec/__detail/__just.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ namespace stdexec {

static constexpr auto start =
[]<class _State, class _Receiver>(_State& __state, _Receiver& __rcvr) noexcept -> void {
__tup::__apply(
__state.apply(
[&]<class... _Ts>(_Ts&... __ts) noexcept {
__tag_t()(static_cast<_Receiver&&>(__rcvr), static_cast<_Ts&&>(__ts)...);
},
Expand Down
115 changes: 53 additions & 62 deletions include/stdexec/__detail/__let.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "__tag_invoke.hpp"
#include "__transform_sender.hpp"
#include "__transform_completion_signatures.hpp"
#include "__variant.hpp"

#include <exception>

Expand Down Expand Up @@ -168,12 +169,14 @@ namespace stdexec {

// A metafunction that computes the result sender type for a given set of argument types
template <class _Fun, class _Set, class _Env, class _Sched>
using __result_sender_fn = //
__mcompose<
__mbind_back_q<__ensure_sender, __result_env_t<_Env, _Sched>, _Set>,
__transform<
__q<__decay_ref>,
__mbind_front<__mtry_catch_q<__call_result_t, __on_not_callable<_Set>>, _Fun>>>;
struct __result_sender_fn {
template <class... _Args>
using __f = //
__ensure_sender<
__mcall<__mtry_catch_q<__call_result_t, __on_not_callable<_Set>>, _Fun, __decay_t<_Args>&...>,
__result_env_t<_Env, _Sched>,
_Set>;
};

// The receiver that gets connected to the result sender is the input receiver,
// possibly augmented with the input sender's completion scheduler (which is
Expand Down Expand Up @@ -211,61 +214,40 @@ namespace stdexec {
_ResultSender,
__checked_result_receiver_t<_ResultSender, _Receiver, _Scheduler>>;

template <class _Receiver, class _Fun, class _Set, class _Sched>
using __op_state_for = //
__mcompose<
__mbind_back_q<__op_state_t, _Receiver, _Sched>,
__result_sender_fn<_Fun, _Set, env_of_t<_Receiver>, _Sched>>;

template <class _Set, class _Sig>
template <class _SetTag, class _Env, class _Fun, class _Sched>
struct __transform_signal_fn {
template <class, class, class>
using __f = completion_signatures<_Sig>;
};

template <class _Set, class... _Args>
struct __transform_signal_fn<_Set, _Set(_Args...)> {
template <class _Env, class _Fun, class _Sched>
template <class... _Args>
using __nothrow_connect = __mbool< //
((__nothrow_decay_copyable<_Args> && ...) //
&& __nothrow_callable<_Fun, _Args...> //
&& __nothrow_connectable_receiver_ref<
__minvoke<__result_sender_fn<_Fun, _Set, _Env, _Sched>, _Args...>,
__mcall<__result_sender_fn<_Fun, _SetTag, _Env, _Sched>, _Args...>,
_Env,
_Sched>)>;

template <class _Env, class _Fun, class _Sched>
template <class... _Args>
using __f = //
__try_make_completion_signatures<
__minvoke<__result_sender_fn<_Fun, _Set, _Env, _Sched>, _Args...>,
__result_env_t<_Env, _Sched>,
__if<
__nothrow_connect<_Env, _Fun, _Sched>,
completion_signatures<>,
completion_signatures<set_error_t(std::exception_ptr)>>>;
__mcall<
__mtry_q<__concat_completion_signatures>,
__completion_signatures_of_t<
__mcall<__result_sender_fn<_Fun, _SetTag, _Env, _Sched>, _Args...>,
__result_env_t<_Env, _Sched>>,
__eptr_completion_if_t<__nothrow_connect<_Args...>>>;
};

// _Sched is the input sender's completion scheduler, or __unknown_scheduler if it doesn't
// have one.
template <class _Env, class _Fun, class _Set, class _Sched, class _Sig>
using __transform_signal_t = __minvoke<__transform_signal_fn<_Set, _Sig>, _Env, _Fun, _Sched>;

template <class _Sender, class _Set>
using __completion_sched =
__query_result_or_t<get_completion_scheduler_t<_Set>, env_of_t<_Sender>, __unknown_scheduler>;

template <class _CvrefSender, class _Env, class _LetTag, class _Fun>
using __completions = //
__mapply<
__transform<
__mbind_front_q<
__transform_signal_t,
_Env,
_Fun,
__t<_LetTag>,
__completion_sched<_CvrefSender, __t<_LetTag>>>,
__mtry_q<__concat_completion_signatures>>,
__completion_signatures_of_t<_CvrefSender, _Env>>;
__gather_completion_signatures<
__completion_signatures_of_t<_CvrefSender, _Env>,
__t<_LetTag>,
__transform_signal_fn<__t<_LetTag>, _Env, _Fun, __completion_sched<_CvrefSender, __t<_LetTag>>>::
template __f,
__sigs::__default_completion,
__mtry_q<__concat_completion_signatures>::__f>;

template <__mstring _Where, __mstring _What>
struct _NO_COMMON_DOMAIN_ { };
Expand All @@ -277,10 +259,15 @@ namespace stdexec {
"The senders returned by Function do not all share a common domain"_mstr>;

template <class _Set>
using __try_common_domain_fn = //
__mtry_catch_q<
__domain::__common_domain_t,
__mcompose<__mbind_front_q<__mexception, __no_common_domain_t<_Set>>, __q<_WITH_SENDERS_>>>;
struct __try_common_domain_fn {
struct __error_fn {
template <class... _Senders>
using __f = __mexception<__no_common_domain_t<_Set>, _WITH_SENDERS_<_Senders...>>;
};

template <class... _Senders>
using __f = __mcall<__mtry_catch_q<__domain::__common_domain_t, __error_fn>, _Senders...>;
};

// Compute all the domains of all the result senders and make sure they're all the same
template <class _Set, class _Child, class _Fun, class _Env, class _Sched>
Expand Down Expand Up @@ -340,16 +327,23 @@ namespace stdexec {
};
}

template <class _Receiver, class _Fun, class _Set, class _Sched>
struct __op_state_for {
template <class... _Args>
using __f = __op_state_t<
__mcall<__result_sender_fn<_Fun, _Set, env_of_t<_Receiver>, _Sched>, _Args...>,
_Receiver,
_Sched>;
};

template <class _Receiver, class _Fun, class _Set, class _Sched, class... _Tuples>
struct __let_state {
using __fun_t = _Fun;
using __sched_t = _Sched;
using __env_t = __result_env_t<env_of_t<_Receiver>, _Sched>;
using __result_variant = std::variant<std::monostate, _Tuples...>;
using __result_variant = __variant_<__monostate, _Tuples...>;
using __op_state_variant = //
__minvoke<
__transform<__uncurry<__op_state_for<_Receiver, _Fun, _Set, _Sched>>, __nullable_variant_fn>,
_Tuples...>;
__variant_<__monostate, __mapply<__op_state_for<_Receiver, _Fun, _Set, _Sched>, _Tuples>...>;

template <class _ResultSender, class _OpState>
auto __get_result_receiver(const _ResultSender&, _OpState& __op_state) -> decltype(auto) {
Expand Down Expand Up @@ -452,7 +446,7 @@ namespace stdexec {
_Set,
_Child,
env_of_t<_Receiver>,
__q<__decayed_tuple>,
__q<__tup::__decayed_tuple>,
__mk_let_state>;

return __sndr.apply(
Expand All @@ -466,14 +460,11 @@ namespace stdexec {

template <class _State, class _OpState, class... _As>
static void __bind_(_State& __state, _OpState& __op_state, _As&&... __as) {
auto& __args =
__state.__args_.template emplace<__decayed_tuple<_As...>>(static_cast<_As&&>(__as)...);
auto __sndr2 = __apply(std::move(__state.__fun_), __args);
auto& __args = __state.__args_.emplace_from(__tup::__mktuple, static_cast<_As&&>(__as)...);
auto __sndr2 = __args.apply(std::move(__state.__fun_), __args);
auto __rcvr2 = __state.__get_result_receiver(__sndr2, __op_state);
auto __mkop = [&] {
return stdexec::connect(std::move(__sndr2), std::move(__rcvr2));
};
auto& __op2 = __state.__op_state3_.template emplace<decltype(__mkop())>(__conv{__mkop});
auto& __op2 = __state.__op_state3_.emplace_from(
stdexec::connect, std::move(__sndr2), std::move(__rcvr2));
stdexec::start(__op2);
}

Expand All @@ -484,7 +475,7 @@ namespace stdexec {
using _Fun = typename _State::__fun_t;
using _Sched = typename _State::__sched_t;
using _ResultSender =
__minvoke<__result_sender_fn<_Fun, _Set, env_of_t<_Receiver>, _Sched>, _As...>;
__mcall<__result_sender_fn<_Fun, _Set, env_of_t<_Receiver>, _Sched>, _As...>;

_State& __state = __op_state.__state_;
_Receiver& __rcvr = __op_state.__rcvr_;
Expand All @@ -509,7 +500,7 @@ namespace stdexec {
_OpState& __op_state,
_Tag,
_As&&... __as) noexcept -> void {
if constexpr (std::same_as<_Tag, _Set>) {
if constexpr (__same_as<_Tag, _Set>) {
__bind(__op_state, static_cast<_As&&>(__as)...);
} else {
using _Receiver = decltype(__op_state.__rcvr_);
Expand Down
3 changes: 3 additions & 0 deletions include/stdexec/__detail/__meta.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,9 @@ namespace stdexec {

#endif

template <class _Fn, class... _Args>
using __mcall = typename _Fn::template __f<_Args...>;

struct __disp_q {
template <class... _Args>
using __f = __disp<_Args...>;
Expand Down
93 changes: 46 additions & 47 deletions include/stdexec/__detail/__schedule_from.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,17 @@
#include "__schedulers.hpp"
#include "__transform_completion_signatures.hpp"
#include "__tuple.hpp"

#include <variant>
#include "__variant.hpp"

namespace stdexec {
/////////////////////////////////////////////////////////////////////////////
// [execution.senders.adaptors.schedule_from]
namespace __schfr {
template <class... _Ts>
using __value_tuple = __tup::__tuple_for<__decay_t<_Ts>...>;
using __tuple_t = __tup::__tuple_for<__decay_t<_Ts>...>;

template <class... _Ts>
using __variant_t = __variant_<__monostate, _Ts...>;

// Compute a variant type that is capable of storing the results of the
// input sender when it completes. The variant has type:
Expand All @@ -50,15 +52,17 @@ namespace stdexec {
// ...
// >
template <class _CvrefSender, class _Env>
using __variant_for_t = //
using __variant_for = //
__for_each_completion_signature<
__completion_signatures_of_t<_CvrefSender, _Env>,
__value_tuple,
__nullable_variant_fn::__f>;
__completion_signatures_of_t<_CvrefSender, _Env>,
__tuple_t,
__munique<__qq<__variant_>>::__f>;

template <class... _Values>
using __decay_value_sig = set_value_t (*)(__decay_t<_Values>...);

template <class _Tag>
using __decay_signature_fn =
__transform<__q<__decay_t>, __mcompose<__q<completion_signatures>, __qf<_Tag>>>;
template <class _Error>
using __decay_error_sig = set_error_t (*)(__decay_t<_Error>);

template <class... _Ts>
using __all_nothrow_decay_copyable = __mbool<(__nothrow_decay_copyable<_Ts> && ...)>;
Expand All @@ -76,16 +80,18 @@ namespace stdexec {

template <class _Scheduler, class _CvrefSender, class _Env>
using __completions_t = //
__try_make_completion_signatures<
_CvrefSender,
_Env,
__mtry_q<__concat_completion_signatures>::__f<
__transform_completion_signatures<
__completion_signatures_of_t<_CvrefSender, _Env>,
__decay_value_sig,
__decay_error_sig,
set_stopped_t (*)(),
__completion_signature_ptrs>,
__try_make_completion_signatures<
schedule_result_t<_Scheduler>,
_Env,
__with_error_t<_CvrefSender, _Env>,
__mconst<completion_signatures<>>>,
__decay_signature_fn<set_value_t>,
__decay_signature_fn<set_error_t>>;
__eptr_completion_if_t<__all_nothrow_decay_copyable_results<_CvrefSender, _Env>>,
__mconst<completion_signatures<>>>>;

template <class _SchedulerId>
struct __environ {
Expand All @@ -110,6 +116,23 @@ namespace stdexec {
template <class _Scheduler, class _Sexpr, class _Receiver>
struct __state;

template <class _State>
STDEXEC_ATTRIBUTE((always_inline))
auto
__make_visitor_fn(_State* __state) noexcept {
return [__state]<class _Tup>(_Tup& __tupl) noexcept -> void {
if constexpr (__same_as<_Tup, __monostate>) {
std::terminate(); // reaching this indicates a bug in schedule_from
} else {
__tupl.apply(
[&]<class... _Args>(auto __tag, _Args&... __args) noexcept -> void {
__tag(std::move(__state->__receiver()), static_cast<_Args&&>(__args)...);
},
__tupl);
}
};
}

// This receiver is to be completed on the execution context associated with the scheduler. When
// the source sender completes, the completion information is saved off in the operation state
// so that when this receiver completes, it can read the completion out of the operation state
Expand All @@ -119,24 +142,7 @@ namespace stdexec {
using receiver_concept = receiver_t;

void set_value() noexcept {
STDEXEC_ASSERT(!__state_->__data_.valueless_by_exception());
// Work around a but in nvc++:
using __receiver_t = _Receiver;
std::visit(
[__state = __state_]<class _Tup>(_Tup& __tupl) noexcept -> void {
if constexpr (__same_as<_Tup, std::monostate>) {
std::terminate(); // reaching this indicates a bug in schedule_from
} else {
__tup::__apply(
[&]<class... _Args>(auto __tag, _Args&... __args) noexcept -> void {
__tag(
static_cast<__receiver_t&&>(__state->__receiver()),
static_cast<_Args&&>(__args)...);
},
__tupl);
}
},
__state_->__data_);
__state_->__data_.visit(__make_visitor_fn(__state_), __state_->__data_);
}

template <class _Error>
Expand All @@ -160,7 +166,7 @@ namespace stdexec {
struct __state
: __enable_receiver_from_this<_Sexpr, _Receiver>
, __immovable {
using __variant_t = __variant_for_t<__child_of<_Sexpr>, env_of_t<_Receiver>>;
using __variant_t = __variant_for<__child_of<_Sexpr>, env_of_t<_Receiver>>;
using __receiver2_t = __receiver2<_Scheduler, _Sexpr, _Receiver>;

__variant_t __data_;
Expand Down Expand Up @@ -221,23 +227,16 @@ namespace stdexec {
__ignore,
_State& __state,
_Receiver& __rcvr,
_Tag,
_Tag __tag,
_Args&&... __args) noexcept -> void {
STDEXEC_APPLE_CLANG(__state.__self_ == &__state ? void() : std::terminate());
// Write the tag and the args into the operation state so that we can forward the completion
// from within the scheduler's execution context.
using __result = __value_tuple<_Tag, _Args...>;
constexpr bool __nothrow_ = noexcept(__result{{_Tag()}, {static_cast<_Args&&>(__args)}...});

auto __emplace_result = [&]() noexcept(__nothrow_) {
return __result{{_Tag()}, {static_cast<_Args&&>(__args)}...};
};

if constexpr (__nothrow_) {
__state.__data_.template emplace<__result>(__conv{__emplace_result});
if constexpr (__nothrow_callable<decltype(__tup::__mktuple), _Tag, _Args...>) {
__state.__data_.emplace_from(__tup::__mktuple, __tag, static_cast<_Args&&>(__args)...);
} else {
try {
__state.__data_.template emplace<__result>(__conv{__emplace_result});
__state.__data_.emplace_from(__tup::__mktuple, __tag, static_cast<_Args&&>(__args)...);
} catch (...) {
stdexec::set_error(static_cast<_Receiver&&>(__rcvr), std::current_exception());
return;
Expand Down
Loading

0 comments on commit 5f9b064

Please sign in to comment.