Skip to content

Commit

Permalink
Improve When(Any/All) #115
Browse files Browse the repository at this point in the history
  • Loading branch information
Ri7ay committed Nov 18, 2021
1 parent 2c36e33 commit c6f9327
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 229 deletions.
39 changes: 23 additions & 16 deletions include/yaclib/algo/when_any.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ enum class PolicyWhenAny {
namespace detail {

template <typename T, PolicyWhenAny P>
class AnyCombinator : public util::IRef {
class AnyCombinator : public BaseCore {
public:
static std::pair<Future<T>, util::Ptr<AnyCombinator>> Make(bool empty = true) {
auto [future, promise] = MakeContract<T>();
Expand All @@ -28,6 +28,12 @@ class AnyCombinator : public util::IRef {
return {std::move(future), new util::Counter<AnyCombinator<T, P>>{std::move(promise)}};
}

void CallInline(void* context) noexcept final {
if (BaseCore::GetState() != BaseCore::State::HasStop) {
Combine(std::move(static_cast<ResultCore<T>*>(context)->Get()));
}
}

void Combine(util::Result<T>&& result) {
if (_done.load(std::memory_order_acquire)) {
return;
Expand All @@ -49,7 +55,7 @@ class AnyCombinator : public util::IRef {
}

private:
explicit AnyCombinator(Promise<T> promise) : _promise{std::move(promise)} {
explicit AnyCombinator(Promise<T> promise) : BaseCore{BaseCore::State::Empty}, _promise{std::move(promise)} {
}

alignas(kCacheLineSize) std::atomic<bool> _done{false};
Expand All @@ -59,7 +65,7 @@ class AnyCombinator : public util::IRef {
};

template <typename T>
class AnyCombinator<T, PolicyWhenAny::LastError> : public util::IRef {
class AnyCombinator<T, PolicyWhenAny::LastError> : public BaseCore {
public:
static std::pair<Future<T>, util::Ptr<AnyCombinator>> Make(size_t size = 0) {
auto [future, promise] = MakeContract<T>();
Expand All @@ -70,6 +76,12 @@ class AnyCombinator<T, PolicyWhenAny::LastError> : public util::IRef {
return {std::move(future), new util::Counter<AnyCombinator<T, PolicyWhenAny::LastError>>{std::move(promise), size}};
}

void CallInline(void* context) noexcept final {
if (BaseCore::GetState() != BaseCore::State::HasStop) {
Combine(std::move(static_cast<ResultCore<T>*>(context)->Get()));
}
}

void Combine(util::Result<T>&& result) {
if (_done.load(std::memory_order_acquire)) {
return;
Expand All @@ -85,7 +97,8 @@ class AnyCombinator<T, PolicyWhenAny::LastError> : public util::IRef {
}

private:
explicit AnyCombinator(Promise<T> promise, size_t size = 0) : _size{size}, _promise{std::move(promise)} {
explicit AnyCombinator(Promise<T> promise, size_t size = 0)
: BaseCore{BaseCore::State::Empty}, _size{size}, _promise{std::move(promise)} {
}

alignas(kCacheLineSize) std::atomic<bool> _done{false};
Expand All @@ -98,10 +111,8 @@ using AnyCombinatorPtr = util::Ptr<AnyCombinator<T, P>>;

template <PolicyWhenAny P = PolicyWhenAny::FirstError, typename T, typename... Fs>
void WhenAnyImpl(detail::AnyCombinatorPtr<T, P>& combinator, Future<T>&& head, Fs&&... tail) {
std::move(head).Subscribe([c = combinator](util::Result<T>&& result) mutable {
c->Combine(std::move(result));
c = nullptr;
});
head.GetCore()->SetCallbackInline(combinator);
std::move(head).Detach();
if constexpr (sizeof...(tail) != 0) {
WhenAnyImpl(combinator, std::forward<Fs>(tail)...);
}
Expand Down Expand Up @@ -129,10 +140,8 @@ auto WhenAny(It begin, size_t size) {
}
}();
for (size_t i = 0; i != size; ++i) {
std::move(*begin).Subscribe([c = combinator](util::Result<T>&& result) mutable {
c->Combine(std::move(result));
c = nullptr;
});
begin->GetCore()->SetCallbackInline(combinator);
std::move(*begin).Detach();
++begin;
}
return std::move(future);
Expand All @@ -158,10 +167,8 @@ auto WhenAny(It begin, It end) {
}
}();
for (; begin != end; ++begin) {
std::move(*begin).Subscribe([c = combinator](util::Result<T>&& result) mutable {
c->Combine(std::move(result));
c = nullptr;
});
begin->GetCore()->SetCallbackInline(combinator);
std::move(*begin).Detach();
}
return std::move(future);
}
Expand Down
58 changes: 22 additions & 36 deletions test/unit/algo/when_all.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ enum class TestSuite {
Array,
};

template <typename T>
class WaitAllT : public testing::Test {
public:
using Type = T;
};

using MyTypes = ::testing::Types<int, void>;

TYPED_TEST_SUITE(WaitAllT, MyTypes);

template <TestSuite suite, typename T = int>
void JustWorks() {
constexpr int kSize = 3;
Expand Down Expand Up @@ -73,20 +83,12 @@ void JustWorks() {
}
}

TEST(Vector, JustWorks) {
JustWorks<TestSuite::Vector>();
}

TEST(VoidVector, JustWorks) {
JustWorks<TestSuite::Vector, void>();
}

TEST(Array, JustWorks) {
JustWorks<TestSuite::Array>();
TYPED_TEST(WaitAllT, VectorJustWorks) {
JustWorks<TestSuite::Vector, typename TestFixture::Type>();
}

TEST(VoidArray, JustWorks) {
JustWorks<TestSuite::Array, void>();
TYPED_TEST(WaitAllT, ArrayJustWorks) {
JustWorks<TestSuite::Array, typename TestFixture::Type>();
}

template <TestSuite suite, typename T = void>
Expand Down Expand Up @@ -120,20 +122,12 @@ void AllFails() {
EXPECT_THROW(std::move(all).Get().Ok(), std::runtime_error);
}

TEST(Vector, AllFails) {
AllFails<TestSuite::Vector>();
TYPED_TEST(WaitAllT, VectorAllFails) {
AllFails<TestSuite::Vector, typename TestFixture::Type>();
}

TEST(Array, AllFails) {
AllFails<TestSuite::Array>();
}

TEST(VoidVector, AllFails) {
AllFails<TestSuite::Vector, void>();
}

TEST(VoidArray, AllFails) {
AllFails<TestSuite::Array, void>();
TYPED_TEST(WaitAllT, ArrayAllFails) {
AllFails<TestSuite::Array, typename TestFixture::Type>();
}

template <typename T = int>
Expand Down Expand Up @@ -203,20 +197,12 @@ void MultiThreaded() {
tp->Wait();
}

TEST(Vector, MultiThreaded) {
MultiThreaded<TestSuite::Vector>();
}

TEST(VoidVector, MultiThreaded) {
MultiThreaded<TestSuite::Vector, void>();
}

TEST(Array, MultiThreaded) {
MultiThreaded<TestSuite::Array>();
TYPED_TEST(WaitAllT, VectorMultiThreaded) {
MultiThreaded<TestSuite::Vector, typename TestFixture::Type>();
}

TEST(VoidArray, MultiThreaded) {
MultiThreaded<TestSuite::Array, void>();
TYPED_TEST(WaitAllT, ArrayMultiThreaded) {
MultiThreaded<TestSuite::Array, typename TestFixture::Type>();
}

} // namespace
Loading

0 comments on commit c6f9327

Please sign in to comment.