From f340828f399e909dacffddc8cf7cb5f1adde5c59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?D=C3=A1vid=20Karnok?= Date: Tue, 30 Oct 2018 23:20:04 +0100 Subject: [PATCH] Fix error handling in Create() --- async-enumerable-dotnet-test/CreateTest.cs | 39 +++++++++++++++++++ async-enumerable-dotnet/impl/CreateEmitter.cs | 13 ++++++- 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/async-enumerable-dotnet-test/CreateTest.cs b/async-enumerable-dotnet-test/CreateTest.cs index 2dfad7c..d0ee040 100644 --- a/async-enumerable-dotnet-test/CreateTest.cs +++ b/async-enumerable-dotnet-test/CreateTest.cs @@ -5,6 +5,7 @@ using Xunit; using async_enumerable_dotnet; using System.Threading.Tasks; +using System; namespace async_enumerable_dotnet_test { @@ -43,5 +44,43 @@ public async void Range_Loop() await Range(); } } + + [Fact] + public async void Items_And_Error() + { + var result = AsyncEnumerable.Create(async e => + { + await e.Next(1); + + await e.Next(2); + + throw new InvalidOperationException(); + }); + + await result.AssertFailure(typeof(InvalidOperationException), 1, 2); + } + + [Fact] + public async ValueTask Take() + { + await AsyncEnumerable.Create(async e => + { + for (var i = 0; i < 10 && !e.DisposeAsyncRequested; i++) + { + await e.Next(i); + } + }) + .Take(5) + .AssertResult(0, 1, 2, 3, 4); + } + + [Fact] + public async void Take_Loop() + { + for (int j = 0; j < 1000; j++) + { + await Take(); + } + } } } diff --git a/async-enumerable-dotnet/impl/CreateEmitter.cs b/async-enumerable-dotnet/impl/CreateEmitter.cs index d8ca368..af5a6cb 100644 --- a/async-enumerable-dotnet/impl/CreateEmitter.cs +++ b/async-enumerable-dotnet/impl/CreateEmitter.cs @@ -32,6 +32,7 @@ private sealed class CreateEmitterEnumerator : IAsyncEnumerator, IAsyncEmitte public bool DisposeAsyncRequested => _disposeRequested; private bool _hasValue; + private Exception _error; public T Current { get; private set; } @@ -41,8 +42,7 @@ private sealed class CreateEmitterEnumerator : IAsyncEnumerator, IAsyncEmitte internal void SetTask(Task task) { - _task = task; - task.ContinueWith(async t => + _task = task.ContinueWith(async t => { if (_disposeRequested) { @@ -55,6 +55,8 @@ internal void SetTask(Task task) return; } + _error = ExceptionHelper.Extract(t.Exception); + ResumeHelper.Resume(ref _valueReady); }); } @@ -78,6 +80,13 @@ public async ValueTask MoveNextAsync() return true; } Current = default; + + var ex = _error; + if (ex != null) + { + _error = null; + throw ex; + } return false; }