Skip to content
Permalink
Browse files

remove WebSocketFrameDecoder inline error handling (#885)

Motivation:

Follows on from the work done in #528 for #527: we have moved the the
default error handling out of WebSocketFrameDecoder, but had to leave
the code there for backward compatibility reasons. We can remove that
code now.

Modifications:

Removed automatic error handling code in WebSocketFrameDecoder.

Result:

- fixes #534
  • Loading branch information...
weissi committed Mar 8, 2019
1 parent c6065cc commit 36a52e1ea3ec03007d67a046856e6b938a6ef81e
@@ -17,7 +17,7 @@ import PackageDescription

var targets: [PackageDescription.Target] = [
.target(name: "_NIO1APIShims",
dependencies: ["NIO", "NIOHTTP1", "NIOTLS", "NIOFoundationCompat"]),
dependencies: ["NIO", "NIOHTTP1", "NIOTLS", "NIOFoundationCompat", "NIOWebSocket"]),
.target(name: "NIO",
dependencies: ["CNIOLinux",
"CNIODarwin",
@@ -231,12 +231,6 @@ public final class WebSocketFrameDecoder: ByteToMessageDecoder {
/// Our parser state.
private var parser = WSParser()

/// Whether we should continue to parse.
private var shouldKeepParsing = true

/// Whether this `ChannelHandler` should be performing automatic error handling.
private let automaticErrorHandling: Bool

/// Construct a new `WebSocketFrameDecoder`
///
/// - parameters:
@@ -252,68 +246,31 @@ public final class WebSocketFrameDecoder: ByteToMessageDecoder {
/// - automaticErrorHandling: Whether this `ChannelHandler` should automatically handle
/// protocol errors in frame serialization, or whether it should allow the pipeline
/// to handle them.
public init(maxFrameSize: Int = 1 << 14, automaticErrorHandling: Bool = true) {
public init(maxFrameSize: Int = 1 << 14) {
precondition(maxFrameSize <= UInt32.max, "invalid overlarge max frame size")
self.maxFrameSize = maxFrameSize
self.automaticErrorHandling = automaticErrorHandling
}

public func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) -> DecodingState {
public func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> DecodingState {
// Even though the calling code will loop around calling us in `decode`, we can't quite
// rely on that: sometimes we have zero-length elements to parse, and the caller doesn't
// guarantee to call us with zero-length bytes.
parseLoop: while self.shouldKeepParsing {
while true {
switch parser.parseStep(&buffer) {
case .result(let frame):
context.fireChannelRead(self.wrapInboundOut(frame))
return .continue
case .continueParsing:
do {
try self.parser.validateState(maxFrameSize: self.maxFrameSize)
} catch {
self.handleError(error, context: context)
}
try self.parser.validateState(maxFrameSize: self.maxFrameSize)
// loop again, might be 'waiting' for 0 bytes
case .insufficientData:
break parseLoop
return .needMoreData
}
}

// We parse eagerly, so once we get here we definitionally need more data.
return .needMoreData
}

public func decodeLast(context: ChannelHandlerContext, buffer: inout ByteBuffer, seenEOF: Bool) throws -> DecodingState {
// EOF is not semantic in WebSocket, so ignore this.
return .needMoreData
}



/// We hit a decoding error, we're going to tear things down now. To do this we're
/// basically going to send an error frame and then close the connection. Once we're
/// in this state we do no further parsing.
///
/// A clean websocket shutdown is not really supposed to have an immediate close,
/// but we're doing that because the remote peer has prevented us from doing
/// further frame parsing, so we can't really wait for the next frame.
private func handleError(_ error: Error, context: ChannelHandlerContext) {
guard let error = error as? NIOWebSocketError else {
fatalError("Can only handle NIOWebSocketErrors")
}
self.shouldKeepParsing = false

// If we've been asked to handle the errors here, we should.
// TODO(cory): Remove this in 2.0, in favour of `WebSocketProtocolErrorHandler`.
if self.automaticErrorHandling {
var data = context.channel.allocator.buffer(capacity: 2)
data.write(webSocketErrorCode: WebSocketErrorCode(error))
let frame = WebSocketFrame(fin: true,
opcode: .connectionClose,
data: data)
context.writeAndFlush(self.wrapInboundOut(frame)).whenComplete { (_: Result<Void, Error>) in
context.close(promise: nil)
}
}

context.fireErrorCaught(error)
}
}
@@ -170,7 +170,7 @@ public final class WebSocketUpgrader: HTTPServerProtocolUpgrader {
/// We never use the automatic error handling feature of the WebSocketFrameDecoder: we always use the separate channel
/// handler.
var upgradeFuture = context.pipeline.addHandler(WebSocketFrameEncoder()).flatMap {
context.pipeline.addHandler(ByteToMessageHandler(WebSocketFrameDecoder(maxFrameSize: self.maxFrameSize, automaticErrorHandling: false)))
context.pipeline.addHandler(ByteToMessageHandler(WebSocketFrameDecoder(maxFrameSize: self.maxFrameSize)))
}

if self.automaticErrorHandling {
@@ -18,6 +18,7 @@ import NIO
import NIOFoundationCompat
import NIOHTTP1
import NIOTLS
import NIOWebSocket

// This is NIO 2's 'NIO1 API Shims' module.
//
@@ -523,3 +524,10 @@ public typealias HTTPUpgradeErrors = HTTPServerUpgradeErrors

@available(*, deprecated, renamed: "NIOThreadPool")
public typealias BlockingIOThreadPool = NIOThreadPool

extension WebSocketFrameDecoder {
@available(*, deprecated, message: "automaticErrorHandling deprecated, use WebSocketProtocolErrorHandler instead")
public convenience init(maxFrameSize: Int = 1 << 14, automaticErrorHandling: Bool) {
self.init(maxFrameSize: maxFrameSize)
}
}
@@ -47,6 +47,7 @@ extension WebSocketFrameDecoderTest {
("testDecoderRejectsFragmentedControlFramesWithSeparateErrorHandling", testDecoderRejectsFragmentedControlFramesWithSeparateErrorHandling),
("testDecoderRejectsMultibyteControlFrameLengthsWithSeparateErrorHandling", testDecoderRejectsMultibyteControlFrameLengthsWithSeparateErrorHandling),
("testIgnoresFurtherDataAfterRejectedFrameWithSeparateErrorHandling", testIgnoresFurtherDataAfterRejectedFrameWithSeparateErrorHandling),
("testErrorHandlerDoesNotSwallowRandomErrors", testErrorHandlerDoesNotSwallowRandomErrors),
]
}
}
@@ -240,6 +240,7 @@ public class WebSocketFrameDecoderTest: XCTestCase {

public func testDecoderRejectsOverlongFrames() throws {
XCTAssertNoThrow(try self.decoderChannel.pipeline.addHandler(WebSocketFrameEncoder(), position: .first).wait())
XCTAssertNoThrow(try self.decoderChannel.pipeline.addHandler(WebSocketProtocolErrorHandler()).wait())

// A fake frame header that claims that the length of the frame is 16385 bytes,
// larger than the frame max.
@@ -260,6 +261,7 @@ public class WebSocketFrameDecoderTest: XCTestCase {

public func testDecoderRejectsFragmentedControlFrames() throws {
XCTAssertNoThrow(try self.decoderChannel.pipeline.addHandler(WebSocketFrameEncoder(), position: .first).wait())
XCTAssertNoThrow(try self.decoderChannel.pipeline.addHandler(WebSocketProtocolErrorHandler()).wait())

// A fake frame header that claims this is a fragmented ping frame.
self.buffer.writeBytes([0x09, 0x00])
@@ -279,6 +281,7 @@ public class WebSocketFrameDecoderTest: XCTestCase {

public func testDecoderRejectsMultibyteControlFrameLengths() throws {
XCTAssertNoThrow(try self.decoderChannel.pipeline.addHandler(WebSocketFrameEncoder(), position: .first).wait())
XCTAssertNoThrow(try self.decoderChannel.pipeline.addHandler(WebSocketProtocolErrorHandler()).wait())

// A fake frame header that claims this is a ping frame with 126 bytes of data.
self.buffer.writeBytes([0x89, 0x7E, 0x00, 0x7E])
@@ -300,6 +303,7 @@ public class WebSocketFrameDecoderTest: XCTestCase {
let swallower = CloseSwallower()
XCTAssertNoThrow(try self.decoderChannel.pipeline.addHandler(WebSocketFrameEncoder(), position: .first).wait())
XCTAssertNoThrow(try self.decoderChannel.pipeline.addHandler(swallower, position: .first).wait())
XCTAssertNoThrow(try self.decoderChannel.pipeline.addHandler(WebSocketProtocolErrorHandler()).wait())

// A fake frame header that claims this is a fragmented ping frame.
self.buffer.writeBytes([0x09, 0x00])
@@ -317,10 +321,17 @@ public class WebSocketFrameDecoderTest: XCTestCase {
XCTAssertNoThrow(XCTAssertEqual([0x88, 0x02, 0x03, 0xEA], try self.decoderChannel.readAllOutboundBytes()))

// Now write another broken frame, this time an overlong frame.
// No error should occur here.
self.buffer.clear()
self.buffer.writeBytes([0x81, 0xFE, 0x40, 0x01])
XCTAssertNoThrow(try self.decoderChannel.writeInbound(self.buffer))
let wrongFrame: [UInt8] = [0x81, 0xFE, 0x40, 0x01]
self.buffer.writeBytes(wrongFrame)
XCTAssertThrowsError(try self.decoderChannel.writeInbound(self.buffer)) { error in
if case .some(.dataReceivedInErrorState(let data)) = error as? ByteToMessageDecoderError {
// ok
XCTAssertEqual(wrongFrame, Array(data.1.readableBytesView))
} else {
XCTFail("unexpected error: \(error)")
}
}

// No extra data should have been sent.
XCTAssertNoThrow(XCTAssertNil(try self.decoderChannel.readOutbound()))
@@ -357,7 +368,7 @@ public class WebSocketFrameDecoderTest: XCTestCase {
public func testDecoderRejectsOverlongFramesWithNoAutomaticErrorHandling() {
// We need to insert a decoder that doesn't do error handling. We still insert
// an encoder because we want to fail gracefully if a frame is written.
self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder(automaticErrorHandling: false)))
self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder()))
XCTAssertNoThrow(try self.decoderChannel.pipeline.addHandler(WebSocketFrameEncoder(), position: .first).wait())

// A fake frame header that claims that the length of the frame is 16385 bytes,
@@ -380,7 +391,7 @@ public class WebSocketFrameDecoderTest: XCTestCase {
public func testDecoderRejectsFragmentedControlFramesWithNoAutomaticErrorHandling() throws {
// We need to insert a decoder that doesn't do error handling. We still insert
// an encoder because we want to fail gracefully if a frame is written.
self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder(automaticErrorHandling: false)))
self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder()))
XCTAssertNoThrow(try self.decoderChannel.pipeline.addHandler(WebSocketFrameEncoder(), position: .first).wait())

// A fake frame header that claims this is a fragmented ping frame.
@@ -402,7 +413,7 @@ public class WebSocketFrameDecoderTest: XCTestCase {
public func testDecoderRejectsMultibyteControlFrameLengthsWithNoAutomaticErrorHandling() throws {
// We need to insert a decoder that doesn't do error handling. We still insert
// an encoder because we want to fail gracefully if a frame is written.
self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder(automaticErrorHandling: false)))
self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder()))
XCTAssertNoThrow(try self.decoderChannel.pipeline.addHandler(WebSocketFrameEncoder(), position: .first).wait())

// A fake frame header that claims this is a ping frame with 126 bytes of data.
@@ -424,7 +435,7 @@ public class WebSocketFrameDecoderTest: XCTestCase {
func testIgnoresFurtherDataAfterRejectedFrameWithNoAutomaticErrorHandling() {
// We need to insert a decoder that doesn't do error handling. We still insert
// an encoder because we want to fail gracefully if a frame is written.
self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder(automaticErrorHandling: false)))
self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder()))
XCTAssertNoThrow(try self.decoderChannel.pipeline.addHandler(WebSocketFrameEncoder(), position: .first).wait())

// A fake frame header that claims this is a fragmented ping frame.
@@ -443,10 +454,17 @@ public class WebSocketFrameDecoderTest: XCTestCase {
XCTAssertNoThrow(XCTAssertEqual([], try self.decoderChannel.readAllOutboundBytes()))

// Now write another broken frame, this time an overlong frame.
// No error should occur here.
self.buffer.clear()
self.buffer.writeBytes([0x81, 0xFE, 0x40, 0x01])
XCTAssertNoThrow(try self.decoderChannel.writeInbound(self.buffer))
let wrongFrame: [UInt8] = [0x81, 0xFE, 0x40, 0x01]
self.buffer.writeBytes(wrongFrame)
XCTAssertThrowsError(try self.decoderChannel.writeInbound(self.buffer)) { error in
if case .some(.dataReceivedInErrorState(let data)) = error as? ByteToMessageDecoderError {
// ok
XCTAssertEqual(wrongFrame, Array(data.1.readableBytesView))
} else {
XCTFail("unexpected error: \(error)")
}
}

// No extra data should have been sent.
XCTAssertNoThrow(XCTAssertNil(try self.decoderChannel.readOutbound()))
@@ -455,7 +473,7 @@ public class WebSocketFrameDecoderTest: XCTestCase {
public func testDecoderRejectsOverlongFramesWithSeparateErrorHandling() throws {
// We need to insert a decoder that doesn't do error handling, and then a separate error
// handler.
self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder(automaticErrorHandling: false)))
self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder()))
XCTAssertNoThrow(try self.decoderChannel.pipeline.addHandler(WebSocketFrameEncoder(), position: .first).wait())
XCTAssertNoThrow(try self.decoderChannel.pipeline.addHandler(WebSocketProtocolErrorHandler()).wait())

@@ -479,7 +497,7 @@ public class WebSocketFrameDecoderTest: XCTestCase {
public func testDecoderRejectsFragmentedControlFramesWithSeparateErrorHandling() throws {
// We need to insert a decoder that doesn't do error handling, and then a separate error
// handler.
self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder(automaticErrorHandling: false)))
self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder()))
XCTAssertNoThrow(try self.decoderChannel.pipeline.addHandler(WebSocketFrameEncoder(), position: .first).wait())
XCTAssertNoThrow(try self.decoderChannel.pipeline.addHandler(WebSocketProtocolErrorHandler()).wait())

@@ -502,7 +520,7 @@ public class WebSocketFrameDecoderTest: XCTestCase {
public func testDecoderRejectsMultibyteControlFrameLengthsWithSeparateErrorHandling() throws {
// We need to insert a decoder that doesn't do error handling, and then a separate error
// handler.
self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder(automaticErrorHandling: false)))
self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder()))
XCTAssertNoThrow(try self.decoderChannel.pipeline.addHandler(WebSocketFrameEncoder(), position: .first).wait())
XCTAssertNoThrow(try self.decoderChannel.pipeline.addHandler(WebSocketProtocolErrorHandler()).wait())

@@ -526,7 +544,7 @@ public class WebSocketFrameDecoderTest: XCTestCase {
let swallower = CloseSwallower()
// We need to insert a decoder that doesn't do error handling, and then a separate error
// handler.
self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder(automaticErrorHandling: false)))
self.swapDecoder(for: ByteToMessageHandler(WebSocketFrameDecoder()))
XCTAssertNoThrow(try self.decoderChannel.pipeline.addHandler(WebSocketFrameEncoder(), position: .first).wait())
XCTAssertNoThrow(try self.decoderChannel.pipeline.addHandler(WebSocketProtocolErrorHandler()).wait())
XCTAssertNoThrow(try self.decoderChannel.pipeline.addHandler(swallower, position: .first).wait())
@@ -547,10 +565,17 @@ public class WebSocketFrameDecoderTest: XCTestCase {
XCTAssertNoThrow(XCTAssertEqual(try self.decoderChannel.readAllOutboundBytes(), [0x88, 0x02, 0x03, 0xEA]))

// Now write another broken frame, this time an overlong frame.
// No error should occur here.
self.buffer.clear()
self.buffer.writeBytes([0x81, 0xFE, 0x40, 0x01])
XCTAssertNoThrow(try self.decoderChannel.writeInbound(self.buffer))
let wrongFrame: [UInt8] = [0x81, 0xFE, 0x40, 0x01]
self.buffer.writeBytes(wrongFrame)
XCTAssertThrowsError(try self.decoderChannel.writeInbound(self.buffer)) { error in
if case .some(.dataReceivedInErrorState(let data)) = error as? ByteToMessageDecoderError {
// ok
XCTAssertEqual(wrongFrame, Array(data.1.readableBytesView))
} else {
XCTFail("unexpected error: \(error)")
}
}

// No extra data should have been sent.
XCTAssertNoThrow(XCTAssertNil(try self.decoderChannel.readOutbound()))
@@ -561,4 +586,32 @@ public class WebSocketFrameDecoderTest: XCTestCase {
// Take the handler out for cleanliness.
XCTAssertNoThrow(try self.decoderChannel.pipeline.removeHandler(swallower).wait())
}

func testErrorHandlerDoesNotSwallowRandomErrors() throws {
XCTAssertNoThrow(try self.decoderChannel.pipeline.addHandler(WebSocketFrameEncoder(), position: .first).wait())
XCTAssertNoThrow(try self.decoderChannel.pipeline.addHandler(WebSocketProtocolErrorHandler()).wait())

// A fake frame header that claims that the length of the frame is 16385 bytes,
// larger than the frame max.
self.buffer.writeBytes([0x81, 0xFE, 0x40, 0x01])

struct Dummy: Error {}

self.decoderChannel.pipeline.fireErrorCaught(Dummy())
XCTAssertThrowsError(try self.decoderChannel.throwIfErrorCaught()) { error in
XCTAssertNotNil(error as? Dummy, "unexpected error: \(error)")
}

do {
try self.decoderChannel.writeInbound(self.buffer)
XCTFail("did not throw")
} catch NIOWebSocketError.invalidFrameLength {
// OK
} catch {
XCTFail("Unexpected error: \(error)")
}

// We expect that an error frame will have been written out.
XCTAssertNoThrow(XCTAssertEqual([0x88, 0x02, 0x03, 0xF1], try self.decoderChannel.readAllOutboundBytes()))
}
}
Oops, something went wrong.

0 comments on commit 36a52e1

Please sign in to comment.
You can’t perform that action at this time.