From de8383bbb3f8a3966bef881c8d637b9a4bd5dd95 Mon Sep 17 00:00:00 2001 From: Philippe Hausler Date: Wed, 15 Jun 2022 23:02:21 -0700 Subject: [PATCH] Ensure AsyncThrowingChannel reliably throws errors when racing against iteration --- .../AsyncThrowingChannel.swift | 34 +++++++++++++------ Tests/AsyncAlgorithmsTests/TestChannel.swift | 6 ++++ 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/Sources/AsyncAlgorithms/AsyncThrowingChannel.swift b/Sources/AsyncAlgorithms/AsyncThrowingChannel.swift index 58c9e8c9..228d3ab4 100644 --- a/Sources/AsyncAlgorithms/AsyncThrowingChannel.swift +++ b/Sources/AsyncAlgorithms/AsyncThrowingChannel.swift @@ -103,10 +103,15 @@ public final class AsyncThrowingChannel: Asyn } } + enum Termination { + case finished + case failed(Error) + } + struct State { var emission: Emission = .idle var generation = 0 - var terminal = false + var terminal: Termination? } let state = ManagedCriticalState(State()) @@ -129,10 +134,11 @@ public final class AsyncThrowingChannel: Asyn func next(_ generation: Int) async throws -> Element? { return try await withUnsafeThrowingContinuation { continuation in var cancelled = false - var terminal = false + var terminal: Termination? state.withCriticalRegion { state -> UnsafeResumption?, Never>? in - if state.terminal { - terminal = true + if let termination = state.terminal { + state.terminal = .finished + terminal = termination return nil } switch state.emission { @@ -160,18 +166,26 @@ public final class AsyncThrowingChannel: Asyn return nil } }?.resume() - if cancelled || terminal { + if cancelled { continuation.resume(returning: nil) } + if let terminal = terminal { + switch terminal { + case .finished: + continuation.resume(returning: nil) + case .failed(let error): + continuation.resume(throwing: error) + } + } } } func finishAll() { let (sends, nexts) = state.withCriticalRegion { state -> ([UnsafeContinuation?, Never>], Set) in - if state.terminal { + if state.terminal != nil { return ([], []) } - state.terminal = true + state.terminal = .finished switch state.emission { case .idle: return ([], []) @@ -197,12 +211,12 @@ public final class AsyncThrowingChannel: Asyn } operation: { let continuation: UnsafeContinuation? = await withUnsafeContinuation { continuation in state.withCriticalRegion { state -> UnsafeResumption?, Never>? in - if state.terminal { + if state.terminal != nil { return UnsafeResumption(continuation: continuation, success: nil) } - if case .failure = result { - state.terminal = true + if case .failure(let error) = result { + state.terminal = .failed(error) } switch state.emission { diff --git a/Tests/AsyncAlgorithmsTests/TestChannel.swift b/Tests/AsyncAlgorithmsTests/TestChannel.swift index 891ec434..e51b7261 100644 --- a/Tests/AsyncAlgorithmsTests/TestChannel.swift +++ b/Tests/AsyncAlgorithmsTests/TestChannel.swift @@ -75,6 +75,12 @@ final class TestChannel: XCTestCase { } catch { XCTAssertEqual(error as? Failure, Failure()) } + do { + let value = try await iterator.next() + XCTAssertNil(value) + } catch { + XCTFail("Unexpected throw of error after iteration producing error") + } } func test_asyncChannel_ends_alls_iterators_and_discards_additional_sent_values_when_finish_is_called() async {