Skip to content

Commit

Permalink
Add cancellation to NIOThreadPool's async runIfActive (#2679)
Browse files Browse the repository at this point in the history
Motivation:

To dedupe the thread pool from NIOFileSystem and NIOPosix, the one from
NIOPosix needs to gain support for cancellation for its async
`runIfActive` function.

Modification:

- Generate a work ID when submitting async work to the pool, add that ID
  to a set in the cancellation handler
- Check the existence of the ID in the cancel set when dequeuing the
  work

Result:

Queued NIOThreadPool tasks can be cancelled before they are run.
  • Loading branch information
glbrntt committed Mar 25, 2024
1 parent 10af7c9 commit 1f5be71
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 26 deletions.
86 changes: 64 additions & 22 deletions Sources/NIOPosix/NIOThreadPool.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//
//===----------------------------------------------------------------------===//

import Atomics
import DequeModule
import Dispatch
import NIOConcurrencyHelpers
Expand Down Expand Up @@ -58,13 +59,19 @@ public final class NIOThreadPool {

/// The work that should be done by the `NIOThreadPool`.
public typealias WorkItem = @Sendable (WorkItemState) -> Void

private struct IdentifiableWorkItem: Sendable {
var workItem: WorkItem
var id: Int?
}

private enum State {
/// The `NIOThreadPool` is already stopped.
case stopped
/// The `NIOThreadPool` is shutting down, the array has one boolean entry for each thread indicating if it has shut down already.
case shuttingDown([Bool])
/// The `NIOThreadPool` is up and running, the `CircularBuffer` containing the yet unprocessed `WorkItems`.
case running(Deque<WorkItem>)
case running(Deque<IdentifiableWorkItem>)
/// Temporary state used when mutating the .running(items). Used to avoid CoW copies.
/// It should never be "leaked" outside of the lock block.
case modifying
Expand All @@ -73,6 +80,26 @@ public final class NIOThreadPool {
private let lock = NIOLock()
private var threads: [NIOThread]? = nil // protected by `lock`
private var state: State = .stopped

// WorkItems don't have a handle so they can't be cancelled directly. Instead an ID is assigned
// to each cancellable work item and the IDs of each work item to cancel is stored in this set.
// The set is checked when dequeuing work items prior to running them, the presence of an ID
// indicates it should be cancelled. This approach makes cancellation cheap, but slow, as the
// task isn't cancelled until it's dequeued.
//
// Possible alternatives:
// - Removing items from the work queue on cancellation. This is linear and runs the risk of
// being expensive if a task tree with many enqueued work items is cancelled.
// - Storing an atomic 'is cancelled' flag with each work item. This adds an allocation per
// work item.
//
// If a future version of this thread pool has work items which do have a handle this set should
// be removed.
//
// Note: protected by 'lock'.
private var cancelledWorkIDs: Set<Int> = []
private let nextWorkID = ManagedAtomic(0)

public let numberOfThreads: Int
private let canBeStopped: Bool

Expand All @@ -99,7 +126,7 @@ public final class NIOThreadPool {
case .running(let items):
self.state = .modifying
queue.async {
items.forEach { $0(.cancelled) }
items.forEach { $0.workItem(.cancelled) }
}
self.state = .shuttingDown(Array(repeating: true, count: numberOfThreads))
(0..<numberOfThreads).forEach { _ in
Expand Down Expand Up @@ -133,15 +160,15 @@ public final class NIOThreadPool {
/// - body: The `WorkItem` to process by the `NIOThreadPool`.
@preconcurrency
public func submit(_ body: @escaping WorkItem) {
self._submit(body)
self._submit(id: nil, body)
}

private func _submit(_ body: @escaping WorkItem) {
private func _submit(id: Int?, _ body: @escaping WorkItem) {
let item = self.lock.withLock { () -> WorkItem? in
switch self.state {
case .running(var items):
self.state = .modifying
items.append(body)
items.append(.init(workItem: body, id: id))
self.state = .running(items)
self.semaphore.signal()
return nil
Expand Down Expand Up @@ -178,19 +205,28 @@ public final class NIOThreadPool {
}

private func process(identifier: Int) {
var item: WorkItem? = nil
var itemAndState: (item: WorkItem, state: WorkItemState)? = nil

repeat {
/* wait until work has become available */
item = nil // ensure previous work item is not retained for duration of semaphore wait
itemAndState = nil // ensure previous work item is not retained for duration of semaphore wait
self.semaphore.wait()

item = self.lock.withLock { () -> (WorkItem)? in
itemAndState = self.lock.withLock { () -> (WorkItem, WorkItemState)? in
switch self.state {
case .running(var items):
self.state = .modifying
let item = items.removeFirst()
let itemAndID = items.removeFirst()

let state: WorkItemState
if let id = itemAndID.id, !self.cancelledWorkIDs.isEmpty {
state = self.cancelledWorkIDs.remove(id) == nil ? .active : .cancelled
} else {
state = .active
}

self.state = .running(items)
return item
return (itemAndID.workItem, state)
case .shuttingDown(var aliveStates):
assert(aliveStates[identifier])
aliveStates[identifier] = false
Expand All @@ -203,8 +239,8 @@ public final class NIOThreadPool {
}
}
/* if there was a work item popped, run it */
item.map { $0(.active) }
} while item != nil
itemAndState.map { item, state in item(state) }
} while itemAndState != nil
}

/// Start the `NIOThreadPool` if not already started.
Expand Down Expand Up @@ -315,18 +351,24 @@ extension NIOThreadPool {
/// - returns: result of the passed closure.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
public func runIfActive<T: Sendable>(_ body: @escaping @Sendable () throws -> T) async throws -> T {
try await withCheckedThrowingContinuation { (cont: CheckedContinuation<T, Error>) in
self.submit { shouldRun in
guard case shouldRun = NIOThreadPool.WorkItemState.active else {
cont.resume(throwing: NIOThreadPoolError.ThreadPoolInactive())
return
}
do {
try cont.resume(returning: body())
} catch {
cont.resume(throwing: error)
let workID = self.nextWorkID.loadThenWrappingIncrement(ordering: .relaxed)

return try await withTaskCancellationHandler {
try await withCheckedThrowingContinuation { (cont: CheckedContinuation<T, Error>) in
self._submit(id: workID) { shouldRun in
switch shouldRun {
case .active:
let result = Result(catching: body)
cont.resume(with: result)
case .cancelled:
cont.resume(throwing: CancellationError())
}
}
}
} onCancel: {
self.lock.withLockVoid {
self.cancelledWorkIDs.insert(workID)
}
}
}
}
Expand Down
25 changes: 24 additions & 1 deletion Tests/NIOPosixTests/NIOThreadPoolTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,34 @@ class NIOThreadPoolTest: XCTestCase {
}
XCTFail("Should not get here as thread pool isn't active")
} catch {
XCTAssertNotNil(error as? NIOThreadPoolError.ThreadPoolInactive, "Error thrown should be of type ThreadPoolError")
XCTAssertNotNil(error as? CancellationError, "Error thrown should be of type CancellationError")
}
try await pool.shutdownGracefully()
}

func testAsyncThreadPoolCancellation() async throws {
let pool = NIOThreadPool(numberOfThreads: 1)
pool.start()

await withThrowingTaskGroup(of: Void.self) { group in
group.cancelAll()
group.addTask {
try await pool.runIfActive {
XCTFail("Should be cancelled before executed")
}
}

do {
try await group.waitForAll()
XCTFail("Expected CancellationError to be thrown")
} catch {
XCTAssert(error is CancellationError)
}
}

try await pool.shutdownGracefully()
}

func testAsyncShutdownWorks() async throws {
let threadPool = NIOThreadPool(numberOfThreads: 17)
let eventLoop = NIOAsyncTestingEventLoop()
Expand Down
6 changes: 3 additions & 3 deletions Tests/NIOPosixTests/NonBlockingFileIOTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1121,7 +1121,7 @@ extension NonBlockingFileIOTest {
}
}

func testAsyncGettingErrorWhenEventLoopGroupIsShutdown() async throws {
func testAsyncGettingErrorWhenThreadPoolIsShutdown() async throws {
try await self.threadPool.shutdownGracefully()

try await withPipe { readFH, writeFH in
Expand All @@ -1130,9 +1130,9 @@ extension NonBlockingFileIOTest {
fileHandle: readFH,
byteCount: 1,
allocator: self.allocator)
XCTFail("testAsyncGettingErrorWhenEventLoopGroupIsShutdown: fileIO.read should throw an error")
XCTFail("testAsyncGettingErrorWhenThreadPoolIsShutdown: fileIO.read should throw an error")
} catch {
XCTAssertTrue(error is NIOThreadPoolError.ThreadPoolInactive)
XCTAssertTrue(error is CancellationError)
}
return [readFH, writeFH]
}
Expand Down

0 comments on commit 1f5be71

Please sign in to comment.