Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 78 additions & 38 deletions Sources/NIOHTTP1/HTTPDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ private struct HTTPParserState {
var currentError: HTTPParserError?
var seenEOF = false
var headerStartIndex: Int?

// Holds the data we need to forward via ctx.fireChannelRead(...) after invoking the parser.
var pendingInOut: NIOAny? = nil

enum DataAwaitingState {
case messageBegin
Expand All @@ -57,6 +60,7 @@ private struct HTTPParserState {
self.currentStatus = nil
self.slice = nil
self.headerStartIndex = nil
self.pendingInOut = nil
}

var cumulationBuffer: ByteBuffer?
Expand Down Expand Up @@ -196,7 +200,6 @@ public class HTTPDecoder<HTTPMessageT>: ByteToMessageDecoder, AnyHTTPDecoder {

private var parser = http_parser()
private var settings = http_parser_settings()
private var decoding: Bool = false

fileprivate var state = HTTPParserState()

Expand Down Expand Up @@ -234,6 +237,25 @@ public class HTTPDecoder<HTTPMessageT>: ByteToMessageDecoder, AnyHTTPDecoder {
return response
}

private func bytesToForwardOnRemoval(ctx: ChannelHandlerContext) -> ByteBuffer? {
guard self.parser.upgrade == 1 && ctx.channel.isActive else {
return nil
}
// We take a slice of the cumulationBuffer so the next handler in the pipeline will just see the readable portion of the buffer.
// While this is not strictly needed it may make it easier to consume.
if let buffer = self.cumulationBuffer?.slice(), buffer.readableBytes > 0 {
return buffer
}
return nil
}

public func handlerRemoved(ctx: ChannelHandlerContext) {
if let buffer = self.bytesToForwardOnRemoval(ctx: ctx) {
ctx.fireChannelRead(NIOAny(buffer))
}
self.cumulationBuffer = nil
}

public func decoderAdded(ctx: ChannelHandlerContext) {
if HTTPMessageT.self == HTTPServerRequestPart.self {
c_nio_http_parser_init(&self.parser, HTTP_REQUEST)
Expand All @@ -258,6 +280,11 @@ public class HTTPDecoder<HTTPMessageT>: ByteToMessageDecoder, AnyHTTPDecoder {
let ctx = evacuateChannelHandlerContext(parser)
let handler = ctx.handler as! AnyHTTPDecoder

// Ensure we pause the parser after this callback is complete so we can safely callout
// to the pipeline.
c_nio_http_parser_pause(parser, 1)
assert(handler.state.pendingInOut == nil)

handler.state.complete(state: handler.state.dataAwaitingState)
handler.state.dataAwaitingState = .body

Expand All @@ -269,7 +296,7 @@ public class HTTPDecoder<HTTPMessageT>: ByteToMessageDecoder, AnyHTTPDecoder {
return -1
}

ctx.fireChannelRead(handler.wrapInboundOut(HTTPServerRequestPart.head(head)))
handler.state.pendingInOut = handler.wrapInboundOut(HTTPServerRequestPart.head(head))
return 0
case let handler as HTTPResponseDecoder:
let head = handler.newResponseHead(parser)
Expand All @@ -278,7 +305,7 @@ public class HTTPDecoder<HTTPMessageT>: ByteToMessageDecoder, AnyHTTPDecoder {
return -1
}

ctx.fireChannelRead(handler.wrapInboundOut(HTTPClientResponsePart.head(head)))
handler.state.pendingInOut = handler.wrapInboundOut(HTTPClientResponsePart.head(head))

// http_parser doesn't correctly handle responses to HEAD requests. We have to do something
// annoyingly opaque here, and in those cases return 1 instead of 0. This forces http_parser
Expand Down Expand Up @@ -322,14 +349,19 @@ public class HTTPDecoder<HTTPMessageT>: ByteToMessageDecoder, AnyHTTPDecoder {
let handler = ctx.handler as! AnyHTTPDecoder
assert(handler.state.dataAwaitingState == .body)

// Ensure we pause the parser after this callback is complete so we can safely callout
// to the pipeline.
c_nio_http_parser_pause(parser, 1)
assert(handler.state.pendingInOut == nil)

// Calculate the index of the data in the cumulationBuffer so we can slice out the ByteBuffer without doing any memory copy
let index = handler.state.calculateIndex(data: data!, length: len)
let slice = handler.state.cumulationBuffer!.getSlice(at: index, length: len)!
switch handler {
case let handler as HTTPRequestDecoder:
ctx.fireChannelRead(handler.wrapInboundOut(HTTPServerRequestPart.body(slice)))
handler.state.pendingInOut = handler.wrapInboundOut(HTTPServerRequestPart.body(slice))
case let handler as HTTPResponseDecoder:
ctx.fireChannelRead(handler.wrapInboundOut(HTTPClientResponsePart.body(slice)))
handler.state.pendingInOut = handler.wrapInboundOut(HTTPClientResponsePart.body(slice))
default:
fatalError("the impossible happened: handler neither a HTTPRequestDecoder nor a HTTPResponseDecoder which should be impossible")
}
Expand Down Expand Up @@ -387,6 +419,12 @@ public class HTTPDecoder<HTTPMessageT>: ByteToMessageDecoder, AnyHTTPDecoder {
self.settings.on_message_complete = { parser in
let ctx = evacuateChannelHandlerContext(parser)
let handler = ctx.handler as! AnyHTTPDecoder

// Ensure we pause the parser after this callback is complete so we can safely callout
// to the pipeline.
c_nio_http_parser_pause(parser, 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we need similar stuff for on_body or so?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ignore me, we've got that

assert(handler.state.pendingInOut == nil)

handler.state.complete(state: handler.state.dataAwaitingState)
handler.state.dataAwaitingState = .messageBegin

Expand All @@ -395,9 +433,9 @@ public class HTTPDecoder<HTTPMessageT>: ByteToMessageDecoder, AnyHTTPDecoder {
handler.state.currentHeaders.removeAll(keepingCapacity: true)
switch handler {
case let handler as HTTPRequestDecoder:
ctx.fireChannelRead(handler.wrapInboundOut(HTTPServerRequestPart.end(trailers)))
handler.state.pendingInOut = handler.wrapInboundOut(HTTPServerRequestPart.end(trailers))
case let handler as HTTPResponseDecoder:
ctx.fireChannelRead(handler.wrapInboundOut(HTTPClientResponsePart.end(trailers)))
handler.state.pendingInOut = handler.wrapInboundOut(HTTPClientResponsePart.end(trailers))
default:
fatalError("the impossible happened: handler neither a HTTPRequestDecoder nor a HTTPResponseDecoder which should be impossible")
}
Expand All @@ -411,10 +449,9 @@ public class HTTPDecoder<HTTPMessageT>: ByteToMessageDecoder, AnyHTTPDecoder {
throw error
}

let httpError = self.parser.http_errno
if httpError != 0 {
self.state.currentError = HTTPParserError.httpError(fromCHTTPParserErrno: http_errno(rawValue: httpError))!
throw self.state.currentError!
if let parserError = self.currentParserError() {
self.state.currentError = parserError
throw parserError
}
}

Expand All @@ -432,13 +469,15 @@ public class HTTPDecoder<HTTPMessageT>: ByteToMessageDecoder, AnyHTTPDecoder {
c_nio_http_parser_execute(&self.parser, &self.settings, p.advanced(by: bufferSlice.readerIndex), bufferSlice.readableBytes)
}
}

try self.rethrowParserError()

/// http_parser_execute(...) should always consume the whole buffer.
assert(result == bufferSlice.readableBytes)
c_nio_http_parser_pause(&self.parser, 0)

// Update readerIndex of the cumulationBuffer itself as we will refetch it in the next loop run if needed.
self.cumulationBuffer?.moveReaderIndex(forwardBy: result)

self.firePendingInOut(ctx: ctx)
}

if self.state.seenEOF {
Expand All @@ -447,22 +486,30 @@ public class HTTPDecoder<HTTPMessageT>: ByteToMessageDecoder, AnyHTTPDecoder {
}
}

private func firePendingInOut(ctx: ChannelHandlerContext) {
if let pending = self.state.pendingInOut {
self.state.pendingInOut = nil
ctx.fireChannelRead(pending)
}
}

private func discardDecodedBytes() {
guard self.cumulationBuffer != nil else {
// Guard against the case of closing the channel. In this case the cumulationBuffer will be nil.
return
}

assert(self.cumulationBuffer!.readableBytes == 0)

switch self.state.dataAwaitingState {
case .body, .messageBegin:
assert(self.state.currentNameIndex == nil)
assert(self.state.currentHeaders.isEmpty)
assert(self.state.slice == nil)

// Its safe to just drop the cumulationBuffer as we not have any extra views into it that are represented as readerIndex / length.
self.cumulationBuffer = nil

if self.cumulationBuffer!.readableBytes == 0 {
// Its safe to just drop the cumulationBuffer as we don't have any extra views into it that are represented as readerIndex / length.
self.cumulationBuffer = nil
}

case .headerField, .headerValue:
guard let headerStartIdx = self.state.headerStartIndex else {
Expand Down Expand Up @@ -519,18 +566,6 @@ public class HTTPDecoder<HTTPMessageT>: ByteToMessageDecoder, AnyHTTPDecoder {
public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
var buffer = self.unwrapInboundIn(data)

// Guard against re-entrance calls of channelRead(...)
guard !self.decoding else {
self.cumulationBuffer!.write(buffer: &buffer)
return
}

// Needed to guard again re-entrant calls.
self.decoding = true
defer {
self.decoding = false
}

// Either use the received buffer directly or merge it into the already existing cumulationBuffer.
if self.cumulationBuffer == nil {
self.cumulationBuffer = buffer
Expand Down Expand Up @@ -599,11 +634,6 @@ public class HTTPDecoder<HTTPMessageT>: ByteToMessageDecoder, AnyHTTPDecoder {
// let us enter this function again.
self.state.seenEOF = true

guard !self.decoding else {
// We are currently decoding, return early as we will handle it in the decoding loop.
return
}

self.notifyParserEOF(ctx: ctx)
}

Expand All @@ -613,15 +643,25 @@ public class HTTPDecoder<HTTPMessageT>: ByteToMessageDecoder, AnyHTTPDecoder {

// We don't need the cumulation buffer, if we're holding it.
self.cumulationBuffer = nil


self.firePendingInOut(ctx: ctx)

// No check to state.currentError because, if we hit it before, we already threw that
// error. This never calls any of the callbacks that set that field anyway. Instead we
// just check if the errno is set and throw.
if let parserError = self.currentParserError() {
self.state.currentError = parserError
ctx.fireErrorCaught(parserError)
}
}

private func currentParserError() -> HTTPParserError? {
let httpError = self.parser.http_errno
if httpError != 0 {
self.state.currentError = HTTPParserError.httpError(fromCHTTPParserErrno: http_errno(rawValue: httpError))!
ctx.fireErrorCaught(self.state.currentError!)
// Also take into account that we may have called c_nio_http_parser_pause(...)
guard httpError != HPE_PAUSED.rawValue && httpError != 0 else {
return nil
}
return HTTPParserError.httpError(fromCHTTPParserErrno: http_errno(rawValue: httpError))!
}
}

Expand Down
Loading