diff --git a/include/yaclib/algo/when_any.hpp b/include/yaclib/algo/when_any.hpp index d50a73030..6c44a60e4 100644 --- a/include/yaclib/algo/when_any.hpp +++ b/include/yaclib/algo/when_any.hpp @@ -17,15 +17,24 @@ enum class PolicyWhenAny { namespace detail { template -class AnyCombinator : public util::IRef { +class AnyCombinator : public BaseCore { public: - static std::pair, util::Ptr> Make(bool empty = true) { + static std::pair, util::Ptr> Make(size_t size = 0) { auto [future, promise] = MakeContract(); - if (empty) { + if (size == 0) { std::move(promise).Set(util::Result{}); return {std::move(future), nullptr}; } - return {std::move(future), new util::Counter>{std::move(promise)}}; + return { + std::move(future), + new util::Counter>{std::move(promise), size}, + }; + } + + void CallInline(void* context) noexcept final { + if (BaseCore::GetState() != BaseCore::State::HasStop) { + Combine(std::move(static_cast*>(context)->Get())); + } } void Combine(util::Result&& result) { @@ -37,59 +46,38 @@ class AnyCombinator : public util::IRef { if (!_done.exchange(true, std::memory_order_acq_rel)) { std::move(_promise).Set(std::move(result)); } - } else if (!_error.load(std::memory_order_acquire) && !_error.exchange(true, std::memory_order_acq_rel)) { - _except_error = std::move(result); + } else { + if constexpr (P == PolicyWhenAny::FirstError) { + if (!_error.load(std::memory_order_acquire) && !_error.exchange(true, std::memory_order_acq_rel)) { + _except_error = std::move(result); + } + } else if constexpr (P == PolicyWhenAny::LastError) { + if (_size.fetch_sub(1, std::memory_order_acq_rel) == 1U) { + std::move(_promise).Set(std::move(result)); + } + } } } ~AnyCombinator() override { - if (!_done.load(std::memory_order_acquire)) { - std::move(_promise).Set(std::move(_except_error)); + if constexpr (P == PolicyWhenAny::FirstError) { + if (!_done.load(std::memory_order_acquire)) { + std::move(_promise).Set(std::move(_except_error)); + } } } private: - explicit AnyCombinator(Promise promise) : _promise{std::move(promise)} { + explicit AnyCombinator(Promise promise, size_t size = 0) + : BaseCore{BaseCore::State::Empty}, _size(size), _promise{std::move(promise)} { } - alignas(kCacheLineSize) std::atomic _done{false}; - alignas(kCacheLineSize) std::atomic _error{false}; - util::Result _except_error; - Promise _promise; -}; + std::atomic _done{false}; -template -class AnyCombinator : public util::IRef { - public: - static std::pair, util::Ptr> Make(size_t size = 0) { - auto [future, promise] = MakeContract(); - if (size == 0) { - std::move(promise).Set(util::Result{}); - return {std::move(future), nullptr}; - } - return {std::move(future), new util::Counter>{std::move(promise), size}}; - } + std::atomic _size{0}; + std::atomic _error{false}; - void Combine(util::Result&& result) { - if (_done.load(std::memory_order_acquire)) { - return; - } - - if (result) { - if (!_done.exchange(true, std::memory_order_acq_rel)) { - std::move(_promise).Set(std::move(result)); - } - } else if (_size.fetch_sub(1, std::memory_order_acq_rel) == 1U) { - std::move(_promise).Set(std::move(result)); - } - } - - private: - explicit AnyCombinator(Promise promise, size_t size = 0) : _size{size}, _promise{std::move(promise)} { - } - - alignas(kCacheLineSize) std::atomic _done{false}; - alignas(kCacheLineSize) std::atomic _size; + util::Result _except_error; Promise _promise; }; @@ -98,10 +86,8 @@ using AnyCombinatorPtr = util::Ptr>; template void WhenAnyImpl(detail::AnyCombinatorPtr& combinator, Future&& head, Fs&&... tail) { - std::move(head).Subscribe([c = combinator](util::Result&& result) mutable { - c->Combine(std::move(result)); - c = nullptr; - }); + head.GetCore()->SetCallbackInline(combinator); + std::move(head).Detach(); if constexpr (sizeof...(tail) != 0) { WhenAnyImpl(combinator, std::forward(tail)...); } @@ -121,18 +107,10 @@ void WhenAnyImpl(detail::AnyCombinatorPtr& combinator, Future&& head, F template ::value_type>> auto WhenAny(It begin, size_t size) { - auto [future, combinator] = [&size] { - if constexpr (P == PolicyWhenAny::FirstError) { - return detail::AnyCombinator::Make(size == 0); - } else { - return detail::AnyCombinator::Make(size); - } - }(); + auto [future, combinator] = detail::AnyCombinator::Make(size); for (size_t i = 0; i != size; ++i) { - std::move(*begin).Subscribe([c = combinator](util::Result&& result) mutable { - c->Combine(std::move(result)); - c = nullptr; - }); + begin->GetCore()->SetCallbackInline(combinator); + std::move(*begin).Detach(); ++begin; } return std::move(future); @@ -150,20 +128,9 @@ auto WhenAny(It begin, size_t size) { template ::value_type>> auto WhenAny(It begin, It end) { - auto [future, combinator] = [&begin, &end] { - if constexpr (P == PolicyWhenAny::FirstError) { - return detail::AnyCombinator::Make(begin == end); - } else { - return detail::AnyCombinator::Make(std::distance(begin, end)); - } - }(); - for (; begin != end; ++begin) { - std::move(*begin).Subscribe([c = combinator](util::Result&& result) mutable { - c->Combine(std::move(result)); - c = nullptr; - }); - } - return std::move(future); + size_t size = std::distance(begin, end); + auto [future, combinator] = detail::AnyCombinator::Make(size); + return WhenAny(begin, size); } /** @@ -177,13 +144,7 @@ auto WhenAny(It begin, It end) { template auto WhenAny(Future&& head, Fs&&... tail) { static_assert((... && util::IsFutureV)); - auto [future, combinator] = [] { - if constexpr (P == PolicyWhenAny::FirstError) { - return detail::AnyCombinator::Make(false); - } else { - return detail::AnyCombinator::Make(sizeof...(Fs) + 1); - } - }(); + auto [future, combinator] = detail::AnyCombinator::Make(sizeof...(Fs) + 1); detail::WhenAnyImpl

(combinator, std::move(head), std::forward(tail)...); return std::move(future); } diff --git a/test/unit/algo/when_all.cpp b/test/unit/algo/when_all.cpp index 6310c670d..224425ec7 100644 --- a/test/unit/algo/when_all.cpp +++ b/test/unit/algo/when_all.cpp @@ -17,6 +17,16 @@ enum class TestSuite { Array, }; +template +class WaitAllT : public testing::Test { + public: + using Type = T; +}; + +using MyTypes = ::testing::Types; + +TYPED_TEST_SUITE(WaitAllT, MyTypes); + template void JustWorks() { constexpr int kSize = 3; @@ -73,20 +83,12 @@ void JustWorks() { } } -TEST(Vector, JustWorks) { - JustWorks(); -} - -TEST(VoidVector, JustWorks) { - JustWorks(); -} - -TEST(Array, JustWorks) { - JustWorks(); +TYPED_TEST(WaitAllT, VectorJustWorks) { + JustWorks(); } -TEST(VoidArray, JustWorks) { - JustWorks(); +TYPED_TEST(WaitAllT, ArrayJustWorks) { + JustWorks(); } template @@ -120,20 +122,12 @@ void AllFails() { EXPECT_THROW(std::move(all).Get().Ok(), std::runtime_error); } -TEST(Vector, AllFails) { - AllFails(); +TYPED_TEST(WaitAllT, VectorAllFails) { + AllFails(); } -TEST(Array, AllFails) { - AllFails(); -} - -TEST(VoidVector, AllFails) { - AllFails(); -} - -TEST(VoidArray, AllFails) { - AllFails(); +TYPED_TEST(WaitAllT, ArrayAllFails) { + AllFails(); } template @@ -203,20 +197,12 @@ void MultiThreaded() { tp->Wait(); } -TEST(Vector, MultiThreaded) { - MultiThreaded(); -} - -TEST(VoidVector, MultiThreaded) { - MultiThreaded(); -} - -TEST(Array, MultiThreaded) { - MultiThreaded(); +TYPED_TEST(WaitAllT, VectorMultiThreaded) { + MultiThreaded(); } -TEST(VoidArray, MultiThreaded) { - MultiThreaded(); +TYPED_TEST(WaitAllT, ArrayMultiThreaded) { + MultiThreaded(); } } // namespace diff --git a/test/unit/algo/when_any.cpp b/test/unit/algo/when_any.cpp index 9f9df3485..141deaab6 100644 --- a/test/unit/algo/when_any.cpp +++ b/test/unit/algo/when_any.cpp @@ -12,26 +12,45 @@ namespace { using namespace yaclib; using namespace std::chrono_literals; -enum class TestSuite { Vector, Array }; +enum class TestSuite { + Vector, + Array, +}; -template -void JustWorks() { - constexpr int kSize = 3; +template +class WaitAnyT : public testing::Test { + public: + using Type = T; +}; - std::array, kSize> promises; - std::array, kSize> futures; +using MyTypes = ::testing::Types; + +TYPED_TEST_SUITE(WaitAnyT, MyTypes); + +template +auto FillArrays(std::array, kSize>& promises, std::array, kSize>& futures) { for (int i = 0; i < kSize; ++i) { auto [f, p] = MakeContract(); futures[i] = std::move(f); promises[i] = std::move(p); } - auto any = [&futures] { + + return [&futures] { if constexpr (suite == TestSuite::Array) { return WhenAny

(std::move(futures[0]), std::move(futures[1]), std::move(futures[2])); } else { return WhenAny

(futures.begin(), futures.end()); } }(); +} + +template +void JustWorks() { + constexpr int kSize = 3; + + std::array, kSize> promises; + std::array, kSize> futures; + auto any = FillArrays(promises, futures); EXPECT_FALSE(any.Ready()); @@ -49,36 +68,20 @@ void JustWorks() { } } -TEST(VectorFirstError, JustWorks) { - JustWorks(); -} - -TEST(VectorLastError, JustWorks) { - JustWorks(); -} - -TEST(VoidVectorFirstError, JustWorks) { - JustWorks(); -} - -TEST(VoidVectorLastError, JustWorks) { - JustWorks(); +TYPED_TEST(WaitAnyT, VectorFirstErrorJustWorks) { + JustWorks(); } -TEST(ArrayFirstError, JustWorks) { - JustWorks(); +TYPED_TEST(WaitAnyT, VectorLastErrorJustWorks) { + JustWorks(); } -TEST(ArrayLastError, JustWorks) { - JustWorks(); +TYPED_TEST(WaitAnyT, ArrayFirstErrorJustWorks) { + JustWorks(); } -TEST(VoidArrayFirstError, JustWorks) { - JustWorks(); -} - -TEST(VoidArrayLastError, JustWorks) { - JustWorks(); +TYPED_TEST(WaitAnyT, ArrayLastErrorJustWorks) { + JustWorks(); } template @@ -86,19 +89,7 @@ void AllFails() { constexpr int kSize = 3; std::array, kSize> promises; std::array, kSize> futures; - for (int i = 0; i < kSize; ++i) { - auto [f, p] = MakeContract(); - futures[i] = std::move(f); - promises[i] = std::move(p); - } - - auto any = [&futures] { - if constexpr (suite == TestSuite::Array) { - return WhenAny

(std::move(futures[0]), std::move(futures[1]), std::move(futures[2])); - } else { - return WhenAny

(futures.begin(), futures.end()); - } - }(); + auto any = FillArrays(promises, futures); EXPECT_FALSE(any.Ready()); @@ -118,36 +109,20 @@ void AllFails() { } } -TEST(VectorFirstError, AllFails) { - AllFails(); -} - -TEST(VectorLastError, AllFails) { - AllFails(); -} - -TEST(VoidVectorFirstError, AllFails) { - AllFails(); -} - -TEST(VoidVectorLastError, AllFails) { - AllFails(); -} - -TEST(ArrayFirstError, AllFails) { - AllFails(); +TYPED_TEST(WaitAnyT, VectorFirstErrorAllFails) { + AllFails(); } -TEST(ArrayLastError, AllFails) { - AllFails(); +TYPED_TEST(WaitAnyT, VectorLastErrorAllFails) { + AllFails(); } -TEST(VoidArrayFirstError, AllFails) { - AllFails(); +TYPED_TEST(WaitAnyT, ArrayFirstErrorAllFails) { + AllFails(); } -TEST(VoidArrayLastError, AllFails) { - AllFails(); +TYPED_TEST(WaitAnyT, ArrayLastErrorAllFails) { + AllFails(); } template @@ -155,19 +130,7 @@ void ResultWithFails() { constexpr int kSize = 3; std::array, kSize> promises; std::array, kSize> futures; - for (int i = 0; i < kSize; ++i) { - auto [f, p] = MakeContract(); - futures[i] = std::move(f); - promises[i] = std::move(p); - } - - auto any = [&futures] { - if constexpr (suite == TestSuite::Array) { - return WhenAny

(std::move(futures[0]), std::move(futures[1]), std::move(futures[2])); - } else { - return WhenAny

(futures.begin(), futures.end()); - } - }(); + auto any = FillArrays(promises, futures); EXPECT_FALSE(any.Ready()); @@ -188,36 +151,20 @@ void ResultWithFails() { } } -TEST(VectorFirstError, ResultWithFails) { - ResultWithFails(); -} - -TEST(VectorLastError, ResultWithFails) { - ResultWithFails(); +TYPED_TEST(WaitAnyT, VectorFirstErrorResultWithFails) { + AllFails(); } -TEST(VoidVectorFirstError, ResultWithFails) { - ResultWithFails(); +TYPED_TEST(WaitAnyT, VectorLastErrorResultWithFails) { + AllFails(); } -TEST(VoidVectorLastError, ResultWithFails) { - ResultWithFails(); +TYPED_TEST(WaitAnyT, ArrayFirstErrorResultWithFails) { + AllFails(); } -TEST(ArrayFirstError, ResultWithFails) { - ResultWithFails(); -} - -TEST(ArrayLastError, ResultWithFails) { - ResultWithFails(); -} - -TEST(VoidArrayFirstError, ResultWithFails) { - ResultWithFails(); -} - -TEST(VoidArrayLastError, ResultWithFails) { - ResultWithFails(); +TYPED_TEST(WaitAnyT, ArrayLastErrorResultWithFails) { + AllFails(); } template @@ -230,20 +177,12 @@ void EmptyInput() { EXPECT_THROW(std::move(any).Get().Ok(), std::exception); } -TEST(VectorFirstError, EmptyInput) { - EmptyInput(); -} - -TEST(VectorLastError, EmptyInput) { - EmptyInput(); +TYPED_TEST(WaitAnyT, FirstErrorEmptyInput) { + EmptyInput(); } -TEST(VoidVectorFirstError, EmptyInput) { - EmptyInput(); -} - -TEST(VoidVectorLastError, EmptyInput) { - EmptyInput(); +TYPED_TEST(WaitAnyT, LastErrorEmptyInput) { + EmptyInput(); } template @@ -293,36 +232,20 @@ void MultiThreaded() { tp->Wait(); } -TEST(VectorFirstError, MultiThreaded) { - MultiThreaded(); -} - -TEST(VectorLastError, MultiThreaded) { - MultiThreaded(); -} - -TEST(VoidVectorFirstError, MultiThreaded) { - MultiThreaded(); +TYPED_TEST(WaitAnyT, VectorFirstErrorMultiThreaded) { + MultiThreaded(); } -TEST(VoidVectorLastError, MultiThreaded) { - MultiThreaded(); +TYPED_TEST(WaitAnyT, VectorLastErrorMultiThreaded) { + MultiThreaded(); } -TEST(ArrayFirstError, MultiThreaded) { - MultiThreaded(); +TYPED_TEST(WaitAnyT, ArrayFirstErrorMultiThreaded) { + MultiThreaded(); } -TEST(ArrayLastError, MultiThreaded) { - MultiThreaded(); -} - -TEST(VoidArrayFirstError, MultiThreaded) { - MultiThreaded(); -} - -TEST(VoidArrayLastError, MultiThreaded) { - MultiThreaded(); +TYPED_TEST(WaitAnyT, ArrayLastErrorMultiThreaded) { + MultiThreaded(); } template @@ -371,36 +294,20 @@ void TimeTest() { tp->Wait(); } -TEST(VectorFirstError, TimeTest) { - TimeTest(); -} - -TEST(VectorLastError, TimeTest) { - TimeTest(); -} - -TEST(VoidVectorFirstError, TimeTest) { - TimeTest(); +TYPED_TEST(WaitAnyT, VectorFirstErrorTimeTest) { + TimeTest(); } -TEST(VoidVectorLastError, TimeTest) { - TimeTest(); +TYPED_TEST(WaitAnyT, VectorLastErrorTimeTest) { + TimeTest(); } -TEST(ArrayFirstError, TimeTest) { - TimeTest(); +TYPED_TEST(WaitAnyT, ArrayFirstErrorTimeTest) { + TimeTest(); } -TEST(ArrayLastError, TimeTest) { - TimeTest(); -} - -TEST(VoidArrayFirstError, TimeTest) { - TimeTest(); -} - -TEST(VoidArrayLastError, TimeTest) { - TimeTest(); +TYPED_TEST(WaitAnyT, ArrayLastErrorTimeTest) { + TimeTest(); } template @@ -438,20 +345,12 @@ void DefaultPolice() { } } -TEST(VectorFirstError, DefaultPolice) { - DefaultPolice(); -} - -TEST(VoidVectorFirstError, DefaultPolice) { - DefaultPolice(); -} - -TEST(ArrayFirstError, DefaultPolice) { - DefaultPolice(); +TYPED_TEST(WaitAnyT, VectorDefaultPolice) { + DefaultPolice(); } -TEST(VoidArrayFirstError, DefaultPolice) { - DefaultPolice(); +TYPED_TEST(WaitAnyT, ArrayDefaultPolice) { + DefaultPolice(); } } // namespace