diff --git a/Sources/NTPClient/Deadline.swift b/Sources/NTPClient/Deadline.swift deleted file mode 100644 index bf0c99a..0000000 --- a/Sources/NTPClient/Deadline.swift +++ /dev/null @@ -1,63 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the SwiftNIO open source project -// -// Copyright (c) 2025 Apple Inc. and the SwiftNIO project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of SwiftNIO project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// - -#if canImport(FoundationEssentials) -internal import FoundationEssentials -#else -internal import Foundation -#endif - -@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -enum DeadlineFailure: Error { - case failed(_ failure: Error) - case timedOut(_ clock: ContinuousClock, _ deadline: ContinuousClock.Instant) - case cancelled -} - -@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) -func withDeadline( - _ deadline: ContinuousClock.Instant, - clock: ContinuousClock, - _ body: @Sendable () async throws -> T -) async throws -> T { - let result: Result = await withoutActuallyEscaping(body) { escapingClosure in - await withTaskGroup(of: Result?.self) { group in - group.addTask { - do { - try await Task.sleep(until: deadline, clock: clock) - return .failure(.timedOut(clock, deadline)) - } catch { - return nil - } - } - - group.addTask { - do { - let value = try await escapingClosure() - return Result.success(value) - } catch { - return Result.failure(DeadlineFailure.failed(error)) - } - } - - guard let result = await group.next() else { - return Result.failure(DeadlineFailure.cancelled) - } - group.cancelAll() - return result! // nil cannot occur here; it has been satisfied by either the sleep or the task - } - } - - return try result.get() -} diff --git a/Sources/NTPClient/NTPClient.swift b/Sources/NTPClient/NTPClient.swift index b042841..747ccb5 100644 --- a/Sources/NTPClient/NTPClient.swift +++ b/Sources/NTPClient/NTPClient.swift @@ -72,8 +72,7 @@ public struct NTPClient: Sendable { /// - Parameter timeout: A duration after which the operation will timeout. /// - Returns: response from the NTP server with some NTP specific calculations. public func query(timeout: Duration) async throws -> NTPResponse { - let deadlineInstant: ContinuousClock.Instant = ContinuousClock.Instant.now + timeout - return try await withDeadline(deadlineInstant, clock: ContinuousClock()) { + try await withTimeout(in: timeout, clock: .continuous) { let bootstrap = DatagramBootstrap( group: MultiThreadedEventLoopGroup.singleton ).channelInitializer { channel in diff --git a/Sources/NTPClient/Timeout.swift b/Sources/NTPClient/Timeout.swift new file mode 100644 index 0000000..f56f5d7 --- /dev/null +++ b/Sources/NTPClient/Timeout.swift @@ -0,0 +1,98 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2025 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +private enum TaskResult: Sendable { + case success(T) + case error(any Error) + case timedOut + case cancelled +} + +package struct TimeOutError: Error, CustomStringConvertible, CustomDebugStringConvertible { + var underlying: any Error + + package var description: String { + "TimeOutError(\(String(describing: underlying))" + } + + package var debugDescription: String { + description + } +} + +@available(macOS 13, iOS 16, tvOS 16, watchOS 9, *) +package func withTimeout( + in timeout: Clock.Duration, + clock: Clock, + isolation: isolated (any Actor)? = #isolation, + body: sending @escaping @isolated(any) () async throws -> T +) async throws -> T { + // This is needed so we can make body sending since we don't have call-once closures yet + let body = { body } + let result: Result = await withTaskGroup(of: TaskResult.self) { group in + let body = body() + group.addTask { + do { + return .success(try await body()) + } catch { + return .error(error) + } + } + group.addTask { + do { + try await clock.sleep(for: timeout, tolerance: .zero) + return .timedOut + } catch { + return .cancelled + } + } + + switch await group.next() { + case .success(let result): + // Work returned a result. Cancel the timer task and return + group.cancelAll() + return .success(result) + case .error(let error): + // Work threw. Cancel the timer task and rethrow + group.cancelAll() + return .failure(error) + case .timedOut: + // Timed out, cancel the work task. + group.cancelAll() + + switch await group.next() { + case .success(let result): + return .success(result) + case .error(let error): + return .failure(TimeOutError(underlying: error)) + case .timedOut, .cancelled, .none: + // We already got a result from the sleeping task so we can't get another one or none. + fatalError("Unexpected task result") + } + case .cancelled: + switch await group.next() { + case .success(let result): + return .success(result) + case .error(let error): + return .failure(TimeOutError(underlying: error)) + case .timedOut, .cancelled, .none: + // We already got a result from the sleeping task so we can't get another one or none. + fatalError("Unexpected task result") + } + case .none: + fatalError("Unexpected task result") + } + } + return try result.get() +} diff --git a/Tests/NTPClientTests/NTPClientTests.swift b/Tests/NTPClientTests/NTPClientTests.swift index 59ae772..63a1ff0 100644 --- a/Tests/NTPClientTests/NTPClientTests.swift +++ b/Tests/NTPClientTests/NTPClientTests.swift @@ -26,7 +26,7 @@ import Testing ) func testNTPQueryTimeout(d: Duration) async { let ntp = NTPClient(config: .init(), server: "169.254.0.1") - await #expect(throws: DeadlineFailure.self, "deadline should be thrown in \(d) seconds") { + await #expect(throws: TimeOutError.self, "notEnoughBytes") { let _ = try await ntp.query(timeout: d) } } diff --git a/Tests/NTPClientTests/TimeoutTests.swift b/Tests/NTPClientTests/TimeoutTests.swift new file mode 100644 index 0000000..94cf18b --- /dev/null +++ b/Tests/NTPClientTests/TimeoutTests.swift @@ -0,0 +1,80 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2025 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NTPClient +import Testing + +@Suite +struct TimeoutTests { + @Test + func workCompletes() async throws { + let expectedValue = "success" + + let result = try await withTimeout(in: .seconds(1), clock: .continuous) { + expectedValue + } + + #expect(result == expectedValue) + } + + @Test + func workTimesOut() async throws { + + let result = await withThrowingTaskGroup(of: Void.self) { group in + // Task to run the test + group.addTask { + _ = try await withTimeout(in: .seconds(1), clock: .continuous) { + // Task that will take longer than the timeout + try await Task.sleep(for: .seconds(10), clock: .continuous) + Issue.record("Should not be reached") + } + } + + return await group.nextResult() + } + #expect(throws: TimeOutError.self) { + try result?.get() + } + } + + @Test + func workThrowsError() async throws { + struct TestError: Error { + var message: String + } + await #expect(throws: TestError.self) { + _ = try await withTimeout(in: .seconds(1), clock: .continuous) { + throw TestError(message: "hi") + } + } + } + + @Test + func overallCancelled() async throws { + // Run a task that will not finish for a long time + let workTask = Task { + try await withTimeout(in: .seconds(100), clock: .continuous) { + try await Task.sleep(for: .seconds(10_000)) + } + } + // Cancel it + workTask.cancel() + + // It should throw an error + await #expect(throws: (any Error).self) { + try await workTask.value + } + } + +}