Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 86 additions & 7 deletions Sources/NIOCore/AsyncChannel/AsyncChannelInboundStream.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,68 @@ public struct NIOAsyncChannelInboundStream<Inbound: Sendable>: Sendable {
@usableFromInline
typealias Producer = NIOThrowingAsyncSequenceProducer<Inbound, Error, NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark, NIOAsyncChannelInboundStreamChannelHandlerProducerDelegate>

/// A source used for driving a ``NIOAsyncChannelInboundStream`` during tests.
public struct TestSource {
@usableFromInline
internal let continuation: AsyncStream<Inbound>.Continuation

@inlinable
init(continuation: AsyncStream<Inbound>.Continuation) {
self.continuation = continuation
}

/// Yields the element to the inbound stream.
///
/// - Parameter element: The element to yield to the inbound stream.
@inlinable
public func yield(_ element: Inbound) {
self.continuation.yield(element)
}

/// Finished the inbound stream.
@inlinable
public func finish() {
self.continuation.finish()
}
}

#if swift(>=5.7)
@usableFromInline
enum _Backing: Sendable {
case asyncStream(AsyncStream<Inbound>)
case producer(Producer)
}
#else
// AsyncStream wasn't marked as `Sendable` in 5.6
@usableFromInline
enum _Backing: @unchecked Sendable {
case asyncStream(AsyncStream<Inbound>)
case producer(Producer)
}
#endif

/// The underlying async sequence.
@usableFromInline let _producer: Producer
@usableFromInline
let _backing: _Backing

/// Creates a new stream with a source for testing.
///
/// This is useful for writing unit tests where you want to drive a ``NIOAsyncChannelInboundStream``.
///
/// - Returns: A tuple containing the input stream and a test source to drive it.
@inlinable
public static func makeTestingStream() -> (Self, TestSource) {
var continuation: AsyncStream<Inbound>.Continuation!
let stream = AsyncStream<Inbound> { continuation = $0 }
let source = TestSource(continuation: continuation)
let inputStream = Self(stream: stream)
return (inputStream, source)
}

@inlinable
init(stream: AsyncStream<Inbound>) {
self._backing = .asyncStream(stream)
}

@inlinable
init<HandlerInbound: Sendable>(
Expand All @@ -48,7 +108,7 @@ public struct NIOAsyncChannelInboundStream<Inbound: Sendable>: Sendable {
)
handler.source = sequence.source
try channel.pipeline.syncOperations.addHandler(handler)
self._producer = sequence.sequence
self._backing = .producer(sequence.sequence)
}

/// Creates a new ``NIOAsyncChannelInboundStream`` which is used when the pipeline got synchronously wrapped.
Expand Down Expand Up @@ -101,23 +161,42 @@ extension NIOAsyncChannelInboundStream: AsyncSequence {

@_spi(AsyncChannel)
public struct AsyncIterator: AsyncIteratorProtocol {
@usableFromInline var _iterator: Producer.AsyncIterator
@usableFromInline
enum _Backing {
case asyncStream(AsyncStream<Inbound>.Iterator)
case producer(Producer.AsyncIterator)
}

@usableFromInline var _backing: _Backing

@inlinable
init(_ iterator: Producer.AsyncIterator) {
self._iterator = iterator
init(_ backing: NIOAsyncChannelInboundStream<Inbound>._Backing) {
switch backing {
case .asyncStream(let asyncStream):
self._backing = .asyncStream(asyncStream.makeAsyncIterator())
case .producer(let producer):
self._backing = .producer(producer.makeAsyncIterator())
}
}

@inlinable @_spi(AsyncChannel)
public mutating func next() async throws -> Element? {
return try await self._iterator.next()
switch self._backing {
case .asyncStream(var iterator):
let value = await iterator.next()
self._backing = .asyncStream(iterator)
return value

case .producer(let iterator):
return try await iterator.next()
}
}
}

@inlinable
@_spi(AsyncChannel)
public func makeAsyncIterator() -> AsyncIterator {
return AsyncIterator(self._producer.makeAsyncIterator())
return AsyncIterator(self._backing)
}
}

Expand Down
98 changes: 90 additions & 8 deletions Sources/NIOCore/AsyncChannel/AsyncChannelOutboundWriter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,64 @@ public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
@usableFromInline
typealias _Writer = NIOAsyncChannelOutboundWriterHandler<OutboundOut>.Writer

/// An `AsyncSequence` backing a ``NIOAsyncChannelOutboundWriter`` for testing purposes.
public struct TestSink: AsyncSequence {
public typealias Element = OutboundOut

@usableFromInline
internal let stream: AsyncStream<OutboundOut>

@usableFromInline
internal let continuation: AsyncStream<OutboundOut>.Continuation

@inlinable
init(
stream: AsyncStream<OutboundOut>,
continuation: AsyncStream<OutboundOut>.Continuation
) {
self.stream = stream
self.continuation = continuation
}

public func makeAsyncIterator() -> AsyncIterator {
AsyncIterator(iterator: self.stream.makeAsyncIterator())
}

public struct AsyncIterator: AsyncIteratorProtocol {
@usableFromInline
internal var iterator: AsyncStream<OutboundOut>.AsyncIterator

@inlinable
init(iterator: AsyncStream<OutboundOut>.AsyncIterator) {
self.iterator = iterator
}

public mutating func next() async -> Element? {
await self.iterator.next()
}
}
}

@usableFromInline
let _outboundWriter: _Writer
enum Backing: Sendable {
case asyncStream(AsyncStream<OutboundOut>.Continuation)
case writer(_Writer)
}

@usableFromInline
internal let _backing: Backing

/// Creates a new ``NIOAsyncChannelOutboundWriter`` backed by a ``NIOAsyncChannelOutboundWriter/TestSink``.
/// This is mostly useful for testing purposes where one wants to observe the written data.
@inlinable
public static func makeTestingWriter() -> (Self, TestSink) {
var continuation: AsyncStream<OutboundOut>.Continuation!
let asyncStream = AsyncStream<OutboundOut> { continuation = $0 }
let writer = Self(continuation: continuation)
let sink = TestSink(stream: asyncStream, continuation: continuation)

return (writer, sink)
}

@inlinable
init(
Expand All @@ -44,12 +100,12 @@ public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {

try channel.pipeline.syncOperations.addHandler(handler)

self._outboundWriter = writer.writer
self._backing = .writer(writer.writer)
}

@inlinable
init(outboundWriter: NIOAsyncChannelOutboundWriterHandler<OutboundOut>.Writer) {
self._outboundWriter = outboundWriter
init(continuation: AsyncStream<OutboundOut>.Continuation) {
self._backing = .asyncStream(continuation)
}

/// Send a write into the ``ChannelPipeline`` and flush it right away.
Expand All @@ -58,7 +114,12 @@ public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
@inlinable
@_spi(AsyncChannel)
public func write(_ data: OutboundOut) async throws {
try await self._outboundWriter.yield(data)
switch self._backing {
case .asyncStream(let continuation):
continuation.yield(data)
case .writer(let writer):
try await writer.yield(data)
}
}

/// Send a sequence of writes into the ``ChannelPipeline`` and flush them right away.
Expand All @@ -67,7 +128,14 @@ public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
@inlinable
@_spi(AsyncChannel)
public func write<Writes: Sequence>(contentsOf sequence: Writes) async throws where Writes.Element == OutboundOut {
try await self._outboundWriter.yield(contentsOf: sequence)
switch self._backing {
case .asyncStream(let continuation):
for data in sequence {
continuation.yield(data)
}
case .writer(let writer):
try await writer.yield(contentsOf: sequence)
}
}

/// Send a sequence of writes into the ``ChannelPipeline`` and flush them right away.
Expand All @@ -77,7 +145,7 @@ public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
@_spi(AsyncChannel)
public func write<Writes: AsyncSequence>(contentsOf sequence: Writes) async throws where Writes.Element == OutboundOut {
for try await data in sequence {
try await self._outboundWriter.yield(data)
try await self.write(data)
}
}

Expand All @@ -86,6 +154,20 @@ public struct NIOAsyncChannelOutboundWriter<OutboundOut: Sendable>: Sendable {
/// This might trigger a half closure if the ``NIOAsyncChannel`` was configured to support it.
@_spi(AsyncChannel)
public func finish() {
self._outboundWriter.finish()
switch self._backing {
case .asyncStream(let continuation):
continuation.finish()
case .writer(let writer):
writer.finish()
}
}
}

#if swift(>=5.7)
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
extension NIOAsyncChannelOutboundWriter.TestSink: Sendable {}
#else
// AsyncStream wasn't marked as `Sendable` in 5.6
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
extension NIOAsyncChannelOutboundWriter.TestSink: @unchecked Sendable {}
#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

@_spi(AsyncChannel) @testable import NIOCore
import XCTest

final class AsyncChannelInboundStreamTests: XCTestCase {
func testTestingStream() async throws {
let (stream, source) = NIOAsyncChannelInboundStream<Int>.makeTestingStream()

try await withThrowingTaskGroup(of: [Int].self) { group in
group.addTask {
var elements = [Int]()
for try await element in stream {
elements.append(element)
}
return elements
}

for element in 0...10 {
source.yield(element)
}
source.finish()

let result = try await group.next()
XCTAssertEqual(result, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

@_spi(AsyncChannel) @testable import NIOCore
import XCTest

final class AsyncChannelOutboundWriterTests: XCTestCase {
func testTestingWriter() async throws {
let (writer, sink) = NIOAsyncChannelOutboundWriter<Int>.makeTestingWriter()

try await withThrowingTaskGroup(of: [Int].self) { group in
group.addTask {
var elements = [Int]()
for try await element in sink {
elements.append(element)
}
return elements
}

for element in 0...10 {
try await writer.write(element)
}
writer.finish()

let result = try await group.next()
XCTAssertEqual(result, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
}
}
}