diff --git a/README.md b/README.md index 44aead6..da49271 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ finally - `Amb` - Relay items of the source that responds first, disposing the others - `Create` - generate values via async push +- `CombineLatest` - combines the latest items of the source async sequences via a function into results - `Concat` - concatenate multiple async sequences - `Defer` - defer the creation of the actual `IAsyncEnumerable` - `Error` - signal an error diff --git a/async-enumerable-dotnet-test/CombineLatestTest.cs b/async-enumerable-dotnet-test/CombineLatestTest.cs new file mode 100644 index 0000000..7d80681 --- /dev/null +++ b/async-enumerable-dotnet-test/CombineLatestTest.cs @@ -0,0 +1,92 @@ +// Copyright (c) David Karnok & Contributors. +// Licensed under the Apache 2.0 License. +// See LICENSE file in the project root for full license information. + +using Xunit; +using async_enumerable_dotnet; +using System.Linq; +using System; + +namespace async_enumerable_dotnet_test +{ + public class CombineLatestTest + { + [Fact] + public async void Empty() + { + await AsyncEnumerable.CombineLatest(v => v.Sum()) + .AssertResult(); + } + + [Fact] + public async void Single() + { + await AsyncEnumerable.CombineLatest(v => v.Sum() + 1, + AsyncEnumerable.Just(1)) + .AssertResult(2); + } + + [Fact] + public async void One_Item_Each() + { + await AsyncEnumerable.CombineLatest(v => v.Sum(), AsyncEnumerable.Just(1), AsyncEnumerable.Just(2)) + .AssertResult(3); + } + + [Fact] + public async void One_Is_Empty() + { + await AsyncEnumerable.CombineLatest(v => v.Sum(), AsyncEnumerable.Empty(), AsyncEnumerable.Just(2)) + .AssertResult(); + } + + [Fact] + public async void Two_Is_Empty() + { + await AsyncEnumerable.CombineLatest(v => v.Sum(), AsyncEnumerable.Just(1), AsyncEnumerable.Empty()) + .AssertResult(); + } + + [Fact] + public async void ZigZag() + { + var t = 200; + if (Environment.GetEnvironmentVariable("CI") != null) + { + t = 2000; + } + await AsyncEnumerable.CombineLatest(v => v.Sum(), + AsyncEnumerable.Interval(1, 5, TimeSpan.FromMilliseconds(t)), + AsyncEnumerable.Interval(1, 5, TimeSpan.FromMilliseconds(t + t / 2), TimeSpan.FromMilliseconds(t)).Map(v => v * 10) + ) + .AssertResult(11, 12, 22, 23, 33, 34, 44, 45, 55); + } + + [Fact] + public async void Second_Many() + { + var t = 200; + if (Environment.GetEnvironmentVariable("CI") != null) + { + t = 2000; + } + await AsyncEnumerable.CombineLatest(v => v.Sum(), + AsyncEnumerable.Just(10L), + AsyncEnumerable.Interval(1, 5, TimeSpan.FromMilliseconds(t + t / 2), TimeSpan.FromMilliseconds(t)) + ) + .AssertResult(11, 12, 13, 14, 15); + } + + [Fact] + public async void Error() + { + await AsyncEnumerable.CombineLatest(v => v.Sum(), + AsyncEnumerable.Just(1), + AsyncEnumerable.Just(2).ConcatWith( + AsyncEnumerable.Error(new InvalidOperationException()) + ) + ) + .AssertFailure(typeof(InvalidOperationException), 3); + } + } +} diff --git a/async-enumerable-dotnet/AsyncEnumerable.cs b/async-enumerable-dotnet/AsyncEnumerable.cs index 6922048..fac7fca 100644 --- a/async-enumerable-dotnet/AsyncEnumerable.cs +++ b/async-enumerable-dotnet/AsyncEnumerable.cs @@ -1551,5 +1551,31 @@ public static IAsyncEnumerable Merge(this IAsyncEnumerable(source, other, func); } + + /// + /// Combines the latest items from each async source into a single + /// sequence of results via a combiner function. + /// + /// The element type of the sources. + /// The result type. + /// The function that receives the latest elements + /// of all sources (if they all produced an item at least) and should + /// a value to be emitted to the consumer. + /// The params array of the async sequences to combine. + /// The new IAsyncEnumerable sequence. + public static IAsyncEnumerable CombineLatest(Func combiner, params IAsyncEnumerable[] sources) + { + RequireNonNull(sources, nameof(sources)); + RequireNonNull(combiner, nameof(combiner)); + if (sources.Length == 0) + { + return Empty(); + } + if (sources.Length == 1) + { + return sources[0].Map(v => combiner(new[] { v })); + } + return new CombineLatest(sources, combiner); + } } } diff --git a/async-enumerable-dotnet/impl/CombineLatest.cs b/async-enumerable-dotnet/impl/CombineLatest.cs new file mode 100644 index 0000000..8d82d34 --- /dev/null +++ b/async-enumerable-dotnet/impl/CombineLatest.cs @@ -0,0 +1,281 @@ +// Copyright (c) David Karnok & Contributors. +// Licensed under the Apache 2.0 License. +// See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace async_enumerable_dotnet.impl +{ + internal sealed class CombineLatest : IAsyncEnumerable + { + private readonly IAsyncEnumerable[] _sources; + + private readonly Func _combiner; + + public CombineLatest(IAsyncEnumerable[] sources, Func combiner) + { + _sources = sources; + _combiner = combiner; + } + + public IAsyncEnumerator GetAsyncEnumerator() + { + return new CombineLatestEnumerator(_sources, _combiner); + } + + private sealed class CombineLatestEnumerator : IAsyncEnumerator + { + private readonly InnerHandler[] _sources; + + private readonly Func _combiner; + + public TResult Current { get; private set; } + + private bool once; + + private TaskCompletionSource _resume; + + private Exception _disposeError; + private int _disposeWip; + private readonly TaskCompletionSource _disposeTask; + + private Exception _error; + private int _done; + private TSource[] _latest; + private int _latestRemaining; + + private readonly ConcurrentQueue _queue; + + public CombineLatestEnumerator(IAsyncEnumerable[] sources, Func combiner) + { + var n = sources.Length; + _sources = new InnerHandler[n]; + for (var i = 0; i < n; i++) + { + _sources[i] = new InnerHandler(sources[i].GetAsyncEnumerator(), this, i); + } + _combiner = combiner; + _disposeTask = new TaskCompletionSource(); + _latest = new TSource[n]; + _queue = new ConcurrentQueue(); + _latestRemaining = n; + Volatile.Write(ref _disposeWip, n); + Volatile.Write(ref _done, n); + } + + internal void MoveNextAll() + { + foreach (var inner in _sources) + { + inner.MoveNext(); + } + } + + public async ValueTask DisposeAsync() + { + foreach (var inner in _sources) + { + inner.Dispose(); + } + await _disposeTask.Task; + + _latest = null; + Current = default; + while (_queue.TryDequeue(out _)) { } + } + + public async ValueTask MoveNextAsync() + { + if (!once) + { + once = true; + MoveNextAll(); + } + + var latest = _latest; + var n = latest.Length; + + for (; ; ) { + + if (_done == 0) + { + var ex = ExceptionHelper.Terminate(ref _error); + if (ex != null) + { + throw ex; + } + return false; + } + + var success = _queue.TryDequeue(out var entry); + + if (success) + { + var inner = _sources[entry.Index]; + + if (entry.Done) + { + if (inner._hasLatest) + { + _done--; + } + else + { + _done = 0; + } + continue; + } + + if (!inner._hasLatest) + { + inner._hasLatest = true; + _latestRemaining--; + } + + latest[entry.Index] = entry.Value; + + if (_latestRemaining == 0) + { + var copy = new TSource[n]; + Array.Copy(latest, 0, copy, 0, n); + + Current = _combiner(copy); + + inner.MoveNext(); + return true; + } + + inner.MoveNext(); + continue; + } + + await ResumeHelper.Await(ref _resume); + ResumeHelper.Clear(ref _resume); + } + } + + internal void Dispose(IAsyncDisposable disposable) + { + disposable.DisposeAsync() + .AsTask() + .ContinueWith(DisposeHandlerAction, this); + } + + private static readonly Action DisposeHandlerAction = (t, state) => ((CombineLatestEnumerator)state).DisposeHandler(t); + + private void DisposeHandler(Task t) + { + QueueDrainHelper.DisposeHandler(t, ref _disposeWip, ref _disposeError, _disposeTask); + } + + internal void InnerNext(int index, TSource value) + { + _queue.Enqueue(new Entry + { + Index = index, + Done = false, + Value = value + }); + } + + internal void InnerError(int index, Exception ex) + { + ExceptionHelper.AddException(ref _error, ex); + _queue.Enqueue(new Entry { + Index = index, Done = true, Value = default + }); + } + + internal void InnerComplete(int index) + { + _queue.Enqueue(new Entry { Index = index, Done = true, Value = default }); + } + + internal void Signal() + { + ResumeHelper.Resume(ref _resume); + } + + struct Entry + { + internal int Index; + internal TSource Value; + internal bool Done; + } + + internal sealed class InnerHandler + { + private readonly IAsyncEnumerator _source; + + private readonly CombineLatestEnumerator _parent; + + internal readonly int Index; + + private int _disposeWip; + + private int _sourceWip; + + internal bool _hasLatest; + + public TSource Current => _source.Current; + + public InnerHandler(IAsyncEnumerator source, CombineLatestEnumerator parent, int index) + { + _source = source; + _parent = parent; + Index = index; + } + + internal void MoveNext() + { + QueueDrainHelper.MoveNext(_source, ref _sourceWip, ref _disposeWip, NextHandlerAction, this); + } + + private static readonly Action, object> NextHandlerAction = (t, state) => ((InnerHandler)state).NextHandler(t); + + private bool TryDispose() + { + if (Interlocked.Decrement(ref _disposeWip) != 0) + { + _parent.Dispose(_source); + return false; + } + return true; + } + + private void NextHandler(Task t) + { + if (t.IsFaulted) + { + _parent.InnerError(Index, ExceptionHelper.Extract(t.Exception)); + } + else if (t.Result) + { + _parent.InnerNext(Index, _source.Current); + } + else + { + _parent.InnerComplete(Index); + } + if (TryDispose()) + { + _parent.Signal(); + } + } + + internal void Dispose() + { + if (Interlocked.Increment(ref _disposeWip) == 1) + { + _parent.Dispose(_source); + } + } + } + } + } +}