Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions Sources/NIOHTTP1/HTTPDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
import NIO
import CNIOHTTPParser

private extension UnsafeMutablePointer where Pointee == http_parser {
/// Returns the `KeepAliveState` for the current message that is parsed.
var keepAliveState: KeepAliveState {
return c_nio_http_should_keep_alive(self) == 0 ? .close : .keepAlive
}
}

private struct HTTPParserState {
var dataAwaitingState: DataAwaitingState = .messageBegin
var currentNameIndex: HTTPHeaderIndex?
Expand Down Expand Up @@ -214,15 +221,15 @@ public class HTTPDecoder<HTTPMessageT>: ByteToMessageDecoder, AnyHTTPDecoder {
private func newRequestHead(_ parser: UnsafeMutablePointer<http_parser>!) -> HTTPRequestHead {
let method = HTTPMethod.from(httpParserMethod: http_method(rawValue: parser.pointee.method))
let version = HTTPVersion(major: parser.pointee.http_major, minor: parser.pointee.http_minor)
let request = HTTPRequestHead(version: version, method: method, rawURI: state.currentURI!, headers: HTTPHeaders(buffer: cumulationBuffer!, headers: state.currentHeaders))
let request = HTTPRequestHead(version: version, method: method, rawURI: state.currentURI!, headers: HTTPHeaders(buffer: cumulationBuffer!, headers: state.currentHeaders, keepAliveState: parser.keepAliveState))
self.state.currentHeaders.removeAll(keepingCapacity: true)
return request
}

private func newResponseHead(_ parser: UnsafeMutablePointer<http_parser>!) -> HTTPResponseHead {
let status = HTTPResponseStatus(statusCode: Int(parser.pointee.status_code), reasonPhrase: state.currentStatus!)
let version = HTTPVersion(major: parser.pointee.http_major, minor: parser.pointee.http_minor)
let response = HTTPResponseHead(version: version, status: status, headers: HTTPHeaders(buffer: cumulationBuffer!, headers: state.currentHeaders))
let response = HTTPResponseHead(version: version, status: status, headers: HTTPHeaders(buffer: cumulationBuffer!, headers: state.currentHeaders, keepAliveState: parser.keepAliveState))
self.state.currentHeaders.removeAll(keepingCapacity: true)
return response
}
Expand Down Expand Up @@ -383,7 +390,8 @@ public class HTTPDecoder<HTTPMessageT>: ByteToMessageDecoder, AnyHTTPDecoder {
handler.state.complete(state: handler.state.dataAwaitingState)
handler.state.dataAwaitingState = .messageBegin

let trailers = handler.state.currentHeaders.isEmpty ? nil : HTTPHeaders(buffer: handler.state.cumulationBuffer!, headers: handler.state.currentHeaders)
// Just use unknown for trailers as there is no point for anything else.
let trailers = handler.state.currentHeaders.isEmpty ? nil : HTTPHeaders(buffer: handler.state.cumulationBuffer!, headers: handler.state.currentHeaders, keepAliveState: .unknown)
handler.state.currentHeaders.removeAll(keepingCapacity: true)
switch handler {
case let handler as HTTPRequestDecoder:
Expand Down
74 changes: 57 additions & 17 deletions Sources/NIOHTTP1/HTTPTypes.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,18 @@ import NIO
let crlf: StaticString = "\r\n"
let headerSeparator: StaticString = ": "

private let connectionUtf8 = "connection".utf8

// Keep track of keep alive state.
internal enum KeepAliveState {
// We know keep alive should be used.
case keepAlive
// We know we should close the connection.
case close
// We need to scan the headers to find out if keep alive is used or not
case unknown
}

/// A representation of the request line and header fields of a HTTP request.
public struct HTTPRequestHead: Equatable {
private final class _Storage {
Expand Down Expand Up @@ -175,15 +187,7 @@ extension HTTPRequestHead {
/// Whether this HTTP request is a keep-alive request: that is, whether the
/// connection should remain open after the request is complete.
public var isKeepAlive: Bool {
guard let connection = headers["connection"].first?.lowercased() else {
// HTTP 1.1 use keep-alive by default if not otherwise told.
return version.major == 1 && version.minor == 1
}

if connection == "close" {
return false
}
return connection == "keep-alive"
return headers.isKeepAlive(version: version)
}
}

Expand Down Expand Up @@ -286,6 +290,26 @@ private extension UInt8 {
}
}

/* private but tests */ internal extension HTTPHeaders {
func isKeepAlive(version: HTTPVersion) -> Bool {
switch self._storage.keepAliveState {
case .close:
return false
case .keepAlive:
return true
case .unknown:
guard let connection = self["connection"].first?.lowercased() else {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whilst you're working on this: worth fixing this implementation of follow-up?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me do a follow up

// HTTP 1.1 use keep-alive by default if not otherwise told.
return version.major == 1 && version.minor == 1
}

if connection == "close" {
return false
}
return connection == "keep-alive"
}
}
}

/// A representation of a block of HTTP header fields.
///
Expand All @@ -304,16 +328,18 @@ public struct HTTPHeaders: CustomStringConvertible {
private final class _Storage {
var buffer: ByteBuffer
var headers: [HTTPHeader]
var continuous: Bool = true
var continuous: Bool
var keepAliveState: KeepAliveState

init(buffer: ByteBuffer, headers: [HTTPHeader], continuous: Bool) {
init(buffer: ByteBuffer, headers: [HTTPHeader], continuous: Bool, keepAliveState: KeepAliveState) {
self.buffer = buffer
self.headers = headers
self.continuous = continuous
self.keepAliveState = keepAliveState
}

func copy() -> _Storage {
return .init(buffer: self.buffer, headers: self.headers, continuous: self.continuous)
return .init(buffer: self.buffer, headers: self.headers, continuous: self.continuous, keepAliveState: self.keepAliveState)
}
}
private var _storage: _Storage
Expand Down Expand Up @@ -356,8 +382,8 @@ public struct HTTPHeaders: CustomStringConvertible {
}

/// Constructor used by our decoder to construct headers without the need of converting bytes to string.
init(buffer: ByteBuffer, headers: [HTTPHeader]) {
self._storage = _Storage(buffer: buffer, headers: headers, continuous: true)
init(buffer: ByteBuffer, headers: [HTTPHeader], keepAliveState: KeepAliveState) {
self._storage = _Storage(buffer: buffer, headers: headers, continuous: true, keepAliveState: keepAliveState)
}

/// Construct a `HTTPHeaders` structure.
Expand All @@ -381,13 +407,17 @@ public struct HTTPHeaders: CustomStringConvertible {
var array: [HTTPHeader] = []
array.reserveCapacity(headers.count)

self.init(buffer: allocator.buffer(capacity: 256), headers: array)
self.init(buffer: allocator.buffer(capacity: 256), headers: array, keepAliveState: .unknown)

for (key, value) in headers {
self.add(name: key, value: value)
}
}


private func isConnectionHeader(_ header: HTTPHeaderIndex) -> Bool {
return self.buffer.equalCaseInsensitiveASCII(view: connectionUtf8, at: header)
}

/// Add a header name/value pair to the block.
///
/// This method is strictly additive: if there are other values for the given header name
Expand All @@ -408,8 +438,14 @@ public struct HTTPHeaders: CustomStringConvertible {
self._storage.buffer.write(staticString: headerSeparator)
let valueStart = self.buffer.writerIndex
let valueLength = self._storage.buffer.write(string: value)!
self._storage.headers.append(HTTPHeader(name: HTTPHeaderIndex(start: nameStart, length: nameLength), value: HTTPHeaderIndex(start: valueStart, length: valueLength)))

let nameIdx = HTTPHeaderIndex(start: nameStart, length: nameLength)
self._storage.headers.append(HTTPHeader(name: nameIdx, value: HTTPHeaderIndex(start: valueStart, length: valueLength)))
self._storage.buffer.write(staticString: crlf)

if self.isConnectionHeader(nameIdx) {
self._storage.keepAliveState = .unknown
}
}

/// Add a header name/value pair to the block, replacing any previous values for the
Expand Down Expand Up @@ -447,6 +483,10 @@ public struct HTTPHeaders: CustomStringConvertible {
let header = self.headers[idx]
if self.buffer.equalCaseInsensitiveASCII(view: utf8, at: header.name) {
array.append(idx)

if self.isConnectionHeader(header.name) {
self._storage.keepAliveState = .unknown
}
}
}

Expand Down
2 changes: 2 additions & 0 deletions Tests/NIOHTTP1Tests/HTTPHeadersTest+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ extension HTTPHeadersTest {
("testTrimWhitespaceWorksOnOnlyWhitespace", testTrimWhitespaceWorksOnOnlyWhitespace),
("testTrimWorksWithCharactersInTheMiddleAndWhitespaceAround", testTrimWorksWithCharactersInTheMiddleAndWhitespaceAround),
("testContains", testContains),
("testKeepAliveStateStartsWithClose", testKeepAliveStateStartsWithClose),
("testKeepAliveStateStartsWithKeepAlive", testKeepAliveStateStartsWithKeepAlive),
]
}
}
Expand Down
36 changes: 36 additions & 0 deletions Tests/NIOHTTP1Tests/HTTPHeadersTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,40 @@ class HTTPHeadersTest : XCTestCase {
XCTAssertTrue(headers.contains(name: "X-Header"))
XCTAssertFalse(headers.contains(name: "X-NonExistingHeader"))
}

func testKeepAliveStateStartsWithClose() {
var buffer = ByteBufferAllocator().buffer(capacity: 32)
buffer.write(string: "Connection: close\r\n")
var headers = HTTPHeaders(buffer: buffer, headers: [HTTPHeader(name: HTTPHeaderIndex(start: 0, length: 10), value: HTTPHeaderIndex(start: 12, length: 5))], keepAliveState: .close)

XCTAssertEqual("close", headers["connection"].first)
XCTAssertFalse(headers.isKeepAlive(version: HTTPVersion(major: 1, minor: 1)))

headers.replaceOrAdd(name: "connection", value: "keep-alive")

XCTAssertEqual("keep-alive", headers["connection"].first)
XCTAssertTrue(headers.isKeepAlive(version: HTTPVersion(major: 1, minor: 1)))

headers.remove(name: "connection")
XCTAssertTrue(headers.isKeepAlive(version: HTTPVersion(major: 1, minor: 1)))
XCTAssertFalse(headers.isKeepAlive(version: HTTPVersion(major: 1, minor: 0)))
}

func testKeepAliveStateStartsWithKeepAlive() {
var buffer = ByteBufferAllocator().buffer(capacity: 32)
buffer.write(string: "Connection: keep-alive\r\n")
var headers = HTTPHeaders(buffer: buffer, headers: [HTTPHeader(name: HTTPHeaderIndex(start: 0, length: 10), value: HTTPHeaderIndex(start: 12, length: 10))], keepAliveState: .keepAlive)

XCTAssertEqual("keep-alive", headers["connection"].first)
XCTAssertTrue(headers.isKeepAlive(version: HTTPVersion(major: 1, minor: 1)))

headers.replaceOrAdd(name: "connection", value: "close")

XCTAssertEqual("close", headers["connection"].first)
XCTAssertFalse(headers.isKeepAlive(version: HTTPVersion(major: 1, minor: 1)))

headers.remove(name: "connection")
XCTAssertTrue(headers.isKeepAlive(version: HTTPVersion(major: 1, minor: 1)))
XCTAssertFalse(headers.isKeepAlive(version: HTTPVersion(major: 1, minor: 0)))
}
}