-
-
Notifications
You must be signed in to change notification settings - Fork 784
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #498 from Cysharp/hadashiA/async-linq-merge
Add UniTaskAsyncEnumerable.Merge
- Loading branch information
Showing
5 changed files
with
374 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
using System; | ||
using System.Threading; | ||
using System.Threading.Tasks; | ||
using Cysharp.Threading.Tasks; | ||
using Cysharp.Threading.Tasks.Linq; | ||
using FluentAssertions; | ||
using Xunit; | ||
|
||
namespace NetCoreTests.Linq | ||
{ | ||
public class MergeTest | ||
{ | ||
[Fact] | ||
public async Task TwoSource() | ||
{ | ||
var semaphore = new SemaphoreSlim(1, 1); | ||
|
||
var a = UniTaskAsyncEnumerable.Create<string>(async (writer, _) => | ||
{ | ||
await UniTask.SwitchToThreadPool(); | ||
await semaphore.WaitAsync(); | ||
await writer.YieldAsync("A1"); | ||
semaphore.Release(); | ||
await semaphore.WaitAsync(); | ||
await writer.YieldAsync("A2"); | ||
semaphore.Release(); | ||
}); | ||
|
||
var b = UniTaskAsyncEnumerable.Create<string>(async (writer, _) => | ||
{ | ||
await UniTask.SwitchToThreadPool(); | ||
await semaphore.WaitAsync(); | ||
await writer.YieldAsync("B1"); | ||
await writer.YieldAsync("B2"); | ||
semaphore.Release(); | ||
await semaphore.WaitAsync(); | ||
await writer.YieldAsync("B3"); | ||
semaphore.Release(); | ||
}); | ||
|
||
var result = await a.Merge(b).ToArrayAsync(); | ||
result.Should().Equal("A1", "B1", "B2", "A2", "B3"); | ||
} | ||
|
||
[Fact] | ||
public async Task ThreeSource() | ||
{ | ||
var semaphore = new SemaphoreSlim(0, 1); | ||
|
||
var a = UniTaskAsyncEnumerable.Create<string>(async (writer, _) => | ||
{ | ||
await UniTask.SwitchToThreadPool(); | ||
await semaphore.WaitAsync(); | ||
await writer.YieldAsync("A1"); | ||
semaphore.Release(); | ||
await semaphore.WaitAsync(); | ||
await writer.YieldAsync("A2"); | ||
semaphore.Release(); | ||
}); | ||
|
||
var b = UniTaskAsyncEnumerable.Create<string>(async (writer, _) => | ||
{ | ||
await UniTask.SwitchToThreadPool(); | ||
await semaphore.WaitAsync(); | ||
await writer.YieldAsync("B1"); | ||
await writer.YieldAsync("B2"); | ||
semaphore.Release(); | ||
await semaphore.WaitAsync(); | ||
await writer.YieldAsync("B3"); | ||
semaphore.Release(); | ||
}); | ||
|
||
var c = UniTaskAsyncEnumerable.Create<string>(async (writer, _) => | ||
{ | ||
await UniTask.SwitchToThreadPool(); | ||
await writer.YieldAsync("C1"); | ||
semaphore.Release(); | ||
}); | ||
|
||
var result = await a.Merge(b, c).ToArrayAsync(); | ||
result.Should().Equal("C1", "A1", "B1", "B2", "A2", "B3"); | ||
} | ||
|
||
[Fact] | ||
public async Task Throw() | ||
{ | ||
var a = UniTaskAsyncEnumerable.Create<string>(async (writer, _) => | ||
{ | ||
await writer.YieldAsync("A1"); | ||
}); | ||
|
||
var b = UniTaskAsyncEnumerable.Create<string>(async (writer, _) => | ||
Check warning on line 102 in src/UniTask.NetCoreTests/Linq/Merge.cs
|
||
{ | ||
throw new UniTaskTestException(); | ||
}); | ||
|
||
var enumerator = a.Merge(b).GetAsyncEnumerator(); | ||
(await enumerator.MoveNextAsync()).Should().Be(true); | ||
enumerator.Current.Should().Be("A1"); | ||
|
||
await Assert.ThrowsAsync<UniTaskTestException>(async () => await enumerator.MoveNextAsync()); | ||
} | ||
|
||
[Fact] | ||
public async Task Cancel() | ||
{ | ||
var cts = new CancellationTokenSource(); | ||
|
||
var a = UniTaskAsyncEnumerable.Create<string>(async (writer, _) => | ||
{ | ||
await writer.YieldAsync("A1"); | ||
}); | ||
|
||
var b = UniTaskAsyncEnumerable.Create<string>(async (writer, _) => | ||
{ | ||
await writer.YieldAsync("B1"); | ||
}); | ||
|
||
var enumerator = a.Merge(b).GetAsyncEnumerator(cts.Token); | ||
(await enumerator.MoveNextAsync()).Should().Be(true); | ||
enumerator.Current.Should().Be("A1"); | ||
|
||
cts.Cancel(); | ||
await Assert.ThrowsAsync<OperationCanceledException>(async () => await enumerator.MoveNextAsync()); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
232 changes: 232 additions & 0 deletions
232
src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,232 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using System.Threading; | ||
using Cysharp.Threading.Tasks.Internal; | ||
|
||
namespace Cysharp.Threading.Tasks.Linq | ||
{ | ||
public static partial class UniTaskAsyncEnumerable | ||
{ | ||
public static IUniTaskAsyncEnumerable<T> Merge<T>(this IUniTaskAsyncEnumerable<T> first, IUniTaskAsyncEnumerable<T> second) | ||
{ | ||
Error.ThrowArgumentNullException(first, nameof(first)); | ||
Error.ThrowArgumentNullException(second, nameof(second)); | ||
|
||
return new Merge<T>(new [] { first, second }); | ||
} | ||
|
||
public static IUniTaskAsyncEnumerable<T> Merge<T>(this IUniTaskAsyncEnumerable<T> first, IUniTaskAsyncEnumerable<T> second, IUniTaskAsyncEnumerable<T> third) | ||
{ | ||
Error.ThrowArgumentNullException(first, nameof(first)); | ||
Error.ThrowArgumentNullException(second, nameof(second)); | ||
Error.ThrowArgumentNullException(third, nameof(third)); | ||
|
||
return new Merge<T>(new[] { first, second, third }); | ||
} | ||
|
||
public static IUniTaskAsyncEnumerable<T> Merge<T>(this IEnumerable<IUniTaskAsyncEnumerable<T>> sources) | ||
{ | ||
return new Merge<T>(sources.ToArray()); | ||
} | ||
|
||
public static IUniTaskAsyncEnumerable<T> Merge<T>(params IUniTaskAsyncEnumerable<T>[] sources) | ||
{ | ||
return new Merge<T>(sources); | ||
} | ||
} | ||
|
||
internal sealed class Merge<T> : IUniTaskAsyncEnumerable<T> | ||
{ | ||
readonly IUniTaskAsyncEnumerable<T>[] sources; | ||
|
||
public Merge(IUniTaskAsyncEnumerable<T>[] sources) | ||
{ | ||
if (sources.Length <= 0) | ||
{ | ||
Error.ThrowArgumentException("No source async enumerable to merge"); | ||
} | ||
this.sources = sources; | ||
} | ||
|
||
public IUniTaskAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default) | ||
=> new _Merge(sources, cancellationToken); | ||
|
||
enum MergeSourceState | ||
{ | ||
Pending, | ||
Running, | ||
Completed, | ||
} | ||
|
||
sealed class _Merge : MoveNextSource, IUniTaskAsyncEnumerator<T> | ||
{ | ||
static readonly Action<object> GetResultAtAction = GetResultAt; | ||
|
||
readonly int length; | ||
readonly IUniTaskAsyncEnumerator<T>[] enumerators; | ||
readonly MergeSourceState[] states; | ||
readonly Queue<(T, Exception, bool)> queuedResult = new Queue<(T, Exception, bool)>(); | ||
readonly CancellationToken cancellationToken; | ||
|
||
int moveNextCompleted; | ||
|
||
public T Current { get; private set; } | ||
|
||
public _Merge(IUniTaskAsyncEnumerable<T>[] sources, CancellationToken cancellationToken) | ||
{ | ||
this.cancellationToken = cancellationToken; | ||
length = sources.Length; | ||
states = ArrayPool<MergeSourceState>.Shared.Rent(length); | ||
enumerators = ArrayPool<IUniTaskAsyncEnumerator<T>>.Shared.Rent(length); | ||
for (var i = 0; i < length; i++) | ||
{ | ||
enumerators[i] = sources[i].GetAsyncEnumerator(cancellationToken); | ||
states[i] = (int)MergeSourceState.Pending;; | ||
} | ||
} | ||
|
||
public UniTask<bool> MoveNextAsync() | ||
{ | ||
cancellationToken.ThrowIfCancellationRequested(); | ||
completionSource.Reset(); | ||
Interlocked.Exchange(ref moveNextCompleted, 0); | ||
|
||
if (HasQueuedResult() && Interlocked.CompareExchange(ref moveNextCompleted, 1, 0) == 0) | ||
{ | ||
(T, Exception, bool) value; | ||
lock (states) | ||
{ | ||
value = queuedResult.Dequeue(); | ||
} | ||
var resultValue = value.Item1; | ||
var exception = value.Item2; | ||
var hasNext = value.Item3; | ||
if (exception != null) | ||
{ | ||
completionSource.TrySetException(exception); | ||
} | ||
else | ||
{ | ||
Current = resultValue; | ||
completionSource.TrySetResult(hasNext); | ||
} | ||
return new UniTask<bool>(this, completionSource.Version); | ||
} | ||
|
||
for (var i = 0; i < length; i++) | ||
{ | ||
lock (states) | ||
{ | ||
if (states[i] == MergeSourceState.Pending) | ||
{ | ||
states[i] = MergeSourceState.Running; | ||
} | ||
else | ||
{ | ||
continue; | ||
} | ||
} | ||
var awaiter = enumerators[i].MoveNextAsync().GetAwaiter(); | ||
if (awaiter.IsCompleted) | ||
{ | ||
GetResultAt(i, awaiter); | ||
} | ||
else | ||
{ | ||
awaiter.SourceOnCompleted(GetResultAtAction, StateTuple.Create(this, i, awaiter)); | ||
} | ||
} | ||
return new UniTask<bool>(this, completionSource.Version); | ||
} | ||
|
||
public async UniTask DisposeAsync() | ||
{ | ||
for (var i = 0; i < length; i++) | ||
{ | ||
await enumerators[i].DisposeAsync(); | ||
} | ||
|
||
ArrayPool<MergeSourceState>.Shared.Return(states, true); | ||
ArrayPool<IUniTaskAsyncEnumerator<T>>.Shared.Return(enumerators, true); | ||
} | ||
|
||
static void GetResultAt(object state) | ||
{ | ||
using (var tuple = (StateTuple<_Merge, int, UniTask<bool>.Awaiter>)state) | ||
{ | ||
tuple.Item1.GetResultAt(tuple.Item2, tuple.Item3); | ||
} | ||
} | ||
|
||
void GetResultAt(int index, UniTask<bool>.Awaiter awaiter) | ||
{ | ||
bool hasNext; | ||
bool completedAll; | ||
try | ||
{ | ||
hasNext = awaiter.GetResult(); | ||
} | ||
catch (Exception ex) | ||
{ | ||
if (Interlocked.CompareExchange(ref moveNextCompleted, 1, 0) == 0) | ||
{ | ||
completionSource.TrySetException(ex); | ||
} | ||
else | ||
{ | ||
lock (states) | ||
{ | ||
queuedResult.Enqueue((default, ex, default)); | ||
} | ||
} | ||
return; | ||
} | ||
|
||
lock (states) | ||
{ | ||
states[index] = hasNext ? MergeSourceState.Pending : MergeSourceState.Completed; | ||
completedAll = !hasNext && IsCompletedAll(); | ||
} | ||
if (hasNext || completedAll) | ||
{ | ||
if (Interlocked.CompareExchange(ref moveNextCompleted, 1, 0) == 0) | ||
{ | ||
Current = enumerators[index].Current; | ||
completionSource.TrySetResult(!completedAll); | ||
} | ||
else | ||
{ | ||
lock (states) | ||
{ | ||
queuedResult.Enqueue((enumerators[index].Current, null, !completedAll)); | ||
} | ||
} | ||
} | ||
} | ||
|
||
bool HasQueuedResult() | ||
{ | ||
lock (states) | ||
{ | ||
return queuedResult.Count > 0; | ||
} | ||
} | ||
|
||
bool IsCompletedAll() | ||
{ | ||
lock (states) | ||
{ | ||
for (var i = 0; i < length; i++) | ||
{ | ||
if (states[i] != MergeSourceState.Completed) | ||
{ | ||
return false; | ||
} | ||
} | ||
} | ||
return true; | ||
} | ||
} | ||
} | ||
} |
3 changes: 3 additions & 0 deletions
3
src/UniTask/Assets/Plugins/UniTask/Runtime/Linq/Merge.cs.meta
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.