Skip to content

Commit

Permalink
Improve When(Any/All) #115
Browse files Browse the repository at this point in the history
  • Loading branch information
Ri7ay committed Nov 19, 2021
1 parent 2c36e33 commit 17dfc28
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 294 deletions.
123 changes: 42 additions & 81 deletions include/yaclib/algo/when_any.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,24 @@ enum class PolicyWhenAny {
namespace detail {

template <typename T, PolicyWhenAny P>
class AnyCombinator : public util::IRef {
class AnyCombinator : public BaseCore {
public:
static std::pair<Future<T>, util::Ptr<AnyCombinator>> Make(bool empty = true) {
static std::pair<Future<T>, util::Ptr<AnyCombinator>> Make(size_t size = 0) {
auto [future, promise] = MakeContract<T>();
if (empty) {
if (size == 0) {
std::move(promise).Set(util::Result<T>{});
return {std::move(future), nullptr};
}
return {std::move(future), new util::Counter<AnyCombinator<T, P>>{std::move(promise)}};
return {
std::move(future),
new util::Counter<AnyCombinator<T, P>>{std::move(promise), size},
};
}

void CallInline(void* context) noexcept final {
if (BaseCore::GetState() != BaseCore::State::HasStop) {
Combine(std::move(static_cast<ResultCore<T>*>(context)->Get()));
}
}

void Combine(util::Result<T>&& result) {
Expand All @@ -37,59 +46,38 @@ class AnyCombinator : public util::IRef {
if (!_done.exchange(true, std::memory_order_acq_rel)) {
std::move(_promise).Set(std::move(result));
}
} else if (!_error.load(std::memory_order_acquire) && !_error.exchange(true, std::memory_order_acq_rel)) {
_except_error = std::move(result);
} else {
if constexpr (P == PolicyWhenAny::FirstError) {
if (!_error.load(std::memory_order_acquire) && !_error.exchange(true, std::memory_order_acq_rel)) {
_except_error = std::move(result);
}
} else if constexpr (P == PolicyWhenAny::LastError) {
if (_size.fetch_sub(1, std::memory_order_acq_rel) == 1U) {
std::move(_promise).Set(std::move(result));
}
}
}
}

~AnyCombinator() override {
if (!_done.load(std::memory_order_acquire)) {
std::move(_promise).Set(std::move(_except_error));
if constexpr (P == PolicyWhenAny::FirstError) {
if (!_done.load(std::memory_order_acquire)) {
std::move(_promise).Set(std::move(_except_error));
}
}
}

private:
explicit AnyCombinator(Promise<T> promise) : _promise{std::move(promise)} {
explicit AnyCombinator(Promise<T> promise, size_t size = 0)
: BaseCore{BaseCore::State::Empty}, _size(size), _promise{std::move(promise)} {
}

alignas(kCacheLineSize) std::atomic<bool> _done{false};
alignas(kCacheLineSize) std::atomic<bool> _error{false};
util::Result<T> _except_error;
Promise<T> _promise;
};
std::atomic<bool> _done{false};

template <typename T>
class AnyCombinator<T, PolicyWhenAny::LastError> : public util::IRef {
public:
static std::pair<Future<T>, util::Ptr<AnyCombinator>> Make(size_t size = 0) {
auto [future, promise] = MakeContract<T>();
if (size == 0) {
std::move(promise).Set(util::Result<T>{});
return {std::move(future), nullptr};
}
return {std::move(future), new util::Counter<AnyCombinator<T, PolicyWhenAny::LastError>>{std::move(promise), size}};
}
std::atomic<size_t> _size{0};
std::atomic<bool> _error{false};

void Combine(util::Result<T>&& result) {
if (_done.load(std::memory_order_acquire)) {
return;
}

if (result) {
if (!_done.exchange(true, std::memory_order_acq_rel)) {
std::move(_promise).Set(std::move(result));
}
} else if (_size.fetch_sub(1, std::memory_order_acq_rel) == 1U) {
std::move(_promise).Set(std::move(result));
}
}

private:
explicit AnyCombinator(Promise<T> promise, size_t size = 0) : _size{size}, _promise{std::move(promise)} {
}

alignas(kCacheLineSize) std::atomic<bool> _done{false};
alignas(kCacheLineSize) std::atomic<size_t> _size;
util::Result<T> _except_error;
Promise<T> _promise;
};

Expand All @@ -98,10 +86,8 @@ using AnyCombinatorPtr = util::Ptr<AnyCombinator<T, P>>;

template <PolicyWhenAny P = PolicyWhenAny::FirstError, typename T, typename... Fs>
void WhenAnyImpl(detail::AnyCombinatorPtr<T, P>& combinator, Future<T>&& head, Fs&&... tail) {
std::move(head).Subscribe([c = combinator](util::Result<T>&& result) mutable {
c->Combine(std::move(result));
c = nullptr;
});
head.GetCore()->SetCallbackInline(combinator);
std::move(head).Detach();
if constexpr (sizeof...(tail) != 0) {
WhenAnyImpl(combinator, std::forward<Fs>(tail)...);
}
Expand All @@ -121,18 +107,10 @@ void WhenAnyImpl(detail::AnyCombinatorPtr<T, P>& combinator, Future<T>&& head, F
template <PolicyWhenAny P = PolicyWhenAny::FirstError, typename It,
typename T = util::detail::FutureValueT<typename std::iterator_traits<It>::value_type>>
auto WhenAny(It begin, size_t size) {
auto [future, combinator] = [&size] {
if constexpr (P == PolicyWhenAny::FirstError) {
return detail::AnyCombinator<T, P>::Make(size == 0);
} else {
return detail::AnyCombinator<T, P>::Make(size);
}
}();
auto [future, combinator] = detail::AnyCombinator<T, P>::Make(size);
for (size_t i = 0; i != size; ++i) {
std::move(*begin).Subscribe([c = combinator](util::Result<T>&& result) mutable {
c->Combine(std::move(result));
c = nullptr;
});
begin->GetCore()->SetCallbackInline(combinator);
std::move(*begin).Detach();
++begin;
}
return std::move(future);
Expand All @@ -150,20 +128,9 @@ auto WhenAny(It begin, size_t size) {
template <PolicyWhenAny P = PolicyWhenAny::FirstError, typename It,
typename T = util::detail::FutureValueT<typename std::iterator_traits<It>::value_type>>
auto WhenAny(It begin, It end) {
auto [future, combinator] = [&begin, &end] {
if constexpr (P == PolicyWhenAny::FirstError) {
return detail::AnyCombinator<T, P>::Make(begin == end);
} else {
return detail::AnyCombinator<T, P>::Make(std::distance(begin, end));
}
}();
for (; begin != end; ++begin) {
std::move(*begin).Subscribe([c = combinator](util::Result<T>&& result) mutable {
c->Combine(std::move(result));
c = nullptr;
});
}
return std::move(future);
size_t size = std::distance(begin, end);
auto [future, combinator] = detail::AnyCombinator<T, P>::Make(size);
return WhenAny(begin, size);
}

/**
Expand All @@ -177,13 +144,7 @@ auto WhenAny(It begin, It end) {
template <PolicyWhenAny P = PolicyWhenAny::FirstError, typename T, typename... Fs>
auto WhenAny(Future<T>&& head, Fs&&... tail) {
static_assert((... && util::IsFutureV<Fs>));
auto [future, combinator] = [] {
if constexpr (P == PolicyWhenAny::FirstError) {
return detail::AnyCombinator<T, P>::Make(false);
} else {
return detail::AnyCombinator<T, P>::Make(sizeof...(Fs) + 1);
}
}();
auto [future, combinator] = detail::AnyCombinator<T, P>::Make(sizeof...(Fs) + 1);
detail::WhenAnyImpl<P>(combinator, std::move(head), std::forward<Fs>(tail)...);
return std::move(future);
}
Expand Down
58 changes: 22 additions & 36 deletions test/unit/algo/when_all.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ enum class TestSuite {
Array,
};

template <typename T>
class WaitAllT : public testing::Test {
public:
using Type = T;
};

using MyTypes = ::testing::Types<int, void>;

TYPED_TEST_SUITE(WaitAllT, MyTypes);

template <TestSuite suite, typename T = int>
void JustWorks() {
constexpr int kSize = 3;
Expand Down Expand Up @@ -73,20 +83,12 @@ void JustWorks() {
}
}

TEST(Vector, JustWorks) {
JustWorks<TestSuite::Vector>();
}

TEST(VoidVector, JustWorks) {
JustWorks<TestSuite::Vector, void>();
}

TEST(Array, JustWorks) {
JustWorks<TestSuite::Array>();
TYPED_TEST(WaitAllT, VectorJustWorks) {
JustWorks<TestSuite::Vector, typename TestFixture::Type>();
}

TEST(VoidArray, JustWorks) {
JustWorks<TestSuite::Array, void>();
TYPED_TEST(WaitAllT, ArrayJustWorks) {
JustWorks<TestSuite::Array, typename TestFixture::Type>();
}

template <TestSuite suite, typename T = void>
Expand Down Expand Up @@ -120,20 +122,12 @@ void AllFails() {
EXPECT_THROW(std::move(all).Get().Ok(), std::runtime_error);
}

TEST(Vector, AllFails) {
AllFails<TestSuite::Vector>();
TYPED_TEST(WaitAllT, VectorAllFails) {
AllFails<TestSuite::Vector, typename TestFixture::Type>();
}

TEST(Array, AllFails) {
AllFails<TestSuite::Array>();
}

TEST(VoidVector, AllFails) {
AllFails<TestSuite::Vector, void>();
}

TEST(VoidArray, AllFails) {
AllFails<TestSuite::Array, void>();
TYPED_TEST(WaitAllT, ArrayAllFails) {
AllFails<TestSuite::Array, typename TestFixture::Type>();
}

template <typename T = int>
Expand Down Expand Up @@ -203,20 +197,12 @@ void MultiThreaded() {
tp->Wait();
}

TEST(Vector, MultiThreaded) {
MultiThreaded<TestSuite::Vector>();
}

TEST(VoidVector, MultiThreaded) {
MultiThreaded<TestSuite::Vector, void>();
}

TEST(Array, MultiThreaded) {
MultiThreaded<TestSuite::Array>();
TYPED_TEST(WaitAllT, VectorMultiThreaded) {
MultiThreaded<TestSuite::Vector, typename TestFixture::Type>();
}

TEST(VoidArray, MultiThreaded) {
MultiThreaded<TestSuite::Array, void>();
TYPED_TEST(WaitAllT, ArrayMultiThreaded) {
MultiThreaded<TestSuite::Array, typename TestFixture::Type>();
}

} // namespace

0 comments on commit 17dfc28

Please sign in to comment.