From 747860165b0bb4ea829a0d83d1b0ba8ee9e19fb8 Mon Sep 17 00:00:00 2001 From: Eric Niebler Date: Wed, 21 Jan 2026 20:20:39 -0800 Subject: [PATCH] implement `std::execution::task_scheduler` per P3927 --- .../__system_context_replaceability_api.hpp | 132 +--- include/exec/system_context.hpp | 100 ++- include/stdexec/__detail/__basic_sender.hpp | 34 +- include/stdexec/__detail/__bulk.hpp | 22 +- .../__detail/__parallel_scheduler_backend.hpp | 253 +++++++ include/stdexec/__detail/__schedulers.hpp | 2 +- .../__detail/__sender_introspection.hpp | 1 - .../__system_context_default_impl.hpp | 132 ++-- .../__system_context_default_impl_entry.hpp | 7 +- .../__system_context_replaceability_api.hpp | 38 + include/stdexec/__detail/__task_scheduler.hpp | 656 ++++++++++++++++++ include/stdexec/__detail/__typeinfo.hpp | 6 +- include/stdexec/__detail/__variant.hpp | 15 + include/stdexec/execution.hpp | 111 +-- src/system_context/system_context.cpp | 2 +- test/CMakeLists.txt | 1 + test/exec/test_system_context.cpp | 73 +- .../test_system_context_replaceability.cpp | 18 +- .../schedulers/test_task_scheduler.cpp | 94 +++ test/test_common/schedulers.hpp | 74 ++ 20 files changed, 1431 insertions(+), 340 deletions(-) create mode 100644 include/stdexec/__detail/__parallel_scheduler_backend.hpp rename include/{exec => stdexec}/__detail/__system_context_default_impl.hpp (77%) rename include/{exec => stdexec}/__detail/__system_context_default_impl_entry.hpp (90%) create mode 100644 include/stdexec/__detail/__system_context_replaceability_api.hpp create mode 100644 include/stdexec/__detail/__task_scheduler.hpp create mode 100644 test/stdexec/schedulers/test_task_scheduler.cpp diff --git a/include/exec/__detail/__system_context_replaceability_api.hpp b/include/exec/__detail/__system_context_replaceability_api.hpp index c2caeb8d3..752c45501 100644 --- a/include/exec/__detail/__system_context_replaceability_api.hpp +++ b/include/exec/__detail/__system_context_replaceability_api.hpp @@ -18,128 +18,48 @@ #define STDEXEC_SYSTEM_CONTEXT_REPLACEABILITY_API_H #include "../../stdexec/__detail/__execution_fwd.hpp" +#include "../../stdexec/__detail/__system_context_replaceability_api.hpp" #include -#include -#include #include -#include -#include - -struct __uuid { - std::uint64_t __parts1; - std::uint64_t __parts2; - - friend auto operator==(__uuid, __uuid) noexcept -> bool = default; -}; namespace exec::system_context_replaceability { + using STDEXEC::system_context_replaceability::__parallel_scheduler_backend_factory; - /// Helper for the `__queryable_interface` concept. - template <__uuid X> - using __check_constexpr_uuid = void; - - /// Concept for a queryable interface. Ensures that the interface has a `__interface_identifier` member. - template - concept __queryable_interface = requires() { - typename __check_constexpr_uuid<_T::__interface_identifier>; - }; - - /// The details for making `_T` a runtime property. - template - struct __runtime_property_helper { - /// Is `_T` a property? - static constexpr bool __is_property = false; - /// The unique identifier for the property. - static constexpr __uuid __property_identifier{0, 0}; - }; - - /// `inplace_stope_token` is a runtime property. - template <> - struct __runtime_property_helper { - static constexpr bool __is_property = true; - static constexpr __uuid __property_identifier{0x8779c09d8aa249df, 0x867db0e653202604}; - }; - - /// Concept for a runtime property. - template - concept __runtime_property = __runtime_property_helper<_T>::__is_property; - - struct parallel_scheduler_backend; + /// Interface for the parallel scheduler backend. + using parallel_scheduler_backend [[deprecated( + "Use STDEXEC::system_context_replaceability::parallel_scheduler_backend instead.")]] = + STDEXEC::system_context_replaceability::parallel_scheduler_backend; /// Get the backend for the parallel scheduler. /// Users might replace this function. - auto query_parallel_scheduler_backend() -> std::shared_ptr; - - /// The type of a factory that can create `parallel_scheduler_backend` instances. - /// Out of spec. - using __parallel_scheduler_backend_factory = std::shared_ptr (*)(); + [[deprecated( + "Use STDEXEC::system_context_replaceability::query_parallel_scheduler_backend instead.")]] + inline auto query_parallel_scheduler_backend() + -> std::shared_ptr { + return STDEXEC::system_context_replaceability::query_parallel_scheduler_backend(); + } /// Set a factory for the parallel scheduler backend. /// Can be used to replace the parallel scheduler at runtime. /// Out of spec. - auto set_parallel_scheduler_backend(__parallel_scheduler_backend_factory __new_factory) - -> __parallel_scheduler_backend_factory; + [[deprecated( + "Use STDEXEC::system_context_replaceability::set_parallel_scheduler_backend instead.")]] + inline auto set_parallel_scheduler_backend(__parallel_scheduler_backend_factory __new_factory) + -> __parallel_scheduler_backend_factory { + return STDEXEC::system_context_replaceability::set_parallel_scheduler_backend(__new_factory); + } /// Interface for completing a sender operation. Backend will call frontend though this interface /// for completing the `schedule` and `schedule_bulk` operations. - struct receiver { - virtual ~receiver() = default; - - protected: - virtual auto __query_env(__uuid, void*) noexcept -> bool = 0; - - public: - /// Called when the system scheduler completes successfully. - virtual void set_value() noexcept = 0; - /// Called when the system scheduler completes with an error. - virtual void set_error(std::exception_ptr) noexcept = 0; - /// Called when the system scheduler was stopped. - virtual void set_stopped() noexcept = 0; - - /// Query the receiver for a property of type `_P`. - template - auto try_query() noexcept -> std::optional> { - if constexpr (__runtime_property<_P>) { - std::decay_t<_P> __p; - bool __success = - __query_env(__runtime_property_helper>::__property_identifier, &__p); - return __success ? std::make_optional(std::move(__p)) : std::nullopt; - } else { - return std::nullopt; - } - } - }; - - /// Receiver for bulk sheduling operations. - struct bulk_item_receiver : receiver { - /// Called for each item of a bulk operation, possible on different threads. - virtual void execute(std::uint32_t, std::uint32_t) noexcept = 0; - }; - - /// Interface for the parallel scheduler backend. - struct parallel_scheduler_backend { - static constexpr __uuid __interface_identifier{0x5ee9202498c4bd4f, 0xa1df2508ffcd9d7e}; - - virtual ~parallel_scheduler_backend() = default; - - /// Schedule work on parallel scheduler, calling `__r` when done and using `__s` for preallocated - /// memory. - virtual void schedule(std::span __s, receiver& __r) noexcept = 0; - /// Schedule bulk work of size `__n` on parallel scheduler, calling `__r` for different - /// subranges of [0, __n), and using `__s` for preallocated memory. - virtual void schedule_bulk_chunked( - std::uint32_t __n, - std::span __s, - bulk_item_receiver& __r) noexcept = 0; - /// Schedule bulk work of size `__n` on parallel scheduler, calling `__r` for each item, and - /// using `__s` for preallocated memory. - virtual void schedule_bulk_unchunked( - std::uint32_t __n, - std::span __s, - bulk_item_receiver& __r) noexcept = 0; - }; - + using receiver + [[deprecated("Use STDEXEC::system_context_replaceability::receiver_proxy instead.")]] = + STDEXEC::system_context_replaceability::receiver_proxy; + + /// Receiver for bulk scheduling operations. + using bulk_item_receiver [[deprecated( + "Use STDEXEC::system_context_replaceability::bulk_item_receiver_proxy instead.")]] = + STDEXEC::system_context_replaceability::bulk_item_receiver_proxy; } // namespace exec::system_context_replaceability #endif diff --git a/include/exec/system_context.hpp b/include/exec/system_context.hpp index 878367901..287ec7a19 100644 --- a/include/exec/system_context.hpp +++ b/include/exec/system_context.hpp @@ -15,11 +15,12 @@ */ #pragma once -#include - #include "../stdexec/execution.hpp" #include "__detail/__system_context_replaceability_api.hpp" +#include +#include + #ifndef STDEXEC_SYSTEM_CONTEXT_SCHEDULE_OP_SIZE # define STDEXEC_SYSTEM_CONTEXT_SCHEDULE_OP_SIZE 72 #endif @@ -43,29 +44,16 @@ namespace exec { namespace detail { /// Allows a frontend receiver of type `_Rcvr` to be passed to the backend. template - struct __receiver_adapter : system_context_replaceability::receiver { + struct __receiver_adapter : STDEXEC::system_context_replaceability::receiver_proxy { explicit __receiver_adapter(_Rcvr&& __rcvr) - : __rcvr_{std::forward<_Rcvr>(__rcvr)} { - } - - auto __query_env(__uuid __id, void* __dest) noexcept -> bool override { - using system_context_replaceability::__runtime_property_helper; - using __StopToken = decltype(STDEXEC::get_stop_token(STDEXEC::get_env(__rcvr_))); - if constexpr (std::is_same_v) { - if (__id == __runtime_property_helper::__property_identifier) { - *static_cast(__dest) = STDEXEC::get_stop_token( - STDEXEC::get_env(__rcvr_)); - return true; - } - } - return false; + : __rcvr_{static_cast<_Rcvr&&>(__rcvr)} { } void set_value() noexcept override { STDEXEC::set_value(std::forward<_Rcvr>(__rcvr_)); } - void set_error(std::exception_ptr __ex) noexcept override { + void set_error(std::exception_ptr&& __ex) noexcept override { STDEXEC::set_error(std::forward<_Rcvr>(__rcvr_), std::move(__ex)); } @@ -73,11 +61,36 @@ namespace exec { STDEXEC::set_stopped(std::forward<_Rcvr>(__rcvr_)); } + protected: + void __query_env( + STDEXEC::__type_index __query_type, + STDEXEC::__type_index __value_type, + void* __dest) const noexcept override { + if (__query_type == STDEXEC::__mtypeid) { + __query(STDEXEC::get_stop_token, __value_type, __dest); + } + } + + private: + void __query(STDEXEC::get_stop_token_t, STDEXEC::__type_index __value_type, void* __dest) + const noexcept { + using __stop_token_t = STDEXEC::stop_token_of_t>; + if constexpr (std::is_same_v) { + if (__value_type == STDEXEC::__mtypeid) { + using __dest_t = std::optional; + *static_cast<__dest_t*>(__dest) = STDEXEC::get_stop_token(STDEXEC::get_env(__rcvr_)); + } + } + } + + public: STDEXEC_ATTRIBUTE(no_unique_address) _Rcvr __rcvr_; }; /// The type large enough to store the data produced by a sender. + /// BUGBUG: this seems wrong. i think this should be a variant of tuples of possible + /// results. template using __sender_data_t = decltype(STDEXEC::sync_wait(std::declval<_Sender>()).value()); @@ -85,6 +98,7 @@ namespace exec { class parallel_scheduler; class __parallel_sender; + template class __parallel_bulk_sender; @@ -106,7 +120,7 @@ namespace exec { namespace detail { using __backend_ptr = - std::shared_ptr; + std::shared_ptr; template auto __make_parallel_scheduler_from(T, __backend_ptr) noexcept; @@ -199,7 +213,7 @@ namespace exec { auto& __scheduler_impl = __preallocated_.__as<__backend_ptr>(); auto __impl = std::move(__scheduler_impl); std::destroy_at(&__scheduler_impl); - __impl->schedule(__preallocated_.__as_storage(), __rcvr_); + __impl->schedule(__rcvr_, __preallocated_.__as_storage()); } /// Object that receives completion from the work described by the sender. @@ -312,7 +326,8 @@ namespace exec { /// This represents the base class that abstracts the storage of the values sent by the previous sender. /// Derived class will properly implement the receiver methods. template - struct __forward_args_receiver : system_context_replaceability::bulk_item_receiver { + struct __forward_args_receiver + : STDEXEC::system_context_replaceability::bulk_item_receiver_proxy { using __storage_t = detail::__sender_data_t<_Previous>; /// Storage for the arguments received from the previous sender. @@ -329,24 +344,11 @@ namespace exec { /// Stores `__as` in the base class storage, with the right types. explicit __typed_forward_args_receiver(_As&&... __as) { static_assert(sizeof(std::tuple<_As...>) <= sizeof(__base_t::__arguments_data_)); + // BUGBUG: this seems wrong. we are not ever destroying this tuple. new (__base_t::__arguments_data_) std::tuple...>{std::move(__as)...}; } - auto __query_env(__uuid __id, void* __dest) noexcept -> bool override { - auto __state = reinterpret_cast<_BulkState*>(this); - using system_context_replaceability::__runtime_property_helper; - using __StopToken = decltype(STDEXEC::get_stop_token(STDEXEC::get_env(__state->__rcvr_))); - if constexpr (std::is_same_v) { - if (__id == __runtime_property_helper::__property_identifier) { - *static_cast(__dest) = STDEXEC::get_stop_token( - STDEXEC::get_env(__state->__rcvr_)); - return true; - } - } - return false; - } - /// Calls `set_value()` on the final receiver of the bulk operation, using the values from the previous sender. void set_value() noexcept override { auto __state = reinterpret_cast<_BulkState*>(this); @@ -396,6 +398,30 @@ namespace exec { *reinterpret_cast*>(__base_t::__arguments_data_)); } } + + protected: + void __query_env( + STDEXEC::__type_index __query_type, + STDEXEC::__type_index __value_type, + void* __dest) const noexcept override { + if (__query_type == STDEXEC::__mtypeid) { + __query(STDEXEC::get_stop_token, __value_type, __dest); + } + } + + private: + void __query(STDEXEC::get_stop_token_t, STDEXEC::__type_index __value_type, void* __dest) + const noexcept { + auto __state = reinterpret_cast(this); + using __stop_token_t = STDEXEC::stop_token_of_t>; + if constexpr (std::is_same_v) { + using __dest_t = std::optional; + if (__value_type == STDEXEC::__mtypeid) { + *static_cast<__dest_t*>(__dest) = STDEXEC::get_stop_token( + STDEXEC::get_env(__state->__rcvr_)); + } + } + } }; /// The state needed to execute the bulk sender created from system context, minus the preallocates space. @@ -645,7 +671,7 @@ namespace exec { }; inline auto get_parallel_scheduler() -> parallel_scheduler { - auto __impl = system_context_replaceability::query_parallel_scheduler_backend(); + auto __impl = STDEXEC::system_context_replaceability::query_parallel_scheduler_backend(); if (!__impl) { STDEXEC_THROW(std::runtime_error{"No system context implementation found"}); } @@ -728,5 +754,5 @@ namespace exec { #if defined(STDEXEC_SYSTEM_CONTEXT_HEADER_ONLY) # define STDEXEC_SYSTEM_CONTEXT_INLINE inline -# include "__detail/__system_context_default_impl_entry.hpp" +# include "../stdexec/__detail/__system_context_default_impl_entry.hpp" #endif diff --git a/include/stdexec/__detail/__basic_sender.hpp b/include/stdexec/__detail/__basic_sender.hpp index 415ae8334..ac552923e 100644 --- a/include/stdexec/__detail/__basic_sender.hpp +++ b/include/stdexec/__detail/__basic_sender.hpp @@ -359,16 +359,19 @@ namespace STDEXEC { STDEXEC_ATTRIBUTE(nodiscard, always_inline) constexpr auto get_env() const noexcept -> decltype(auto) { - return __apply(__detail::__drop_front(__sexpr_impl<__tag_t>::get_attrs), *this); + return __apply( + __detail::__drop_front(__sexpr_impl<__tag_t>::get_attrs), __c_upcast<__sexpr>(*this)); } template static consteval auto get_completion_signatures() { using namespace __detail; - if constexpr (__has_get_completion_signatures<__tag_t, _Self, _Env...>) { - return __sexpr_impl<__tag_t>::template get_completion_signatures<_Self, _Env...>(); - } else if constexpr (__has_get_completion_signatures<__tag_t, _Self>) { - return __sexpr_impl<__tag_t>::template get_completion_signatures<_Self>(); + static_assert(STDEXEC_IS_BASE_OF(__sexpr, __decay_t<_Self>)); + using __self_t = __copy_cvref_t<_Self, __sexpr>; + if constexpr (__has_get_completion_signatures<__tag_t, __self_t, _Env...>) { + return __sexpr_impl<__tag_t>::template get_completion_signatures<__self_t, _Env...>(); + } else if constexpr (__has_get_completion_signatures<__tag_t, __self_t>) { + return __sexpr_impl<__tag_t>::template get_completion_signatures<__self_t>(); } else if constexpr (sizeof...(_Env) == 0) { return __dependent_sender<_Self>(); } else { @@ -379,12 +382,13 @@ namespace STDEXEC { // Non-standard extension: template STDEXEC_ATTRIBUTE(nodiscard, always_inline) - static constexpr auto static_connect(_Self&& __self, _Receiver __rcvr) - noexcept(__noexcept_of<__sexpr_impl<__tag_t>::connect, _Self, _Receiver>) - -> __result_of<__sexpr_impl<__tag_t>::connect, _Self, _Receiver> { - static_assert(__decays_to_derived_from<_Self, __sexpr>); + static constexpr auto static_connect(_Self&& __self, _Receiver __rcvr) noexcept( + __noexcept_of<__sexpr_impl<__tag_t>::connect, __copy_cvref_t<_Self, __sexpr>, _Receiver>) + -> __result_of<__sexpr_impl<__tag_t>::connect, __copy_cvref_t<_Self, __sexpr>, _Receiver> { + static_assert(STDEXEC_IS_BASE_OF(__sexpr, __decay_t<_Self>)); return __sexpr_impl<__tag_t>::connect( - static_cast<_Self&&>(__self), static_cast<_Receiver&&>(__rcvr)); + STDEXEC::__c_upcast<__sexpr>(static_cast<_Self&&>(__self)), + static_cast<_Receiver&&>(__rcvr)); } template @@ -408,12 +412,12 @@ namespace STDEXEC { // Non-standard extension: template STDEXEC_ATTRIBUTE(nodiscard, always_inline) - static constexpr auto submit(_Self&& __self, _Receiver&& __rcvr) - noexcept(__noexcept_of<__sexpr_impl<__tag_t>::submit, _Self, _Receiver>) - -> __result_of<__sexpr_impl<__tag_t>::submit, _Self, _Receiver> { - static_assert(__decays_to_derived_from<_Self, __sexpr>); + static constexpr auto submit(_Self&& __self, _Receiver&& __rcvr) noexcept( + __noexcept_of<__sexpr_impl<__tag_t>::submit, __copy_cvref_t<_Self, __sexpr>, _Receiver>) + -> __result_of<__sexpr_impl<__tag_t>::submit, __copy_cvref_t<_Self, __sexpr>, _Receiver> { return __sexpr_impl<__tag_t>::submit( - static_cast<_Self&&>(__self), static_cast<_Receiver&&>(__rcvr)); + STDEXEC::__c_upcast<__sexpr>(static_cast<_Self&&>(__self)), + static_cast<_Receiver&&>(__rcvr)); } }; diff --git a/include/stdexec/__detail/__bulk.hpp b/include/stdexec/__detail/__bulk.hpp index 59e70ea3a..586e50d92 100644 --- a/include/stdexec/__detail/__bulk.hpp +++ b/include/stdexec/__detail/__bulk.hpp @@ -51,6 +51,7 @@ namespace STDEXEC { : __pol_{__pol} { } + [[nodiscard]] const _Pol& __get() const noexcept { return __pol_; } @@ -61,6 +62,7 @@ namespace STDEXEC { /*implicit*/ __policy_wrapper(const sequenced_policy&) { } + [[nodiscard]] const sequenced_policy& __get() const noexcept { return seq; } @@ -71,6 +73,7 @@ namespace STDEXEC { /*implicit*/ __policy_wrapper(const parallel_policy&) { } + [[nodiscard]] const parallel_policy& __get() const noexcept { return par; } @@ -81,6 +84,7 @@ namespace STDEXEC { /*implicit*/ __policy_wrapper(const parallel_unsequenced_policy&) { } + [[nodiscard]] const parallel_unsequenced_policy& __get() const noexcept { return par_unseq; } @@ -91,6 +95,7 @@ namespace STDEXEC { /*implicit*/ __policy_wrapper(const unsequenced_policy&) { } + [[nodiscard]] const unsequenced_policy& __get() const noexcept { return unseq; } @@ -101,9 +106,8 @@ namespace STDEXEC { STDEXEC_ATTRIBUTE(no_unique_address) __policy_wrapper<_Pol> __pol_; _Shape __shape_; STDEXEC_ATTRIBUTE(no_unique_address) _Fun __fun_; - static constexpr auto __mbrs_ = - __mliterals<&__data::__pol_, &__data::__shape_, &__data::__fun_>(); }; + template __data(const _Pol&, _Shape, _Fun) -> __data<_Pol, _Shape, _Fun>; @@ -260,7 +264,7 @@ namespace STDEXEC { struct bulk_unchunked_t : __generic_bulk_t { }; template - struct __bulk_impl_base : __sexpr_defaults { + struct __impl_base : __sexpr_defaults { template using __fun_t = decltype(__decay_t<__data_of<_Sender>>::__fun_); @@ -287,7 +291,7 @@ namespace STDEXEC { }; }; - struct __bulk_chunked_impl : __bulk_impl_base { + struct __chunked_impl : __impl_base { //! This implements the core default behavior for `bulk_chunked`: //! When setting value, it calls the function with the entire range. //! Note: This is not done in parallel. That is customized by the scheduler. @@ -317,7 +321,7 @@ namespace STDEXEC { }; }; - struct __bulk_unchunked_impl : __bulk_impl_base { + struct __unchunked_impl : __impl_base { //! This implements the core default behavior for `bulk_unchunked`: //! When setting value, it loops over the shape and invokes the function. //! Note: This is not done in concurrently. That is customized by the scheduler. @@ -348,7 +352,7 @@ namespace STDEXEC { }; }; - struct __bulk_impl : __bulk_impl_base { + struct __impl : __impl_base { // Implementation is handled by lowering to `bulk_chunked` in the tag's `transform_sender`. }; } // namespace __bulk @@ -361,13 +365,13 @@ namespace STDEXEC { inline constexpr bulk_unchunked_t bulk_unchunked{}; template <> - struct __sexpr_impl : __bulk::__bulk_impl { }; + struct __sexpr_impl : __bulk::__impl { }; template <> - struct __sexpr_impl : __bulk::__bulk_chunked_impl { }; + struct __sexpr_impl : __bulk::__chunked_impl { }; template <> - struct __sexpr_impl : __bulk::__bulk_unchunked_impl { }; + struct __sexpr_impl : __bulk::__unchunked_impl { }; } // namespace STDEXEC STDEXEC_PRAGMA_POP() diff --git a/include/stdexec/__detail/__parallel_scheduler_backend.hpp b/include/stdexec/__detail/__parallel_scheduler_backend.hpp new file mode 100644 index 000000000..3d5eb5d8f --- /dev/null +++ b/include/stdexec/__detail/__parallel_scheduler_backend.hpp @@ -0,0 +1,253 @@ +/* + * Copyright (c) 2025 Lucian Radu Teodorescu, Lewis Baker + * Copyright (c) 2026 NVIDIA Corporation + * + * Licensed under the Apache License Version 2.0 with LLVM Exceptions + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://llvm.org/LICENSE.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "__execution_fwd.hpp" + +// include these after __execution_fwd.hpp +// #include "any_allocator.cuh" +#include "../functional.hpp" // IWYU pragma: keep for __with_default +#include "../stop_token.hpp" +#include "__queries.hpp" +#include "__typeinfo.hpp" + +#include +#include +#include + +STDEXEC_PRAGMA_PUSH() +STDEXEC_PRAGMA_IGNORE_MSVC(4702) // warning C4702: unreachable code + +namespace STDEXEC { + template + class any_allocator : public std::allocator<_Ty> { + public: + template + struct rebind { + using other = any_allocator<_OtherTy>; + }; + + template <__not_same_as _Alloc> + any_allocator(const _Alloc&) noexcept { + } + }; + + template + STDEXEC_HOST_DEVICE_DEDUCTION_GUIDE + any_allocator(_Alloc) -> any_allocator; + + STDEXEC_HOST_DEVICE_DEDUCTION_GUIDE + any_allocator(std::allocator) -> any_allocator; + + class task_scheduler; + + // namespace __detail { + // struct __env_proxy : __immovable { + // [[nodiscard]] + // virtual auto query(const get_stop_token_t&) const noexcept -> inplace_stop_token = 0; + // [[nodiscard]] + // virtual auto query(const get_allocator_t&) const noexcept -> any_allocator = 0; + // [[nodiscard]] + // virtual auto query(const get_scheduler_t&) const noexcept -> task_scheduler = 0; + // }; + // } // namespace __detail + + namespace system_context_replaceability { + /// Interface for completing a sender operation. Backend will call frontend though + /// this interface for completing the `schedule` and `schedule_bulk` operations. + class receiver_proxy { //: __detail::__env_proxy { + public: + virtual ~receiver_proxy() = 0; + + virtual void set_value() noexcept = 0; + virtual void set_error(std::exception_ptr&&) noexcept = 0; + virtual void set_stopped() noexcept = 0; + + // // NOT TO SPEC: + // [[nodiscard]] + // auto get_env() const noexcept -> const __detail::__env_proxy& { + // return *this; + // } + + /// Query the receiver for a property of type `_P`. + template + auto try_query(_Query) const noexcept -> std::optional<_P> { + std::optional<_P> __p; + __query_env(__mtypeid<_Query>, __mtypeid<_P>, &__p); + return __p; + } + + protected: + virtual void __query_env(__type_index, __type_index, void*) const noexcept = 0; + }; + + inline receiver_proxy::~receiver_proxy() = default; + + struct bulk_item_receiver_proxy : receiver_proxy { + virtual void execute(size_t, size_t) noexcept = 0; + }; + + /// Interface for the parallel scheduler backend. + struct parallel_scheduler_backend { + virtual ~parallel_scheduler_backend() = 0; + + /// Schedule work on parallel scheduler, calling `__r` when done and using `__s` for preallocated + /// memory. + virtual void schedule(receiver_proxy&, std::span) noexcept = 0; + + /// Schedule bulk work of size `__n` on parallel scheduler, calling `__r` for different + /// subranges of [0, __n), and using `__s` for preallocated memory. + virtual void + schedule_bulk_chunked(size_t, bulk_item_receiver_proxy&, std::span) noexcept = 0; + + /// Schedule bulk work of size `__n` on parallel scheduler, calling `__r` for each item, and + /// using `__s` for preallocated memory. + virtual void schedule_bulk_unchunked( + size_t, + bulk_item_receiver_proxy&, + std::span) noexcept = 0; + }; + + inline parallel_scheduler_backend::~parallel_scheduler_backend() = default; + } // namespace system_context_replaceability + + namespace __detail { + // Partially implements the _RcvrProxy interface (either receiver_proxy or + // bulk_item_receiver_proxy) in terms of a concrete receiver type _Rcvr. + template + struct __receiver_proxy_base : _RcvrProxy { + public: + using receiver_concept = receiver_t; + + explicit __receiver_proxy_base(_Rcvr rcvr) noexcept + : __rcvr_(static_cast<_Rcvr&&>(rcvr)) { + } + + void set_error(std::exception_ptr&& eptr) noexcept final { + STDEXEC::set_error(std::move(__rcvr_), std::move(eptr)); + } + + void set_stopped() noexcept final { + STDEXEC::set_stopped(std::move(__rcvr_)); + } + + protected: + void __query_env(__type_index __query_id, __type_index __value, void* __dest) + const noexcept final { + if (__query_id == __mtypeid) { + __query(get_stop_token, __value, __dest); + } else if (__query_id == __mtypeid) { + __query(get_allocator, __value, __dest); + } + } + + private: + void __query(get_stop_token_t, __type_index __value_type, void* __dest) const noexcept { + using __stop_token_t = stop_token_of_t>; + if constexpr (std::is_same_v) { + if (__value_type == __mtypeid) { + using __dest_t = std::optional; + *static_cast<__dest_t*>(__dest) = STDEXEC::get_stop_token(STDEXEC::get_env(__rcvr_)); + } + } + } + + void __query(get_allocator_t, __type_index __value_type, void* __dest) const noexcept { + if (__value_type == __mtypeid>) { + using __dest_t = std::optional>; + *static_cast<__dest_t*>(__dest) = any_allocator{ + __with_default(get_allocator, std::allocator())(STDEXEC::get_env(__rcvr_))}; + } + } + + // [[nodiscard]] + // auto query(const get_stop_token_t&) const noexcept -> inplace_stop_token final { + // if constexpr (__callable>) { + // if constexpr (__same_as>, inplace_stop_token>) { + // return get_stop_token(get_env(__rcvr_)); + // } + // } + // return inplace_stop_token{}; // MSVC thinks this is unreachable. :-? + // } + + // [[nodiscard]] + // auto query(const get_allocator_t&) const noexcept -> any_allocator final { + // return any_allocator{ + // __with_default(get_allocator, std::allocator())(get_env(__rcvr_))}; + // } + + // // defined in task_scheduler.cuh: + // [[nodiscard]] + // auto query(const get_scheduler_t& __query) const noexcept -> task_scheduler final; + + public: + _Rcvr __rcvr_; + }; + + template + struct __receiver_proxy + : __receiver_proxy_base<_Rcvr, system_context_replaceability::receiver_proxy> { + using __receiver_proxy_base< + _Rcvr, + system_context_replaceability::receiver_proxy + >::__receiver_proxy_base; + + void set_value() noexcept final { + STDEXEC::set_value(std::move(this->__rcvr_)); + } + }; + + // A receiver type that forwards its completion operations to a _RcvrProxy member held by + // reference (where _RcvrProxy is one of receiver_proxy or bulk_item_receiver_proxy). It + // is also responsible to destroying and, if necessary, deallocating the operation state. + template + struct __proxy_receiver { + using receiver_concept = receiver_t; + using __delete_fn_t = void(void*) noexcept; + + void set_value() noexcept { + auto& __proxy = __rcvr_proxy_; + __delete_fn_(__opstate_storage_); // NB: destroys *this + __proxy.set_value(); + } + + void set_error(std::exception_ptr eptr) noexcept { + auto& __proxy = __rcvr_proxy_; + __delete_fn_(__opstate_storage_); // NB: destroys *this + __proxy.set_error(std::move(eptr)); + } + + void set_stopped() noexcept { + auto& __proxy = __rcvr_proxy_; + __delete_fn_(__opstate_storage_); // NB: destroys *this + __proxy.set_stopped(); + } + + [[nodiscard]] + auto get_env() const noexcept -> env_of_t<_RcvrProxy> { + return STDEXEC::get_env(__rcvr_proxy_); + } + + _RcvrProxy& __rcvr_proxy_; + void* __opstate_storage_; + __delete_fn_t* __delete_fn_; + }; + } // namespace __detail +} // namespace STDEXEC + +STDEXEC_PRAGMA_POP() diff --git a/include/stdexec/__detail/__schedulers.hpp b/include/stdexec/__detail/__schedulers.hpp index 559a8ab5f..c1dc5c99e 100644 --- a/include/stdexec/__detail/__schedulers.hpp +++ b/include/stdexec/__detail/__schedulers.hpp @@ -180,7 +180,7 @@ namespace STDEXEC { using __read_query_t = typename get_completion_scheduler_t::__read_query_t; if constexpr (__callable<__read_query_t, _Sch, const _Env&...>) { - using __sch2_t = __call_result_t<__read_query_t, _Sch, const _Env&...>; + using __sch2_t = __decay_t<__call_result_t<__read_query_t, _Sch, const _Env&...>>; if constexpr (__same_as<_Sch, __sch2_t>) { _Sch __prev = __sch; do { diff --git a/include/stdexec/__detail/__sender_introspection.hpp b/include/stdexec/__detail/__sender_introspection.hpp index 87f1ca4bb..560e86112 100644 --- a/include/stdexec/__detail/__sender_introspection.hpp +++ b/include/stdexec/__detail/__sender_introspection.hpp @@ -19,7 +19,6 @@ #include "__meta.hpp" #include "__tuple.hpp" #include "__type_traits.hpp" -#include "__utility.hpp" #include #include // IWYU pragma: keep for std::terminate diff --git a/include/exec/__detail/__system_context_default_impl.hpp b/include/stdexec/__detail/__system_context_default_impl.hpp similarity index 77% rename from include/exec/__detail/__system_context_default_impl.hpp rename to include/stdexec/__detail/__system_context_default_impl.hpp index 60864fcdd..a27ab46ad 100644 --- a/include/exec/__detail/__system_context_default_impl.hpp +++ b/include/stdexec/__detail/__system_context_default_impl.hpp @@ -15,25 +15,20 @@ */ #pragma once +#include "__atomic.hpp" #include "__system_context_replaceability_api.hpp" -#include "../../stdexec/execution.hpp" #if STDEXEC_ENABLE_LIBDISPATCH -# include "../libdispatch_queue.hpp" // IWYU pragma: keep +# include "../../exec/libdispatch_queue.hpp" // IWYU pragma: keep #elif STDEXEC_ENABLE_IO_URING -# include "../linux/io_uring_context.hpp" // IWYU pragma: keep +# include "../../exec/linux/io_uring_context.hpp" // IWYU pragma: keep #elif STDEXEC_ENABLE_WINDOWS_THREAD_POOL -# include "../windows/windows_thread_pool.hpp" // IWYU pragma: keep +# include "../../exec/windows/windows_thread_pool.hpp" // IWYU pragma: keep #else -# include "../static_thread_pool.hpp" // IWYU pragma: keep +# include "../../exec/static_thread_pool.hpp" // IWYU pragma: keep #endif -#include "../../stdexec/__detail/__atomic.hpp" - -namespace exec::__system_context_default_impl { - using system_context_replaceability::receiver; - using system_context_replaceability::bulk_item_receiver; - using system_context_replaceability::parallel_scheduler_backend; +namespace STDEXEC::__system_context_default_impl { using system_context_replaceability::__parallel_scheduler_backend_factory; /// Receiver that calls the callback when the operation completes. @@ -70,16 +65,16 @@ namespace exec::__system_context_default_impl { using receiver_concept = STDEXEC::receiver_t; //! The operation state on the frontend. - receiver* __r_; + STDEXEC::system_context_replaceability::receiver_proxy* __r_; //! The parent operation state that we will destroy when we complete. __operation<_Sender>* __op_; void set_value() noexcept { auto __op = __op_; - auto __r = __r_; + auto __rcvr = __r_; __op->__destruct(); // destroys the operation, including `this`. - __r->set_value(); + __rcvr->set_value(); // Note: when calling a completion signal, the parent operation might complete, making the // static storage passed to this operation invalid. Thus, we need to ensure that we are not // using the operation state after the completion signal. @@ -87,21 +82,21 @@ namespace exec::__system_context_default_impl { void set_error(std::exception_ptr __ptr) noexcept { auto __op = __op_; - auto __r = __r_; + auto __rcvr = __r_; __op->__destruct(); // destroys the operation, including `this`. - __r->set_error(__ptr); + __rcvr->set_error(std::move(__ptr)); } void set_stopped() noexcept { auto __op = __op_; - auto __r = __r_; + auto __rcvr = __r_; __op->__destruct(); // destroys the operation, including `this`. - __r->set_stopped(); + __rcvr->set_stopped(); } [[nodiscard]] auto get_env() const noexcept -> decltype(auto) { - auto __o = __r_->try_query(); + auto __o = __r_->try_query(STDEXEC::get_stop_token); STDEXEC::inplace_stop_token __st = __o ? *__o : STDEXEC::inplace_stop_token{}; return STDEXEC::prop{STDEXEC::get_stop_token, __st}; } @@ -133,7 +128,7 @@ namespace exec::__system_context_default_impl { /// Try to construct the operation in the preallocated memory if it fits, otherwise allocate a new operation. static auto __construct_maybe_alloc( std::span __storage, - receiver* __completion, + STDEXEC::system_context_replaceability::receiver_proxy* __completion, _Sender __sndr) -> __operation* { __storage = __ensure_alignment(__storage, alignof(__operation)); if (__storage.data() == nullptr || __storage.size() < sizeof(__operation)) { @@ -158,7 +153,10 @@ namespace exec::__system_context_default_impl { } private: - __operation(_Sender __sndr, receiver* __completion, bool __on_heap) + __operation( + _Sender __sndr, + STDEXEC::system_context_replaceability::receiver_proxy* __completion, + bool __on_heap) : __inner_op_(STDEXEC::connect(std::move(__sndr), __recv<_Sender>{__completion, this})) , __on_heap_(__on_heap) { } @@ -170,13 +168,12 @@ namespace exec::__system_context_default_impl { }; template - struct __generic_impl : parallel_scheduler_backend { + struct __generic_impl : STDEXEC::system_context_replaceability::parallel_scheduler_backend { __generic_impl() - : __pool_scheduler_(__pool_.get_scheduler()) - , __available_parallelism_(0) { + : __pool_scheduler_(__pool_.get_scheduler()) { // If the pool exposes the available parallelism, use it to determine the chunk size. if constexpr (__has_available_paralellism<_BaseSchedulerContext>) { - __available_parallelism_ = static_cast(__pool_.available_parallelism()); + __available_parallelism_ = static_cast(__pool_.available_parallelism()); } else { __available_parallelism_ = std::thread::hardware_concurrency(); } @@ -190,39 +187,40 @@ namespace exec::__system_context_default_impl { __pool_scheduler_t __pool_scheduler_; //! The available parallelism of the pool, used to determine the chunk size. //! Use a value of 0 to disable chunking. - uint32_t __available_parallelism_; + size_t __available_parallelism_{}; //! Helper class that maps from a chunk index to the start and end of the chunk. struct __chunker { - uint32_t __chunk_size_; - uint32_t __max_size_; + size_t __chunk_size_; + size_t __max_size_; - uint32_t __begin(uint32_t __chunk_index) const noexcept { + [[nodiscard]] + size_t __begin(size_t __chunk_index) const noexcept { return __chunk_index * __chunk_size_; } - uint32_t __end(uint32_t __chunk_index) const noexcept { + [[nodiscard]] + size_t __end(size_t __chunk_index) const noexcept { return (std::min) (__begin(__chunk_index + 1), __max_size_); } }; //! Functor called by the `bulk_chunked` operation; sends a `execute` signal to the frontend. struct __bulk_chunked_functor { - bulk_item_receiver* __r_; + STDEXEC::system_context_replaceability::bulk_item_receiver_proxy* __r_; __chunker __chunker_; - void operator()(unsigned long __idx) const noexcept { - auto __chunk_index = static_cast(__idx); - __r_->execute(__chunker_.__begin(__chunk_index), __chunker_.__end(__chunk_index)); + void operator()(size_t const __idx) const noexcept { + __r_->execute(__chunker_.__begin(__idx), __chunker_.__end(__idx)); } }; //! Functor called by the `bulk_unchunked` operation; sends a `execute` signal to the frontend. struct __bulk_unchunked_functor { - bulk_item_receiver* __r_; + STDEXEC::system_context_replaceability::bulk_item_receiver_proxy* __r_; - void operator()(unsigned long __idx) const noexcept { - __r_->execute(static_cast(__idx), static_cast(__idx + 1)); + void operator()(size_t const __idx) const noexcept { + __r_->execute(__idx, __idx + 1); } }; @@ -232,72 +230,75 @@ namespace exec::__system_context_default_impl { using __schedule_bulk_chunked_operation_t = __operation()), STDEXEC::par, - std::declval(), + std::declval(), std::declval<__bulk_chunked_functor>()))>; + using __schedule_bulk_unchunked_operation_t = __operation()), STDEXEC::par, - std::declval(), + std::declval(), std::declval<__bulk_unchunked_functor>()))>; public: - void schedule(std::span __storage, receiver& __r) noexcept override { + void schedule( + STDEXEC::system_context_replaceability::receiver_proxy& __rcvr, + std::span __storage) noexcept override { STDEXEC_TRY { auto __sndr = STDEXEC::schedule(__pool_scheduler_); auto __os = - __schedule_operation_t::__construct_maybe_alloc(__storage, &__r, std::move(__sndr)); + __schedule_operation_t::__construct_maybe_alloc(__storage, &__rcvr, std::move(__sndr)); __os->start(); } STDEXEC_CATCH_ALL { - __r.set_error(std::current_exception()); + __rcvr.set_error(std::current_exception()); } } void schedule_bulk_chunked( - uint32_t __size, - std::span __storage, - bulk_item_receiver& __r) noexcept override { + size_t __size, + STDEXEC::system_context_replaceability::bulk_item_receiver_proxy& __rcvr, + std::span __storage) noexcept override { STDEXEC_TRY { // Determine the chunking size based on the ratio between the given size and the number of workers in our pool. // Aim at having 2 chunks per worker. - uint32_t __chunk_size = (__available_parallelism_ > 0 - && __size > 3 * __available_parallelism_) - ? __size / __available_parallelism_ / 2 - : 1; - uint32_t __num_chunks = (__size + __chunk_size - 1) / __chunk_size; + size_t __chunk_size = (__available_parallelism_ > 0 + && __size > 3ul * __available_parallelism_) + ? __size / __available_parallelism_ / 2ul + : 1ul; + size_t __num_chunks = (__size + __chunk_size - 1) / __chunk_size; auto __sndr = STDEXEC::bulk( STDEXEC::schedule(__pool_scheduler_), STDEXEC::par, __num_chunks, __bulk_chunked_functor{ - &__r, __chunker{__chunk_size, __size} + &__rcvr, __chunker{__chunk_size, __size} }); auto __os = __schedule_bulk_chunked_operation_t::__construct_maybe_alloc( - __storage, &__r, std::move(__sndr)); + __storage, &__rcvr, std::move(__sndr)); __os->start(); } STDEXEC_CATCH_ALL { - __r.set_error(std::current_exception()); + __rcvr.set_error(std::current_exception()); } } void schedule_bulk_unchunked( - uint32_t __size, - std::span __storage, - bulk_item_receiver& __r) noexcept override { + size_t __size, + STDEXEC::system_context_replaceability::bulk_item_receiver_proxy& __rcvr, + std::span __storage) noexcept override { STDEXEC_TRY { auto __sndr = STDEXEC::bulk( STDEXEC::schedule(__pool_scheduler_), STDEXEC::par, __size, - __bulk_unchunked_functor{&__r}); + __bulk_unchunked_functor{&__rcvr}); auto __os = __schedule_bulk_unchunked_operation_t::__construct_maybe_alloc( - __storage, &__r, std::move(__sndr)); + __storage, &__rcvr, std::move(__sndr)); __os->start(); } STDEXEC_CATCH_ALL { - __r.set_error(std::current_exception()); + __rcvr.set_error(std::current_exception()); } } }; @@ -314,10 +315,10 @@ namespace exec::__system_context_default_impl { auto __get_current_instance() -> std::shared_ptr<_Interface> { // If we have a valid instance, return it. __acquire_instance_lock(); - auto __r = __instance_; + auto __rcvr = __instance_; __release_instance_lock(); - if (__r) { - return __r; + if (__rcvr) { + return __rcvr; } // Otherwise, create a new instance using the factory. @@ -379,7 +380,10 @@ namespace exec::__system_context_default_impl { #endif /// The singleton to hold the `parallel_scheduler_backend` instance. - inline constinit __instance_data + inline constinit __instance_data< + STDEXEC::system_context_replaceability::parallel_scheduler_backend, + __parallel_scheduler_backend_impl + > __parallel_scheduler_backend_singleton{}; -} // namespace exec::__system_context_default_impl +} // namespace STDEXEC::__system_context_default_impl diff --git a/include/exec/__detail/__system_context_default_impl_entry.hpp b/include/stdexec/__detail/__system_context_default_impl_entry.hpp similarity index 90% rename from include/exec/__detail/__system_context_default_impl_entry.hpp rename to include/stdexec/__detail/__system_context_default_impl_entry.hpp index 1d147c3aa..aea9c2dbb 100644 --- a/include/exec/__detail/__system_context_default_impl_entry.hpp +++ b/include/stdexec/__detail/__system_context_default_impl_entry.hpp @@ -1,5 +1,6 @@ /* * Copyright (c) 2024 Lucian Radu Teodorescu + * Copyright (c) 2026 NVIDIA Corporation * * Licensed under the Apache License Version 2.0 with LLVM Exceptions * (the "License"); you may not use this file except in compliance with @@ -19,7 +20,7 @@ // and doxygen don't know that, so we need to include the header that defines it when clang-tidy and // doxygen are invoked. #if defined(STDEXEC_CLANG_TIDY_INVOKED) || defined(STDEXEC_DOXYGEN_INVOKED) -# include "../system_context.hpp" // IWYU pragma: keep +# include "../../exec/system_context.hpp" // IWYU pragma: keep #endif #if !defined(STDEXEC_SYSTEM_CONTEXT_INLINE) @@ -30,7 +31,7 @@ #define __STDEXEC_SYSTEM_CONTEXT_API extern STDEXEC_SYSTEM_CONTEXT_INLINE STDEXEC_ATTRIBUTE(weak) -namespace exec::system_context_replaceability { +namespace STDEXEC::system_context_replaceability { /// Get the backend for the parallel scheduler. /// Users might replace this function. @@ -48,4 +49,4 @@ namespace exec::system_context_replaceability { .__set_backend_factory(__new_factory); } -} // namespace exec::system_context_replaceability +} // namespace STDEXEC::system_context_replaceability diff --git a/include/stdexec/__detail/__system_context_replaceability_api.hpp b/include/stdexec/__detail/__system_context_replaceability_api.hpp new file mode 100644 index 000000000..9bba61ca8 --- /dev/null +++ b/include/stdexec/__detail/__system_context_replaceability_api.hpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2025 Lucian Radu Teodorescu, Lewis Baker + * + * Licensed under the Apache License Version 2.0 with LLVM Exceptions + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://llvm.org/LICENSE.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "__execution_fwd.hpp" +#include "__parallel_scheduler_backend.hpp" + +#include + +namespace STDEXEC::system_context_replaceability { + /// The type of a factory that can create `parallel_scheduler_backend` instances. + /// TODO(ericniebler): NOT TO SPEC. + using __parallel_scheduler_backend_factory = std::shared_ptr (*)(); + + /// Get the backend for the parallel scheduler. + /// Users might replace this function. + auto query_parallel_scheduler_backend() -> std::shared_ptr; + + /// Set a factory for the parallel scheduler backend. + /// Can be used to replace the parallel scheduler at runtime. + /// TODO(ericniebler): NOT TO SPEC. + auto set_parallel_scheduler_backend(__parallel_scheduler_backend_factory __new_factory) + -> __parallel_scheduler_backend_factory; +} // namespace STDEXEC::system_context_replaceability diff --git a/include/stdexec/__detail/__task_scheduler.hpp b/include/stdexec/__detail/__task_scheduler.hpp new file mode 100644 index 000000000..b56dbce03 --- /dev/null +++ b/include/stdexec/__detail/__task_scheduler.hpp @@ -0,0 +1,656 @@ +/* + * Copyright (c) 2026 NVIDIA Corporation + * + * Licensed under the Apache License Version 2.0 with LLVM Exceptions + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://llvm.org/LICENSE.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "__execution_fwd.hpp" + +// include these after __execution_fwd.hpp +#include "__bulk.hpp" +#include "__concepts.hpp" +#include "__diagnostics.hpp" +#include "__domain.hpp" +#include "__env.hpp" +#include "__inline_scheduler.hpp" +#include "__meta.hpp" +#include "__parallel_scheduler_backend.hpp" +#include "__queries.hpp" +#include "__schedulers.hpp" +#include "__transform_completion_signatures.hpp" +#include "__typeinfo.hpp" +#include "__variant.hpp" // IWYU pragma: keep for __variant_for + +#include + +#include +#include +#include +#include + +namespace STDEXEC { + class task_scheduler; + struct task_scheduler_domain; + + namespace __detail { + // The concrete type-erased sender returned by task_scheduler::schedule() + struct __task_sender; + + template + struct __task_bulk_sender; + + template + struct __task_bulk_state; + + template + struct __task_bulk_receiver; + + struct __task_scheduler_backend : system_context_replaceability::parallel_scheduler_backend { + [[nodiscard]] + virtual auto + query(get_forward_progress_guarantee_t) const noexcept -> forward_progress_guarantee = 0; + virtual auto __equal_to(const void* __other, __type_index __type) -> bool = 0; + }; + + using __backend_ptr_t = std::shared_ptr<__task_scheduler_backend>; + + template + concept __non_task_scheduler = __not_same_as && scheduler<_Sch>; + } // namespace __detail + + struct _CANNOT_DISPATCH_BULK_ALGORITHM_TO_TASK_SCHEDULER_BECAUSE_THERE_IS_NO_TASK_SCHEDULER_IN_THE_ENVIRONMENT; + struct _ADD_A_CONTINUES_ON_TRANSITION_TO_THE_TASK_SCHEDULER_BEFORE_THE_BULK_ALGORITHM; + + struct task_scheduler_domain : default_domain { + template > + requires __one_of<_BulkTag, bulk_chunked_t, bulk_unchunked_t> + [[nodiscard]] + static constexpr auto transform_sender(set_value_t, _Sndr&& __sndr, const _Env& __env) { + using __sched_t = __call_result_or_t< + get_completion_scheduler_t, + __not_a_scheduler<>, + env_of_t<_Sndr>, + const _Env& + >; + + if constexpr (!__same_as<__sched_t, task_scheduler>) { + return __not_a_sender< + _WHERE_(_IN_ALGORITHM_, _BulkTag), + _WHAT_<>( + _CANNOT_DISPATCH_BULK_ALGORITHM_TO_TASK_SCHEDULER_BECAUSE_THERE_IS_NO_TASK_SCHEDULER_IN_THE_ENVIRONMENT), + _TO_FIX_THIS_ERROR_( + _ADD_A_CONTINUES_ON_TRANSITION_TO_THE_TASK_SCHEDULER_BEFORE_THE_BULK_ALGORITHM), + _WITH_PRETTY_SENDER_<_Sndr>, + _WITH_ENVIRONMENT_(_Env) + >{}; + } else { + auto __sch = get_completion_scheduler(get_env(__sndr), __env); + return __detail::__task_bulk_sender<_Sndr>{static_cast<_Sndr&&>(__sndr), std::move(__sch)}; + } + } + }; + + //! @brief A type-erased scheduler. + //! + //! The `task_scheduler` struct is implemented in terms of a backend type derived from + //! @c parallel_scheduler_backend, providing a type-erased interface for scheduling tasks. + //! It exposes query functions to retrieve the completion scheduler and domain. + //! + //! @see parallel_scheduler_backend + class task_scheduler { + template + class __backend_for; + + public: + using scheduler_concept = scheduler_t; + + template > + requires __detail::__non_task_scheduler<_Sch> + explicit task_scheduler(_Sch __sch, _Alloc __alloc = {}) + : __backend_( + std::allocate_shared<__backend_for<_Sch, _Alloc>>(__alloc, std::move(__sch), __alloc)) { + } + + [[nodiscard]] + auto schedule() const noexcept -> __detail::__task_sender; + + [[nodiscard]] + bool operator==(const task_scheduler& __rhs) const noexcept = default; + + template <__detail::__non_task_scheduler _Sch> + [[nodiscard]] + auto operator==(const _Sch& __other) const noexcept -> bool { + return __backend_->__equal_to(std::addressof(__other), __mtypeid<_Sch>); + } + + [[nodiscard]] + auto query(get_forward_progress_guarantee_t) const noexcept -> forward_progress_guarantee { + return __backend_->query(get_forward_progress_guarantee); + } + + [[nodiscard]] + auto query(get_completion_scheduler_t) const noexcept -> task_scheduler { + return *this; + } + + [[nodiscard]] + constexpr auto query(get_completion_domain_t) const noexcept { + return task_scheduler_domain{}; + } + + private: + template + friend struct __detail::__task_bulk_sender; + friend struct __detail::__task_sender; + + __detail::__backend_ptr_t __backend_; + }; + + namespace __detail { + //! @brief A type-erased opstate returned when connecting the result of + //! task_scheduler::schedule() to a receiver. + template + class __task_opstate_t { + public: + using operation_state_concept = operation_state_t; + + __task_opstate_t(__backend_ptr_t __backend, _Rcvr __rcvr) + : __rcvr_proxy_(std::move(__rcvr)) + , __backend_(std::move(__backend)) { + } + + void start() noexcept { + STDEXEC_TRY { + __backend_->schedule(__rcvr_proxy_, std::span{__storage_}); + } + STDEXEC_CATCH_ALL { + __rcvr_proxy_.set_error(std::current_exception()); + } + } + + private: + __detail::__receiver_proxy<_Rcvr> __rcvr_proxy_; + __backend_ptr_t __backend_; + std::byte __storage_[8 * sizeof(void*)]; + }; + + //! @brief A type-erased sender returned by task_scheduler::schedule(). + struct __task_sender { + using sender_concept = sender_t; + using __completions_t = completion_signatures< + set_value_t(), // + set_error_t(std::exception_ptr), + set_stopped_t() + >; + + explicit __task_sender(task_scheduler __sch) + : __attrs_{std::move(__sch)} { + } + + template + [[nodiscard]] + auto connect(_Rcvr __rcvr) const noexcept -> __task_opstate_t<_Rcvr> { + return __task_opstate_t<_Rcvr>( + get_completion_scheduler(__attrs_).__backend_, std::move(__rcvr)); + } + + template + [[nodiscard]] + static consteval auto get_completion_signatures() noexcept -> __completions_t { + return {}; + } + + [[nodiscard]] + auto get_env() const noexcept -> const __sched_attrs& { + return __attrs_; + } + + private: + __sched_attrs __attrs_; + }; + + //! @brief A receiver used to connect the predecessor of a bulk operation launched by a + //! task_scheduler. Its set_value member stores the predecessor's values in the bulk + //! operation state and then starts the bulk operation. + template + struct __task_bulk_receiver { + using receiver_concept = receiver_t; + + template + void set_value(_As&&... __as) noexcept { + STDEXEC_TRY { + // Store the predecessor's values in the bulk operation state. + using __values_t = __decayed_tuple<_As...>; + __state_->__values_.template emplace<__values_t>(static_cast<_As&&>(__as)...); + + // Start the bulk operation. + if constexpr (__same_as<_BulkTag, bulk_chunked_t>) { + __state_->__backend_->schedule_bulk_chunked( + __state_->__shape_, *__state_, std::span{__state_->__storage_}); + } else { + __state_->__backend_->schedule_bulk_unchunked( + __state_->__shape_, *__state_, std::span{__state_->__storage_}); + } + } + STDEXEC_CATCH_ALL { + STDEXEC::set_error(std::move(__state_->__rcvr_), std::current_exception()); + } + } + + template + void set_error(_Error&& __err) noexcept { + STDEXEC::set_error(std::move(__state_->__rcvr_), static_cast<_Error&&>(__err)); + } + + void set_stopped() noexcept { + STDEXEC::set_stopped(std::move(__state_->__rcvr_)); + } + + [[nodiscard]] + auto get_env() const noexcept -> env_of_t<_Rcvr> { + return STDEXEC::get_env(__state_->__rcvr_); + } + + __task_bulk_state<_BulkTag, _Policy, _Fn, _Rcvr, _Values>* __state_; + }; + + //! Returns a visitor (callable) used to invoke the bulk (unchunked) function with the + //! predecessor's values, which are stored in a variant in the bulk operation state. + template + [[nodiscard]] + constexpr auto __get_execute_bulk_fn( + bulk_unchunked_t, + _Fn& __fn, + size_t __shape, + size_t __begin, + size_t) noexcept { + return [=, &__fn](auto& __args) { + constexpr bool __valid_args = !__same_as; + // runtime assert that we never take this path without valid args from the predecessor: + STDEXEC_ASSERT(__valid_args); + + if constexpr (__valid_args) { + // If we are not parallelizing, we need to run all the iterations sequentially. + const size_t __increments = _Parallelize ? 1 : __shape; + // Precompose the function with the arguments so we don't have to do it every iteration. + auto __precomposed_fn = __apply( + [&](auto&... __as) { + return [&](size_t __i) -> void { + __fn(__i, __as...); + }; + }, + __args); + for (size_t __i = __begin; __i < __begin + __increments; ++__i) { + __precomposed_fn(__i); + } + } + }; + } + + template + struct __apply_bulk_execute { + template + void operator()(_As&... __as) const noexcept(__nothrow_callable<_Fn&, size_t, _As&...>) { + if constexpr (_Parallelize) { + __fn_(__begin_, __end_, __as...); + } else { + // If we are not parallelizing, we need to pass the entire range to the functor. + __fn_(size_t(0), __shape_, __as...); + } + } + + size_t __begin_, __end_, __shape_; + _Fn& __fn_; + }; + + //! Returns a visitor (callable) used to invoke the bulk (chunked) function with the + //! predecessor's values, which are stored in a variant in the bulk operation state. + template + [[nodiscard]] + constexpr auto __get_execute_bulk_fn( + bulk_chunked_t, + _Fn& __fn, + size_t __shape, + size_t __begin, + size_t __end) noexcept { + return [=, &__fn](auto& __args) { + constexpr bool __valid_args = !__same_as; + STDEXEC_ASSERT(__valid_args); + + if constexpr (__valid_args) { + __apply(__apply_bulk_execute<_Parallelize, _Fn>{__begin, __end, __shape, __fn}, __args); + } + }; + } + + //! Stores the state for a bulk operation launched by a task_scheduler. A type-erased + //! reference to this object is passed to either the task_scheduler's + //! schedule_bulk_chunked or schedule_bulk_unchunked methods, which is expected to call + //! execute(begin, end) on it to run the bulk operation. After the bulk operation is + //! complete, set_value is called, which forwards the predecessor's values to the + //! downstream receiver. + template + struct __task_bulk_state + : __detail::__receiver_proxy_base< + _Rcvr, + system_context_replaceability::bulk_item_receiver_proxy + > { + explicit __task_bulk_state(_Rcvr __rcvr, size_t __shape, _Fn __fn, __backend_ptr_t __backend) + : __task_bulk_state::__receiver_proxy_base(std::move(__rcvr)) + , __fn_(std::move(__fn)) + , __shape_(__shape) + , __backend_(std::move(__backend)) { + } + + void set_value() noexcept final { + // Send the stored values to the downstream receiver. + __visit( + [this](auto& __tupl) { + constexpr bool __valid_args = __not_same_as; + // runtime assert that we never take this path without valid args from the predecessor: + STDEXEC_ASSERT(__valid_args); + + if constexpr (__valid_args) { + __apply(STDEXEC::set_value, std::move(__tupl), std::move(this->__rcvr_)); + } + }, + __values_); + } + + //! Actually runs the bulk operation over the specified range. + void execute(size_t __begin, size_t __end) noexcept final { + STDEXEC_TRY { + using __policy_t = std::remove_cvref_t().__get())>; + constexpr bool __parallelize = std::same_as<__policy_t, STDEXEC::parallel_policy> + || std::same_as<__policy_t, STDEXEC::parallel_unsequenced_policy>; + __visit( + __detail::__get_execute_bulk_fn<__parallelize>( + _BulkTag(), __fn_, __shape_, __begin, __end), + __values_); + } + STDEXEC_CATCH_ALL { + STDEXEC::set_error(std::move(this->__rcvr_), std::current_exception()); + } + } + + private: + template + friend struct __task_bulk_receiver; + + _Fn __fn_; + size_t __shape_; + _Values __values_{}; + __backend_ptr_t __backend_; + std::byte __storage_[8 * sizeof(void*)]; + }; + + //////////////////////////////////////////////////////////////////////////////////// + // Operation state for task scheduler bulk operations + template + struct __task_bulk_opstate { + using operation_state_concept = operation_state_t; + + explicit __task_bulk_opstate( + _Sndr&& __sndr, + size_t __shape, + _Fn __fn, + _Rcvr __rcvr, + __backend_ptr_t __backend) + : __state_{std::move(__rcvr), __shape, std::move(__fn), std::move(__backend)} + , __opstate1_(STDEXEC::connect(static_cast<_Sndr&&>(__sndr), __rcvr_t{&__state_})) { + } + + void start() noexcept { + STDEXEC::start(__opstate1_); + } + + private: + using __values_t = value_types_of_t< + _Sndr, + __fwd_env_t>, + __decayed_tuple, + __mbind_front_q<__variant_for, __monostate>::__f + >; + using __rcvr_t = __task_bulk_receiver<_BulkTag, _Policy, _Fn, _Rcvr, __values_t>; + using __opstate1_t = connect_result_t<_Sndr, __rcvr_t>; + + __task_bulk_state<_BulkTag, _Policy, _Fn, _Rcvr, __values_t> __state_; + __opstate1_t __opstate1_; + }; + + template + struct __task_bulk_sender { + using sender_concept = sender_t; + + explicit __task_bulk_sender(_Sndr __sndr, task_scheduler __sch) + : __sndr_(std::move(__sndr)) + , __attrs_{std::move(__sch)} { + } + + template + auto connect(_Rcvr __rcvr) && { + auto& [__tag, __data, __child] = __sndr_; + auto& [__pol, __shape, __fn] = __data; + return __task_bulk_opstate< + decltype(__tag), + decltype(__pol), + decltype(__child), + decltype(__fn), + _Rcvr + >{std::move(__child), + static_cast(__shape), + std::move(__fn), + std::move(__rcvr), + std::move(__attrs_.__sched_.__backend_)}; + } + + template + requires __same_as<_Self, __task_bulk_sender> // accept only rvalues. + [[nodiscard]] + static consteval auto get_completion_signatures() { + // This calls get_completion_signatures on the wrapped bulk_[un]chunked sender. We + // call it directly instead of using STDEXEC::get_completion_signatures to avoid + // another trip through transform_sender, which would lead to infinite recursion. + auto __completions = __decay_t<_Sndr>::template get_completion_signatures<_Sndr, _Env>(); + return STDEXEC::__transform_completion_signatures( + __completions, __decay_arguments(), {}, {}, __eptr_completion()); + } + + [[nodiscard]] + auto get_env() const noexcept -> const __sched_attrs& { + return __attrs_; + } + + private: + _Sndr __sndr_; + __sched_attrs __attrs_; + }; + + //! Function called by the `bulk_chunked` operation; calls `execute` on the bulk_item_receiver_proxy. + struct __bulk_chunked_fn { + void operator()(size_t __begin, size_t __end) noexcept { + __rcvr_.execute(__begin, __end); + } + + system_context_replaceability::bulk_item_receiver_proxy& __rcvr_; + }; + + //! Function called by the `bulk_unchunked` operation; calls `execute` on the bulk_item_receiver_proxy. + struct __bulk_unchunked_fn { + void operator()(size_t __idx) noexcept { + __rcvr_.execute(__idx, __idx + 1); + } + + system_context_replaceability::bulk_item_receiver_proxy& __rcvr_; + }; + + template + auto + __emplace_into(std::span __storage, _Alloc& __alloc, _Args&&... __args) -> _Ty& { + using __traits_t = std::allocator_traits<_Alloc>::template rebind_traits<_Ty>; + using __alloc_t = std::allocator_traits<_Alloc>::template rebind_alloc<_Ty>; + __alloc_t __alloc_copy{__alloc}; + + const bool __in_situ = __storage.size() >= sizeof(_Ty); + auto* __ptr = __in_situ ? reinterpret_cast<_Ty*>(__storage.data()) + : __traits_t::allocate(__alloc_copy, 1); + __traits_t::construct(__alloc_copy, __ptr, static_cast<_Args&&>(__args)...); + return *std::launder(__ptr); + } + + template + class __opstate : _Alloc { + public: + using allocator_type = _Alloc; + + explicit __opstate( + _Alloc __alloc, + _Sndr __sndr, + system_context_replaceability::receiver_proxy& __rcvr_proxy, + bool __in_situ) + : _Alloc(std::move(__alloc)) + , __opstate_( + STDEXEC::connect( + std::move(__sndr), + __detail::__proxy_receiver{ + __rcvr_proxy, + this, + __in_situ ? __delete_opstate : __delete_opstate})) { + } + __opstate(__opstate&&) = delete; + + void start() noexcept { + STDEXEC::start(__opstate_); + } + + [[nodiscard]] + auto query(get_allocator_t) const noexcept -> const _Alloc& { + return *this; + } + + private: + template + static void __delete_opstate(void* __ptr) noexcept { + using __traits_t = std::allocator_traits<_Alloc>::template rebind_traits<__opstate>; + using __alloc_t = std::allocator_traits<_Alloc>::template rebind_alloc<__opstate>; + auto* __op = static_cast<__opstate*>(__ptr); + __alloc_t __alloc_copy{get_allocator(*__op)}; + + __traits_t::destroy(__alloc_copy, __op); + if constexpr (!_InSitu) { + __traits_t::deallocate(__alloc_copy, __op, 1); + } + } + + using __child_opstate_t = connect_result_t< + _Sndr, + __detail::__proxy_receiver + >; + __child_opstate_t __opstate_; + }; + } // namespace __detail + + [[nodiscard]] + inline auto task_scheduler::schedule() const noexcept -> __detail::__task_sender { + return __detail::__task_sender{*this}; + } + + template + class task_scheduler::__backend_for + : public __detail::__task_scheduler_backend + , _Alloc { + template + friend struct __detail::__proxy_receiver; + + template + void __schedule( + _RcvrProxy& __rcvr_proxy, + _Sndr&& __sndr, + std::span __storage) noexcept { + STDEXEC_TRY { + using __opstate_t = connect_result_t<_Sndr, __detail::__proxy_receiver<_RcvrProxy>>; + const bool __in_situ = __storage.size() >= sizeof(__opstate_t); + _Alloc& __alloc = *this; + auto& __opstate = __detail::__emplace_into<__detail::__opstate<_Alloc, _Sndr>>( + __storage, __alloc, __alloc, static_cast<_Sndr&&>(__sndr), __rcvr_proxy, __in_situ); + STDEXEC::start(__opstate); + } + STDEXEC_CATCH_ALL { + __rcvr_proxy.set_error(std::current_exception()); + } + } + + public: + explicit __backend_for(_Sch __sch, _Alloc __alloc) + : _Alloc(std::move(__alloc)) + , __sch_(std::move(__sch)) { + } + + void schedule( + system_context_replaceability::receiver_proxy& __rcvr_proxy, + std::span __storage) noexcept final override { + __schedule(__rcvr_proxy, STDEXEC::schedule(__sch_), __storage); + } + + void schedule_bulk_chunked( + size_t __size, + system_context_replaceability::bulk_item_receiver_proxy& __rcvr_proxy, + std::span __storage) noexcept final { + auto __sndr = STDEXEC::bulk_chunked( + STDEXEC::schedule(__sch_), par, __size, __detail::__bulk_chunked_fn{__rcvr_proxy}); + __schedule(__rcvr_proxy, std::move(__sndr), __storage); + } + + void schedule_bulk_unchunked( + size_t __size, + system_context_replaceability::bulk_item_receiver_proxy& __rcvr_proxy, + std::span __storage) noexcept override { + auto __sndr = STDEXEC::bulk_unchunked( + STDEXEC::schedule(__sch_), par, __size, __detail::__bulk_unchunked_fn{__rcvr_proxy}); + __schedule(__rcvr_proxy, std::move(__sndr), __storage); + } + + [[nodiscard]] + auto + query(get_forward_progress_guarantee_t) const noexcept -> forward_progress_guarantee final { + return get_forward_progress_guarantee(__sch_); + } + + [[nodiscard]] + bool __equal_to(const void* __other, __type_index __type) final { + if (__type == __mtypeid<_Sch>) { + const _Sch& __other_sch = *static_cast(__other); + return __sch_ == __other_sch; + } + return false; + } + + private: + _Sch __sch_; + }; + + // namespace __detail { + // // Implementation of the get_scheduler_t query for __proxy_receiver_impl from + // // parallel_scheduler_backend.cuh + // template + // auto __receiver_proxy_base<_Rcvr, _Proxy>::query(const get_scheduler_t&) const noexcept + // -> task_scheduler { + // if constexpr (__callable>) { + // return task_scheduler{get_scheduler(get_env(__rcvr_))}; + // } else { + // return task_scheduler{inline_scheduler{}}; + // } + // } + // } // namespace __detail +} // namespace STDEXEC diff --git a/include/stdexec/__detail/__typeinfo.hpp b/include/stdexec/__detail/__typeinfo.hpp index 9f2f5026a..44f571f63 100644 --- a/include/stdexec/__detail/__typeinfo.hpp +++ b/include/stdexec/__detail/__typeinfo.hpp @@ -109,7 +109,7 @@ namespace STDEXEC { // This specialization is what makes __mtypeof< Id > return the type associated with Id. template - requires __same_as + requires __same_as extern __fn_t<__t()))>> *__mtypeof_v<_Index>; } // namespace __detail @@ -119,7 +119,5 @@ namespace STDEXEC { inline constexpr __type_index __mtypeid = __detail::__mtypeid_value<_Ty>::__id; // Sanity check: - static_assert(STDEXEC_IS_SAME(int, __mtypeof<__mtypeid>)); - - constexpr auto __nat_id = __mtypeid<__none_such>; + static_assert(STDEXEC_IS_SAME(void, __mtypeof<__mtypeid>)); } // namespace STDEXEC diff --git a/include/stdexec/__detail/__variant.hpp b/include/stdexec/__detail/__variant.hpp index 3ba5a0a63..28eec6b99 100644 --- a/include/stdexec/__detail/__variant.hpp +++ b/include/stdexec/__detail/__variant.hpp @@ -225,6 +225,21 @@ namespace STDEXEC { using __var::__variant; + struct __visit_t { + // clang-format off + template + STDEXEC_ATTRIBUTE(host, device, always_inline) + auto operator()(_Fn &&__fn, _Variant &&__var, _As &&...__as) const STDEXEC_AUTO_RETURN( + __var.visit( + static_cast<_Fn &&>(__fn), + static_cast<_Variant &&>(__var), + static_cast<_As &&>(__as)...) + ); + // clang-format on + }; + + inline constexpr __visit_t __visit{}; + template using __variant_for = __variant<__indices_for<_Ts...>{}, _Ts...>; diff --git a/include/stdexec/execution.hpp b/include/stdexec/execution.hpp index 23d017adf..d2b95932e 100644 --- a/include/stdexec/execution.hpp +++ b/include/stdexec/execution.hpp @@ -18,61 +18,64 @@ #include "__detail/__execution_fwd.hpp" // include these after __execution_fwd.hpp -#include "__detail/__as_awaitable.hpp" // IWYU pragma: export -#include "__detail/__basic_sender.hpp" // IWYU pragma: export -#include "__detail/__bulk.hpp" // IWYU pragma: export -#include "__detail/__completion_signatures.hpp" // IWYU pragma: export -#include "__detail/__connect_awaitable.hpp" // IWYU pragma: export -#include "__detail/__continues_on.hpp" // IWYU pragma: export -#include "__detail/__cpo.hpp" // IWYU pragma: export -#include "__detail/__debug.hpp" // IWYU pragma: export -#include "__detail/__domain.hpp" // IWYU pragma: export -#include "__detail/__ensure_started.hpp" // IWYU pragma: export -#include "__detail/__env.hpp" // IWYU pragma: export -#include "__detail/__execute.hpp" // IWYU pragma: export -#include "__detail/__execution_legacy.hpp" // IWYU pragma: export -#include "__detail/__inline_scheduler.hpp" // IWYU pragma: export -#include "__detail/__into_variant.hpp" // IWYU pragma: export -#include "__detail/__intrusive_ptr.hpp" // IWYU pragma: export -#include "__detail/__intrusive_slist.hpp" // IWYU pragma: export -#include "__detail/__just.hpp" // IWYU pragma: export -#include "__detail/__let.hpp" // IWYU pragma: export -#include "__detail/__meta.hpp" // IWYU pragma: export -#include "__detail/__on.hpp" // IWYU pragma: export -#include "__detail/__operation_states.hpp" // IWYU pragma: export -#include "__detail/__read_env.hpp" // IWYU pragma: export -#include "__detail/__receiver_adaptor.hpp" // IWYU pragma: export -#include "__detail/__receiver_ref.hpp" // IWYU pragma: export -#include "__detail/__receivers.hpp" // IWYU pragma: export -#include "__detail/__run_loop.hpp" // IWYU pragma: export -#include "__detail/__schedule_from.hpp" // IWYU pragma: export -#include "__detail/__schedulers.hpp" // IWYU pragma: export -#include "__detail/__sender_adaptor_closure.hpp" // IWYU pragma: export -#include "__detail/__senders.hpp" // IWYU pragma: export -#include "__detail/__split.hpp" // IWYU pragma: export -#include "__detail/__start_detached.hpp" // IWYU pragma: export -#include "__detail/__starts_on.hpp" // IWYU pragma: export -#include "__detail/__stopped_as_error.hpp" // IWYU pragma: export -#include "__detail/__stopped_as_optional.hpp" // IWYU pragma: export -#include "__detail/__submit.hpp" // IWYU pragma: export -#include "__detail/__sync_wait.hpp" // IWYU pragma: export -#include "__detail/__then.hpp" // IWYU pragma: export -#include "__detail/__transfer_just.hpp" // IWYU pragma: export -#include "__detail/__transform_completion_signatures.hpp" // IWYU pragma: export -#include "__detail/__transform_sender.hpp" // IWYU pragma: export -#include "__detail/__type_traits.hpp" // IWYU pragma: export -#include "__detail/__unstoppable.hpp" // IWYU pragma: export -#include "__detail/__upon_error.hpp" // IWYU pragma: export -#include "__detail/__upon_stopped.hpp" // IWYU pragma: export -#include "__detail/__utility.hpp" // IWYU pragma: export -#include "__detail/__when_all.hpp" // IWYU pragma: export -#include "__detail/__with_awaitable_senders.hpp" // IWYU pragma: export -#include "__detail/__write_env.hpp" // IWYU pragma: export +// IWYU pragma: begin_exports +#include "__detail/__as_awaitable.hpp" +#include "__detail/__basic_sender.hpp" +#include "__detail/__bulk.hpp" +#include "__detail/__completion_signatures.hpp" +#include "__detail/__connect_awaitable.hpp" +#include "__detail/__continues_on.hpp" +#include "__detail/__cpo.hpp" +#include "__detail/__debug.hpp" +#include "__detail/__domain.hpp" +#include "__detail/__ensure_started.hpp" +#include "__detail/__env.hpp" +#include "__detail/__execute.hpp" +#include "__detail/__execution_legacy.hpp" +#include "__detail/__inline_scheduler.hpp" +#include "__detail/__into_variant.hpp" +#include "__detail/__intrusive_ptr.hpp" +#include "__detail/__intrusive_slist.hpp" +#include "__detail/__just.hpp" +#include "__detail/__let.hpp" +#include "__detail/__meta.hpp" +#include "__detail/__on.hpp" +#include "__detail/__operation_states.hpp" +#include "__detail/__read_env.hpp" +#include "__detail/__receiver_adaptor.hpp" +#include "__detail/__receiver_ref.hpp" +#include "__detail/__receivers.hpp" +#include "__detail/__run_loop.hpp" +#include "__detail/__schedule_from.hpp" +#include "__detail/__schedulers.hpp" +#include "__detail/__sender_adaptor_closure.hpp" +#include "__detail/__senders.hpp" +#include "__detail/__split.hpp" +#include "__detail/__start_detached.hpp" +#include "__detail/__starts_on.hpp" +#include "__detail/__stopped_as_error.hpp" +#include "__detail/__stopped_as_optional.hpp" +#include "__detail/__submit.hpp" +#include "__detail/__sync_wait.hpp" +#include "__detail/__task_scheduler.hpp" +#include "__detail/__then.hpp" +#include "__detail/__transfer_just.hpp" +#include "__detail/__transform_completion_signatures.hpp" +#include "__detail/__transform_sender.hpp" +#include "__detail/__type_traits.hpp" +#include "__detail/__unstoppable.hpp" +#include "__detail/__upon_error.hpp" +#include "__detail/__upon_stopped.hpp" +#include "__detail/__utility.hpp" +#include "__detail/__when_all.hpp" +#include "__detail/__with_awaitable_senders.hpp" +#include "__detail/__write_env.hpp" -#include "concepts.hpp" // IWYU pragma: export -#include "coroutine.hpp" // IWYU pragma: export -#include "functional.hpp" // IWYU pragma: export -#include "stop_token.hpp" // IWYU pragma: export +#include "concepts.hpp" +#include "coroutine.hpp" +#include "functional.hpp" +#include "stop_token.hpp" +// IWYU pragma: end_exports // For issuing a meaningful diagnostic for the erroneous `snd1 | snd2`. template diff --git a/src/system_context/system_context.cpp b/src/system_context/system_context.cpp index d90a12b93..f755ec0b0 100644 --- a/src/system_context/system_context.cpp +++ b/src/system_context/system_context.cpp @@ -15,4 +15,4 @@ */ #define STDEXEC_SYSTEM_CONTEXT_INLINE /*no inline*/ -#include +#include "../../include/stdexec/__detail/__system_context_default_impl_entry.hpp" diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 2549fb8e2..a9194d69c 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -62,6 +62,7 @@ set(stdexec_test_sources stdexec/algos/other/test_execute.cpp stdexec/detail/test_completion_signatures.cpp stdexec/detail/test_utility.cpp + stdexec/schedulers/test_task_scheduler.cpp stdexec/queries/test_env.cpp stdexec/queries/test_get_forward_progress_guarantee.cpp stdexec/queries/test_forwarding_queries.cpp diff --git a/test/exec/test_system_context.cpp b/test/exec/test_system_context.cpp index a56c5c143..f0467dcab 100644 --- a/test/exec/test_system_context.cpp +++ b/test/exec/test_system_context.cpp @@ -29,7 +29,7 @@ #include namespace ex = STDEXEC; -namespace scr = exec::system_context_replaceability; +namespace scr = ex::system_context_replaceability; TEST_CASE("system_context can return a scheduler", "[types][system_scheduler]") { auto sched = exec::get_parallel_scheduler(); @@ -307,8 +307,8 @@ TEST_CASE( } struct my_parallel_scheduler_backend_impl - : exec::__system_context_default_impl::__parallel_scheduler_backend_impl { - using base_t = exec::__system_context_default_impl::__parallel_scheduler_backend_impl; + : ex::__system_context_default_impl::__parallel_scheduler_backend_impl { + using base_t = ex::__system_context_default_impl::__parallel_scheduler_backend_impl; my_parallel_scheduler_backend_impl() = default; @@ -317,9 +317,9 @@ struct my_parallel_scheduler_backend_impl return count_schedules_; } - void schedule(std::span __s, scr::receiver& __r) noexcept override { + void schedule(scr::receiver_proxy& __r, std::span __s) noexcept override { count_schedules_++; - base_t::schedule(__s, __r); + base_t::schedule(__r, __s); } @@ -328,22 +328,23 @@ struct my_parallel_scheduler_backend_impl }; struct my_inline_scheduler_backend_impl : scr::parallel_scheduler_backend { - void schedule(std::span, scr::receiver& r) noexcept override { + void schedule(scr::receiver_proxy& r, std::span) noexcept override { r.set_value(); } - void - schedule_bulk_chunked(uint32_t count, std::span, scr::bulk_item_receiver& r) noexcept - override { + void schedule_bulk_chunked( + size_t count, + scr::bulk_item_receiver_proxy& r, + std::span) noexcept override { r.execute(0, count); r.set_value(); } void schedule_bulk_unchunked( - uint32_t count, - std::span, - scr::bulk_item_receiver& r) noexcept override { - for (uint32_t i = 0; i < count; ++i) + size_t count, + scr::bulk_item_receiver_proxy& r, + std::span) noexcept override { + for (size_t i = 0; i < count; ++i) r.execute(i, i + 1); r.set_value(); } @@ -394,15 +395,14 @@ TEST_CASE( } TEST_CASE("empty environment always returns nullopt for any query", "[types][system_scheduler]") { - struct my_receiver : scr::receiver { - auto __query_env(__uuid, void*) noexcept -> bool override { - return false; + struct my_receiver : scr::receiver_proxy { + void __query_env(ex::__type_index, ex::__type_index, void*) const noexcept override { } void set_value() noexcept override { } - void set_error(std::exception_ptr) noexcept override { + void set_error(std::exception_ptr&&) noexcept override { } void set_stopped() noexcept override { @@ -411,38 +411,39 @@ TEST_CASE("empty environment always returns nullopt for any query", "[types][sys my_receiver rcvr{}; - REQUIRE(rcvr.try_query() == std::nullopt); - REQUIRE(rcvr.try_query() == std::nullopt); - REQUIRE(rcvr.try_query>() == std::nullopt); + REQUIRE(rcvr.try_query(ex::get_stop_token) == std::nullopt); + REQUIRE(rcvr.try_query(ex::get_stop_token) == std::nullopt); + REQUIRE(rcvr.try_query>(ex::get_allocator) == std::nullopt); } TEST_CASE("environment with a stop token can expose its stop token", "[types][system_scheduler]") { - struct my_receiver : scr::receiver { - auto __query_env(__uuid uuid, void* dest) noexcept -> bool override { - if ( - uuid - == scr::__runtime_property_helper::__property_identifier) { - *static_cast(dest) = ss.get_token(); - return true; - } - return false; - } - + struct my_receiver : ex::system_context_replaceability::receiver_proxy { void set_value() noexcept override { } - void set_error(std::exception_ptr) noexcept override { + void set_error(std::exception_ptr&&) noexcept override { } void set_stopped() noexcept override { } - STDEXEC::inplace_stop_source ss; + protected: + void __query_env(ex::__type_index query, ex::__type_index value, void* dest) + const noexcept override { + if ( + query == ex::__mtypeid + && value == ex::__mtypeid) { + *static_cast*>(dest) = ss.get_token(); + } + } + + public: + ex::inplace_stop_source ss; }; my_receiver rcvr{}; - auto o1 = rcvr.try_query(); + auto o1 = rcvr.try_query(ex::get_stop_token); REQUIRE(o1.has_value()); REQUIRE(o1.value().stop_requested() == false); REQUIRE(o1.value() == rcvr.ss.get_token()); @@ -450,6 +451,6 @@ TEST_CASE("environment with a stop token can expose its stop token", "[types][sy rcvr.ss.request_stop(); REQUIRE(o1.value().stop_requested() == true); - REQUIRE(rcvr.try_query() == std::nullopt); - REQUIRE(rcvr.try_query>() == std::nullopt); + REQUIRE(rcvr.try_query(ex::get_stop_token) == std::nullopt); + REQUIRE(rcvr.try_query>(ex::get_allocator) == std::nullopt); } diff --git a/test/exec/test_system_context_replaceability.cpp b/test/exec/test_system_context_replaceability.cpp index ffbc785cd..1428a0da6 100644 --- a/test/exec/test_system_context_replaceability.cpp +++ b/test/exec/test_system_context_replaceability.cpp @@ -15,38 +15,38 @@ */ #include -#include +#include #include #include namespace ex = STDEXEC; -namespace scr = exec::system_context_replaceability; +namespace scr = ex::system_context_replaceability; namespace { static int count_schedules = 0; struct my_parallel_scheduler_backend_impl - : exec::__system_context_default_impl::__parallel_scheduler_backend_impl { - using base_t = exec::__system_context_default_impl::__parallel_scheduler_backend_impl; + : ex::__system_context_default_impl::__parallel_scheduler_backend_impl { + using base_t = ex::__system_context_default_impl::__parallel_scheduler_backend_impl; my_parallel_scheduler_backend_impl() = default; - void schedule(std::span __s, scr::receiver& __r) noexcept override { + void schedule(scr::receiver_proxy& __r, std::span __s) noexcept override { count_schedules++; - base_t::schedule(__s, __r); + base_t::schedule(__r, __s); } }; } // namespace -namespace exec::system_context_replaceability { +namespace STDEXEC::system_context_replaceability { // Should replace the function defined in __system_context_default_impl.hpp auto query_parallel_scheduler_backend() - -> std::shared_ptr { + -> std::shared_ptr { return std::make_shared(); } -} // namespace exec::system_context_replaceability +} // namespace STDEXEC::system_context_replaceability TEST_CASE( "Check that we are using a replaced system context (with weak linking)", diff --git a/test/stdexec/schedulers/test_task_scheduler.cpp b/test/stdexec/schedulers/test_task_scheduler.cpp new file mode 100644 index 000000000..5d811c68e --- /dev/null +++ b/test/stdexec/schedulers/test_task_scheduler.cpp @@ -0,0 +1,94 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + * + * Licensed under the Apache License, Version 2.0 with LLVM Exceptions (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://llvm.org/LICENSE.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include +#include + +namespace ex = STDEXEC; + +namespace { +TEST_CASE("simple task_scheduler test", "[scheduler][task_scheduler]") +{ + ex::task_scheduler sched{dummy_scheduler{}}; + STATIC_REQUIRE(ex::scheduler); + auto sndr = sched.schedule(); + STATIC_REQUIRE(ex::sender); + auto op = ex::connect(std::move(sndr), expect_value_receiver{}); + ex::start(op); + // The receiver checks that it's called +} + +TEST_CASE("task_scheduler starts work on the correct execution context", "[scheduler][task_scheduler]") +{ + exec::single_thread_context ctx; + ex::task_scheduler sched{ctx.get_scheduler()}; + auto sndr = ex::starts_on(sched, ex::just() | ex::then([] { + return ::std::this_thread::get_id(); + })); + auto [tid] = ex::sync_wait(std::move(sndr)).value(); + CHECK(tid == ctx.get_thread_id()); +} + +static bool g_called = false; + +template +struct protect : private Sndr +{ + using sender_concept = ex::sender_t; + explicit protect(Sndr sndr) + : Sndr{std::move(sndr)} + {} + using Sndr::connect; + using Sndr::get_completion_signatures; + using Sndr::get_env; +}; + +struct test_domain +{ + template Sndr, class Env> + auto transform_sender(ex::set_value_t, Sndr sndr, const Env&) const + { + return ex::then(protect{std::move(sndr)}, []() noexcept { + g_called = true; + }); + } +}; + +TEST_CASE("bulk_unchunked dispatches correctly through task_scheduler", "[scheduler][task_scheduler]") +{ + ex::task_scheduler sched{dummy_scheduler{}}; + auto sndr = ex::on(sched, ex::just(-1) | ex::bulk_chunked(ex::par_unseq, 100, [](int, int, int&) {})); + g_called = false; + auto [val] = ex::sync_wait(std::move(sndr)).value(); + CHECK(val == -1); + CHECK(g_called); +} + +TEST_CASE("bulk dispatches correctly through task_scheduler", "[scheduler][task_scheduler]") +{ + ex::task_scheduler sched{dummy_scheduler{}}; + auto sndr = ex::on(sched, ex::just(-1) | ex::bulk(ex::par_unseq, 100, [](int, int&) {})); + g_called = false; + auto [val] = ex::sync_wait(std::move(sndr)).value(); + CHECK(val == -1); + CHECK(g_called); +} +} diff --git a/test/test_common/schedulers.hpp b/test/test_common/schedulers.hpp index 0f2953259..c694e2969 100644 --- a/test/test_common/schedulers.hpp +++ b/test/test_common/schedulers.hpp @@ -404,4 +404,78 @@ namespace { } }; }; + + namespace _dummy { + template + struct _attrs_t { + constexpr auto + query(ex::get_completion_scheduler_t) const noexcept; + + constexpr auto + query(ex::get_completion_domain_t) const noexcept { + return Domain{}; + } + }; + + template + struct _opstate_t : ex::__immovable { + using operation_state_concept = ex::operation_state_t; + + constexpr _opstate_t(Rcvr rcvr) noexcept + : _rcvr(static_cast(rcvr)) { + } + + constexpr void start() noexcept { + ex::set_value(static_cast(_rcvr)); + } + + Rcvr _rcvr; + }; + + template + struct _sndr_t { + using sender_concept = ex::sender_t; + + template + static consteval auto get_completion_signatures() noexcept { + return ex::completion_signatures(); + } + + template + constexpr auto connect(Rcvr rcvr) const noexcept -> _opstate_t { + return _opstate_t(static_cast(rcvr)); + } + + [[nodiscard]] + constexpr auto get_env() const noexcept { + return _attrs_t{}; + } + }; + } // namespace _dummy + + //! Scheduler that returns a sender that always completes inline (successfully). + template + struct dummy_scheduler : _dummy::_attrs_t { + using scheduler_concept = ex::scheduler_t; + + static constexpr auto schedule() noexcept -> _dummy::_sndr_t { + return {}; + } + + friend constexpr bool operator==(dummy_scheduler, dummy_scheduler) noexcept { + return true; + } + + friend constexpr bool operator!=(dummy_scheduler, dummy_scheduler) noexcept { + return false; + } + }; + + namespace _dummy { + template + constexpr auto + _attrs_t::query(ex::get_completion_scheduler_t) const noexcept { + return dummy_scheduler{}; + } + } // namespace _dummy } // anonymous namespace