Skip to content

Commit

Permalink
Extend collect any (#371)
Browse files Browse the repository at this point in the history
  • Loading branch information
qicosmos committed Mar 15, 2024
1 parent ab311d7 commit 6be48e7
Show file tree
Hide file tree
Showing 5 changed files with 398 additions and 30 deletions.
198 changes: 173 additions & 25 deletions async_simple/coro/Collect.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <vector>
#include "async_simple/Common.h"
#include "async_simple/Try.h"
#include "async_simple/Unit.h"
#include "async_simple/coro/CountEvent.h"
#include "async_simple/coro/Lazy.h"
#include "async_simple/experimental/coroutine.h"
Expand Down Expand Up @@ -83,18 +84,25 @@ struct CollectAnyResult {
#endif
};

template <typename LazyType, typename InAlloc>
template <typename LazyType, typename InAlloc, typename Callback = Unit>
struct CollectAnyAwaiter {
using ValueType = typename LazyType::ValueType;
using ResultType = CollectAnyResult<ValueType>;

CollectAnyAwaiter(std::vector<LazyType, InAlloc>&& input)
: _input(std::move(input)), _result(nullptr) {}

CollectAnyAwaiter(std::vector<LazyType, InAlloc>&& input, Callback callback)
: _input(std::move(input)),
_result(nullptr),
_callback(std::move(callback)) {}

CollectAnyAwaiter(const CollectAnyAwaiter&) = delete;
CollectAnyAwaiter& operator=(const CollectAnyAwaiter&) = delete;
CollectAnyAwaiter(CollectAnyAwaiter&& other)
: _input(std::move(other._input)), _result(std::move(other._result)) {}
: _input(std::move(other._input)),
_result(std::move(other._result)),
_callback(std::move(other._callback)) {}

bool await_ready() const noexcept {
return _input.empty() ||
Expand All @@ -113,6 +121,7 @@ struct CollectAnyAwaiter {
// if any coroutine finishes before this function.
auto result = std::make_shared<ResultType>();
auto event = std::make_shared<detail::CountEvent>(input.size());
auto callback = std::move(_callback);

_result = result;
for (size_t i = 0;
Expand All @@ -122,26 +131,130 @@ struct CollectAnyAwaiter {
input[i]._coro.promise()._executor = executor;
}

input[i].start([i = i, size = input.size(), r = result,
c = continuation,
e = event](Try<ValueType>&& result) mutable {
assert(e != nullptr);
auto count = e->downCount();
if (count == size + 1) {
r->_idx = i;
r->_value = std::move(result);
c.resume();
}
});
if constexpr (std::is_same_v<Callback, Unit>) {
(void)callback;
input[i].start([i, size = input.size(), r = result,
c = continuation,
e = event](Try<ValueType>&& result) mutable {
assert(e != nullptr);
auto count = e->downCount();
if (count == size + 1) {
r->_idx = i;
r->_value = std::move(result);
c.resume();
}
});
} else {
input[i].start([i, size = input.size(), r = result,
c = continuation, e = event,
callback](Try<ValueType>&& result) mutable {
assert(e != nullptr);
auto count = e->downCount();
if (count == size + 1) {
r->_idx = i;
(*callback)(i, std::move(result));
c.resume();
}
});
}
} // end for
}
auto await_resume() {
assert(_result != nullptr);
return std::move(*_result);
if constexpr (std::is_same_v<Callback, Unit>) {
assert(_result != nullptr);
return std::move(*_result);
} else {
return _result->index();
}
}

std::vector<LazyType, InAlloc> _input;
std::shared_ptr<ResultType> _result;
[[no_unique_address]] Callback _callback;
};

template <typename... Ts>
struct CollectAnyVariadicPairAwaiter {
using InputType = std::tuple<Ts...>;

CollectAnyVariadicPairAwaiter(Ts&&... inputs)
: _input(std::move(inputs)...), _result(nullptr) {}

CollectAnyVariadicPairAwaiter(InputType&& inputs)
: _input(std::move(inputs)), _result(nullptr) {}

CollectAnyVariadicPairAwaiter(const CollectAnyVariadicPairAwaiter&) =
delete;
CollectAnyVariadicPairAwaiter& operator=(
const CollectAnyVariadicPairAwaiter&) = delete;
CollectAnyVariadicPairAwaiter(CollectAnyVariadicPairAwaiter&& other)
: _input(std::move(other._input)), _result(std::move(other._result)) {}

bool await_ready() const noexcept {
return _result && _result->has_value();
}

void await_suspend(std::coroutine_handle<> continuation) {
auto promise_type =
std::coroutine_handle<LazyPromiseBase>::from_address(
continuation.address())
.promise();
auto executor = promise_type._executor;
auto event =
std::make_shared<detail::CountEvent>(std::tuple_size<InputType>());
auto result = std::make_shared<std::optional<size_t>>();
_result = result;

auto input = std::move(_input);

[&]<size_t... I>(std::index_sequence<I...>) {
(
[&](auto& lazy, auto& callback) {
if (result->has_value()) {
return;
}

if (!lazy._coro.promise()._executor) {
lazy._coro.promise()._executor = executor;
}

lazy.start([result, event, continuation,
callback](auto&& res) mutable {
auto count = event->downCount();
if (count == std::tuple_size<InputType>() + 1) {
callback(std::move(res));
*result = I;
continuation.resume();
}
});
}(std::get<0>(std::get<I>(input)),
std::get<1>(std::get<I>(input))),
...);
}
(std::make_index_sequence<sizeof...(Ts)>());
}

auto await_resume() {
assert(_result != nullptr);
return std::move(_result->value());
}

std::tuple<Ts...> _input;
std::shared_ptr<std::optional<size_t>> _result;
};

template <typename... Ts>
struct SimpleCollectAnyVariadicPairAwaiter {
using InputType = std::tuple<Ts...>;

InputType _inputs;

SimpleCollectAnyVariadicPairAwaiter(Ts&&... inputs)
: _inputs(std::move(inputs)...) {}

auto coAwait(Executor* ex) {
return CollectAnyVariadicPairAwaiter(std::move(_inputs));
}
};

template <template <typename> typename LazyType, typename... Ts>
Expand Down Expand Up @@ -224,19 +337,29 @@ struct CollectAnyVariadicAwaiter {
std::shared_ptr<std::optional<ResultType>> _result;
};

template <typename T, typename InAlloc>
template <typename T, typename InAlloc, typename Callback = Unit>
struct SimpleCollectAnyAwaitable {
using ValueType = T;
using LazyType = Lazy<T>;
using VectorType = std::vector<LazyType, InAlloc>;

VectorType _input;
[[no_unique_address]] Callback _callback;

SimpleCollectAnyAwaitable(std::vector<LazyType, InAlloc>&& input)
: _input(std::move(input)) {}

SimpleCollectAnyAwaitable(std::vector<LazyType, InAlloc>&& input,
Callback callback)
: _input(std::move(input)), _callback(std::move(callback)) {}

auto coAwait(Executor* ex) {
return CollectAnyAwaiter<LazyType, InAlloc>(std::move(_input));
if constexpr (std::is_same_v<Callback, Unit>) {
return CollectAnyAwaiter<LazyType, InAlloc>(std::move(_input));
} else {
return CollectAnyAwaiter<LazyType, InAlloc, Callback>(
std::move(_input), std::move(_callback));
}
}
};

Expand Down Expand Up @@ -486,15 +609,16 @@ inline auto collectAllVariadicImpl(LazyType<Ts>&&... awaitables) {
}

// collectAny

template <typename T, template <typename> typename LazyType,
typename IAlloc = std::allocator<LazyType<T>>>
inline auto collectAnyImpl(std::vector<LazyType<T>, IAlloc> input) {
using AT =
std::conditional_t<std::is_same_v<LazyType<T>, Lazy<T>>,
detail::SimpleCollectAnyAwaitable<T, IAlloc>,
detail::CollectAnyAwaiter<LazyType<T>, IAlloc>>;
return AT(std::move(input));
typename IAlloc = std::allocator<LazyType<T>>,
typename Callback = Unit>
inline auto collectAnyImpl(std::vector<LazyType<T>, IAlloc> input,
Callback callback = {}) {
using AT = std::conditional_t<
std::is_same_v<LazyType<T>, Lazy<T>>,
detail::SimpleCollectAnyAwaitable<T, IAlloc, Callback>,
detail::CollectAnyAwaiter<LazyType<T>, IAlloc, Callback>>;
return AT(std::move(input), std::move(callback));
}

// collectAnyVariadic
Expand All @@ -507,6 +631,15 @@ inline auto CollectAnyVariadicImpl(LazyType<Ts>&&... inputs) {
return AT(std::move(inputs)...);
}

// collectAnyVariadicPair
template <typename T, typename... Ts>
inline auto CollectAnyVariadicPairImpl(T&& input, Ts&&... inputs) {
using U = std::tuple_element_t<0, std::remove_cvref_t<T>>;
using AT = std::conditional_t<is_lazy<U>::value,
SimpleCollectAnyVariadicPairAwaiter<T, Ts...>,
CollectAnyVariadicPairAwaiter<T, Ts...>>;
return AT(std::move(input), std::move(inputs)...);
}
} // namespace detail

template <typename T, template <typename> typename LazyType,
Expand All @@ -515,12 +648,27 @@ inline auto collectAny(std::vector<LazyType<T>, IAlloc>&& input) {
return detail::collectAnyImpl(std::move(input));
}

template <typename T, template <typename> typename LazyType,
typename IAlloc = std::allocator<LazyType<T>>, typename Callback>
inline auto collectAny(std::vector<LazyType<T>, IAlloc>&& input,
Callback callback) {
auto cb = std::make_shared<Callback>(std::move(callback));
return detail::collectAnyImpl(std::move(input), std::move(cb));
}

template <template <typename> typename LazyType, typename... Ts>
inline auto collectAny(LazyType<Ts>... awaitables) {
static_assert(sizeof...(Ts), "collectAny need at least one param!");
return detail::CollectAnyVariadicImpl(std::move(awaitables)...);
}

// collectAny with std::pair<Lazy, CallbackFunction>
template <typename... Ts>
inline auto collectAny(Ts&&... inputs) {
static_assert(sizeof...(Ts), "collectAny need at least one param!");
return detail::CollectAnyVariadicPairImpl(std::move(inputs)...);
}

// The collectAll() function can be used to co_await on a vector of LazyType
// tasks in **one thread**,and producing a vector of Try values containing each
// of the results.
Expand Down
14 changes: 10 additions & 4 deletions async_simple/coro/Lazy.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,15 @@ struct CollectAllAwaiter;
template <bool Para, template <typename> typename LazyType, typename... Ts>
struct CollectAllVariadicAwaiter;

template <typename LazyType, typename IAlloc>
template <typename LazyType, typename IAlloc, typename Callback>
struct CollectAnyAwaiter;

template <template <typename> typename LazyType, typename... Ts>
struct CollectAnyVariadicAwaiter;

template <typename... Ts>
struct CollectAnyVariadicPairAwaiter;

} // namespace detail

namespace detail {
Expand Down Expand Up @@ -151,8 +154,8 @@ class LazyPromise : public LazyPromiseBase {

template <typename V>
void return_value(V&& value) noexcept(
std::is_nothrow_constructible_v<
T, V&&>) requires std::is_convertible_v<V&&, T> {
std::is_nothrow_constructible_v<T, V&&>) requires
std::is_convertible_v<V&&, T> {
_value.template emplace<T>(std::forward<V>(value));
}
void unhandled_exception() noexcept {
Expand Down Expand Up @@ -387,11 +390,14 @@ class LazyBase {
template <bool, template <typename> typename, typename...>
friend struct detail::CollectAllVariadicAwaiter;

template <typename LazyType, typename IAlloc>
template <typename LazyType, typename IAlloc, typename Callback>
friend struct detail::CollectAnyAwaiter;

template <template <typename> typename LazyType, typename... Ts>
friend struct detail::CollectAnyVariadicAwaiter;

template <typename... Ts>
friend struct detail::CollectAnyVariadicPairAwaiter;
};

} // namespace detail
Expand Down
Loading

0 comments on commit 6be48e7

Please sign in to comment.