diff --git a/Examples/Streaming/README.md b/Examples/Streaming/README.md index 2c40df57..663ac14c 100644 --- a/Examples/Streaming/README.md +++ b/Examples/Streaming/README.md @@ -82,7 +82,7 @@ You can test the function locally before deploying: swift run # In another terminal, test with curl: -curl -v \ +curl -v --output response.txt \ --header "Content-Type: application/json" \ --data '"this is not used"' \ http://127.0.0.1:7000/invoke diff --git a/Sources/AWSLambdaRuntime/Lambda+LocalServer+Pool.swift b/Sources/AWSLambdaRuntime/Lambda+LocalServer+Pool.swift new file mode 100644 index 00000000..c64a8183 --- /dev/null +++ b/Sources/AWSLambdaRuntime/Lambda+LocalServer+Pool.swift @@ -0,0 +1,138 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftAWSLambdaRuntime open source project +// +// Copyright (c) 2025 Apple Inc. and the SwiftAWSLambdaRuntime project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftAWSLambdaRuntime project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#if LocalServerSupport +import DequeModule +import Synchronization + +@available(LambdaSwift 2.0, *) +extension LambdaHTTPServer { + /// A shared data structure to store the current invocation or response requests and the continuation objects. + /// This data structure is shared between instances of the HTTPHandler + /// (one instance to serve requests from the Lambda function and one instance to serve requests from the client invoking the lambda function). + internal final class Pool: AsyncSequence, AsyncIteratorProtocol, Sendable where T: Sendable { + private let poolName: String + internal init(name: String = "Pool") { self.poolName = name } + + typealias Element = T + + enum State: ~Copyable { + case buffer(Deque) + case continuation(CheckedContinuation) + } + + private let lock = Mutex(.buffer([])) + + /// enqueue an element, or give it back immediately to the iterator if it is waiting for an element + public func push(_ invocation: T) { + + // if the iterator is waiting for an element on `next()``, give it to it + // otherwise, enqueue the element + let maybeContinuation = self.lock.withLock { state -> CheckedContinuation? in + switch consume state { + case .continuation(let continuation): + state = .buffer([]) + return continuation + + case .buffer(var buffer): + buffer.append(invocation) + state = .buffer(buffer) + return nil + } + } + + maybeContinuation?.resume(returning: invocation) + } + + /// AsyncSequence's standard next() function + /// Returns: + /// - nil when the task is cancelled + /// - an element when there is one in the queue + /// + /// When there is no element in the queue, the task will be suspended until an element is pushed to the queue + /// or the task is cancelled + /// + /// - Throws: PoolError if the next() function is called twice concurrently + @Sendable + func next() async throws -> T? { + // exit the async for loop if the task is cancelled + guard !Task.isCancelled else { + return nil + } + + return try await withTaskCancellationHandler { + try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let nextAction: Result? = self.lock.withLock { state -> Result? in + switch consume state { + case .buffer(var buffer): + if let first = buffer.popFirst() { + state = .buffer(buffer) + return .success(first) + } else { + state = .continuation(continuation) + return nil + } + + case .continuation(let previousContinuation): + state = .buffer([]) + return .failure(PoolError(cause: .nextCalledTwice(previousContinuation))) + } + } + + switch nextAction { + case .success(let action): + continuation.resume(returning: action) + case .failure(let error): + if case let .nextCalledTwice(continuation) = error.cause { + continuation.resume(throwing: error) + } + continuation.resume(throwing: error) + case .none: + // do nothing + break + } + } + } onCancel: { + self.lock.withLock { state in + switch consume state { + case .buffer(let buffer): + state = .buffer(buffer) + case .continuation(let continuation): + state = .buffer([]) + continuation.resume(throwing: CancellationError()) + } + } + } + } + + func makeAsyncIterator() -> Pool { + self + } + + struct PoolError: Error { + let cause: Cause + var message: String { + switch self.cause { + case .nextCalledTwice: + return "Concurrent invocations to next(). This is not allowed." + } + } + + enum Cause { + case nextCalledTwice(CheckedContinuation) + } + } + } +} +#endif diff --git a/Sources/AWSLambdaRuntime/Lambda+LocalServer.swift b/Sources/AWSLambdaRuntime/Lambda+LocalServer.swift index c3fa2e5d..61cadc94 100644 --- a/Sources/AWSLambdaRuntime/Lambda+LocalServer.swift +++ b/Sources/AWSLambdaRuntime/Lambda+LocalServer.swift @@ -13,13 +13,11 @@ //===----------------------------------------------------------------------===// #if LocalServerSupport -import DequeModule import Dispatch import Logging import NIOCore import NIOHTTP1 import NIOPosix -import Synchronization // This functionality is designed for local testing when the LocalServerSupport trait is enabled. @@ -95,8 +93,8 @@ extension Lambda { internal struct LambdaHTTPServer { private let invocationEndpoint: String - private let invocationPool = Pool() - private let responsePool = Pool() + private let invocationPool = Pool(name: "Invocation Pool") + private let responsePool = Pool(name: "Response Pool") private init( invocationEndpoint: String? @@ -272,7 +270,7 @@ internal struct LambdaHTTPServer { // for streaming requests, push a partial head response if self.isStreamingResponse(requestHead) { - await self.responsePool.push( + self.responsePool.push( LocalServerResponse( id: requestId, status: .ok @@ -286,7 +284,7 @@ internal struct LambdaHTTPServer { // if this is a request from a Streaming Lambda Handler, // stream the response instead of buffering it if self.isStreamingResponse(requestHead) { - await self.responsePool.push( + self.responsePool.push( LocalServerResponse(id: requestId, body: body) ) } else { @@ -298,7 +296,7 @@ internal struct LambdaHTTPServer { if self.isStreamingResponse(requestHead) { // for streaming response, send the final response - await self.responsePool.push( + self.responsePool.push( LocalServerResponse(id: requestId, final: true) ) } else { @@ -388,34 +386,55 @@ internal struct LambdaHTTPServer { // we always accept the /invoke request and push them to the pool let requestId = "\(DispatchTime.now().uptimeNanoseconds)" logger[metadataKey: "requestId"] = "\(requestId)" + logger.trace("/invoke received invocation, pushing it to the pool and wait for a lambda response") - await self.invocationPool.push(LocalServerInvocation(requestId: requestId, request: body)) + self.invocationPool.push(LocalServerInvocation(requestId: requestId, request: body)) // wait for the lambda function to process the request - for try await response in self.responsePool { - logger[metadataKey: "response requestId"] = "\(response.requestId ?? "nil")" - logger.trace("Received response to return to client") - if response.requestId == requestId { - logger.trace("/invoke requestId is valid, sending the response") - // send the response to the client - // if the response is final, we can send it and return - // if the response is not final, we can send it and wait for the next response - try await self.sendResponse(response, outbound: outbound, logger: logger) - if response.final == true { - logger.trace("/invoke returning") - return // if the response is final, we can return and close the connection + // when POST /invoke is called multiple times before a response is processed, + // the `for try await ... in` loop will throw an error and we will return a 400 error to the client + do { + for try await response in self.responsePool { + logger[metadataKey: "response requestId"] = "\(response.requestId ?? "nil")" + logger.trace("Received response to return to client") + if response.requestId == requestId { + logger.trace("/invoke requestId is valid, sending the response") + // send the response to the client + // if the response is final, we can send it and return + // if the response is not final, we can send it and wait for the next response + try await self.sendResponse(response, outbound: outbound, logger: logger) + if response.final == true { + logger.trace("/invoke returning") + return // if the response is final, we can return and close the connection + } + } else { + logger.error( + "Received response for a different requestId", + metadata: ["response requestId": "\(response.requestId ?? "")"] + ) + let response = LocalServerResponse( + id: requestId, + status: .badRequest, + body: ByteBuffer(string: "The responseId is not equal to the requestId.") + ) + try await self.sendResponse(response, outbound: outbound, logger: logger) } - } else { - logger.error( - "Received response for a different request id", - metadata: ["response requestId": "\(response.requestId ?? "")"] - ) - // should we return an error here ? Or crash as this is probably a programming error? } + // What todo when there is no more responses to process? + // This should not happen as the async iterator blocks until there is a response to process + fatalError("No more responses to process - the async for loop should not return") + } catch is LambdaHTTPServer.Pool.PoolError { + // detect concurrent invocations of POST and gently decline the requests while we're processing one. + let response = LocalServerResponse( + id: requestId, + status: .badRequest, + body: ByteBuffer( + string: + "It is not allowed to invoke multiple Lambda function executions in parallel. (The Lambda runtime environment on AWS will never do that)" + ) + ) + try await self.sendResponse(response, outbound: outbound, logger: logger) } - // What todo when there is no more responses to process? - // This should not happen as the async iterator blocks until there is a response to process - fatalError("No more responses to process - the async for loop should not return") // client uses incorrect HTTP method case (_, let url) where url.hasSuffix(self.invocationEndpoint): @@ -457,7 +476,7 @@ internal struct LambdaHTTPServer { } // enqueue the lambda function response to be served as response to the client /invoke logger.trace("/:requestId/response received response", metadata: ["requestId": "\(requestId)"]) - await self.responsePool.push( + self.responsePool.push( LocalServerResponse( id: requestId, status: .accepted, @@ -488,7 +507,7 @@ internal struct LambdaHTTPServer { } // enqueue the lambda function response to be served as response to the client /invoke logger.trace("/:requestId/response received response", metadata: ["requestId": "\(requestId)"]) - await self.responsePool.push( + self.responsePool.push( LocalServerResponse( id: requestId, status: .internalServerError, @@ -544,85 +563,6 @@ internal struct LambdaHTTPServer { } } - /// A shared data structure to store the current invocation or response requests and the continuation objects. - /// This data structure is shared between instances of the HTTPHandler - /// (one instance to serve requests from the Lambda function and one instance to serve requests from the client invoking the lambda function). - internal final class Pool: AsyncSequence, AsyncIteratorProtocol, Sendable where T: Sendable { - typealias Element = T - - enum State: ~Copyable { - case buffer(Deque) - case continuation(CheckedContinuation?) - } - - private let lock = Mutex(.buffer([])) - - /// enqueue an element, or give it back immediately to the iterator if it is waiting for an element - public func push(_ invocation: T) async { - // if the iterator is waiting for an element, give it to it - // otherwise, enqueue the element - let maybeContinuation = self.lock.withLock { state -> CheckedContinuation? in - switch consume state { - case .continuation(let continuation): - state = .buffer([]) - return continuation - - case .buffer(var buffer): - buffer.append(invocation) - state = .buffer(buffer) - return nil - } - } - - maybeContinuation?.resume(returning: invocation) - } - - func next() async throws -> T? { - // exit the async for loop if the task is cancelled - guard !Task.isCancelled else { - return nil - } - - return try await withTaskCancellationHandler { - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - let nextAction = self.lock.withLock { state -> T? in - switch consume state { - case .buffer(var buffer): - if let first = buffer.popFirst() { - state = .buffer(buffer) - return first - } else { - state = .continuation(continuation) - return nil - } - - case .continuation: - fatalError("Concurrent invocations to next(). This is illegal.") - } - } - - guard let nextAction else { return } - - continuation.resume(returning: nextAction) - } - } onCancel: { - self.lock.withLock { state in - switch consume state { - case .buffer(let buffer): - state = .buffer(buffer) - case .continuation(let continuation): - continuation?.resume(throwing: CancellationError()) - state = .buffer([]) - } - } - } - } - - func makeAsyncIterator() -> Pool { - self - } - } - private struct LocalServerResponse: Sendable { let requestId: String? let status: HTTPResponseStatus? diff --git a/Tests/AWSLambdaRuntimeTests/PoolTests.swift b/Tests/AWSLambdaRuntimeTests/PoolTests.swift index 8cbe8a2e..1e2fff2e 100644 --- a/Tests/AWSLambdaRuntimeTests/PoolTests.swift +++ b/Tests/AWSLambdaRuntimeTests/PoolTests.swift @@ -24,8 +24,8 @@ struct PoolTests { let pool = LambdaHTTPServer.Pool() // Push values - await pool.push("first") - await pool.push("second") + pool.push("first") + pool.push("second") // Iterate and verify order var values = [String]() @@ -53,7 +53,11 @@ struct PoolTests { task.cancel() // This should complete without receiving any values - try await task.value + do { + try await task.value + } catch is CancellationError { + // this might happen depending on the order on which the cancellation is handled + } } @Test @@ -78,7 +82,7 @@ struct PoolTests { try await withThrowingTaskGroup(of: Void.self) { group in for i in 0..() + + // Create two tasks that will both wait for elements to be available + await #expect(throws: LambdaHTTPServer.Pool.PoolError.self) { + try await withThrowingTaskGroup(of: Void.self) { group in + + // one of the two task will throw a PoolError + + group.addTask { + for try await _ in pool { + } + Issue.record("Loop 1 should not complete") + } + + group.addTask { + for try await _ in pool { + } + Issue.record("Loop 2 should not complete") + } + try await group.waitForAll() + } + } + } + }