Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 24 additions & 10 deletions Sources/AsyncAlgorithms/AsyncThrowingChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,15 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
}
}

enum Termination {
case finished
case failed(Error)
}

struct State {
var emission: Emission = .idle
var generation = 0
var terminal = false
var terminal: Termination?
}

let state = ManagedCriticalState(State())
Expand All @@ -129,10 +134,11 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
func next(_ generation: Int) async throws -> Element? {
return try await withUnsafeThrowingContinuation { continuation in
var cancelled = false
var terminal = false
var terminal: Termination?
state.withCriticalRegion { state -> UnsafeResumption<UnsafeContinuation<Element?, Error>?, Never>? in
if state.terminal {
terminal = true
if let termination = state.terminal {
state.terminal = .finished
terminal = termination
return nil
}
switch state.emission {
Expand Down Expand Up @@ -160,18 +166,26 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
return nil
}
}?.resume()
if cancelled || terminal {
if cancelled {
continuation.resume(returning: nil)
}
if let terminal = terminal {
switch terminal {
case .finished:
continuation.resume(returning: nil)
case .failed(let error):
continuation.resume(throwing: error)
}
}
}
}

func finishAll() {
let (sends, nexts) = state.withCriticalRegion { state -> ([UnsafeContinuation<UnsafeContinuation<Element?, Error>?, Never>], Set<Awaiting>) in
if state.terminal {
if state.terminal != nil {
return ([], [])
}
state.terminal = true
state.terminal = .finished
switch state.emission {
case .idle:
return ([], [])
Expand All @@ -197,12 +211,12 @@ public final class AsyncThrowingChannel<Element: Sendable, Failure: Error>: Asyn
} operation: {
let continuation: UnsafeContinuation<Element?, Error>? = await withUnsafeContinuation { continuation in
state.withCriticalRegion { state -> UnsafeResumption<UnsafeContinuation<Element?, Error>?, Never>? in
if state.terminal {
if state.terminal != nil {
return UnsafeResumption(continuation: continuation, success: nil)
}

if case .failure = result {
state.terminal = true
if case .failure(let error) = result {
state.terminal = .failed(error)
}

switch state.emission {
Expand Down
6 changes: 6 additions & 0 deletions Tests/AsyncAlgorithmsTests/TestChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ final class TestChannel: XCTestCase {
} catch {
XCTAssertEqual(error as? Failure, Failure())
}
do {
let value = try await iterator.next()
XCTAssertNil(value)
} catch {
XCTFail("Unexpected throw of error after iteration producing error")
}
}

func test_asyncChannel_ends_alls_iterators_and_discards_additional_sent_values_when_finish_is_called() async {
Expand Down