diff --git a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs index f129082a..d4ea969a 100644 --- a/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs +++ b/src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs @@ -66,9 +66,11 @@ sealed class _Merge : MoveNextSource, IUniTaskAsyncEnumerator readonly int length; readonly IUniTaskAsyncEnumerator[] enumerators; readonly MergeSourceState[] states; - readonly Queue<(T, Exception)> queuedResult = new Queue<(T, Exception)>(); + readonly Queue<(T, Exception, bool)> queuedResult = new Queue<(T, Exception, bool)>(); readonly CancellationToken cancellationToken; + int moveNextCompleted; + public T Current { get; private set; } public _Merge(IUniTaskAsyncEnumerable[] sources, CancellationToken cancellationToken) @@ -88,33 +90,35 @@ public UniTask MoveNextAsync() { cancellationToken.ThrowIfCancellationRequested(); completionSource.Reset(); + Interlocked.Exchange(ref moveNextCompleted, 0); - lock (queuedResult) + if (HasQueuedResult() && Interlocked.CompareExchange(ref moveNextCompleted, 1, 0) == 0) { - if (queuedResult.Count > 0) + (T, Exception, bool) value; + lock (states) { - var result = queuedResult.Dequeue(); - var queuedValue = result.Item1; - var queuedException = result.Item2; - - if (queuedException != null) - { - completionSource.TrySetException(queuedException); - } - else - { - Current = queuedValue; - completionSource.TrySetResult(!IsCompletedAll()); - } - return new UniTask(this, completionSource.Version); + value = queuedResult.Dequeue(); + } + var resultValue = value.Item1; + var exception = value.Item2; + var hasNext = value.Item3; + if (exception != null) + { + completionSource.TrySetException(exception); + } + else + { + Current = resultValue; + completionSource.TrySetResult(hasNext); } + return new UniTask(this, completionSource.Version); } for (var i = 0; i < length; i++) { - lock (queuedResult) + lock (states) { - if (states[i] == (int)MergeSourceState.Pending) + if (states[i] == MergeSourceState.Pending) { states[i] = MergeSourceState.Running; } @@ -158,48 +162,67 @@ static void GetResultAt(object state) void GetResultAt(int index, UniTask.Awaiter awaiter) { bool hasNext; + bool completedAll; try { hasNext = awaiter.GetResult(); } catch (Exception ex) { - if (!completionSource.TrySetException(ex)) + if (Interlocked.CompareExchange(ref moveNextCompleted, 1, 0) == 0) { - lock (queuedResult) + completionSource.TrySetException(ex); + } + else + { + lock (states) { - queuedResult.Enqueue((default, ex)); + queuedResult.Enqueue((default, ex, default)); } } return; } - lock (queuedResult) + lock (states) { states[index] = hasNext ? MergeSourceState.Pending : MergeSourceState.Completed; - var completedAll = !hasNext && IsCompletedAll(); - if (hasNext || completedAll) + completedAll = !hasNext && IsCompletedAll(); + } + if (hasNext || completedAll) + { + if (Interlocked.CompareExchange(ref moveNextCompleted, 1, 0) == 0) { - if (completionSource.GetStatus(completionSource.Version).IsCompleted()) - { - queuedResult.Enqueue((enumerators[index].Current, null)); - } - else + Current = enumerators[index].Current; + completionSource.TrySetResult(!completedAll); + } + else + { + lock (states) { - Current = enumerators[index].Current; - completionSource.TrySetResult(!completedAll); + queuedResult.Enqueue((enumerators[index].Current, null, !completedAll)); } } } } + bool HasQueuedResult() + { + lock (states) + { + return queuedResult.Count > 0; + } + } + bool IsCompletedAll() { - for (var i = 0; i < length; i++) + lock (states) { - if (states[i] != MergeSourceState.Completed) + for (var i = 0; i < length; i++) { - return false; + if (states[i] != MergeSourceState.Completed) + { + return false; + } } } return true;