diff --git a/examples/historical_scanning/main.rs b/examples/historical_scanning/main.rs index 2bfcdfb..b7224bf 100644 --- a/examples/historical_scanning/main.rs +++ b/examples/historical_scanning/main.rs @@ -69,17 +69,17 @@ async fn main() -> anyhow::Result<()> { while let Some(message) = stream.next().await { match message { - Message::Data(logs) => { + Ok(Message::Data(logs)) => { for log in logs { info!("Callback successfully executed with event {:?}", log.inner.data); } } - Message::Error(e) => { - error!("Received error: {}", e); - } - Message::Notification(info) => { + Ok(Message::Notification(info)) => { info!("Received info: {:?}", info); } + Err(e) => { + error!("Received error: {}", e); + } } } diff --git a/examples/latest_events_scanning/main.rs b/examples/latest_events_scanning/main.rs index 3dfbdaa..2c3d3b7 100644 --- a/examples/latest_events_scanning/main.rs +++ b/examples/latest_events_scanning/main.rs @@ -70,17 +70,17 @@ async fn main() -> anyhow::Result<()> { while let Some(message) = stream.next().await { match message { - Message::Data(logs) => { + Ok(Message::Data(logs)) => { for log in logs { info!("Received event: {:?}", log.inner.data); } } - Message::Error(e) => { - error!("Received error: {}", e); - } - Message::Notification(info) => { + Ok(Message::Notification(info)) => { info!("Received notification: {:?}", info); } + Err(e) => { + error!("Received error: {}", e); + } } } diff --git a/examples/live_scanning/main.rs b/examples/live_scanning/main.rs index 5864b48..7eb8d9c 100644 --- a/examples/live_scanning/main.rs +++ b/examples/live_scanning/main.rs @@ -69,17 +69,17 @@ async fn main() -> anyhow::Result<()> { while let Some(message) = stream.next().await { match message { - Message::Data(logs) => { + Ok(Message::Data(logs)) => { for log in logs { info!("Callback successfully executed with event {:?}", log.inner.data); } } - Message::Error(e) => { - error!("Received error: {}", e); - } - Message::Notification(info) => { + Ok(Message::Notification(info)) => { info!("Received info: {:?}", info); } + Err(e) => { + error!("Received error: {}", e); + } } } diff --git a/examples/sync_from_block_scanning/main.rs b/examples/sync_from_block_scanning/main.rs index 38cf60f..48ac359 100644 --- a/examples/sync_from_block_scanning/main.rs +++ b/examples/sync_from_block_scanning/main.rs @@ -86,7 +86,7 @@ async fn main() -> anyhow::Result<()> { while let Some(message) = stream.next().await { match message { - Message::Data(logs) => { + Ok(Message::Data(logs)) => { for log in logs { let Counter::CountIncreased { newCount } = log.log_decode().unwrap().inner.data; if newCount <= 3 { @@ -98,12 +98,12 @@ async fn main() -> anyhow::Result<()> { } } } - Message::Error(e) => { - error!("Received error: {}", e); - } - Message::Notification(info) => { + Ok(Message::Notification(info)) => { info!("Received notification: {:?}", info); } + Err(e) => { + error!("Received error: {}", e); + } } if historical_processed && live_processed { diff --git a/examples/sync_from_latest_scanning/main.rs b/examples/sync_from_latest_scanning/main.rs index 476e84a..ed0216a 100644 --- a/examples/sync_from_latest_scanning/main.rs +++ b/examples/sync_from_latest_scanning/main.rs @@ -77,17 +77,17 @@ async fn main() -> anyhow::Result<()> { // only the last 5 events will be streamed before switching to live mode while let Some(message) = stream.next().await { match message { - Message::Data(logs) => { + Ok(Message::Data(logs)) => { for log in logs { info!("Callback successfully executed with event {:?}", log.inner.data); } } - Message::Error(e) => { - error!("Received error: {}", e); - } - Message::Notification(info) => { + Ok(Message::Notification(info)) => { info!("Received info: {:?}", info); } + Err(e) => { + error!("Received error: {}", e); + } } } diff --git a/src/block_range_scanner.rs b/src/block_range_scanner.rs index 4f4e875..dab8054 100644 --- a/src/block_range_scanner.rs +++ b/src/block_range_scanner.rs @@ -7,10 +7,10 @@ //! //! use alloy::providers::{Provider, ProviderBuilder}; //! use event_scanner::{ -//! ScannerError, +//! ScannerError, ScannerMessage, //! block_range_scanner::{ //! BlockRangeScanner, BlockRangeScannerClient, DEFAULT_BLOCK_CONFIRMATIONS, -//! DEFAULT_MAX_BLOCK_RANGE, Message, +//! DEFAULT_MAX_BLOCK_RANGE, //! }, //! robust_provider::RobustProviderBuilder, //! }; @@ -35,10 +35,13 @@ //! //! while let Some(message) = stream.next().await { //! match message { -//! Message::Data(range) => { +//! Ok(ScannerMessage::Data(range)) => { //! // process range //! } -//! Message::Error(e) => { +//! Ok(ScannerMessage::Notification(notification)) => { +//! info!("Received notification: {:?}", notification); +//! } +//! Err(e) => { //! error!("Received error from subscription: {e}"); //! match e { //! ScannerError::ServiceShutdown => break, @@ -47,9 +50,6 @@ //! } //! } //! } -//! Message::Notification(notification) => { -//! info!("Received notification: {:?}", notification); -//! } //! } //! } //! @@ -67,18 +67,17 @@ use tokio::{ use tokio_stream::{StreamExt, wrappers::ReceiverStream}; use crate::{ - ScannerMessage, - error::ScannerError, + ScannerError, ScannerMessage, robust_provider::{Error as RobustProviderError, IntoRobustProvider, RobustProvider}, - types::{Notification, TryStream}, + types::{IntoScannerResult, Notification, ScannerResult, TryStream}, }; + use alloy::{ consensus::BlockHeader, eips::{BlockId, BlockNumberOrTag}, network::{BlockResponse, Network, primitives::HeaderResponse}, primitives::{B256, BlockNumber}, pubsub::Subscription, - transports::{RpcError, TransportErrorKind}, }; use tracing::{debug, error, info, warn}; @@ -92,11 +91,13 @@ pub const MAX_BUFFERED_MESSAGES: usize = 50000; // is considered final) pub const DEFAULT_REORG_REWIND_DEPTH: u64 = 64; -pub type Message = ScannerMessage, ScannerError>; +pub type BlockScannerResult = ScannerResult>; + +pub type Message = ScannerMessage>; impl From> for Message { - fn from(logs: RangeInclusive) -> Self { - Message::Data(logs) + fn from(range: RangeInclusive) -> Self { + Message::Data(range) } } @@ -106,21 +107,9 @@ impl PartialEq> for Message { } } -impl From for Message { - fn from(error: RobustProviderError) -> Self { - Message::Error(error.into()) - } -} - -impl From> for Message { - fn from(error: RpcError) -> Self { - Message::Error(error.into()) - } -} - -impl From for Message { - fn from(error: ScannerError) -> Self { - Message::Error(error) +impl IntoScannerResult> for RangeInclusive { + fn into_scanner_message_result(self) -> BlockScannerResult { + Ok(Message::Data(self)) } } @@ -190,24 +179,24 @@ impl ConnectedBlockRangeScanner { #[derive(Debug)] pub enum Command { StreamLive { - sender: mpsc::Sender, + sender: mpsc::Sender, block_confirmations: u64, response: oneshot::Sender>, }, StreamHistorical { - sender: mpsc::Sender, + sender: mpsc::Sender, start_id: BlockId, end_id: BlockId, response: oneshot::Sender>, }, StreamFrom { - sender: mpsc::Sender, + sender: mpsc::Sender, start_id: BlockId, block_confirmations: u64, response: oneshot::Sender>, }, Rewind { - sender: mpsc::Sender, + sender: mpsc::Sender, start_id: BlockId, end_id: BlockId, response: oneshot::Sender>, @@ -288,7 +277,7 @@ impl Service { async fn handle_live( &mut self, block_confirmations: u64, - sender: mpsc::Sender, + sender: mpsc::Sender, ) -> Result<(), ScannerError> { let max_block_range = self.max_block_range; let latest = self.provider.get_block_number().await?; @@ -320,7 +309,7 @@ impl Service { &mut self, start_id: BlockId, end_id: BlockId, - sender: mpsc::Sender, + sender: mpsc::Sender, ) -> Result<(), ScannerError> { let max_block_range = self.max_block_range; @@ -354,7 +343,7 @@ impl Service { &mut self, start_id: BlockId, block_confirmations: u64, - sender: mpsc::Sender, + sender: mpsc::Sender, ) -> Result<(), ScannerError> { let provider = self.provider.clone(); let max_block_range = self.max_block_range; @@ -407,7 +396,7 @@ impl Service { // Step 2: Setup the live streaming buffer // This channel will accumulate while historical sync is running let (live_block_buffer_sender, live_block_buffer_receiver) = - mpsc::channel::(MAX_BUFFERED_MESSAGES); + mpsc::channel::(MAX_BUFFERED_MESSAGES); // The cutoff is the last block we have synced historically // Any block > cutoff will come from the live stream @@ -457,7 +446,7 @@ impl Service { &mut self, start_id: BlockId, end_id: BlockId, - sender: mpsc::Sender, + sender: mpsc::Sender, ) -> Result<(), ScannerError> { let max_block_range = self.max_block_range; let provider = self.provider.clone(); @@ -489,7 +478,7 @@ impl Service { from: N::BlockResponse, to: N::BlockResponse, max_block_range: u64, - sender: &mpsc::Sender, + sender: &mpsc::Sender, provider: &RobustProvider, ) { let mut batch_count = 0; @@ -569,7 +558,7 @@ impl Service { start: BlockNumber, end: BlockNumber, max_block_range: u64, - sender: &mpsc::Sender, + sender: &mpsc::Sender, ) { let mut batch_count = 0; @@ -606,7 +595,7 @@ impl Service { async fn stream_live_blocks( mut range_start: BlockNumber, subscription: Subscription, - sender: mpsc::Sender, + sender: mpsc::Sender, block_confirmations: u64, max_block_range: u64, ) { @@ -650,8 +639,8 @@ impl Service { } async fn process_live_block_buffer( - mut buffer_rx: mpsc::Receiver, - sender: mpsc::Sender, + mut buffer_rx: mpsc::Receiver, + sender: mpsc::Sender, cutoff: BlockNumber, ) { let mut processed = 0; @@ -660,7 +649,7 @@ impl Service { // Process all buffered messages while let Some(data) = buffer_rx.recv().await { match data { - Message::Data(range) => { + Ok(Message::Data(range)) => { let (start, end) = (*range.start(), *range.end()); if start >= cutoff { if !sender.try_stream(range).await { @@ -730,7 +719,7 @@ impl BlockRangeScannerClient { pub async fn stream_live( &self, block_confirmations: u64, - ) -> Result, ScannerError> { + ) -> Result, ScannerError> { let (blocks_sender, blocks_receiver) = mpsc::channel(MAX_BUFFERED_MESSAGES); let (response_tx, response_rx) = oneshot::channel(); @@ -761,7 +750,7 @@ impl BlockRangeScannerClient { &self, start_id: impl Into, end_id: impl Into, - ) -> Result, ScannerError> { + ) -> Result, ScannerError> { let (blocks_sender, blocks_receiver) = mpsc::channel(MAX_BUFFERED_MESSAGES); let (response_tx, response_rx) = oneshot::channel(); @@ -793,7 +782,7 @@ impl BlockRangeScannerClient { &self, start_id: impl Into, block_confirmations: u64, - ) -> Result, ScannerError> { + ) -> Result, ScannerError> { let (blocks_sender, blocks_receiver) = mpsc::channel(MAX_BUFFERED_MESSAGES); let (response_tx, response_rx) = oneshot::channel(); @@ -825,7 +814,7 @@ impl BlockRangeScannerClient { &self, start_id: impl Into, end_id: impl Into, - ) -> Result, ScannerError> { + ) -> Result, ScannerError> { let (blocks_sender, blocks_receiver) = mpsc::channel(MAX_BUFFERED_MESSAGES); let (response_tx, response_rx) = oneshot::channel(); @@ -874,9 +863,9 @@ mod tests { async fn buffered_messages_after_cutoff_are_all_passed() { let cutoff = 50; let (buffer_tx, buffer_rx) = mpsc::channel(8); - buffer_tx.send(Message::Data(51..=55)).await.unwrap(); - buffer_tx.send(Message::Data(56..=60)).await.unwrap(); - buffer_tx.send(Message::Data(61..=70)).await.unwrap(); + buffer_tx.send(Ok(Message::Data(51..=55))).await.unwrap(); + buffer_tx.send(Ok(Message::Data(56..=60))).await.unwrap(); + buffer_tx.send(Ok(Message::Data(61..=70))).await.unwrap(); drop(buffer_tx); let (out_tx, out_rx) = mpsc::channel(8); @@ -895,9 +884,9 @@ mod tests { let cutoff = 100; let (buffer_tx, buffer_rx) = mpsc::channel(8); - buffer_tx.send(Message::Data(40..=50)).await.unwrap(); - buffer_tx.send(Message::Data(51..=60)).await.unwrap(); - buffer_tx.send(Message::Data(61..=70)).await.unwrap(); + buffer_tx.send(Ok(Message::Data(40..=50))).await.unwrap(); + buffer_tx.send(Ok(Message::Data(51..=60))).await.unwrap(); + buffer_tx.send(Ok(Message::Data(61..=70))).await.unwrap(); drop(buffer_tx); let (out_tx, out_rx) = mpsc::channel(8); @@ -913,9 +902,9 @@ mod tests { let cutoff = 75; let (buffer_tx, buffer_rx) = mpsc::channel(8); - buffer_tx.send(Message::Data(60..=70)).await.unwrap(); - buffer_tx.send(Message::Data(71..=80)).await.unwrap(); - buffer_tx.send(Message::Data(81..=86)).await.unwrap(); + buffer_tx.send(Ok(Message::Data(60..=70))).await.unwrap(); + buffer_tx.send(Ok(Message::Data(71..=80))).await.unwrap(); + buffer_tx.send(Ok(Message::Data(81..=86))).await.unwrap(); drop(buffer_tx); let (out_tx, out_rx) = mpsc::channel(8); @@ -933,11 +922,11 @@ mod tests { let cutoff = 100; let (buffer_tx, buffer_rx) = mpsc::channel(8); - buffer_tx.send(Message::Data(98..=98)).await.unwrap(); // Just before: discard - buffer_tx.send(Message::Data(99..=100)).await.unwrap(); // Includes cutoff: trim to 100..=100 - buffer_tx.send(Message::Data(100..=100)).await.unwrap(); // Exactly at: forward - buffer_tx.send(Message::Data(100..=101)).await.unwrap(); // Starts at cutoff: forward - buffer_tx.send(Message::Data(102..=102)).await.unwrap(); // After cutoff: forward + buffer_tx.send(Ok(Message::Data(98..=98))).await.unwrap(); // Just before: discard + buffer_tx.send(Ok(Message::Data(99..=100))).await.unwrap(); // Includes cutoff: trim to 100..=100 + buffer_tx.send(Ok(Message::Data(100..=100))).await.unwrap(); // Exactly at: forward + buffer_tx.send(Ok(Message::Data(100..=101))).await.unwrap(); // Starts at cutoff: forward + buffer_tx.send(Ok(Message::Data(102..=102))).await.unwrap(); // After cutoff: forward drop(buffer_tx); let (out_tx, out_rx) = mpsc::channel(8); @@ -954,15 +943,13 @@ mod tests { #[tokio::test] async fn try_send_forwards_errors_to_subscribers() { - let (tx, mut rx) = mpsc::channel::(1); + let (tx, mut rx) = mpsc::channel::(1); _ = tx.try_stream(ScannerError::BlockNotFound(4.into())).await; assert!(matches!( rx.recv().await, - Some(ScannerMessage::Error(ScannerError::BlockNotFound(BlockId::Number( - BlockNumberOrTag::Number(4) - )))) + Some(Err(ScannerError::BlockNotFound(BlockId::Number(BlockNumberOrTag::Number(4))))) )); } } diff --git a/src/error.rs b/src/error.rs index c9a83af..11277fa 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{mem::discriminant, sync::Arc}; use alloy::{ eips::BlockId, @@ -6,7 +6,7 @@ use alloy::{ }; use thiserror::Error; -use crate::robust_provider::Error as RobustProviderError; +use crate::{robust_provider::Error as RobustProviderError, types::ScannerResult}; #[derive(Error, Debug, Clone)] pub enum ScannerError { @@ -47,3 +47,11 @@ impl From> for ScannerError { ScannerError::RpcError(Arc::new(error)) } } +impl PartialEq for ScannerResult { + fn eq(&self, other: &ScannerError) -> bool { + match self { + Ok(_) => false, + Err(err) => discriminant(err) == discriminant(other), + } + } +} diff --git a/src/event_scanner/error.rs b/src/event_scanner/error.rs deleted file mode 100644 index b689950..0000000 --- a/src/event_scanner/error.rs +++ /dev/null @@ -1,27 +0,0 @@ -use alloy::{ - rpc::types::Log, - transports::{RpcError, TransportErrorKind}, -}; - -use crate::{Message, ScannerError}; - -impl From> for Message { - fn from(e: RpcError) -> Self { - Message::Error(e.into()) - } -} - -impl From for Message { - fn from(error: ScannerError) -> Self { - Message::Error(error) - } -} - -impl From, RpcError>> for Message { - fn from(logs: Result, RpcError>) -> Self { - match logs { - Ok(logs) => Message::Data(logs), - Err(e) => Message::Error(e.into()), - } - } -} diff --git a/src/event_scanner/listener.rs b/src/event_scanner/listener.rs index 2e57b77..10562b7 100644 --- a/src/event_scanner/listener.rs +++ b/src/event_scanner/listener.rs @@ -1,8 +1,8 @@ -use crate::event_scanner::{filter::EventFilter, message::Message}; +use crate::event_scanner::{EventScannerResult, filter::EventFilter}; use tokio::sync::mpsc::Sender; #[derive(Clone)] pub(crate) struct EventListener { pub filter: EventFilter, - pub sender: Sender, + pub sender: Sender, } diff --git a/src/event_scanner/message.rs b/src/event_scanner/message.rs index ebd1081..df61a21 100644 --- a/src/event_scanner/message.rs +++ b/src/event_scanner/message.rs @@ -1,8 +1,12 @@ use alloy::{rpc::types::Log, sol_types::SolEvent}; -use crate::{ScannerError, ScannerMessage, robust_provider::Error as RobustProviderError}; +use crate::{ + ScannerMessage, + types::{IntoScannerResult, ScannerResult}, +}; -pub type Message = ScannerMessage, ScannerError>; +pub type Message = ScannerMessage>; +pub type EventScannerResult = ScannerResult>; impl From> for Message { fn from(logs: Vec) -> Self { @@ -10,10 +14,9 @@ impl From> for Message { } } -impl From for Message { - fn from(error: RobustProviderError) -> Message { - let scanner_error: ScannerError = error.into(); - scanner_error.into() +impl IntoScannerResult> for Vec { + fn into_scanner_message_result(self) -> EventScannerResult { + Ok(Message::Data(self)) } } diff --git a/src/event_scanner/mod.rs b/src/event_scanner/mod.rs index 2e031b7..74cd78a 100644 --- a/src/event_scanner/mod.rs +++ b/src/event_scanner/mod.rs @@ -1,11 +1,10 @@ -mod error; mod filter; mod listener; mod message; mod scanner; pub use filter::EventFilter; -pub use message::Message; +pub use message::{EventScannerResult, Message}; pub use scanner::{ EventScanner, EventScannerBuilder, Historic, LatestEvents, Live, SyncFromBlock, SyncFromLatestEvents, diff --git a/src/event_scanner/scanner/common.rs b/src/event_scanner/scanner/common.rs index af736a7..b30cbc8 100644 --- a/src/event_scanner/scanner/common.rs +++ b/src/event_scanner/scanner/common.rs @@ -1,8 +1,9 @@ use std::ops::RangeInclusive; use crate::{ - block_range_scanner::{MAX_BUFFERED_MESSAGES, Message as BlockRangeMessage}, - event_scanner::{filter::EventFilter, listener::EventListener}, + ScannerMessage, + block_range_scanner::{BlockScannerResult, MAX_BUFFERED_MESSAGES}, + event_scanner::{EventScannerResult, filter::EventFilter, listener::EventListener}, robust_provider::{Error as RobustProviderError, RobustProvider}, types::TryStream, }; @@ -11,7 +12,10 @@ use alloy::{ rpc::types::{Filter, Log}, }; use tokio::{ - sync::broadcast::{self, Sender, error::RecvError}, + sync::{ + broadcast::{self, Sender, error::RecvError}, + mpsc, + }, task::JoinSet, }; use tokio_stream::{Stream, StreamExt}; @@ -45,13 +49,13 @@ pub enum ConsumerMode { /// # Note /// /// Assumes it is running in a separate tokio task, so as to be non-blocking. -pub async fn handle_stream + Unpin>( +pub async fn handle_stream + Unpin>( mut stream: S, provider: &RobustProvider, listeners: &[EventListener], mode: ConsumerMode, ) { - let (range_tx, _) = broadcast::channel::(MAX_BUFFERED_MESSAGES); + let (range_tx, _) = broadcast::channel::(MAX_BUFFERED_MESSAGES); let consumers = spawn_log_consumers(provider, listeners, &range_tx, mode); @@ -73,7 +77,7 @@ pub async fn handle_stream + Unp pub fn spawn_log_consumers( provider: &RobustProvider, listeners: &[EventListener], - range_tx: &Sender, + range_tx: &Sender, mode: ConsumerMode, ) -> JoinSet<()> { listeners.iter().cloned().fold(JoinSet::new(), |mut set, listener| { @@ -92,50 +96,18 @@ pub fn spawn_log_consumers( loop { match range_rx.recv().await { - Ok(BlockRangeMessage::Data(range)) => { - match get_logs(range, &filter, &base_filter, &provider).await { - Ok(logs) => { - if logs.is_empty() { - continue; - } - - match mode { - ConsumerMode::Stream => { - if !sender.try_stream(logs).await { - break; - } - } - ConsumerMode::CollectLatest { count } => { - let take = count.saturating_sub(collected.len()); - // if we have enough logs, break - if take == 0 { - break; - } - // take latest within this range - collected.extend(logs.into_iter().rev().take(take)); - // if we have enough logs, break - if collected.len() == count { - break; - } - } - } - } - Err(e) => { - if !sender.try_stream(e).await { - break; - } - } - } - } - Ok(BlockRangeMessage::Error(e)) => { - error!(error = ?e, "Received error message"); - if !sender.try_stream(e).await { - break; - } - } - Ok(BlockRangeMessage::Notification(notification)) => { - info!(notification = ?notification, "Received notification"); - if !sender.try_stream(notification).await { + Ok(message) => { + if !handle_block_range_message( + message, + &filter, + &base_filter, + &provider, + &sender, + mode, + &mut collected, + ) + .await + { break; } } @@ -196,3 +168,84 @@ async fn get_logs( } } } + +#[must_use] +async fn handle_block_range_message( + message: BlockScannerResult, + filter: &EventFilter, + base_filter: &Filter, + provider: &RobustProvider, + sender: &mpsc::Sender, + mode: ConsumerMode, + collected: &mut Vec, +) -> bool { + match message { + Ok(ScannerMessage::Data(range)) => { + if !handle_block_range(range, filter, base_filter, provider, sender, mode, collected) + .await + { + return false; + } + } + Ok(ScannerMessage::Notification(notification)) => { + info!(notification = ?notification, "Received notification"); + if !sender.try_stream(notification).await { + return false; + } + } + Err(e) => { + error!(error = ?e, "Received error message"); + if !sender.try_stream(e).await { + return false; + } + } + } + true +} + +#[must_use] +async fn handle_block_range( + range: RangeInclusive, + filter: &EventFilter, + base_filter: &Filter, + provider: &RobustProvider, + sender: &mpsc::Sender, + mode: ConsumerMode, + collected: &mut Vec, +) -> bool { + match get_logs(range, filter, base_filter, provider).await { + Ok(logs) => { + if logs.is_empty() { + return true; + } + + match mode { + ConsumerMode::Stream => { + if !sender.try_stream(logs).await { + return false; + } + } + ConsumerMode::CollectLatest { count } => { + let take = count.saturating_sub(collected.len()); + // if we have enough logs, break + if take == 0 { + return false; + } + // take latest within this range + collected.extend(logs.into_iter().rev().take(take)); + // if we have enough logs, break + if collected.len() == count { + return false; + } + } + } + } + Err(e) => { + error!(error = ?e, "Received error message"); + if !sender.try_stream(e).await { + return false; + } + } + } + true +} diff --git a/src/event_scanner/scanner/mod.rs b/src/event_scanner/scanner/mod.rs index f51a945..8b91bc5 100644 --- a/src/event_scanner/scanner/mod.rs +++ b/src/event_scanner/scanner/mod.rs @@ -6,12 +6,12 @@ use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use crate::{ - EventFilter, Message, ScannerError, + EventFilter, ScannerError, block_range_scanner::{ BlockRangeScanner, ConnectedBlockRangeScanner, DEFAULT_BLOCK_CONFIRMATIONS, MAX_BUFFERED_MESSAGES, }, - event_scanner::listener::EventListener, + event_scanner::{EventScannerResult, listener::EventListener}, robust_provider::IntoRobustProvider, }; @@ -96,7 +96,7 @@ impl EventScannerBuilder { /// /// scanner.start().await?; /// - /// while let Some(Message::Data(logs)) = stream.next().await { + /// while let Some(Ok(Message::Data(logs))) = stream.next().await { /// println!("Received {} logs", logs.len()); /// } /// # Ok(()) @@ -167,13 +167,13 @@ impl EventScannerBuilder { /// /// while let Some(msg) = stream.next().await { /// match msg { - /// Message::Data(logs) => { + /// Ok(Message::Data(logs)) => { /// println!("Received {} new events", logs.len()); /// } - /// Message::Notification(notification) => { + /// Ok(Message::Notification(notification)) => { /// println!("Notification received: {:?}", notification); /// } - /// Message::Error(e) => { + /// Err(e) => { /// eprintln!("Error: {}", e); /// } /// } @@ -251,7 +251,7 @@ impl EventScannerBuilder { /// scanner.start().await?; /// /// // Expect a single message with up to 10 logs, then the stream ends - /// while let Some(Message::Data(logs)) = stream.next().await { + /// while let Some(Ok(Message::Data(logs))) = stream.next().await { /// println!("Latest logs: {}", logs.len()); /// } /// # Ok(()) @@ -411,8 +411,8 @@ impl EventScannerBuilder { impl EventScanner { #[must_use] - pub fn subscribe(&mut self, filter: EventFilter) -> ReceiverStream { - let (sender, receiver) = mpsc::channel::(MAX_BUFFERED_MESSAGES); + pub fn subscribe(&mut self, filter: EventFilter) -> ReceiverStream { + let (sender, receiver) = mpsc::channel::(MAX_BUFFERED_MESSAGES); self.listeners.push(EventListener { filter, sender }); ReceiverStream::new(receiver) } diff --git a/src/event_scanner/scanner/sync/from_latest.rs b/src/event_scanner/scanner/sync/from_latest.rs index 7be468d..f34e1b8 100644 --- a/src/event_scanner/scanner/sync/from_latest.rs +++ b/src/event_scanner/scanner/sync/from_latest.rs @@ -9,8 +9,8 @@ use tokio_stream::{StreamExt, wrappers::ReceiverStream}; use tracing::info; use crate::{ - EventScannerBuilder, Notification, ScannerError, - block_range_scanner::Message as BlockRangeMessage, + EventScannerBuilder, Notification, ScannerError, ScannerMessage, + block_range_scanner::BlockScannerResult, event_scanner::{ EventScanner, scanner::{ @@ -104,9 +104,9 @@ impl EventScanner { info!("Switching to live stream"); // Use a one-off channel for the notification. - let (tx, rx) = mpsc::channel::(1); + let (tx, rx) = mpsc::channel::(1); let stream = ReceiverStream::new(rx); - tx.send(BlockRangeMessage::Notification(Notification::SwitchingToLive)) + tx.send(Ok(ScannerMessage::Notification(Notification::SwitchingToLive))) .await .expect("receiver exists"); diff --git a/src/event_scanner/scanner/sync/mod.rs b/src/event_scanner/scanner/sync/mod.rs index 889a58e..45ffb8f 100644 --- a/src/event_scanner/scanner/sync/mod.rs +++ b/src/event_scanner/scanner/sync/mod.rs @@ -45,14 +45,14 @@ impl EventScannerBuilder { /// /// while let Some(msg) = stream.next().await { /// match msg { - /// Message::Data(logs) => { + /// Ok(Message::Data(logs)) => { /// println!("Received {} events", logs.len()); /// } - /// Message::Notification(notification) => { + /// Ok(Message::Notification(notification)) => { /// println!("Notification received: {:?}", notification); /// // You'll see Notification::SwitchingToLive when transitioning /// } - /// Message::Error(e) => { + /// Err(e) => { /// eprintln!("Error: {}", e); /// } /// } @@ -143,14 +143,14 @@ impl EventScannerBuilder { /// /// while let Some(msg) = stream.next().await { /// match msg { - /// Message::Data(logs) => { + /// Ok(Message::Data(logs)) => { /// println!("Received {} events", logs.len()); /// } - /// Message::Notification(notification) => { + /// Ok(Message::Notification(notification)) => { /// println!("Notification received: {:?}", notification); /// // You'll see Notification::SwitchingToLive when transitioning /// } - /// Message::Error(e) => { + /// Err(e) => { /// eprintln!("Error: {}", e); /// } /// } diff --git a/src/lib.rs b/src/lib.rs index 4af7fa6..ff37a8d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,6 +12,6 @@ pub use error::ScannerError; pub use types::{Notification, ScannerMessage}; pub use event_scanner::{ - EventFilter, EventScanner, EventScannerBuilder, Historic, LatestEvents, Live, Message, - SyncFromBlock, SyncFromLatestEvents, + EventFilter, EventScanner, EventScannerBuilder, EventScannerResult, Historic, LatestEvents, + Live, Message, SyncFromBlock, SyncFromLatestEvents, }; diff --git a/src/test_utils/macros.rs b/src/test_utils/macros.rs index 78e6032..8a82889 100644 --- a/src/test_utils/macros.rs +++ b/src/test_utils/macros.rs @@ -1,14 +1,15 @@ use alloy::primitives::LogData; use tokio_stream::Stream; -use crate::Message; +use crate::{ScannerMessage, event_scanner::EventScannerResult}; #[macro_export] macro_rules! assert_next { - ($stream: expr, $expected: expr) => { - assert_next!($stream, $expected, timeout = 5) + // 1. Explicit Error Matching (Value based) - uses the new PartialEq implementation + ($stream: expr, Err($expected_err:expr)) => { + $crate::assert_next!($stream, Err($expected_err), timeout = 5) }; - ($stream: expr, $expected: expr, timeout = $secs: expr) => { + ($stream: expr, Err($expected_err:expr), timeout = $secs: expr) => { let message = tokio::time::timeout( std::time::Duration::from_secs($secs), tokio_stream::StreamExt::next(&mut $stream), @@ -16,9 +17,35 @@ macro_rules! assert_next { .await .expect("timed out"); if let Some(msg) = message { - assert_eq!(msg, $expected) + let expected = &$expected_err; + assert_eq!(&msg, expected, "Expected error {:?}, got {:?}", expected, msg); } else { - panic!("Expected {:?}, but channel was closed", $expected) + panic!("Expected error {:?}, but channel was closed", $expected_err); + } + }; + + // 2. Success Matching (Implicit unwrapping) - existing behavior + ($stream: expr, $expected: expr) => { + $crate::assert_next!($stream, $expected, timeout = 5) + }; + ($stream: expr, $expected: expr, timeout = $secs: expr) => { + let message = tokio::time::timeout( + std::time::Duration::from_secs($secs), + tokio_stream::StreamExt::next(&mut $stream), + ) + .await + .expect("timed out"); + let expected = $expected; + match message { + std::option::Option::Some(std::result::Result::Ok(msg)) => { + assert_eq!(msg, expected, "Expected {:?}, got {:?}", expected, msg); + } + std::option::Option::Some(std::result::Result::Err(e)) => { + panic!("Expected Ok({:?}), got Err({:?})", expected, e); + } + std::option::Option::None => { + panic!("Expected Ok({:?}), but channel was closed", expected); + } } }; } @@ -162,7 +189,7 @@ macro_rules! assert_event_sequence_final { } #[allow(clippy::missing_panics_doc)] -pub async fn assert_event_sequence + Unpin>( +pub async fn assert_event_sequence + Unpin>( stream: &mut S, expected_options: impl IntoIterator, timeout_secs: u64, @@ -186,7 +213,7 @@ pub async fn assert_event_sequence + Unpin>( .expect("timed out waiting for next batch"); match message { - Some(Message::Data(batch)) => { + Some(Ok(ScannerMessage::Data(batch))) => { let mut batch = batch.iter(); let event = batch.next().expect("Streamed batch should not be empty"); assert_eq!( @@ -205,9 +232,12 @@ pub async fn assert_event_sequence + Unpin>( ); } } - Some(other) => { + Some(Ok(other)) => { panic!("Expected Message::Data, got: {other:#?}"); } + Some(Err(e)) => { + panic!("Expected Ok(Message::Data), got Err: {e:#?}"); + } None => { panic!("Stream closed while still expecting: {:#?}", remaining.collect::>()); } @@ -222,7 +252,7 @@ pub async fn assert_event_sequence + Unpin>( /// range must start exactly where the previous one ended, and all ranges must fit within /// the expected bounds. /// -/// The macro expects the stream to yield `Message::Data(range)` variants containing +/// The macro expects the stream to yield `ScannerMessage::Data(range)` variants containing /// `RangeInclusive` values representing block ranges. It tracks coverage by ensuring /// each new range starts at the next expected block number and doesn't exceed the end of /// the expected range. Once the entire range is covered, the assertion succeeds. @@ -230,7 +260,7 @@ pub async fn assert_event_sequence + Unpin>( /// # Example /// /// ```rust -/// use event_scanner::{assert_range_coverage, block_range_scanner::Message}; +/// use event_scanner::{ScannerMessage, assert_range_coverage}; /// use tokio::sync::mpsc; /// use tokio_stream::wrappers::ReceiverStream; /// @@ -241,8 +271,8 @@ pub async fn assert_event_sequence + Unpin>( /// /// // Simulate a scanner that splits blocks 100-199 into chunks /// tokio::spawn(async move { -/// tx.send(Message::Data(100..=149)).await.unwrap(); -/// tx.send(Message::Data(150..=199)).await.unwrap(); +/// tx.send(ScannerMessage::Data(100..=149)).await.unwrap(); +/// tx.send(ScannerMessage::Data(150..=199)).await.unwrap(); /// }); /// /// // Assert that the stream covers blocks 100-199 @@ -304,7 +334,7 @@ macro_rules! assert_range_coverage { .expect("Timed out waiting for the next block range"); match message { - Some( $crate::block_range_scanner::Message::Data(range)) => { + std::option::Option::Some(std::result::Result::Ok(event_scanner::ScannerMessage::Data(range))) => { let (streamed_start, streamed_end) = bounds(&range); streamed_ranges.push(range.clone()); assert!( @@ -316,10 +346,13 @@ macro_rules! assert_range_coverage { ); start = streamed_end + 1; } - Some(other) => { + std::option::Option::Some(std::result::Result::Ok(other)) => { panic!("Expected a block range, got: {other:#?}"); } - None => { + std::option::Option::Some(std::result::Result::Err(e)) => { + panic!("Expected Ok(Message::Data), got Err: {e:#?}"); + } + std::option::Option::None => { panic!("Stream closed without covering range: {:#?}", start..=end); } } diff --git a/src/types.rs b/src/types.rs index d657920..c21f037 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,12 +1,13 @@ -use std::{error::Error, fmt::Debug}; +use std::fmt::Debug; use tokio::sync::mpsc; use tracing::{info, warn}; -#[derive(Copy, Debug, Clone)] -pub enum ScannerMessage { +use crate::ScannerError; + +#[derive(Debug, Clone)] +pub enum ScannerMessage { Data(T), - Error(E), Notification(Notification), } @@ -16,13 +17,13 @@ pub enum Notification { ReorgDetected, } -impl From for ScannerMessage { +impl From for ScannerMessage { fn from(value: Notification) -> Self { ScannerMessage::Notification(value) } } -impl PartialEq for ScannerMessage { +impl PartialEq for ScannerMessage { fn eq(&self, other: &Notification) -> bool { if let ScannerMessage::Notification(notification) = self { notification == other @@ -32,15 +33,48 @@ impl PartialEq for ScannerMessage { - async fn try_stream>>(&self, msg: M) -> bool; +pub type ScannerResult = Result, ScannerError>; + +pub trait IntoScannerResult { + fn into_scanner_message_result(self) -> ScannerResult; +} + +impl IntoScannerResult for ScannerResult { + fn into_scanner_message_result(self) -> ScannerResult { + self + } } -impl TryStream for mpsc::Sender> { - async fn try_stream>>(&self, msg: M) -> bool { - let msg = msg.into(); - info!(msg = ?msg, "Sending message"); - if let Err(err) = self.send(msg).await { +impl IntoScannerResult for ScannerMessage { + fn into_scanner_message_result(self) -> ScannerResult { + Ok(self) + } +} + +impl> IntoScannerResult for E { + fn into_scanner_message_result(self) -> ScannerResult { + Err(self.into()) + } +} + +impl IntoScannerResult for Notification { + fn into_scanner_message_result(self) -> ScannerResult { + Ok(ScannerMessage::Notification(self)) + } +} + +pub(crate) trait TryStream { + async fn try_stream>(&self, msg: M) -> bool; +} + +impl TryStream for mpsc::Sender> { + async fn try_stream>(&self, msg: M) -> bool { + let item = msg.into_scanner_message_result(); + match &item { + Ok(msg) => info!(item = ?msg, "Sending message"), + Err(err) => info!(error = ?err, "Sending error"), + } + if let Err(err) = self.send(item).await { warn!(error = %err, "Downstream channel closed, stopping stream"); return false; } diff --git a/tests/common/setup_scanner.rs b/tests/common/setup_scanner.rs index b8b9e00..9471f63 100644 --- a/tests/common/setup_scanner.rs +++ b/tests/common/setup_scanner.rs @@ -6,8 +6,8 @@ use alloy::{ }; use alloy_node_bindings::AnvilInstance; use event_scanner::{ - EventFilter, EventScanner, EventScannerBuilder, Historic, LatestEvents, Live, Message, - SyncFromBlock, SyncFromLatestEvents, robust_provider::RobustProvider, + EventFilter, EventScanner, EventScannerBuilder, EventScannerResult, Historic, LatestEvents, + Live, SyncFromBlock, SyncFromLatestEvents, robust_provider::RobustProvider, }; use tokio_stream::wrappers::ReceiverStream; @@ -24,7 +24,7 @@ where pub provider: RobustProvider, pub contract: TestCounter::TestCounterInstance

, pub scanner: S, - pub stream: ReceiverStream, + pub stream: ReceiverStream, #[allow(dead_code)] pub anvil: AnvilInstance, }