diff --git a/Package.swift b/Package.swift index 194ab8fd..ce7f1ba0 100644 --- a/Package.swift +++ b/Package.swift @@ -24,4 +24,4 @@ let package = Package( name: "AsyncAlgorithmsTests", dependencies: ["AsyncAlgorithms"]), ] -) \ No newline at end of file +) diff --git a/Sources/AsyncAlgorithms/TaskFirst.swift b/Sources/AsyncAlgorithms/TaskFirst.swift new file mode 100644 index 00000000..b1abf344 --- /dev/null +++ b/Sources/AsyncAlgorithms/TaskFirst.swift @@ -0,0 +1,129 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift Async Algorithms open source project +// +// Copyright (c) 2022 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// +//===----------------------------------------------------------------------===// + +struct TaskFirstState: Sendable { + var continuation: UnsafeContinuation? + var tasks: [Task]? = [] + + mutating func add(_ task: Task) -> Task? { + if var tasks = tasks { + tasks.append(task) + self.tasks = tasks + return nil + } else { + return task + } + } +} + +extension Task { + /// Determine the first result of a sequence of tasks. + /// + /// - Parameters: + /// - tasks: The running tasks to obtain a result from + /// - Returns: The first result or thrown error from the running tasks + static func first( + _ tasks: Tasks + ) async throws -> Success + where Tasks.Element == Task, Failure == Error { + let state = ManagedCriticalState(TaskFirstState()) + return try await withTaskCancellationHandler { + let tasks = state.withCriticalRegion { state -> [Task] in + defer { state.tasks = nil } + return state.tasks ?? [] + } + for task in tasks { + task.cancel() + } + } operation: { + try await withUnsafeThrowingContinuation { continuation in + state.withCriticalRegion { state in + state.continuation = continuation + } + for task in tasks { + Task { + let result = await task.result + state.withCriticalRegion { state -> UnsafeContinuation? in + defer { state.continuation = nil } + return state.continuation + }?.resume(with: result) + } + state.withCriticalRegion { state in + state.add(task) + }?.cancel() + } + } + } + } + + /// Determine the first result of a list of tasks. + /// + /// - Parameters: + /// - tasks: The running tasks to obtain a result from + /// - Returns: The first result or thrown error from the running tasks + static func first( + _ tasks: Task... + ) async throws -> Success where Failure == Error { + try await first(tasks) + } +} + +extension Task where Failure == Never { + /// Determine the first result of a sequence of tasks. + /// + /// - Parameters: + /// - tasks: The running tasks to obtain a result from + /// - Returns: The first result from the running tasks + static func first( + _ tasks: Tasks + ) async -> Success + where Tasks.Element == Task { + let state = ManagedCriticalState(TaskFirstState()) + return await withTaskCancellationHandler { + let tasks = state.withCriticalRegion { state -> [Task] in + defer { state.tasks = nil } + return state.tasks ?? [] + } + for task in tasks { + task.cancel() + } + } operation: { + await withUnsafeContinuation { continuation in + state.withCriticalRegion { state in + state.continuation = continuation + } + for task in tasks { + Task { + let result = await task.result + state.withCriticalRegion { state -> UnsafeContinuation? in + defer { state.continuation = nil } + return state.continuation + }?.resume(with: result) + } + state.withCriticalRegion { state in + state.add(task) + }?.cancel() + } + } + } + } + + /// Determine the first result of a list of tasks. + /// + /// - Parameters: + /// - tasks: The running tasks to obtain a result from + /// - Returns: The first result from the running tasks + static func first( + _ tasks: Task... + ) async -> Success { + await first(tasks) + } +} diff --git a/Tests/AsyncAlgorithmsTests/TestTaskFirst.swift b/Tests/AsyncAlgorithmsTests/TestTaskFirst.swift new file mode 100644 index 00000000..401c6427 --- /dev/null +++ b/Tests/AsyncAlgorithmsTests/TestTaskFirst.swift @@ -0,0 +1,78 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the Swift Async Algorithms open source project +// +// Copyright (c) 2022 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// +//===----------------------------------------------------------------------===// + +import XCTest +import AsyncAlgorithms + +final class TestTaskFirst: XCTestCase { + func test_first() async { + let firstValue = await Task.first(Task { + return 1 + }, Task { + try! await Task.sleep(nanoseconds: NSEC_PER_SEC * 2) + return 2 + }) + XCTAssertEqual(firstValue, 1) + } + + func test_second() async { + let firstValue = await Task.first(Task { + try! await Task.sleep(nanoseconds: NSEC_PER_SEC * 2) + return 1 + }, Task { + return 2 + }) + XCTAssertEqual(firstValue, 2) + } + + func test_throwing() async { + do { + _ = try await Task.first(Task { () async throws -> Int in + try await Task.sleep(nanoseconds: NSEC_PER_SEC * 2) + return 1 + }, Task { () async throws -> Int in + throw NSError(domain: NSCocoaErrorDomain, code: -1, userInfo: nil) + }) + XCTFail() + } catch { + XCTAssertEqual((error as NSError).code, -1) + } + } + + func test_cancellation() async { + let firstReady = expectation(description: "first ready") + let secondReady = expectation(description: "second ready") + let firstCancelled = expectation(description: "first cancelled") + let secondCancelled = expectation(description: "second cancelled") + let task = Task { + _ = await Task.first(Task { + await withTaskCancellationHandler { + firstCancelled.fulfill() + } operation: { () -> Int in + firstReady.fulfill() + try? await Task.sleep(nanoseconds: NSEC_PER_SEC * 2) + return 1 + } + }, Task { + await withTaskCancellationHandler { + secondCancelled.fulfill() + } operation: { () -> Int in + secondReady.fulfill() + try? await Task.sleep(nanoseconds: NSEC_PER_SEC * 2) + return 1 + } + }) + } + wait(for: [firstReady, secondReady], timeout: 1.0) + task.cancel() + wait(for: [firstCancelled, secondCancelled], timeout: 1.0) + } +}