diff --git a/Sources/NIO/Channel.swift b/Sources/NIO/Channel.swift index c3702330376..6f272777768 100644 --- a/Sources/NIO/Channel.swift +++ b/Sources/NIO/Channel.swift @@ -244,6 +244,28 @@ public extension Channel { } } +public extension ChannelCore { + /// Unwraps the given `NIOAny` as a specific concrete type. + /// + /// This method is intended for use when writing custom `ChannelCore` implementations. + /// This can safely be called in methods like `write0` to extract data from the `NIOAny` + /// provided in those cases. + /// + /// Note that if the unwrap fails, this will cause a runtime trap. `ChannelCore` + /// implementations should be concrete about what types they support writing. If multiple + /// types are supported, considere using a tagged union to store the type information like + /// NIO's own `IOData`, which will minimise the amount of runtime type checking. + /// + /// - parameters: + /// - data: The `NIOAny` to unwrap. + /// - as: The type to extract from the `NIOAny`. + /// - returns: The content of the `NIOAny`. + @_inlineable + public func unwrapData(_ data: NIOAny, as: T.Type = T.self) -> T { + return data.forceAs() + } +} + /// An error that can occur on `Channel` operations. public enum ChannelError: Error { /// Tried to connect on a `Channel` that is already connecting. diff --git a/Sources/NIO/NIOAny.swift b/Sources/NIO/NIOAny.swift index 68e4ce59e8b..75c97711398 100644 --- a/Sources/NIO/NIOAny.swift +++ b/Sources/NIO/NIOAny.swift @@ -43,13 +43,14 @@ /// } /// } public struct NIOAny { - private let storage: _NIOAny + @_versioned + /* private but _versioned */ let _storage: _NIOAny /// Wrap a value in a `NIOAny`. In most cases you should not create a `NIOAny` directly using this constructor. /// The abstraction that accepts values of type `NIOAny` must also provide a mechanism to do the wrapping. An /// example is a `ChannelInboundHandler` which provides `self.wrapInboundOut(aValueOfTypeInboundOut)`. public init(_ value: T) { - self.storage = _NIOAny(value) + self._storage = _NIOAny(value) } enum _NIOAny { @@ -77,8 +78,9 @@ public struct NIOAny { /// Try unwrapping the wrapped message as `ByteBuffer`. /// /// returns: The wrapped `ByteBuffer` or `nil` if the wrapped message is not a `ByteBuffer`. + @_versioned @_inlineable func tryAsByteBuffer() -> ByteBuffer? { - if case .ioData(.byteBuffer(let bb)) = self.storage { + if case .ioData(.byteBuffer(let bb)) = self._storage { return bb } else { return nil @@ -88,19 +90,21 @@ public struct NIOAny { /// Force unwrapping the wrapped message as `ByteBuffer`. /// /// returns: The wrapped `ByteBuffer` or crash if the wrapped message is not a `ByteBuffer`. + @_versioned @_inlineable func forceAsByteBuffer() -> ByteBuffer { if let v = tryAsByteBuffer() { return v } else { - fatalError("tried to decode as type \(ByteBuffer.self) but found \(Mirror(reflecting: Mirror(reflecting: self.storage).children.first!.value).subjectType)") + fatalError("tried to decode as type \(ByteBuffer.self) but found \(Mirror(reflecting: Mirror(reflecting: self._storage).children.first!.value).subjectType)") } } /// Try unwrapping the wrapped message as `IOData`. /// /// returns: The wrapped `IOData` or `nil` if the wrapped message is not a `IOData`. + @_versioned @_inlineable func tryAsIOData() -> IOData? { - if case .ioData(let data) = self.storage { + if case .ioData(let data) = self._storage { return data } else { return nil @@ -110,19 +114,21 @@ public struct NIOAny { /// Force unwrapping the wrapped message as `IOData`. /// /// returns: The wrapped `IOData` or crash if the wrapped message is not a `IOData`. + @_versioned @_inlineable func forceAsIOData() -> IOData { if let v = tryAsIOData() { return v } else { - fatalError("tried to decode as type \(IOData.self) but found \(Mirror(reflecting: Mirror(reflecting: self.storage).children.first!.value).subjectType)") + fatalError("tried to decode as type \(IOData.self) but found \(Mirror(reflecting: Mirror(reflecting: self._storage).children.first!.value).subjectType)") } } /// Try unwrapping the wrapped message as `FileRegion`. /// /// returns: The wrapped `FileRegion` or `nil` if the wrapped message is not a `FileRegion`. + @_versioned @_inlineable func tryAsFileRegion() -> FileRegion? { - if case .ioData(.fileRegion(let f)) = self.storage { + if case .ioData(.fileRegion(let f)) = self._storage { return f } else { return nil @@ -132,19 +138,21 @@ public struct NIOAny { /// Force unwrapping the wrapped message as `FileRegion`. /// /// returns: The wrapped `FileRegion` or crash if the wrapped message is not a `FileRegion`. + @_versioned @_inlineable func forceAsFileRegion() -> FileRegion { if let v = tryAsFileRegion() { return v } else { - fatalError("tried to decode as type \(FileRegion.self) but found \(Mirror(reflecting: Mirror(reflecting: self.storage).children.first!.value).subjectType)") + fatalError("tried to decode as type \(FileRegion.self) but found \(Mirror(reflecting: Mirror(reflecting: self._storage).children.first!.value).subjectType)") } } /// Try unwrapping the wrapped message as `AddressedEnvelope`. /// /// returns: The wrapped `AddressedEnvelope` or `nil` if the wrapped message is not an `AddressedEnvelope`. + @_versioned @_inlineable func tryAsByteEnvelope() -> AddressedEnvelope? { - if case .bufferEnvelope(let e) = self.storage { + if case .bufferEnvelope(let e) = self._storage { return e } else { return nil @@ -154,19 +162,21 @@ public struct NIOAny { /// Force unwrapping the wrapped message as `AddressedEnvelope`. /// /// returns: The wrapped `AddressedEnvelope` or crash if the wrapped message is not an `AddressedEnvelope`. + @_versioned @_inlineable func forceAsByteEnvelope() -> AddressedEnvelope { if let e = tryAsByteEnvelope() { return e } else { - fatalError("tried to decode as type \(AddressedEnvelope.self) but found \(Mirror(reflecting: Mirror(reflecting: self.storage).children.first!.value).subjectType)") + fatalError("tried to decode as type \(AddressedEnvelope.self) but found \(Mirror(reflecting: Mirror(reflecting: self._storage).children.first!.value).subjectType)") } } /// Try unwrapping the wrapped message as `T`. /// /// returns: The wrapped `T` or `nil` if the wrapped message is not a `T`. + @_versioned @_inlineable func tryAsOther(type: T.Type = T.self) -> T? { - if case .other(let any) = self.storage { + if case .other(let any) = self._storage { return any as? T } else { return nil @@ -176,17 +186,19 @@ public struct NIOAny { /// Force unwrapping the wrapped message as `T`. /// /// returns: The wrapped `T` or crash if the wrapped message is not a `T`. + @_versioned @_inlineable func forceAsOther(type: T.Type = T.self) -> T { if let v = tryAsOther(type: type) { return v } else { - fatalError("tried to decode as type \(T.self) but found \(Mirror(reflecting: Mirror(reflecting: self.storage).children.first!.value).subjectType)") + fatalError("tried to decode as type \(T.self) but found \(Mirror(reflecting: Mirror(reflecting: self._storage).children.first!.value).subjectType)") } } /// Force unwrapping the wrapped message as `T`. /// /// returns: The wrapped `T` or crash if the wrapped message is not a `T`. + @_versioned @_inlineable func forceAs(type: T.Type = T.self) -> T { switch T.self { case let t where t == ByteBuffer.self: @@ -205,6 +217,7 @@ public struct NIOAny { /// Try unwrapping the wrapped message as `T`. /// /// returns: The wrapped `T` or `nil` if the wrapped message is not a `T`. + @_versioned @_inlineable func tryAs(type: T.Type = T.self) -> T? { switch T.self { case let t where t == ByteBuffer.self: @@ -223,8 +236,9 @@ public struct NIOAny { /// Unwrap the wrapped message. /// /// returns: The wrapped message. + @_versioned @_inlineable func asAny() -> Any { - switch self.storage { + switch self._storage { case .ioData(.byteBuffer(let bb)): return bb case .ioData(.fileRegion(let f)): diff --git a/Tests/LinuxMain.swift b/Tests/LinuxMain.swift index 603bd9f5f60..310b9f7a712 100644 --- a/Tests/LinuxMain.swift +++ b/Tests/LinuxMain.swift @@ -43,6 +43,7 @@ import XCTest testCase(ChannelTests.allTests), testCase(CircularBufferTests.allTests), testCase(CompositeErrorTests.allTests), + testCase(CustomChannelTests.allTests), testCase(DatagramChannelTests.allTests), testCase(EchoServerClientTest.allTests), testCase(EmbeddedChannelTest.allTests), diff --git a/Tests/NIOTests/CustomChannelTests+XCTest.swift b/Tests/NIOTests/CustomChannelTests+XCTest.swift new file mode 100644 index 00000000000..acf99449044 --- /dev/null +++ b/Tests/NIOTests/CustomChannelTests+XCTest.swift @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +// +// CustomChannelTests+XCTest.swift +// +import XCTest + +/// +/// NOTE: This file was generated by generate_linux_tests.rb +/// +/// Do NOT edit this file directly as it will be regenerated automatically when needed. +/// + +extension CustomChannelTests { + + static var allTests : [(String, (CustomChannelTests) -> () throws -> Void)] { + return [ + ("testWritingIntToSpecialChannel", testWritingIntToSpecialChannel), + ] + } +} + diff --git a/Tests/NIOTests/CustomChannelTests.swift b/Tests/NIOTests/CustomChannelTests.swift new file mode 100644 index 00000000000..6012c2af24e --- /dev/null +++ b/Tests/NIOTests/CustomChannelTests.swift @@ -0,0 +1,87 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import XCTest +import NIO +import NIOConcurrencyHelpers + +struct NotImplementedError: Error { } + +struct InvalidTypeError: Error { } + +/// A basic ChannelCore that expects write0 to receive a NIOAny containing an Int. +/// +/// Everything else either throws or returns a failed future, except for things that cannot, +/// which precondition instead. +private class IntChannelCore: ChannelCore { + func localAddress0() throws -> SocketAddress { + throw NotImplementedError() + } + + func remoteAddress0() throws -> SocketAddress { + throw NotImplementedError() + } + + func register0(promise: EventLoopPromise?) { + promise?.fail(error: NotImplementedError()) + } + + func bind0(to: SocketAddress, promise: EventLoopPromise?) { + promise?.fail(error: NotImplementedError()) + } + + func connect0(to: SocketAddress, promise: EventLoopPromise?) { + promise?.fail(error: NotImplementedError()) + } + + func write0(_ data: NIOAny, promise: EventLoopPromise?) { + _ = self.unwrapData(data, as: Int.self) + promise?.succeed(result: ()) + } + + func flush0() { + preconditionFailure("Must not flush") + } + + func read0() { + preconditionFailure("Must not ew") + } + + func close0(error: Error, mode: CloseMode, promise: EventLoopPromise?) { + promise?.fail(error: NotImplementedError()) + } + + func triggerUserOutboundEvent0(_ event: Any, promise: EventLoopPromise?) { + promise?.fail(error: NotImplementedError()) + } + + func channelRead0(_ data: NIOAny) { + preconditionFailure("Must not call channelRead0") + } + + func errorCaught0(error: Error) { + preconditionFailure("Must not call errorCaught0") + } +} + +class CustomChannelTests: XCTestCase { + func testWritingIntToSpecialChannel() throws { + let loop = EmbeddedEventLoop() + let intCore = IntChannelCore() + let writePromise: EventLoopPromise = loop.newPromise() + + intCore.write0(NIOAny(5), promise: writePromise) + XCTAssertNoThrow(try writePromise.futureResult.wait()) + } +}