diff --git a/Sources/NIOHTTP1/HTTPDecoder.swift b/Sources/NIOHTTP1/HTTPDecoder.swift index a0d2447e60..2e3a61e7b3 100644 --- a/Sources/NIOHTTP1/HTTPDecoder.swift +++ b/Sources/NIOHTTP1/HTTPDecoder.swift @@ -465,8 +465,7 @@ public final class HTTPDecoder: ByteToMessageDecoder, HTTPDecoderDelega /// /// - parameters: /// - leftOverBytesStrategy: The strategy to use when removing the decoder from the pipeline and an upgrade was, - /// detected. Note that this does not affect what happens on EOF (in which case an - /// `ByteToMessageDecoderError.leftoverDataWhenDone` error is fired.) + /// detected. Note that this does not affect what happens on EOF. public init(leftOverBytesStrategy: RemoveAfterUpgradeStrategy = .dropBytes) { self.headers.reserveCapacity(16) if In.self == HTTPServerRequestPart.self { @@ -620,18 +619,16 @@ public final class HTTPDecoder: ByteToMessageDecoder, HTTPDecoderDelega try self.feedEOF(context: context) } } - if buffer.readableBytes > 0 { - if seenEOF { + if buffer.readableBytes > 0 && !seenEOF { + // We only do this if we haven't seen EOF because the left-overs strategy must only be invoked when we're + // sure that this is the completion of an upgrade. + switch self.leftOverBytesStrategy { + case .dropBytes: + () + case .fireError: context.fireErrorCaught(ByteToMessageDecoderError.leftoverDataWhenDone(buffer)) - } else { - switch self.leftOverBytesStrategy { - case .dropBytes: - () - case .fireError: - context.fireErrorCaught(ByteToMessageDecoderError.leftoverDataWhenDone(buffer)) - case .forwardBytes: - context.fireChannelRead(NIOAny(buffer)) - } + case .forwardBytes: + context.fireChannelRead(NIOAny(buffer)) } } return .needMoreData diff --git a/Tests/NIOHTTP1Tests/HTTPDecoderTest+XCTest.swift b/Tests/NIOHTTP1Tests/HTTPDecoderTest+XCTest.swift index dbd25c0a71..f026e4258c 100644 --- a/Tests/NIOHTTP1Tests/HTTPDecoderTest+XCTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPDecoderTest+XCTest.swift @@ -46,7 +46,7 @@ extension HTTPDecoderTest { ("testDoesNotDeliverLeftoversUnnecessarily", testDoesNotDeliverLeftoversUnnecessarily), ("testHTTPResponseWithoutHeaders", testHTTPResponseWithoutHeaders), ("testBasicVerifications", testBasicVerifications), - ("testErrorFiredOnEOFForLeftOversInAllLeftOversModes", testErrorFiredOnEOFForLeftOversInAllLeftOversModes), + ("testNothingHappensOnEOFForLeftOversInAllLeftOversModes", testNothingHappensOnEOFForLeftOversInAllLeftOversModes), ("testBytesCanBeForwardedWhenHandlerRemoved", testBytesCanBeForwardedWhenHandlerRemoved), ("testBytesCanBeFiredAsErrorWhenHandlerRemoved", testBytesCanBeFiredAsErrorWhenHandlerRemoved), ("testBytesCanBeDroppedWhenHandlerRemoved", testBytesCanBeDroppedWhenHandlerRemoved), diff --git a/Tests/NIOHTTP1Tests/HTTPDecoderTest.swift b/Tests/NIOHTTP1Tests/HTTPDecoderTest.swift index 7fb4b7a90c..4b7b6bb7b9 100644 --- a/Tests/NIOHTTP1Tests/HTTPDecoderTest.swift +++ b/Tests/NIOHTTP1Tests/HTTPDecoderTest.swift @@ -594,24 +594,14 @@ class HTTPDecoderTest: XCTestCase { decoderFactory: { HTTPRequestDecoder() })) } - func testErrorFiredOnEOFForLeftOversInAllLeftOversModes() throws { + func testNothingHappensOnEOFForLeftOversInAllLeftOversModes() throws { class Receiver: ChannelInboundHandler { typealias InboundIn = HTTPServerRequestPart - private let errorReceivedPromise: EventLoopPromise private var numberOfErrors = 0 - init(errorReceivedPromise: EventLoopPromise) { - self.errorReceivedPromise = errorReceivedPromise - } - func errorCaught(context: ChannelHandlerContext, error: Error) { - self.numberOfErrors += 1 - if self.numberOfErrors == 1, let error = error as? ByteToMessageDecoderError { - self.errorReceivedPromise.succeed(error) - } else { - XCTFail("illegal: number of errors: \(self.numberOfErrors), error: \(error)") - } + XCTFail("unexpected error: \(error)") } func channelRead(context: ChannelHandlerContext, data: NIOAny) { @@ -629,24 +619,14 @@ class HTTPDecoderTest: XCTestCase { for leftOverBytesStrategy in [RemoveAfterUpgradeStrategy.dropBytes, .fireError, .forwardBytes] { let channel = EmbeddedChannel() - let errorReceivedPromise: EventLoopPromise = channel.eventLoop.makePromise() var buffer = channel.allocator.buffer(capacity: 64) buffer.writeStaticString("OPTIONS * HTTP/1.1\r\nHost: L\r\nUpgrade: P\r\nConnection: upgrade\r\n\r\nXXXX") let decoder = HTTPRequestDecoder(leftOverBytesStrategy: leftOverBytesStrategy) XCTAssertNoThrow(try channel.pipeline.addHandler(ByteToMessageHandler(decoder)).wait()) - XCTAssertNoThrow(try channel.pipeline.addHandler(Receiver(errorReceivedPromise: errorReceivedPromise)).wait()) + XCTAssertNoThrow(try channel.pipeline.addHandler(Receiver()).wait()) XCTAssertNoThrow(try channel.writeInbound(buffer)) XCTAssertNoThrow(XCTAssert(try channel.finish().isClean)) - - switch Result(catching: { try errorReceivedPromise.futureResult.wait() }) { - case .success(ByteToMessageDecoderError.leftoverDataWhenDone(let buffer)): - XCTAssertEqual("XXXX", String(decoding: buffer.readableBytesView, as: Unicode.UTF8.self)) - case .failure(let error): - XCTFail("unexpected error: \(error)") - case .success(let error): - XCTFail("unexpected error: \(error)") - } } }