Skip to content

Commit

Permalink
Add .finished() to NestedSubsystem, add sequential shutdown example
Browse files Browse the repository at this point in the history
  • Loading branch information
Finomnis committed Feb 7, 2024
1 parent 20f5671 commit d7b3287
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 5 deletions.
124 changes: 124 additions & 0 deletions examples/19_sequential_shutdown.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
//! This example demonstrates how multiple subsystems could be shut down sequentially.
//!
//! When a shutdown gets triggered (via Ctrl+C), Nested1 will shutdown first,
//! followed by Nested2 and Nested3. Only once the previous subsystem is finished shutting down,
//! the next subsystem will follow.

use miette::Result;
use tokio::time::{sleep, Duration};
use tokio_graceful_shutdown::{
FutureExt, SubsystemBuilder, SubsystemFinishedFuture, SubsystemHandle, Toplevel,
};

async fn counter(id: &str) {
let mut i = 0;
loop {
tracing::info!("{id}: {i}");
i += 1;
sleep(Duration::from_millis(50)).await;
}
}

async fn nested1(subsys: SubsystemHandle) -> Result<()> {
tracing::info!("Nested1 started.");
if counter("Nested1").cancel_on_shutdown(&subsys).await.is_ok() {
tracing::info!("Nested1 counter finished.");
} else {
tracing::info!("Nested1 shutting down ...");
sleep(Duration::from_millis(200)).await;
}
subsys.on_shutdown_requested().await;
tracing::info!("Nested1 stopped.");
Ok(())
}

async fn nested2(subsys: SubsystemHandle, nested1_finished: SubsystemFinishedFuture) -> Result<()> {
// Create a future that triggers once nested1 is finished **and** a shutdown is requested
let shutdown = {
let cancellation_token = subsys.create_cancellation_token();
async move {
tokio::join!(cancellation_token.cancelled(), nested1_finished);
}
};

tracing::info!("Nested2 started.");
tokio::select! {
_ = shutdown => {
tracing::info!("Nested2 shutting down ...");
sleep(Duration::from_millis(200)).await;
}
_ = counter("Nested2") => {
tracing::info!("Nested2 counter finished.");
}
}

tracing::info!("Nested2 stopped.");
Ok(())
}

async fn nested3(subsys: SubsystemHandle, nested2_finished: SubsystemFinishedFuture) -> Result<()> {
// Create a future that triggers once nested2 is finished **and** a shutdown is requested
let shutdown = {
let cancellation_token = subsys.create_cancellation_token();
async move {
tokio::join!(cancellation_token.cancelled(), nested2_finished);
}
};

tracing::info!("Nested3 started.");
tokio::select! {
_ = shutdown => {
tracing::info!("Nested3 shutting down ...");
sleep(Duration::from_millis(200)).await;
}
_ = counter("Nested3") => {
tracing::info!("Nested3 counter finished.");
}
}

tracing::info!("Nested3 stopped.");
Ok(())
}

async fn root(subsys: SubsystemHandle) -> Result<()> {
// This subsystem shuts down the nested subsystem after 5 seconds.
tracing::info!("Root started.");

tracing::info!("Starting nested subsystems ...");
let nested1 = subsys.start(SubsystemBuilder::new("Nested1", nested1));
let nested1_finished = nested1.finished();
let nested2 = subsys.start(SubsystemBuilder::new("Nested2", |s| {
nested2(s, nested1_finished)
}));
let nested2_finished = nested2.finished();
subsys.start(SubsystemBuilder::new("Nested3", |s| {
nested3(s, nested2_finished)
}));
tracing::info!("Nested subsystems started.");

// Wait for all children to finish shutting down.
subsys.wait_for_children().await;

tracing::info!("All children finished, stopping root ...");
sleep(Duration::from_millis(200)).await;
tracing::info!("Root stopped.");

Ok(())
}

#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<()> {
// Init logging
tracing_subscriber::fmt()
.with_max_level(tracing::Level::TRACE)
.init();

// Setup and execute subsystem tree
Toplevel::new(|s| async move {
s.start(SubsystemBuilder::new("Root", root));
})
.catch_signals()
.handle_shutdown_requests(Duration::from_millis(1000))
.await
.map_err(Into::into)
}
2 changes: 1 addition & 1 deletion src/future_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use pin_project_lite::pin_project;
use tokio_util::sync::WaitForCancellationFuture;

pin_project! {
/// A Future that is resolved once the corresponding task is finished
/// A future that is resolved once the corresponding task is finished
/// or a shutdown is initiated.
#[must_use = "futures do nothing unless polled"]
pub struct CancelOnShutdownFuture<'a, T: std::future::Future>{
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,5 +121,6 @@ pub use future_ext::FutureExt;
pub use into_subsystem::IntoSubsystem;
pub use subsystem::NestedSubsystem;
pub use subsystem::SubsystemBuilder;
pub use subsystem::SubsystemFinishedFuture;
pub use subsystem::SubsystemHandle;
pub use toplevel::Toplevel;
13 changes: 12 additions & 1 deletion src/subsystem/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
mod error_collector;
mod nested_subsystem;
mod subsystem_builder;
mod subsystem_finished_future;
mod subsystem_handle;

use std::sync::{Arc, Mutex};
use std::{
future::Future,
pin::Pin,
sync::{Arc, Mutex},
};

pub use subsystem_builder::SubsystemBuilder;
pub use subsystem_handle::SubsystemHandle;
Expand Down Expand Up @@ -35,3 +40,9 @@ pub(crate) struct ErrorActions {
pub(crate) on_failure: Atomic<ErrorAction>,
pub(crate) on_panic: Atomic<ErrorAction>,
}

/// A future that is resolved once the corresponding subsystem is finished.
#[must_use = "futures do nothing unless polled"]
pub struct SubsystemFinishedFuture {
future: Pin<Box<dyn Future<Output = ()> + Send + Sync>>,
}
14 changes: 11 additions & 3 deletions src/subsystem/nested_subsystem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::atomic::Ordering;

use crate::{errors::SubsystemJoinError, ErrTypeTraits, ErrorAction};

use super::NestedSubsystem;
use super::{NestedSubsystem, SubsystemFinishedFuture};

impl<ErrType: ErrTypeTraits> NestedSubsystem<ErrType> {
/// Wait for the subsystem to be finished.
Expand Down Expand Up @@ -68,7 +68,7 @@ impl<ErrType: ErrTypeTraits> NestedSubsystem<ErrType> {
/// Changes the way this subsystem should react to failures,
/// meaning if it or one of its children returns an `Err` value.
///
/// For more information, see [ErrorAction].
/// For more information, see [`ErrorAction`].
pub fn change_failure_action(&self, action: ErrorAction) {
self.error_actions
.on_failure
Expand All @@ -78,8 +78,16 @@ impl<ErrType: ErrTypeTraits> NestedSubsystem<ErrType> {
/// Changes the way this subsystem should react if it or one
/// of its children panic.
///
/// For more information, see [ErrorAction].
/// For more information, see [`ErrorAction`].
pub fn change_panic_action(&self, action: ErrorAction) {
self.error_actions.on_panic.store(action, Ordering::Relaxed);
}

/// Returns a future that resolves once the subsystem is finished.
///
/// Similar to [`join`](NestedSubsystem::join), but more light-weight
/// as does not perform any error handling.
pub fn finished(&self) -> SubsystemFinishedFuture {
SubsystemFinishedFuture::new(self.joiner.clone())
}
}
25 changes: 25 additions & 0 deletions src/subsystem/subsystem_finished_future.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};

use crate::utils::JoinerTokenRef;

use super::SubsystemFinishedFuture;

impl SubsystemFinishedFuture {
pub(crate) fn new(joiner: JoinerTokenRef) -> Self {
Self {
future: Box::pin(async move { joiner.join().await }),
}
}
}

impl Future for SubsystemFinishedFuture {
type Output = ();

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
self.future.as_mut().poll(cx)
}
}
1 change: 1 addition & 0 deletions src/utils/joiner_token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub(crate) struct JoinerToken<ErrType: ErrTypeTraits> {

/// A reference version that does not keep the content alive; purely for
/// joining the subtree.
#[derive(Clone)]
pub(crate) struct JoinerTokenRef {
counter: watch::Receiver<(bool, u32)>,
}
Expand Down

0 comments on commit d7b3287

Please sign in to comment.