diff --git a/src/rpc/parallel_batch_layer.rs b/src/rpc/parallel_batch_layer.rs index 1c55217c032..c367de028d3 100644 --- a/src/rpc/parallel_batch_layer.rs +++ b/src/rpc/parallel_batch_layer.rs @@ -1,18 +1,24 @@ // Copyright 2019-2026 ChainSafe Systems // SPDX-License-Identifier: Apache-2.0, MIT -use futures::{FutureExt, StreamExt, stream::FuturesOrdered}; +use std::{borrow::Cow, sync::Arc}; + +use ahash::HashMap; use jsonrpsee::{ MethodResponse, core::middleware::{Batch, BatchEntry, Notification}, server::{BatchResponseBuilder, middleware::rpc::RpcServiceT}, + types::{ErrorCode, ErrorObject, Id, Request}, }; +use tokio::task::JoinSet; use tower::Layer; -/// Parallelize batch RPC requests that are processed in sequence by default -/// See +/// Parallelize batch RPC requests across the `tokio` worker pool. /// -/// Note that such parallelization is allowed as per the [`JSON-RPC` specification](https://www.jsonrpc.org/specification#:~:text=6%20Batch) +/// jsonrpsee processes batches sequentially by default. The +/// [JSON-RPC spec](https://www.jsonrpc.org/specification#batch) does not +/// require sequential processing or response ordering, but order is +/// preserved here to avoid surprising clients. #[derive(Clone, derive_more::Constructor)] pub(super) struct ParallelBatchLayer { max_response_body_size: usize, @@ -23,7 +29,7 @@ impl Layer for ParallelBatchLayer { fn layer(&self, service: S) -> Self::Service { ParallelBatchService { - service, + service: Arc::new(service), max_response_body_size: self.max_response_body_size, } } @@ -31,7 +37,7 @@ impl Layer for ParallelBatchLayer { #[derive(Clone)] pub(super) struct ParallelBatchService { - service: S, + service: Arc, max_response_body_size: usize, } @@ -49,53 +55,85 @@ where type NotificationResponse = S::NotificationResponse; type BatchResponse = S::BatchResponse; - fn call<'a>( - &self, - req: jsonrpsee::types::Request<'a>, - ) -> impl Future + Send + 'a { + fn call<'a>(&self, req: Request<'a>) -> impl Future + Send + 'a { self.service.call(req) } - // Parallelized version of https://github.com/paritytech/jsonrpsee/blob/v0.26.0/server/src/middleware/rpc.rs#L151 fn batch<'a>(&self, batch: Batch<'a>) -> impl Future + Send + 'a { - // Process batch in parallel instead of delegating to the inner service, which processes them sequentially. - let mut batch_rp = BatchResponseBuilder::new_with_limit(self.max_response_body_size); + let max = self.max_response_body_size; let mut got_notification = false; - // Although it's not neccesary to perserve the order in response, we do it to avoid potential bugs on client side - // See - let mut tasks = FuturesOrdered::new(); - for batch_entry in batch.into_iter() { - match batch_entry { + // JoinSet aborts in-flight tasks on drop. + let mut join_set: JoinSet<(usize, Option)> = JoinSet::new(); + // Lets a panicked call task be turned into a per-entry error with the + // original request id. + let mut call_meta: HashMap)> = HashMap::default(); + let mut results: Vec<(usize, Option)> = Vec::new(); + + for (idx, entry) in batch.into_iter().enumerate() { + let service = Arc::clone(&self.service); + match entry { Ok(BatchEntry::Call(req)) => { - tasks.push_back(self.service.call(req).map(Some).boxed()); + let req_id = req.id().into_owned(); + let req = into_owned_request(req); + let handle = + join_set.spawn(async move { (idx, Some(service.call(req).await)) }); + call_meta.insert(handle.id(), (idx, req_id)); } Ok(BatchEntry::Notification(n)) => { got_notification = true; - tasks.push_back(self.service.notification(n).map(|_| None).boxed()); + let n = into_owned_notification(n); + join_set.spawn(async move { + service.notification(n).await; + (idx, None) + }); } Err(err) => { let (err, id) = err.into_parts(); - let rp = MethodResponse::error(id, err); - tasks.push_back(async move { Some(rp) }.boxed()); + results.push(( + idx, + Some(MethodResponse::error(id.into_owned(), err.into_owned())), + )); } } } async move { - while let Some(r) = tasks.next().await { - if let Some(rp) = r + results.reserve(join_set.len()); + while let Some(joined) = join_set.join_next_with_id().await { + match joined { + Ok((_, r)) => results.push(r), + Err(e) if e.is_panic() => { + if let Some((idx, req_id)) = call_meta.remove(&e.id()) { + tracing::error!(idx, "RPC call panicked in batch entry"); + let err = ErrorObject::owned::<()>( + ErrorCode::InternalError.code(), + "RPC handler panicked", + None, + ); + results.push((idx, Some(MethodResponse::error(req_id, err)))); + } else { + tracing::error!("RPC notification panicked in batch entry"); + } + } + Err(_) => unreachable!("JoinSet only cancels tasks on drop"), + } + } + results.sort_by_key(|(idx, _)| *idx); + + let mut batch_rp = BatchResponseBuilder::new_with_limit(max); + for (_, rp) in results { + if let Some(rp) = rp && let Err(err) = batch_rp.append(rp) { return err; } } - // If the batch is empty and we got a notification, we return an empty response. + // Empty builder + at least one notification is the spec's + // "no response" case for a notification-only batch. if batch_rp.is_empty() && got_notification { MethodResponse::notification() - } - // An empty batch is regarded as an invalid request here. - else { + } else { MethodResponse::from_batch(batch_rp.finish()) } } @@ -108,3 +146,169 @@ where self.service.notification(n) } } + +fn into_owned_request(req: Request<'_>) -> Request<'static> { + Request { + jsonrpc: req.jsonrpc, + id: req.id.into_owned(), + method: Cow::Owned(req.method.into_owned()), + params: req.params.map(|p| Cow::Owned(p.into_owned())), + extensions: req.extensions, + } +} + +fn into_owned_notification(n: Notification<'_>) -> Notification<'static> { + Notification { + jsonrpc: n.jsonrpc, + method: Cow::Owned(n.method.into_owned()), + params: n.params.map(|p| Cow::Owned(p.into_owned())), + extensions: n.extensions, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use jsonrpsee::core::middleware::BatchEntryErr; + use jsonrpsee::server::ResponsePayload; + use jsonrpsee::types::{Extensions, TwoPointZero}; + use std::time::Duration; + + const MAX_RESP: usize = 1024 * 1024; + + /// Method conventions used by tests: + /// "ok" – success response carrying the method name. + /// "slow:" – sleep, then succeed. + /// "panic" – panic inside the call task. + #[derive(Clone, Default)] + struct TestService; + + impl RpcServiceT for TestService { + type MethodResponse = MethodResponse; + type NotificationResponse = MethodResponse; + type BatchResponse = MethodResponse; + + fn call<'a>( + &self, + req: Request<'a>, + ) -> impl Future + Send + 'a { + let id = req.id().into_owned(); + let method = req.method_name().to_string(); + async move { + if method == "panic" { + panic!("test panic"); + } + if let Some(rest) = method.strip_prefix("slow:") { + let ms: u64 = rest.parse().unwrap(); + tokio::time::sleep(Duration::from_millis(ms)).await; + } + MethodResponse::response(id, ResponsePayload::success(method), MAX_RESP) + } + } + + // `async fn` form drops the explicit `'a` capture the trait wants, + // and the `manual_async_fn` lint fires on trivial `async {}` bodies. + #[expect(clippy::manual_async_fn, reason = "trait demands explicit 'a")] + fn batch<'a>( + &self, + _b: Batch<'a>, + ) -> impl Future + Send + 'a { + async { unreachable!("ParallelBatchLayer overrides this") } + } + + #[expect(clippy::manual_async_fn, reason = "trait demands explicit 'a")] + fn notification<'a>( + &self, + _n: Notification<'a>, + ) -> impl Future + Send + 'a { + async { MethodResponse::notification() } + } + } + + fn layer() -> ParallelBatchService { + ParallelBatchService { + service: Arc::new(TestService), + max_response_body_size: MAX_RESP, + } + } + + fn call(id: u64, method: &str) -> Request<'static> { + Request::owned(method.to_string(), None, Id::Number(id)) + } + + fn notification(method: &str) -> Notification<'static> { + Notification { + jsonrpc: TwoPointZero, + method: Cow::Owned(method.to_string()), + params: None, + extensions: Extensions::new(), + } + } + + fn as_array(rp: &MethodResponse) -> Vec { + serde_json::from_str::>(rp.as_json().get()).unwrap() + } + + #[tokio::test] + async fn preserves_order_under_heterogeneous_latency() { + let svc = layer(); + let batch = Batch::from(vec![ + Ok(BatchEntry::Call(call(1, "slow:50"))), + Ok(BatchEntry::Call(call(2, "ok"))), + Ok(BatchEntry::Call(call(3, "slow:25"))), + ]); + let arr = as_array(&svc.batch(batch).await); + assert_eq!(arr.len(), 3); + assert_eq!(arr[0]["id"], 1); + assert_eq!(arr[1]["id"], 2); + assert_eq!(arr[2]["id"], 3); + } + + #[tokio::test] + async fn panicked_call_yields_per_entry_error() { + let svc = layer(); + let batch = Batch::from(vec![ + Ok(BatchEntry::Call(call(1, "ok"))), + Ok(BatchEntry::Call(call(2, "panic"))), + Ok(BatchEntry::Call(call(3, "ok"))), + ]); + let arr = as_array(&svc.batch(batch).await); + assert_eq!(arr.len(), 3); + assert_eq!(arr[0]["id"], 1); + assert!(arr[0]["result"].is_string(), "first entry should succeed"); + assert_eq!(arr[1]["id"], 2); + assert!( + arr[1]["error"].is_object(), + "panicked entry must carry its own error" + ); + assert_eq!(arr[2]["id"], 3); + assert!(arr[2]["result"].is_string(), "third entry should succeed"); + } + + #[tokio::test] + async fn notification_only_batch_returns_notification() { + let svc = layer(); + let batch = Batch::from(vec![Ok(BatchEntry::Notification(notification("ok")))]); + let resp = svc.batch(batch).await; + assert!(resp.is_notification()); + } + + #[tokio::test] + async fn entry_err_preserves_index() { + let svc = layer(); + let batch = Batch::from(vec![ + Ok(BatchEntry::Call(call(1, "ok"))), + Err(BatchEntryErr::new( + Id::Number(2), + ErrorObject::from(ErrorCode::InvalidRequest), + )), + Ok(BatchEntry::Call(call(3, "ok"))), + ]); + let arr = as_array(&svc.batch(batch).await); + assert_eq!(arr.len(), 3); + assert_eq!(arr[0]["id"], 1); + assert_eq!(arr[1]["id"], 2); + assert!(arr[1]["error"].is_object()); + assert_eq!(arr[2]["id"], 3); + } +}