diff --git a/Sources/NIOHTTP1/HTTPDecoder.swift b/Sources/NIOHTTP1/HTTPDecoder.swift index 01de669fdd5..59a0163c3f2 100644 --- a/Sources/NIOHTTP1/HTTPDecoder.swift +++ b/Sources/NIOHTTP1/HTTPDecoder.swift @@ -34,6 +34,9 @@ private struct HTTPParserState { var currentError: HTTPParserError? var seenEOF = false var headerStartIndex: Int? + + // Holds the data we need to forward via ctx.fireChannelRead(...) after invoking the parser. + var pendingInOut: NIOAny? = nil enum DataAwaitingState { case messageBegin @@ -57,6 +60,7 @@ private struct HTTPParserState { self.currentStatus = nil self.slice = nil self.headerStartIndex = nil + self.pendingInOut = nil } var cumulationBuffer: ByteBuffer? @@ -196,7 +200,6 @@ public class HTTPDecoder: ByteToMessageDecoder, AnyHTTPDecoder { private var parser = http_parser() private var settings = http_parser_settings() - private var decoding: Bool = false fileprivate var state = HTTPParserState() @@ -234,6 +237,25 @@ public class HTTPDecoder: ByteToMessageDecoder, AnyHTTPDecoder { return response } + private func bytesToForwardOnRemoval(ctx: ChannelHandlerContext) -> ByteBuffer? { + guard self.parser.upgrade == 1 && ctx.channel.isActive else { + return nil + } + // We take a slice of the cumulationBuffer so the next handler in the pipeline will just see the readable portion of the buffer. + // While this is not strictly needed it may make it easier to consume. + if let buffer = self.cumulationBuffer?.slice(), buffer.readableBytes > 0 { + return buffer + } + return nil + } + + public func handlerRemoved(ctx: ChannelHandlerContext) { + if let buffer = self.bytesToForwardOnRemoval(ctx: ctx) { + ctx.fireChannelRead(NIOAny(buffer)) + } + self.cumulationBuffer = nil + } + public func decoderAdded(ctx: ChannelHandlerContext) { if HTTPMessageT.self == HTTPServerRequestPart.self { c_nio_http_parser_init(&self.parser, HTTP_REQUEST) @@ -258,6 +280,11 @@ public class HTTPDecoder: ByteToMessageDecoder, AnyHTTPDecoder { let ctx = evacuateChannelHandlerContext(parser) let handler = ctx.handler as! AnyHTTPDecoder + // Ensure we pause the parser after this callback is complete so we can safely callout + // to the pipeline. + c_nio_http_parser_pause(parser, 1) + assert(handler.state.pendingInOut == nil) + handler.state.complete(state: handler.state.dataAwaitingState) handler.state.dataAwaitingState = .body @@ -269,7 +296,7 @@ public class HTTPDecoder: ByteToMessageDecoder, AnyHTTPDecoder { return -1 } - ctx.fireChannelRead(handler.wrapInboundOut(HTTPServerRequestPart.head(head))) + handler.state.pendingInOut = handler.wrapInboundOut(HTTPServerRequestPart.head(head)) return 0 case let handler as HTTPResponseDecoder: let head = handler.newResponseHead(parser) @@ -278,7 +305,7 @@ public class HTTPDecoder: ByteToMessageDecoder, AnyHTTPDecoder { return -1 } - ctx.fireChannelRead(handler.wrapInboundOut(HTTPClientResponsePart.head(head))) + handler.state.pendingInOut = handler.wrapInboundOut(HTTPClientResponsePart.head(head)) // http_parser doesn't correctly handle responses to HEAD requests. We have to do something // annoyingly opaque here, and in those cases return 1 instead of 0. This forces http_parser @@ -322,14 +349,19 @@ public class HTTPDecoder: ByteToMessageDecoder, AnyHTTPDecoder { let handler = ctx.handler as! AnyHTTPDecoder assert(handler.state.dataAwaitingState == .body) + // Ensure we pause the parser after this callback is complete so we can safely callout + // to the pipeline. + c_nio_http_parser_pause(parser, 1) + assert(handler.state.pendingInOut == nil) + // Calculate the index of the data in the cumulationBuffer so we can slice out the ByteBuffer without doing any memory copy let index = handler.state.calculateIndex(data: data!, length: len) let slice = handler.state.cumulationBuffer!.getSlice(at: index, length: len)! switch handler { case let handler as HTTPRequestDecoder: - ctx.fireChannelRead(handler.wrapInboundOut(HTTPServerRequestPart.body(slice))) + handler.state.pendingInOut = handler.wrapInboundOut(HTTPServerRequestPart.body(slice)) case let handler as HTTPResponseDecoder: - ctx.fireChannelRead(handler.wrapInboundOut(HTTPClientResponsePart.body(slice))) + handler.state.pendingInOut = handler.wrapInboundOut(HTTPClientResponsePart.body(slice)) default: fatalError("the impossible happened: handler neither a HTTPRequestDecoder nor a HTTPResponseDecoder which should be impossible") } @@ -387,6 +419,12 @@ public class HTTPDecoder: ByteToMessageDecoder, AnyHTTPDecoder { self.settings.on_message_complete = { parser in let ctx = evacuateChannelHandlerContext(parser) let handler = ctx.handler as! AnyHTTPDecoder + + // Ensure we pause the parser after this callback is complete so we can safely callout + // to the pipeline. + c_nio_http_parser_pause(parser, 1) + assert(handler.state.pendingInOut == nil) + handler.state.complete(state: handler.state.dataAwaitingState) handler.state.dataAwaitingState = .messageBegin @@ -395,9 +433,9 @@ public class HTTPDecoder: ByteToMessageDecoder, AnyHTTPDecoder { handler.state.currentHeaders.removeAll(keepingCapacity: true) switch handler { case let handler as HTTPRequestDecoder: - ctx.fireChannelRead(handler.wrapInboundOut(HTTPServerRequestPart.end(trailers))) + handler.state.pendingInOut = handler.wrapInboundOut(HTTPServerRequestPart.end(trailers)) case let handler as HTTPResponseDecoder: - ctx.fireChannelRead(handler.wrapInboundOut(HTTPClientResponsePart.end(trailers))) + handler.state.pendingInOut = handler.wrapInboundOut(HTTPClientResponsePart.end(trailers)) default: fatalError("the impossible happened: handler neither a HTTPRequestDecoder nor a HTTPResponseDecoder which should be impossible") } @@ -411,10 +449,9 @@ public class HTTPDecoder: ByteToMessageDecoder, AnyHTTPDecoder { throw error } - let httpError = self.parser.http_errno - if httpError != 0 { - self.state.currentError = HTTPParserError.httpError(fromCHTTPParserErrno: http_errno(rawValue: httpError))! - throw self.state.currentError! + if let parserError = self.currentParserError() { + self.state.currentError = parserError + throw parserError } } @@ -432,13 +469,15 @@ public class HTTPDecoder: ByteToMessageDecoder, AnyHTTPDecoder { c_nio_http_parser_execute(&self.parser, &self.settings, p.advanced(by: bufferSlice.readerIndex), bufferSlice.readableBytes) } } + try self.rethrowParserError() - /// http_parser_execute(...) should always consume the whole buffer. - assert(result == bufferSlice.readableBytes) + c_nio_http_parser_pause(&self.parser, 0) // Update readerIndex of the cumulationBuffer itself as we will refetch it in the next loop run if needed. self.cumulationBuffer?.moveReaderIndex(forwardBy: result) + + self.firePendingInOut(ctx: ctx) } if self.state.seenEOF { @@ -447,22 +486,30 @@ public class HTTPDecoder: ByteToMessageDecoder, AnyHTTPDecoder { } } + private func firePendingInOut(ctx: ChannelHandlerContext) { + if let pending = self.state.pendingInOut { + self.state.pendingInOut = nil + ctx.fireChannelRead(pending) + } + } + private func discardDecodedBytes() { guard self.cumulationBuffer != nil else { // Guard against the case of closing the channel. In this case the cumulationBuffer will be nil. return } - assert(self.cumulationBuffer!.readableBytes == 0) switch self.state.dataAwaitingState { case .body, .messageBegin: assert(self.state.currentNameIndex == nil) assert(self.state.currentHeaders.isEmpty) assert(self.state.slice == nil) - - // Its safe to just drop the cumulationBuffer as we not have any extra views into it that are represented as readerIndex / length. - self.cumulationBuffer = nil + + if self.cumulationBuffer!.readableBytes == 0 { + // Its safe to just drop the cumulationBuffer as we don't have any extra views into it that are represented as readerIndex / length. + self.cumulationBuffer = nil + } case .headerField, .headerValue: guard let headerStartIdx = self.state.headerStartIndex else { @@ -519,18 +566,6 @@ public class HTTPDecoder: ByteToMessageDecoder, AnyHTTPDecoder { public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) { var buffer = self.unwrapInboundIn(data) - // Guard against re-entrance calls of channelRead(...) - guard !self.decoding else { - self.cumulationBuffer!.write(buffer: &buffer) - return - } - - // Needed to guard again re-entrant calls. - self.decoding = true - defer { - self.decoding = false - } - // Either use the received buffer directly or merge it into the already existing cumulationBuffer. if self.cumulationBuffer == nil { self.cumulationBuffer = buffer @@ -599,11 +634,6 @@ public class HTTPDecoder: ByteToMessageDecoder, AnyHTTPDecoder { // let us enter this function again. self.state.seenEOF = true - guard !self.decoding else { - // We are currently decoding, return early as we will handle it in the decoding loop. - return - } - self.notifyParserEOF(ctx: ctx) } @@ -613,15 +643,25 @@ public class HTTPDecoder: ByteToMessageDecoder, AnyHTTPDecoder { // We don't need the cumulation buffer, if we're holding it. self.cumulationBuffer = nil - + + self.firePendingInOut(ctx: ctx) + // No check to state.currentError because, if we hit it before, we already threw that // error. This never calls any of the callbacks that set that field anyway. Instead we // just check if the errno is set and throw. + if let parserError = self.currentParserError() { + self.state.currentError = parserError + ctx.fireErrorCaught(parserError) + } + } + + private func currentParserError() -> HTTPParserError? { let httpError = self.parser.http_errno - if httpError != 0 { - self.state.currentError = HTTPParserError.httpError(fromCHTTPParserErrno: http_errno(rawValue: httpError))! - ctx.fireErrorCaught(self.state.currentError!) + // Also take into account that we may have called c_nio_http_parser_pause(...) + guard httpError != HPE_PAUSED.rawValue && httpError != 0 else { + return nil } + return HTTPParserError.httpError(fromCHTTPParserErrno: http_errno(rawValue: httpError))! } } diff --git a/Sources/NIOHTTP1/HTTPUpgradeHandler.swift b/Sources/NIOHTTP1/HTTPUpgradeHandler.swift index 7a725cacd6f..f2ce8520575 100644 --- a/Sources/NIOHTTP1/HTTPUpgradeHandler.swift +++ b/Sources/NIOHTTP1/HTTPUpgradeHandler.swift @@ -73,9 +73,8 @@ public class HTTPServerUpgradeHandler: ChannelInboundHandler { /// Whether we've already seen the first request. private var seenFirstRequest = false - /// Whether we're upgrading: if we are, we want to buffer the data until the - /// upgrade is complete. - private var upgrading = false + /// The closure that should be invoked when the end of the upgrade request is received if any. + private var upgrade: (() -> Void)? private var receivedMessages: [NIOAny] = [] /// Create a `HTTPServerUpgradeHandler`. @@ -131,45 +130,59 @@ public class HTTPServerUpgradeHandler: ChannelInboundHandler { } public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) { - if self.upgrading { + guard !self.seenFirstRequest else { // We're waiting for upgrade to complete: buffer this data. self.receivedMessages.append(data) return } - // We're trying to remove ourselves from the pipeline but not upgrading, so just pass this on. - if seenFirstRequest { - ctx.fireChannelRead(data) - return - } - let requestPart = unwrapInboundIn(data) - seenFirstRequest = true - - // We should only ever see a request header: by the time the body comes in we should - // be out of the pipeline. Anything else is an error. - guard case .head(let request) = requestPart else { - ctx.fireErrorCaught(HTTPUpgradeErrors.invalidHTTPOrdering) - notUpgrading(ctx: ctx, data: data) - return - } - // Ok, we have a HTTP request. Check if it's an upgrade. If it's not, we want to pass it on and remove ourselves - // from the channel pipeline. - let requestedProtocols = request.headers[canonicalForm: "upgrade"] - guard requestedProtocols.count > 0 else { - notUpgrading(ctx: ctx, data: data) - return - } - - // Cool, this is an upgrade! Let's go. - if !handleUpgrade(ctx: ctx, request: request, requestedProtocols: requestedProtocols) { - notUpgrading(ctx: ctx, data: data) + if let upgrade = self.upgrade { + switch requestPart { + case .head(_): + ctx.fireErrorCaught(HTTPUpgradeErrors.invalidHTTPOrdering) + notUpgrading(ctx: ctx, data: data) + return + case .body(_): + // TODO: In the future way may want to add some API to also allow special handling of the body during the + // upgrade. For now just ignore it. + break + case .end(_): + self.seenFirstRequest = true + + // The request is complete now trigger the upgrade. + upgrade() + } + } else { + // We should decide if we're going to upgrade based on the first request header: if we aren't upgrading, + // by the time the body comes in we should be out of the pipeline. That means that if we don't think we're + // upgrading, the only thing we should see is a request head. Anything else in an error. + guard case .head(let request) = requestPart else { + ctx.fireErrorCaught(HTTPUpgradeErrors.invalidHTTPOrdering) + notUpgrading(ctx: ctx, data: data) + return + } + + // Ok, we have a HTTP request. Check if it's an upgrade. If it's not, we want to pass it on and remove ourselves + // from the channel pipeline. + let requestedProtocols = request.headers[canonicalForm: "upgrade"] + guard requestedProtocols.count > 0 else { + notUpgrading(ctx: ctx, data: data) + return + } + + // Cool, this is an upgrade! Let's go. + if let upgrade = handleUpgrade(ctx: ctx, request: request, requestedProtocols: requestedProtocols) { + self.upgrade = upgrade + } else { + notUpgrading(ctx: ctx, data: data) + } } } /// The core of the upgrade handling logic. - private func handleUpgrade(ctx: ChannelHandlerContext, request: HTTPRequestHead, requestedProtocols: [String]) -> Bool { + private func handleUpgrade(ctx: ChannelHandlerContext, request: HTTPRequestHead, requestedProtocols: [String]) -> (() -> Void)? { let connectionHeader = Set(request.headers[canonicalForm: "connection"].map { $0.lowercased() }) let allHeaderNames = Set(request.headers.map { $0.name.lowercased() }) @@ -191,48 +204,45 @@ public class HTTPServerUpgradeHandler: ChannelInboundHandler { ctx.fireErrorCaught(error) continue } - - // We are now upgrading, any further data should be buffered and replayed. - self.upgrading = true - - // Before we finish the upgrade we have to remove the HTTPDecoder and any other non-Encoder HTTP - // handlers from the pipeline, to prevent them parsing any more data. We'll buffer the data until - // that completes. - // While there are a lot of Futures involved here it's quite possible that all of this code will - // actually complete synchronously: we just want to program for the possibility that it won't. - // Once that's done, we send the upgrade response, then remove the HTTP encoder, then call the - // internal handler, then call the user code, and then finally when the user code is done we do - // our final cleanup steps, namely we replay the received data we buffered in the meantime and - // then remove ourselves from the pipeline. - _ = self.removeExtraHandlers(ctx: ctx).then { - self.sendUpgradeResponse(ctx: ctx, upgradeRequest: request, responseHeaders: responseHeaders) - }.then { - self.removeHandler(ctx: ctx, handler: self.httpEncoder) - }.map { (_: Bool) in - self.upgradeCompletionHandler(ctx) - }.then { - upgrader.upgrade(ctx: ctx, upgradeRequest: request) - }.map { - ctx.fireUserInboundEventTriggered(HTTPUpgradeEvents.upgradeComplete(toProtocol: proto, upgradeRequest: request)) - - self.upgrading = false - - // We unbuffer any buffered data here and, if we sent any, - // we also fire readComplete. - let bufferedMessages = self.receivedMessages - self.receivedMessages = [] - bufferedMessages.forEach { ctx.fireChannelRead($0) } - if bufferedMessages.count > 0 { - ctx.fireChannelReadComplete() + + return { + // Before we finish the upgrade we have to remove the HTTPDecoder and any other non-Encoder HTTP + // handlers from the pipeline, to prevent them parsing any more data. We'll buffer the data until + // that completes. + // While there are a lot of Futures involved here it's quite possible that all of this code will + // actually complete synchronously: we just want to program for the possibility that it won't. + // Once that's done, we send the upgrade response, then remove the HTTP encoder, then call the + // internal handler, then call the user code, and then finally when the user code is done we do + // our final cleanup steps, namely we replay the received data we buffered in the meantime and + // then remove ourselves from the pipeline. + _ = self.removeExtraHandlers(ctx: ctx).then { + self.sendUpgradeResponse(ctx: ctx, upgradeRequest: request, responseHeaders: responseHeaders) + }.then { + self.removeHandler(ctx: ctx, handler: self.httpEncoder) + }.map { (_: Bool) in + self.upgradeCompletionHandler(ctx) + }.then { + upgrader.upgrade(ctx: ctx, upgradeRequest: request) + }.map { + ctx.fireUserInboundEventTriggered(HTTPUpgradeEvents.upgradeComplete(toProtocol: proto, upgradeRequest: request)) + + self.upgrade = nil + + // We unbuffer any buffered data here and, if we sent any, + // we also fire readComplete. + let bufferedMessages = self.receivedMessages + self.receivedMessages = [] + bufferedMessages.forEach { ctx.fireChannelRead($0) } + if bufferedMessages.count > 0 { + ctx.fireChannelReadComplete() + } + }.then { + ctx.pipeline.remove(ctx: ctx) } - }.then { - ctx.pipeline.remove(ctx: ctx) } - - return true } - return false + return nil } /// Sends the 101 Switching Protocols response for the pipeline. diff --git a/Tests/NIOHTTP1Tests/HTTPUpgradeTests+XCTest.swift b/Tests/NIOHTTP1Tests/HTTPUpgradeTests+XCTest.swift index 18c801089e3..f311af83d52 100644 --- a/Tests/NIOHTTP1Tests/HTTPUpgradeTests+XCTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPUpgradeTests+XCTest.swift @@ -41,6 +41,7 @@ extension HTTPUpgradeTestCase { ("testBuffersInboundDataDuringDelayedUpgrade", testBuffersInboundDataDuringDelayedUpgrade), ("testRemovesAllHTTPRelatedHandlersAfterUpgrade", testRemovesAllHTTPRelatedHandlersAfterUpgrade), ("testBasicUpgradePipelineMutation", testBasicUpgradePipelineMutation), + ("testUpgradeWithUpgradePayloadInlineWithRequestWorks", testUpgradeWithUpgradePayloadInlineWithRequestWorks), ] } } diff --git a/Tests/NIOHTTP1Tests/HTTPUpgradeTests.swift b/Tests/NIOHTTP1Tests/HTTPUpgradeTests.swift index bf3ece044b5..cc2af587aa8 100644 --- a/Tests/NIOHTTP1Tests/HTTPUpgradeTests.swift +++ b/Tests/NIOHTTP1Tests/HTTPUpgradeTests.swift @@ -814,4 +814,110 @@ class HTTPUpgradeTestCase: XCTestCase { try channel.pipeline.assertDoesNotContain(handlerType: HTTPRequestDecoder.self) try channel.pipeline.assertDoesNotContain(handlerType: HTTPResponseEncoder.self) } + + func testUpgradeWithUpgradePayloadInlineWithRequestWorks() throws { + var upgradeRequest: HTTPRequestHead? = nil + var upgradeHandlerCbFired = false + var upgraderCbFired = false + + class CheckWeReadInlineAndExtraData: ChannelDuplexHandler { + typealias InboundIn = ByteBuffer + typealias OutboundIn = Never + typealias OutboundOut = Never + + enum State { + case fresh + case added + case inlineDataRead + case extraDataRead + case closed + } + + private let allDonePromise: EventLoopPromise + private var state = State.fresh + + init(allDonePromise: EventLoopPromise) { + self.allDonePromise = allDonePromise + } + + func handlerAdded(ctx: ChannelHandlerContext) { + XCTAssertEqual(.fresh, self.state) + self.state = .added + } + + func channelRead(ctx: ChannelHandlerContext, data: NIOAny) { + let buf = self.unwrapInboundIn(data) + XCTAssertEqual(1, buf.readableBytes) + let stringRead = buf.getString(at: 0, length: buf.readableBytes) + switch self.state { + case .added: + XCTAssertEqual("A", stringRead) + self.state = .inlineDataRead + case .inlineDataRead: + XCTAssertEqual("B", stringRead) + self.state = .extraDataRead + ctx.channel.close(promise: nil) + default: + XCTFail("channel read in wrong state \(self.state)") + } + } + + func close(ctx: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise?) { + XCTAssertEqual(.extraDataRead, self.state) + self.state = .closed + ctx.close(mode: mode, promise: promise) + + allDonePromise.succeed(result: ()) + } + } + + let upgrader = SuccessfulUpgrader(forProtocol: "myproto", requiringHeaders: ["kafkaesque"]) { req in + upgradeRequest = req + XCTAssert(upgradeHandlerCbFired) + upgraderCbFired = true + } + + let allDonePromise: EventLoopPromise = EmbeddedEventLoop().newPromise() + let (group, server, client, connectedServer) = try setUpTestWithAutoremoval(upgraders: [upgrader], + extraHandlers: []) { (ctx) in + // This is called before the upgrader gets called. + XCTAssertNil(upgradeRequest) + upgradeHandlerCbFired = true + + _ = ctx.channel.pipeline.add(handler: CheckWeReadInlineAndExtraData(allDonePromise: allDonePromise)) + } + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + let completePromise: EventLoopPromise = group.next().newPromise() + let clientHandler = ArrayAccumulationHandler { buffers in + let resultString = buffers.map { $0.getString(at: $0.readerIndex, length: $0.readableBytes)! }.joined(separator: "") + assertResponseIs(response: resultString, + expectedResponseLine: "HTTP/1.1 101 Switching Protocols", + expectedResponseHeaders: ["X-Upgrade-Complete: true", "upgrade: myproto", "connection: upgrade"]) + completePromise.succeed(result: ()) + } + XCTAssertNoThrow(try client.pipeline.add(handler: clientHandler).wait()) + + // This request is safe to upgrade. + var request = "OPTIONS * HTTP/1.1\r\nHost: localhost\r\nUpgrade: myproto\r\nKafkaesque: yup\r\nConnection: upgrade\r\nConnection: kafkaesque\r\n\r\n" + request += "A" + XCTAssertNoThrow(try client.writeAndFlush(NIOAny(ByteBuffer.forString(request))).wait()) + + XCTAssertNoThrow(try client.writeAndFlush(NIOAny(ByteBuffer.forString("B"))).wait()) + + // Let the machinery do its thing. + XCTAssertNoThrow(try completePromise.futureResult.wait()) + + // At this time we want to assert that everything got called. Their own callbacks assert + // that the ordering was correct. + XCTAssert(upgradeHandlerCbFired) + XCTAssert(upgraderCbFired) + + // We also want to confirm that the upgrade handler is no longer in the pipeline. + try connectedServer.pipeline.assertDoesNotContainUpgrader() + + XCTAssertNoThrow(try allDonePromise.futureResult.wait()) + } }