Skip to content

Commit

Permalink
move with_awaitable_senders into its own header
Browse files Browse the repository at this point in the history
  • Loading branch information
ericniebler committed May 6, 2024
1 parent f93649f commit c34728f
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 106 deletions.
132 changes: 132 additions & 0 deletions include/stdexec/__detail/__with_awaitable_senders.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* Copyright (c) 2021-2024 NVIDIA Corporation
*
* Licensed under the Apache License Version 2.0 with LLVM Exceptions
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* https://llvm.org/LICENSE.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "__execution_fwd.hpp"

#include "__as_awaitable.hpp"
#include "__concepts.hpp"

#include <execution>

namespace stdexec {
#if !STDEXEC_STD_NO_COROUTINES()
namespace __was {
template <class _Promise = void>
class __continuation_handle;

template <>
class __continuation_handle<void> {
public:
__continuation_handle() = default;

template <class _Promise>
__continuation_handle(__coro::coroutine_handle<_Promise> __coro) noexcept
: __coro_(__coro) {
if constexpr (requires(_Promise& __promise) { __promise.unhandled_stopped(); }) {
__stopped_callback_ = [](void* __address) noexcept -> __coro::coroutine_handle<> {
// This causes the rest of the coroutine (the part after the co_await
// of the sender) to be skipped and invokes the calling coroutine's
// stopped handler.
return __coro::coroutine_handle<_Promise>::from_address(__address)
.promise()
.unhandled_stopped();
};
}
// If _Promise doesn't implement unhandled_stopped(), then if a "stopped" unwind
// reaches this point, it's considered an unhandled exception and terminate()
// is called.
}

[[nodiscard]]
auto handle() const noexcept -> __coro::coroutine_handle<> {
return __coro_;
}

[[nodiscard]]
auto unhandled_stopped() const noexcept -> __coro::coroutine_handle<> {
return __stopped_callback_(__coro_.address());
}

private:
using __stopped_callback_t = __coro::coroutine_handle<> (*)(void*) noexcept;

__coro::coroutine_handle<> __coro_{};
__stopped_callback_t __stopped_callback_ = [](void*) noexcept -> __coro::coroutine_handle<> {
std::terminate();
};
};

template <class _Promise>
class __continuation_handle {
public:
__continuation_handle() = default;

__continuation_handle(__coro::coroutine_handle<_Promise> __coro) noexcept
: __continuation_{__coro} {
}

auto handle() const noexcept -> __coro::coroutine_handle<_Promise> {
return __coro::coroutine_handle<_Promise>::from_address(__continuation_.handle().address());
}

[[nodiscard]]
auto unhandled_stopped() const noexcept -> __coro::coroutine_handle<> {
return __continuation_.unhandled_stopped();
}

private:
__continuation_handle<> __continuation_{};
};

struct __with_awaitable_senders_base {
template <class _OtherPromise>
void set_continuation(__coro::coroutine_handle<_OtherPromise> __hcoro) noexcept {
static_assert(!__same_as<_OtherPromise, void>);
__continuation_ = __hcoro;
}

void set_continuation(__continuation_handle<> __continuation) noexcept {
__continuation_ = __continuation;
}

[[nodiscard]]
auto continuation() const noexcept -> __continuation_handle<> {
return __continuation_;
}

auto unhandled_stopped() noexcept -> __coro::coroutine_handle<> {
return __continuation_.unhandled_stopped();
}

private:
__continuation_handle<> __continuation_{};
};

template <class _Promise>
struct with_awaitable_senders : __with_awaitable_senders_base {
template <class _Value>
auto await_transform(_Value&& __val) -> __call_result_t<as_awaitable_t, _Value, _Promise&> {
static_assert(derived_from<_Promise, with_awaitable_senders>);
return as_awaitable(static_cast<_Value&&>(__val), static_cast<_Promise&>(*this));
}
};
} // namespace __was

using __was::with_awaitable_senders;
using __was::__continuation_handle;
#endif
} // namespace stdexec
107 changes: 1 addition & 106 deletions include/stdexec/execution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "__detail/__transform_completion_signatures.hpp"
#include "__detail/__type_traits.hpp"
#include "__detail/__utility.hpp"
#include "__detail/__with_awaitable_senders.hpp"

#include "functional.hpp"
#include "concepts.hpp"
Expand Down Expand Up @@ -86,112 +87,6 @@ namespace stdexec {
requires bool { _Predicate(_Tag{}) };
};

#if !STDEXEC_STD_NO_COROUTINES()
namespace __with_awaitable_senders {
template <class _Promise = void>
class __continuation_handle;

template <>
class __continuation_handle<void> {
public:
__continuation_handle() = default;

template <class _Promise>
__continuation_handle(__coro::coroutine_handle<_Promise> __coro) noexcept
: __coro_(__coro) {
if constexpr (requires(_Promise& __promise) { __promise.unhandled_stopped(); }) {
__stopped_callback_ = [](void* __address) noexcept -> __coro::coroutine_handle<> {
// This causes the rest of the coroutine (the part after the co_await
// of the sender) to be skipped and invokes the calling coroutine's
// stopped handler.
return __coro::coroutine_handle<_Promise>::from_address(__address)
.promise()
.unhandled_stopped();
};
}
// If _Promise doesn't implement unhandled_stopped(), then if a "stopped" unwind
// reaches this point, it's considered an unhandled exception and terminate()
// is called.
}

[[nodiscard]]
auto handle() const noexcept -> __coro::coroutine_handle<> {
return __coro_;
}

[[nodiscard]]
auto unhandled_stopped() const noexcept -> __coro::coroutine_handle<> {
return __stopped_callback_(__coro_.address());
}

private:
__coro::coroutine_handle<> __coro_{};
using __stopped_callback_t = __coro::coroutine_handle<> (*)(void*) noexcept;
__stopped_callback_t __stopped_callback_ = [](void*) noexcept -> __coro::coroutine_handle<> {
std::terminate();
};
};

template <class _Promise>
class __continuation_handle {
public:
__continuation_handle() = default;

__continuation_handle(__coro::coroutine_handle<_Promise> __coro) noexcept
: __continuation_{__coro} {
}

auto handle() const noexcept -> __coro::coroutine_handle<_Promise> {
return __coro::coroutine_handle<_Promise>::from_address(__continuation_.handle().address());
}

[[nodiscard]]
auto unhandled_stopped() const noexcept -> __coro::coroutine_handle<> {
return __continuation_.unhandled_stopped();
}

private:
__continuation_handle<> __continuation_{};
};

struct __with_awaitable_senders_base {
template <class _OtherPromise>
void set_continuation(__coro::coroutine_handle<_OtherPromise> __hcoro) noexcept {
static_assert(!std::is_void_v<_OtherPromise>);
__continuation_ = __hcoro;
}

void set_continuation(__continuation_handle<> __continuation) noexcept {
__continuation_ = __continuation;
}

[[nodiscard]]
auto continuation() const noexcept -> __continuation_handle<> {
return __continuation_;
}

auto unhandled_stopped() noexcept -> __coro::coroutine_handle<> {
return __continuation_.unhandled_stopped();
}

private:
__continuation_handle<> __continuation_{};
};

template <class _Promise>
struct with_awaitable_senders : __with_awaitable_senders_base {
template <class _Value>
auto await_transform(_Value&& __val) -> __call_result_t<as_awaitable_t, _Value, _Promise&> {
static_assert(derived_from<_Promise, with_awaitable_senders>);
return as_awaitable(static_cast<_Value&&>(__val), static_cast<_Promise&>(*this));
}
};
} // namespace __with_awaitable_senders

using __with_awaitable_senders::with_awaitable_senders;
using __with_awaitable_senders::__continuation_handle;
#endif

namespace {
inline constexpr auto __ref = []<class _Ty>(_Ty& __ty) noexcept {
return [__ty = &__ty]() noexcept -> decltype(auto) {
Expand Down

0 comments on commit c34728f

Please sign in to comment.