Skip to content

Commit

Permalink
NativePromise::all should handle NativePromise with ResolveValueT==void
Browse files Browse the repository at this point in the history
https://bugs.webkit.org/show_bug.cgi?id=264082
rdar://117837167

Reviewed by Youenn Fablet.

Also fix bug 264122 as fly-by

Added API tests.

* Source/WTF/wtf/NativePromise.h:
* Tools/TestWebKitAPI/Tests/WTF/NativePromise.cpp:
(TestWebKitAPI::TEST):

Canonical link: https://commits.webkit.org/270344@main
  • Loading branch information
jyavenard committed Nov 7, 2023
1 parent 92399f9 commit cc6dee3
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 22 deletions.
59 changes: 37 additions & 22 deletions Source/WTF/wtf/NativePromise.h
Original file line number Diff line number Diff line change
Expand Up @@ -429,11 +429,13 @@ class NativePromise final : public NativePromiseBase, public ConvertibleToNative
return p;
}

using AllPromiseType = NativePromise<Vector<ResolveValueType>, RejectValueType, options>;
using AllPromiseType = NativePromise<std::conditional_t<std::is_void_v<ResolveValueType>, void, Vector<ResolveValueType>>, RejectValueType, options>;
using AllSettledPromiseType = NativePromise<Vector<Result>, bool, options>;

private:
friend class Producer;
struct VoidPlaceholder {
};

template<typename SettleValueType>
inline void settleImpl(SettleValueType&& result, Locker<Lock>& lock)
Expand Down Expand Up @@ -517,7 +519,8 @@ class NativePromise final : public NativePromiseBase, public ConvertibleToNative
, m_outstandingPromises(dependentPromisesCount)
{
ASSERT(dependentPromisesCount);
m_resolveValues.resize(dependentPromisesCount);
if constexpr (!std::is_void_v<ResolveValueT>)
m_resolveValues.resize(dependentPromisesCount);
}

template<typename ResolveValueType_>
Expand All @@ -528,11 +531,16 @@ class NativePromise final : public NativePromiseBase, public ConvertibleToNative
return;
}

m_resolveValues[index] = std::forward<ResolveValueType_>(resolveValue);
if constexpr (!std::is_void_v<ResolveValueT>)
m_resolveValues[index] = std::forward<ResolveValueType_>(resolveValue);
if (!--m_outstandingPromises) {
m_producer->resolve(WTF::map(std::exchange(m_resolveValues, { }), [](auto&& resolveValue) {
return WTFMove(*resolveValue);
}));
if constexpr (std::is_void_v<ResolveValueT>)
m_producer->resolve();
else {
m_producer->resolve(WTF::map(std::exchange(m_resolveValues, { }), [](auto&& resolveValue) {
return WTFMove(*resolveValue);
}));
}
m_producer = nullptr;
}
}
Expand All @@ -546,13 +554,14 @@ class NativePromise final : public NativePromiseBase, public ConvertibleToNative
}
m_producer->reject(std::forward<RejectValueType_>(rejectValue));
m_producer = nullptr;
m_resolveValues.clear();
if constexpr (!std::is_void_v<ResolveValueT>)
m_resolveValues.clear();
}

Ref<AllPromiseType> promise() { return static_cast<Ref<AllPromiseType>>(*m_producer); }

private:
Vector<std::optional<ResolveValueType>> m_resolveValues;
NO_UNIQUE_ADDRESS std::conditional_t<!std::is_void_v<ResolveValueT>, Vector<std::optional<ResolveValueType>>, VoidPlaceholder> m_resolveValues;
std::unique_ptr<typename AllPromiseType::Producer> m_producer;
size_t m_outstandingPromises;
};
Expand Down Expand Up @@ -593,27 +602,33 @@ class NativePromise final : public NativePromiseBase, public ConvertibleToNative

public:
template <class Dispatcher>
static Ref<AllPromiseType> all(Dispatcher& targetQueue, Vector<Ref<NativePromise>>& promises)
static Ref<AllPromiseType> all(Dispatcher& targetQueue, const Vector<Ref<NativePromise>>& promises)
{
static_assert(LooksLikeRCSerialDispatcher<typename RemoveSmartPointer<Dispatcher>::type>::value, "Must be used with a RefCounted SerialFunctionDispatcher");
if (promises.isEmpty())
return AllPromiseType::createAndResolve(Vector<ResolveValueType>());

if (promises.isEmpty()) {
if constexpr (std::is_void_v<ResolveValueT>)
return AllPromiseType::createAndResolve();
else
return AllPromiseType::createAndResolve(typename AllPromiseType::ResolveValueType());
}
auto producer = adoptRef(new AllPromiseProducer(promises.size()));
auto promise = producer->promise();
for (size_t i = 0; i < promises.size(); ++i) {
promises[i]->whenSettled(targetQueue, [producer, i] (ResultParam result) {
if (result)
producer->resolve(i, maybeMove(result.value()));
else
if (result) {
if constexpr (std::is_void_v<ResolveValueT>)
producer->resolve(i, VoidPlaceholder());
else
producer->resolve(i, maybeMove(result.value()));
} else
producer->reject(maybeMove(result.error()));
});
}
return promise;
}

template <class Dispatcher>
static Ref<AllSettledPromiseType> allSettled(Dispatcher& targetQueue, Vector<Ref<NativePromise>>& promises)
static Ref<AllSettledPromiseType> allSettled(Dispatcher& targetQueue, const Vector<Ref<NativePromise>>& promises)
{
static_assert(LooksLikeRCSerialDispatcher<typename RemoveSmartPointer<Dispatcher>::type>::value, "Must be used with a RefCounted SerialFunctionDispatcher");
if (promises.isEmpty())
Expand Down Expand Up @@ -768,8 +783,8 @@ class NativePromise final : public NativePromiseBase, public ConvertibleToNative
}
#endif

NO_UNIQUE_ADDRESS std::conditional_t<IsChaining, Lock, std::monostate> m_lock;
NO_UNIQUE_ADDRESS std::conditional_t<IsChaining, std::unique_ptr<typename ReturnPromiseType::Producer>, std::monostate> m_completionProducer WTF_GUARDED_BY_LOCK(m_lock);
NO_UNIQUE_ADDRESS std::conditional_t<IsChaining, Lock, VoidPlaceholder> m_lock;
NO_UNIQUE_ADDRESS std::conditional_t<IsChaining, std::unique_ptr<typename ReturnPromiseType::Producer>, VoidPlaceholder> m_completionProducer WTF_GUARDED_BY_LOCK(m_lock);
private:
CallBackType m_settleFunction WTF_GUARDED_BY_CAPABILITY(*ThenCallbackBase::m_targetQueue);
};
Expand Down Expand Up @@ -949,7 +964,7 @@ class NativePromise final : public NativePromiseBase, public ConvertibleToNative
using DispatcherRealType = typename RemoveSmartPointer<DispatcherType>::type;
static_assert(LooksLikeRCSerialDispatcher<DispatcherRealType>::value, "Must be used with a RefCounted SerialFunctionDispatcher");

using R1 = typename RemoveSmartPointer<decltype(invokeWithVoidOrWithArg(std::forward<ResolveFunction>(resolveFunction), std::declval<std::conditional_t<std::is_void_v<ResolveValueT>, std::nullptr_t, ResolveValueT>>()))>::type;
using R1 = typename RemoveSmartPointer<decltype(invokeWithVoidOrWithArg(std::forward<ResolveFunction>(resolveFunction), std::declval<std::conditional_t<std::is_void_v<ResolveValueT>, VoidPlaceholder, ResolveValueType>>()))>::type;
using R2 = typename RemoveSmartPointer<decltype(invokeWithVoidOrWithArg(std::forward<RejectFunction>(rejectFunction), std::declval<RejectValueT>()))>::type;
using IsChaining = std::bool_constant<RelatedNativePromise<R1, R2>>;
static_assert(IsChaining::value || (std::is_void_v<R1> && std::is_void_v<R2>), "resolve/reject methods must return a promise of the same type or nothing");
Expand All @@ -958,7 +973,7 @@ class NativePromise final : public NativePromiseBase, public ConvertibleToNative
return whenSettled(targetQueue, [resolveFunction = std::forward<ResolveFunction>(resolveFunction), rejectFunction = std::forward<RejectFunction>(rejectFunction)] (ResultParam result) mutable -> LambdaReturnType {
if (result) {
if constexpr (std::is_void_v<ResolveValueT>)
return invokeWithVoidOrWithArg(WTFMove(resolveFunction), std::nullptr_t());
return invokeWithVoidOrWithArg(WTFMove(resolveFunction), VoidPlaceholder());
else
return invokeWithVoidOrWithArg(WTFMove(resolveFunction), maybeMove(result.value()));
}
Expand All @@ -970,7 +985,7 @@ class NativePromise final : public NativePromiseBase, public ConvertibleToNative
auto then(DispatcherType& targetQueue, ThisType& thisVal, ResolveMethod resolveMethod, RejectMethod rejectMethod, const Logger::LogSiteIdentifier& callSite = DEFAULT_LOGSITEIDENTIFIER)
{
static_assert(HasRefCountMethods<ThisType>::value, "ThisType must be refounted object");
using R1 = typename RemoveSmartPointer<decltype(invokeWithVoidOrWithArg(thisVal, resolveMethod, std::declval<std::conditional_t<std::is_void_v<ResolveValueT>, std::nullptr_t, ResolveValueT>>()))>::type;
using R1 = typename RemoveSmartPointer<decltype(invokeWithVoidOrWithArg(thisVal, resolveMethod, std::declval<std::conditional_t<std::is_void_v<ResolveValueT>, VoidPlaceholder, ResolveValueType>>()))>::type;
using R2 = typename RemoveSmartPointer<decltype(invokeWithVoidOrWithArg(thisVal, rejectMethod, std::declval<RejectValueT>()))>::type;
using IsChaining = std::bool_constant<RelatedNativePromise<R1, R2>>;
static_assert(IsChaining::value || (std::is_void_v<R1> && std::is_void_v<R2>), "resolve/reject methods must return a promise of the same type or nothing");
Expand All @@ -979,7 +994,7 @@ class NativePromise final : public NativePromiseBase, public ConvertibleToNative
return whenSettled(targetQueue, [thisVal = Ref { thisVal }, resolveMethod, rejectMethod] (ResultParam result) -> LambdaReturnType {
if (result) {
if constexpr (std::is_void_v<ResolveValueT>)
return invokeWithVoidOrWithArg(thisVal.get(), resolveMethod, std::nullptr_t());
return invokeWithVoidOrWithArg(thisVal.get(), resolveMethod, VoidPlaceholder());
else
return invokeWithVoidOrWithArg(thisVal.get(), resolveMethod, maybeMove(result.value()));
}
Expand Down
26 changes: 26 additions & 0 deletions Tools/TestWebKitAPI/Tests/WTF/NativePromise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,32 @@ TEST(NativePromise, PromiseAllResolve)
});
}

TEST(NativePromise, PromiseVoidAllResolve)
{
AutoWorkQueue awq;
auto queue = awq.queue();
queue->dispatch([queue] {
Vector<Ref<GenericPromise>> promises;
promises.append(GenericPromise::createAndResolve());
promises.append(GenericPromise::createAndResolve());
promises.append(GenericPromise::createAndResolve());

GenericPromise::all(queue, promises)->then(queue,
[] () {
EXPECT_TRUE(true);
},
doFail());

GenericPromise::all(queue, Vector<Ref<GenericPromise>>(10, [](size_t) {
return GenericPromise::createAndResolve();
}))->then(queue,
[queue] () {
queue->beginShutdown();
},
doFail());
});
}

TEST(NativePromise, PromiseAllResolveAsync)
{
AutoWorkQueue awq;
Expand Down

0 comments on commit cc6dee3

Please sign in to comment.