Skip to content

Commit

Permalink
Use TaskGroup to ensure all primary / worker tasks are cancelled on…
Browse files Browse the repository at this point in the history
… error and panic (#707)
  • Loading branch information
mwtian committed Aug 13, 2022
1 parent e153cd8 commit 810919c
Show file tree
Hide file tree
Showing 10 changed files with 99 additions and 89 deletions.
8 changes: 4 additions & 4 deletions narwhal/executor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ impl Executor {
rx_consensus: metered_channel::Receiver<ConsensusOutput>,
tx_output: Sender<ExecutorOutput<State>>,
registry: &Registry,
) -> SubscriberResult<Vec<JoinHandle<()>>>
) -> SubscriberResult<Vec<(&'static str, JoinHandle<()>)>>
where
State: ExecutionState + Send + Sync + 'static,
State::Outcome: Send + 'static,
Expand Down Expand Up @@ -190,9 +190,9 @@ impl Executor {
// Return the handle.
info!("Consensus subscriber successfully started");
Ok(vec![
subscriber_handle,
executor_handle,
batch_loader_handle,
("executor_subscriber", subscriber_handle),
("executor", executor_handle),
("executor_batch_loader", batch_loader_handle),
])
}
}
Expand Down
1 change: 1 addition & 0 deletions narwhal/node/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ tracing-log = "0.1.3"
tracing-subscriber = { version = "0.3.15", features = ["time", "env-filter"] }
url = "2.2.2"
axum = "0.5.13"
task-group = "0.2.2"

config = { path = "../config" }
consensus = { path = "../consensus" }
Expand Down
41 changes: 26 additions & 15 deletions narwhal/node/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ use store::{
rocks::{open_cf, DBMap},
Store,
};
use task_group::{TaskGroup, TaskManager};
use tokio::{
sync::{mpsc::Sender, watch},
task::JoinHandle,
task::{JoinError, JoinHandle},
};
use tracing::debug;
use types::{
Expand Down Expand Up @@ -123,7 +124,7 @@ impl Node {
tx_confirmation: Sender<ExecutorOutput<State>>,
// A prometheus exporter Registry to use for the metrics
registry: &Registry,
) -> SubscriberResult<Vec<JoinHandle<()>>>
) -> SubscriberResult<TaskManager<JoinError>>
where
State: ExecutionState + Send + Sync + 'static,
State::Outcome: Send + 'static,
Expand Down Expand Up @@ -159,7 +160,7 @@ impl Node {
let consensus_metrics = Arc::new(ConsensusMetrics::new(registry));
let (handle, dag) = Dag::new(&committee.load(), rx_new_certificates, consensus_metrics);

handles.push(handle);
handles.push(("dag", handle));

(Some(Arc::new(dag)), NetworkModel::Asynchronous)
} else {
Expand Down Expand Up @@ -231,7 +232,13 @@ impl Node {
});
}

Ok(handles)
let (task_group, task_manager) = TaskGroup::new();
for (name, handle) in handles {
// The tasks will be awaited with the `task_manager`, so the task handles / futures can be dropped.
let _ = task_group.spawn(name, handle);
}

Ok(task_manager)
}

/// Spawn the consensus core and the client executing transactions.
Expand All @@ -249,7 +256,7 @@ impl Node {
SerializedTransaction,
)>,
registry: &Registry,
) -> SubscriberResult<Vec<JoinHandle<()>>>
) -> SubscriberResult<Vec<(&'static str, JoinHandle<()>)>>
where
PublicKey: VerifyingKey,
State: ExecutionState + Send + Sync + 'static,
Expand All @@ -262,13 +269,15 @@ impl Node {
let (tx_sequence, rx_sequence) =
metered_channel::channel(Self::CHANNEL_CAPACITY, &channel_metrics.tx_sequence);

let mut handles = Vec::new();

// Spawn the consensus core who only sequences transactions.
let ordering_engine = Bullshark::new(
(**committee.load()).clone(),
store.consensus_store.clone(),
parameters.gc_depth,
);
let consensus_handles = Consensus::spawn(
let consensus_handle = Consensus::spawn(
(**committee.load()).clone(),
store.consensus_store.clone(),
store.certificate_store.clone(),
Expand All @@ -280,6 +289,7 @@ impl Node {
consensus_metrics.clone(),
parameters.gc_depth,
);
handles.push(("consensus", consensus_handle));

// Spawn the client executing the transactions. It can also synchronize with the
// subscriber handler if it missed some transactions.
Expand All @@ -294,11 +304,9 @@ impl Node {
registry,
)
.await?;
handles.extend(executor_handles);

Ok(executor_handles
.into_iter()
.chain(std::iter::once(consensus_handles))
.collect())
Ok(handles)
}

/// Spawn a specified number of workers.
Expand All @@ -315,11 +323,10 @@ impl Node {
parameters: Parameters,
// The prometheus metrics Registry
registry: &Registry,
) -> Vec<JoinHandle<()>> {
let mut handles = Vec::new();

) -> TaskManager<JoinError> {
let metrics = initialise_metrics(registry);

let (task_group, task_manager) = TaskGroup::new();
for id in ids {
let worker_handles = Worker::spawn(
name.clone(),
Expand All @@ -329,8 +336,12 @@ impl Node {
store.batch_store.clone(),
metrics.clone(),
);
handles.extend(worker_handles);
// TODO(narwhal/727): propagate worker task names.
for (i, h) in worker_handles.into_iter().enumerate() {
let _ = task_group.spawn(format!("worker_{}_{}", id, i), h);
}
}
handles

task_manager
}
}
15 changes: 8 additions & 7 deletions narwhal/node/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ use clap::{crate_name, crate_version, App, AppSettings, ArgMatches, SubCommand};
use config::{Committee, Import, Parameters, WorkerId};
use crypto::{generate_production_keypair, traits::KeyPair as _, KeyPair};
use executor::{SerializedTransaction, SubscriberResult};
use eyre::Context;
use futures::future::join_all;
use eyre::{eyre, Context};
use node::{
execution_state::SimpleExecutionState,
metrics::{primary_metrics_registry, start_prometheus_server, worker_metrics_registry},
Expand Down Expand Up @@ -156,7 +155,7 @@ async fn run(matches: &ArgMatches<'_>) -> Result<(), eyre::Report> {
let registry;

// Check whether to run a primary, a worker, or an entire authority.
let node_handles = match matches.subcommand() {
let task_manager = match matches.subcommand() {
// Spawn the primary and consensus core.
("primary", Some(sub_matches)) => {
registry = primary_metrics_registry(keypair.public().clone());
Expand Down Expand Up @@ -209,10 +208,12 @@ async fn run(matches: &ArgMatches<'_>) -> Result<(), eyre::Report> {
analyze(rx_transaction_confirmation).await;

// Await on the completion handles of all the nodes we have launched
join_all(node_handles).await;

// If this expression is reached, the program ends and all other tasks terminate.
Ok(())
return task_manager.await.map_err(|err| match err {
task_group::RuntimeError::Panic { name: n, panic: p } => eyre!("{} paniced: {:?}", n, p),
task_group::RuntimeError::Application { name: n, error: e } => {
eyre!("{} error: {:?}", n, e)
}
});
}

/// Receives an ordered list of certificates and apply any application-specific logic.
Expand Down
12 changes: 6 additions & 6 deletions narwhal/node/src/restarter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl NodeRestarter {
let mut name = keypair.public().clone();
let mut committee = committee.clone();

let mut handles = Vec::new();
let mut task_managers = Vec::new();
let mut primary_network = WorkerToPrimaryNetwork::default();
let mut worker_network = PrimaryToWorkerNetwork::default();

Expand All @@ -49,7 +49,7 @@ impl NodeRestarter {
let store = NodeStorage::reopen(store_path);

// Restart the relevant components.
let primary_handles = Node::spawn_primary(
let primary = Node::spawn_primary(
keypair,
Arc::new(ArcSwap::new(Arc::new(committee.clone()))),
&store,
Expand All @@ -62,7 +62,7 @@ impl NodeRestarter {
.await
.unwrap();

let worker_handles = Node::spawn_workers(
let workers = Node::spawn_workers(
name.clone(),
/* worker_ids */ vec![0],
Arc::new(ArcSwap::new(Arc::new(committee.clone()))),
Expand All @@ -71,8 +71,8 @@ impl NodeRestarter {
registry,
);

handles.extend(primary_handles);
handles.extend(worker_handles);
task_managers.push(primary);
task_managers.push(workers);

// Wait for a committee change.
let (new_keypair, new_committee) = match rx_reconfigure.recv().await {
Expand Down Expand Up @@ -111,7 +111,7 @@ impl NodeRestarter {
worker_network.cleanup(committee.network_diff(&new_committee));

// Wait for the components to shut down.
join_all(handles.drain(..)).await;
join_all(task_managers.drain(..)).await;
tracing::debug!("All tasks successfully exited");

// Give it an extra second in case the last task to exit is a network server. The OS
Expand Down
2 changes: 1 addition & 1 deletion narwhal/node/tests/reconfigure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ async fn epoch_change() {
}
});

let _primary_handles = Node::spawn_primary(
let _primary = Node::spawn_primary(
keypair,
Arc::new(ArcSwap::new(Arc::new(committee.clone()))),
&store,
Expand Down
28 changes: 14 additions & 14 deletions narwhal/primary/src/primary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl Primary {
tx_reconfigure: watch::Sender<ReconfigureNotification>,
tx_committed_certificates: Sender<Certificate>,
registry: &Registry,
) -> Vec<JoinHandle<()>> {
) -> Vec<(&str, JoinHandle<()>)> {
// Write the parameters to the logs.
parameters.tracing();

Expand Down Expand Up @@ -440,22 +440,22 @@ impl Primary {
);

let mut handles = vec![
primary_receiver_handle,
worker_receiver_handle,
core_handle,
payload_receiver_handle,
block_synchronizer_handle,
block_waiter_handle,
block_remover_handle,
header_waiter_handle,
certificate_waiter_handle,
proposer_handle,
helper_handle,
state_handler_handle,
("primary_receiver", primary_receiver_handle),
("worker_receiver", worker_receiver_handle),
("core", core_handle),
("payload_receiver", payload_receiver_handle),
("block_synchronizer", block_synchronizer_handle),
("block_waiter", block_waiter_handle),
("block_remover", block_remover_handle),
("header_waiter", header_waiter_handle),
("certificate_waiter", certificate_waiter_handle),
("proposer", proposer_handle),
("helper", helper_handle),
("state_handler", state_handler_handle),
];

if let Some(h) = consensus_api_handle {
handles.push(h);
handles.push(("consensus_api", h));
}

handles
Expand Down
12 changes: 6 additions & 6 deletions narwhal/primary/tests/epoch_change.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ async fn test_restart_with_new_committee_change() {
let (tx_reconfigure, _rx_reconfigure) = watch::channel(initial_committee);

let store = NodeStorage::reopen(temp_dir());

let registry = Registry::new();
let primary_handles = Primary::spawn(
name,
signer,
Expand All @@ -322,9 +322,9 @@ async fn test_restart_with_new_committee_change() {
NetworkModel::Asynchronous,
tx_reconfigure,
/* tx_committed_certificates */ tx_feedback,
&Registry::new(),
&registry,
);
handles.extend(primary_handles);
handles.extend(primary_handles.into_iter().map(|(_n, j)| j));
}

// Run for a while in epoch 0.
Expand Down Expand Up @@ -380,7 +380,7 @@ async fn test_restart_with_new_committee_change() {
let (tx_reconfigure, _rx_reconfigure) = watch::channel(initial_committee);

let store = NodeStorage::reopen(temp_dir());

let registry = Registry::new();
let primary_handles = Primary::spawn(
name,
signer,
Expand All @@ -395,9 +395,9 @@ async fn test_restart_with_new_committee_change() {
NetworkModel::Asynchronous,
tx_reconfigure,
/* tx_committed_certificates */ tx_feedback,
&Registry::new(),
&registry,
);
handles.extend(primary_handles);
handles.extend(primary_handles.into_iter().map(|(_n, j)| j));
}

// Run for a while.
Expand Down
1 change: 1 addition & 0 deletions narwhal/test_utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ tokio = { version = "1.20.1", features = ["sync", "rt", "macros"] }
tokio-util = { version = "0.7.3", features = ["codec"] }
tonic = "0.7.2"
tracing = "0.1.36"
task-group = "0.2.2"

config = { path = "../config" }
crypto = { path = "../crypto", features = ["copy_key"] }
Expand Down
Loading

0 comments on commit 810919c

Please sign in to comment.