Skip to content

Commit

Permalink
Properly handle autoread (apple#35)
Browse files Browse the repository at this point in the history
Motivation:

NIO's backpressure management systems rely on catching calls to read()
in the ChannelPipeline. This only works if your Channel implementation
actually makes those calls to read()! Sadly, SSHChildChannel did not:
we instead just recursed directly into the Channel's read0().

While I'm here I also noticed that we produced a lot of
channelReadComplete messages in the pipeline unnecessarily. We did this
in the service of having a read "fast-path" where new frames would be
automatically delivered into the ChannelPipeline without being buffered
until channelReadComplete. This fast-path is probably inadvisable: I
don't think it provided a meaningful performance benefit anyway,
especially as the removal of that fast-path allowed us to suppress
unnecessary calls to channelReadComplete, and therefore unnecessary
calls to read() as well.

Modifications:

- Forced autoRead through the pipeline.
- Removed read() fast-path.
- Suppressed channelReadComplete if no frames were delivered.
- Updated tests.

Result:

Clearer I/O path, better auto-read handling.
  • Loading branch information
Lukasa authored and artemredkin committed Aug 19, 2020
1 parent 57e4712 commit 7b57cae
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 22 deletions.
38 changes: 16 additions & 22 deletions Sources/NIOSSH/Child Channels/SSHChildChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -450,9 +450,7 @@ extension SSHChildChannel: Channel, ChannelCore {
self.changeWritability(to: false)
}
self.unbufferOutboundEvents()
if self.autoRead {
self.read0()
}
self.tryToAutoRead()
self.deliverPendingWrites()
self.writePendingToMultiplexer()
if let promise = self.userActivatePromise {
Expand Down Expand Up @@ -600,22 +598,30 @@ extension SSHChildChannel: Channel, ChannelCore {
return
}

// If there are no pending reads, do nothing.
guard self.pendingReads.count > 0 else {
return
}

// Ok, we're satisfying a read here.
self.unsatisfiedRead = false
self.deliverPendingReads()

// If auto-read is turned on, recurse into read0.
// This cannot recurse indefinitely unless frames are being delivered
// by the read stacks, which is generally fairly unlikely to continue unbounded.
if self.autoRead {
self.read0()
}
self.tryToAutoRead()
}

private func changeWritability(to newWritability: Bool) {
self._isWritable.store(newWritability)
self.pipeline.fireChannelWritabilityChanged()
}

private func tryToAutoRead() {
if self.autoRead {
// If auto-read is turned on, recurse into channelPipeline.read().
// This cannot recurse indefinitely unless frames are being delivered
// by the read stacks, which is generally fairly unlikely to continue unbounded.
self.pipeline.read()
}
}
}

// MARK: - Functions used to manage pending reads and writes.
Expand Down Expand Up @@ -761,10 +767,6 @@ extension SSHChildChannel {
if self.allowRemoteHalfClosure {
// Hey, remote half-closure is allowed! That's handy! We queue this with the reads to avoid it being re-ordered.
self.pendingReads.append(.eof)

if self.unsatisfiedRead {
self.tryToRead()
}
} else {
// We don't support remote half-closure. That puts us in a bit of a bind. We have to promote this up to full-closure.
// We need to send a channel close, so let's just do that: the outbound state machine will make this a full-closure.
Expand Down Expand Up @@ -803,10 +805,6 @@ extension SSHChildChannel {
// State machine is happy. Handle the flow control.
try self.windowManager.bufferFlowControlledBytes(message.data.readableBytes)
self.pendingReads.append(.data(.init(message)))

if self.unsatisfiedRead {
self.tryToRead()
}
}

private func handleInboundChannelExtendedData(_ message: SSHMessage.ChannelExtendedDataMessage) throws {
Expand All @@ -815,10 +813,6 @@ extension SSHChildChannel {
// State machine is happy. Handle the flow control.
try self.windowManager.bufferFlowControlledBytes(message.data.readableBytes)
self.pendingReads.append(.data(.init(message)))

if self.unsatisfiedRead {
self.tryToRead()
}
}

private func handleInboundChannelRequest(_ message: SSHMessage.ChannelRequestMessage) throws {
Expand Down
44 changes: 44 additions & 0 deletions Tests/NIOSSHTests/ChildChannelMultiplexerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@ final class ErrorLoggingHandler: ChannelInboundHandler {
}
}

final class ReadCountingHandler: ChannelOutboundHandler {
typealias OutboundIn = Any
typealias OutboundOut = Any

var readCount = 0

func read(context: ChannelHandlerContext) {
self.readCount += 1
context.read()
}
}

final class ReadRecordingHandler: ChannelInboundHandler {
typealias InboundIn = SSHChannelData
typealias InboundOut = SSHChannelData
Expand Down Expand Up @@ -651,6 +663,7 @@ final class ChildChannelMultiplexerTests: XCTestCase {
// Delivering two new messages causes one read.
for _ in 0 ..< 2 {
XCTAssertNoThrow(try harness.multiplexer.receiveMessage(self.data(peerChannelID: channelID!, data: buffer)))
harness.multiplexer.parentChannelReadComplete()
}
XCTAssertEqual(readRecorder.reads, Array(repeating: .init(type: .channel, data: .byteBuffer(buffer)), count: 6))
XCTAssertEqual(harness.flushedMessages.count, 1)
Expand Down Expand Up @@ -1508,4 +1521,35 @@ final class ChildChannelMultiplexerTests: XCTestCase {

XCTAssertEqual(initializedChannels, channelTypes)
}

func testAutoReadOnChildChannel() throws {
let readCounter = ReadCountingHandler()

let harness = self.harness { channel, _ in
channel.pipeline.addHandler(readCounter)
}
defer {
harness.finish()
}

XCTAssertNoThrow(try harness.multiplexer.receiveMessage(self.openRequest(channelID: 1)))
let channelID = self.assertChannelOpenConfirmation(harness.flushedMessages.first, recipientChannel: 1)
XCTAssertEqual(readCounter.readCount, 1)

// Now we're going to deliver some data. These should not propagate into the channel until channelReadComplete.
var buffer = ByteBufferAllocator().buffer(capacity: 1024)
buffer.writeString("hello, world!")

for _ in 0 ..< 5 {
XCTAssertNoThrow(try harness.multiplexer.receiveMessage(self.data(peerChannelID: channelID!, data: buffer)))
}
XCTAssertEqual(readCounter.readCount, 1)

harness.multiplexer.parentChannelReadComplete()
XCTAssertEqual(readCounter.readCount, 2)

// If no reads were delivered, further channel read completes do not trigger read() calls.
harness.multiplexer.parentChannelReadComplete()
XCTAssertEqual(readCounter.readCount, 2)
}
}

0 comments on commit 7b57cae

Please sign in to comment.