diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index 32ea358cf5..c179930670 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -41,9 +41,8 @@ jobs: name: Enable Rust Caching with: shared-key: "build-and-test" - prefix-key: ${{ matrix.just_variants }} + prefix-key: ${{ matrix.just_variants }}-${{ github.ref }} cache-on-failure: "true" - save-if: ${{ github.ref == 'refs/heads/main' }} - name: Install Just run: | @@ -119,9 +118,8 @@ jobs: name: Enable Rust Caching with: shared-key: "build-and-test" - prefix-key: ${{ matrix.just_variants }} + prefix-key: ${{ matrix.just_variants }}-${{ github.ref }} cache-on-failure: "true" - save-if: ${{ github.ref == 'refs/heads/main' }} - name: Install Just run: | @@ -153,9 +151,8 @@ jobs: name: Enable Rust Caching with: shared-key: "build-and-test" - prefix-key: ${{ matrix.just_variants }} + prefix-key: ${{ matrix.just_variants }}-${{ github.ref }} cache-on-failure: "true" - save-if: ${{ github.ref == 'refs/heads/main' }} - name: Install Just run: | @@ -197,9 +194,8 @@ jobs: name: Enable Rust Caching with: shared-key: "build-and-test" - prefix-key: ${{ matrix.just_variants }} + prefix-key: ${{ matrix.just_variants }}-${{ github.ref }} cache-on-failure: "true" - save-if: ${{ github.ref == 'refs/heads/main' }} - name: Build examples in release mode run: just ${{ matrix.just_variants }} build_release --examples --package hotshot-examples --no-default-features diff --git a/Cargo.lock b/Cargo.lock index 7a8e48cdc3..c5f4df1dea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3296,9 +3296,11 @@ dependencies = [ name = "hotshot-task" version = "0.5.43" dependencies = [ + "anyhow", "async-broadcast", "async-compatibility-layer", "async-std", + "async-trait", "futures", "tokio", "tracing", @@ -3342,6 +3344,7 @@ dependencies = [ name = "hotshot-testing" version = "0.5.43" dependencies = [ + "anyhow", "async-broadcast", "async-compatibility-layer", "async-lock 2.8.0", diff --git a/crates/hotshot/src/lib.rs b/crates/hotshot/src/lib.rs index 43fee91d06..49ed7030cf 100644 --- a/crates/hotshot/src/lib.rs +++ b/crates/hotshot/src/lib.rs @@ -26,14 +26,16 @@ use async_lock::RwLock; use async_trait::async_trait; use committable::Committable; use futures::join; -use hotshot_task::task::TaskRegistry; +use hotshot_task::task::{ConsensusTaskRegistry, NetworkTaskRegistry}; use hotshot_task_impls::{events::HotShotEvent, helpers::broadcast_event, network}; // Internal /// Reexport error type pub use hotshot_types::error::HotShotError; use hotshot_types::{ consensus::{Consensus, ConsensusMetricsValue, View, ViewInner}, - constants::{BASE_VERSION, EVENT_CHANNEL_SIZE, EXTERNAL_EVENT_CHANNEL_SIZE, STATIC_VER_0_1}, + constants::{ + Version01, BASE_VERSION, EVENT_CHANNEL_SIZE, EXTERNAL_EVENT_CHANNEL_SIZE, STATIC_VER_0_1, + }, data::Leaf, event::{EventType, LeafInfo}, message::{DataMessage, Message, MessageKind}, @@ -53,22 +55,12 @@ use hotshot_types::{ // External /// Reexport rand crate pub use rand; -use tasks::{add_request_network_task, add_response_task, add_vid_task}; +use tasks::{add_request_network_task, add_response_task}; use tracing::{debug, instrument, trace}; use vbs::version::Version; -#[cfg(not(feature = "dependency-tasks"))] -use crate::tasks::add_consensus_task; -#[cfg(feature = "dependency-tasks")] -use crate::tasks::{ - add_consensus2_task, add_quorum_proposal_recv_task, add_quorum_proposal_task, - add_quorum_vote_task, -}; use crate::{ - tasks::{ - add_da_task, add_network_event_task, add_network_message_task, add_transaction_task, - add_upgrade_task, add_view_sync_task, - }, + tasks::{add_consensus_tasks, add_network_event_task, add_network_message_task}, traits::NodeImplementation, types::{Event, SystemContextHandle}, }; @@ -561,8 +553,8 @@ impl> SystemContext { /// For a list of which tasks are being spawned, see this module's documentation. #[allow(clippy::too_many_lines)] pub async fn run_tasks(&self) -> SystemContextHandle { - // ED Need to set first first number to 1, or properly trigger the change upon start - let registry = Arc::new(TaskRegistry::default()); + let consensus_registry = ConsensusTaskRegistry::new(); + let network_registry = NetworkTaskRegistry::new(); let output_event_stream = self.external_event_stream.clone(); let internal_event_stream = self.internal_event_stream.clone(); @@ -574,171 +566,60 @@ impl> SystemContext { let vid_membership = self.memberships.vid_membership.clone(); let view_sync_membership = self.memberships.view_sync_membership.clone(); - let (event_tx, event_rx) = internal_event_stream.clone(); - - let handle = SystemContextHandle { - registry: Arc::clone(®istry), + let mut handle = SystemContextHandle { + consensus_registry, + network_registry, output_event_stream: output_event_stream.clone(), internal_event_stream: internal_event_stream.clone(), hotshot: self.clone().into(), storage: Arc::clone(&self.storage), }; - add_network_message_task( - Arc::clone(®istry), - event_tx.clone(), - Arc::clone(&quorum_network), - ) - .await; - add_network_message_task( - Arc::clone(®istry), - event_tx.clone(), - Arc::clone(&da_network), - ) - .await; + add_network_message_task(&mut handle, Arc::clone(&quorum_network)).await; + add_network_message_task(&mut handle, Arc::clone(&da_network)).await; - if let Some(request_rx) = da_network.spawn_request_receiver_task(STATIC_VER_0_1).await { - add_response_task( - Arc::clone(®istry), - event_rx.activate_cloned(), - request_rx, - &handle, - ) - .await; - add_request_network_task( - Arc::clone(®istry), - event_tx.clone(), - event_rx.activate_cloned(), - &handle, - ) - .await; + if let Some(request_receiver) = da_network.spawn_request_receiver_task(STATIC_VER_0_1).await + { + add_response_task(&mut handle, request_receiver).await; + add_request_network_task(&mut handle).await; } add_network_event_task( - Arc::clone(®istry), - event_tx.clone(), - event_rx.activate_cloned(), + &mut handle, Arc::clone(&quorum_network), quorum_membership.clone(), network::quorum_filter, - Arc::clone(&handle.storage()), ) .await; add_network_event_task( - Arc::clone(®istry), - event_tx.clone(), - event_rx.activate_cloned(), + &mut handle, Arc::clone(&quorum_network), quorum_membership, network::upgrade_filter, - Arc::clone(&handle.storage()), ) .await; add_network_event_task( - Arc::clone(®istry), - event_tx.clone(), - event_rx.activate_cloned(), + &mut handle, Arc::clone(&da_network), da_membership, network::da_filter, - Arc::clone(&handle.storage()), ) .await; add_network_event_task( - Arc::clone(®istry), - event_tx.clone(), - event_rx.activate_cloned(), + &mut handle, Arc::clone(&quorum_network), view_sync_membership, network::view_sync_filter, - Arc::clone(&handle.storage()), ) .await; add_network_event_task( - Arc::clone(®istry), - event_tx.clone(), - event_rx.activate_cloned(), + &mut handle, Arc::clone(&quorum_network), vid_membership, network::vid_filter, - Arc::clone(&handle.storage()), - ) - .await; - #[cfg(not(feature = "dependency-tasks"))] - add_consensus_task( - Arc::clone(®istry), - event_tx.clone(), - event_rx.activate_cloned(), - &handle, - ) - .await; - add_da_task( - Arc::clone(®istry), - event_tx.clone(), - event_rx.activate_cloned(), - &handle, - ) - .await; - add_vid_task( - Arc::clone(®istry), - event_tx.clone(), - event_rx.activate_cloned(), - &handle, - ) - .await; - add_transaction_task( - Arc::clone(®istry), - event_tx.clone(), - event_rx.activate_cloned(), - &handle, - ) - .await; - add_view_sync_task( - Arc::clone(®istry), - event_tx.clone(), - event_rx.activate_cloned(), - &handle, - ) - .await; - add_upgrade_task( - Arc::clone(®istry), - event_tx.clone(), - event_rx.activate_cloned(), - &handle, - ) - .await; - #[cfg(feature = "dependency-tasks")] - add_quorum_proposal_task( - Arc::clone(®istry), - event_tx.clone(), - event_rx.activate_cloned(), - &handle, - ) - .await; - #[cfg(feature = "dependency-tasks")] - add_quorum_vote_task( - Arc::clone(®istry), - event_tx.clone(), - event_rx.activate_cloned(), - &handle, - ) - .await; - #[cfg(feature = "dependency-tasks")] - add_quorum_proposal_recv_task( - Arc::clone(®istry), - event_tx.clone(), - event_rx.activate_cloned(), - &handle, - ) - .await; - #[cfg(feature = "dependency-tasks")] - add_consensus2_task( - Arc::clone(®istry), - event_tx.clone(), - event_rx.activate_cloned(), - &handle, ) .await; + add_consensus_tasks::(&mut handle).await; handle } } diff --git a/crates/hotshot/src/tasks/mod.rs b/crates/hotshot/src/tasks/mod.rs index 71cf00a5ec..7e50849795 100644 --- a/crates/hotshot/src/tasks/mod.rs +++ b/crates/hotshot/src/tasks/mod.rs @@ -5,19 +5,13 @@ pub mod task_state; use std::{sync::Arc, time::Duration}; -use async_broadcast::{Receiver, Sender}; use async_compatibility_layer::art::{async_sleep, async_spawn}; -use async_lock::RwLock; -use hotshot_task::task::{Task, TaskRegistry}; +use hotshot_task::task::Task; use hotshot_task_impls::{ consensus::ConsensusTaskState, - consensus2::Consensus2TaskState, da::DaTaskState, events::HotShotEvent, network::{NetworkEventTaskState, NetworkMessageTaskState}, - quorum_proposal::QuorumProposalTaskState, - quorum_proposal_recv::QuorumProposalRecvTaskState, - quorum_vote::QuorumVoteTaskState, request::NetworkRequestState, response::{run_response_task, NetworkResponseState, RequestReceiver}, transactions::TransactionTaskState, @@ -31,10 +25,9 @@ use hotshot_types::{ traits::{ network::ConnectedNetwork, node_implementation::{ConsensusTime, NodeImplementation, NodeType}, - storage::Storage, }, }; -use tracing::error; +use vbs::version::StaticVersionType; use crate::{tasks::task_state::CreateTaskState, types::SystemContextHandle, ConsensusApi}; @@ -49,57 +42,59 @@ pub enum GlobalEvent { /// Add tasks for network requests and responses pub async fn add_request_network_task>( - task_reg: Arc, - tx: Sender>>, - rx: Receiver>>, - handle: &SystemContextHandle, + handle: &mut SystemContextHandle, ) { let state = NetworkRequestState::::create_from(handle).await; - let task = Task::new(tx, rx, Arc::clone(&task_reg), state); - task_reg.run_task(task).await; + let task = Task::new( + state, + handle.internal_event_stream.0.clone(), + handle.internal_event_stream.1.activate_cloned(), + ); + handle.consensus_registry.run_task(task); } /// Add a task which responds to requests on the network. pub async fn add_response_task>( - task_reg: Arc, - hs_rx: Receiver>>, - rx: RequestReceiver, - handle: &SystemContextHandle, + handle: &mut SystemContextHandle, + request_receiver: RequestReceiver, ) { let state = NetworkResponseState::::new( handle.hotshot.consensus(), - rx, + request_receiver, handle.hotshot.memberships.quorum_membership.clone().into(), handle.public_key().clone(), handle.private_key().clone(), ); - task_reg - .register(run_response_task::(state, hs_rx)) - .await; + handle + .network_registry + .register(run_response_task::( + state, + handle.internal_event_stream.1.activate_cloned(), + )); } /// Add the network task to handle messages and publish events. pub async fn add_network_message_task< TYPES: NodeType, + I: NodeImplementation, NET: ConnectedNetwork, TYPES::SignatureKey>, >( - task_reg: Arc, - event_stream: Sender>>, + handle: &mut SystemContextHandle, channel: Arc, ) { let net = Arc::clone(&channel); let network_state: NetworkMessageTaskState<_> = NetworkMessageTaskState { - event_stream: event_stream.clone(), + event_stream: handle.internal_event_stream.0.clone(), }; let network = Arc::clone(&net); let mut state = network_state.clone(); - let handle = async_spawn(async move { + let task_handle = async_spawn(async move { loop { let msgs = match network.recv_msgs().await { Ok(msgs) => Messages(msgs), Err(err) => { - error!("failed to receive messages: {err}"); + tracing::error!("failed to receive messages: {err}"); // return zero messages so we sleep and try again Messages(vec![]) @@ -113,21 +108,18 @@ pub async fn add_network_message_task< } } }); - task_reg.register(handle).await; + handle.network_registry.register(task_handle); } /// Add the network task to handle events and send messages. pub async fn add_network_event_task< TYPES: NodeType, + I: NodeImplementation, NET: ConnectedNetwork, TYPES::SignatureKey>, - S: Storage + 'static, >( - task_reg: Arc, - tx: Sender>>, - rx: Receiver>>, + handle: &mut SystemContextHandle, channel: Arc, membership: TYPES::Membership, filter: fn(&Arc>) -> bool, - storage: Arc>, ) { let network_state: NetworkEventTaskState<_, _, _> = NetworkEventTaskState { channel, @@ -135,138 +127,38 @@ pub async fn add_network_event_task< version: VERSION_0_1, membership, filter, - storage, + storage: Arc::clone(&handle.storage()), }; - let task = Task::new(tx, rx, Arc::clone(&task_reg), network_state); - task_reg.run_task(task).await; -} - -/// add the consensus task -pub async fn add_consensus_task>( - task_reg: Arc, - tx: Sender>>, - rx: Receiver>>, - handle: &SystemContextHandle, -) { - let consensus_state = ConsensusTaskState::create_from(handle).await; - - let task = Task::new(tx, rx, Arc::clone(&task_reg), consensus_state); - task_reg.run_task(task).await; -} - -/// add the VID task -pub async fn add_vid_task>( - task_reg: Arc, - tx: Sender>>, - rx: Receiver>>, - handle: &SystemContextHandle, -) { - let vid_state = VidTaskState::create_from(handle).await; - let task = Task::new(tx, rx, Arc::clone(&task_reg), vid_state); - task_reg.run_task(task).await; -} - -/// add the Upgrade task. -pub async fn add_upgrade_task>( - task_reg: Arc, - tx: Sender>>, - rx: Receiver>>, - handle: &SystemContextHandle, -) { - let upgrade_state = UpgradeTaskState::create_from(handle).await; - - let task = Task::new(tx, rx, Arc::clone(&task_reg), upgrade_state); - task_reg.run_task(task).await; -} -/// add the Data Availability task -pub async fn add_da_task>( - task_reg: Arc, - tx: Sender>>, - rx: Receiver>>, - handle: &SystemContextHandle, -) { - // build the da task - let da_state = DaTaskState::create_from(handle).await; - - let task = Task::new(tx, rx, Arc::clone(&task_reg), da_state); - task_reg.run_task(task).await; -} - -/// add the Transaction Handling task -pub async fn add_transaction_task>( - task_reg: Arc, - tx: Sender>>, - rx: Receiver>>, - handle: &SystemContextHandle, -) { - let transactions_state = TransactionTaskState::<_, _, _, Version01>::create_from(handle).await; - - let task = Task::new(tx, rx, Arc::clone(&task_reg), transactions_state); - task_reg.run_task(task).await; -} - -/// add the view sync task -pub async fn add_view_sync_task>( - task_reg: Arc, - tx: Sender>>, - rx: Receiver>>, - handle: &SystemContextHandle, -) { - let view_sync_state = ViewSyncTaskState::create_from(handle).await; - - let task = Task::new(tx, rx, Arc::clone(&task_reg), view_sync_state); - task_reg.run_task(task).await; -} - -/// add the quorum proposal task -pub async fn add_quorum_proposal_task>( - task_reg: Arc, - tx: Sender>>, - rx: Receiver>>, - handle: &SystemContextHandle, -) { - let quorum_proposal_task_state = QuorumProposalTaskState::create_from(handle).await; - let task = Task::new(tx, rx, Arc::clone(&task_reg), quorum_proposal_task_state); - task_reg.run_task(task).await; -} - -/// Add the quorum vote task. -pub async fn add_quorum_vote_task>( - task_reg: Arc, - tx: Sender>>, - rx: Receiver>>, - handle: &SystemContextHandle, -) { - let quorum_vote_task_state = QuorumVoteTaskState::create_from(handle).await; - let task = Task::new(tx, rx, Arc::clone(&task_reg), quorum_vote_task_state); - task_reg.run_task(task).await; -} - -/// Add the quorum proposal recv task. -pub async fn add_quorum_proposal_recv_task>( - task_reg: Arc, - tx: Sender>>, - rx: Receiver>>, - handle: &SystemContextHandle, -) { - let quorum_proposal_recv_task_state = QuorumProposalRecvTaskState::create_from(handle).await; let task = Task::new( - tx, - rx, - Arc::clone(&task_reg), - quorum_proposal_recv_task_state, + network_state, + handle.internal_event_stream.0.clone(), + handle.internal_event_stream.1.activate_cloned(), ); - task_reg.run_task(task).await; + handle.consensus_registry.run_task(task); } -/// Add the Consensus2 task. -pub async fn add_consensus2_task>( - task_reg: Arc, - tx: Sender>>, - rx: Receiver>>, - handle: &SystemContextHandle, +/// Adds consensus-related tasks to a `SystemContextHandle`. +pub async fn add_consensus_tasks< + TYPES: NodeType, + I: NodeImplementation, + VERSION: StaticVersionType + 'static, +>( + handle: &mut SystemContextHandle, ) { - let consensus2_task_state = Consensus2TaskState::create_from(handle).await; - let task = Task::new(tx, rx, Arc::clone(&task_reg), consensus2_task_state); - task_reg.run_task(task).await; + handle.add_task(ViewSyncTaskState::::create_from(handle).await); + handle.add_task(VidTaskState::::create_from(handle).await); + handle.add_task(DaTaskState::::create_from(handle).await); + handle.add_task(TransactionTaskState::::create_from(handle).await); + handle.add_task(UpgradeTaskState::::create_from(handle).await); + { + #![cfg(not(feature = "dependency-tasks"))] + handle.add_task(ConsensusTaskState::::create_from(handle).await); + } + { + #![cfg(feature = "dependency-tasks")] + handle.add_task(QuorumProposalTaskState::::create_from(handle).await); + handle.add_task(QuorumVoteTaskState::::create_from(handle).await); + handle.add_task(QuorumProposalRecvTaskState::::create_from(handle).await); + handle.add_task(Consensus2TaskState::::create_from(handle).await); + } } diff --git a/crates/hotshot/src/tasks/task_state.rs b/crates/hotshot/src/tasks/task_state.rs index a7017699bc..a07887d297 100644 --- a/crates/hotshot/src/tasks/task_state.rs +++ b/crates/hotshot/src/tasks/task_state.rs @@ -50,19 +50,18 @@ impl, V: StaticVersionType> Create _phantom: PhantomData, id: handle.hotshot.id, shutdown_flag: Arc::new(AtomicBool::new(false)), + spawned_tasks: BTreeMap::new(), } } } #[async_trait] impl> CreateTaskState - for UpgradeTaskState> + for UpgradeTaskState { - async fn create_from( - handle: &SystemContextHandle, - ) -> UpgradeTaskState> { + async fn create_from(handle: &SystemContextHandle) -> UpgradeTaskState { UpgradeTaskState { - api: handle.clone(), + output_event_stream: handle.hotshot.external_event_stream.0.clone(), cur_view: handle.cur_view().await, quorum_membership: handle.hotshot.memberships.quorum_membership.clone().into(), quorum_network: Arc::clone(&handle.hotshot.networks.quorum_network), @@ -98,14 +97,12 @@ impl> CreateTaskState #[async_trait] impl> CreateTaskState - for DaTaskState> + for DaTaskState { - async fn create_from( - handle: &SystemContextHandle, - ) -> DaTaskState> { + async fn create_from(handle: &SystemContextHandle) -> DaTaskState { DaTaskState { - api: handle.clone(), consensus: handle.hotshot.consensus(), + output_event_stream: handle.hotshot.external_event_stream.0.clone(), da_membership: handle.hotshot.memberships.da_membership.clone().into(), da_network: Arc::clone(&handle.hotshot.networks.da_network), quorum_membership: handle.hotshot.memberships.quorum_membership.clone().into(), @@ -121,11 +118,9 @@ impl> CreateTaskState #[async_trait] impl> CreateTaskState - for ViewSyncTaskState> + for ViewSyncTaskState { - async fn create_from( - handle: &SystemContextHandle, - ) -> ViewSyncTaskState> { + async fn create_from(handle: &SystemContextHandle) -> ViewSyncTaskState { let cur_view = handle.cur_view().await; ViewSyncTaskState { current_view: cur_view, @@ -139,7 +134,6 @@ impl> CreateTaskState .into(), public_key: handle.public_key().clone(), private_key: handle.private_key().clone(), - api: handle.clone(), num_timeouts_tracked: 0, replica_task_map: HashMap::default().into(), pre_commit_relay_map: HashMap::default().into(), @@ -154,14 +148,14 @@ impl> CreateTaskState #[async_trait] impl, Ver: StaticVersionType> - CreateTaskState - for TransactionTaskState, Ver> + CreateTaskState for TransactionTaskState { async fn create_from( handle: &SystemContextHandle, - ) -> TransactionTaskState, Ver> { + ) -> TransactionTaskState { TransactionTaskState { - api: handle.clone(), + builder_timeout: handle.builder_timeout(), + output_event_stream: handle.hotshot.external_event_stream.0.clone(), consensus: handle.hotshot.consensus(), cur_view: handle.cur_view().await, network: Arc::clone(&handle.hotshot.networks.quorum_network), diff --git a/crates/hotshot/src/traits/networking/push_cdn_network.rs b/crates/hotshot/src/traits/networking/push_cdn_network.rs index 0c424c83b9..11e020d224 100644 --- a/crates/hotshot/src/traits/networking/push_cdn_network.rs +++ b/crates/hotshot/src/traits/networking/push_cdn_network.rs @@ -4,9 +4,9 @@ use std::{collections::BTreeSet, marker::PhantomData}; #[cfg(feature = "hotshot-testing")] use std::{path::Path, sync::Arc, time::Duration}; +use async_compatibility_layer::channel::UnboundedSendError; #[cfg(feature = "hotshot-testing")] -use async_compatibility_layer::art::async_spawn; -use async_compatibility_layer::{art::async_sleep, channel::UnboundedSendError}; +use async_compatibility_layer::{art::async_sleep, art::async_spawn}; use async_trait::async_trait; use bincode::config::Options; use cdn_broker::reexports::{ diff --git a/crates/hotshot/src/types/handle.rs b/crates/hotshot/src/types/handle.rs index 38cd3d8a5c..e092238282 100644 --- a/crates/hotshot/src/types/handle.rs +++ b/crates/hotshot/src/types/handle.rs @@ -8,15 +8,13 @@ use async_lock::RwLock; #[cfg(async_executor_impl = "async-std")] use async_std::task::JoinHandle; use futures::Stream; -use hotshot_task::task::TaskRegistry; +use hotshot_task::task::{ConsensusTaskRegistry, NetworkTaskRegistry, Task, TaskState}; use hotshot_task_impls::{events::HotShotEvent, helpers::broadcast_event}; use hotshot_types::{ - boxed_sync, consensus::Consensus, data::Leaf, error::HotShotError, traits::{election::Membership, node_implementation::NodeType}, - BoxSyncFuture, }; #[cfg(async_executor_impl = "tokio")] use tokio::task::JoinHandle; @@ -28,7 +26,6 @@ use crate::{traits::NodeImplementation, types::Event, SystemContext}; /// This type provides the means to message and interact with a background [`SystemContext`] instance, /// allowing the ability to receive [`Event`]s from it, send transactions to it, and interact with /// the underlying storage. -#[derive(Clone)] pub struct SystemContextHandle> { /// The [sender](Sender) and [receiver](Receiver), /// to allow the application to communicate with HotShot. @@ -40,8 +37,11 @@ pub struct SystemContextHandle> { Sender>>, InactiveReceiver>>, ), - /// registry for controlling tasks - pub(crate) registry: Arc, + /// registry for controlling consensus tasks + pub(crate) consensus_registry: ConsensusTaskRegistry>, + + /// registry for controlling network tasks + pub(crate) network_registry: NetworkTaskRegistry, /// Internal reference to the underlying [`SystemContext`] pub hotshot: Arc>, @@ -51,6 +51,17 @@ pub struct SystemContextHandle> { } impl + 'static> SystemContextHandle { + /// Adds a hotshot consensus-related task to the `SystemContextHandle`. + pub fn add_task> + 'static>(&mut self, task_state: S) { + let task = Task::new( + task_state, + self.internal_event_stream.0.clone(), + self.internal_event_stream.1.activate_cloned(), + ); + + self.consensus_registry.run_task(task); + } + /// obtains a stream to expose to the user pub fn event_stream(&self) -> impl Stream> { self.output_event_stream.1.activate_cloned() @@ -140,25 +151,24 @@ impl + 'static> SystemContextHandl } /// Shut down the the inner hotshot and wait until all background threads are closed. - // pub async fn shut_down(mut self) { - // self.registry.shutdown_all().await - pub fn shut_down<'a, 'b>(&'a mut self) -> BoxSyncFuture<'b, ()> - where - 'a: 'b, - Self: 'b, - { - boxed_sync(async move { - self.hotshot.networks.shut_down_networks().await; - // this is required because `SystemContextHandle` holds an inactive receiver and - // `broadcast_direct` below can wait indefinitely - self.internal_event_stream.0.set_await_active(false); - let _ = self - .internal_event_stream - .0 - .broadcast_direct(Arc::new(HotShotEvent::Shutdown)) - .await; - self.registry.shutdown().await; - }) + pub async fn shut_down(&mut self) { + // this is required because `SystemContextHandle` holds an inactive receiver and + // `broadcast_direct` below can wait indefinitely + self.internal_event_stream.0.set_await_active(false); + let _ = self + .internal_event_stream + .0 + .broadcast_direct(Arc::new(HotShotEvent::Shutdown)) + .await + .inspect_err(|err| tracing::error!("Failed to send shutdown event: {err}")); + tracing::error!("Shutting down network tasks!"); + self.network_registry.shutdown().await; + + tracing::error!("Shutting down networks!"); + self.hotshot.networks.shut_down_networks().await; + + tracing::error!("Shutting down consensus!"); + self.consensus_registry.shutdown().await; } /// return the timeout for a view of the underlying `SystemContext` diff --git a/crates/macros/src/lib.rs b/crates/macros/src/lib.rs index 06b9c0dcae..b866a7ed02 100644 --- a/crates/macros/src/lib.rs +++ b/crates/macros/src/lib.rs @@ -294,11 +294,6 @@ pub fn test_scripts(input: proc_macro::TokenStream) -> TokenStream { .map(|i| format_ident!("{}_output_index", quote::quote!(#i).to_string())) .collect(); - let task_names: Vec<_> = scripts - .iter() - .map(|i| format_ident!("{}_task", quote::quote!(#i).to_string())) - .collect(); - let task_expectations: Vec<_> = scripts .iter() .map(|i| format_ident!("{}_expectations", quote::quote!(#i).to_string())) @@ -316,25 +311,20 @@ pub fn test_scripts(input: proc_macro::TokenStream) -> TokenStream { validate_task_state_or_panic_in_script, }; - use hotshot_testing::{predicates::Predicate, script::RECV_TIMEOUT}; + use hotshot_testing::{predicates::Predicate}; use async_broadcast::broadcast; use hotshot_task_impls::events::HotShotEvent; use async_compatibility_layer::art::async_timeout; - use hotshot_task::task::{Task, TaskRegistry, TaskState}; + use hotshot_task::task::{Task, TaskState}; use hotshot_types::traits::node_implementation::NodeType; use std::sync::Arc; - let registry = Arc::new(TaskRegistry::default()); - - let (test_input, task_receiver) = broadcast(1024); - // let (task_input, mut test_receiver) = broadcast(1024); + async { - let task_input = test_input.clone(); - let mut test_receiver = task_receiver.clone(); + let (to_task, mut from_test) = broadcast(1024); + let (to_test, mut from_task) = broadcast(1024); - let mut loop_receiver = task_receiver.clone(); - - #(let mut #task_names = Task::new(task_input.clone(), task_receiver.clone(), registry.clone(), #scripts.state);)* + let mut loop_receiver = from_task.clone(); #(let mut #task_expectations = #scripts.expectations;)* @@ -346,20 +336,28 @@ pub fn test_scripts(input: proc_macro::TokenStream) -> TokenStream { for input in &input_group { #( - if !#task_names.state().filter(&input.clone().into()) { tracing::debug!("Test sent: {:?}", input); - if let Some(res) = #task_names.handle_event(input.clone().into()).await { - #task_names.state().handle_result(&res).await; - } + to_task + .broadcast(input.clone().into()) + .await + .expect("Failed to broadcast input message"); + + + let _ = #scripts.state + .handle_event(input.clone().into(), &to_test, &from_test) + .await + .inspect_err(|e| tracing::info!("{e}")); - while let Ok(Ok(received_output)) = async_timeout(Duration::from_millis(35), test_receiver.recv_direct()).await { + while from_test.try_recv().is_ok() {} + + while let Ok(Ok(received_output)) = async_timeout(#scripts.timeout, from_task.recv_direct()).await { tracing::debug!("Test received: {:?}", received_output); let output_asserts = &mut #task_expectations[stage_number].output_asserts; if #output_index_names >= output_asserts.len() { - panic_extra_output_in_script(stage_number, #script_names.to_string(), &received_output); + panic_extra_output_in_script(stage_number, #script_names.to_string(), &received_output); }; let assert = &mut output_asserts[#output_index_names]; @@ -368,26 +366,32 @@ pub fn test_scripts(input: proc_macro::TokenStream) -> TokenStream { #output_index_names += 1; } - } )* } while let Ok(input) = loop_receiver.try_recv() { #( - if !#task_names.state().filter(&input) { tracing::debug!("Test sent: {:?}", input); - if let Some(res) = #task_names.handle_event(input.clone()).await { - #task_names.state().handle_result(&res).await; - } + to_task + .broadcast(input.clone().into()) + .await + .expect("Failed to broadcast input message"); + + let _ = #scripts.state + .handle_event(input.clone().into(), &to_test, &from_test) + .await + .inspect_err(|e| tracing::info!("{e}")); - while let Ok(Ok(received_output)) = async_timeout(RECV_TIMEOUT, test_receiver.recv_direct()).await { + while from_test.try_recv().is_ok() {} + + while let Ok(Ok(received_output)) = async_timeout(#scripts.timeout, from_task.recv_direct()).await { tracing::debug!("Test received: {:?}", received_output); let output_asserts = &mut #task_expectations[stage_number].output_asserts; if #output_index_names >= output_asserts.len() { - panic_extra_output_in_script(stage_number, #script_names.to_string(), &received_output); + panic_extra_output_in_script(stage_number, #script_names.to_string(), &received_output); }; let mut assert = &mut output_asserts[#output_index_names]; @@ -396,7 +400,6 @@ pub fn test_scripts(input: proc_macro::TokenStream) -> TokenStream { #output_index_names += 1; } - } )* } @@ -410,11 +413,13 @@ pub fn test_scripts(input: proc_macro::TokenStream) -> TokenStream { let task_state_asserts = &mut #task_expectations[stage_number].task_state_asserts; for assert in task_state_asserts { - validate_task_state_or_panic_in_script(stage_number, #script_names.to_string(), #task_names.state(), &**assert).await; + validate_task_state_or_panic_in_script(stage_number, #script_names.to_string(), &#scripts.state, &**assert).await; } )* } } + } + }; expanded.into() diff --git a/crates/task-impls/src/consensus/mod.rs b/crates/task-impls/src/consensus/mod.rs index 92695a59d3..77f5728260 100644 --- a/crates/task-impls/src/consensus/mod.rs +++ b/crates/task-impls/src/consensus/mod.rs @@ -2,14 +2,15 @@ use std::{collections::BTreeMap, sync::Arc}; #[cfg(not(feature = "dependency-tasks"))] use anyhow::Result; -use async_broadcast::Sender; +use async_broadcast::{Receiver, Sender}; #[cfg(not(feature = "dependency-tasks"))] use async_compatibility_layer::art::async_spawn; use async_lock::RwLock; #[cfg(async_executor_impl = "async-std")] use async_std::task::JoinHandle; +use async_trait::async_trait; use futures::future::join_all; -use hotshot_task::task::{Task, TaskState}; +use hotshot_task::task::TaskState; #[cfg(not(feature = "dependency-tasks"))] use hotshot_types::data::VidDisperseShare; #[cfg(not(feature = "dependency-tasks"))] @@ -347,7 +348,7 @@ impl> ConsensusTaskState let result = collector .as_mut() .unwrap() - .handle_event(Arc::clone(&event), &event_stream) + .handle_vote_event(Arc::clone(&event), &event_stream) .await; if result == Some(HotShotTaskCompleted) { @@ -386,7 +387,7 @@ impl> ConsensusTaskState let result = collector .as_mut() .unwrap() - .handle_event(Arc::clone(&event), &event_stream) + .handle_vote_event(Arc::clone(&event), &event_stream) .await; if result == Some(HotShotTaskCompleted) { @@ -729,39 +730,33 @@ impl> ConsensusTaskState } } +#[async_trait] impl> TaskState for ConsensusTaskState { - type Event = Arc>; - type Output = (); - fn filter(&self, event: &Arc>) -> bool { - !matches!( - event.as_ref(), - HotShotEvent::QuorumProposalRecv(_, _) - | HotShotEvent::QuorumVoteRecv(_) - | HotShotEvent::QuorumProposalValidated(..) - | HotShotEvent::QcFormed(_) - | HotShotEvent::UpgradeCertificateFormed(_) - | HotShotEvent::DaCertificateRecv(_) - | HotShotEvent::ViewChange(_) - | HotShotEvent::SendPayloadCommitmentAndMetadata(..) - | HotShotEvent::Timeout(_) - | HotShotEvent::TimeoutVoteRecv(_) - | HotShotEvent::VidShareRecv(..) - | HotShotEvent::ViewSyncFinalizeCertificate2Recv(_) - | HotShotEvent::QuorumVoteSend(_) - | HotShotEvent::QuorumProposalSend(_, _) - | HotShotEvent::Shutdown, - ) - } - async fn handle_event(event: Self::Event, task: &mut Task) -> Option<()> - where - Self: Sized, - { - let sender = task.clone_sender(); - tracing::trace!("sender queue len {}", sender.len()); - task.state_mut().handle(event, sender).await; - None + type Event = HotShotEvent; + + async fn handle_event( + &mut self, + event: Arc, + sender: &Sender>, + _receiver: &Receiver>, + ) -> Result<()> { + self.handle(event, sender.clone()).await; + + Ok(()) } - fn should_shutdown(event: &Self::Event) -> bool { - matches!(event.as_ref(), HotShotEvent::Shutdown) + + async fn cancel_subtasks(&mut self) { + while !self.spawned_tasks.is_empty() { + let Some((_, handles)) = self.spawned_tasks.pop_first() else { + break; + }; + + for handle in handles { + #[cfg(async_executor_impl = "async-std")] + handle.cancel().await; + #[cfg(async_executor_impl = "tokio")] + handle.abort(); + } + } } } diff --git a/crates/task-impls/src/consensus2/handlers.rs b/crates/task-impls/src/consensus2/handlers.rs index f210af96db..ee8c8db214 100644 --- a/crates/task-impls/src/consensus2/handlers.rs +++ b/crates/task-impls/src/consensus2/handlers.rs @@ -58,7 +58,7 @@ pub(crate) async fn handle_quorum_vote_recv> Consensus2TaskState> TaskState for Consensus2TaskState { - type Event = Arc>; - type Output = (); - - fn filter(&self, event: &Arc>) -> bool { - !matches!( - event.as_ref(), - HotShotEvent::QuorumVoteRecv(_) - | HotShotEvent::TimeoutVoteRecv(_) - | HotShotEvent::ViewChange(_) - | HotShotEvent::Timeout(_) - | HotShotEvent::LastDecidedViewUpdated(_) - | HotShotEvent::Shutdown - ) - } - async fn handle_event(event: Self::Event, task: &mut Task) -> Option<()> - where - Self: Sized, - { - let sender = task.clone_sender(); - task.state_mut().handle(event, sender).await; - None - } + type Event = HotShotEvent; - fn should_shutdown(event: &Self::Event) -> bool { - matches!(event.as_ref(), HotShotEvent::Shutdown) + async fn handle_event( + &mut self, + event: Arc, + sender: &Sender>, + _receiver: &Receiver>, + ) -> Result<()> { + self.handle(event, sender.clone()).await; + + Ok(()) } + + /// Joins all subtasks. + async fn cancel_subtasks(&mut self) {} } diff --git a/crates/task-impls/src/da.rs b/crates/task-impls/src/da.rs index ff2b238173..1c4e596406 100644 --- a/crates/task-impls/src/da.rs +++ b/crates/task-impls/src/da.rs @@ -1,11 +1,13 @@ use std::{marker::PhantomData, sync::Arc}; -use async_broadcast::Sender; +use anyhow::Result; +use async_broadcast::{Receiver, Sender}; use async_compatibility_layer::art::async_spawn; use async_lock::RwLock; #[cfg(async_executor_impl = "async-std")] use async_std::task::spawn_blocking; -use hotshot_task::task::{Task, TaskState}; +use async_trait::async_trait; +use hotshot_task::task::TaskState; use hotshot_types::{ consensus::{Consensus, View}, data::DaProposal, @@ -15,7 +17,6 @@ use hotshot_types::{ simple_vote::{DaData, DaVote}, traits::{ block_contents::vid_commitment, - consensus_api::ConsensusApi, election::Membership, network::ConnectedNetwork, node_implementation::{ConsensusTime, NodeImplementation, NodeType}, @@ -42,13 +43,9 @@ use crate::{ type VoteCollectorOption = Option>; /// Tracks state of a DA task -pub struct DaTaskState< - TYPES: NodeType, - I: NodeImplementation, - A: ConsensusApi + 'static, -> { - /// The state's api - pub api: A, +pub struct DaTaskState> { + /// Output events to application + pub output_event_stream: async_broadcast::Sender>, /// View number this view is executing in. pub cur_view: TYPES::Time, @@ -83,9 +80,7 @@ pub struct DaTaskState< pub storage: Arc>, } -impl, A: ConsensusApi + 'static> - DaTaskState -{ +impl> DaTaskState { /// main task event handler #[instrument(skip_all, fields(id = self.id, view = *self.cur_view), name = "DA Main Task", level = "error")] pub async fn handle( @@ -152,15 +147,17 @@ impl, A: ConsensusApi + return None; } // Proposal is fresh and valid, notify the application layer - self.api - .send_event(Event { + broadcast_event( + Event { view_number: self.cur_view, event: EventType::DaProposal { proposal: proposal.clone(), sender: sender.clone(), }, - }) - .await; + }, + &self.output_event_stream, + ) + .await; if !self.da_membership.has_stake(&self.public_key) { debug!( @@ -262,7 +259,7 @@ impl, A: ConsensusApi + let result = collector .as_mut() .unwrap() - .handle_event(Arc::clone(&event), &event_stream) + .handle_vote_event(Arc::clone(&event), &event_stream) .await; if result == Some(HotShotTaskCompleted) { @@ -332,43 +329,27 @@ impl, A: ConsensusApi + error!("Shutting down because of shutdown signal!"); return Some(HotShotTaskCompleted); } - _ => { - error!("unexpected event {:?}", event); - } + _ => {} } None } } +#[async_trait] /// task state implementation for DA Task -impl, A: ConsensusApi + 'static> TaskState - for DaTaskState -{ - type Event = Arc>; - - type Output = HotShotTaskCompleted; - - fn filter(&self, event: &Arc>) -> bool { - !matches!( - event.as_ref(), - HotShotEvent::DaProposalRecv(_, _) - | HotShotEvent::DaVoteRecv(_) - | HotShotEvent::Shutdown - | HotShotEvent::BlockRecv(_, _, _, _, _) - | HotShotEvent::ViewChange(_) - | HotShotEvent::DaProposalValidated(_, _) - ) - } +impl> TaskState for DaTaskState { + type Event = HotShotEvent; async fn handle_event( - event: Self::Event, - task: &mut Task, - ) -> Option { - let sender = task.clone_sender(); - task.state_mut().handle(event, sender).await - } + &mut self, + event: Arc, + sender: &Sender>, + _receiver: &Receiver>, + ) -> Result<()> { + self.handle(event, sender.clone()).await; - fn should_shutdown(event: &Self::Event) -> bool { - matches!(event.as_ref(), HotShotEvent::Shutdown) + Ok(()) } + + async fn cancel_subtasks(&mut self) {} } diff --git a/crates/task-impls/src/events.rs b/crates/task-impls/src/events.rs index e3a116f00b..e1b4127ed0 100644 --- a/crates/task-impls/src/events.rs +++ b/crates/task-impls/src/events.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use either::Either; +use hotshot_task::task::TaskEvent; use hotshot_types::{ data::{DaProposal, Leaf, QuorumProposal, UpgradeProposal, VidDisperse, VidDisperseShare}, message::Proposal, @@ -21,6 +22,12 @@ use vbs::version::Version; use crate::view_sync::ViewSyncPhase; +impl TaskEvent for HotShotEvent { + fn shutdown_event() -> Self { + HotShotEvent::Shutdown + } +} + /// Marker that the task completed #[derive(Eq, PartialEq, Debug, Clone)] pub struct HotShotTaskCompleted; diff --git a/crates/task-impls/src/harness.rs b/crates/task-impls/src/harness.rs index 01a944543b..8c2f732587 100644 --- a/crates/task-impls/src/harness.rs +++ b/crates/task-impls/src/harness.rs @@ -2,7 +2,7 @@ use std::{sync::Arc, time::Duration}; use async_broadcast::broadcast; use async_compatibility_layer::art::async_timeout; -use hotshot_task::task::{Task, TaskRegistry, TaskState}; +use hotshot_task::task::{ConsensusTaskRegistry, Task, TaskState}; use hotshot_types::traits::node_implementation::NodeType; use crate::events::{HotShotEvent, HotShotTaskCompleted}; @@ -15,23 +15,6 @@ pub struct TestHarnessState { allow_extra_output: bool, } -impl TaskState for TestHarnessState { - type Event = Arc>; - type Output = HotShotTaskCompleted; - - async fn handle_event( - event: Self::Event, - task: &mut Task, - ) -> Option { - let extra = task.state_mut().allow_extra_output; - handle_event(event, task, extra) - } - - fn should_shutdown(event: &Self::Event) -> bool { - matches!(event.as_ref(), HotShotEvent::Shutdown) - } -} - /// Runs a test by building the task using `build_fn` and then passing it the `input` events /// and testing the make sure all of the `expected_output` events are seen /// @@ -44,7 +27,7 @@ impl TaskState for TestHarnessState { /// Panics if any state the test expects is not set. Panicking causes a test failure #[allow(clippy::implicit_hasher)] #[allow(clippy::panic)] -pub async fn run_harness>> + Send + 'static>( +pub async fn run_harness> + Send + 'static>( input: Vec>, expected_output: Vec>, state: S, @@ -52,37 +35,35 @@ pub async fn run_harness>> + ) where TYPES: NodeType, { - let registry = Arc::new(TaskRegistry::default()); - let mut tasks = vec![]; + let mut registry = ConsensusTaskRegistry::new(); // set up two broadcast channels so the test sends to the task and the task back to the test let (to_task, from_test) = broadcast(1024); - let (to_test, from_task) = broadcast(1024); - let test_state = TestHarnessState { + let (to_test, mut from_task) = broadcast(1024); + let mut test_state = TestHarnessState { expected_output, allow_extra_output, }; - let test_task = Task::new( - to_test.clone(), - from_task.clone(), - Arc::clone(®istry), - test_state, - ); - let task = Task::new( - to_test.clone(), - from_test.clone(), - Arc::clone(®istry), - state, - ); + let task = Task::new(state, to_test.clone(), from_test.clone()); + + let handle = task.run(); + let test_future = async move { + loop { + if let Ok(event) = from_task.recv_direct().await { + if let Some(HotShotTaskCompleted) = check_event(event, &mut test_state) { + break; + } + } + } + }; - tasks.push(test_task.run()); - tasks.push(task.run()); + registry.register(handle); for event in input { to_task.broadcast_direct(Arc::new(event)).await.unwrap(); } - if async_timeout(Duration::from_secs(2), futures::future::join_all(tasks)) + if async_timeout(Duration::from_secs(2), test_future) .await .is_err() { @@ -100,16 +81,14 @@ pub async fn run_harness>> + /// # Panics /// Will panic to fail the test when it receives and unexpected event #[allow(clippy::needless_pass_by_value)] -pub fn handle_event( +fn check_event( event: Arc>, - task: &mut Task>, - allow_extra_output: bool, + state: &mut TestHarnessState, ) -> Option { - let state = task.state_mut(); // Check the output in either case: // * We allow outputs only in our expected output set. // * We haven't received all expected outputs yet. - if !allow_extra_output || !state.expected_output.is_empty() { + if !state.allow_extra_output || !state.expected_output.is_empty() { assert!( state.expected_output.contains(&event), "Got an unexpected event: {event:?}", diff --git a/crates/task-impls/src/network.rs b/crates/task-impls/src/network.rs index 59a47ea19f..bcc5d59d9f 100644 --- a/crates/task-impls/src/network.rs +++ b/crates/task-impls/src/network.rs @@ -1,9 +1,11 @@ use std::{collections::HashMap, sync::Arc}; -use async_broadcast::Sender; +use anyhow::Result; +use async_broadcast::{Receiver, Sender}; use async_compatibility_layer::art::async_spawn; use async_lock::RwLock; -use hotshot_task::task::{Task, TaskState}; +use async_trait::async_trait; +use hotshot_task::task::TaskState; use hotshot_types::{ constants::{BASE_VERSION, STATIC_VER_0_1}, data::{VidDisperse, VidDisperseShare}, @@ -20,7 +22,7 @@ use hotshot_types::{ }, vote::{HasViewNumber, Vote}, }; -use tracing::{debug, error, info, instrument, warn}; +use tracing::{debug, error, instrument, warn}; use vbs::version::Version; use crate::{ @@ -79,27 +81,6 @@ pub struct NetworkMessageTaskState { pub event_stream: Sender>>, } -impl TaskState for NetworkMessageTaskState { - type Event = Vec>; - type Output = (); - - async fn handle_event(event: Self::Event, task: &mut Task) -> Option<()> - where - Self: Sized, - { - task.state_mut().handle_messages(event).await; - None - } - - fn filter(&self, _event: &Self::Event) -> bool { - false - } - - fn should_shutdown(_event: &Self::Event) -> bool { - false - } -} - impl NetworkMessageTaskState { #[instrument(skip_all, name = "Network message task", level = "trace")] /// Handle the message. @@ -212,41 +193,31 @@ pub struct NetworkEventTaskState< pub storage: Arc>, } +#[async_trait] impl< TYPES: NodeType, COMMCHANNEL: ConnectedNetwork, TYPES::SignatureKey>, S: Storage + 'static, > TaskState for NetworkEventTaskState { - type Event = Arc>; - - type Output = HotShotTaskCompleted; + type Event = HotShotEvent; async fn handle_event( - event: Self::Event, - task: &mut Task, - ) -> Option { - let membership = task.state_mut().membership.clone(); - task.state_mut().handle_event(event, &membership).await - } + &mut self, + event: Arc, + _sender: &Sender>, + _receiver: &Receiver>, + ) -> Result<()> { + let membership = self.membership.clone(); - fn should_shutdown(event: &Self::Event) -> bool { - if matches!(event.as_ref(), HotShotEvent::Shutdown) { - info!("Network Task received Shutdown event"); - return true; + if !(self.filter)(&event) { + self.handle(event, &membership).await; } - false - } - fn filter(&self, event: &Self::Event) -> bool { - (self.filter)(event) - && !matches!( - event.as_ref(), - HotShotEvent::VersionUpgrade(_) - | HotShotEvent::ViewChange(_) - | HotShotEvent::Shutdown - ) + Ok(()) } + + async fn cancel_subtasks(&mut self) {} } impl< @@ -260,7 +231,7 @@ impl< /// Returns the completion status. #[allow(clippy::too_many_lines)] // TODO https://github.com/EspressoSystems/HotShot/issues/1704 #[instrument(skip_all, fields(view = *self.view), name = "Network Task", level = "error")] - pub async fn handle_event( + pub async fn handle( &mut self, event: Arc>, membership: &TYPES::Membership, diff --git a/crates/task-impls/src/quorum_proposal/mod.rs b/crates/task-impls/src/quorum_proposal/mod.rs index aca27eb4f0..b4931c1082 100644 --- a/crates/task-impls/src/quorum_proposal/mod.rs +++ b/crates/task-impls/src/quorum_proposal/mod.rs @@ -1,14 +1,16 @@ use std::{collections::HashMap, sync::Arc}; +use anyhow::Result; use async_broadcast::{Receiver, Sender}; use async_lock::RwLock; #[cfg(async_executor_impl = "async-std")] use async_std::task::JoinHandle; +use async_trait::async_trait; use either::Either; use hotshot_task::{ dependency::{AndDependency, EventDependency, OrDependency}, dependency_task::DependencyTask, - task::{Task, TaskState}, + task::TaskState, }; use hotshot_types::{ consensus::Consensus, @@ -534,36 +536,33 @@ impl> QuorumProposalTaskState> TaskState for QuorumProposalTaskState { - type Event = Arc>; - type Output = (); - fn filter(&self, event: &Self::Event) -> bool { - !matches!( - event.as_ref(), - HotShotEvent::QuorumProposalValidated(..) - | HotShotEvent::QcFormed(_) - | HotShotEvent::SendPayloadCommitmentAndMetadata(..) - | HotShotEvent::ViewSyncFinalizeCertificate2Recv(_) - | HotShotEvent::QuorumProposalLivenessValidated(..) - | HotShotEvent::QuorumProposalSend(..) - | HotShotEvent::VidShareValidated(_) - | HotShotEvent::ValidatedStateUpdated(..) - | HotShotEvent::UpdateHighQc(_) - | HotShotEvent::Shutdown - ) - } - async fn handle_event(event: Self::Event, task: &mut Task) -> Option<()> - where - Self: Sized, - { - let receiver = task.subscribe(); - let sender = task.clone_sender(); - task.state_mut().handle(event, receiver, sender).await; - None + type Event = HotShotEvent; + + async fn handle_event( + &mut self, + event: Arc, + sender: &Sender>, + receiver: &Receiver>, + ) -> Result<()> { + self.handle(event, receiver.clone(), sender.clone()).await; + + Ok(()) } - fn should_shutdown(event: &Self::Event) -> bool { - matches!(event.as_ref(), HotShotEvent::Shutdown) + + async fn cancel_subtasks(&mut self) { + for handle in self + .propose_dependencies + .drain() + .map(|(_view, handle)| handle) + { + #[cfg(async_executor_impl = "async-std")] + handle.cancel().await; + #[cfg(async_executor_impl = "tokio")] + handle.abort(); + } } } diff --git a/crates/task-impls/src/quorum_proposal_recv/mod.rs b/crates/task-impls/src/quorum_proposal_recv/mod.rs index efc3c08774..fde3d7e056 100644 --- a/crates/task-impls/src/quorum_proposal_recv/mod.rs +++ b/crates/task-impls/src/quorum_proposal_recv/mod.rs @@ -2,10 +2,12 @@ use std::{collections::BTreeMap, sync::Arc}; -use async_broadcast::Sender; +use anyhow::Result; +use async_broadcast::{Receiver, Sender}; use async_lock::RwLock; #[cfg(async_executor_impl = "async-std")] use async_std::task::JoinHandle; +use async_trait::async_trait; use futures::future::join_all; use hotshot_task::task::{Task, TaskState}; use hotshot_types::{ @@ -190,28 +192,35 @@ impl> QuorumProposalRecvTaskState< } } +#[async_trait] impl> TaskState for QuorumProposalRecvTaskState { - type Event = Arc>; - type Output = (); - fn filter(&self, event: &Arc>) -> bool { - !matches!( - event.as_ref(), - HotShotEvent::QuorumProposalRecv(..) | HotShotEvent::Shutdown - ) - } + type Event = HotShotEvent; + + async fn handle_event( + &mut self, + event: Arc, + sender: &Sender>, + _receiver: &Receiver>, + ) -> Result<()> { + self.handle(event, sender.clone()).await; - async fn handle_event(event: Self::Event, task: &mut Task) -> Option<()> - where - Self: Sized, - { - let sender = task.clone_sender(); - task.state_mut().handle(event, sender).await; - None + Ok(()) } - fn should_shutdown(event: &Self::Event) -> bool { - matches!(event.as_ref(), HotShotEvent::Shutdown) + async fn cancel_subtasks(&mut self) { + while !self.spawned_tasks.is_empty() { + let Some((_, handles)) = self.spawned_tasks.pop_first() else { + break; + }; + + for handle in handles { + #[cfg(async_executor_impl = "async-std")] + handle.cancel().await; + #[cfg(async_executor_impl = "tokio")] + handle.abort(); + } + } } } diff --git a/crates/task-impls/src/quorum_vote.rs b/crates/task-impls/src/quorum_vote.rs index db60d70433..13557871f5 100644 --- a/crates/task-impls/src/quorum_vote.rs +++ b/crates/task-impls/src/quorum_vote.rs @@ -1,14 +1,16 @@ use std::{collections::HashMap, sync::Arc}; +use anyhow::Result; use async_broadcast::{Receiver, Sender}; use async_lock::RwLock; #[cfg(async_executor_impl = "async-std")] use async_std::task::JoinHandle; +use async_trait::async_trait; use committable::Committable; use hotshot_task::{ dependency::{AndDependency, EventDependency, OrDependency}, dependency_task::{DependencyTask, HandleDepOutput}, - task::{Task, TaskState}, + task::TaskState, }; use hotshot_types::{ consensus::Consensus, @@ -536,32 +538,27 @@ impl> QuorumVoteTaskState> TaskState for QuorumVoteTaskState { - type Event = Arc>; - type Output = (); - fn filter(&self, event: &Arc>) -> bool { - !matches!( - event.as_ref(), - HotShotEvent::DaCertificateRecv(_) - | HotShotEvent::VidShareRecv(..) - | HotShotEvent::QuorumVoteDependenciesValidated(_) - | HotShotEvent::VoteNow(..) - | HotShotEvent::QuorumProposalValidated(..) - | HotShotEvent::ValidatedStateUpdated(..) - | HotShotEvent::Shutdown, - ) - } - async fn handle_event(event: Self::Event, task: &mut Task) -> Option<()> - where - Self: Sized, - { - let receiver = task.subscribe(); - let sender = task.clone_sender(); - tracing::trace!("sender queue len {}", sender.len()); - task.state_mut().handle(event, receiver, sender).await; - None + type Event = HotShotEvent; + + async fn handle_event( + &mut self, + event: Arc, + sender: &Sender>, + receiver: &Receiver>, + ) -> Result<()> { + self.handle(event, receiver.clone(), sender.clone()).await; + + Ok(()) } - fn should_shutdown(event: &Self::Event) -> bool { - matches!(event.as_ref(), HotShotEvent::Shutdown) + + async fn cancel_subtasks(&mut self) { + for handle in self.vote_dependencies.drain().map(|(_view, handle)| handle) { + #[cfg(async_executor_impl = "async-std")] + handle.cancel().await; + #[cfg(async_executor_impl = "tokio")] + handle.abort(); + } } } diff --git a/crates/task-impls/src/request.rs b/crates/task-impls/src/request.rs index ab16dc425e..43faddb5d4 100644 --- a/crates/task-impls/src/request.rs +++ b/crates/task-impls/src/request.rs @@ -1,4 +1,5 @@ use std::{ + collections::BTreeMap, marker::PhantomData, sync::{ atomic::{AtomicBool, Ordering}, @@ -7,9 +8,13 @@ use std::{ time::Duration, }; -use async_broadcast::Sender; +use anyhow::Result; +use async_broadcast::{Receiver, Sender}; use async_compatibility_layer::art::{async_sleep, async_spawn, async_timeout}; use async_lock::RwLock; +#[cfg(async_executor_impl = "async-std")] +use async_std::task::JoinHandle; +use async_trait::async_trait; use hotshot_task::task::TaskState; use hotshot_types::{ consensus::Consensus, @@ -24,13 +29,12 @@ use hotshot_types::{ }; use rand::{prelude::SliceRandom, thread_rng}; use sha2::{Digest, Sha256}; +#[cfg(async_executor_impl = "tokio")] +use tokio::task::JoinHandle; use tracing::{debug, error, info, instrument, warn}; use vbs::{version::StaticVersionType, BinarySerializer, Serializer}; -use crate::{ - events::{HotShotEvent, HotShotTaskCompleted}, - helpers::broadcast_event, -}; +use crate::{events::HotShotEvent, helpers::broadcast_event}; /// Amount of time to try for a request before timing out. const REQUEST_TIMEOUT: Duration = Duration::from_millis(500); @@ -42,7 +46,7 @@ const REQUEST_TIMEOUT: Duration = Duration::from_millis(500); pub struct NetworkRequestState< TYPES: NodeType, I: NodeImplementation, - Ver: StaticVersionType, + Ver: StaticVersionType + 'static, > { /// Network to send requests over pub network: Arc, @@ -67,64 +71,69 @@ pub struct NetworkRequestState< pub id: u64, /// A flag indicating that `HotShotEvent::Shutdown` has been received pub shutdown_flag: Arc, + /// A flag indicating that `HotShotEvent::Shutdown` has been received + pub spawned_tasks: BTreeMap>>, +} + +impl, Ver: StaticVersionType + 'static> Drop + for NetworkRequestState +{ + fn drop(&mut self) { + futures::executor::block_on(async move { self.cancel_subtasks().await }); + } } /// Alias for a signature type Signature = <::SignatureKey as SignatureKey>::PureAssembledSignatureType; +#[async_trait] impl, Ver: StaticVersionType + 'static> TaskState for NetworkRequestState { - type Event = Arc>; - - type Output = HotShotTaskCompleted; + type Event = HotShotEvent; async fn handle_event( - event: Self::Event, - task: &mut hotshot_task::task::Task, - ) -> Option { + &mut self, + event: Arc, + sender: &Sender>, + _receiver: &Receiver>, + ) -> Result<()> { match event.as_ref() { HotShotEvent::QuorumProposalValidated(proposal, _) => { - let state = task.state(); let prop_view = proposal.view_number(); - if prop_view >= state.view { - state - .spawn_requests(prop_view, task.clone_sender(), Ver::instance()) + if prop_view >= self.view { + self.spawn_requests(prop_view, sender.clone(), Ver::instance()) .await; } - None + Ok(()) } HotShotEvent::ViewChange(view) => { let view = *view; - if view > task.state().view { - task.state_mut().view = view; + if view > self.view { + self.view = view; } - None - } - HotShotEvent::Shutdown => { - task.state().set_shutdown_flag(); - Some(HotShotTaskCompleted) + Ok(()) } - _ => None, + _ => Ok(()), } } - fn should_shutdown(event: &Self::Event) -> bool { - matches!(event.as_ref(), HotShotEvent::Shutdown) - } + async fn cancel_subtasks(&mut self) { + self.set_shutdown_flag(); - fn filter(&self, event: &Self::Event) -> bool { - !matches!( - event.as_ref(), - HotShotEvent::Shutdown - | HotShotEvent::QuorumProposalValidated(..) - | HotShotEvent::ViewChange(_) - ) - } + while !self.spawned_tasks.is_empty() { + let Some((_, handles)) = self.spawned_tasks.pop_first() else { + break; + }; - async fn shutdown(&mut self) { - self.set_shutdown_flag(); + for handle in handles { + #[cfg(async_executor_impl = "async-std")] + handle.cancel().await; + #[cfg(async_executor_impl = "tokio")] + handle.abort(); + } + } } } @@ -133,7 +142,7 @@ impl, Ver: StaticVersionType + 'st { /// Spawns tasks for a given view to retrieve any data needed. async fn spawn_requests( - &self, + &mut self, view: TYPES::Time, sender: Sender>>, bind_version: Ver, @@ -161,7 +170,7 @@ impl, Ver: StaticVersionType + 'st /// received will be sent over `sender` #[instrument(skip_all, fields(id = self.id, view = *self.view), name = "NetworkRequestState run_delay", level = "error")] fn run_delay( - &self, + &mut self, request: RequestKind, sender: Sender>>, view: TYPES::Time, @@ -193,11 +202,13 @@ impl, Ver: StaticVersionType + 'st return; }; debug!("Requesting data: {:?}", request); - async_spawn(requester.run::(request, signature)); + let handle = async_spawn(requester.run::(request, signature)); + + self.spawned_tasks.entry(view).or_default().push(handle); } /// Signals delayed requesters to finish - fn set_shutdown_flag(&self) { + pub fn set_shutdown_flag(&self) { self.shutdown_flag.store(true, Ordering::Relaxed); } } @@ -280,6 +291,7 @@ impl> DelayedRequester { } Ok(Err(e)) => { warn!("Error Sending request. Error: {:?}", e); + async_sleep(REQUEST_TIMEOUT).await; } Err(_) => { warn!("Request to other node timed out"); diff --git a/crates/task-impls/src/transactions.rs b/crates/task-impls/src/transactions.rs index b74f4771ec..7a24915cb4 100644 --- a/crates/task-impls/src/transactions.rs +++ b/crates/task-impls/src/transactions.rs @@ -3,14 +3,15 @@ use std::{ time::{Duration, Instant}, }; -use anyhow::{bail, Context}; -use async_broadcast::Sender; +use anyhow::{bail, Context, Result}; +use async_broadcast::{Receiver, Sender}; use async_compatibility_layer::art::async_sleep; use async_lock::RwLock; +use async_trait::async_trait; use hotshot_builder_api::block_info::{ AvailableBlockData, AvailableBlockHeaderInput, AvailableBlockInfo, }; -use hotshot_task::task::{Task, TaskState}; +use hotshot_task::task::TaskState; use hotshot_types::{ consensus::Consensus, data::{null_block, Leaf}, @@ -18,7 +19,6 @@ use hotshot_types::{ simple_certificate::UpgradeCertificate, traits::{ block_contents::{precompute_vid_commitment, BuilderFee, EncodeBytes}, - consensus_api::ConsensusApi, election::Membership, node_implementation::{ConsensusTime, NodeImplementation, NodeType}, signature_key::{BuilderSignatureKey, SignatureKey}, @@ -48,15 +48,18 @@ pub struct BuilderResponses { /// It contains the final block information pub block_header: AvailableBlockHeaderInput, } + /// Tracks state of a Transaction task pub struct TransactionTaskState< TYPES: NodeType, I: NodeImplementation, - A: ConsensusApi + 'static, Ver: StaticVersionType, > { /// The state's api - pub api: A, + pub builder_timeout: Duration, + + /// Output events to application + pub output_event_stream: async_broadcast::Sender>, /// View number this view is executing in. pub cur_view: TYPES::Time, @@ -85,12 +88,8 @@ pub struct TransactionTaskState< pub decided_upgrade_certificate: Option>, } -impl< - TYPES: NodeType, - I: NodeImplementation, - A: ConsensusApi + 'static, - Ver: StaticVersionType, - > TransactionTaskState +impl, Ver: StaticVersionType> + TransactionTaskState { /// main task event handler #[instrument(skip_all, fields(id = self.id, view = *self.cur_view), name = "Transaction task", level = "error")] @@ -101,14 +100,17 @@ impl< ) -> Option { match event.as_ref() { HotShotEvent::TransactionsRecv(transactions) => { - self.api - .send_event(Event { + broadcast_event( + Event { view_number: self.cur_view, event: EventType::Transactions { transactions: transactions.clone(), }, - }) - .await; + }, + &self.output_event_stream, + ) + .await; + return None; } HotShotEvent::UpgradeDecided(cert) => { @@ -270,10 +272,9 @@ impl< } }; - while task_start_time.elapsed() < self.api.builder_timeout() { + while task_start_time.elapsed() < self.builder_timeout { match async_compatibility_layer::art::async_timeout( - self.api - .builder_timeout() + self.builder_timeout .saturating_sub(task_start_time.elapsed()), self.block_from_builder(parent_comm, view_num, &parent_comm_sig), ) @@ -401,37 +402,23 @@ impl< } } +#[async_trait] /// task state implementation for Transactions Task -impl< - TYPES: NodeType, - I: NodeImplementation, - A: ConsensusApi + 'static, - Ver: StaticVersionType + 'static, - > TaskState for TransactionTaskState +impl, Ver: StaticVersionType + 'static> TaskState + for TransactionTaskState { - type Event = Arc>; - - type Output = HotShotTaskCompleted; - - fn filter(&self, event: &Arc>) -> bool { - !matches!( - event.as_ref(), - HotShotEvent::TransactionsRecv(_) - | HotShotEvent::Shutdown - | HotShotEvent::ViewChange(_) - | HotShotEvent::UpgradeDecided(_) - ) - } + type Event = HotShotEvent; async fn handle_event( - event: Self::Event, - task: &mut Task, - ) -> Option { - let sender = task.clone_sender(); - task.state_mut().handle(event, sender).await - } + &mut self, + event: Arc, + sender: &Sender>, + _receiver: &Receiver>, + ) -> Result<()> { + self.handle(event, sender.clone()).await; - fn should_shutdown(event: &Self::Event) -> bool { - matches!(event.as_ref(), HotShotEvent::Shutdown) + Ok(()) } + + async fn cancel_subtasks(&mut self) {} } diff --git a/crates/task-impls/src/upgrade.rs b/crates/task-impls/src/upgrade.rs index c72e9d842d..656b6dc0c7 100644 --- a/crates/task-impls/src/upgrade.rs +++ b/crates/task-impls/src/upgrade.rs @@ -1,14 +1,15 @@ use std::sync::Arc; -use async_broadcast::Sender; +use anyhow::Result; +use async_broadcast::{Receiver, Sender}; use async_lock::RwLock; +use async_trait::async_trait; use hotshot_task::task::TaskState; use hotshot_types::{ event::{Event, EventType}, simple_certificate::UpgradeCertificate, simple_vote::{UpgradeProposalData, UpgradeVote}, traits::{ - consensus_api::ConsensusApi, election::Membership, node_implementation::{ConsensusTime, NodeImplementation, NodeType}, signature_key::SignatureKey, @@ -29,13 +30,10 @@ use crate::{ type VoteCollectorOption = Option>; /// Tracks state of a DA task -pub struct UpgradeTaskState< - TYPES: NodeType, - I: NodeImplementation, - A: ConsensusApi + 'static, -> { - /// The state's api - pub api: A, +pub struct UpgradeTaskState> { + /// Output events to application + pub output_event_stream: async_broadcast::Sender>, + /// View number this view is executing in. pub cur_view: TYPES::Time, @@ -61,9 +59,7 @@ pub struct UpgradeTaskState< pub id: u64, } -impl, A: ConsensusApi + 'static> - UpgradeTaskState -{ +impl> UpgradeTaskState { /// main task event handler #[instrument(skip_all, fields(id = self.id, view = *self.cur_view), name = "Upgrade Task", level = "error")] pub async fn handle( @@ -123,15 +119,17 @@ impl, A: ConsensusApi + // * the proposal was expected, // * the proposal is valid, and // so we notify the application layer - self.api - .send_event(Event { + broadcast_event( + Event { view_number: self.cur_view, event: EventType::UpgradeProposal { proposal: proposal.clone(), sender: sender.clone(), }, - }) - .await; + }, + &self.output_event_stream, + ) + .await; // If everything is fine up to here, we generate and send a vote on the proposal. let Ok(vote) = UpgradeVote::create_signed_vote( @@ -182,7 +180,7 @@ impl, A: ConsensusApi + let result = collector .as_mut() .unwrap() - .handle_event(Arc::clone(&event), &tx) + .handle_vote_event(Arc::clone(&event), &tx) .await; if result == Some(HotShotTaskCompleted) { @@ -261,43 +259,27 @@ impl, A: ConsensusApi + error!("Shutting down because of shutdown signal!"); return Some(HotShotTaskCompleted); } - _ => { - error!("unexpected event {:?}", event); - } + _ => {} } None } } +#[async_trait] /// task state implementation for the upgrade task -impl, A: ConsensusApi + 'static> TaskState - for UpgradeTaskState -{ - type Event = Arc>; - - type Output = HotShotTaskCompleted; +impl> TaskState for UpgradeTaskState { + type Event = HotShotEvent; async fn handle_event( - event: Self::Event, - task: &mut hotshot_task::task::Task, - ) -> Option { - let sender = task.clone_sender(); - tracing::trace!("sender queue len {}", sender.len()); - task.state_mut().handle(event, sender).await - } + &mut self, + event: Arc, + sender: &Sender>, + _receiver: &Receiver>, + ) -> Result<()> { + self.handle(event, sender.clone()).await; - fn should_shutdown(event: &Self::Event) -> bool { - matches!(event.as_ref(), HotShotEvent::Shutdown) + Ok(()) } - fn filter(&self, event: &Self::Event) -> bool { - !matches!( - event.as_ref(), - HotShotEvent::UpgradeProposalRecv(_, _) - | HotShotEvent::UpgradeVoteRecv(_) - | HotShotEvent::Shutdown - | HotShotEvent::ViewChange(_) - | HotShotEvent::VersionUpgrade(_) - ) - } + async fn cancel_subtasks(&mut self) {} } diff --git a/crates/task-impls/src/vid.rs b/crates/task-impls/src/vid.rs index f7dd6d4045..25750939ac 100644 --- a/crates/task-impls/src/vid.rs +++ b/crates/task-impls/src/vid.rs @@ -1,8 +1,10 @@ use std::{marker::PhantomData, sync::Arc}; -use async_broadcast::Sender; +use anyhow::Result; +use async_broadcast::{Receiver, Sender}; use async_lock::RwLock; -use hotshot_task::task::{Task, TaskState}; +use async_trait::async_trait; +use hotshot_task::task::TaskState; use hotshot_types::{ consensus::Consensus, data::{VidDisperse, VidDisperseShare}, @@ -25,7 +27,6 @@ use crate::{ pub struct VidTaskState> { /// View number this view is executing in. pub cur_view: TYPES::Time, - /// Reference to consensus. Leader will require a read lock on this. pub consensus: Arc>>, /// Network for all nodes @@ -146,38 +147,26 @@ impl> VidTaskState { HotShotEvent::Shutdown => { return Some(HotShotTaskCompleted); } - _ => { - error!("unexpected event {:?}", event); - } + _ => {} } None } } +#[async_trait] /// task state implementation for VID Task impl> TaskState for VidTaskState { - type Event = Arc>; - - type Output = HotShotTaskCompleted; + type Event = HotShotEvent; async fn handle_event( - event: Self::Event, - task: &mut Task, - ) -> Option { - let sender = task.clone_sender(); - task.state_mut().handle(event, sender).await; - None - } - fn filter(&self, event: &Self::Event) -> bool { - !matches!( - event.as_ref(), - HotShotEvent::Shutdown - | HotShotEvent::BlockRecv(_, _, _, _, _) - | HotShotEvent::BlockReady(_, _) - | HotShotEvent::ViewChange(_) - ) - } - fn should_shutdown(event: &Self::Event) -> bool { - matches!(event.as_ref(), HotShotEvent::Shutdown) + &mut self, + event: Arc, + sender: &Sender>, + _receiver: &Receiver>, + ) -> Result<()> { + self.handle(event, sender.clone()).await; + Ok(()) } + + async fn cancel_subtasks(&mut self) {} } diff --git a/crates/task-impls/src/view_sync.rs b/crates/task-impls/src/view_sync.rs index 36fd1424f2..0fa9e60224 100644 --- a/crates/task-impls/src/view_sync.rs +++ b/crates/task-impls/src/view_sync.rs @@ -6,12 +6,14 @@ use std::{ time::Duration, }; -use async_broadcast::Sender; +use anyhow::Result; +use async_broadcast::{Receiver, Sender}; use async_compatibility_layer::art::{async_sleep, async_spawn}; use async_lock::RwLock; #[cfg(async_executor_impl = "async-std")] use async_std::task::JoinHandle; -use hotshot_task::task::{Task, TaskState}; +use async_trait::async_trait; +use hotshot_task::task::TaskState; use hotshot_types::{ message::GeneralConsensusMessage, simple_certificate::{ @@ -22,7 +24,6 @@ use hotshot_types::{ ViewSyncPreCommitData, ViewSyncPreCommitVote, }, traits::{ - consensus_api::ConsensusApi, election::Membership, node_implementation::{ConsensusTime, NodeImplementation, NodeType}, signature_key::SignatureKey, @@ -58,11 +59,7 @@ type RelayMap = HashMap<::Time, BTreeMap>>; /// Main view sync task state -pub struct ViewSyncTaskState< - TYPES: NodeType, - I: NodeImplementation, - A: ConsensusApi + 'static + std::clone::Clone, -> { +pub struct ViewSyncTaskState> { /// View HotShot is currently in pub current_view: TYPES::Time, /// View HotShot wishes to be in @@ -75,8 +72,6 @@ pub struct ViewSyncTaskState< pub public_key: TYPES::SignatureKey, /// Our Private Key pub private_key: ::PrivateKey, - /// HotShot consensus API - pub api: A, /// Our node id; for logging pub id: u64, @@ -84,7 +79,7 @@ pub struct ViewSyncTaskState< pub num_timeouts_tracked: u64, /// Map of running replica tasks - pub replica_task_map: RwLock>>, + pub replica_task_map: RwLock>>, /// Map of pre-commit vote accumulates for the relay pub pre_commit_relay_map: @@ -103,49 +98,26 @@ pub struct ViewSyncTaskState< pub last_garbage_collected_view: TYPES::Time, } -impl< - TYPES: NodeType, - I: NodeImplementation, - A: ConsensusApi + 'static + std::clone::Clone, - > TaskState for ViewSyncTaskState -{ - type Event = Arc>; - - type Output = (); +#[async_trait] +impl> TaskState for ViewSyncTaskState { + type Event = HotShotEvent; - async fn handle_event(event: Self::Event, task: &mut Task) -> Option<()> { - let sender = task.clone_sender(); - task.state_mut().handle(event, sender).await; - None - } + async fn handle_event( + &mut self, + event: Arc, + sender: &Sender>, + _receiver: &Receiver>, + ) -> Result<()> { + self.handle(event, sender.clone()).await; - fn filter(&self, event: &Self::Event) -> bool { - !matches!( - event.as_ref(), - HotShotEvent::ViewSyncPreCommitCertificate2Recv(_) - | HotShotEvent::ViewSyncCommitCertificate2Recv(_) - | HotShotEvent::ViewSyncFinalizeCertificate2Recv(_) - | HotShotEvent::ViewSyncPreCommitVoteRecv(_) - | HotShotEvent::ViewSyncCommitVoteRecv(_) - | HotShotEvent::ViewSyncFinalizeVoteRecv(_) - | HotShotEvent::Shutdown - | HotShotEvent::Timeout(_) - | HotShotEvent::ViewSyncTimeout(_, _, _) - | HotShotEvent::ViewChange(_) - ) + Ok(()) } - fn should_shutdown(event: &Self::Event) -> bool { - matches!(event.as_ref(), HotShotEvent::Shutdown) - } + async fn cancel_subtasks(&mut self) {} } /// State of a view sync replica task -pub struct ViewSyncReplicaTaskState< - TYPES: NodeType, - I: NodeImplementation, - A: ConsensusApi + 'static, -> { +pub struct ViewSyncReplicaTaskState> { /// Timeout for view sync rounds pub view_sync_timeout: Duration, /// Current round HotShot is in @@ -171,49 +143,29 @@ pub struct ViewSyncReplicaTaskState< pub public_key: TYPES::SignatureKey, /// Our Private Key pub private_key: ::PrivateKey, - /// HotShot consensus API - pub api: A, } -impl, A: ConsensusApi + 'static> TaskState - for ViewSyncReplicaTaskState +#[async_trait] +impl> TaskState + for ViewSyncReplicaTaskState { - type Event = Arc>; + type Event = HotShotEvent; - type Output = (); + async fn handle_event( + &mut self, + event: Arc, + sender: &Sender>, + _receiver: &Receiver>, + ) -> Result<()> { + self.handle(event, sender.clone()).await; - async fn handle_event(event: Self::Event, task: &mut Task) -> Option<()> { - let sender = task.clone_sender(); - task.state_mut().handle(event, sender).await; - None - } - fn filter(&self, event: &Self::Event) -> bool { - !matches!( - event.as_ref(), - HotShotEvent::ViewSyncPreCommitCertificate2Recv(_) - | HotShotEvent::ViewSyncCommitCertificate2Recv(_) - | HotShotEvent::ViewSyncFinalizeCertificate2Recv(_) - | HotShotEvent::ViewSyncPreCommitVoteRecv(_) - | HotShotEvent::ViewSyncCommitVoteRecv(_) - | HotShotEvent::ViewSyncFinalizeVoteRecv(_) - | HotShotEvent::Shutdown - | HotShotEvent::Timeout(_) - | HotShotEvent::ViewSyncTimeout(_, _, _) - | HotShotEvent::ViewChange(_) - ) + Ok(()) } - fn should_shutdown(event: &Self::Event) -> bool { - matches!(event.as_ref(), HotShotEvent::Shutdown) - } + async fn cancel_subtasks(&mut self) {} } -impl< - TYPES: NodeType, - I: NodeImplementation, - A: ConsensusApi + 'static + std::clone::Clone, - > ViewSyncTaskState -{ +impl> ViewSyncTaskState { #[instrument(skip_all, fields(id = self.id, view = *self.current_view), name = "View Sync Main Task", level = "error")] #[allow(clippy::type_complexity)] /// Handles incoming events for the main view sync task @@ -249,7 +201,7 @@ impl< } // We do not have a replica task already running, so start one - let mut replica_state: ViewSyncReplicaTaskState = ViewSyncReplicaTaskState { + let mut replica_state: ViewSyncReplicaTaskState = ViewSyncReplicaTaskState { current_view: view, next_view: view, relay: 0, @@ -260,7 +212,6 @@ impl< network: Arc::clone(&self.network), public_key: self.public_key.clone(), private_key: self.private_key.clone(), - api: self.api.clone(), view_sync_timeout: self.view_sync_timeout, id: self.id, }; @@ -319,7 +270,7 @@ impl< if let Some(relay_task) = relay_map.get_mut(&relay) { debug!("Forwarding message"); let result = relay_task - .handle_event(Arc::clone(&event), &event_stream) + .handle_vote_event(Arc::clone(&event), &event_stream) .await; if result == Some(HotShotTaskCompleted) { @@ -357,7 +308,7 @@ impl< if let Some(relay_task) = relay_map.get_mut(&relay) { debug!("Forwarding message"); let result = relay_task - .handle_event(Arc::clone(&event), &event_stream) + .handle_vote_event(Arc::clone(&event), &event_stream) .await; if result == Some(HotShotTaskCompleted) { @@ -395,7 +346,7 @@ impl< if let Some(relay_task) = relay_map.get_mut(&relay) { debug!("Forwarding message"); let result = relay_task - .handle_event(Arc::clone(&event), &event_stream) + .handle_vote_event(Arc::clone(&event), &event_stream) .await; if result == Some(HotShotTaskCompleted) { @@ -510,9 +461,7 @@ impl< } } -impl, A: ConsensusApi + 'static> - ViewSyncReplicaTaskState -{ +impl> ViewSyncReplicaTaskState { #[instrument(skip_all, fields(id = self.id, view = *self.current_view), name = "View Sync Replica Task", level = "error")] /// Handle incoming events for the view sync replica task pub async fn handle( diff --git a/crates/task-impls/src/vote_collection.rs b/crates/task-impls/src/vote_collection.rs index 62b604cae7..dda367a840 100644 --- a/crates/task-impls/src/vote_collection.rs +++ b/crates/task-impls/src/vote_collection.rs @@ -3,7 +3,6 @@ use std::{collections::HashMap, fmt::Debug, marker::PhantomData, sync::Arc}; use async_broadcast::Sender; use async_trait::async_trait; use either::Either::{self, Left, Right}; -use hotshot_task::task::{Task, TaskState}; use hotshot_types::{ simple_certificate::{ DaCertificate, QuorumCertificate, TimeoutCertificate, UpgradeCertificate, @@ -105,36 +104,6 @@ impl< } } -impl< - TYPES: NodeType, - VOTE: Vote - + AggregatableVote - + std::marker::Send - + std::marker::Sync - + 'static, - CERT: Certificate - + Debug - + std::marker::Send - + std::marker::Sync - + 'static, - > TaskState for VoteCollectionTaskState -where - VoteCollectionTaskState: HandleVoteEvent, -{ - type Event = Arc>; - - type Output = HotShotTaskCompleted; - - async fn handle_event(event: Self::Event, task: &mut Task) -> Option { - let sender = task.clone_sender(); - task.state_mut().handle_event(event, &sender).await - } - - fn should_shutdown(event: &Self::Event) -> bool { - matches!(event.as_ref(), HotShotEvent::Shutdown) - } -} - /// Trait for types which will handle a vote event. #[async_trait] pub trait HandleVoteEvent @@ -144,7 +113,7 @@ where CERT: Certificate + Debug, { /// Handle a vote event - async fn handle_event( + async fn handle_vote_event( &mut self, event: Arc>, sender: &Sender>>, @@ -211,7 +180,7 @@ where id: info.id, }; - let result = state.handle_event(Arc::clone(&event), sender).await; + let result = state.handle_vote_event(Arc::clone(&event), sender).await; if result == Some(HotShotTaskCompleted) { // The protocol has finished @@ -354,7 +323,7 @@ impl impl HandleVoteEvent, QuorumCertificate> for QuorumVoteState { - async fn handle_event( + async fn handle_vote_event( &mut self, event: Arc>, sender: &Sender>>, @@ -374,7 +343,7 @@ impl HandleVoteEvent, QuorumCertificat impl HandleVoteEvent, UpgradeCertificate> for UpgradeVoteState { - async fn handle_event( + async fn handle_vote_event( &mut self, event: Arc>, sender: &Sender>>, @@ -393,7 +362,7 @@ impl HandleVoteEvent, UpgradeCertific impl HandleVoteEvent, DaCertificate> for DaVoteState { - async fn handle_event( + async fn handle_vote_event( &mut self, event: Arc>, sender: &Sender>>, @@ -412,7 +381,7 @@ impl HandleVoteEvent, DaCertificate impl HandleVoteEvent, TimeoutCertificate> for TimeoutVoteState { - async fn handle_event( + async fn handle_vote_event( &mut self, event: Arc>, sender: &Sender>>, @@ -432,7 +401,7 @@ impl HandleVoteEvent, ViewSyncPreCommitCertificate2> for ViewSyncPreCommitState { - async fn handle_event( + async fn handle_vote_event( &mut self, event: Arc>, sender: &Sender>>, @@ -454,7 +423,7 @@ impl HandleVoteEvent, ViewSyncCommitCertificate2> for ViewSyncCommitVoteState { - async fn handle_event( + async fn handle_vote_event( &mut self, event: Arc>, sender: &Sender>>, @@ -474,7 +443,7 @@ impl HandleVoteEvent, ViewSyncFinalizeCertificate2> for ViewSyncFinalizeVoteState { - async fn handle_event( + async fn handle_vote_event( &mut self, event: Arc>, sender: &Sender>>, diff --git a/crates/task/Cargo.toml b/crates/task/Cargo.toml index 7e4dadd3a4..3983158a31 100644 --- a/crates/task/Cargo.toml +++ b/crates/task/Cargo.toml @@ -8,10 +8,12 @@ edition = { workspace = true } [dependencies] -futures = "0.3" -async-broadcast = "0.7" +futures = { workspace = true } +async-broadcast = { workspace = true } tracing = { workspace = true } async-compatibility-layer = { workspace = true } +anyhow = { workspace = true } +async-trait = { workspace = true } [target.'cfg(all(async_executor_impl = "tokio"))'.dependencies] tokio = { workspace = true, features = [ diff --git a/crates/task/src/task.rs b/crates/task/src/task.rs index ad7ce9c316..1daebeaa1c 100644 --- a/crates/task/src/task.rs +++ b/crates/task/src/task.rs @@ -1,466 +1,187 @@ -use std::{sync::Arc, time::Duration}; +use std::sync::Arc; -use async_broadcast::{Receiver, SendError, Sender}; -use async_compatibility_layer::art::async_timeout; +use anyhow::Result; +use async_broadcast::{Receiver, Sender}; #[cfg(async_executor_impl = "async-std")] -use async_std::{ - sync::RwLock, - task::{spawn, JoinHandle}, -}; +use async_std::task::{spawn, JoinHandle}; +use async_trait::async_trait; #[cfg(async_executor_impl = "async-std")] use futures::future::join_all; #[cfg(async_executor_impl = "tokio")] use futures::future::try_join_all; -use futures::{future::select_all, Future}; #[cfg(async_executor_impl = "tokio")] -use tokio::{ - sync::RwLock, - task::{spawn, JoinHandle}, -}; -use tracing::{error, warn}; +use tokio::task::{spawn, JoinHandle}; -use crate::{ - dependency::Dependency, - dependency_task::{DependencyTask, HandleDepOutput}, -}; +/// Trait for events that long-running tasks handle +pub trait TaskEvent: PartialEq { + /// The shutdown signal for this event type + /// + /// Note that this is necessarily uniform across all tasks. + /// Exiting the task loop is handled by the task spawner, rather than the task individually. + fn shutdown_event() -> Self; +} +#[async_trait] /// Type for mutable task state that can be used as the state for a `Task` pub trait TaskState: Send { /// Type of event sent and received by the task - type Event: Clone + Send + Sync + 'static; - /// The result returned when this task completes - type Output: Send; - /// Handle event and update state. Return true if the task is finished - /// false otherwise. The handler can access the state through `Task::state_mut` - fn handle_event( - event: Self::Event, - task: &mut Task, - ) -> impl Future> + Send - where - Self: Sized; + type Event: TaskEvent + Clone + Send + Sync; - /// Return true if the event should be filtered - fn filter(&self, _event: &Self::Event) -> bool { - // default doesn't filter - false - } - /// Do something with the result of the task before it shuts down - fn handle_result(&self, _res: &Self::Output) -> impl std::future::Future + Send { - async {} - } - /// Return true if the event should shut the task down - fn should_shutdown(event: &Self::Event) -> bool; - /// Handle anything before the task is completely shutdown - fn shutdown(&mut self) -> impl std::future::Future + Send { - async {} - } -} + /// Joins all subtasks. + async fn cancel_subtasks(&mut self); -/// Task state for a test. Similar to `TaskState` but it handles -/// messages as well as events. Messages are events that are -/// external to this task. (i.e. a test message would be an event from non test task) -/// This is used as state for `TestTask` and messages can come from many -/// different input streams. -pub trait TestTaskState: Send { - /// Message type handled by the task - type Message: Clone + Send + Sync + 'static; - /// Result returned by the test task on completion - type Output: Send; - /// The state type - type State: TaskState; - /// Handle and incoming message and return `Some` if the task is finished - fn handle_message( - message: Self::Message, - id: usize, - task: &mut TestTask, - ) -> impl Future> + Send - where - Self: Sized; + /// Handles an event, providing direct access to the specific channel we received the event on. + async fn handle_event( + &mut self, + event: Arc, + _sender: &Sender>, + _receiver: &Receiver>, + ) -> Result<()>; } /// A basic task which loops waiting for events to come from `event_receiver` -/// and then handles them using it's state -/// It sends events to other `Task`s through `event_sender` +/// and then handles them using its state +/// It sends events to other `Task`s through `sender` /// This should be used as the primary building block for long running /// or medium running tasks (i.e. anything that can't be described as a dependency task) pub struct Task { + /// The state of the task. It is fed events from `receiver` + /// and mutated via `handle_event`. + state: S, /// Sends events all tasks including itself - event_sender: Sender, + sender: Sender>, /// Receives events that are broadcast from any task, including itself - event_receiver: Receiver, - /// Contains this task, used to register any spawned tasks - registry: Arc, - /// The state of the task. It is fed events from `event_sender` - /// and mutates it state ocordingly. Also it signals the task - /// if it is complete/should shutdown - state: S, + receiver: Receiver>, } impl Task { /// Create a new task - pub fn new( - tx: Sender, - rx: Receiver, - registry: Arc, - state: S, - ) -> Self { + pub fn new(state: S, sender: Sender>, receiver: Receiver>) -> Self { Task { - event_sender: tx, - event_receiver: rx, - registry, state, + sender, + receiver, } } - /// The Task analog of `TaskState::handle_event`. - pub fn handle_event( - &mut self, - event: S::Event, - ) -> impl Future> + Send + '_ - where - Self: Sized, - { - S::handle_event(event, self) + /// The state of the task, as a boxed dynamic trait object. + fn boxed_state(self) -> Box> { + Box::new(self.state) as Box> } /// Spawn the task loop, consuming self. Will continue until /// the task reaches some shutdown condition - pub fn run(mut self) -> JoinHandle<()> { - spawn(async move { - loop { - match self.event_receiver.recv_direct().await { - Ok(event) => { - if S::should_shutdown(&event) { - self.state.shutdown().await; - break; - } - if self.state.filter(&event) { - continue; - } - if let Some(res) = S::handle_event(event, &mut self).await { - self.state.handle_result(&res).await; - self.state.shutdown().await; - break; - } - } - Err(e) => { - tracing::error!("Failed to receiving from event stream Error: {}", e); - } - } - } - }) - } - - /// Create a new event `Receiver` from this Task's receiver. - /// The returned receiver will get all messages not yet seen by this task - pub fn subscribe(&self) -> Receiver { - self.event_receiver.clone() - } - /// Get a new sender handle for events - pub fn sender(&self) -> &Sender { - &self.event_sender - } - /// Clone the sender handle - pub fn clone_sender(&self) -> Sender { - self.event_sender.clone() - } - /// Broadcast a message to all listening tasks - /// # Errors - /// Errors if the broadcast fails - pub async fn send(&self, event: S::Event) -> Result, SendError> { - self.event_sender.broadcast(event).await - } - /// Get a mutable reference to this tasks state - pub fn state_mut(&mut self) -> &mut S { - &mut self.state - } - /// Get an immutable reference to this tasks state - pub fn state(&self) -> &S { - &self.state - } - - /// Spawn a new task and register it. It will get all events not seend - /// by the task creating it. - pub async fn run_sub_task(&self, state: S) { - let task = Task { - event_sender: self.clone_sender(), - event_receiver: self.subscribe(), - registry: Arc::clone(&self.registry), - state, - }; - // Note: await here is only awaiting the task to be added to the - // registry, not for the task to run. - self.registry.run_task(task).await; - } -} - -/// Similar to `Task` but adds functionality for testing. Notably -/// it adds message receivers to collect events from many non-test tasks -pub struct TestTask { - /// Task which handles test events - task: Task, - /// Receivers for outside events - message_receivers: Vec>, -} - -impl< - S: TaskState + Send + 'static, - T: TestTaskState + Send + Sync + 'static, - > TestTask -{ - /// Create a test task - pub fn new(task: Task, rxs: Vec>) -> Self { - Self { - task, - message_receivers: rxs, - } - } - /// Runs the task, taking events from the the test events and the message receivers. - /// Consumes self and runs until some shutdown condition is met. - /// The join handle will return the result of the task, useful for deciding if the test - /// passed or not. - pub fn run(mut self) -> JoinHandle { + pub fn run(mut self) -> JoinHandle>> { spawn(async move { loop { - let mut futs = vec![]; + match self.receiver.recv_direct().await { + Ok(input) => { + if *input == S::Event::shutdown_event() { + self.state.cancel_subtasks().await; - if let Ok(event) = self.task.event_receiver.try_recv() { - if S::should_shutdown(&event) { - self.task.state.shutdown().await; - tracing::error!("Shutting down test task TODO!"); - todo!(); - } - if !self.state().filter(&event) { - if let Some(res) = S::handle_event(event, &mut self.task).await { - self.task.state.handle_result(&res).await; - self.task.state.shutdown().await; - return res; + break self.boxed_state(); } - } - } - for rx in &mut self.message_receivers { - futs.push(rx.recv()); - } - // if let Ok((Ok(msg), id, _)) = - match async_timeout(Duration::from_secs(1), select_all(futs)).await { - Ok((Ok(msg), id, _)) => { - if let Some(res) = T::handle_message(msg, id, &mut self).await { - self.task.state.handle_result(&res).await; - self.task.state.shutdown().await; - return res; - } + let _ = + S::handle_event(&mut self.state, input, &self.sender, &self.receiver) + .await + .inspect_err(|e| tracing::info!("{e}")); } Err(e) => { - warn!("Failed to get event from task. Error: {:?}", e); - } - Ok((Err(e), _, _)) => { - error!("A task channel returned an Error: {:?}", e); + tracing::error!("Failed to receive from event stream Error: {}", e); } } } }) } - - /// Get a ref to state - pub fn state(&self) -> &S { - &self.task.state - } - /// Get a mutable ref to state - pub fn state_mut(&mut self) -> &mut S { - self.task.state_mut() - } - /// Send an event to other listening test tasks - /// - /// # Panics - /// panics if the event can't be sent (ok to panic in test) - pub async fn send_event(&self, event: S::Event) { - self.task.send(event).await.unwrap(); - } } #[derive(Default)] /// A collection of tasks which can handle shutdown -pub struct TaskRegistry { +pub struct ConsensusTaskRegistry { /// Tasks this registry controls - task_handles: RwLock>>, + task_handles: Vec>>>, } -impl TaskRegistry { +impl ConsensusTaskRegistry { + #[must_use] + /// Create a new task registry + pub fn new() -> Self { + ConsensusTaskRegistry { + task_handles: vec![], + } + } /// Add a task to the registry - pub async fn register(&self, handle: JoinHandle<()>) { - self.task_handles.write().await.push(handle); + pub fn register(&mut self, handle: JoinHandle>>) { + self.task_handles.push(handle); } /// Try to cancel/abort the task this registry has - pub async fn shutdown(&self) { - let mut handles = self.task_handles.write().await; + /// + /// # Panics + /// + /// Should not panic, unless awaiting on the JoinHandle in tokio fails. + pub async fn shutdown(&mut self) { + let handles = &mut self.task_handles; + while let Some(handle) = handles.pop() { #[cfg(async_executor_impl = "async-std")] - handle.cancel().await; + let mut task_state = handle.await; #[cfg(async_executor_impl = "tokio")] - handle.abort(); + let mut task_state = handle.await.unwrap(); + + task_state.cancel_subtasks().await; } } /// Take a task, run it, and register it - pub async fn run_task(&self, task: Task) + pub fn run_task(&mut self, task: Task) where - S: TaskState + Send + 'static, + S: TaskState + Send + 'static, { - self.register(task.run()).await; - } - /// Create a new `DependencyTask` run it, and register it - pub async fn spawn_dependency_task( - &self, - dep: impl Dependency + Send + 'static, - handle: impl HandleDepOutput, - ) { - let join_handle = DependencyTask { dep, handle }.run(); - self.register(join_handle).await; + self.register(task.run()); } + /// Wait for the results of all the tasks registered /// # Panics /// Panics if one of the tasks panicked - pub async fn join_all(self) -> Vec<()> { + pub async fn join_all(self) -> Vec>> { #[cfg(async_executor_impl = "async-std")] - let ret = join_all(self.task_handles.into_inner()).await; + let states = join_all(self.task_handles).await; #[cfg(async_executor_impl = "tokio")] - let ret = try_join_all(self.task_handles.into_inner()).await.unwrap(); - ret + let states = try_join_all(self.task_handles).await.unwrap(); + + states } } -#[cfg(test)] -mod tests { - use std::{collections::HashSet, time::Duration}; - - use async_broadcast::broadcast; - #[cfg(async_executor_impl = "async-std")] - use async_std::task::sleep; - #[cfg(async_executor_impl = "tokio")] - use tokio::time::sleep; - - use super::*; - - #[derive(Default)] - pub struct DummyHandle { - val: usize, - seen: HashSet, - } +#[derive(Default)] +/// A collection of tasks which can handle shutdown +pub struct NetworkTaskRegistry { + /// Tasks this registry controls + pub handles: Vec>, +} - #[allow(clippy::panic)] - impl TaskState for DummyHandle { - type Event = usize; - type Output = (); - async fn handle_event(event: usize, task: &mut Task) -> Option<()> { - sleep(Duration::from_millis(10)).await; - let state = task.state_mut(); - state.seen.insert(event); - if event > state.val { - state.val = event; - assert!( - state.val < 100, - "Test should shutdown before getting an event for 100" - ); - task.send(event + 1).await.unwrap(); - } - None - } - fn should_shutdown(event: &usize) -> bool { - *event >= 98 - } - async fn shutdown(&mut self) { - for i in 1..98 { - assert!(self.seen.contains(&i)); - } - } +impl NetworkTaskRegistry { + #[must_use] + /// Create a new task registry + pub fn new() -> Self { + NetworkTaskRegistry { handles: vec![] } } - impl TestTaskState for DummyHandle { - type Message = String; - type Output = (); - type State = Self; + #[allow(clippy::unused_async)] + /// Shuts down all tasks in the registry, performing any associated cleanup. + pub async fn shutdown(&mut self) { + let handles = &mut self.handles; - async fn handle_message( - message: Self::Message, - _: usize, - _: &mut TestTask, - ) -> Option<()> { - if message == *"done".to_string() { - return Some(()); - } - None + while let Some(handle) = handles.pop() { + #[cfg(async_executor_impl = "async-std")] + handle.cancel().await; + #[cfg(async_executor_impl = "tokio")] + handle.abort(); } } - #[cfg_attr(async_executor_impl = "tokio", tokio::test(flavor = "multi_thread"))] - #[cfg_attr(async_executor_impl = "async-std", async_std::test)] - #[allow(unused_must_use)] - async fn it_works() { - let reg = Arc::new(TaskRegistry::default()); - let (tx, rx) = broadcast(10); - let task1 = Task:: { - event_sender: tx.clone(), - event_receiver: rx.clone(), - registry: Arc::clone(®), - state: DummyHandle::default(), - }; - tx.broadcast(1).await.unwrap(); - let task2 = Task:: { - event_sender: tx.clone(), - event_receiver: rx, - registry: reg, - state: DummyHandle::default(), - }; - let handle = task2.run(); - let _res = task1.run().await; - handle.await; - } - - #[cfg_attr( - async_executor_impl = "tokio", - tokio::test(flavor = "multi_thread", worker_threads = 10) - )] - #[cfg_attr(async_executor_impl = "async-std", async_std::test)] - #[allow(clippy::should_panic_without_expect)] - #[should_panic] - async fn test_works() { - let reg = Arc::new(TaskRegistry::default()); - let (tx, rx) = broadcast(10); - let (msg_tx, msg_rx) = broadcast(10); - let task1 = Task:: { - event_sender: tx.clone(), - event_receiver: rx.clone(), - registry: Arc::clone(®), - state: DummyHandle::default(), - }; - tx.broadcast(1).await.unwrap(); - let task2 = Task:: { - event_sender: tx.clone(), - event_receiver: rx, - registry: reg, - state: DummyHandle::default(), - }; - let test1 = TestTask::<_, DummyHandle> { - task: task1, - message_receivers: vec![msg_rx.clone()], - }; - let test2 = TestTask::<_, DummyHandle> { - task: task2, - message_receivers: vec![msg_rx.clone()], - }; - let handle = test1.run(); - let handle2 = test2.run(); - sleep(Duration::from_millis(30)).await; - msg_tx.broadcast("done".into()).await.unwrap(); - #[cfg(async_executor_impl = "tokio")] - { - handle.await.unwrap(); - handle2.await.unwrap(); - } - #[cfg(async_executor_impl = "async-std")] - { - handle.await; - handle2.await; - } + /// Add a task to the registry + pub fn register(&mut self, handle: JoinHandle<()>) { + self.handles.push(handle); } } diff --git a/crates/testing/Cargo.toml b/crates/testing/Cargo.toml index ad2757fdb5..af0bb6b6d4 100644 --- a/crates/testing/Cargo.toml +++ b/crates/testing/Cargo.toml @@ -14,6 +14,7 @@ dependency-tasks = ["hotshot/dependency-tasks"] [dependencies] automod = "1.0.14" +anyhow = { workspace = true } async-broadcast = { workspace = true } async-compatibility-layer = { workspace = true } async-lock = { workspace = true } diff --git a/crates/testing/src/completion_task.rs b/crates/testing/src/completion_task.rs index c0e5d77e67..3f8f711454 100644 --- a/crates/testing/src/completion_task.rs +++ b/crates/testing/src/completion_task.rs @@ -1,7 +1,8 @@ -use std::time::Duration; +use std::{sync::Arc, time::Duration}; use async_broadcast::{Receiver, Sender}; use async_compatibility_layer::art::{async_spawn, async_timeout}; +use async_lock::RwLock; #[cfg(async_executor_impl = "async-std")] use async_std::task::JoinHandle; use hotshot::traits::TestableNodeImplementation; @@ -11,8 +12,7 @@ use snafu::Snafu; #[cfg(async_executor_impl = "tokio")] use tokio::task::JoinHandle; -use super::GlobalTestEvent; -use crate::test_runner::{HotShotTaskCompleted, Node}; +use crate::{test_runner::Node, test_task::TestEvent}; /// the idea here is to run as long as we want @@ -22,33 +22,33 @@ pub struct CompletionTaskErr {} /// Completion task state pub struct CompletionTask> { - pub tx: Sender, + pub tx: Sender, - pub rx: Receiver, + pub rx: Receiver, /// handles to the nodes in the test - pub(crate) handles: Vec>, + pub(crate) handles: Arc>>>, /// Duration of the task. pub duration: Duration, } impl> CompletionTask { - pub fn run(mut self) -> JoinHandle { + pub fn run(mut self) -> JoinHandle<()> { async_spawn(async move { if async_timeout(self.duration, self.wait_for_shutdown()) .await .is_err() { - broadcast_event(GlobalTestEvent::ShutDown, &self.tx).await; + broadcast_event(TestEvent::Shutdown, &self.tx).await; } - for node in &self.handles { - node.handle.clone().shut_down().await; + + for node in &mut self.handles.write().await.iter_mut() { + node.handle.shut_down().await; } - HotShotTaskCompleted::ShutDown }) } async fn wait_for_shutdown(&mut self) { while let Ok(event) = self.rx.recv_direct().await { - if matches!(event, GlobalTestEvent::ShutDown) { + if matches!(event, TestEvent::Shutdown) { tracing::error!("Completion Task shutting down"); return; } diff --git a/crates/testing/src/lib.rs b/crates/testing/src/lib.rs index e1bb38e034..43ce6ebe5b 100644 --- a/crates/testing/src/lib.rs +++ b/crates/testing/src/lib.rs @@ -30,6 +30,9 @@ pub mod completion_task; /// task to spin nodes up and down pub mod spinning_task; +/// the `TestTask` struct and associated trait/functions +pub mod test_task; + /// task for checking if view sync got activated pub mod view_sync_task; @@ -44,10 +47,3 @@ pub mod script; /// view generator for tests pub mod view_generator; - -/// global event at the test level -#[derive(Clone, Debug)] -pub enum GlobalTestEvent { - /// the test is shutting down - ShutDown, -} diff --git a/crates/testing/src/overall_safety_task.rs b/crates/testing/src/overall_safety_task.rs index 635c277539..0827e57fd6 100644 --- a/crates/testing/src/overall_safety_task.rs +++ b/crates/testing/src/overall_safety_task.rs @@ -3,8 +3,11 @@ use std::{ sync::Arc, }; +use anyhow::Result; +use async_broadcast::Sender; +use async_lock::RwLock; +use async_trait::async_trait; use hotshot::{traits::TestableNodeImplementation, HotShotError}; -use hotshot_task::task::{Task, TaskState, TestTaskState}; use hotshot_types::{ data::Leaf, error::RoundTimedoutState, @@ -16,12 +19,13 @@ use hotshot_types::{ use snafu::Snafu; use tracing::error; -use crate::test_runner::{HotShotTaskCompleted, Node}; +use crate::{ + test_runner::Node, + test_task::{TestEvent, TestResult, TestTaskState}, +}; /// convenience type alias for state and block pub type StateAndBlock = (Vec, Vec); -use super::GlobalTestEvent; - /// the status of a view #[derive(Debug, Clone)] pub enum ViewStatus { @@ -66,78 +70,25 @@ pub enum OverallSafetyTaskErr { /// Data availability task state pub struct OverallSafetyTask> { /// handles - pub handles: Vec>, + pub handles: Arc>>>, /// ctx pub ctx: RoundCtx, /// configure properties pub properties: OverallSafetyPropertiesDescription, + /// error + pub error: Option>>, + /// sender to test event channel + pub test_sender: Sender, } -impl> TaskState - for OverallSafetyTask -{ - type Event = GlobalTestEvent; - - type Output = HotShotTaskCompleted; - - async fn handle_event(event: Self::Event, task: &mut Task) -> Option { - match event { - GlobalTestEvent::ShutDown => { - tracing::error!("Shutting down SafetyTask"); - let state = task.state_mut(); - let OverallSafetyPropertiesDescription { - check_leaf: _, - check_block: _, - num_failed_views: num_failed_rounds_total, - num_successful_views, - threshold_calculator: _, - transaction_threshold: _, - }: OverallSafetyPropertiesDescription = state.properties.clone(); - - let num_incomplete_views = state.ctx.round_results.len() - - state.ctx.successful_views.len() - - state.ctx.failed_views.len(); - - if state.ctx.successful_views.len() < num_successful_views { - return Some(HotShotTaskCompleted::Error(Box::new( - OverallSafetyTaskErr::::NotEnoughDecides { - got: state.ctx.successful_views.len(), - expected: num_successful_views, - }, - ))); - } - - if state.ctx.failed_views.len() + num_incomplete_views >= num_failed_rounds_total { - return Some(HotShotTaskCompleted::Error(Box::new( - OverallSafetyTaskErr::::TooManyFailures { - failed_views: state.ctx.failed_views.clone(), - }, - ))); - } - Some(HotShotTaskCompleted::ShutDown) - } - } - } - - fn should_shutdown(_event: &Self::Event) -> bool { - false - } -} - +#[async_trait] impl> TestTaskState for OverallSafetyTask { - type Message = Event; - - type Output = HotShotTaskCompleted; + type Event = Event; - type State = Self; - - async fn handle_message( - message: Self::Message, - idx: usize, - task: &mut hotshot_task::task::TestTask, - ) -> Option { + /// Handles an event from one of multiple receivers. + async fn handle_event(&mut self, (message, id): (Self::Event, usize)) -> Result<()> { let OverallSafetyPropertiesDescription { check_leaf, check_block, @@ -145,13 +96,12 @@ impl> TestTaskState num_successful_views, threshold_calculator, transaction_threshold, - }: OverallSafetyPropertiesDescription = task.state().properties.clone(); + }: OverallSafetyPropertiesDescription = self.properties.clone(); let Event { view_number, event } = message; let key = match event { EventType::Error { error } => { - task.state_mut() - .ctx - .insert_error_to_context(view_number, idx, error); + self.ctx + .insert_error_to_context(view_number, id, error.clone()); None } EventType::Decide { @@ -161,17 +111,17 @@ impl> TestTaskState } => { // Skip the genesis leaf. if leaf_chain.last().unwrap().leaf.view_number() == TYPES::Time::genesis() { - return None; + return Ok(()); } let paired_up = (leaf_chain.to_vec(), (*qc).clone()); - match task.state_mut().ctx.round_results.entry(view_number) { + match self.ctx.round_results.entry(view_number) { Entry::Occupied(mut o) => { o.get_mut() - .insert_into_result(idx, paired_up, maybe_block_size) + .insert_into_result(id, paired_up, maybe_block_size) } Entry::Vacant(v) => { let mut round_result = RoundResult::default(); - let key = round_result.insert_into_result(idx, paired_up, maybe_block_size); + let key = round_result.insert_into_result(id, paired_up, maybe_block_size); v.insert(round_result); key } @@ -182,25 +132,18 @@ impl> TestTaskState view_number, state: RoundTimedoutState::TestCollectRoundEventsTimedOut, }); - task.state_mut() - .ctx - .insert_error_to_context(view_number, idx, error); + self.ctx.insert_error_to_context(view_number, id, error); None } - _ => return None, + _ => return Ok(()), }; + let len = self.handles.read().await.len(); + // update view count - let threshold = - (threshold_calculator)(task.state().handles.len(), task.state().handles.len()); - - let len = task.state().handles.len(); - let view = task - .state_mut() - .ctx - .round_results - .get_mut(&view_number) - .unwrap(); + let threshold = (threshold_calculator)(len, len); + + let view = self.ctx.round_results.get_mut(&view_number).unwrap(); if let Some(key) = key { view.update_status( threshold, @@ -212,47 +155,77 @@ impl> TestTaskState ); match view.status.clone() { ViewStatus::Ok => { - task.state_mut().ctx.successful_views.insert(view_number); - if task.state_mut().ctx.successful_views.len() >= num_successful_views { - task.send_event(GlobalTestEvent::ShutDown).await; - return Some(HotShotTaskCompleted::ShutDown); + self.ctx.successful_views.insert(view_number); + if self.ctx.successful_views.len() >= num_successful_views { + let _ = self.test_sender.broadcast(TestEvent::Shutdown).await; } - return None; + return Ok(()); } ViewStatus::Failed => { - task.state_mut().ctx.failed_views.insert(view_number); - if task.state_mut().ctx.failed_views.len() > num_failed_views { - task.send_event(GlobalTestEvent::ShutDown).await; - return Some(HotShotTaskCompleted::Error(Box::new( - OverallSafetyTaskErr::::TooManyFailures { - failed_views: task.state_mut().ctx.failed_views.clone(), - }, - ))); + self.ctx.failed_views.insert(view_number); + if self.ctx.failed_views.len() > num_failed_views { + let _ = self.test_sender.broadcast(TestEvent::Shutdown).await; + self.error = + Some(Box::new(OverallSafetyTaskErr::::TooManyFailures { + failed_views: self.ctx.failed_views.clone(), + })); } - return None; + return Ok(()); } ViewStatus::Err(e) => { - task.send_event(GlobalTestEvent::ShutDown).await; - return Some(HotShotTaskCompleted::Error(Box::new(e))); + let _ = self.test_sender.broadcast(TestEvent::Shutdown).await; + self.error = Some(Box::new(e)); + return Ok(()); } ViewStatus::InProgress => { - return None; + return Ok(()); } } } else if view.check_if_failed(threshold, len) { view.status = ViewStatus::Failed; - task.state_mut().ctx.failed_views.insert(view_number); - if task.state_mut().ctx.failed_views.len() > num_failed_views { - task.send_event(GlobalTestEvent::ShutDown).await; - return Some(HotShotTaskCompleted::Error(Box::new( - OverallSafetyTaskErr::::TooManyFailures { - failed_views: task.state_mut().ctx.failed_views.clone(), - }, - ))); + self.ctx.failed_views.insert(view_number); + if self.ctx.failed_views.len() > num_failed_views { + let _ = self.test_sender.broadcast(TestEvent::Shutdown).await; + self.error = Some(Box::new(OverallSafetyTaskErr::::TooManyFailures { + failed_views: self.ctx.failed_views.clone(), + })); } - return None; + return Ok(()); } - None + Ok(()) + } + + fn check(&self) -> TestResult { + if let Some(e) = &self.error { + return TestResult::Fail(e.clone()); + } + + let OverallSafetyPropertiesDescription { + check_leaf: _, + check_block: _, + num_failed_views: num_failed_rounds_total, + num_successful_views, + threshold_calculator: _, + transaction_threshold: _, + }: OverallSafetyPropertiesDescription = self.properties.clone(); + + let num_incomplete_views = self.ctx.round_results.len() + - self.ctx.successful_views.len() + - self.ctx.failed_views.len(); + + if self.ctx.successful_views.len() < num_successful_views { + return TestResult::Fail(Box::new(OverallSafetyTaskErr::::NotEnoughDecides { + got: self.ctx.successful_views.len(), + expected: num_successful_views, + })); + } + + if self.ctx.failed_views.len() + num_incomplete_views > num_failed_rounds_total { + return TestResult::Fail(Box::new(OverallSafetyTaskErr::::TooManyFailures { + failed_views: self.ctx.failed_views.clone(), + })); + } + TestResult::Pass } } diff --git a/crates/testing/src/predicates/event.rs b/crates/testing/src/predicates/event.rs index e5a2c15d60..69c40db865 100644 --- a/crates/testing/src/predicates/event.rs +++ b/crates/testing/src/predicates/event.rs @@ -237,6 +237,27 @@ where Box::new(EventPredicate { check, info }) } +pub fn view_sync_timeout() -> Box> +where + TYPES: NodeType, +{ + let info = "ViewSyncTimeout".to_string(); + let check: EventCallback = + Arc::new(move |e: Arc>| matches!(e.as_ref(), ViewSyncTimeout(..))); + Box::new(EventPredicate { check, info }) +} + +pub fn view_sync_precommit_vote_send() -> Box> +where + TYPES: NodeType, +{ + let info = "ViewSyncPreCommitVoteSend".to_string(); + let check: EventCallback = Arc::new(move |e: Arc>| { + matches!(e.as_ref(), ViewSyncPreCommitVoteSend(..)) + }); + Box::new(EventPredicate { check, info }) +} + pub fn vote_now() -> Box> where TYPES: NodeType, diff --git a/crates/testing/src/script.rs b/crates/testing/src/script.rs index 3edde1b4d9..2721c87846 100644 --- a/crates/testing/src/script.rs +++ b/crates/testing/src/script.rs @@ -2,7 +2,7 @@ use std::{sync::Arc, time::Duration}; use async_broadcast::broadcast; use async_compatibility_layer::art::async_timeout; -use hotshot_task::task::{Task, TaskRegistry, TaskState}; +use hotshot_task::task::TaskState; use hotshot_task_impls::events::HotShotEvent; use hotshot_types::traits::node_implementation::NodeType; @@ -10,7 +10,7 @@ use crate::predicates::{Predicate, PredicateResult}; pub const RECV_TIMEOUT: Duration = Duration::from_millis(250); -pub struct TestScriptStage>>> { +pub struct TestScriptStage>> { pub inputs: Vec>, pub outputs: Vec>>>>, pub asserts: Vec>>, @@ -89,22 +89,15 @@ where /// Note: the task is not spawned with an async thread; instead, the harness just calls `handle_event`. /// This has a few implications, e.g. shutting down tasks doesn't really make sense, /// and event ordering is deterministic. -pub async fn run_test_script< - TYPES, - S: TaskState>> + Send + 'static, ->( +pub async fn run_test_script> + Send + 'static>( mut script: TestScript, - state: S, + mut state: S, ) where TYPES: NodeType, { - let registry = Arc::new(TaskRegistry::default()); - let (to_task, mut from_test) = broadcast(1024); let (to_test, mut from_task) = broadcast(1024); - let mut task = Task::new(to_test.clone(), from_test.clone(), registry.clone(), state); - for (stage_number, stage) in script.iter_mut().enumerate() { tracing::debug!("Beginning test stage {}", stage_number); for input in &stage.inputs { @@ -113,13 +106,12 @@ pub async fn run_test_script< .await .expect("Failed to broadcast input message"); - if !task.state_mut().filter(&Arc::new(input.clone())) { - tracing::debug!("Test sent: {:?}", input.clone()); + tracing::debug!("Test sent: {:?}", input.clone()); - if let Some(res) = S::handle_event(input.clone().into(), &mut task).await { - task.state_mut().handle_result(&res).await; - } - } + let _ = state + .handle_event(input.clone().into(), &to_test, &from_test) + .await + .inspect_err(|e| tracing::info!("{e}")); while from_test.try_recv().is_ok() {} } @@ -146,13 +138,12 @@ pub async fn run_test_script< .await .expect("Failed to re-broadcast output message"); - if !task.state_mut().filter(&received_output.clone()) { - tracing::debug!("Test sent: {:?}", received_output.clone()); + tracing::debug!("Test sent: {:?}", received_output.clone()); - if let Some(res) = S::handle_event(received_output.clone(), &mut task).await { - task.state_mut().handle_result(&res).await; - } - } + let _ = state + .handle_event(received_output.clone(), &to_test, &from_test) + .await + .inspect_err(|e| tracing::info!("{e}")); while from_test.try_recv().is_ok() {} @@ -167,7 +158,7 @@ pub async fn run_test_script< } for assert in &mut stage.asserts { - validate_task_state_or_panic(stage_number, task.state(), &**assert).await; + validate_task_state_or_panic(stage_number, &state, &**assert).await; } if let Ok(received_output) = from_task.try_recv() { @@ -177,6 +168,8 @@ pub async fn run_test_script< } pub struct TaskScript { + /// The time to wait on the receiver for this script. + pub timeout: Duration, pub state: S, pub expectations: Vec>, } diff --git a/crates/testing/src/spinning_task.rs b/crates/testing/src/spinning_task.rs index 40466bc004..9a99b5fdd1 100644 --- a/crates/testing/src/spinning_task.rs +++ b/crates/testing/src/spinning_task.rs @@ -1,12 +1,17 @@ -use std::collections::{BTreeMap, HashMap}; +use std::{ + collections::{BTreeMap, HashMap}, + sync::Arc, +}; -use either::{Left, Right}; +use anyhow::Result; +use async_lock::RwLock; +use async_trait::async_trait; +use futures::future::Either::{Left, Right}; use hotshot::{traits::TestableNodeImplementation, types::EventType, HotShotInitializer}; use hotshot_example_types::{ state_types::{TestInstanceState, TestValidatedState}, storage_types::TestStorage, }; -use hotshot_task::task::{Task, TaskState, TestTaskState}; use hotshot_types::{ data::Leaf, event::Event, @@ -21,12 +26,14 @@ use hotshot_types::{ }; use snafu::Snafu; -use crate::test_runner::{HotShotTaskCompleted, LateStartNode, Node, TestRunner}; +use crate::{ + test_runner::{LateStartNode, Node, TestRunner}, + test_task::{TestResult, TestTaskState}, +}; + /// convience type for state and block pub type StateAndBlock = (Vec, Vec); -use super::GlobalTestEvent; - /// error for the spinning task #[derive(Snafu, Debug)] pub struct SpinningTaskErr {} @@ -34,7 +41,7 @@ pub struct SpinningTaskErr {} /// Spinning task state pub struct SpinningTask> { /// handle to the nodes - pub(crate) handles: Vec>, + pub(crate) handles: Arc>>>, /// late start nodes pub(crate) late_start: HashMap>, /// time based changes @@ -47,23 +54,7 @@ pub struct SpinningTask> { pub(crate) high_qc: QuorumCertificate, } -impl> TaskState for SpinningTask { - type Event = GlobalTestEvent; - - type Output = HotShotTaskCompleted; - - async fn handle_event(event: Self::Event, _task: &mut Task) -> Option { - if matches!(event, GlobalTestEvent::ShutDown) { - return Some(HotShotTaskCompleted::ShutDown); - } - None - } - - fn should_shutdown(_event: &Self::Event) -> bool { - false - } -} - +#[async_trait] impl< TYPES: NodeType, I: TestableNodeImplementation, @@ -73,21 +64,11 @@ where I: TestableNodeImplementation, I: NodeImplementation>, { - type Message = Event; - - type Output = HotShotTaskCompleted; + type Event = Event; - type State = Self; - - async fn handle_message( - message: Self::Message, - _id: usize, - task: &mut hotshot_task::task::TestTask, - ) -> Option { + async fn handle_event(&mut self, (message, _id): (Self::Event, usize)) -> Result<()> { let Event { view_number, event } = message; - let state = &mut task.state_mut(); - if let EventType::Decide { leaf_chain, qc: _, @@ -95,27 +76,28 @@ where } = event { let leaf = leaf_chain.first().unwrap().leaf.clone(); - if leaf.view_number() > state.last_decided_leaf.view_number() { - state.last_decided_leaf = leaf; + if leaf.view_number() > self.last_decided_leaf.view_number() { + self.last_decided_leaf = leaf; } } else if let EventType::QuorumProposal { proposal, sender: _, } = event { - if proposal.data.justify_qc.view_number() > state.high_qc.view_number() { - state.high_qc = proposal.data.justify_qc; + if proposal.data.justify_qc.view_number() > self.high_qc.view_number() { + self.high_qc = proposal.data.justify_qc.clone(); } } + // if we have not seen this view before - if state.latest_view.is_none() || view_number > state.latest_view.unwrap() { + if self.latest_view.is_none() || view_number > self.latest_view.unwrap() { // perform operations on the nodes - if let Some(operations) = state.changes.remove(&view_number) { + if let Some(operations) = self.changes.remove(&view_number) { for ChangeNode { idx, updown } in operations { match updown { UpDown::Up => { let node_id = idx.try_into().unwrap(); - if let Some(node) = state.late_start.remove(&node_id) { + if let Some(node) = self.late_start.remove(&node_id) { tracing::error!("Node {} spinning up late", idx); let node_id = idx.try_into().unwrap(); let context = match node.context { @@ -124,11 +106,11 @@ where // based on the received leaf. Right((storage, memberships, config)) => { let initializer = HotShotInitializer::::from_reload( - state.last_decided_leaf.clone(), + self.last_decided_leaf.clone(), TestInstanceState {}, None, view_number, - state.high_qc.clone(), + self.high_qc.clone(), Vec::new(), BTreeMap::new(), ); @@ -164,26 +146,26 @@ where networks: node.networks, handle, }; - state.handles.push(node.clone()); - node.handle.hotshot.start_consensus().await; + + self.handles.write().await.push(node); } } UpDown::Down => { - if let Some(node) = state.handles.get_mut(idx) { + if let Some(node) = self.handles.write().await.get_mut(idx) { tracing::error!("Node {} shutting down", idx); node.handle.shut_down().await; } } UpDown::NetworkUp => { - if let Some(handle) = state.handles.get(idx) { + if let Some(handle) = self.handles.write().await.get(idx) { tracing::error!("Node {} networks resuming", idx); handle.networks.0.resume(); handle.networks.1.resume(); } } UpDown::NetworkDown => { - if let Some(handle) = state.handles.get(idx) { + if let Some(handle) = self.handles.write().await.get(idx) { tracing::error!("Node {} networks pausing", idx); handle.networks.0.pause(); handle.networks.1.pause(); @@ -194,10 +176,14 @@ where } // update our latest view - state.latest_view = Some(view_number); + self.latest_view = Some(view_number); } - None + Ok(()) + } + + fn check(&self) -> TestResult { + TestResult::Pass } } diff --git a/crates/testing/src/test_runner.rs b/crates/testing/src/test_runner.rs index c3289297c3..e84baa58c4 100644 --- a/crates/testing/src/test_runner.rs +++ b/crates/testing/src/test_runner.rs @@ -6,8 +6,11 @@ use std::{ }; use async_broadcast::broadcast; -use either::Either::{self, Left, Right}; -use futures::future::join_all; +use async_lock::RwLock; +use futures::future::{ + join_all, Either, + Either::{Left, Right}, +}; use hotshot::{ traits::TestableNodeImplementation, types::SystemContextHandle, HotShotInitializer, Memberships, SystemContext, @@ -16,7 +19,6 @@ use hotshot_example_types::{ state_types::{TestInstanceState, TestValidatedState}, storage_types::TestStorage, }; -use hotshot_task::task::{Task, TaskRegistry, TestTask}; use hotshot_types::{ consensus::ConsensusMetricsValue, constants::EVENT_CHANNEL_SIZE, @@ -43,76 +45,11 @@ use crate::{ completion_task::CompletionTaskDescription, spinning_task::{ChangeNode, SpinningTask, UpDown}, test_launcher::{Networks, TestLauncher}, + test_task::{TestResult, TestTask}, txn_task::TxnTaskDescription, view_sync_task::ViewSyncTask, }; -/// a node participating in a test -#[derive(Clone)] -pub struct Node> { - /// The node's unique identifier - pub node_id: u64, - /// The underlying networks belonging to the node - pub networks: Networks, - /// The handle to the node's internals - pub handle: SystemContextHandle, -} - -/// Either the node context or the parameters to construct the context for nodes that start late. -pub type LateNodeContext = Either< - Arc>, - ( - >::Storage, - Memberships, - HotShotConfig<::SignatureKey>, - ), ->; - -/// A yet-to-be-started node that participates in tests -pub struct LateStartNode> { - /// The underlying networks belonging to the node - pub networks: Networks, - /// Either the context to which we will use to launch HotShot for initialized node when it's - /// time, or the parameters that will be used to initialize the node and launch HotShot. - pub context: LateNodeContext, -} - -/// The runner of a test network -/// spin up and down nodes, execute rounds -pub struct TestRunner< - TYPES: NodeType, - I: TestableNodeImplementation, - N: ConnectedNetwork, TYPES::SignatureKey>, -> { - /// test launcher, contains a bunch of useful metadata and closures - pub(crate) launcher: TestLauncher, - /// nodes in the test - pub(crate) nodes: Vec>, - /// nodes with a late start - pub(crate) late_start: HashMap>, - /// the next node unique identifier - pub(crate) next_node_id: u64, - /// Phantom for N - pub(crate) _pd: PhantomData, -} - -/// enum describing how the tasks completed -pub enum HotShotTaskCompleted { - /// the task shut down successfully - ShutDown, - /// the task encountered an error - Error(Box), - /// the streams the task was listening for died - StreamsDied, - /// we somehow lost the state - /// this is definitely a bug. - LostState, - /// lost the return value somehow - LostReturnValue, - /// Stream exists but missing handler - MissingHandler, -} - pub trait TaskErr: std::error::Error + Sync + Send + 'static {} impl TaskErr for T {} @@ -131,7 +68,7 @@ where /// if the test fails #[allow(clippy::too_many_lines)] pub async fn run_test>(mut self) { - let (tx, rx) = broadcast(EVENT_CHANNEL_SIZE); + let (test_sender, test_receiver) = broadcast(EVENT_CHANNEL_SIZE); let spinning_changes = self .launcher .metadata @@ -165,8 +102,6 @@ where internal_event_rxs.push(r); } - let reg = Arc::new(TaskRegistry::default()); - let TestRunner { ref launcher, nodes, @@ -178,13 +113,15 @@ where let mut task_futs = vec![]; let meta = launcher.metadata.clone(); + let handles = Arc::new(RwLock::new(nodes)); + let txn_task = if let TxnTaskDescription::RoundRobinTimeBased(duration) = meta.txn_description { let txn_task = TxnTask { - handles: nodes.clone(), + handles: Arc::clone(&handles), next_node_idx: Some(0), duration, - shutdown_chan: rx.clone(), + shutdown_chan: test_receiver.clone(), }; Some(txn_task) } else { @@ -195,9 +132,9 @@ where let CompletionTaskDescription::TimeBasedCompletionTaskBuilder(time_based) = meta.completion_task_description; let completion_task = CompletionTask { - tx: tx.clone(), - rx: rx.clone(), - handles: nodes.clone(), + tx: test_sender.clone(), + rx: test_receiver.clone(), + handles: Arc::clone(&handles), duration: time_based.duration, }; @@ -212,7 +149,7 @@ where } let spinning_task_state = SpinningTask { - handles: nodes.clone(), + handles: Arc::clone(&handles), late_start, latest_view: None, changes, @@ -224,25 +161,24 @@ where ) .await, }; - let spinning_task = TestTask::, SpinningTask>::new( - Task::new(tx.clone(), rx.clone(), reg.clone(), spinning_task_state), + let spinning_task = TestTask::>::new( + spinning_task_state, event_rxs.clone(), + test_receiver.clone(), ); // add safety task let overall_safety_task_state = OverallSafetyTask { - handles: nodes.clone(), + handles: Arc::clone(&handles), ctx: RoundCtx::default(), properties: self.launcher.metadata.overall_safety_properties, + error: None, + test_sender, }; - let safety_task = TestTask::, OverallSafetyTask>::new( - Task::new( - tx.clone(), - rx.clone(), - reg.clone(), - overall_safety_task_state, - ), + let safety_task = TestTask::>::new( + overall_safety_task_state, event_rxs.clone(), + test_receiver.clone(), ); // add view sync task @@ -252,47 +188,55 @@ where _pd: PhantomData, }; - let view_sync_task = TestTask::, ViewSyncTask>::new( - Task::new(tx.clone(), rx.clone(), reg.clone(), view_sync_task_state), + let view_sync_task = TestTask::>::new( + view_sync_task_state, internal_event_rxs, + test_receiver.clone(), ); + let nodes = handles.read().await; + // wait for networks to be ready - for node in &nodes { + for node in &*nodes { node.networks.0.wait_for_ready().await; node.networks.1.wait_for_ready().await; } // Start hotshot - for node in nodes { + for node in &*nodes { if !late_start_nodes.contains(&node.node_id) { node.handle.hotshot.start_consensus().await; } } + + drop(nodes); + task_futs.push(safety_task.run()); task_futs.push(view_sync_task.run()); - if let Some(txn) = txn_task { - task_futs.push(txn.run()); - } - task_futs.push(completion_task.run()); task_futs.push(spinning_task.run()); + + // `generator` tasks that do not process events. + let txn_handle = txn_task.map(|txn| txn.run()); + let completion_handle = completion_task.run(); + let mut error_list = vec![]; #[cfg(async_executor_impl = "async-std")] { let results = join_all(task_futs).await; - tracing::info!("test tasks joined"); + tracing::error!("test tasks joined"); for result in results { match result { - HotShotTaskCompleted::ShutDown => { + TestResult::Pass => { info!("Task shut down successfully"); } - HotShotTaskCompleted::Error(e) => error_list.push(e), - _ => { - panic!("Future impl for task abstraction failed! This should never happen"); - } + TestResult::Fail(e) => error_list.push(e), } } + if let Some(handle) = txn_handle { + handle.cancel().await; + } + completion_handle.cancel().await; } #[cfg(async_executor_impl = "tokio")] @@ -302,28 +246,34 @@ where tracing::error!("test tasks joined"); for result in results { match result { - Ok(res) => { - match res { - HotShotTaskCompleted::ShutDown => { - info!("Task shut down successfully"); - } - HotShotTaskCompleted::Error(e) => error_list.push(e), - _ => { - panic!("Future impl for task abstraction failed! This should never happen"); - } + Ok(res) => match res { + TestResult::Pass => { + info!("Task shut down successfully"); } - } + TestResult::Fail(e) => error_list.push(e), + }, Err(e) => { tracing::error!("Error Joining the test task {:?}", e); } } } + + if let Some(handle) = txn_handle { + handle.abort(); + } + completion_handle.abort(); } assert!( error_list.is_empty(), "TEST FAILED! Results: {error_list:?}" ); + + let mut nodes = handles.write().await; + + for node in &mut *nodes { + node.handle.shut_down().await; + } } /// Add nodes. @@ -379,7 +329,6 @@ where let networks = (self.launcher.resource_generator.channel_generator)(node_id).await; let storage = (self.launcher.resource_generator.storage)(node_id); - // Create a future that waits for the networks to be ready let network0 = networks.0.clone(); let network1 = networks.1.clone(); let networks_ready_future = async move { @@ -387,7 +336,6 @@ where network1.wait_for_ready().await; }; - // Collect it so we can wait for all networks to be ready before starting the tasks networks_ready.push(networks_ready_future); if self.launcher.metadata.skip_late && late_start.contains(&node_id) { @@ -494,3 +442,51 @@ where .expect("Could not init hotshot") } } + +/// a node participating in a test +pub struct Node> { + /// The node's unique identifier + pub node_id: u64, + /// The underlying networks belonging to the node + pub networks: Networks, + /// The handle to the node's internals + pub handle: SystemContextHandle, +} + +/// Either the node context or the parameters to construct the context for nodes that start late. +pub type LateNodeContext = Either< + Arc>, + ( + >::Storage, + Memberships, + HotShotConfig<::SignatureKey>, + ), +>; + +/// A yet-to-be-started node that participates in tests +pub struct LateStartNode> { + /// The underlying networks belonging to the node + pub networks: Networks, + /// Either the context to which we will use to launch HotShot for initialized node when it's + /// time, or the parameters that will be used to initialize the node and launch HotShot. + pub context: LateNodeContext, +} + +/// The runner of a test network +/// spin up and down nodes, execute rounds +pub struct TestRunner< + TYPES: NodeType, + I: TestableNodeImplementation, + N: ConnectedNetwork, TYPES::SignatureKey>, +> { + /// test launcher, contains a bunch of useful metadata and closures + pub(crate) launcher: TestLauncher, + /// nodes in the test + pub(crate) nodes: Vec>, + /// nodes with a late start + pub(crate) late_start: HashMap>, + /// the next node unique identifier + pub(crate) next_node_id: u64, + /// Phantom for N + pub(crate) _pd: PhantomData, +} diff --git a/crates/testing/src/test_task.rs b/crates/testing/src/test_task.rs new file mode 100644 index 0000000000..8346eec0cf --- /dev/null +++ b/crates/testing/src/test_task.rs @@ -0,0 +1,137 @@ +use std::{sync::Arc, time::Duration}; + +use anyhow::Result; +use async_broadcast::{Receiver, Sender}; +use async_compatibility_layer::art::{async_sleep, async_spawn, async_timeout}; +#[cfg(async_executor_impl = "async-std")] +use async_std::task::{spawn, JoinHandle}; +use async_trait::async_trait; +use futures::future::select_all; +use hotshot_task_impls::{events::HotShotEvent, network::NetworkMessageTaskState}; +use hotshot_types::{ + message::{Message, Messages}, + traits::{network::ConnectedNetwork, node_implementation::NodeType}, +}; +#[cfg(async_executor_impl = "tokio")] +use tokio::task::{spawn, JoinHandle}; +use tracing::error; + +/// enum describing how the tasks completed +pub enum TestResult { + /// the test task passed + Pass, + /// the test task failed with an error + Fail(Box), +} + +#[async_trait] +/// Type for mutable task state that can be used as the state for a `Task` +pub trait TestTaskState: Send { + /// Type of event sent and received by the task + type Event: Clone + Send + Sync; + + /// Handles an event from one of multiple receivers. + async fn handle_event(&mut self, (event, id): (Self::Event, usize)) -> Result<()>; + + /// Check the result of the test. + fn check(&self) -> TestResult; +} + +/// A basic task which loops waiting for events to come from `event_receiver` +/// and then handles them using it's state +/// It sends events to other `Task`s through `event_sender` +/// This should be used as the primary building block for long running +/// or medium running tasks (i.e. anything that can't be described as a dependency task) +pub struct TestTask { + /// The state of the task. It is fed events from `event_sender` + /// and mutates it state ocordingly. Also it signals the task + /// if it is complete/should shutdown + state: S, + /// Receives events that are broadcast from any task, including itself + receivers: Vec>, + /// Receiver for test events, used for communication between test tasks. + test_receiver: Receiver, +} + +#[derive(Clone, Debug)] +pub enum TestEvent { + Shutdown, +} + +impl TestTask { + /// Create a new task + pub fn new( + state: S, + receivers: Vec>, + test_receiver: Receiver, + ) -> Self { + TestTask { + state, + receivers, + test_receiver, + } + } + + /// Spawn the task loop, consuming self. Will continue until + /// the task reaches some shutdown condition + pub fn run(mut self) -> JoinHandle { + spawn(async move { + loop { + if let Ok(TestEvent::Shutdown) = self.test_receiver.try_recv() { + break self.state.check(); + } + + let mut messages = Vec::new(); + + for receiver in &mut self.receivers { + messages.push(receiver.recv()); + } + + if let Ok((Ok(input), id, _)) = + async_timeout(Duration::from_millis(50), select_all(messages)).await + { + let _ = S::handle_event(&mut self.state, (input, id)) + .await + .inspect_err(|e| tracing::error!("{e}")); + } + } + }) + } +} + +/// Add the network task to handle messages and publish events. +pub async fn add_network_message_test_task< + TYPES: NodeType, + NET: ConnectedNetwork, TYPES::SignatureKey>, +>( + event_stream: Sender>>, + channel: Arc, +) -> JoinHandle<()> { + let net = Arc::clone(&channel); + let network_state: NetworkMessageTaskState<_> = NetworkMessageTaskState { + event_stream: event_stream.clone(), + }; + + let network = Arc::clone(&net); + let mut state = network_state.clone(); + + async_spawn(async move { + loop { + let msgs = match network.recv_msgs().await { + Ok(msgs) => Messages(msgs), + Err(err) => { + error!("failed to receive messages: {err}"); + + // return zero messages so we sleep and try again + Messages(vec![]) + } + }; + if msgs.0.is_empty() { + // TODO: Stop sleeping here: https://github.com/EspressoSystems/HotShot/issues/2558 + async_sleep(Duration::from_millis(100)).await; + } else { + state.handle_messages(msgs.0).await; + } + } + }) +} diff --git a/crates/testing/src/txn_task.rs b/crates/testing/src/txn_task.rs index dcf72dcfb5..2dfabc0446 100644 --- a/crates/testing/src/txn_task.rs +++ b/crates/testing/src/txn_task.rs @@ -1,7 +1,8 @@ -use std::time::Duration; +use std::{sync::Arc, time::Duration}; -use async_broadcast::{Receiver, TryRecvError}; +use async_broadcast::Receiver; use async_compatibility_layer::art::{async_sleep, async_spawn}; +use async_lock::RwLock; #[cfg(async_executor_impl = "async-std")] use async_std::task::JoinHandle; use hotshot::traits::TestableNodeImplementation; @@ -11,8 +12,7 @@ use snafu::Snafu; #[cfg(async_executor_impl = "tokio")] use tokio::task::JoinHandle; -use super::GlobalTestEvent; -use crate::test_runner::{HotShotTaskCompleted, Node}; +use crate::{test_runner::Node, test_task::TestEvent}; // the obvious idea here is to pass in a "stream" that completes every `n` seconds // the stream construction can definitely be fancier but that's the baseline idea @@ -25,29 +25,23 @@ pub struct TxnTaskErr {} pub struct TxnTask> { // TODO should this be in a rwlock? Or maybe a similar abstraction to the registry is in order /// Handles for all nodes. - pub handles: Vec>, + pub handles: Arc>>>, /// Optional index of the next node. pub next_node_idx: Option, /// time to wait between txns pub duration: Duration, /// Receiver for the shutdown signal from the testing harness - pub shutdown_chan: Receiver, + pub shutdown_chan: Receiver, } impl> TxnTask { - pub fn run(mut self) -> JoinHandle { + pub fn run(mut self) -> JoinHandle<()> { async_spawn(async move { async_sleep(Duration::from_millis(100)).await; loop { async_sleep(self.duration).await; - match self.shutdown_chan.try_recv() { - Ok(_event) => { - return HotShotTaskCompleted::ShutDown; - } - Err(TryRecvError::Empty) => {} - Err(_) => { - return HotShotTaskCompleted::StreamsDied; - } + if let Ok(TestEvent::Shutdown) = self.shutdown_chan.try_recv() { + break; } self.submit_tx().await; } @@ -55,10 +49,11 @@ impl> TxnTask { } async fn submit_tx(&mut self) { if let Some(idx) = self.next_node_idx { + let handles = &self.handles.read().await; // submit to idx handle // increment state - self.next_node_idx = Some((idx + 1) % self.handles.len()); - match self.handles.get(idx) { + self.next_node_idx = Some((idx + 1) % handles.len()); + match handles.get(idx) { None => { tracing::error!("couldn't get node in txn task"); // should do error diff --git a/crates/testing/src/view_sync_task.rs b/crates/testing/src/view_sync_task.rs index 648d840dc3..9a885beec8 100644 --- a/crates/testing/src/view_sync_task.rs +++ b/crates/testing/src/view_sync_task.rs @@ -1,11 +1,12 @@ use std::{collections::HashSet, marker::PhantomData, sync::Arc}; -use hotshot_task::task::{Task, TaskState, TestTaskState}; +use anyhow::Result; +use async_trait::async_trait; use hotshot_task_impls::events::HotShotEvent; use hotshot_types::traits::node_implementation::{NodeType, TestableNodeImplementation}; use snafu::Snafu; -use crate::{test_runner::HotShotTaskCompleted, GlobalTestEvent}; +use crate::test_task::{TestResult, TestTaskState}; /// `ViewSync` Task error #[derive(Snafu, Debug, Clone)] @@ -24,49 +25,15 @@ pub struct ViewSyncTask> { pub(crate) _pd: PhantomData<(TYPES, I)>, } -impl> TaskState for ViewSyncTask { - type Event = GlobalTestEvent; - - type Output = HotShotTaskCompleted; - - async fn handle_event(event: Self::Event, task: &mut Task) -> Option { - let state = task.state_mut(); - match event { - GlobalTestEvent::ShutDown => match state.description.clone() { - ViewSyncTaskDescription::Threshold(min, max) => { - let num_hits = state.hit_view_sync.len(); - if min <= num_hits && num_hits <= max { - Some(HotShotTaskCompleted::ShutDown) - } else { - Some(HotShotTaskCompleted::Error(Box::new(ViewSyncTaskErr { - hit_view_sync: state.hit_view_sync.clone(), - }))) - } - } - }, - } - } - - fn should_shutdown(_event: &Self::Event) -> bool { - false - } -} - +#[async_trait] impl> TestTaskState for ViewSyncTask { - type Message = Arc>; + type Event = Arc>; - type Output = HotShotTaskCompleted; - - type State = Self; - - async fn handle_message( - message: Self::Message, - id: usize, - task: &mut hotshot_task::task::TestTask, - ) -> Option { - match message.as_ref() { + /// Handles an event from one of multiple receivers. + async fn handle_event(&mut self, (event, id): (Self::Event, usize)) -> Result<()> { + match event.as_ref() { // all the view sync events HotShotEvent::ViewSyncTimeout(_, _, _) | HotShotEvent::ViewSyncPreCommitVoteRecv(_) @@ -82,11 +49,27 @@ impl> TestTaskState | HotShotEvent::ViewSyncCommitCertificate2Send(_, _) | HotShotEvent::ViewSyncFinalizeCertificate2Send(_, _) | HotShotEvent::ViewSyncTrigger(_) => { - task.state_mut().hit_view_sync.insert(id); + self.hit_view_sync.insert(id); } _ => (), } - None + + Ok(()) + } + + fn check(&self) -> TestResult { + match self.description.clone() { + ViewSyncTaskDescription::Threshold(min, max) => { + let num_hits = self.hit_view_sync.len(); + if min <= num_hits && num_hits <= max { + TestResult::Pass + } else { + TestResult::Fail(Box::new(ViewSyncTaskErr { + hit_view_sync: self.hit_view_sync.clone(), + })) + } + } + } } } diff --git a/crates/testing/tests/tests_1/consensus_task.rs b/crates/testing/tests/tests_1/consensus_task.rs index 73d9279bf7..34c86e38d5 100644 --- a/crates/testing/tests/tests_1/consensus_task.rs +++ b/crates/testing/tests/tests_1/consensus_task.rs @@ -11,13 +11,14 @@ use hotshot_example_types::{ }; use hotshot_task_impls::{consensus::ConsensusTaskState, events::HotShotEvent::*}; use hotshot_testing::{ + helpers::{ + build_system_handle, key_pair_for_id, permute_input_with_index_order, + vid_scheme_from_view_number, vid_share, + }, predicates::event::{ exact, quorum_proposal_send, quorum_proposal_validated, quorum_vote_send, timeout_vote_send, }, script::{run_test_script, TestScriptStage}, - helpers::{ - build_system_handle, vid_share, key_pair_for_id, vid_scheme_from_view_number, permute_input_with_index_order - }, view_generator::TestViewGenerator, }; use hotshot_types::{ @@ -121,8 +122,8 @@ async fn test_consensus_vote() { use hotshot::tasks::task_state::CreateTaskState; use hotshot_task_impls::{consensus::ConsensusTaskState, events::HotShotEvent::*}; use hotshot_testing::{ - script::{run_test_script, TestScriptStage}, helpers::build_system_handle, + script::{run_test_script, TestScriptStage}, view_generator::TestViewGenerator, }; diff --git a/crates/testing/tests/tests_1/da_task.rs b/crates/testing/tests/tests_1/da_task.rs index 178a8c187d..992e3fd2b9 100644 --- a/crates/testing/tests/tests_1/da_task.rs +++ b/crates/testing/tests/tests_1/da_task.rs @@ -1,16 +1,16 @@ use std::sync::Arc; use futures::StreamExt; -use hotshot::{tasks::task_state::CreateTaskState, types::SystemContextHandle}; +use hotshot::tasks::task_state::CreateTaskState; use hotshot_example_types::{ block_types::{TestMetadata, TestTransaction}, node_types::{MemoryImpl, TestTypes}, }; use hotshot_task_impls::{da::DaTaskState, events::HotShotEvent::*}; use hotshot_testing::{ + helpers::build_system_handle, predicates::event::exact, script::{run_test_script, TestScriptStage}, - helpers::build_system_handle, view_generator::TestViewGenerator, }; use hotshot_types::{ @@ -94,7 +94,7 @@ async fn test_da_task() { asserts: vec![], }; - let da_state = DaTaskState::>::create_from(&handle).await; + let da_state = DaTaskState::::create_from(&handle).await; let stages = vec![view_1, view_2]; run_test_script(stages, da_state).await; @@ -181,7 +181,7 @@ async fn test_da_task_storage_failure() { asserts: vec![], }; - let da_state = DaTaskState::>::create_from(&handle).await; + let da_state = DaTaskState::::create_from(&handle).await; let stages = vec![view_1, view_2, view_3]; run_test_script(stages, da_state).await; diff --git a/crates/testing/tests/tests_1/network_task.rs b/crates/testing/tests/tests_1/network_task.rs index ef3ce59143..ccb591b047 100644 --- a/crates/testing/tests/tests_1/network_task.rs +++ b/crates/testing/tests/tests_1/network_task.rs @@ -2,14 +2,17 @@ use std::{sync::Arc, time::Duration}; use async_compatibility_layer::art::async_timeout; use async_lock::RwLock; -use hotshot::{tasks::add_network_message_task, traits::implementations::MemoryNetwork}; +use hotshot::traits::implementations::MemoryNetwork; use hotshot_example_types::node_types::{MemoryImpl, TestTypes}; -use hotshot_task::task::{Task, TaskRegistry}; +use hotshot_task::task::{ConsensusTaskRegistry, Task}; use hotshot_task_impls::{ events::HotShotEvent, network::{self, NetworkEventTaskState}, }; -use hotshot_testing::{test_builder::TestDescription, view_generator::TestViewGenerator}; +use hotshot_testing::{ + test_builder::TestDescription, test_task::add_network_message_test_task, + view_generator::TestViewGenerator, +}; use hotshot_types::{ constants::BASE_VERSION, data::ViewNumber, @@ -59,16 +62,16 @@ async fn test_network_task() { storage, }; let (tx, rx) = async_broadcast::broadcast(10); - let task_reg = Arc::new(TaskRegistry::default()); + let mut task_reg = ConsensusTaskRegistry::new(); - let task = Task::new(tx.clone(), rx, task_reg.clone(), network_state); - task_reg.run_task(task).await; + let task = Task::new(network_state, tx.clone(), rx); + task_reg.run_task(task); let mut generator = TestViewGenerator::generate(membership.clone(), membership); let view = generator.next().await.unwrap(); let (out_tx, mut out_rx) = async_broadcast::broadcast(10); - add_network_message_task(task_reg, out_tx.clone(), channel.clone()).await; + add_network_message_test_task(out_tx.clone(), channel.clone()).await; tx.broadcast_direct(Arc::new(HotShotEvent::QuorumProposalSend( view.quorum_proposal, @@ -124,16 +127,16 @@ async fn test_network_storage_fail() { storage, }; let (tx, rx) = async_broadcast::broadcast(10); - let task_reg = Arc::new(TaskRegistry::default()); + let mut task_reg = ConsensusTaskRegistry::new(); - let task = Task::new(tx.clone(), rx, task_reg.clone(), network_state); - task_reg.run_task(task).await; + let task = Task::new(network_state, tx.clone(), rx); + task_reg.run_task(task); let mut generator = TestViewGenerator::generate(membership.clone(), membership); let view = generator.next().await.unwrap(); let (out_tx, mut out_rx) = async_broadcast::broadcast(10); - add_network_message_task(task_reg, out_tx.clone(), channel.clone()).await; + add_network_message_test_task(out_tx.clone(), channel.clone()).await; tx.broadcast_direct(Arc::new(HotShotEvent::QuorumProposalSend( view.quorum_proposal, diff --git a/crates/testing/tests/tests_1/proposal_ordering.rs b/crates/testing/tests/tests_1/proposal_ordering.rs index 4ca6f36225..3faf7b470d 100644 --- a/crates/testing/tests/tests_1/proposal_ordering.rs +++ b/crates/testing/tests/tests_1/proposal_ordering.rs @@ -10,8 +10,8 @@ use hotshot_example_types::{ }; use hotshot_task_impls::{consensus::ConsensusTaskState, events::HotShotEvent::*}; use hotshot_testing::{ + helpers::{permute_input_with_index_order, vid_scheme_from_view_number, vid_share}, predicates::event::{all_predicates, exact, quorum_proposal_send, quorum_proposal_validated}, - helpers::{vid_share, vid_scheme_from_view_number, permute_input_with_index_order}, view_generator::TestViewGenerator, }; use hotshot_types::{ @@ -29,8 +29,8 @@ async fn test_ordering_with_specific_order(input_permutation: Vec) { use futures::StreamExt; use hotshot_example_types::state_types::TestValidatedState; use hotshot_testing::{ - script::{run_test_script, TestScriptStage}, helpers::build_system_handle, + script::{run_test_script, TestScriptStage}, }; async_compatibility_layer::logging::setup_logging(); diff --git a/crates/testing/tests/tests_1/quorum_proposal_task.rs b/crates/testing/tests/tests_1/quorum_proposal_task.rs index 19fecf5e0d..9b8a37a529 100644 --- a/crates/testing/tests/tests_1/quorum_proposal_task.rs +++ b/crates/testing/tests/tests_1/quorum_proposal_task.rs @@ -1,5 +1,7 @@ #![cfg(feature = "dependency-tasks")] +use std::sync::Arc; + use committable::Committable; use hotshot::tasks::task_state::CreateTaskState; use hotshot_example_types::{ @@ -38,7 +40,6 @@ use hotshot_types::{ }; use jf_vid::VidScheme; use sha2::Digest; -use std::sync::Arc; fn make_payload_commitment( membership: &::Membership, diff --git a/crates/testing/tests/tests_1/upgrade_task.rs b/crates/testing/tests/tests_1/upgrade_task.rs index 94b88b0177..35d7094bcc 100644 --- a/crates/testing/tests/tests_1/upgrade_task.rs +++ b/crates/testing/tests/tests_1/upgrade_task.rs @@ -15,9 +15,9 @@ use hotshot_task_impls::{ consensus::ConsensusTaskState, events::HotShotEvent::*, upgrade::UpgradeTaskState, }; use hotshot_testing::{ + helpers::vid_share, predicates::{event::*, upgrade::*}, script::{Expectations, TaskScript}, - helpers::vid_share, view_generator::TestViewGenerator, }; use hotshot_types::{ @@ -33,8 +33,8 @@ use vbs::version::Version; /// Tests that we correctly update our internal consensus state when reaching a decided upgrade certificate. async fn test_consensus_task_upgrade() { use hotshot_testing::{ - script::{run_test_script, TestScriptStage}, helpers::build_system_handle, + script::{run_test_script, TestScriptStage}, }; async_compatibility_layer::logging::setup_logging(); @@ -118,9 +118,11 @@ async fn test_consensus_task_upgrade() { ], outputs: vec![ exact(ViewChange(ViewNumber::new(3))), - quorum_proposal_validated(), - leaf_decided(), - exact(QuorumVoteSend(votes[2].clone())), + all_predicates(vec![ + quorum_proposal_validated(), + leaf_decided(), + exact(QuorumVoteSend(votes[2].clone())), + ]), ], asserts: vec![no_decided_upgrade_cert()], }; @@ -133,9 +135,11 @@ async fn test_consensus_task_upgrade() { ], outputs: vec![ exact(ViewChange(ViewNumber::new(4))), - quorum_proposal_validated(), - leaf_decided(), - exact(QuorumVoteSend(votes[3].clone())), + all_predicates(vec![ + quorum_proposal_validated(), + leaf_decided(), + exact(QuorumVoteSend(votes[3].clone())), + ]), ], asserts: vec![no_decided_upgrade_cert()], }; @@ -144,9 +148,7 @@ async fn test_consensus_task_upgrade() { inputs: vec![QuorumProposalRecv(proposals[4].clone(), leaders[4])], outputs: vec![ exact(ViewChange(ViewNumber::new(5))), - quorum_proposal_validated(), - upgrade_decided(), - leaf_decided(), + all_predicates(vec![quorum_proposal_validated(), upgrade_decided(), leaf_decided()]), ], asserts: vec![decided_upgrade_cert()], }; @@ -225,12 +227,7 @@ async fn test_upgrade_and_consensus_task() { .map(|h| views[2].create_upgrade_vote(upgrade_data.clone(), &h.0)); let consensus_state = ConsensusTaskState::::create_from(&handle).await; - let mut upgrade_state = UpgradeTaskState::< - TestTypes, - MemoryImpl, - SystemContextHandle, - >::create_from(&handle) - .await; + let mut upgrade_state = UpgradeTaskState::::create_from(&handle).await; upgrade_state.should_vote = |_| true; @@ -261,7 +258,8 @@ async fn test_upgrade_and_consensus_task() { ], ]; - let consensus_script = TaskScript { + let mut consensus_script = TaskScript { + timeout: Duration::from_millis(35), state: consensus_state, expectations: vec![ Expectations { @@ -291,7 +289,8 @@ async fn test_upgrade_and_consensus_task() { ], }; - let upgrade_script = TaskScript { + let mut upgrade_script = TaskScript { + timeout: Duration::from_millis(35), state: upgrade_state, expectations: vec![ Expectations { @@ -313,7 +312,7 @@ async fn test_upgrade_and_consensus_task() { ], }; - test_scripts![inputs, consensus_script, upgrade_script]; + test_scripts![inputs, consensus_script, upgrade_script].await; } #[cfg(not(feature = "dependency-tasks"))] @@ -418,12 +417,7 @@ async fn test_upgrade_and_consensus_task_blank_blocks() { } let consensus_state = ConsensusTaskState::::create_from(&handle).await; - let mut upgrade_state = UpgradeTaskState::< - TestTypes, - MemoryImpl, - SystemContextHandle, - >::create_from(&handle) - .await; + let mut upgrade_state = UpgradeTaskState::::create_from(&handle).await; upgrade_state.should_vote = |_| true; @@ -507,7 +501,8 @@ async fn test_upgrade_and_consensus_task_blank_blocks() { ], ]; - let consensus_script = TaskScript { + let mut consensus_script = TaskScript { + timeout: Duration::from_millis(35), state: consensus_state, expectations: vec![ Expectations { @@ -571,7 +566,8 @@ async fn test_upgrade_and_consensus_task_blank_blocks() { ], }; - let upgrade_script = TaskScript { + let mut upgrade_script = TaskScript { + timeout: Duration::from_millis(35), state: upgrade_state, expectations: vec![ Expectations { @@ -605,5 +601,5 @@ async fn test_upgrade_and_consensus_task_blank_blocks() { ], }; - test_scripts![inputs, consensus_script, upgrade_script]; + test_scripts![inputs, consensus_script, upgrade_script].await; } diff --git a/crates/testing/tests/tests_1/vid_task.rs b/crates/testing/tests/tests_1/vid_task.rs index e78418c045..0375385089 100644 --- a/crates/testing/tests/tests_1/vid_task.rs +++ b/crates/testing/tests/tests_1/vid_task.rs @@ -8,9 +8,9 @@ use hotshot_example_types::{ }; use hotshot_task_impls::{events::HotShotEvent::*, vid::VidTaskState}; use hotshot_testing::{ + helpers::{build_system_handle, vid_scheme_from_view_number}, predicates::event::exact, script::{run_test_script, TestScriptStage}, - helpers::{build_system_handle, vid_scheme_from_view_number}, }; use hotshot_types::{ data::{null_block, DaProposal, VidDisperse, ViewNumber}, diff --git a/crates/testing/tests/tests_1/view_sync_task.rs b/crates/testing/tests/tests_1/view_sync_task.rs index 811f74b675..72c2cdfbea 100644 --- a/crates/testing/tests/tests_1/view_sync_task.rs +++ b/crates/testing/tests/tests_1/view_sync_task.rs @@ -1,4 +1,4 @@ -use hotshot::{tasks::task_state::CreateTaskState, types::SystemContextHandle}; +use hotshot::tasks::task_state::CreateTaskState; use hotshot_example_types::node_types::{MemoryImpl, TestTypes}; use hotshot_task_impls::{ events::HotShotEvent, harness::run_harness, view_sync::ViewSyncTaskState, @@ -44,11 +44,6 @@ async fn test_view_sync_task() { output.push(HotShotEvent::ViewChange(ViewNumber::new(2))); output.push(HotShotEvent::ViewSyncPreCommitVoteSend(vote.clone())); - let view_sync_state = ViewSyncTaskState::< - TestTypes, - MemoryImpl, - SystemContextHandle, - >::create_from(&handle) - .await; + let view_sync_state = ViewSyncTaskState::::create_from(&handle).await; run_harness(input, output, view_sync_state, false).await; } diff --git a/crates/testing/tests/tests_2/catchup.rs b/crates/testing/tests/tests_2/catchup.rs index 9f2a250a41..a11fb4a0e9 100644 --- a/crates/testing/tests/tests_2/catchup.rs +++ b/crates/testing/tests/tests_2/catchup.rs @@ -12,6 +12,7 @@ async fn test_catchup() { spinning_task::{ChangeNode, SpinningTaskDescription, UpDown}, test_builder::{TestDescription, TimingData}, }; + async_compatibility_layer::logging::setup_logging(); async_compatibility_layer::logging::setup_backtrace(); let timing_data = TimingData { diff --git a/crates/types/src/traits/network.rs b/crates/types/src/traits/network.rs index bfcf2d3de0..01248b55d0 100644 --- a/crates/types/src/traits/network.rs +++ b/crates/types/src/traits/network.rs @@ -245,7 +245,6 @@ pub trait ConnectedNetwork: async fn wait_for_ready(&self); /// Blocks until the network is shut down - /// then returns true fn shut_down<'a, 'b>(&'a self) -> BoxSyncFuture<'b, ()> where 'a: 'b,