Skip to content

Commit

Permalink
Improve When(Any/All) #115
Browse files Browse the repository at this point in the history
  • Loading branch information
Ri7ay committed Nov 21, 2021
1 parent 2c36e33 commit ff85f64
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 290 deletions.
154 changes: 73 additions & 81 deletions include/yaclib/algo/when_any.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,57 +17,54 @@ enum class PolicyWhenAny {
namespace detail {

template <typename T, PolicyWhenAny P>
class AnyCombinator : public util::IRef {
public:
static std::pair<Future<T>, util::Ptr<AnyCombinator>> Make(bool empty = true) {
auto [future, promise] = MakeContract<T>();
if (empty) {
std::move(promise).Set(util::Result<T>{});
return {std::move(future), nullptr};
}
return {std::move(future), new util::Counter<AnyCombinator<T, P>>{std::move(promise)}};
class AnyCombinatorBase {
protected:
std::atomic<size_t> _size{0};
Promise<T> _promise;

AnyCombinatorBase(Promise<T> promise, size_t size = 0) : _size(size), _promise{std::move(promise)} {
}

void Combine(util::Result<T>&& result) {
if (_done.load(std::memory_order_acquire)) {
return;
void CombineError(util::Result<T>&& result) {
if (_size.fetch_sub(1, std::memory_order_acq_rel) == 1) {
std::move(_promise).Set(std::move(result));
}
};
};

if (result) {
if (!_done.exchange(true, std::memory_order_acq_rel)) {
std::move(_promise).Set(std::move(result));
}
} else if (!_error.load(std::memory_order_acquire) && !_error.exchange(true, std::memory_order_acq_rel)) {
_except_error = std::move(result);
}
}
template <typename T>
class AnyCombinatorBase<T, PolicyWhenAny::FirstError> {
protected:
std::atomic<bool> _error{false};
util::Result<T> _except_error;
Promise<T> _promise;

~AnyCombinator() override {
if (!_done.load(std::memory_order_acquire)) {
std::move(_promise).Set(std::move(_except_error));
}
AnyCombinatorBase(Promise<T> promise, size_t) : _promise{std::move(promise)} {
}

private:
explicit AnyCombinator(Promise<T> promise) : _promise{std::move(promise)} {
void CombineError(util::Result<T>&& result) {
if (!_error.load(std::memory_order_acquire) && !_error.exchange(true, std::memory_order_acq_rel)) {
_except_error = std::move(result);
}
}

alignas(kCacheLineSize) std::atomic<bool> _done{false};
alignas(kCacheLineSize) std::atomic<bool> _error{false};
util::Result<T> _except_error;
Promise<T> _promise;
};

template <typename T>
class AnyCombinator<T, PolicyWhenAny::LastError> : public util::IRef {
template <typename T, PolicyWhenAny P>
class AnyCombinator : public BaseCore, public AnyCombinatorBase<T, P> {
public:
static std::pair<Future<T>, util::Ptr<AnyCombinator>> Make(size_t size = 0) {
auto [future, promise] = MakeContract<T>();
if (size == 0) {
std::move(promise).Set(util::Result<T>{});
return {std::move(future), nullptr};
}
return {std::move(future), new util::Counter<AnyCombinator<T, PolicyWhenAny::LastError>>{std::move(promise), size}};
return {std::move(future), new util::Counter<AnyCombinator<T, P>>{std::move(promise), size}};
}

void CallInline(void* context) noexcept final {
if (BaseCore::GetState() != BaseCore::State::HasStop) {
Combine(std::move(static_cast<ResultCore<T>*>(context)->Get()));
}
}

void Combine(util::Result<T>&& result) {
Expand All @@ -77,36 +74,64 @@ class AnyCombinator<T, PolicyWhenAny::LastError> : public util::IRef {

if (result) {
if (!_done.exchange(true, std::memory_order_acq_rel)) {
std::move(_promise).Set(std::move(result));
std::move(AnyCombinatorBase<T, P>::_promise).Set(std::move(result));
}
} else {
AnyCombinatorBase<T, P>::CombineError(std::move(result));
}
}

~AnyCombinator() override {
if constexpr (P == PolicyWhenAny::FirstError) {
if (!_done.load(std::memory_order_acquire)) {
std::move(AnyCombinatorBase<T, P>::_promise).Set(std::move(AnyCombinatorBase<T, P>::_except_error));
}
} else if (_size.fetch_sub(1, std::memory_order_acq_rel) == 1U) {
std::move(_promise).Set(std::move(result));
}
}

private:
explicit AnyCombinator(Promise<T> promise, size_t size = 0) : _size{size}, _promise{std::move(promise)} {
explicit AnyCombinator(Promise<T> promise, size_t size = 0)
: BaseCore{BaseCore::State::Empty}, AnyCombinatorBase<T, P>{std::move(promise), size} {
}

alignas(kCacheLineSize) std::atomic<bool> _done{false};
alignas(kCacheLineSize) std::atomic<size_t> _size;
Promise<T> _promise;
std::atomic<bool> _done{false};

// std::atomic<size_t> _data{0};

// util::Result<T> _except_error;
// Promise<T> _promise;
};

template <typename T, PolicyWhenAny P>
using AnyCombinatorPtr = util::Ptr<AnyCombinator<T, P>>;

template <PolicyWhenAny P = PolicyWhenAny::FirstError, typename T, typename... Fs>
void WhenAnyImpl(detail::AnyCombinatorPtr<T, P>& combinator, Future<T>&& head, Fs&&... tail) {
std::move(head).Subscribe([c = combinator](util::Result<T>&& result) mutable {
c->Combine(std::move(result));
c = nullptr;
});
head.GetCore()->SetCallbackInline(combinator);
std::move(head).Detach();
if constexpr (sizeof...(tail) != 0) {
WhenAnyImpl(combinator, std::forward<Fs>(tail)...);
}
}

template <PolicyWhenAny P = PolicyWhenAny::FirstError, typename It,
typename T = util::detail::FutureValueT<typename std::iterator_traits<It>::value_type>, typename Indx>
auto WhenAnyImpl(It iter, Indx begin, Indx end) {
auto [future, combinator] = [&] {
if constexpr (std::is_same_v<It, Indx> && P == PolicyWhenAny::FirstError) {
return detail::AnyCombinator<T, P>::Make(begin != end);
} else {
return detail::AnyCombinator<T, P>::Make(std::distance(begin, end));
}
}();
for (; begin != end; ++begin) {
iter->GetCore()->SetCallbackInline(combinator);
std::move(*iter).Detach();
++iter;
}
return std::move(future);
}

} // namespace detail

/**
Expand All @@ -121,21 +146,7 @@ void WhenAnyImpl(detail::AnyCombinatorPtr<T, P>& combinator, Future<T>&& head, F
template <PolicyWhenAny P = PolicyWhenAny::FirstError, typename It,
typename T = util::detail::FutureValueT<typename std::iterator_traits<It>::value_type>>
auto WhenAny(It begin, size_t size) {
auto [future, combinator] = [&size] {
if constexpr (P == PolicyWhenAny::FirstError) {
return detail::AnyCombinator<T, P>::Make(size == 0);
} else {
return detail::AnyCombinator<T, P>::Make(size);
}
}();
for (size_t i = 0; i != size; ++i) {
std::move(*begin).Subscribe([c = combinator](util::Result<T>&& result) mutable {
c->Combine(std::move(result));
c = nullptr;
});
++begin;
}
return std::move(future);
return detail::WhenAnyImpl(begin, size_t{0}, size);
}

/**
Expand All @@ -150,20 +161,7 @@ auto WhenAny(It begin, size_t size) {
template <PolicyWhenAny P = PolicyWhenAny::FirstError, typename It,
typename T = util::detail::FutureValueT<typename std::iterator_traits<It>::value_type>>
auto WhenAny(It begin, It end) {
auto [future, combinator] = [&begin, &end] {
if constexpr (P == PolicyWhenAny::FirstError) {
return detail::AnyCombinator<T, P>::Make(begin == end);
} else {
return detail::AnyCombinator<T, P>::Make(std::distance(begin, end));
}
}();
for (; begin != end; ++begin) {
std::move(*begin).Subscribe([c = combinator](util::Result<T>&& result) mutable {
c->Combine(std::move(result));
c = nullptr;
});
}
return std::move(future);
return detail::WhenAnyImpl(begin, begin, end);
}

/**
Expand All @@ -177,13 +175,7 @@ auto WhenAny(It begin, It end) {
template <PolicyWhenAny P = PolicyWhenAny::FirstError, typename T, typename... Fs>
auto WhenAny(Future<T>&& head, Fs&&... tail) {
static_assert((... && util::IsFutureV<Fs>));
auto [future, combinator] = [] {
if constexpr (P == PolicyWhenAny::FirstError) {
return detail::AnyCombinator<T, P>::Make(false);
} else {
return detail::AnyCombinator<T, P>::Make(sizeof...(Fs) + 1);
}
}();
auto [future, combinator] = detail::AnyCombinator<T, P>::Make(sizeof...(Fs) + 1);
detail::WhenAnyImpl<P>(combinator, std::move(head), std::forward<Fs>(tail)...);
return std::move(future);
}
Expand Down
87 changes: 55 additions & 32 deletions test/unit/algo/when_all.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ enum class TestSuite {
Array,
};

template <typename T>
class WaitAllT : public testing::Test {
public:
using Type = T;
};

using MyTypes = ::testing::Types<int, void>;

TYPED_TEST_SUITE(WaitAllT, MyTypes);

template <TestSuite suite, typename T = int>
void JustWorks() {
constexpr int kSize = 3;
Expand Down Expand Up @@ -73,20 +83,12 @@ void JustWorks() {
}
}

TEST(Vector, JustWorks) {
JustWorks<TestSuite::Vector>();
}

TEST(VoidVector, JustWorks) {
JustWorks<TestSuite::Vector, void>();
}

TEST(Array, JustWorks) {
JustWorks<TestSuite::Array>();
TYPED_TEST(WaitAllT, VectorJustWorks) {
JustWorks<TestSuite::Vector, typename TestFixture::Type>();
}

TEST(VoidArray, JustWorks) {
JustWorks<TestSuite::Array, void>();
TYPED_TEST(WaitAllT, ArrayJustWorks) {
JustWorks<TestSuite::Array, typename TestFixture::Type>();
}

template <TestSuite suite, typename T = void>
Expand Down Expand Up @@ -120,20 +122,49 @@ void AllFails() {
EXPECT_THROW(std::move(all).Get().Ok(), std::runtime_error);
}

TEST(Vector, AllFails) {
AllFails<TestSuite::Vector>();
TYPED_TEST(WaitAllT, VectorAllFails) {
AllFails<TestSuite::Vector, typename TestFixture::Type>();
}

TYPED_TEST(WaitAllT, ArrayAllFails) {
AllFails<TestSuite::Array, typename TestFixture::Type>();
}

TEST(Array, AllFails) {
AllFails<TestSuite::Array>();
template <TestSuite suite, typename T = void>
void ErrorFails() {
constexpr int kSize = 3;
std::array<Promise<T>, kSize> promises;
std::array<Future<T>, kSize> futures;
for (int i = 0; i < kSize; ++i) {
auto [f, p] = MakeContract<T>();
futures[i] = std::move(f);
promises[i] = std::move(p);
}

auto all = [&futures] {
if constexpr (suite == TestSuite::Array) {
return WhenAll(std::move(futures[0]), std::move(futures[1]), std::move(futures[2]));
} else {
return WhenAll(futures.begin(), futures.end());
}
}();

EXPECT_FALSE(all.Ready());

std::move(promises[1]).Set(std::error_code{});

EXPECT_TRUE(all.Ready());

// Second error
std::move(promises[1]).Set(std::error_code{});
}

TEST(VoidVector, AllFails) {
AllFails<TestSuite::Vector, void>();
TYPED_TEST(WaitAllT, VectorErrorFails) {
ErrorFails<TestSuite::Vector, typename TestFixture::Type>();
}

TEST(VoidArray, AllFails) {
AllFails<TestSuite::Array, void>();
TYPED_TEST(WaitAllT, ArrayErrorFails) {
ErrorFails<TestSuite::Array, typename TestFixture::Type>();
}

template <typename T = int>
Expand Down Expand Up @@ -203,20 +234,12 @@ void MultiThreaded() {
tp->Wait();
}

TEST(Vector, MultiThreaded) {
MultiThreaded<TestSuite::Vector>();
}

TEST(VoidVector, MultiThreaded) {
MultiThreaded<TestSuite::Vector, void>();
}

TEST(Array, MultiThreaded) {
MultiThreaded<TestSuite::Array>();
TYPED_TEST(WaitAllT, VectorMultiThreaded) {
MultiThreaded<TestSuite::Vector, typename TestFixture::Type>();
}

TEST(VoidArray, MultiThreaded) {
MultiThreaded<TestSuite::Array, void>();
TYPED_TEST(WaitAllT, ArrayMultiThreaded) {
MultiThreaded<TestSuite::Array, typename TestFixture::Type>();
}

} // namespace
Loading

0 comments on commit ff85f64

Please sign in to comment.