diff --git a/src/core/Akka.Streams.Tests/Dsl/AsyncEnumerableSpec.cs b/src/core/Akka.Streams.Tests/Dsl/AsyncEnumerableSpec.cs index 225e907b795..5c9c72d0eed 100644 --- a/src/core/Akka.Streams.Tests/Dsl/AsyncEnumerableSpec.cs +++ b/src/core/Akka.Streams.Tests/Dsl/AsyncEnumerableSpec.cs @@ -10,19 +10,16 @@ using System.Threading; using System.Threading.Tasks; using Akka.Pattern; -using Akka.Routing; using Akka.Streams.Dsl; using Akka.Streams.TestKit; using Akka.TestKit; using FluentAssertions; -using Nito.AsyncEx.Synchronous; using Xunit; using Xunit.Abstractions; using System.Collections.Generic; -using Akka.Actor; -using Akka.Streams.Actors; -using Akka.Streams.Tests.Actor; -using Reactive.Streams; +using System.Runtime.CompilerServices; +using Akka.Util; +using FluentAssertions.Extensions; namespace Akka.Streams.Tests.Dsl { @@ -31,9 +28,10 @@ public class AsyncEnumerableSpec : AkkaSpec { private ActorMaterializer Materializer { get; } private ITestOutputHelper _helper; + public AsyncEnumerableSpec(ITestOutputHelper helper) : base( - AkkaSpecConfig.WithFallback(StreamTestDefaultMailbox.DefaultConfig), - helper) + AkkaSpecConfig.WithFallback(StreamTestDefaultMailbox.DefaultConfig), + helper) { _helper = helper; var settings = ActorMaterializerSettings.Create(Sys).WithInputBuffer(2, 16); @@ -41,14 +39,14 @@ public class AsyncEnumerableSpec : AkkaSpec } - [Fact] + [Fact] public async Task RunAsAsyncEnumerable_Uses_CancellationToken() { var input = Enumerable.Range(1, 6).ToList(); var cts = new CancellationTokenSource(); var token = cts.Token; - + var asyncEnumerable = Source.From(input).RunAsAsyncEnumerable(Materializer); var output = input.ToArray(); bool caught = false; @@ -63,10 +61,10 @@ await foreach (var a in asyncEnumerable.WithCancellation(token)) { caught = true; } - + caught.ShouldBeTrue(); } - + [Fact] public async Task RunAsAsyncEnumerable_must_return_an_IAsyncEnumerableT_from_a_Source() { @@ -78,7 +76,8 @@ await foreach (var a in asyncEnumerable) (output[0] == a).ShouldBeTrue("Did not get elements in order!"); output = output.Skip(1).ToArray(); } - output.Length.ShouldBe(0,"Did not receive all elements!"); + + output.Length.ShouldBe(0, "Did not receive all elements!"); } [Fact] @@ -92,15 +91,17 @@ await foreach (var a in asyncEnumerable) (output[0] == a).ShouldBeTrue("Did not get elements in order!"); output = output.Skip(1).ToArray(); } - output.Length.ShouldBe(0,"Did not receive all elements!"); - + + output.Length.ShouldBe(0, "Did not receive all elements!"); + output = input.ToArray(); await foreach (var a in asyncEnumerable) { (output[0] == a).ShouldBeTrue("Did not get elements in order!"); output = output.Skip(1).ToArray(); } - output.Length.ShouldBe(0,"Did not receive all elements in second enumeration!!"); + + output.Length.ShouldBe(0, "Did not receive all elements in second enumeration!!"); } @@ -110,8 +111,8 @@ public async Task RunAsAsyncEnumerable_Throws_on_Abrupt_Stream_termination() var materializer = ActorMaterializer.Create(Sys); var probe = this.CreatePublisherProbe(); var task = Source.FromPublisher(probe).RunAsAsyncEnumerable(materializer); - - var a = Task.Run( async () => + + var a = Task.Run(async () => { await foreach (var notused in task) { @@ -122,22 +123,23 @@ await foreach (var notused in task) //we want to send messages so we aren't just waiting forever. probe.SendNext(1); probe.SendNext(2); - bool thrown = false; + var thrown = false; try { await a; } - catch (StreamDetachedException e) - { - thrown = true; - } + catch (StreamDetachedException e) + { + thrown = true; + } catch (AbruptTerminationException e) { thrown = true; } + thrown.ShouldBeTrue(); } - + [Fact] public async Task RunAsAsyncEnumerable_Throws_if_materializer_gone_before_Enumeration() { @@ -150,47 +152,128 @@ async Task ShouldThrow() { await foreach (var a in task) { - } } - + await Assert.ThrowsAsync(ShouldThrow); } - [Fact] - public void AsyncEnumerableSource_Must_Complete_Immediately_With_No_elements_When_An_Empty_IAsyncEnumerable_Is_Passed_In() + [Fact] + public async Task + AsyncEnumerableSource_Must_Complete_Immediately_With_No_elements_When_An_Empty_IAsyncEnumerable_Is_Passed_In() { - Func> range = () => - { - return RangeAsync(1, 100); - }; + IAsyncEnumerable Range() => RangeAsync(0, 0); var subscriber = this.CreateManualSubscriberProbe(); - Source.From(range) + Source.From(Range) .RunWith(Sink.FromSubscriber(subscriber), Materializer); - var subscription = subscriber.ExpectSubscription(); + var subscription = await subscriber.ExpectSubscriptionAsync(); subscription.Request(100); - for (int i = 1; i <= 20; i++) + await subscriber.ExpectCompleteAsync(); + } + + [Fact] + public async Task AsyncEnumerableSource_Must_Process_All_Elements() + { + IAsyncEnumerable Range() => RangeAsync(0, 100); + var subscriber = this.CreateManualSubscriberProbe(); + + Source.From(Range) + .RunWith(Sink.FromSubscriber(subscriber), Materializer); + + var subscription = await subscriber.ExpectSubscriptionAsync(); + subscription.Request(101); + + await subscriber.ExpectNextNAsync(Enumerable.Range(0, 100)); + + await subscriber.ExpectCompleteAsync(); + } + + [Fact] + public async Task AsyncEnumerableSource_Must_Process_Source_That_Immediately_Throws() + { + IAsyncEnumerable Range() => ThrowingRangeAsync(0, 100, 50); + var subscriber = this.CreateManualSubscriberProbe(); + + Source.From(Range) + .RunWith(Sink.FromSubscriber(subscriber), Materializer); + + var subscription = await subscriber.ExpectSubscriptionAsync(); + subscription.Request(101); + + await subscriber.ExpectNextNAsync(Enumerable.Range(0, 50)); + + var exception = await subscriber.ExpectErrorAsync(); + + // Exception should be automatically unrolled, this SHOULD NOT be AggregateException + exception.Should().BeOfType(); + exception.Message.Should().Be("BOOM!"); + } + + [Fact] + public async Task AsyncEnumerableSource_Must_Cancel_Running_Source_If_Downstream_Completes() + { + var latch = new AtomicBoolean(); + IAsyncEnumerable Range() => ProbeableRangeAsync(0, 100, latch); + var subscriber = this.CreateManualSubscriberProbe(); + + Source.From(Range) + .RunWith(Sink.FromSubscriber(subscriber), Materializer); + + var subscription = await subscriber.ExpectSubscriptionAsync(); + subscription.Request(50); + await subscriber.ExpectNextNAsync(Enumerable.Range(0, 50)); + subscription.Cancel(); + + // The cancellation token inside the IAsyncEnumerable should be cancelled + await WithinAsync(3.Seconds(), async () => latch.Value); + } + + private static async IAsyncEnumerable RangeAsync(int start, int count, + [EnumeratorCancellation] CancellationToken token = default) + { + foreach (var i in Enumerable.Range(start, count)) { - var next = subscriber.ExpectNext(i); - _helper.WriteLine(i.ToString()); + await Task.Delay(10, token); + if(token.IsCancellationRequested) + yield break; + yield return i; } - - //subscriber.ExpectComplete(); } + + private static async IAsyncEnumerable ThrowingRangeAsync(int start, int count, int throwAt, + [EnumeratorCancellation] CancellationToken token = default) + { + foreach (var i in Enumerable.Range(start, count)) + { + if(token.IsCancellationRequested) + yield break; - static async IAsyncEnumerable RangeAsync(int start, int count) + if (i == throwAt) + throw new TestException("BOOM!"); + + yield return i; + } + } + + private static async IAsyncEnumerable ProbeableRangeAsync(int start, int count, AtomicBoolean latch, + [EnumeratorCancellation] CancellationToken token = default) { - for (var i = 0; i < count; i++) + token.Register(() => { - await Task.Delay(i); - yield return start + i; + latch.GetAndSet(true); + }); + foreach (var i in Enumerable.Range(start, count)) + { + if(token.IsCancellationRequested) + yield break; + + yield return i; } } } #else #endif - -} +} \ No newline at end of file diff --git a/src/core/Akka.Streams/Implementation/Fusing/Ops.cs b/src/core/Akka.Streams/Implementation/Fusing/Ops.cs index 50d8a49d68e..cba750fd5fd 100644 --- a/src/core/Akka.Streams/Implementation/Fusing/Ops.cs +++ b/src/core/Akka.Streams/Implementation/Fusing/Ops.cs @@ -10,6 +10,7 @@ using System.Collections.Immutable; using System.Linq; using System.Runtime.CompilerServices; +using System.Threading; using System.Threading.Tasks; using Akka.Annotations; using Akka.Event; @@ -3770,29 +3771,22 @@ public sealed class AsyncEnumerable : GraphStage> private sealed class Logic : OutGraphStageLogic { - private readonly IAsyncEnumerator _enumerator; + private readonly IAsyncEnumerable _enumerable; private readonly Outlet _outlet; private readonly Action _onSuccess; private readonly Action _onFailure; private readonly Action _onComplete; - private readonly Action> _handleContinuation; + + private CancellationTokenSource _completionCts; + private IAsyncEnumerator _enumerator; - public Logic(SourceShape shape, IAsyncEnumerator enumerator) : base(shape) + public Logic(SourceShape shape, IAsyncEnumerable enumerable) : base(shape) { - _enumerator = enumerator; + _enumerable = enumerable; _outlet = shape.Outlet; _onSuccess = GetAsyncCallback(OnSuccess); _onFailure = GetAsyncCallback(OnFailure); _onComplete = GetAsyncCallback(OnComplete); - _handleContinuation = task => - { - // Since this Action is used as task continuation, we cannot safely call corresponding - // OnSuccess/OnFailure/OnComplete methods directly. We need to do that via async callbacks. - if (task.IsFaulted) _onFailure(task.Exception); - else if (task.IsCanceled) _onFailure(new TaskCanceledException(task)); - else if (task.Result) _onSuccess(enumerator.Current); - else _onComplete(); - }; SetHandler(_outlet, this); } @@ -3805,12 +3799,19 @@ public Logic(SourceShape shape, IAsyncEnumerator enumerator) : base(shape) [MethodImpl(MethodImplOptions.AggressiveInlining)] private void OnSuccess(T element) => Push(_outlet, element); + public override void PreStart() + { + base.PreStart(); + _completionCts = new CancellationTokenSource(); + _enumerator = _enumerable.GetAsyncEnumerator(_completionCts.Token); + } + public override void OnPull() { var vtask = _enumerator.MoveNextAsync(); if (vtask.IsCompletedSuccessfully) { - // When MoveNextAsync returned immediatelly, we don't need to await. + // When MoveNextAsync returned immediately, we don't need to await. // We can use fast path instead. if (vtask.Result) { @@ -3822,25 +3823,68 @@ public override void OnPull() // if result is false, it means enumerator was closed. Complete stage in that case. CompleteStage(); } + } + else if (vtask.IsCompleted) // IsCompleted covers Faulted, Cancelled, and RanToCompletion async state + { + // vtask will always contains an exception because we know we're not successful and always throws + try + { + // This does not block because we know that the task already completed + // Using GetAwaiter().GetResult() to automatically unwraps AggregateException inner exception + vtask.GetAwaiter().GetResult(); + } + catch (Exception ex) + { + FailStage(ex); + return; + } + + throw new InvalidOperationException("Should never reach this code"); } else { - vtask.AsTask().ContinueWith(_handleContinuation); + async Task ProcessTask() + { + // Since this Action is used as task continuation, we cannot safely call corresponding + // OnSuccess/OnFailure/OnComplete methods directly. We need to do that via async callbacks. + try + { + var completed = await vtask.ConfigureAwait(false); + if (completed) + _onSuccess(_enumerator.Current); + else + _onComplete(); + } + catch (Exception ex) + { + _onFailure(ex); + } + } + +#pragma warning disable CS4014 + ProcessTask(); +#pragma warning restore CS4014 } } + public override void OnDownstreamFinish(Exception cause) { - var vtask = _enumerator.DisposeAsync(); - if (vtask.IsCompletedSuccessfully) + _completionCts.Cancel(); + _completionCts.Dispose(); + + try { - CompleteStage(); // if dispose completed immediately, complete stage directly + _enumerator.DisposeAsync().GetAwaiter().GetResult(); } - else + catch (Exception ex) { - // for async disposals use async callback - vtask.GetAwaiter().OnCompleted(_onComplete); + Log.Warning(ex, "Failed to dispose IAsyncEnumerator asynchronously"); + } + finally + { + CompleteStage(); + base.OnDownstreamFinish(cause); } - base.OnDownstreamFinish(cause); } } @@ -3856,7 +3900,7 @@ public AsyncEnumerable(Func> factory) } public override SourceShape Shape { get; } - protected override GraphStageLogic CreateLogic(Attributes inheritedAttributes) => new Logic(Shape, _factory().GetAsyncEnumerator()); + protected override GraphStageLogic CreateLogic(Attributes inheritedAttributes) => new Logic(Shape, _factory()); protected override Attributes InitialAttributes { get; } = DefaultAttributes.EnumerableSource;