diff --git a/Package.swift b/Package.swift index 42e4c3ed..622b0770 100644 --- a/Package.swift +++ b/Package.swift @@ -52,7 +52,10 @@ let package = Package( .target(name: "MockServer", dependencies: [ .product(name: "NIOHTTP1", package: "swift-nio"), ]), - .target(name: "StringSample", dependencies: ["AWSLambdaRuntime"]), + .target(name: "StringSample", dependencies: [ + .byName(name: "AWSLambdaRuntime"), + .byName(name: "AWSLambdaTesting"), + ]), .target(name: "CodableSample", dependencies: ["AWSLambdaRuntime"]), ] ) diff --git a/Sources/AWSLambdaRuntimeCore/Lambda+LocalServer.swift b/Sources/AWSLambdaRuntimeCore/Lambda+LocalServer.swift deleted file mode 100644 index e39c1179..00000000 --- a/Sources/AWSLambdaRuntimeCore/Lambda+LocalServer.swift +++ /dev/null @@ -1,268 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the SwiftAWSLambdaRuntime open source project -// -// Copyright (c) 2020 Apple Inc. and the SwiftAWSLambdaRuntime project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of SwiftAWSLambdaRuntime project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// - -#if DEBUG -import Dispatch -import Logging -import NIO -import NIOConcurrencyHelpers -import NIOHTTP1 - -// This functionality is designed for local testing hence beind a #if DEBUG flag. -// For example: -// -// try Lambda.withLocalServer { -// Lambda.run { (context: Lambda.Context, payload: String, callback: @escaping (Result) -> Void) in -// callback(.success("Hello, \(payload)!")) -// } -// } -extension Lambda { - /// Execute code in the context of a mock Lambda server. - /// - /// - parameters: - /// - invocationEndpoint: The endpoint to post payloads to. - /// - body: Code to run within the context of the mock server. Typically this would be a Lambda.run function call. - /// - /// - note: This API is designed stricly for local testing and is behind a DEBUG flag - public static func withLocalServer(invocationEndpoint: String? = nil, _ body: @escaping () -> Void) throws { - let server = LocalLambda.Server(invocationEndpoint: invocationEndpoint) - try server.start().wait() - defer { try! server.stop() } // FIXME: - body() - } -} - -// MARK: - Local Mock Server - -private enum LocalLambda { - struct Server { - private let logger: Logger - private let group: EventLoopGroup - private let host: String - private let port: Int - private let invocationEndpoint: String - - public init(invocationEndpoint: String?) { - let configuration = Lambda.Configuration() - var logger = Logger(label: "LocalLambdaServer") - logger.logLevel = configuration.general.logLevel - self.logger = logger - self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1) - self.host = configuration.runtimeEngine.ip - self.port = configuration.runtimeEngine.port - self.invocationEndpoint = invocationEndpoint ?? "/invoke" - } - - func start() -> EventLoopFuture { - let bootstrap = ServerBootstrap(group: group) - .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) - .childChannelInitializer { channel in - channel.pipeline.configureHTTPServerPipeline(withErrorHandling: true).flatMap { _ in - channel.pipeline.addHandler(HTTPHandler(logger: self.logger, invocationEndpoint: self.invocationEndpoint)) - } - } - return bootstrap.bind(host: self.host, port: self.port).flatMap { channel -> EventLoopFuture in - guard channel.localAddress != nil else { - return channel.eventLoop.makeFailedFuture(ServerError.cantBind) - } - self.logger.info("LocalLambdaServer started and listening on \(self.host):\(self.port), receiving payloads on \(self.invocationEndpoint)") - return channel.eventLoop.makeSucceededFuture(()) - } - } - - func stop() throws { - try self.group.syncShutdownGracefully() - } - } - - final class HTTPHandler: ChannelInboundHandler { - public typealias InboundIn = HTTPServerRequestPart - public typealias OutboundOut = HTTPServerResponsePart - - private var pending = CircularBuffer<(head: HTTPRequestHead, body: ByteBuffer?)>() - - private static var invocations = CircularBuffer() - private static var invocationState = InvocationState.waitingForLambdaRequest - - private let logger: Logger - private let invocationEndpoint: String - - init(logger: Logger, invocationEndpoint: String) { - self.logger = logger - self.invocationEndpoint = invocationEndpoint - } - - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let requestPart = unwrapInboundIn(data) - - switch requestPart { - case .head(let head): - self.pending.append((head: head, body: nil)) - case .body(var buffer): - var request = self.pending.removeFirst() - if request.body == nil { - request.body = buffer - } else { - request.body!.writeBuffer(&buffer) - } - self.pending.prepend(request) - case .end: - let request = self.pending.removeFirst() - self.processRequest(context: context, request: request) - } - } - - func processRequest(context: ChannelHandlerContext, request: (head: HTTPRequestHead, body: ByteBuffer?)) { - switch (request.head.method, request.head.uri) { - // this endpoint is called by the client invoking the lambda - case (.POST, let url) where url.hasSuffix(self.invocationEndpoint): - guard let work = request.body else { - return self.writeResponse(context: context, response: .init(status: .badRequest)) - } - let requestID = "\(DispatchTime.now().uptimeNanoseconds)" // FIXME: - let promise = context.eventLoop.makePromise(of: Response.self) - promise.futureResult.whenComplete { result in - switch result { - case .failure(let error): - self.logger.error("invocation error: \(error)") - self.writeResponse(context: context, response: .init(status: .internalServerError)) - case .success(let response): - self.writeResponse(context: context, response: response) - } - } - let invocation = Invocation(requestID: requestID, request: work, responsePromise: promise) - switch Self.invocationState { - case .waitingForInvocation(let promise): - promise.succeed(invocation) - case .waitingForLambdaRequest, .waitingForLambdaResponse: - Self.invocations.append(invocation) - } - // /next endpoint is called by the lambda polling for work - case (.GET, let url) where url.hasSuffix(Consts.getNextInvocationURLSuffix): - // check if our server is in the correct state - guard case .waitingForLambdaRequest = Self.invocationState else { - self.logger.error("invalid invocation state \(Self.invocationState)") - self.writeResponse(context: context, response: .init(status: .unprocessableEntity)) - return - } - - // pop the first task from the queue - switch Self.invocations.popFirst() { - case .none: - // if there is nothing in the queue, - // create a promise that we can fullfill when we get a new task - let promise = context.eventLoop.makePromise(of: Invocation.self) - promise.futureResult.whenComplete { result in - switch result { - case .failure(let error): - self.logger.error("invocation error: \(error)") - self.writeResponse(context: context, response: .init(status: .internalServerError)) - case .success(let invocation): - Self.invocationState = .waitingForLambdaResponse(invocation) - self.writeResponse(context: context, response: invocation.makeResponse()) - } - } - Self.invocationState = .waitingForInvocation(promise) - case .some(let invocation): - // if there is a task pending, we can immediatly respond with it. - Self.invocationState = .waitingForLambdaResponse(invocation) - self.writeResponse(context: context, response: invocation.makeResponse()) - } - // :requestID/response endpoint is called by the lambda posting the response - case (.POST, let url) where url.hasSuffix(Consts.postResponseURLSuffix): - let parts = request.head.uri.split(separator: "/") - guard let requestID = parts.count > 2 ? String(parts[parts.count - 2]) : nil else { - // the request is malformed, since we were expecting a requestId in the path - return self.writeResponse(context: context, response: .init(status: .badRequest)) - } - guard case .waitingForLambdaResponse(let invocation) = Self.invocationState else { - // a response was send, but we did not expect to receive one - self.logger.error("invalid invocation state \(Self.invocationState)") - return self.writeResponse(context: context, response: .init(status: .unprocessableEntity)) - } - guard requestID == invocation.requestID else { - // the request's requestId is not matching the one we are expecting - self.logger.error("invalid invocation state request ID \(requestID) does not match expected \(invocation.requestID)") - return self.writeResponse(context: context, response: .init(status: .badRequest)) - } - - invocation.responsePromise.succeed(.init(status: .ok, body: request.body)) - self.writeResponse(context: context, response: .init(status: .accepted)) - Self.invocationState = .waitingForLambdaRequest - // unknown call - default: - self.writeResponse(context: context, response: .init(status: .notFound)) - } - } - - func writeResponse(context: ChannelHandlerContext, response: Response) { - var headers = HTTPHeaders(response.headers ?? []) - headers.add(name: "content-length", value: "\(response.body?.readableBytes ?? 0)") - let head = HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: response.status, headers: headers) - - context.write(wrapOutboundOut(.head(head))).whenFailure { error in - self.logger.error("\(self) write error \(error)") - } - - if let buffer = response.body { - context.write(wrapOutboundOut(.body(.byteBuffer(buffer)))).whenFailure { error in - self.logger.error("\(self) write error \(error)") - } - } - - context.writeAndFlush(wrapOutboundOut(.end(nil))).whenComplete { result in - if case .failure(let error) = result { - self.logger.error("\(self) write error \(error)") - } - } - } - - struct Response { - var status: HTTPResponseStatus = .ok - var headers: [(String, String)]? - var body: ByteBuffer? - } - - struct Invocation { - let requestID: String - let request: ByteBuffer - let responsePromise: EventLoopPromise - - func makeResponse() -> Response { - var response = Response() - response.body = self.request - // required headers - response.headers = [ - (AmazonHeaders.requestID, self.requestID), - (AmazonHeaders.invokedFunctionARN, "arn:aws:lambda:us-east-1:\(Int16.random(in: Int16.min ... Int16.max)):function:custom-runtime"), - (AmazonHeaders.traceID, "Root=\(Int16.random(in: Int16.min ... Int16.max));Parent=\(Int16.random(in: Int16.min ... Int16.max));Sampled=1"), - (AmazonHeaders.deadline, "\(DispatchWallTime.distantFuture.millisSinceEpoch)"), - ] - return response - } - } - - enum InvocationState { - case waitingForInvocation(EventLoopPromise) - case waitingForLambdaRequest - case waitingForLambdaResponse(Invocation) - } - } - - enum ServerError: Error { - case notReady - case cantBind - } -} -#endif diff --git a/Sources/AWSLambdaTesting/APIGatewayV2Proxy.swift b/Sources/AWSLambdaTesting/APIGatewayV2Proxy.swift new file mode 100644 index 00000000..511b903b --- /dev/null +++ b/Sources/AWSLambdaTesting/APIGatewayV2Proxy.swift @@ -0,0 +1,33 @@ + +#if DEBUG +import AWSLambdaEvents +import NIO + +struct APIGatewayV2Proxy: LocalLambdaInvocationProxy { + let eventLoop: EventLoop + + init(eventLoop: EventLoop) { + self.eventLoop = eventLoop + } + + func invocation(from request: HTTPRequest) -> EventLoopFuture { + switch (request.method, request.uri) { + case (.POST, "/invoke"): + guard let body = request.body else { + return self.eventLoop.makeFailedFuture(InvocationHTTPError(.init(status: .badRequest))) + } + return self.eventLoop.makeSucceededFuture(body) + default: + return self.eventLoop.makeFailedFuture(InvocationHTTPError(.init(status: .notFound))) + } + } + + func processResult(_ result: ByteBuffer?) -> EventLoopFuture { + self.eventLoop.makeSucceededFuture(.init(status: .ok, body: result)) + } + + func processError(_ error: ByteBuffer?) -> EventLoopFuture { + self.eventLoop.makeSucceededFuture(.init(status: .internalServerError, body: error)) + } +} +#endif diff --git a/Sources/AWSLambdaTesting/Lambda+LocalServer.swift b/Sources/AWSLambdaTesting/Lambda+LocalServer.swift new file mode 100644 index 00000000..f90eea34 --- /dev/null +++ b/Sources/AWSLambdaTesting/Lambda+LocalServer.swift @@ -0,0 +1,518 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftAWSLambdaRuntime open source project +// +// Copyright (c) 2020 Apple Inc. and the SwiftAWSLambdaRuntime project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftAWSLambdaRuntime project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#if DEBUG +@testable import AWSLambdaRuntimeCore +import Dispatch +import Logging +import NIO +import NIOConcurrencyHelpers +import NIOHTTP1 + +// This functionality is designed for local testing hence beind a #if DEBUG flag. +// For example: +// +// try Lambda.withLocalServer { +// Lambda.run { (context: Lambda.Context, payload: String, callback: @escaping (Result) -> Void) in +// callback(.success("Hello, \(payload)!")) +// } +// } +extension Lambda { + /// Execute code in the context of a mock Lambda server. + /// + /// - parameters: + /// - invocationEndpoint: The endpoint to post payloads to. + /// - body: Code to run within the context of the mock server. Typically this would be a Lambda.run function call. + /// + /// - note: This API is designed stricly for local testing and is behind a DEBUG flag + public static func withLocalServer(proxyType: LocalLambdaInvocationProxy.Type = InvokeProxy.self, _ body: @escaping () -> Void) throws { + let server = LocalLambda.Server(proxyType: proxyType) + try server.start().wait() + defer { try! server.stop() } // FIXME: + body() + } +} + +public struct HTTPRequest { + let method: HTTPMethod + let uri: String + let headers: [(String, String)] + let body: ByteBuffer? + + internal init(head: HTTPRequestHead, body: ByteBuffer?) { + self.method = head.method + self.headers = head.headers.map { $0 } + self.uri = head.uri + self.body = body + } +} + +public struct HTTPResponse { + var status: HTTPResponseStatus = .ok + var headers: [(String, String)]? + var body: ByteBuffer? +} + +public struct InvocationHTTPError: Error { + let response: HTTPResponse + + init(_ response: HTTPResponse) { + self.response = response + } +} + +public protocol LocalLambdaInvocationProxy { + init(eventLoop: EventLoop) + + /// throws HTTPError + func invocation(from request: HTTPRequest) -> EventLoopFuture + func processResult(_ result: ByteBuffer?) -> EventLoopFuture + func processError(_ error: ByteBuffer?) -> EventLoopFuture +} + +// MARK: - Local Mock Server + +private enum LocalLambda { + struct Server { + private let logger: Logger + private let eventLoopGroup: EventLoopGroup + private let eventLoop: EventLoop + private let controlPlaneHost: String + private let controlPlanePort: Int + private let invokeAPIHost: String + private let invokeAPIPort: Int + private let proxy: LocalLambdaInvocationProxy + + public init(proxyType: LocalLambdaInvocationProxy.Type) { + let configuration = Lambda.Configuration() + var logger = Logger(label: "LocalLambdaServer") + logger.logLevel = configuration.general.logLevel + self.logger = logger + self.eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + self.eventLoop = self.eventLoopGroup.next() + self.controlPlaneHost = configuration.runtimeEngine.ip + self.controlPlanePort = configuration.runtimeEngine.port + self.invokeAPIHost = configuration.runtimeEngine.ip + self.invokeAPIPort = configuration.runtimeEngine.port + 1 + self.proxy = proxyType.init(eventLoop: self.eventLoop) + } + + func start() -> EventLoopFuture { + let state = ServerState(eventLoop: self.eventLoop, logger: self.logger, proxy: self.proxy) + + let controlPlaneBootstrap = ServerBootstrap(group: eventLoopGroup) + .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) + .childChannelInitializer { channel in + channel.pipeline.configureHTTPServerPipeline(withErrorHandling: true).flatMap { _ in + channel.pipeline.addHandler(ControlPlaneHandler(logger: self.logger, serverState: state)) + } + } + + let invokeBootstrap = ServerBootstrap(group: eventLoopGroup) + .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) + .childChannelInitializer { channel in + channel.pipeline.configureHTTPServerPipeline(withErrorHandling: true).flatMap { _ in + channel.pipeline.addHandler(InvokeHandler(logger: self.logger, serverState: state)) + } + } + + let controlPlaneFuture = controlPlaneBootstrap.bind(host: self.controlPlaneHost, port: self.controlPlanePort).flatMap { channel -> EventLoopFuture in + guard channel.localAddress != nil else { + return channel.eventLoop.makeFailedFuture(ServerError.cantBind) + } + self.logger.info("Control plane api started and listening on \(self.controlPlaneHost):\(self.controlPlanePort)") + return channel.eventLoop.makeSucceededFuture(()) + } + + let invokeAPIFuture = invokeBootstrap.bind(host: self.invokeAPIHost, port: self.invokeAPIPort).flatMap { channel -> EventLoopFuture in + guard channel.localAddress != nil else { + return channel.eventLoop.makeFailedFuture(ServerError.cantBind) + } + self.logger.info("Invocation proxy api started and listening on \(self.controlPlaneHost):\(self.controlPlanePort + 1)") + return channel.eventLoop.makeSucceededFuture(()) + } + + return controlPlaneFuture.and(invokeAPIFuture).map { _ in Void() } + } + + func stop() throws { + try self.eventLoopGroup.syncShutdownGracefully() + } + } + + final class ServerState { + private enum State { + case waitingForInvocation(EventLoopPromise) + case waitingForLambdaRequest + case waitingForLambdaResponse(Invocation) + } + + enum Error: Swift.Error { + case invalidState + case invalidRequestId + } + + private var invocations = CircularBuffer() + private var state = State.waitingForLambdaRequest + private var logger: Logger + + let eventLoop: EventLoop + let proxy: LocalLambdaInvocationProxy + + init(eventLoop: EventLoop, logger: Logger, proxy: LocalLambdaInvocationProxy) { + self.eventLoop = eventLoop + self.logger = logger + self.proxy = proxy + } + + // MARK: Invocation API + + func queueInvocationRequest(_ httpRequest: HTTPRequest) -> EventLoopFuture { + self.proxy.invocation(from: httpRequest).flatMap { byteBuffer in + let promise = self.eventLoop.makePromise(of: HTTPResponse.self) + + let uuid = "\(DispatchTime.now().uptimeNanoseconds)" // FIXME: + let invocation = Invocation(requestID: uuid, request: byteBuffer, responsePromise: promise) + + switch self.state { + case .waitingForInvocation(let promise): + self.state = .waitingForLambdaResponse(invocation) + promise.succeed(invocation) + default: + self.invocations.append(invocation) + } + + return promise.futureResult + } + } + + // MARK: Lambda Control Plane API + + func getNextInvocation() -> EventLoopFuture { + guard case .waitingForLambdaRequest = self.state else { + self.logger.error("invalid invocation state \(self.state)") + return self.eventLoop.makeFailedFuture(Error.invalidState) + } + + switch self.invocations.popFirst() { + case .some(let invocation): + // if there is a task pending, we can immediatly respond with it. + self.state = .waitingForLambdaResponse(invocation) + return self.eventLoop.makeSucceededFuture(invocation) + case .none: + // if there is nothing in the queue, + // create a promise that we can fullfill when we get a new task + let promise = self.eventLoop.makePromise(of: Invocation.self) + self.state = .waitingForInvocation(promise) + return promise.futureResult + } + } + + func processInvocationResult(for invocationId: String, body: ByteBuffer?) throws { + let invocation = try self.pendingInvocation(for: invocationId) + self.state = .waitingForLambdaRequest + + self.proxy.processResult(body).whenComplete { result in + switch result { + case .success(let response): + invocation.responsePromise.succeed(response) + case .failure(let error): + invocation.responsePromise.fail(error) + } + } + } + + func processInvocationError(for invocationId: String, body: ByteBuffer?) throws { + let invocation = try self.pendingInvocation(for: invocationId) + self.state = .waitingForLambdaRequest + + self.proxy.processError(body).whenComplete { result in + switch result { + case .success(let response): + invocation.responsePromise.succeed(response) + case .failure(let error): + invocation.responsePromise.fail(error) + } + } + } + + private func pendingInvocation(for requestID: String) throws -> Invocation { + guard case .waitingForLambdaResponse(let invocation) = self.state else { + // a response was send, but we did not expect to receive one + self.logger.error("invalid invocation state \(self.state)") + throw Error.invalidState + } + guard requestID == invocation.requestID else { + // the request's requestId is not matching the one we are expecting + self.logger.error("invalid invocation state request ID \(requestID) does not match expected \(invocation.requestID)") + throw Error.invalidRequestId + } + + return invocation + } + } + + final class ControlPlaneHandler: ChannelInboundHandler { + public typealias InboundIn = HTTPServerRequestPart + public typealias OutboundOut = HTTPServerResponsePart + + private var pending = CircularBuffer<(head: HTTPRequestHead, body: ByteBuffer?)>() + + private let serverState: ServerState + private let logger: Logger + + init(logger: Logger, serverState: ServerState) { + self.logger = logger + self.serverState = serverState + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let requestPart = unwrapInboundIn(data) + + switch requestPart { + case .head(let head): + self.pending.append((head: head, body: nil)) + case .body(var buffer): + var request = self.pending.removeFirst() + if request.body == nil { + request.body = buffer + } else { + request.body!.writeBuffer(&buffer) + } + self.pending.prepend(request) + case .end: + let request = self.pending.removeFirst() + self.processRequest(context: context, request: request) + } + } + + func processRequest(context: ChannelHandlerContext, request: (head: HTTPRequestHead, body: ByteBuffer?)) { + switch (request.head.method, request.head.uri) { + // /next endpoint is called by the lambda polling for work + case (.GET, let url) where url.hasSuffix(Consts.getNextInvocationURLSuffix): + // check if our server is in the correct state + self.serverState.getNextInvocation().whenComplete { result in + switch result { + case .success(let invocation): + self.writeResponse(context: context, response: invocation.makeResponse()) + case .failure(let error): + self.logger.error("invocation error: \(error)") + self.writeResponse(context: context, response: .init(status: .internalServerError)) + } + } + + // :requestID/response endpoint is called by the lambda posting the response + case (.POST, let url) where url.hasSuffix(Consts.postResponseURLSuffix): + let parts = request.head.uri.split(separator: "/") + guard let requestID = parts.count > 2 ? String(parts[parts.count - 2]) : nil else { + // the request is malformed, since we were expecting a requestId in the path + return self.writeResponse(context: context, response: .init(status: .badRequest)) + } + + do { + // a sync call here looks... interesting. + try self.serverState.processInvocationResult(for: requestID, body: request.body) + self.writeResponse(context: context, response: .init(status: .accepted)) + } catch { + self.writeResponse(context: context, response: .init(status: .badRequest)) + } + + // :requestID/error endpoint is called by the lambda posting an error + case (.POST, let url) where url.hasSuffix(Consts.postErrorURLSuffix): + let parts = request.head.uri.split(separator: "/") + guard let requestID = parts.count > 2 ? String(parts[parts.count - 2]) : nil else { + // the request is malformed, since we were expecting a requestId in the path + return self.writeResponse(context: context, response: .init(status: .badRequest)) + } + + do { + // a sync call here looks... interesting. + try self.serverState.processInvocationError(for: requestID, body: request.body) + self.writeResponse(context: context, response: .init(status: .accepted)) + } catch { + self.writeResponse(context: context, response: .init(status: .badRequest)) + } + + // unknown call + default: + self.writeResponse(context: context, response: .init(status: .notFound)) + } + } + + func writeResponse(context: ChannelHandlerContext, response: HTTPResponse) { + var headers = HTTPHeaders(response.headers ?? []) + headers.add(name: "content-length", value: "\(response.body?.readableBytes ?? 0)") + let head = HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: response.status, headers: headers) + + context.write(wrapOutboundOut(.head(head))).whenFailure { error in + self.logger.error("\(self) write error \(error)") + } + + if let buffer = response.body { + context.write(wrapOutboundOut(.body(.byteBuffer(buffer)))).whenFailure { error in + self.logger.error("\(self) write error \(error)") + } + } + + context.writeAndFlush(wrapOutboundOut(.end(nil))).whenComplete { result in + if case .failure(let error) = result { + self.logger.error("\(self) write error \(error)") + } + } + } + + struct Invocation { + let requestID: String + let request: ByteBuffer + let responsePromise: EventLoopPromise + + func makeResponse() -> HTTPResponse { + var response = HTTPResponse() + response.body = self.request + // required headers + response.headers = [ + (AmazonHeaders.requestID, self.requestID), + (AmazonHeaders.invokedFunctionARN, "arn:aws:lambda:us-east-1:\(Int16.random(in: Int16.min ... Int16.max)):function:custom-runtime"), + (AmazonHeaders.traceID, "Root=\(Int16.random(in: Int16.min ... Int16.max));Parent=\(Int16.random(in: Int16.min ... Int16.max));Sampled=1"), + (AmazonHeaders.deadline, "\(DispatchWallTime.distantFuture.millisSinceEpoch)"), + ] + return response + } + } + } + + final class InvokeHandler: ChannelInboundHandler { + public typealias InboundIn = HTTPServerRequestPart + public typealias OutboundOut = HTTPServerResponsePart + + private var pending = CircularBuffer<(head: HTTPRequestHead, body: ByteBuffer?)>() + + private let serverState: ServerState + private let logger: Logger + + init(logger: Logger, serverState: ServerState) { + self.logger = logger + self.serverState = serverState + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let requestPart = unwrapInboundIn(data) + + switch requestPart { + case .head(let head): + self.pending.append((head: head, body: nil)) + case .body(var buffer): + var request = self.pending.removeFirst() + if request.body == nil { + request.body = buffer + } else { + request.body!.writeBuffer(&buffer) + } + self.pending.prepend(request) + case .end: + let request = self.pending.removeFirst() + self.processRequest(context: context, request: request) + } + } + + func processRequest(context: ChannelHandlerContext, request: (head: HTTPRequestHead, body: ByteBuffer?)) { + self.serverState.queueInvocationRequest(HTTPRequest(head: request.head, body: request.body)).whenComplete { result in + switch result { + case .success(let response): + self.writeResponse(context: context, response: response) + case .failure(let error as InvocationHTTPError): + self.writeResponse(context: context, response: error.response) + case .failure: + self.writeResponse(context: context, response: .init(status: .internalServerError)) + } + } + } + + func writeResponse(context: ChannelHandlerContext, response: HTTPResponse) { + var headers = HTTPHeaders(response.headers ?? []) + headers.add(name: "content-length", value: "\(response.body?.readableBytes ?? 0)") + let head = HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: response.status, headers: headers) + + context.write(wrapOutboundOut(.head(head))).whenFailure { error in + self.logger.error("\(self) write error \(error)") + } + + if let buffer = response.body { + context.write(wrapOutboundOut(.body(.byteBuffer(buffer)))).whenFailure { error in + self.logger.error("\(self) write error \(error)") + } + } + + context.writeAndFlush(wrapOutboundOut(.end(nil))).whenComplete { result in + if case .failure(let error) = result { + self.logger.error("\(self) write error \(error)") + } + } + } + } + + struct Invocation { + let requestID: String + let request: ByteBuffer + let responsePromise: EventLoopPromise + + func makeResponse() -> HTTPResponse { + var response = HTTPResponse() + response.body = self.request + // required headers + response.headers = [ + (AmazonHeaders.requestID, self.requestID), + (AmazonHeaders.invokedFunctionARN, "arn:aws:lambda:us-east-1:\(Int16.random(in: Int16.min ... Int16.max)):function:custom-runtime"), + (AmazonHeaders.traceID, "Root=\(Int16.random(in: Int16.min ... Int16.max));Parent=\(Int16.random(in: Int16.min ... Int16.max));Sampled=1"), + (AmazonHeaders.deadline, "\(DispatchWallTime.distantFuture.millisSinceEpoch)"), + ] + return response + } + } + + enum ServerError: Error { + case notReady + case cantBind + } +} + +public struct InvokeProxy: LocalLambdaInvocationProxy { + let eventLoop: EventLoop + + public init(eventLoop: EventLoop) { + self.eventLoop = eventLoop + } + + public func invocation(from request: HTTPRequest) -> EventLoopFuture { + switch (request.method, request.uri) { + case (.POST, "/invoke"): + guard let body = request.body else { + return self.eventLoop.makeFailedFuture(InvocationHTTPError(.init(status: .badRequest))) + } + return self.eventLoop.makeSucceededFuture(body) + default: + return self.eventLoop.makeFailedFuture(InvocationHTTPError(.init(status: .notFound))) + } + } + + public func processResult(_ result: ByteBuffer?) -> EventLoopFuture { + self.eventLoop.makeSucceededFuture(.init(status: .ok, body: result)) + } + + public func processError(_ error: ByteBuffer?) -> EventLoopFuture { + self.eventLoop.makeSucceededFuture(.init(status: .internalServerError, body: error)) + } +} + +#endif