Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions Sources/NIOSSL/NIOSSLHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -98,25 +98,35 @@ public class NIOSSLHandler : ChannelInboundHandler, ChannelOutboundHandler, Remo
// keeping track of the state we're in properly before we do anything else.
let oldState = state
state = .closed
let channelError: NIOSSLError

switch oldState {
case .closed, .idle:
// Nothing to do, but discard any buffered writes we still have.
discardBufferedWrites(reason: ChannelError.ioOnClosedChannel)
// Return early
context.fireChannelInactive()
return
case .handshaking:
// In this case the channel is going through the doHandshake steps and
// a channelInactive is fired taking down the connection.
// This case propogates a .handshakeFailed instead of an .uncleanShutdown.
channelError = NIOSSLError.handshakeFailed(.sslError(BoringSSLError.buildErrorStack()))
default:
// This is a ragged EOF: we weren't sent a CLOSE_NOTIFY. We want to send a user
// event to notify about this before we propagate channelInactive. We also want to fail all
// these writes.
let shutdownPromise = self.shutdownPromise
self.shutdownPromise = nil
let closePromise = self.closePromise
self.closePromise = nil

shutdownPromise?.fail(NIOSSLError.uncleanShutdown)
closePromise?.fail(NIOSSLError.uncleanShutdown)
context.fireErrorCaught(NIOSSLError.uncleanShutdown)
discardBufferedWrites(reason: NIOSSLError.uncleanShutdown)
channelError = NIOSSLError.uncleanShutdown
}
let shutdownPromise = self.shutdownPromise
self.shutdownPromise = nil
let closePromise = self.closePromise
self.closePromise = nil

shutdownPromise?.fail(channelError)
closePromise?.fail(channelError)
context.fireErrorCaught(channelError)
discardBufferedWrites(reason: channelError)

context.fireChannelInactive()
}
Expand Down
1 change: 1 addition & 0 deletions Tests/NIOSSLTests/UnwrappingTests+XCTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ extension UnwrappingTests {
("testUnwrappingTimeout", testUnwrappingTimeout),
("testSuccessfulUnwrapCancelsTimeout", testSuccessfulUnwrapCancelsTimeout),
("testUnwrappingAndClosingShareATimeout", testUnwrappingAndClosingShareATimeout),
("testChannelInactiveDuringHandshake", testChannelInactiveDuringHandshake),
]
}
}
Expand Down
67 changes: 67 additions & 0 deletions Tests/NIOSSLTests/UnwrappingTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -963,4 +963,71 @@ final class UnwrappingTests: XCTestCase {
XCTAssertTrue(serverClosed)
XCTAssertTrue(unwrapped)
}

func testChannelInactiveDuringHandshake() throws {

let serverChannel = EmbeddedChannel()
let clientChannel = EmbeddedChannel()

var serverClosed = false
var serverUnwrapped = false
defer {
// The errors here are expected
XCTAssertThrowsError(try serverChannel.finish())
XCTAssertThrowsError(try clientChannel.finish())
}

let context = try assertNoThrowWithValue(configuredSSLContext())
let serverHandler = try assertNoThrowWithValue(NIOSSLServerHandler(context: context))
let clientHandler = try assertNoThrowWithValue(NIOSSLClientHandler(context: context, serverHostname: nil))
XCTAssertNoThrow(try serverChannel.pipeline.addHandler(NIOSSLServerHandler(context: context)).wait())
XCTAssertNoThrow(try clientChannel.pipeline.addHandler(clientHandler).wait())
let handshakeHandler = HandshakeCompletedHandler()
XCTAssertNoThrow(try clientChannel.pipeline.addHandler(handshakeHandler).wait())

serverChannel.closeFuture.whenComplete { _ in
serverClosed = true
}

// Place the guts of connectInMemory here to abruptly alter the handshake process
let addr = try assertNoThrowWithValue(SocketAddress(unixDomainSocketPath: "/tmp/whatever2"))
let _ = clientChannel.connect(to: addr)

XCTAssertFalse(serverClosed)

serverChannel.pipeline.fireChannelActive()
clientChannel.pipeline.fireChannelActive()
// doHandshakeStep process should start here out in NIOSSLHandler before fireChannelInactive
serverChannel.pipeline.fireChannelInactive()
clientChannel.pipeline.fireChannelInactive()

// Need to test this error as a BoringSSLError because that means success instead of an uncleanShutdown
do {
try interactInMemory(clientChannel: clientChannel, serverChannel: serverChannel)
} catch {
switch error as? NIOSSLError {
case .some(.handshakeFailed):
// Expected to fall into .handshakeFailed
break
default:
XCTFail("Unexpected error: \(error)")
}
}
clientHandler.stopTLS(promise: nil)

// Go through the process of closing and verifying the close on the server side.
XCTAssertFalse(serverUnwrapped)

let serverStopPromise: EventLoopPromise<Void> = serverChannel.eventLoop.makePromise()
serverStopPromise.futureResult.whenComplete { _ in
serverUnwrapped = true
}
serverHandler.stopTLS(promise: serverStopPromise)
XCTAssertNoThrow(try interactInMemory(clientChannel: clientChannel, serverChannel: serverChannel))

(serverChannel.eventLoop as! EmbeddedEventLoop).run()

XCTAssertTrue(serverClosed)
XCTAssertTrue(serverUnwrapped)
}
}