From 3a70d608e54c1a1fd52d1f683a10e93f83e15881 Mon Sep 17 00:00:00 2001 From: abhinavgautam01 Date: Mon, 11 May 2026 18:07:52 +0530 Subject: [PATCH 1/2] test: cover push scheduling in client integration tests - Add standalone scheduler/executor APIs for TaskSchedulingPolicy. - Run push-staged executor path in tests; fix scheduler endpoint and gRPC readiness. - Extend rstest fixtures and context_setup with push cases. --- ballista/client/src/extension.rs | 2 +- ballista/client/tests/common/mod.rs | 133 ++++++++-- ballista/client/tests/context_checks.rs | 32 ++- ballista/client/tests/context_setup.rs | 29 +- ballista/client/tests/context_unsupported.rs | 7 +- ballista/executor/src/executor_server.rs | 23 ++ ballista/executor/src/lib.rs | 3 + ballista/executor/src/standalone.rs | 263 ++++++++++++++++++- ballista/scheduler/src/standalone.rs | 90 ++++++- examples/tests/common/mod.rs | 4 +- 10 files changed, 534 insertions(+), 52 deletions(-) diff --git a/ballista/client/src/extension.rs b/ballista/client/src/extension.rs index 54b2123442..e31368b950 100644 --- a/ballista/client/src/extension.rs +++ b/ballista/client/src/extension.rs @@ -196,7 +196,7 @@ impl Extension { .map(|s| s.config().clone()) .unwrap_or_else(default_config_producer); - let scheduler_url = format!("http://localhost:{}", addr.port()); + let scheduler_url = format!("http://127.0.0.1:{}", addr.port()); let scheduler = loop { match SchedulerGrpcClient::connect(scheduler_url.clone()).await { diff --git a/ballista/client/tests/common/mod.rs b/ballista/client/tests/common/mod.rs index 93a2214a20..057d48da05 100644 --- a/ballista/client/tests/common/mod.rs +++ b/ballista/client/tests/common/mod.rs @@ -20,6 +20,7 @@ use std::error::Error; use std::path::PathBuf; use ballista::prelude::{SessionConfigExt, SessionContextExt}; +use ballista_core::config::TaskSchedulingPolicy; use ballista_core::serde::{ BallistaCodec, protobuf::scheduler_grpc_client::SchedulerGrpcClient, }; @@ -99,25 +100,36 @@ fn get_data_dir(udf_env: &str, submodule_data: &str) -> Result (String, u16) { + setup_test_cluster_with_scheduling(TaskSchedulingPolicy::PullStaged).await +} + +/// starts a ballista cluster using the given [`TaskSchedulingPolicy`]. +#[allow(dead_code)] +pub async fn setup_test_cluster_with_scheduling( + scheduling_policy: TaskSchedulingPolicy, +) -> (String, u16) { let config = SessionConfig::new_with_ballista(); let default_codec = BallistaCodec::default(); - let addr = ballista_scheduler::standalone::new_standalone_scheduler() - .await - .expect("scheduler to be created"); + let addr = ballista_scheduler::standalone::new_standalone_scheduler_with_scheduling( + scheduling_policy, + ) + .await + .expect("scheduler to be created"); - let host = "localhost".to_string(); + let host = "127.0.0.1".to_string(); let scheduler = connect_to_scheduler(format!("http://{}:{}", host, addr.port())).await; - ballista_executor::new_standalone_executor( + ballista_executor::new_standalone_executor_with_scheduling_policy( scheduler, config.ballista_standalone_parallelism(), default_codec, + scheduling_policy, ) .await .expect("executor to be created"); @@ -127,26 +139,48 @@ pub async fn setup_test_cluster() -> (String, u16) { (host, addr.port()) } -/// starts a ballista cluster for integration tests +/// starts a ballista cluster using push-staged scheduling (default executor policy). #[allow(dead_code)] -pub async fn setup_test_cluster_with_state(session_state: SessionState) -> (String, u16) { - let config = SessionConfig::new_with_ballista(); +pub async fn setup_test_cluster_push_scheduling() -> (String, u16) { + setup_test_cluster_with_scheduling(TaskSchedulingPolicy::PushStaged).await +} - let addr = ballista_scheduler::standalone::new_standalone_scheduler_from_state( - &session_state, +/// starts a cluster with [`SessionState`] (pull scheduling). +#[allow(dead_code)] +pub async fn setup_test_cluster_with_state(session_state: SessionState) -> (String, u16) { + setup_test_cluster_with_state_and_scheduling( + session_state, + TaskSchedulingPolicy::PullStaged, ) .await - .expect("scheduler to be created"); +} - let host = "localhost".to_string(); +/// starts a ballista cluster with selectable [`TaskSchedulingPolicy`]. +#[allow(dead_code)] +pub async fn setup_test_cluster_with_state_and_scheduling( + session_state: SessionState, + scheduling_policy: TaskSchedulingPolicy, +) -> (String, u16) { + let config = SessionConfig::new_with_ballista(); + + let addr = + ballista_scheduler::standalone::new_standalone_scheduler_from_state_with_scheduling_policy( + &session_state, + scheduling_policy, + ) + .await + .expect("scheduler to be created"); + + let host = "127.0.0.1".to_string(); let scheduler = connect_to_scheduler(format!("http://{}:{}", host, addr.port())).await; - ballista_executor::new_standalone_executor_from_state( + ballista_executor::new_standalone_executor_from_state_with_scheduling_policy( scheduler, config.ballista_standalone_parallelism(), &session_state, + scheduling_policy, ) .await .expect("executor to be created"); @@ -156,11 +190,39 @@ pub async fn setup_test_cluster_with_state(session_state: SessionState) -> (Stri (host, addr.port()) } +/// starts a cluster with push-staged scheduling and a custom session state. +#[allow(dead_code)] +pub async fn setup_test_cluster_with_state_push_scheduling( + session_state: SessionState, +) -> (String, u16) { + setup_test_cluster_with_state_and_scheduling( + session_state, + TaskSchedulingPolicy::PushStaged, + ) + .await +} + #[allow(dead_code)] pub async fn setup_test_cluster_with_builders( config_producer: ConfigProducer, runtime_producer: RuntimeProducer, session_builder: SessionBuilder, +) -> (String, u16) { + setup_test_cluster_with_builders_and_scheduling( + config_producer, + runtime_producer, + session_builder, + TaskSchedulingPolicy::PullStaged, + ) + .await +} + +#[allow(dead_code)] +pub async fn setup_test_cluster_with_builders_and_scheduling( + config_producer: ConfigProducer, + runtime_producer: RuntimeProducer, + session_builder: SessionBuilder, + scheduling_policy: TaskSchedulingPolicy, ) -> (String, u16) { let config = config_producer(); @@ -171,26 +233,29 @@ pub async fn setup_test_cluster_with_builders( datafusion_proto::protobuf::PhysicalPlanNode, > = BallistaCodec::new(logical, physical); - let addr = ballista_scheduler::standalone::new_standalone_scheduler_with_builder( - session_builder, - config_producer.clone(), - codec.clone(), - ) - .await - .expect("scheduler to be created"); + let addr = + ballista_scheduler::standalone::new_standalone_scheduler_with_builder_and_policy( + session_builder, + config_producer.clone(), + codec.clone(), + scheduling_policy, + ) + .await + .expect("scheduler to be created"); - let host = "localhost".to_string(); + let host = "127.0.0.1".to_string(); let scheduler = connect_to_scheduler(format!("http://{}:{}", host, addr.port())).await; - ballista_executor::new_standalone_executor_from_builder( + ballista_executor::new_standalone_executor_from_builder_with_scheduling_policy( scheduler, config.ballista_standalone_parallelism(), config_producer, runtime_producer, codec, Default::default(), + scheduling_policy, ) .await .expect("executor to be created"); @@ -234,6 +299,15 @@ pub async fn remote_context() -> SessionContext { .unwrap() } +/// Remote [`SessionContext`] against a throwaway cluster using push-staged scheduling. +#[allow(dead_code)] +pub async fn remote_context_push_scheduling() -> SessionContext { + let (host, port) = setup_test_cluster_push_scheduling().await; + SessionContext::remote(&format!("df://{host}:{port}")) + .await + .unwrap() +} + #[allow(dead_code)] pub async fn standalone_context_with_state() -> SessionContext { let config = SessionConfig::new_with_ballista(); @@ -257,6 +331,19 @@ pub async fn remote_context_with_state() -> SessionContext { .unwrap() } +#[allow(dead_code)] +pub async fn remote_context_with_state_push_scheduling() -> SessionContext { + let config = SessionConfig::new_with_ballista(); + let state = SessionStateBuilder::new() + .with_config(config) + .with_default_features() + .build(); + let (host, port) = setup_test_cluster_with_state_push_scheduling(state.clone()).await; + SessionContext::remote_with_state(&format!("df://{host}:{port}"), state) + .await + .unwrap() +} + #[ctor::ctor(unsafe)] fn init() { // Enable RUST_LOG logging configuration for test diff --git a/ballista/client/tests/context_checks.rs b/ballista/client/tests/context_checks.rs index 86bec04f9a..3320b553eb 100644 --- a/ballista/client/tests/context_checks.rs +++ b/ballista/client/tests/context_checks.rs @@ -20,7 +20,8 @@ mod common; mod supported { use crate::common::{ - remote_context, remote_context_with_state, standalone_context, + remote_context, remote_context_push_scheduling, remote_context_with_state, + remote_context_with_state_push_scheduling, standalone_context, standalone_context_with_state, }; use ballista_core::config::BallistaConfig; @@ -43,6 +44,7 @@ mod supported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[tokio::test] async fn should_execute_sql_collect( #[future(awt)] @@ -82,6 +84,7 @@ mod supported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[tokio::test] async fn should_collect_client_statistics_for_show( #[future(awt)] @@ -162,6 +165,7 @@ mod supported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[tokio::test] async fn should_collect_client_statistics_for_insert( #[future(awt)] @@ -233,6 +237,7 @@ mod supported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[tokio::test] async fn should_execute_sql_show_configs( #[future(awt)] @@ -264,6 +269,7 @@ mod supported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[tokio::test] async fn should_execute_sql_show_configs_ballista( #[future(awt)] @@ -300,6 +306,7 @@ mod supported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[tokio::test] async fn should_execute_sql_set_configs( #[future(awt)] @@ -335,6 +342,7 @@ mod supported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[tokio::test] async fn should_execute_show_tables( #[future(awt)] @@ -375,6 +383,7 @@ mod supported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[tokio::test] async fn should_execute_sql_create_external_table( #[future(awt)] @@ -407,6 +416,7 @@ mod supported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[tokio::test] async fn should_collect_from_dataframe( #[future(awt)] @@ -442,6 +452,7 @@ mod supported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[tokio::test] async fn should_execute_sql_write( #[future(awt)] @@ -494,6 +505,7 @@ mod supported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[tokio::test] async fn should_disable_view_types( #[future(awt)] @@ -529,6 +541,7 @@ mod supported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[tokio::test] async fn should_disable_collect_left( #[future(awt)] @@ -558,6 +571,7 @@ mod supported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[tokio::test] async fn should_execute_sql_show_with_url_table( #[future(awt)] @@ -593,6 +607,7 @@ mod supported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[tokio::test] async fn should_support_sql_insert_into( #[future(awt)] @@ -652,8 +667,10 @@ mod supported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[case::standalone_state(standalone_context_with_state())] #[case::remote_state(remote_context_with_state())] + #[case::remote_state_push(remote_context_with_state_push_scheduling())] #[tokio::test] async fn should_execute_sql_write_read_roundtrip( #[future(awt)] @@ -749,8 +766,10 @@ mod supported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[case::standalone_state(standalone_context_with_state())] #[case::remote_state(remote_context_with_state())] + #[case::remote_state_push(remote_context_with_state_push_scheduling())] #[tokio::test] async fn should_execute_sql_show_multiple_times( #[future(awt)] @@ -793,8 +812,10 @@ mod supported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[case::standalone_state(standalone_context_with_state())] #[case::remote_state(remote_context_with_state())] + #[case::remote_state_push(remote_context_with_state_push_scheduling())] #[tokio::test] async fn should_execute_group_by( #[future(awt)] @@ -832,6 +853,7 @@ mod supported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[tokio::test] async fn should_force_local_read( #[future(awt)] @@ -890,6 +912,7 @@ mod supported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[tokio::test] async fn should_force_local_read_with_flight( #[future(awt)] @@ -956,6 +979,7 @@ mod supported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[tokio::test] async fn should_support_sort_merge_join( #[future(awt)] @@ -1011,6 +1035,7 @@ mod supported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[tokio::test] async fn should_support_hash_join_when_opted_in( #[future(awt)] @@ -1069,6 +1094,7 @@ mod supported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[tokio::test] async fn should_execute_explain_query_correctly( #[future(awt)] @@ -1126,6 +1152,7 @@ mod supported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[tokio::test] async fn should_execute_explain_analyze_query( #[future(awt)] @@ -1209,6 +1236,7 @@ mod supported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[tokio::test] async fn should_force_client_pull( #[future(awt)] @@ -1267,6 +1295,7 @@ mod supported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[tokio::test] async fn should_execute_sql_collect_from_arrow_file( @@ -1305,6 +1334,7 @@ mod supported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[tokio::test] async fn should_set_io_retry_config( #[future(awt)] diff --git a/ballista/client/tests/context_setup.rs b/ballista/client/tests/context_setup.rs index df999c06e4..5f6fedd2c6 100644 --- a/ballista/client/tests/context_setup.rs +++ b/ballista/client/tests/context_setup.rs @@ -20,16 +20,26 @@ mod common; #[cfg(test)] mod remote { use ballista::prelude::{SessionConfigExt, SessionContextExt}; + use ballista_core::config::TaskSchedulingPolicy; use datafusion::{ assert_batches_eq, execution::SessionStateBuilder, prelude::{SessionConfig, SessionContext}, }; + use rstest::rstest; + #[rstest] + #[case(TaskSchedulingPolicy::PullStaged)] + #[case(TaskSchedulingPolicy::PushStaged)] #[tokio::test] - async fn should_execute_sql_show_with_custom_state() -> datafusion::error::Result<()> - { - let (host, port) = crate::common::setup_test_cluster().await; + async fn should_execute_sql_show_with_custom_state( + #[case] scheduling_policy: TaskSchedulingPolicy, + ) -> datafusion::error::Result<()> { + let (host, port) = crate::common::setup_test_cluster_with_state_and_scheduling( + SessionStateBuilder::new().with_default_features().build(), + scheduling_policy, + ) + .await; let url = format!("df://{host}:{port}"); let state = SessionStateBuilder::new().with_default_features().build(); @@ -63,9 +73,18 @@ mod remote { Ok(()) } + #[rstest] + #[case(TaskSchedulingPolicy::PullStaged)] + #[case(TaskSchedulingPolicy::PushStaged)] #[tokio::test] - async fn should_execute_sql_set_configs() -> datafusion::error::Result<()> { - let (host, port) = crate::common::setup_test_cluster().await; + async fn should_execute_sql_set_configs( + #[case] scheduling_policy: TaskSchedulingPolicy, + ) -> datafusion::error::Result<()> { + let (host, port) = crate::common::setup_test_cluster_with_state_and_scheduling( + SessionStateBuilder::new().with_default_features().build(), + scheduling_policy, + ) + .await; let url = format!("df://{host}:{port}"); let session_config = SessionConfig::new_with_ballista() diff --git a/ballista/client/tests/context_unsupported.rs b/ballista/client/tests/context_unsupported.rs index 6e038034dd..cfb72eac14 100644 --- a/ballista/client/tests/context_unsupported.rs +++ b/ballista/client/tests/context_unsupported.rs @@ -23,7 +23,9 @@ mod common; /// gets support for them #[cfg(test)] mod unsupported { - use crate::common::{remote_context, standalone_context}; + use crate::common::{ + remote_context, remote_context_push_scheduling, standalone_context, + }; use datafusion::prelude::*; use datafusion::{assert_batches_eq, prelude::SessionContext}; use rstest::*; @@ -36,6 +38,7 @@ mod unsupported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[tokio::test] #[should_panic] async fn should_support_sql_create_table( @@ -63,6 +66,7 @@ mod unsupported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[tokio::test] #[should_panic] async fn should_support_caching_data_frame( @@ -107,6 +111,7 @@ mod unsupported { #[rstest] #[case::standalone(standalone_context())] #[case::remote(remote_context())] + #[case::remote_push(remote_context_push_scheduling())] #[tokio::test] async fn should_support_on_cache_collect( #[future(awt)] diff --git a/ballista/executor/src/executor_server.rs b/ballista/executor/src/executor_server.rs index c87c1322b5..c295139a27 100644 --- a/ballista/executor/src/executor_server.rs +++ b/ballista/executor/src/executor_server.rs @@ -71,6 +71,27 @@ use crate::{TaskExecutionTimes, as_task_status}; type ServerHandle = JoinHandle>; type SchedulerClients = Arc>>; +/// Wait until something is listening on `[host]:port` so the scheduler does not connect +/// before the executor gRPC server task has bound ([`startup`] spawns the listener +/// asynchronously — see in-file TODO). +async fn wait_executor_grpc_listen(host: &str, port: u16) -> Result<(), BallistaError> { + let addr = format!("{host}:{port}"); + for attempt in 0..500 { + match tokio::net::TcpStream::connect(&addr).await { + Ok(_) => return Ok(()), + Err(_) => { + if attempt == 499 { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + } + } + Err(BallistaError::General(format!( + "timed out waiting for executor gRPC listener on {addr}" + ))) +} + /// Wrap TaskDefinition with its curator scheduler id for task update to its specific curator scheduler later #[derive(Debug)] struct CuratorTaskDefinition { @@ -148,6 +169,8 @@ pub async fn startup( }) }; + wait_executor_grpc_listen(&config.bind_host, config.grpc_port).await?; + // 2. Do executor registration // TODO the executor registration should happen only after the executor grpc server started. let executor_server = Arc::new(executor_server); diff --git a/ballista/executor/src/lib.rs b/ballista/executor/src/lib.rs index 3db736e016..8c7ae2c0d4 100644 --- a/ballista/executor/src/lib.rs +++ b/ballista/executor/src/lib.rs @@ -52,7 +52,10 @@ use std::net::SocketAddr; pub use standalone::new_standalone_executor; pub use standalone::new_standalone_executor_from_builder; +pub use standalone::new_standalone_executor_from_builder_with_scheduling_policy; pub use standalone::new_standalone_executor_from_state; +pub use standalone::new_standalone_executor_from_state_with_scheduling_policy; +pub use standalone::new_standalone_executor_with_scheduling_policy; use log::info; diff --git a/ballista/executor/src/standalone.rs b/ballista/executor/src/standalone.rs index bd72dee4b4..972ee61e7c 100644 --- a/ballista/executor/src/standalone.rs +++ b/ballista/executor/src/standalone.rs @@ -20,9 +20,13 @@ //! This module provides functions for creating executors that run in the same //! process as the client, useful for testing and development purposes. +use crate::executor_process::ExecutorProcessConfig; +use crate::executor_server; use crate::metrics::LoggingMetricsCollector; +use crate::shutdown::ShutdownNotifier; use crate::{execution_loop, executor::Executor, flight_service::BallistaFlightService}; use arrow_flight::flight_service_server::FlightServiceServer; +use ballista_core::config::TaskSchedulingPolicy; use ballista_core::extension::SessionConfigExt; use ballista_core::registry::BallistaFunctionRegistry; use ballista_core::utils::{GrpcServerConfig, default_config_producer}; @@ -36,10 +40,11 @@ use ballista_core::{ }; use ballista_core::{ConfigProducer, RuntimeProducer}; use datafusion::execution::{SessionState, SessionStateBuilder}; -use log::info; +use log::{error, info}; use std::sync::Arc; use tempfile::TempDir; use tokio::net::TcpListener; +use tokio::sync::mpsc; use tonic::transport::Channel; use uuid::Uuid; @@ -52,6 +57,22 @@ pub async fn new_standalone_executor_from_state( scheduler: SchedulerGrpcClient, concurrent_tasks: usize, session_state: &SessionState, +) -> Result<()> { + new_standalone_executor_from_state_with_scheduling_policy( + scheduler, + concurrent_tasks, + session_state, + TaskSchedulingPolicy::PullStaged, + ) + .await +} + +/// Same as [`new_standalone_executor_from_state`], with an explicit scheduler policy. +pub async fn new_standalone_executor_from_state_with_scheduling_policy( + scheduler: SchedulerGrpcClient, + concurrent_tasks: usize, + session_state: &SessionState, + scheduling_policy: TaskSchedulingPolicy, ) -> Result<()> { let logical = session_state.config().ballista_logical_extension_codec(); let physical = session_state.config().ballista_physical_extension_codec(); @@ -67,13 +88,14 @@ pub async fn new_standalone_executor_from_state( let config_producer: ConfigProducer = Arc::new(move || config.clone()); let runtime_producer: RuntimeProducer = Arc::new(move |_| Ok(runtime.clone())); - new_standalone_executor_from_builder( + new_standalone_executor_from_builder_with_scheduling_policy( scheduler, concurrent_tasks, config_producer, runtime_producer, codec, session_state.into(), + scheduling_policy, ) .await } @@ -93,16 +115,72 @@ pub async fn new_standalone_executor_from_builder( codec: BallistaCodec, function_registry: BallistaFunctionRegistry, ) -> Result<()> { - // Let the OS assign a random, free port + new_standalone_executor_from_builder_with_scheduling_policy( + scheduler, + concurrent_tasks, + config_producer, + runtime_producer, + codec, + function_registry, + TaskSchedulingPolicy::PullStaged, + ) + .await +} + +/// Same as [`new_standalone_executor_from_builder`] with selectable [`TaskSchedulingPolicy`]. +/// +/// Push mode starts the executor gRPC server required for staged task push from the scheduler. +pub async fn new_standalone_executor_from_builder_with_scheduling_policy( + scheduler: SchedulerGrpcClient, + concurrent_tasks: usize, + config_producer: ConfigProducer, + runtime_producer: RuntimeProducer, + codec: BallistaCodec, + function_registry: BallistaFunctionRegistry, + scheduling_policy: TaskSchedulingPolicy, +) -> Result<()> { + match scheduling_policy { + TaskSchedulingPolicy::PullStaged => { + pull_staged_standalone_executor( + scheduler, + concurrent_tasks, + config_producer, + runtime_producer, + codec, + function_registry, + ) + .await + } + TaskSchedulingPolicy::PushStaged => { + push_staged_standalone_executor( + scheduler, + concurrent_tasks, + config_producer, + runtime_producer, + codec, + function_registry, + ) + .await + } + } +} + +async fn pull_staged_standalone_executor( + scheduler: SchedulerGrpcClient, + concurrent_tasks: usize, + config_producer: ConfigProducer, + runtime_producer: RuntimeProducer, + codec: BallistaCodec, + function_registry: BallistaFunctionRegistry, +) -> Result<()> { let listener = TcpListener::bind("localhost:0").await?; let address = listener.local_addr()?; info!("Ballista v{BALLISTA_VERSION} Rust Executor listening on {address:?}"); let executor_meta = ExecutorRegistration { - id: Uuid::new_v4().to_string(), // assign this executor a unique ID + id: Uuid::new_v4().to_string(), host: Some("localhost".to_string()), port: address.port() as u32, - // TODO Make it configurable grpc_port: 50020, specification: Some( ExecutorSpecification::default() @@ -112,11 +190,46 @@ pub async fn new_standalone_executor_from_builder( os_info: Some(ExecutorOperatingSystemSpecification::default().into()), }; + spawn_executor_services_pull(PullStandaloneServices { + scheduler, + concurrent_tasks, + config_producer, + runtime_producer, + codec, + function_registry, + executor_meta, + listener, + }) + .await +} + +struct PullStandaloneServices { + scheduler: SchedulerGrpcClient, + concurrent_tasks: usize, + config_producer: ConfigProducer, + runtime_producer: RuntimeProducer, + codec: BallistaCodec, + function_registry: BallistaFunctionRegistry, + executor_meta: ExecutorRegistration, + listener: TcpListener, +} + +async fn spawn_executor_services_pull(services: PullStandaloneServices) -> Result<()> { + let PullStandaloneServices { + scheduler, + concurrent_tasks, + config_producer, + runtime_producer, + codec, + function_registry, + executor_meta, + listener, + } = services; + let config = config_producer(); let max_message_size = config.ballista_grpc_client_max_message_size(); let work_dir = TempDir::new()?.path().to_str().unwrap().to_string(); - info!("work_dir: {work_dir}"); let executor = Arc::new(Executor::with_default_execution_engine( @@ -146,8 +259,108 @@ pub async fn new_standalone_executor_from_builder( Ok(()) } -/// Creates standalone executor with most values -/// set as default. +async fn push_staged_standalone_executor( + scheduler: SchedulerGrpcClient, + concurrent_tasks: usize, + config_producer: ConfigProducer, + runtime_producer: RuntimeProducer, + codec: BallistaCodec, + function_registry: BallistaFunctionRegistry, +) -> Result<()> { + let flight_listener = TcpListener::bind("localhost:0").await?; + let flight_addr = flight_listener.local_addr()?; + info!( + "Ballista v{BALLISTA_VERSION} Rust Executor (push) listening on {flight_addr:?}" + ); + + let grpc_probe = TcpListener::bind("127.0.0.1:0").await?; + let grpc_port = grpc_probe.local_addr()?.port(); + drop(grpc_probe); + + let executor_meta = ExecutorRegistration { + id: Uuid::new_v4().to_string(), + host: Some("localhost".to_string()), + port: flight_addr.port() as u32, + grpc_port: grpc_port as u32, + specification: Some( + ExecutorSpecification::default() + .with_task_slots(concurrent_tasks as u32) + .into(), + ), + os_info: Some(ExecutorOperatingSystemSpecification::default().into()), + }; + + let config_snap = config_producer(); + let max_message_sz = config_snap.ballista_grpc_client_max_message_size() as u32; + + let work_dir = TempDir::new()?.path().to_str().unwrap().to_string(); + info!("work_dir: {work_dir}"); + + let executor = Arc::new(Executor::with_default_execution_engine( + executor_meta, + &work_dir, + runtime_producer, + config_producer.clone(), + Arc::new(function_registry), + Arc::new(LoggingMetricsCollector::default()), + concurrent_tasks, + )); + + let service = BallistaFlightService::new(work_dir); + let server = FlightServiceServer::new(service) + .max_decoding_message_size(max_message_sz as usize) + .max_encoding_message_size(max_message_sz as usize); + + tokio::spawn( + create_grpc_server(&GrpcServerConfig::default()) + .add_service(server) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new( + flight_listener, + )), + ); + + let exec_cfg = ExecutorProcessConfig { + bind_host: "127.0.0.1".into(), + port: flight_addr.port(), + grpc_port, + concurrent_tasks, + task_scheduling_policy: TaskSchedulingPolicy::PushStaged, + grpc_max_decoding_message_size: max_message_sz, + grpc_max_encoding_message_size: max_message_sz, + ..ExecutorProcessConfig::default() + }; + + let shutdown_notifier: &'static ShutdownNotifier = + Box::leak(Box::new(ShutdownNotifier::new())); + let (stop_send, _stop_recv) = mpsc::channel::(10); + + let server_handle = executor_server::startup( + scheduler, + Arc::new(exec_cfg), + executor, + codec, + stop_send, + shutdown_notifier, + ) + .await + .map_err(|e| { + error!("Standalone push executor failed to start gRPC server: {e}"); + e + })?; + + tokio::spawn(async move { + match server_handle.await { + Ok(Ok(())) => {} + Ok(Err(e)) => error!("Standalone push executor gRPC server exited: {e:?}"), + Err(join_err) => { + error!("Standalone push executor gRPC join error: {join_err:?}") + } + } + }); + + Ok(()) +} +/// Creates standalone executor with most values set as default (pull scheduling). pub async fn new_standalone_executor( scheduler: SchedulerGrpcClient, concurrent_tasks: usize, @@ -178,3 +391,37 @@ pub async fn new_standalone_executor( ) .await } + +/// Like [`new_standalone_executor`] with selectable [`TaskSchedulingPolicy`]. +pub async fn new_standalone_executor_with_scheduling_policy( + scheduler: SchedulerGrpcClient, + concurrent_tasks: usize, + codec: BallistaCodec, + scheduling_policy: TaskSchedulingPolicy, +) -> Result<()> { + use ballista_core::extension::{ + ballista_aggregate_functions, ballista_scalar_functions, + ballista_window_functions, + }; + + let session_state = SessionStateBuilder::new() + .with_default_features() + .with_scalar_functions(ballista_scalar_functions()) + .with_aggregate_functions(ballista_aggregate_functions()) + .with_window_functions(ballista_window_functions()) + .build(); + + let runtime = session_state.runtime_env().clone(); + let runtime_producer: RuntimeProducer = Arc::new(move |_| Ok(runtime.clone())); + + new_standalone_executor_from_builder_with_scheduling_policy( + scheduler, + concurrent_tasks, + Arc::new(default_config_producer), + runtime_producer, + codec, + (&session_state).into(), + scheduling_policy, + ) + .await +} diff --git a/ballista/scheduler/src/standalone.rs b/ballista/scheduler/src/standalone.rs index a6b6f2693d..ff53e1f778 100644 --- a/ballista/scheduler/src/standalone.rs +++ b/ballista/scheduler/src/standalone.rs @@ -20,6 +20,7 @@ use crate::config::SchedulerConfig; use crate::metrics::default_metrics_collector; use crate::scheduler_server::SchedulerServer; use ballista_core::ConfigProducer; +use ballista_core::config::TaskSchedulingPolicy; use ballista_core::extension::SessionConfigExt; use ballista_core::serde::BallistaCodec; use ballista_core::utils::{ @@ -45,11 +46,23 @@ use tokio::net::TcpListener; /// Returns the socket address the scheduler is listening on. /// Useful for testing and single-node deployments. pub async fn new_standalone_scheduler() -> Result { + new_standalone_scheduler_with_scheduling(TaskSchedulingPolicy::PullStaged).await +} + +/// Creates a standalone scheduler using the specified task-scheduling policy. +/// +/// Prefer [`PullStaged`](TaskSchedulingPolicy::PullStaged) for parity with legacy +/// in-process integration tests when you explicitly need pull-based scheduling, +/// or [`PushStaged`](TaskSchedulingPolicy::PushStaged) to match the executor default policy. +pub async fn new_standalone_scheduler_with_scheduling( + scheduling_policy: TaskSchedulingPolicy, +) -> Result { let codec = BallistaCodec::default(); - new_standalone_scheduler_with_builder( + new_standalone_scheduler_with_builder_and_policy( Arc::new(default_session_builder), Arc::new(default_config_producer), codec, + scheduling_policy, ) .await } @@ -68,32 +81,68 @@ pub async fn new_standalone_scheduler_from_state( let session_builder = Arc::new(move |_: SessionConfig| Ok(session_state.clone())); let config_producer = Arc::new(move || session_config.clone()); - new_standalone_scheduler_with_builder(session_builder, config_producer, codec).await + new_standalone_scheduler_with_builder_and_policy( + session_builder, + config_producer, + codec, + TaskSchedulingPolicy::PullStaged, + ) + .await } /// Creates a standalone scheduler with custom session builder, config producer, and codec. /// +/// Uses [`TaskSchedulingPolicy::PullStaged`], matching legacy integration-test defaults. +/// /// Returns the socket address the scheduler is listening on. pub async fn new_standalone_scheduler_with_builder( session_builder: crate::scheduler_server::SessionBuilder, config_producer: ConfigProducer, codec: BallistaCodec, +) -> Result { + new_standalone_scheduler_with_builder_and_policy( + session_builder, + config_producer, + codec, + TaskSchedulingPolicy::PullStaged, + ) + .await +} + +/// Creates a standalone scheduler with custom session dependencies and a selectable +/// task-scheduling policy. +/// +/// Returns the socket address the scheduler is listening on. +pub async fn new_standalone_scheduler_with_builder_and_policy( + session_builder: crate::scheduler_server::SessionBuilder, + config_producer: ConfigProducer, + codec: BallistaCodec, + scheduling_policy: TaskSchedulingPolicy, ) -> Result { let config = config_producer(); - let cluster = - BallistaCluster::new_memory("localhost:50050", session_builder, config_producer); + // Resolve the scheduler gRPC endpoint before constructing the cluster / server state so + // task metadata (scheduler_id / curator) matches the listener we expose to clients. + // A fixed placeholder host:port breaks push-mode executors which open new clients to + // `http://{scheduler_id}` when reporting task status. + let listener = TcpListener::bind("127.0.0.1:0").await?; + let addr = listener.local_addr()?; + let scheduler_endpoint = addr.to_string(); + + let cluster = BallistaCluster::new_memory( + scheduler_endpoint.clone(), + session_builder, + config_producer, + ); let metrics_collector = default_metrics_collector()?; let mut scheduler_server: SchedulerServer = SchedulerServer::new( - "localhost:50050".to_owned(), + scheduler_endpoint, cluster, codec, - Arc::new(SchedulerConfig::default().with_scheduler_policy( - ballista_core::config::TaskSchedulingPolicy::PullStaged, - )), + Arc::new(SchedulerConfig::default().with_scheduler_policy(scheduling_policy)), metrics_collector, ); @@ -102,9 +151,6 @@ pub async fn new_standalone_scheduler_with_builder( .max_decoding_message_size(config.ballista_grpc_client_max_message_size()) .max_encoding_message_size(config.ballista_grpc_client_max_message_size()); - // Let the OS assign a random, free port - let listener = TcpListener::bind("localhost:0").await?; - let addr = listener.local_addr()?; info!( "Ballista Scheduler v{BALLISTA_VERSION} (DataFusion v{DATAFUSION_VERSION}) listening on {addr:?}" ); @@ -118,3 +164,25 @@ pub async fn new_standalone_scheduler_with_builder( Ok(addr) } + +/// Like [`new_standalone_scheduler_from_state`], but uses `scheduling_policy` for task placement. +pub async fn new_standalone_scheduler_from_state_with_scheduling_policy( + session_state: &SessionState, + scheduling_policy: TaskSchedulingPolicy, +) -> Result { + let logical = session_state.config().ballista_logical_extension_codec(); + let physical = session_state.config().ballista_physical_extension_codec(); + let codec = BallistaCodec::new(logical, physical); + let session_config = session_state.config().clone(); + let session_state = session_state.clone(); + let session_builder = Arc::new(move |_: SessionConfig| Ok(session_state.clone())); + let config_producer = Arc::new(move || session_config.clone()); + + new_standalone_scheduler_with_builder_and_policy( + session_builder, + config_producer, + codec, + scheduling_policy, + ) + .await +} diff --git a/examples/tests/common/mod.rs b/examples/tests/common/mod.rs index 76c987acfc..2ad7c4feb1 100644 --- a/examples/tests/common/mod.rs +++ b/examples/tests/common/mod.rs @@ -83,7 +83,7 @@ pub async fn setup_test_cluster_with_state(session_state: SessionState) -> (Stri .await .expect("scheduler to be created"); - let host = "localhost".to_string(); + let host = "127.0.0.1".to_string(); let scheduler = connect_to_scheduler(format!("http://{}:{}", host, addr.port())).await; @@ -121,7 +121,7 @@ pub async fn setup_test_cluster_with_builders( .await .expect("scheduler to be created"); - let host = "localhost".to_string(); + let host = "127.0.0.1".to_string(); let scheduler = connect_to_scheduler(format!("http://{}:{}", host, addr.port())).await; From b83747972ae5937076a0f36565dd8349e183f3ab Mon Sep 17 00:00:00 2001 From: abhinavgautam01 Date: Tue, 12 May 2026 06:56:33 +0530 Subject: [PATCH 2/2] Address executor standalone/server review feedback --- ballista/executor/src/executor_process.rs | 12 ++- ballista/executor/src/executor_server.rs | 117 ++++++++++++++++------ ballista/executor/src/standalone.rs | 41 ++++---- 3 files changed, 121 insertions(+), 49 deletions(-) diff --git a/ballista/executor/src/executor_process.rs b/ballista/executor/src/executor_process.rs index 131a50549f..282b3af2f4 100644 --- a/ballista/executor/src/executor_process.rs +++ b/ballista/executor/src/executor_process.rs @@ -73,7 +73,7 @@ use crate::metrics::LoggingMetricsCollector; use crate::shutdown::Shutdown; use crate::shutdown::ShutdownNotifier; use crate::{ArrowFlightServerProvider, terminate}; -use crate::{execution_loop, executor_server}; +use crate::{execution_loop, executor_server, executor_server::ExecutorGrpcListen}; /// Wrap a [`RuntimeProducer`] so that every produced /// [`RuntimeEnv`](datafusion::execution::runtime_env::RuntimeEnv) carries a @@ -487,6 +487,15 @@ pub async fn start_executor_process( // PullStaged => executor is polling the scheduler when it is idle match scheduler_policy { TaskSchedulingPolicy::PushStaged => { + let grpc_listen_addr: SocketAddr = + format!("{}:{}", opt.bind_host, executor.metadata.grpc_port) + .parse() + .map_err(|e| { + BallistaError::Configuration(format!( + "invalid executor gRPC listen address ({}:{}): {e}", + opt.bind_host, executor.metadata.grpc_port + )) + })?; service_handlers.push( // If there is executor registration error during startup, return the error and stop early. executor_server::startup( @@ -495,6 +504,7 @@ pub async fn start_executor_process( executor.clone(), default_codec, stop_send, + ExecutorGrpcListen::Dynamic(grpc_listen_addr), &shutdown_notification, ) .await?, diff --git a/ballista/executor/src/executor_server.rs b/ballista/executor/src/executor_server.rs index c295139a27..8eb3b23570 100644 --- a/ballista/executor/src/executor_server.rs +++ b/ballista/executor/src/executor_server.rs @@ -25,6 +25,7 @@ use ballista_core::BALLISTA_VERSION; use memory_stats::memory_stats; use std::collections::HashMap; use std::convert::TryInto; +use std::net::SocketAddr; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; @@ -60,24 +61,32 @@ use datafusion::execution::TaskContext; use datafusion_proto::{logical_plan::AsLogicalPlan, physical_plan::AsExecutionPlan}; use tokio::sync::mpsc::error::TryRecvError; use tokio::task::JoinHandle; +use tokio_stream::wrappers::TcpListenerStream; use crate::cpu_bound_executor::DedicatedExecutor; use crate::executor::Executor; use crate::executor_process::{ExecutorProcessConfig, remove_job_dir}; use crate::metrics::ExecutorMetricCollectionPolicy; -use crate::shutdown::ShutdownNotifier; +use crate::shutdown::{Shutdown, ShutdownNotifier}; use crate::{TaskExecutionTimes, as_task_status}; type ServerHandle = JoinHandle>; type SchedulerClients = Arc>>; -/// Wait until something is listening on `[host]:port` so the scheduler does not connect -/// before the executor gRPC server task has bound ([`startup`] spawns the listener -/// asynchronously — see in-file TODO). -async fn wait_executor_grpc_listen(host: &str, port: u16) -> Result<(), BallistaError> { - let addr = format!("{host}:{port}"); +/// How the executor gRPC server acquires its listen socket for [`startup`]. +pub(crate) enum ExecutorGrpcListen { + /// Tonic binds the socket inside `serve_with_shutdown` (normal executor process). + Dynamic(SocketAddr), + /// Caller already bound `TcpListener` on the desired port (standalone push); avoids a + /// reserve-port probe/drop race. + Bound(tokio::net::TcpListener), +} + +/// Wait until something is accepting TCP connections on `addr` so registration does not run +/// before the executor gRPC server task has bound ([`startup`] — see TODO below). +async fn wait_executor_grpc_listen(addr: SocketAddr) -> Result<(), BallistaError> { for attempt in 0..500 { - match tokio::net::TcpStream::connect(&addr).await { + match tokio::net::TcpStream::connect(addr).await { Ok(_) => return Ok(()), Err(_) => { if attempt == 499 { @@ -115,14 +124,20 @@ struct CuratorTaskStatus { /// - Starts the task runner pool /// /// Returns a handle to the server task that can be awaited for completion. -pub async fn startup( +pub(crate) async fn startup( mut scheduler: SchedulerGrpcClient, config: Arc, executor: Arc, codec: BallistaCodec, stop_send: mpsc::Sender, + grpc_listen: ExecutorGrpcListen, shutdown_noti: &ShutdownNotifier, ) -> Result { + debug_assert_eq!( + executor.metadata.grpc_port as u16, config.grpc_port, + "executor registration metadata.grpc_port must match ExecutorProcessConfig.grpc_port" + ); + let channel_buf_size = executor.concurrent_tasks * 50; let (tx_task, rx_task) = mpsc::channel::(channel_buf_size); let (tx_task_status, rx_task_status) = @@ -143,33 +158,75 @@ pub async fn startup( config.metric_collection_policy, ); + let grpc_server_config = config.grpc_server_config.clone(); + let enc = config.grpc_max_encoding_message_size as usize; + let dec = config.grpc_max_decoding_message_size as usize; + // 1. Start executor grpc service - let server = { - let executor_meta = executor.metadata.clone(); - let addr = format!("{}:{}", config.bind_host, executor_meta.grpc_port); - let addr = addr.parse().unwrap(); - let grpc_server_config = config.grpc_server_config.clone(); + let listen_addr_for_wait = match &grpc_listen { + ExecutorGrpcListen::Dynamic(addr) => Some(*addr), + ExecutorGrpcListen::Bound(listener) => { + debug_assert_eq!( + listener.local_addr()?.port(), + config.grpc_port, + "bound listener port must match config.grpc_port / registration metadata" + ); + None + } + }; - info!( - "Ballista v{BALLISTA_VERSION} Rust Executor Grpc Server listening on {addr:?}" - ); - let server = ExecutorGrpcServer::new(executor_server.clone()) - .max_encoding_message_size(config.grpc_max_encoding_message_size as usize) - .max_decoding_message_size(config.grpc_max_decoding_message_size as usize); - let mut grpc_shutdown = shutdown_noti.subscribe_for_shutdown(); - tokio::spawn(async move { - let shutdown_signal = grpc_shutdown.recv(); - let grpc_server_future = create_grpc_server(&grpc_server_config) - .add_service(server) - .serve_with_shutdown(addr, shutdown_signal); - grpc_server_future.await.map_err(|e| { - error!("Tonic error, Could not start Executor Grpc Server."); - BallistaError::TonicError(e) + let shutdown_grpc_notify = shutdown_noti.notify_shutdown.clone(); + let server = match grpc_listen { + ExecutorGrpcListen::Dynamic(addr) => { + info!( + "Ballista v{BALLISTA_VERSION} Rust Executor Grpc Server listening on {addr:?}" + ); + let server_svc = ExecutorGrpcServer::new(executor_server.clone()) + .max_encoding_message_size(enc) + .max_decoding_message_size(dec); + tokio::spawn(async move { + let mut grpc_shutdown = Shutdown::new(shutdown_grpc_notify.subscribe()); + let shutdown_signal = grpc_shutdown.recv(); + create_grpc_server(&grpc_server_config) + .add_service(server_svc) + .serve_with_shutdown(addr, shutdown_signal) + .await + .map_err(|e| { + error!("Tonic error, Could not start Executor Grpc Server."); + BallistaError::TonicError(e) + }) }) - }) + } + ExecutorGrpcListen::Bound(listener) => { + let addr = listener.local_addr()?; + info!( + "Ballista v{BALLISTA_VERSION} Rust Executor Grpc Server listening on {addr:?}" + ); + let server_svc = ExecutorGrpcServer::new(executor_server.clone()) + .max_encoding_message_size(enc) + .max_decoding_message_size(dec); + tokio::spawn(async move { + let mut grpc_shutdown = Shutdown::new(shutdown_grpc_notify.subscribe()); + let shutdown_signal = grpc_shutdown.recv(); + let incoming = TcpListenerStream::new(listener); + create_grpc_server(&grpc_server_config) + .add_service(server_svc) + .serve_with_incoming_shutdown(incoming, shutdown_signal) + .await + .map_err(|e| { + error!("Tonic error, Could not start Executor Grpc Server."); + BallistaError::TonicError(e) + }) + }) + } }; - wait_executor_grpc_listen(&config.bind_host, config.grpc_port).await?; + if let Some(addr) = listen_addr_for_wait { + wait_executor_grpc_listen(addr).await?; + } else { + // Listener is already bound; yield once so the server task can start polling. + tokio::task::yield_now().await; + } // 2. Do executor registration // TODO the executor registration should happen only after the executor grpc server started. diff --git a/ballista/executor/src/standalone.rs b/ballista/executor/src/standalone.rs index 972ee61e7c..6afba8cfd5 100644 --- a/ballista/executor/src/standalone.rs +++ b/ballista/executor/src/standalone.rs @@ -21,7 +21,7 @@ //! process as the client, useful for testing and development purposes. use crate::executor_process::ExecutorProcessConfig; -use crate::executor_server; +use crate::executor_server::{self, ExecutorGrpcListen}; use crate::metrics::LoggingMetricsCollector; use crate::shutdown::ShutdownNotifier; use crate::{execution_loop, executor::Executor, flight_service::BallistaFlightService}; @@ -106,7 +106,7 @@ pub async fn new_standalone_executor_from_state_with_scheduling_policy( /// by accepting custom producers for session config, runtime environment, /// codec, and function registry. /// -/// The executor binds to a random available port on localhost. +/// The executor binds to a random available port on IPv4 loopback (`127.0.0.1`). pub async fn new_standalone_executor_from_builder( scheduler: SchedulerGrpcClient, concurrent_tasks: usize, @@ -173,14 +173,15 @@ async fn pull_staged_standalone_executor( codec: BallistaCodec, function_registry: BallistaFunctionRegistry, ) -> Result<()> { - let listener = TcpListener::bind("localhost:0").await?; + let listener = TcpListener::bind("127.0.0.1:0").await?; let address = listener.local_addr()?; info!("Ballista v{BALLISTA_VERSION} Rust Executor listening on {address:?}"); let executor_meta = ExecutorRegistration { id: Uuid::new_v4().to_string(), - host: Some("localhost".to_string()), + host: Some("127.0.0.1".to_string()), port: address.port() as u32, + // TODO: allow configuring advertised gRPC port for standalone pull (currently fixed default). grpc_port: 50020, specification: Some( ExecutorSpecification::default() @@ -267,19 +268,18 @@ async fn push_staged_standalone_executor( codec: BallistaCodec, function_registry: BallistaFunctionRegistry, ) -> Result<()> { - let flight_listener = TcpListener::bind("localhost:0").await?; + let flight_listener = TcpListener::bind("127.0.0.1:0").await?; let flight_addr = flight_listener.local_addr()?; info!( "Ballista v{BALLISTA_VERSION} Rust Executor (push) listening on {flight_addr:?}" ); - let grpc_probe = TcpListener::bind("127.0.0.1:0").await?; - let grpc_port = grpc_probe.local_addr()?.port(); - drop(grpc_probe); + let grpc_listener = TcpListener::bind("127.0.0.1:0").await?; + let grpc_port = grpc_listener.local_addr()?.port(); let executor_meta = ExecutorRegistration { id: Uuid::new_v4().to_string(), - host: Some("localhost".to_string()), + host: Some("127.0.0.1".to_string()), port: flight_addr.port() as u32, grpc_port: grpc_port as u32, specification: Some( @@ -291,7 +291,9 @@ async fn push_staged_standalone_executor( }; let config_snap = config_producer(); - let max_message_sz = config_snap.ballista_grpc_client_max_message_size() as u32; + let max_message_size = config_snap.ballista_grpc_client_max_message_size(); + // ExecutorProcessConfig stores tonic limits as u32; SessionConfig exposes usize. + let grpc_max_message_size = max_message_size.min(u32::MAX as usize) as u32; let work_dir = TempDir::new()?.path().to_str().unwrap().to_string(); info!("work_dir: {work_dir}"); @@ -308,8 +310,8 @@ async fn push_staged_standalone_executor( let service = BallistaFlightService::new(work_dir); let server = FlightServiceServer::new(service) - .max_decoding_message_size(max_message_sz as usize) - .max_encoding_message_size(max_message_sz as usize); + .max_decoding_message_size(max_message_size) + .max_encoding_message_size(max_message_size); tokio::spawn( create_grpc_server(&GrpcServerConfig::default()) @@ -325,14 +327,15 @@ async fn push_staged_standalone_executor( grpc_port, concurrent_tasks, task_scheduling_policy: TaskSchedulingPolicy::PushStaged, - grpc_max_decoding_message_size: max_message_sz, - grpc_max_encoding_message_size: max_message_sz, + grpc_max_decoding_message_size: grpc_max_message_size, + grpc_max_encoding_message_size: grpc_max_message_size, ..ExecutorProcessConfig::default() }; - let shutdown_notifier: &'static ShutdownNotifier = - Box::leak(Box::new(ShutdownNotifier::new())); - let (stop_send, _stop_recv) = mpsc::channel::(10); + let shutdown_notifier = Arc::new(ShutdownNotifier::new()); + let shutdown_keepalive_for_join = shutdown_notifier.clone(); + let (stop_send, mut stop_recv) = mpsc::channel::(10); + tokio::spawn(async move { while stop_recv.recv().await.is_some() {} }); let server_handle = executor_server::startup( scheduler, @@ -340,7 +343,8 @@ async fn push_staged_standalone_executor( executor, codec, stop_send, - shutdown_notifier, + ExecutorGrpcListen::Bound(grpc_listener), + shutdown_notifier.as_ref(), ) .await .map_err(|e| { @@ -349,6 +353,7 @@ async fn push_staged_standalone_executor( })?; tokio::spawn(async move { + let _shutdown_keepalive = shutdown_keepalive_for_join; match server_handle.await { Ok(Ok(())) => {} Ok(Err(e)) => error!("Standalone push executor gRPC server exited: {e:?}"),