Skip to content

Commit

Permalink
use __tup::__cat_apply instead of std::tuple_cat and std::apply
Browse files Browse the repository at this point in the history
  • Loading branch information
ericniebler committed May 26, 2024
1 parent 61f3d08 commit af93d4b
Show file tree
Hide file tree
Showing 11 changed files with 205 additions and 151 deletions.
45 changes: 26 additions & 19 deletions include/nvexec/detail/variant.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -107,23 +107,25 @@ namespace nvexec {

template <class VisitorT, class V>
STDEXEC_ATTRIBUTE((host, device))
void visit_impl(
std::integral_constant<std::size_t, 0>,
VisitorT&& visitor,
V&& v,
std::size_t index) {
void
visit_impl(
std::integral_constant<std::size_t, 0>,
VisitorT&& visitor,
V&& v,
std::size_t index) {
if (0 == index) {
static_cast<VisitorT&&>(visitor)((static_cast<V&&>(v)).template get<0>());
}
}

template <std::size_t I, class VisitorT, class V>
STDEXEC_ATTRIBUTE((host, device))
void visit_impl(
std::integral_constant<std::size_t, I>,
VisitorT&& visitor,
V&& v,
std::size_t index) {
void
visit_impl(
std::integral_constant<std::size_t, I>,
VisitorT&& visitor,
V&& v,
std::size_t index) {
if (I == index) {
static_cast<VisitorT&&>(visitor)((static_cast<V&&>(v)).template get<I>());
return;
Expand All @@ -139,7 +141,8 @@ namespace nvexec {

template <class VisitorT, class V>
STDEXEC_ATTRIBUTE((host, device))
void visit(VisitorT&& visitor, V&& v) {
void
visit(VisitorT&& visitor, V&& v) {
detail::visit_impl(
std::integral_constant<std::size_t, stdexec::__decay_t<V>::size - 1>{},
static_cast<VisitorT&&>(visitor),
Expand All @@ -149,7 +152,8 @@ namespace nvexec {

template <class VisitorT, class V>
STDEXEC_ATTRIBUTE((host, device))
void visit(VisitorT&& visitor, V&& v, std::size_t index) {
void
visit(VisitorT&& visitor, V&& v, std::size_t index) {
detail::visit_impl(
std::integral_constant<std::size_t, stdexec::__decay_t<V>::size - 1>{},
static_cast<VisitorT&&>(visitor),
Expand Down Expand Up @@ -180,7 +184,8 @@ namespace nvexec {

template <std::size_t I>
STDEXEC_ATTRIBUTE((host, device))
detail::nth_type<I, Ts...>& get() noexcept {
detail::nth_type<I, Ts...>&
get() noexcept {
return get<detail::nth_type<I, Ts...>>();
}

Expand All @@ -197,28 +202,30 @@ namespace nvexec {
}

STDEXEC_ATTRIBUTE((host, device))

bool holds_alternative() const {
bool
holds_alternative() const {
return index_ != detail::npos<index_t>();
}

template <detail::one_of<Ts...> T, class... As>
STDEXEC_ATTRIBUTE((host, device))
void emplace(As&&... as) {
void
emplace(As&&... as) {
destroy();
construct<T>(static_cast<As&&>(as)...);
}

template <detail::one_of<Ts...> T, class... As>
STDEXEC_ATTRIBUTE((host, device))
void construct(As&&... as) {
void
construct(As&&... as) {
::new (storage_.data_) T(static_cast<As&&>(as)...);
index_ = index_of<T>();
}

STDEXEC_ATTRIBUTE((host, device))

void destroy() {
void
destroy() {
if (holds_alternative()) {
visit(
[](auto& val) noexcept {
Expand Down
29 changes: 15 additions & 14 deletions include/nvexec/stream/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ namespace nvexec {
std::pmr::memory_resource* managed_resource,
stream_pools_t* stream_pools,
queue::task_hub_t* hub,
stream_priority priority = stream_priority::normal)
stream_priority priority = stream_priority::normal) noexcept
: pinned_resource_(pinned_resource)
, managed_resource_(managed_resource)
, stream_pools_(stream_pools)
Expand Down Expand Up @@ -273,9 +273,6 @@ namespace nvexec {
}
};

template <class... Ts>
using decayed_tuple = ::cuda::std::tuple<__decay_t<Ts>...>;

struct set_noop {
template <class... Ts>
STDEXEC_ATTRIBUTE((host, device))
Expand All @@ -287,8 +284,7 @@ namespace nvexec {
};

template <class... Ts>
using _nullable_variant_t =
variant_t<::cuda::std::tuple<set_noop>, Ts...>;
using _nullable_variant_t = variant_t<::cuda::std::tuple<set_noop>, Ts...>;

template <class... Ts>
using decayed_tuple = ::cuda::std::tuple<__decay_t<Ts>...>;
Expand Down Expand Up @@ -385,20 +381,24 @@ namespace nvexec {

template <class... As>
STDEXEC_ATTRIBUTE((host, device))
void set_value(As&&... as) noexcept {
variant_->template emplace<decayed_tuple<set_value_t, As...>>(set_value_t(), static_cast<As&&>(as)...);
void
set_value(As&&... as) noexcept {
variant_->template emplace<decayed_tuple<set_value_t, As...>>(
set_value_t(), static_cast<As&&>(as)...);
producer_(task_);
}

STDEXEC_ATTRIBUTE((host, device))
void set_stopped() noexcept {
void
set_stopped() noexcept {
variant_->template emplace<decayed_tuple<set_stopped_t>>(set_stopped_t());
producer_(task_);
}

template <class Error>
STDEXEC_ATTRIBUTE((host, device))
void set_error(Error&& err) noexcept {
void
set_error(Error&& err) noexcept {
if constexpr (__decays_to<Error, std::exception_ptr>) {
// What is `exception_ptr` but death pending
variant_->template emplace<decayed_tuple<set_error_t, cudaError_t>>(
Expand Down Expand Up @@ -601,13 +601,14 @@ namespace nvexec {
operation_state_base_t<OuterReceiverId>& operation_state_;

template <class... _Args>
void set_value(_Args &&...__args) noexcept {
operation_state_.propagate_completion_signal(set_value_t(), static_cast<_Args &&>(__args)...);
void set_value(_Args&&... __args) noexcept {
operation_state_.propagate_completion_signal(
set_value_t(), static_cast<_Args&&>(__args)...);
}

template <class _Error>
void set_error(_Error &&__err) noexcept {
operation_state_.propagate_completion_signal(set_error_t(), static_cast<_Error &&>(__err));
void set_error(_Error&& __err) noexcept {
operation_state_.propagate_completion_signal(set_error_t(), static_cast<_Error&&>(__err));
}

void set_stopped() noexcept {
Expand Down
4 changes: 2 additions & 2 deletions include/nvexec/stream/split.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -327,10 +327,10 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
template <__decays_to<__t> Self, receiver Receiver>
requires receiver_of<Receiver, completion_signatures_of_t<Self, empty_env>>
STDEXEC_MEMFN_DECL(
auto connect)(this Self&& self, Receiver recvr) //
auto connect)(this Self&& self, Receiver rcvr) //
noexcept(__nothrow_constructible_from<__decay_t<Receiver>, Receiver>)
-> operation_t<Receiver> {
return operation_t<Receiver>{static_cast<Receiver&&>(recvr), self.shared_state_};
return operation_t<Receiver>{static_cast<Receiver&&>(rcvr), self.shared_state_};
}

auto get_env() const noexcept -> env_of_t<const Sender&> {
Expand Down
Loading

0 comments on commit af93d4b

Please sign in to comment.