diff --git a/Sources/NIO/BaseSocketChannel.swift b/Sources/NIO/BaseSocketChannel.swift index c80c66ce6b5..1e4f9865e06 100644 --- a/Sources/NIO/BaseSocketChannel.swift +++ b/Sources/NIO/BaseSocketChannel.swift @@ -18,14 +18,16 @@ private struct SocketChannelLifecycleManager { // MARK: Types private enum State { case fresh - case registered + case preRegistered // register() has been run but the selector doesn't know about it yet + case fullyRegistered // fully registered, ie. the selector knows about it case activated case closed } private enum Event { case activate - case register + case beginRegistration + case finishRegistration case close } @@ -66,8 +68,13 @@ private struct SocketChannelLifecycleManager { } @inline(__always) // we need to return a closure here and to not suffer from a potential allocation for that this must be inlined - internal mutating func register(promise: EventLoopPromise?) -> ((ChannelPipeline) -> Void) { - return self.moveState(event: .register, promise: promise) + internal mutating func beginRegistration(promise: EventLoopPromise?) -> ((ChannelPipeline) -> Void) { + return self.moveState(event: .beginRegistration, promise: promise) + } + + @inline(__always) // we need to return a closure here and to not suffer from a potential allocation for that this must be inlined + internal mutating func finishRegistration(promise: EventLoopPromise?) -> ((ChannelPipeline) -> Void) { + return self.moveState(event: .finishRegistration, promise: promise) } @inline(__always) // we need to return a closure here and to not suffer from a potential allocation for that this must be inlined @@ -87,21 +94,28 @@ private struct SocketChannelLifecycleManager { switch (self.currentState, event) { // origin: .fresh - case (.fresh, .register): - return self.doStateTransfer(newState: .registered, promise: promise) { pipeline in + case (.fresh, .beginRegistration): + return self.doStateTransfer(newState: .preRegistered, promise: promise) { pipeline in pipeline.fireChannelRegistered0() } case (.fresh, .close): return self.doStateTransfer(newState: .closed, promise: promise) { (_: ChannelPipeline) in } - // origin: .registered - case (.registered, .activate): + // origin: .preRegistered + case (.preRegistered, .finishRegistration): + return self.doStateTransfer(newState: .fullyRegistered, promise: promise) { pipeline in + // we don't tell the user about this + } + + // origin: .fullyRegistered + case (.fullyRegistered, .activate): return self.doStateTransfer(newState: .activated, promise: promise) { pipeline in pipeline.fireChannelActive0() } - case (.registered, .close): + // origin: .preRegistered || .fullyRegistered + case (.preRegistered, .close), (.fullyRegistered, .close): return self.doStateTransfer(newState: .closed, promise: promise) { pipeline in pipeline.fireChannelUnregistered0() } @@ -114,11 +128,16 @@ private struct SocketChannelLifecycleManager { } // bad transitions - case (.fresh, .activate), // should go through .registered first - (.registered, .register), // already registered - (.activated, .activate), // already activated - (.activated, .register), // already registered (and activated) - (.closed, _): // already closed + case (.fresh, .activate), // should go through .registered first + (.preRegistered, .activate), // need to first be fully registered + (.preRegistered, .beginRegistration), // already registered + (.fullyRegistered, .beginRegistration), // already registered + (.activated, .activate), // already activated + (.activated, .beginRegistration), // already fully registered (and activated) + (.activated, .finishRegistration), // already fully registered (and activated) + (.fullyRegistered, .finishRegistration), // already fully registered + (.fresh, .finishRegistration), // need to register lazily first + (.closed, _): // already closed self.badTransition(event: event) } } @@ -143,12 +162,22 @@ private struct SocketChannelLifecycleManager { return self.currentState == .activated } - internal var isRegistered: Bool { + internal var isPreRegistered: Bool { assert(self.eventLoop.inEventLoop) switch self.currentState { case .fresh, .closed: return false - case .registered, .activated: + case .preRegistered, .fullyRegistered, .activated: + return true + } + } + + internal var isRegisteredFully: Bool { + assert(self.eventLoop.inEventLoop) + switch self.currentState { + case .fresh, .closed, .preRegistered: + return false + case .fullyRegistered, .activated: return true } } @@ -264,7 +293,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { var isRegistered: Bool { assert(self.eventLoop.inEventLoop) - return self.lifecycleManager.isRegistered + return self.lifecycleManager.isPreRegistered } internal var selectable: T { @@ -463,7 +492,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { // We only want to call read0() or pauseRead0() if we already registered to the EventLoop if not this will be automatically done // once register0 is called. Beside this we also only need to do it when the value actually change. - if self.lifecycleManager.isRegistered && old != auto { + if self.lifecycleManager.isPreRegistered && old != auto { if auto { read0() } else { @@ -536,7 +565,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { promise?.fail(error: ChannelError.ioOnClosedChannel) return } - guard self.isRegistered else { + guard self.lifecycleManager.isPreRegistered else { promise?.fail(error: ChannelLifecycleError.inappropriateOperationForState) return } @@ -598,7 +627,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { } if !isWritePending() && flushNow() == .register { - assert(self.lifecycleManager.isRegistered) + assert(self.lifecycleManager.isPreRegistered) registerForWritable() } } @@ -611,7 +640,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { } readPending = true - if self.lifecycleManager.isRegistered { + if self.lifecycleManager.isPreRegistered { registerForReadable() } } @@ -619,14 +648,14 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { private final func pauseRead0() { assert(eventLoop.inEventLoop) - if self.lifecycleManager.isRegistered { + if self.lifecycleManager.isPreRegistered { unregisterForReadable() } } private final func registerForReadable() { assert(eventLoop.inEventLoop) - assert(self.lifecycleManager.isRegistered) + assert(self.lifecycleManager.isRegisteredFully) guard !self.lifecycleManager.hasSeenEOFNotification else { // we have seen an EOF notification before so there's no point in registering for reads @@ -642,7 +671,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { internal final func unregisterForReadable() { assert(eventLoop.inEventLoop) - assert(self.lifecycleManager.isRegistered) + assert(self.lifecycleManager.isRegisteredFully) guard self.interestedEvent.contains(.read) else { return @@ -651,6 +680,14 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { self.safeReregister(interested: self.interestedEvent.subtracting(.read)) } + /// Closes the this `BaseChannelChannel` and fulfills `promise` with the result of the _close_ operation. + /// So unless either the deregistration or the close itself fails, `promise` will be succeeded regardless of + /// `error`. `error` is used to fail outstanding writes (if any) and the `connectPromise` if set. + /// + /// - parameters: + /// - error: The error to fail the outstanding (if any) writes/connect with. + /// - mode: The close mode, must be `.all` for `BaseSocketChannel` + /// - promise: The promise that gets notified about the result of the deregistration/close operations. public func close0(error: Error, mode: CloseMode, promise: EventLoopPromise?) { assert(eventLoop.inEventLoop) @@ -731,19 +768,23 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { return } - guard !self.lifecycleManager.isRegistered else { + guard !self.lifecycleManager.isPreRegistered else { promise?.fail(error: ChannelLifecycleError.inappropriateOperationForState) return } - // Was not registered yet so do it now. - do { - // We always register with interested .none and will just trigger readIfNeeded0() later to re-register if needed. - try self.safeRegister(interested: [.readEOF, .reset]) - self.lifecycleManager.register(promise: promise)(self.pipeline) - } catch { + guard self.selectableEventLoop.isOpen else { + let error = EventLoopError.shutdown + self.pipeline.fireErrorCaught0(error: error) + // `close0`'s error is about the result of the `close` operation, ... + self.close0(error: error, mode: .all, promise: nil) + // ... therefore we need to fail the registration `promise` separately. promise?.fail(error: error) + return } + + // we can't fully register yet as epoll would give us EPOLLHUP if bind/connect wasn't called yet. + self.lifecycleManager.beginRegistration(promise: promise)(self.pipeline) } public final func registerAlreadyConfigured0(promise: EventLoopPromise?) { @@ -751,7 +792,12 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { assert(self.isOpen) assert(!self.lifecycleManager.isActive) register0(promise: nil) - becomeActive0(promise: promise) + if self.lifecycleManager.isPreRegistered { + try! becomeFullyRegistered0() + if self.lifecycleManager.isRegisteredFully { + becomeActive0(promise: promise) + } + } } public final func triggerUserOutboundEvent0(_ event: Any, promise: EventLoopPromise?) { @@ -772,7 +818,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { private func finishConnect() { assert(eventLoop.inEventLoop) - assert(self.lifecycleManager.isRegistered) + assert(self.lifecycleManager.isPreRegistered) if let connectPromise = pendingConnect { assert(!self.lifecycleManager.isActive) @@ -799,7 +845,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { assert(eventLoop.inEventLoop) if self.isOpen { - assert(self.lifecycleManager.isRegistered) + assert(self.lifecycleManager.isPreRegistered) unregisterForWritable() } } @@ -810,8 +856,8 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { // we can't be not active but still registered here; this would mean that we got a notification about a // channel before we're ready to receive them. - assert(self.lifecycleManager.isActive || !self.lifecycleManager.isRegistered, - "illegal state: active: \(self.lifecycleManager.isActive), registered: \(self.lifecycleManager.isRegistered)") + assert(self.lifecycleManager.isActive || !self.lifecycleManager.isPreRegistered, + "illegal state: \(self): active: \(self.lifecycleManager.isActive), pre-registered: \(self.lifecycleManager.isPreRegistered)") self.readEOF0() @@ -820,7 +866,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { } final func readEOF0() { - if self.lifecycleManager.isRegistered { + if self.lifecycleManager.isRegisteredFully { // we're unregistering from `readEOF` here as we want this to be one-shot. We're then synchronously // reading all input until the EOF that we're guaranteed to see. After that `readEOF` becomes uninteresting // and would anyway fire constantly. @@ -835,7 +881,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { case .error: // we should be unregistered and inactive now (as `readable0` would've called close). assert(!self.lifecycleManager.isActive) - assert(!self.lifecycleManager.isRegistered) + assert(!self.lifecycleManager.isPreRegistered) break loop case .normal(.none): preconditionFailure("got .readEOF and read returned not reading any bytes, nor EOF.") @@ -854,7 +900,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { self.readEOF0() if self.socket.isOpen { - assert(self.lifecycleManager.isRegistered) + assert(self.lifecycleManager.isPreRegistered) let error: IOError // if the socket is still registered (and therefore open), let's try to get the actual socket error from the socket do { @@ -873,7 +919,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { self.close0(error: error, mode: .all, promise: nil) } } - assert(!self.lifecycleManager.isRegistered) + assert(!self.lifecycleManager.isPreRegistered) } public final func readable() { @@ -973,7 +1019,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { return } - guard self.lifecycleManager.isRegistered else { + guard self.lifecycleManager.isPreRegistered else { promise?.fail(error: ChannelLifecycleError.inappropriateOperationForState) return } @@ -987,13 +1033,14 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { } else { pendingConnect = eventLoop.newPromise() } + try becomeFullyRegistered0() registerForWritable() } else { self.updateCachedAddressesFromSocket() becomeActive0(promise: promise) } } catch let error { - assert(self.lifecycleManager.isRegistered) + assert(self.lifecycleManager.isPreRegistered) // We would like to have this assertion here, but we want to be able to go through this // code path in cases where connect() is being called on channels that are already active. //assert(!self.lifecycleManager.isActive) @@ -1018,7 +1065,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { private final func safeReregister(interested: SelectorEventSet) { assert(eventLoop.inEventLoop) - assert(self.lifecycleManager.isRegistered) + assert(self.lifecycleManager.isRegisteredFully) guard self.isOpen else { assert(self.interestedEvent == .reset, "interestedEvent=\(self.interestedEvent) event though we're closed") @@ -1039,7 +1086,7 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { private func safeRegister(interested: SelectorEventSet) throws { assert(eventLoop.inEventLoop) - assert(!self.lifecycleManager.isRegistered) + assert(!self.lifecycleManager.isRegisteredFully) guard self.isOpen else { throw ChannelError.ioOnClosedChannel @@ -1055,8 +1102,27 @@ class BaseSocketChannel: SelectableChannel, ChannelCore { } } + final func becomeFullyRegistered0() throws { + assert(self.eventLoop.inEventLoop) + assert(self.lifecycleManager.isPreRegistered) + assert(!self.lifecycleManager.isRegisteredFully) + + // We always register with interested .none and will just trigger readIfNeeded0() later to re-register if needed. + try self.safeRegister(interested: [.readEOF, .reset]) + self.lifecycleManager.finishRegistration(promise: nil)(self.pipeline) + } + final func becomeActive0(promise: EventLoopPromise?) { - assert(eventLoop.inEventLoop) + assert(self.eventLoop.inEventLoop) + assert(self.lifecycleManager.isPreRegistered) + if !self.lifecycleManager.isRegisteredFully { + do { + try self.becomeFullyRegistered0() + } catch { + self.close0(error: error, mode: .all, promise: promise) + return + } + } self.lifecycleManager.activate(promise: promise)(self.pipeline) self.readIfNeeded0() } diff --git a/Sources/NIO/EventLoop.swift b/Sources/NIO/EventLoop.swift index c320ee58086..71f2cb43e6a 100644 --- a/Sources/NIO/EventLoop.swift +++ b/Sources/NIO/EventLoop.swift @@ -344,6 +344,12 @@ internal final class SelectableEventLoop: EventLoop { _addresses.deallocate() } + /// Is this `SelectableEventLoop` still open (ie. not shutting down or shut down) + internal var isOpen: Bool { + assert(self.inEventLoop) + return self.lifecycleState == .open + } + /// Register the given `SelectableChannel` with this `SelectableEventLoop`. After this point all I/O for the `SelectableChannel` will be processed by this `SelectableEventLoop` until it /// is deregistered by calling `deregister`. public func register(channel: C) throws { diff --git a/Tests/NIOTests/ChannelTests+XCTest.swift b/Tests/NIOTests/ChannelTests+XCTest.swift index 707d3fa358e..dd5729ca788 100644 --- a/Tests/NIOTests/ChannelTests+XCTest.swift +++ b/Tests/NIOTests/ChannelTests+XCTest.swift @@ -69,6 +69,8 @@ extension ChannelTests { ("testSocketErroringSynchronouslyCorrectlyTearsTheChannelDown", testSocketErroringSynchronouslyCorrectlyTearsTheChannelDown), ("testConnectWithECONNREFUSEDGetsTheRightError", testConnectWithECONNREFUSEDGetsTheRightError), ("testCloseInUnregister", testCloseInUnregister), + ("testLazyRegistrationWorksForServerSockets", testLazyRegistrationWorksForServerSockets), + ("testLazyRegistrationWorksForClientSockets", testLazyRegistrationWorksForClientSockets), ] } } diff --git a/Tests/NIOTests/ChannelTests.swift b/Tests/NIOTests/ChannelTests.swift index 50471e57e47..dac7ec46070 100644 --- a/Tests/NIOTests/ChannelTests.swift +++ b/Tests/NIOTests/ChannelTests.swift @@ -2317,6 +2317,49 @@ public class ChannelTests: XCTestCase { } + func testLazyRegistrationWorksForServerSockets() throws { + let group = MultiThreadedEventLoopGroup(numThreads: 1) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + let server = try ServerSocketChannel(eventLoop: group.next() as! SelectableEventLoop, + group: group, + protocolFamily: PF_INET) + defer { + XCTAssertNoThrow(try server.close().wait()) + } + XCTAssertNoThrow(try server.register().wait()) + XCTAssertNoThrow(try server.eventLoop.submit { + XCTAssertFalse(server.isActive) + }.wait()) + XCTAssertEqual(0, server.localAddress!.port!) + XCTAssertNoThrow(try server.bind(to: SocketAddress(ipAddress: "0.0.0.0", port: 0)).wait()) + XCTAssertNotEqual(0, server.localAddress!.port!) + } + + func testLazyRegistrationWorksForClientSockets() throws { + let group = MultiThreadedEventLoopGroup(numThreads: 1) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + let serverChannel = try ServerBootstrap(group: group) + .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) + .bind(host: "localhost", port: 0) + .wait() + + let client = try SocketChannel(eventLoop: group.next() as! SelectableEventLoop, + protocolFamily: serverChannel.localAddress!.protocolFamily) + defer { + XCTAssertNoThrow(try client.close().wait()) + } + XCTAssertNoThrow(try client.register().wait()) + XCTAssertNoThrow(try client.eventLoop.submit { + XCTAssertFalse(client.isActive) + }.wait()) + XCTAssertNoThrow(try client.connect(to: serverChannel.localAddress!).wait()) + XCTAssertTrue(client.isActive) + XCTAssertEqual(serverChannel.localAddress!, client.remoteAddress!) + } } fileprivate class VerifyConnectionFailureHandler: ChannelInboundHandler {