Skip to content

Commit

Permalink
Refactor the codecs for extension. (#219)
Browse files Browse the repository at this point in the history
Motivation:

As part of the work in #214 we're going to need to update the
HTTP2ToHTTP1 codecs. These need to be replaced for the new channel
pipelines. The core of the logic will be identical in both cases, so
let's start by factoring that logic out into some nice standalone
objects that we can reuse.

Modifications:

- Pull out the base codecs into structures.
- Rewrite the main codecs in terms of these new structures.

Result:

Easier extension points.
  • Loading branch information
Lukasa committed Jul 29, 2020
1 parent edd373d commit 04706e7
Showing 1 changed file with 166 additions and 108 deletions.
274 changes: 166 additions & 108 deletions Sources/NIOHTTP2/HTTP2ToHTTP1Codec.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,84 @@ import NIOHTTP1
import NIOHPACK


fileprivate struct BaseClientCodec {
private let protocolString: String
private let normalizeHTTPHeaders: Bool

private var headerStateMachine: HTTP2HeadersStateMachine = HTTP2HeadersStateMachine(mode: .client)

/// Initializes a `BaseClientCodec`.
///
/// - parameters:
/// - httpProtocol: The protocol (usually `"http"` or `"https"` that is used).
/// - normalizeHTTPHeaders: Whether to automatically normalize the HTTP headers to be suitable for HTTP/2.
/// The normalization will for example lower-case all heder names (as required by the
/// HTTP/2 spec) and remove headers that are unsuitable for HTTP/2 such as
/// headers related to HTTP/1's keep-alive behaviour. Unless you are sure that all your
/// headers conform to the HTTP/2 spec, you should leave this parameter set to `true`.
fileprivate init(httpProtocol: HTTP2ToHTTP1ClientCodec.HTTPProtocol, normalizeHTTPHeaders: Bool) {
self.normalizeHTTPHeaders = normalizeHTTPHeaders

switch httpProtocol {
case .http:
self.protocolString = "http"
case .https:
self.protocolString = "https"
}
}

mutating func processInboundData(_ data: HTTP2Frame.FramePayload) throws -> (first: HTTPClientResponsePart?, second: HTTPClientResponsePart?) {
switch data {
case .headers(let headerContent):
if case .trailer = try self.headerStateMachine.newHeaders(block: headerContent.headers) {
return (first: .end(HTTPHeaders(regularHeadersFrom: headerContent.headers)), second: nil)
} else {
let respHead = try HTTPResponseHead(http2HeaderBlock: headerContent.headers)
let first = HTTPClientResponsePart.head(respHead)
var second: HTTPClientResponsePart? = nil
if headerContent.endStream {
second = .end(nil)
}
return (first: first, second: second)
}
case .data(let content):
guard case .byteBuffer(let b) = content.data else {
preconditionFailure("Received DATA frame with non-bytebuffer IOData")
}

let first = HTTPClientResponsePart.body(b)
var second: HTTPClientResponsePart? = nil
if content.endStream {
second = .end(nil)
}
return (first: first, second: second)
case .alternativeService, .rstStream, .priority, .windowUpdate, .settings, .pushPromise, .ping, .goAway, .origin:
// These don't have an HTTP/1 equivalent, so let's drop them.
return (first: nil, second: nil)
}
}

mutating func processOutboundData(_ data: HTTPClientRequestPart, allocator: ByteBufferAllocator) throws -> HTTP2Frame.FramePayload {
switch data {
case .head(let head):
let h1Headers = try HTTPHeaders(requestHead: head, protocolString: self.protocolString)
let headerContent = HTTP2Frame.FramePayload.Headers(headers: HPACKHeaders(httpHeaders: h1Headers,
normalizeHTTPHeaders: self.normalizeHTTPHeaders))
return .headers(headerContent)
case .body(let body):
return .data(HTTP2Frame.FramePayload.Data(data: body))
case .end(let trailers):
if let trailers = trailers {
return .headers(.init(headers: HPACKHeaders(httpHeaders: trailers,
normalizeHTTPHeaders: self.normalizeHTTPHeaders),
endStream: true))
} else {
return .data(.init(data: .byteBuffer(allocator.buffer(capacity: 0)), endStream: true))
}
}
}
}

/// A simple channel handler that translates HTTP/2 concepts into HTTP/1 data types,
/// and vice versa, for use on the client side.
///
Expand All @@ -37,10 +115,7 @@ public final class HTTP2ToHTTP1ClientCodec: ChannelInboundHandler, ChannelOutbou
}

private let streamID: HTTP2StreamID
private let protocolString: String
private let normalizeHTTPHeaders: Bool

private var headerStateMachine: HTTP2HeadersStateMachine = HTTP2HeadersStateMachine(mode: .client)
private var baseCodec: BaseClientCodec

/// Initializes a `HTTP2ToHTTP1ClientCodec` for the given `HTTP2StreamID`.
///
Expand All @@ -54,14 +129,7 @@ public final class HTTP2ToHTTP1ClientCodec: ChannelInboundHandler, ChannelOutbou
/// headers conform to the HTTP/2 spec, you should leave this parameter set to `true`.
public init(streamID: HTTP2StreamID, httpProtocol: HTTPProtocol, normalizeHTTPHeaders: Bool) {
self.streamID = streamID
self.normalizeHTTPHeaders = normalizeHTTPHeaders

switch httpProtocol {
case .http:
self.protocolString = "http"
case .https:
self.protocolString = "https"
}
self.baseCodec = BaseClientCodec(httpProtocol: httpProtocol, normalizeHTTPHeaders: normalizeHTTPHeaders)
}

/// Initializes a `HTTP2ToHTTP1ClientCodec` for the given `HTTP2StreamID`.
Expand All @@ -75,67 +143,91 @@ public final class HTTP2ToHTTP1ClientCodec: ChannelInboundHandler, ChannelOutbou

public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let frame = self.unwrapInboundIn(data)
do {
let (first, second) = try self.baseCodec.processInboundData(frame.payload)
if let first = first {
context.fireChannelRead(self.wrapInboundOut(first))
}
if let second = second {
context.fireChannelRead(self.wrapInboundOut(second))
}
} catch {
context.fireErrorCaught(error)
}
}

public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
let responsePart = self.unwrapOutboundIn(data)

do {
let transformedPayload = try self.baseCodec.processOutboundData(responsePart, allocator: context.channel.allocator)
let part = HTTP2Frame(streamID: self.streamID, payload: transformedPayload)
context.write(self.wrapOutboundOut(part), promise: promise)
} catch {
promise?.fail(error)
context.fireErrorCaught(error)
}
}
}


fileprivate struct BaseServerCodec {
private let normalizeHTTPHeaders: Bool
private var headerStateMachine: HTTP2HeadersStateMachine = HTTP2HeadersStateMachine(mode: .server)

init(normalizeHTTPHeaders: Bool) {
self.normalizeHTTPHeaders = normalizeHTTPHeaders
}

switch frame.payload {
mutating func processInboundData(_ data: HTTP2Frame.FramePayload) throws -> (first: HTTPServerRequestPart?, second: HTTPServerRequestPart?) {
switch data {
case .headers(let headerContent):
do {
if case .trailer = try self.headerStateMachine.newHeaders(block: headerContent.headers) {
context.fireChannelRead(self.wrapInboundOut(.end(HTTPHeaders(regularHeadersFrom: headerContent.headers))))
} else {
let respHead = try HTTPResponseHead(http2HeaderBlock: headerContent.headers)
context.fireChannelRead(self.wrapInboundOut(.head(respHead)))
if headerContent.endStream {
context.fireChannelRead(self.wrapInboundOut(.end(nil)))
}
if case .trailer = try self.headerStateMachine.newHeaders(block: headerContent.headers) {
return (first: .end(HTTPHeaders(regularHeadersFrom: headerContent.headers)), second: nil)
} else {
let reqHead = try HTTPRequestHead(http2HeaderBlock: headerContent.headers)

let first = HTTPServerRequestPart.head(reqHead)
var second: HTTPServerRequestPart? = nil
if headerContent.endStream {
second = .end(nil)
}
} catch {
context.fireErrorCaught(error)
return (first: first, second: second)
}
case .data(let content):
guard case .byteBuffer(let b) = content.data else {
preconditionFailure("Received DATA frame with non-bytebuffer IOData")
case .data(let dataContent):
guard case .byteBuffer(let b) = dataContent.data else {
preconditionFailure("Received non-byteBuffer IOData from network")
}

context.fireChannelRead(self.wrapInboundOut(.body(b)))
if content.endStream {
context.fireChannelRead(self.wrapInboundOut(.end(nil)))
let first = HTTPServerRequestPart.body(b)
var second: HTTPServerRequestPart? = nil
if dataContent.endStream {
second = .end(nil)
}
case .alternativeService, .rstStream, .priority, .windowUpdate, .settings, .pushPromise, .ping, .goAway, .origin:
// These don't have an HTTP/1 equivalent, so let's drop them.
return
return (first: first, second: second)
default:
// Any other frame type is ignored.
return (first: nil, second: nil)
}
}

public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
let responsePart = self.unwrapOutboundIn(data)
switch responsePart {
mutating func processOutboundData(_ data: HTTPServerResponsePart, allocator: ByteBufferAllocator) throws -> HTTP2Frame.FramePayload {
switch data {
case .head(let head):
do {
let h1Headers = try HTTPHeaders(requestHead: head, protocolString: self.protocolString)
let headerContent = HTTP2Frame.FramePayload.Headers(headers: HPACKHeaders(httpHeaders: h1Headers,
normalizeHTTPHeaders: self.normalizeHTTPHeaders))
let frame = HTTP2Frame(streamID: self.streamID, payload: .headers(headerContent))
context.write(self.wrapOutboundOut(frame), promise: promise)
} catch {
promise?.fail(error)
context.fireErrorCaught(error)
}
let h1 = HTTPHeaders(responseHead: head)
let payload = HTTP2Frame.FramePayload.Headers(headers: HPACKHeaders(httpHeaders: h1,
normalizeHTTPHeaders: self.normalizeHTTPHeaders))
return .headers(payload)
case .body(let body):
let payload = HTTP2Frame.FramePayload.Data(data: body)
let frame = HTTP2Frame(streamID: self.streamID, payload: .data(payload))
context.write(self.wrapOutboundOut(frame), promise: promise)
return .data(payload)
case .end(let trailers):
let payload: HTTP2Frame.FramePayload
if let trailers = trailers {
payload = .headers(.init(headers: HPACKHeaders(httpHeaders: trailers,
normalizeHTTPHeaders: self.normalizeHTTPHeaders),
endStream: true))
return .headers(.init(headers: HPACKHeaders(httpHeaders: trailers,
normalizeHTTPHeaders: self.normalizeHTTPHeaders),
endStream: true))
} else {
payload = .data(.init(data: .byteBuffer(context.channel.allocator.buffer(capacity: 0)), endStream: true))
return .data(.init(data: .byteBuffer(allocator.buffer(capacity: 0)), endStream: true))
}

let frame = HTTP2Frame(streamID: self.streamID, payload: payload)
context.write(self.wrapOutboundOut(frame), promise: promise)
}
}
}
Expand All @@ -155,9 +247,7 @@ public final class HTTP2ToHTTP1ServerCodec: ChannelInboundHandler, ChannelOutbou
public typealias OutboundOut = HTTP2Frame

private let streamID: HTTP2StreamID
private let normalizeHTTPHeaders: Bool

private var headerStateMachine: HTTP2HeadersStateMachine = HTTP2HeadersStateMachine(mode: .server)
private var baseCodec: BaseServerCodec

/// Initializes a `HTTP2ToHTTP1ServerCodec` for the given `HTTP2StreamID`.
///
Expand All @@ -170,7 +260,7 @@ public final class HTTP2ToHTTP1ServerCodec: ChannelInboundHandler, ChannelOutbou
/// headers conform to the HTTP/2 spec, you should leave this parameter set to `true`.
public init(streamID: HTTP2StreamID, normalizeHTTPHeaders: Bool) {
self.streamID = streamID
self.normalizeHTTPHeaders = normalizeHTTPHeaders
self.baseCodec = BaseServerCodec(normalizeHTTPHeaders: normalizeHTTPHeaders)
}

public convenience init(streamID: HTTP2StreamID) {
Expand All @@ -180,61 +270,29 @@ public final class HTTP2ToHTTP1ServerCodec: ChannelInboundHandler, ChannelOutbou
public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let frame = self.unwrapInboundIn(data)

switch frame.payload {
case .headers(let headerContent):
do {
if case .trailer = try self.headerStateMachine.newHeaders(block: headerContent.headers) {
context.fireChannelRead(self.wrapInboundOut(.end(HTTPHeaders(regularHeadersFrom: headerContent.headers))))
} else {
let reqHead = try HTTPRequestHead(http2HeaderBlock: headerContent.headers)
context.fireChannelRead(self.wrapInboundOut(.head(reqHead)))
if headerContent.endStream {
context.fireChannelRead(self.wrapInboundOut(.end(nil)))
}
}
} catch {
context.fireErrorCaught(error)
do {
let (first, second) = try self.baseCodec.processInboundData(frame.payload)
if let first = first {
context.fireChannelRead(self.wrapInboundOut(first))
}
case .data(let dataContent):
guard case .byteBuffer(let b) = dataContent.data else {
preconditionFailure("Received non-byteBuffer IOData from network")
if let second = second {
context.fireChannelRead(self.wrapInboundOut(second))
}
context.fireChannelRead(self.wrapInboundOut(.body(b)))
if dataContent.endStream {
context.fireChannelRead(self.wrapInboundOut(.end(nil)))
}
default:
// Any other frame type is ignored.
break
} catch {
context.fireErrorCaught(error)
}
}

public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
let responsePart = self.unwrapOutboundIn(data)
switch responsePart {
case .head(let head):
let h1 = HTTPHeaders(responseHead: head)
let payload = HTTP2Frame.FramePayload.Headers(headers: HPACKHeaders(httpHeaders: h1,
normalizeHTTPHeaders: self.normalizeHTTPHeaders))
let frame = HTTP2Frame(streamID: self.streamID, payload: .headers(payload))
context.write(self.wrapOutboundOut(frame), promise: promise)
case .body(let body):
let payload = HTTP2Frame.FramePayload.Data(data: body)
let frame = HTTP2Frame(streamID: self.streamID, payload: .data(payload))
context.write(self.wrapOutboundOut(frame), promise: promise)
case .end(let trailers):
let payload: HTTP2Frame.FramePayload

if let trailers = trailers {
payload = .headers(.init(headers: HPACKHeaders(httpHeaders: trailers,
normalizeHTTPHeaders: self.normalizeHTTPHeaders),
endStream: true))
} else {
payload = .data(.init(data: .byteBuffer(context.channel.allocator.buffer(capacity: 0)), endStream: true))
}

let frame = HTTP2Frame(streamID: self.streamID, payload: payload)
context.write(self.wrapOutboundOut(frame), promise: promise)
do {
let transformedPayload = try self.baseCodec.processOutboundData(responsePart, allocator: context.channel.allocator)
let part = HTTP2Frame(streamID: self.streamID, payload: transformedPayload)
context.write(self.wrapOutboundOut(part), promise: promise)
} catch {
promise?.fail(error)
context.fireErrorCaught(error)
}
}
}
Expand Down

0 comments on commit 04706e7

Please sign in to comment.