diff --git a/benchmarks/README.md b/benchmarks/README.md index 8fed85fa02b8..142762e86561 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -762,7 +762,7 @@ Different queries are included to test nested loop joins under various workloads ## Hash Join -This benchmark focuses on the performance of queries with nested hash joins, minimizing other overheads such as scanning data sources or evaluating predicates. +This benchmark focuses on the performance of queries with hash joins, minimizing other overheads such as scanning data sources or evaluating predicates. Several queries are included to test hash joins under various workloads. @@ -774,6 +774,19 @@ Several queries are included to test hash joins under various workloads. ./bench.sh run hj ``` +## Sort Merge Join + +This benchmark focuses on the performance of queries with sort merge joins joins, minimizing other overheads such as scanning data sources or evaluating predicates. + +Several queries are included to test sort merge joins under various workloads. + +### Example Run + +```bash +# No need to generate data: this benchmark uses table function `range()` as the data source + +./bench.sh run smj +``` ## Cancellation Test performance of cancelling queries. diff --git a/benchmarks/bench.sh b/benchmarks/bench.sh index dbfd319dd9ad..948f75635311 100755 --- a/benchmarks/bench.sh +++ b/benchmarks/bench.sh @@ -126,6 +126,7 @@ imdb: Join Order Benchmark (JOB) using the IMDB dataset conver cancellation: How long cancelling a query takes nlj: Benchmark for simple nested loop joins, testing various join scenarios hj: Benchmark for simple hash joins, testing various join scenarios +smj: Benchmark for simple sort merge joins, testing various join scenarios compile_profile: Compile and execute TPC-H across selected Cargo profiles, reporting timing and binary size @@ -311,6 +312,10 @@ main() { # hj uses range() function, no data generation needed echo "HJ benchmark does not require data generation" ;; + smj) + # smj uses range() function, no data generation needed + echo "SMJ benchmark does not require data generation" + ;; compile_profile) data_tpch "1" ;; @@ -384,6 +389,7 @@ main() { run_external_aggr run_nlj run_hj + run_smj ;; tpch) run_tpch "1" "parquet" @@ -494,6 +500,9 @@ main() { hj) run_hj ;; + smj) + run_smj + ;; compile_profile) run_compile_profile "${PROFILE_ARGS[@]}" ;; @@ -1154,6 +1163,14 @@ run_hj() { debug_run $CARGO_COMMAND --bin dfbench -- hj --iterations 5 -o "${RESULTS_FILE}" ${QUERY_ARG} } +# Runs the smj benchmark +run_smj() { + RESULTS_FILE="${RESULTS_DIR}/smj.json" + echo "RESULTS_FILE: ${RESULTS_FILE}" + echo "Running smj benchmark..." + debug_run $CARGO_COMMAND --bin dfbench -- smj --iterations 5 -o "${RESULTS_FILE}" ${QUERY_ARG} +} + compare_benchmarks() { BASE_RESULTS_DIR="${SCRIPT_DIR}/results" diff --git a/benchmarks/src/bin/dfbench.rs b/benchmarks/src/bin/dfbench.rs index 816cae0e3855..2fbc7bef3dca 100644 --- a/benchmarks/src/bin/dfbench.rs +++ b/benchmarks/src/bin/dfbench.rs @@ -34,7 +34,7 @@ static ALLOC: snmalloc_rs::SnMalloc = snmalloc_rs::SnMalloc; static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; use datafusion_benchmarks::{ - cancellation, clickbench, h2o, hj, imdb, nlj, sort_tpch, tpch, + cancellation, clickbench, h2o, hj, imdb, nlj, smj, sort_tpch, tpch, }; #[derive(Debug, StructOpt)] @@ -46,6 +46,7 @@ enum Options { HJ(hj::RunOpt), Imdb(imdb::RunOpt), Nlj(nlj::RunOpt), + Smj(smj::RunOpt), SortTpch(sort_tpch::RunOpt), Tpch(tpch::RunOpt), TpchConvert(tpch::ConvertOpt), @@ -63,6 +64,7 @@ pub async fn main() -> Result<()> { Options::HJ(opt) => opt.run().await, Options::Imdb(opt) => Box::pin(opt.run()).await, Options::Nlj(opt) => opt.run().await, + Options::Smj(opt) => opt.run().await, Options::SortTpch(opt) => opt.run().await, Options::Tpch(opt) => Box::pin(opt.run()).await, Options::TpchConvert(opt) => opt.run().await, diff --git a/benchmarks/src/lib.rs b/benchmarks/src/lib.rs index 07cffa5ae468..5c8c2dcff568 100644 --- a/benchmarks/src/lib.rs +++ b/benchmarks/src/lib.rs @@ -22,6 +22,7 @@ pub mod h2o; pub mod hj; pub mod imdb; pub mod nlj; +pub mod smj; pub mod sort_tpch; pub mod tpch; pub mod util; diff --git a/benchmarks/src/smj.rs b/benchmarks/src/smj.rs new file mode 100644 index 000000000000..32a620a12d4f --- /dev/null +++ b/benchmarks/src/smj.rs @@ -0,0 +1,524 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::util::{BenchmarkRun, CommonOpt, QueryResult}; +use datafusion::physical_plan::execute_stream; +use datafusion::{error::Result, prelude::SessionContext}; +use datafusion_common::instant::Instant; +use datafusion_common::{exec_datafusion_err, exec_err, DataFusionError}; +use structopt::StructOpt; + +use futures::StreamExt; + +/// Run the Sort Merge Join (SMJ) benchmark +/// +/// This micro-benchmark focuses on the performance characteristics of SMJs. +/// +/// It uses equality join predicates (to ensure SMJ is selected) and varies: +/// - Join type: Inner/Left/Right/Full/LeftSemi/LeftAnti/RightSemi/RightAnti +/// - Key cardinality: 1:1, 1:N, N:M relationships +/// - Filter selectivity: Low (1%), Medium (10%), High (50%) +/// - Input sizes: Small to large, balanced and skewed +/// +/// All inputs are pre-sorted in CTEs before the join to isolate join +/// performance from sort overhead. +#[derive(Debug, StructOpt, Clone)] +#[structopt(verbatim_doc_comment)] +pub struct RunOpt { + /// Query number (between 1 and 20). If not specified, runs all queries + #[structopt(short, long)] + query: Option, + + /// Common options + #[structopt(flatten)] + common: CommonOpt, + + /// If present, write results json here + #[structopt(parse(from_os_str), short = "o", long = "output")] + output_path: Option, +} + +/// Inline SQL queries for SMJ benchmarks +/// +/// Each query's comment includes: +/// - Join type +/// - Left row count × Right row count +/// - Key cardinality (rows per key) +/// - Filter selectivity (if applicable) +const SMJ_QUERIES: &[&str] = &[ + // Q1: INNER 100K x 100K | 1:1 + r#" + WITH t1_sorted AS ( + SELECT value as key FROM range(100000) ORDER BY value + ), + t2_sorted AS ( + SELECT value as key FROM range(100000) ORDER BY value + ) + SELECT t1_sorted.key as k1, t2_sorted.key as k2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + "#, + // Q2: INNER 100K x 1M | 1:10 + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + "#, + // Q3: INNER 1M x 1M | 1:100 + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + "#, + // Q4: INNER 100K x 1M | 1:10 | 1% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + WHERE t2_sorted.data % 100 = 0 + "#, + // Q5: INNER 1M x 1M | 1:100 | 10% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + WHERE t1_sorted.data <> t2_sorted.data AND t2_sorted.data % 10 = 0 + "#, + // Q6: LEFT 100K x 1M | 1:10 + r#" + WITH t1_sorted AS ( + SELECT value % 10500 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted LEFT JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + "#, + // Q7: LEFT 100K x 1M | 1:10 | 50% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted LEFT JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + WHERE t2_sorted.data IS NULL OR t2_sorted.data % 2 = 0 + "#, + // Q8: FULL 100K x 100K | 1:10 + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 12500 as key, value as data + FROM range(100000) + ORDER BY key, data + ) + SELECT t1_sorted.key as k1, t1_sorted.data as d1, + t2_sorted.key as k2, t2_sorted.data as d2 + FROM t1_sorted FULL JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + "#, + // Q9: FULL 100K x 1M | 1:10 | 10% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key as k1, t1_sorted.data as d1, + t2_sorted.key as k2, t2_sorted.data as d2 + FROM t1_sorted FULL JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + WHERE (t1_sorted.data IS NULL OR t2_sorted.data IS NULL + OR t1_sorted.data <> t2_sorted.data) + AND (t1_sorted.data IS NULL OR t1_sorted.data % 10 = 0) + "#, + // Q10: LEFT SEMI 100K x 1M | 1:10 + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key + FROM range(1000000) + ORDER BY key + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + ) + "#, + // Q11: LEFT SEMI 100K x 1M | 1:10 | 1% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + AND t2_sorted.data <> t1_sorted.data + AND t2_sorted.data % 100 = 0 + ) + "#, + // Q12: LEFT SEMI 100K x 1M | 1:10 | 50% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + AND t2_sorted.data <> t1_sorted.data + AND t2_sorted.data % 2 = 0 + ) + "#, + // Q13: LEFT SEMI 100K x 1M | 1:10 | 90% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + AND t2_sorted.data % 10 <> 0 + ) + "#, + // Q14: LEFT ANTI 100K x 1M | 1:10 + r#" + WITH t1_sorted AS ( + SELECT value % 10500 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key + FROM range(1000000) + ORDER BY key + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE NOT EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + ) + "#, + // Q15: LEFT ANTI 100K x 1M | 1:10 | partial match + r#" + WITH t1_sorted AS ( + SELECT value % 12000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key + FROM range(1000000) + ORDER BY key + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE NOT EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + ) + "#, + // Q16: LEFT ANTI 100K x 100K | 1:1 | stress + r#" + WITH t1_sorted AS ( + SELECT value % 11000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key + FROM range(100000) + ORDER BY key + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE NOT EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + ) + "#, + // Q17: INNER 100K x 5M | 1:50 | 5% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(5000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data as d1, t2_sorted.data as d2 + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + WHERE t2_sorted.data <> t1_sorted.data AND t2_sorted.data % 20 = 0 + "#, + // Q18: LEFT SEMI 100K x 5M | 1:50 | 2% + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(5000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + AND t2_sorted.data <> t1_sorted.data + AND t2_sorted.data % 50 = 0 + ) + "#, + // Q19: LEFT ANTI 100K x 5M | 1:50 | partial match + r#" + WITH t1_sorted AS ( + SELECT value % 15000 as key, value as data + FROM range(100000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key + FROM range(5000000) + ORDER BY key + ) + SELECT t1_sorted.key, t1_sorted.data + FROM t1_sorted + WHERE NOT EXISTS ( + SELECT 1 FROM t2_sorted + WHERE t2_sorted.key = t1_sorted.key + ) + "#, + // Q20: INNER 1M x 10M | 1:100 + GROUP BY + r#" + WITH t1_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(1000000) + ORDER BY key, data + ), + t2_sorted AS ( + SELECT value % 10000 as key, value as data + FROM range(10000000) + ORDER BY key, data + ) + SELECT t1_sorted.key, count(*) as cnt + FROM t1_sorted JOIN t2_sorted ON t1_sorted.key = t2_sorted.key + GROUP BY t1_sorted.key + "#, +]; + +impl RunOpt { + pub async fn run(self) -> Result<()> { + println!("Running SMJ benchmarks with the following options: {self:#?}\n"); + + // Define query range + let query_range = match self.query { + Some(query_id) => { + if query_id >= 1 && query_id <= SMJ_QUERIES.len() { + query_id..=query_id + } else { + return exec_err!( + "Query {query_id} not found. Available queries: 1 to {}", + SMJ_QUERIES.len() + ); + } + } + None => 1..=SMJ_QUERIES.len(), + }; + + let mut config = self.common.config()?; + // Disable hash joins to force SMJ + config = config.set_bool("datafusion.optimizer.prefer_hash_join", false); + let rt_builder = self.common.runtime_env_builder()?; + let ctx = SessionContext::new_with_config_rt(config, rt_builder.build_arc()?); + + let mut benchmark_run = BenchmarkRun::new(); + for query_id in query_range { + let query_index = query_id - 1; // Convert 1-based to 0-based index + + let sql = SMJ_QUERIES[query_index]; + benchmark_run.start_new_case(&format!("Query {query_id}")); + let query_run = self.benchmark_query(sql, &query_id.to_string(), &ctx).await; + match query_run { + Ok(query_results) => { + for iter in query_results { + benchmark_run.write_iter(iter.elapsed, iter.row_count); + } + } + Err(e) => { + return Err(DataFusionError::Context( + format!("SMJ benchmark Q{query_id} failed with error:"), + Box::new(e), + )); + } + } + } + + benchmark_run.maybe_write_json(self.output_path.as_ref())?; + Ok(()) + } + + async fn benchmark_query( + &self, + sql: &str, + query_name: &str, + ctx: &SessionContext, + ) -> Result> { + let mut query_results = vec![]; + + // Validate that the query plan includes a Sort Merge Join + let df = ctx.sql(sql).await?; + let physical_plan = df.create_physical_plan().await?; + let plan_string = format!("{physical_plan:#?}"); + + if !plan_string.contains("SortMergeJoinExec") { + return Err(exec_datafusion_err!( + "Query {query_name} does not use Sort Merge Join. Physical plan: {plan_string}" + )); + } + + for i in 0..self.common.iterations { + let start = Instant::now(); + + let row_count = Self::execute_sql_without_result_buffering(sql, ctx).await?; + + let elapsed = start.elapsed(); + + println!( + "Query {query_name} iteration {i} returned {row_count} rows in {elapsed:?}" + ); + + query_results.push(QueryResult { elapsed, row_count }); + } + + Ok(query_results) + } + + /// Executes the SQL query and drops each result batch after evaluation, to + /// minimizes memory usage by not buffering results. + /// + /// Returns the total result row count + async fn execute_sql_without_result_buffering( + sql: &str, + ctx: &SessionContext, + ) -> Result { + let mut row_count = 0; + + let df = ctx.sql(sql).await?; + let physical_plan = df.create_physical_plan().await?; + let mut stream = execute_stream(physical_plan, ctx.task_ctx())?; + + while let Some(batch) = stream.next().await { + row_count += batch?.num_rows(); + + // Evaluate the result and do nothing, the result will be dropped + // to reduce memory pressure + } + + Ok(row_count) + } +} diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs index 0325e37d42e7..dd86628b2be2 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs @@ -41,7 +41,8 @@ use crate::{PhysicalExpr, RecordBatchStream, SendableRecordBatchStream}; use arrow::array::{types::UInt64Type, *}; use arrow::compute::{ - self, concat_batches, filter_record_batch, is_not_null, take, SortOptions, + self, concat_batches, filter_record_batch, is_not_null, take, BatchCoalescer, + SortOptions, }; use arrow::datatypes::{DataType, SchemaRef, TimeUnit}; use arrow::error::ArrowError; @@ -323,11 +324,15 @@ pub(super) struct SortMergeJoinStream { pub staging_output_record_batches: JoinedRecordBatches, /// Output buffer. Currently used by filtering as it requires double buffering /// to avoid small/empty batches. Non-filtered join outputs directly from `staging_output_record_batches.batches` - pub output: RecordBatch, + /// Uses BatchCoalescer to accumulate small batches efficiently without repeated concatenation. + pub output_buffer: BatchCoalescer, /// Staging output size, including output batches and staging joined results. /// Increased when we put rows into buffer and decreased after we actually output batches. /// Used to trigger output when sufficient rows are ready pub output_size: usize, + /// Flag indicating that staging_output_record_batches coalescer has reached target + /// and should be processed + pub staging_ready: bool, /// The comparison result of current streamed row and buffered batches pub current_ordering: Ordering, /// Manages the process of spilling and reading back intermediate data @@ -350,7 +355,7 @@ pub(super) struct SortMergeJoinStream { /// Joined batches with attached join filter information pub(super) struct JoinedRecordBatches { /// Joined batches. Each batch is already joined columns from left and right sources - pub batches: Vec, + pub coalescer: BatchCoalescer, /// Filter match mask for each row(matched/non-matched) pub filter_mask: BooleanBuilder, /// Left row indices to glue together rows in `batches` and `filter_mask` @@ -363,7 +368,7 @@ pub(super) struct JoinedRecordBatches { impl JoinedRecordBatches { fn clear(&mut self) { - self.batches.clear(); + // Note: BatchCoalescer clears itself in finish_buffered_batch() self.batch_ids.clear(); self.filter_mask = BooleanBuilder::new(); self.row_indices = UInt64Builder::new(); @@ -592,29 +597,42 @@ impl Stream for SortMergeJoinStream { self.freeze_all()?; // If join is filtered and there is joined tuples waiting - // to be filtered - if !self - .staging_output_record_batches - .batches - .is_empty() - { + // to be filtered. Process when coalescer has reached target size. + if self.staging_ready { + // Track buffered row count before draining the coalescer + let pre_filter_row_count = self + .staging_output_record_batches + .coalescer + .get_buffered_rows(); + // Apply filter on joined tuples and get filtered batch let out_filtered_batch = self.filter_joined_batch()?; + // Decrement output_size by the number of unfiltered rows processed. + // output_size tracks unfiltered pairs, but we just processed + // pre_filter_row_count rows from the coalescer. + if pre_filter_row_count > self.output_size { + self.output_size = 0; + } else { + self.output_size -= pre_filter_row_count; + } + + // Reset the flag after processing + self.staging_ready = false; + // Append filtered batch to the output buffer - self.output = concat_batches( - &self.schema(), - [&self.output, &out_filtered_batch], - )?; - - // Send to output if the output buffer surpassed the `batch_size` - if self.output.num_rows() >= self.batch_size { - let record_batch = std::mem::replace( - &mut self.output, - RecordBatch::new_empty( - out_filtered_batch.schema(), - ), + self.output_buffer + .push_batch(out_filtered_batch)?; + if self.output_buffer.has_completed_batch() { + self.output_buffer + .finish_buffered_batch()?; + let record_batch = self + .output_buffer + .next_completed_batch() + .unwrap(); + (&record_batch).record_output( + &self.join_metrics.baseline_metrics(), ); return Poll::Ready(Some(Ok( record_batch, @@ -676,13 +694,11 @@ impl Stream for SortMergeJoinStream { } } else { self.freeze_all()?; - if !self.staging_output_record_batches.batches.is_empty() { - let record_batch = self.output_record_batch_and_reset()?; - // For non-filtered join output whenever the target output batch size - // is hit. For filtered join its needed to output on later phase - // because target output batch size can be hit in the middle of - // filtering causing the filtering to be incomplete and causing - // correctness issues + // Only process if coalescer has reached target + if self.staging_ready { + // For filtered joins, batches accumulate across multiple freeze_all() calls + // and are processed at safe transition points (between streamed batches or + // at Exhausted state). Don't output here in the tight JoinOutput loop. if self.filter.is_some() && matches!( self.join_type, @@ -697,19 +713,31 @@ impl Stream for SortMergeJoinStream { | JoinType::Full ) { + // Keep staging_ready set to let it propagate to Init state + // where it will be processed. Transition to Init state to continue. + self.buffered_data.scanning_reset(); + self.state = SortMergeJoinState::Init; continue; + } else { + // Non-filtered joins output immediately + let record_batch = + self.output_record_batch_and_reset()?; + self.staging_ready = false; + (&record_batch) + .record_output(&self.join_metrics.baseline_metrics()); + return Poll::Ready(Some(Ok(record_batch))); } - - return Poll::Ready(Some(Ok(record_batch))); } - return Poll::Pending; + // Reset scanning and transition to Init to continue processing + self.buffered_data.scanning_reset(); + self.state = SortMergeJoinState::Init; } } SortMergeJoinState::Exhausted => { self.freeze_all()?; - // if there is still something not processed - if !self.staging_output_record_batches.batches.is_empty() { + // if there is still something not processed in coalescer + if !self.staging_output_record_batches.coalescer.is_empty() { if self.filter.is_some() && matches!( self.join_type, @@ -725,18 +753,24 @@ impl Stream for SortMergeJoinStream { ) { let record_batch = self.filter_joined_batch()?; + (&record_batch) + .record_output(&self.join_metrics.baseline_metrics()); return Poll::Ready(Some(Ok(record_batch))); } else { let record_batch = self.output_record_batch_and_reset()?; + (&record_batch) + .record_output(&self.join_metrics.baseline_metrics()); return Poll::Ready(Some(Ok(record_batch))); } - } else if self.output.num_rows() > 0 { + } else if !self.output_buffer.is_empty() { // if processed but still not outputted because it didn't hit batch size before - let schema = self.output.schema(); - let record_batch = std::mem::replace( - &mut self.output, - RecordBatch::new_empty(schema), - ); + self.output_buffer.finish_buffered_batch()?; + let record_batch = + self.output_buffer.next_completed_batch().unwrap_or_else( + || RecordBatch::new_empty(Arc::clone(&self.schema)), + ); + (&record_batch) + .record_output(&self.join_metrics.baseline_metrics()); return Poll::Ready(Some(Ok(record_batch))); } else { return Poll::Ready(None); @@ -794,13 +828,16 @@ impl SortMergeJoinStream { on_buffered, filter, staging_output_record_batches: JoinedRecordBatches { - batches: vec![], + coalescer: BatchCoalescer::new(Arc::clone(&schema), batch_size) + .with_biggest_coalesce_batch_size(Some(batch_size / 2)), filter_mask: BooleanBuilder::new(), row_indices: UInt64Builder::new(), batch_ids: vec![], }, - output: RecordBatch::new_empty(schema), + output_buffer: BatchCoalescer::new(Arc::clone(&schema), batch_size) + .with_biggest_coalesce_batch_size(Some(batch_size / 2)), output_size: 0, + staging_ready: false, batch_size, join_type, join_metrics, @@ -1202,8 +1239,15 @@ impl SortMergeJoinStream { ); self.staging_output_record_batches - .batches - .push(record_batch); + .coalescer + .push_batch(record_batch)?; + if self + .staging_output_record_batches + .coalescer + .has_completed_batch() + { + self.staging_ready = true; + } } buffered_batch.null_joined.clear(); } @@ -1248,8 +1292,15 @@ impl SortMergeJoinStream { 0, ); self.staging_output_record_batches - .batches - .push(record_batch); + .coalescer + .push_batch(record_batch)?; + if self + .staging_output_record_batches + .coalescer + .has_completed_batch() + { + self.staging_ready = true; + } } buffered_batch.join_filter_not_matched_map.clear(); @@ -1259,6 +1310,7 @@ impl SortMergeJoinStream { // Produces and stages record batch for all output indices found // for current streamed batch and clears staged output indices. fn freeze_streamed(&mut self) -> Result<()> { + let mut rows_processed = 0; for chunk in self.streamed_batch.output_indices.iter_mut() { // The row indices of joined streamed batch let left_indices = chunk.streamed_indices.finish(); @@ -1267,6 +1319,8 @@ impl SortMergeJoinStream { continue; } + rows_processed += left_indices.len(); + let mut left_columns = self .streamed_batch .batch @@ -1391,13 +1445,26 @@ impl SortMergeJoinStream { | JoinType::Full ) { self.staging_output_record_batches - .batches - .push(output_batch); + .coalescer + .push_batch(output_batch)?; + if self + .staging_output_record_batches + .coalescer + .has_completed_batch() + { + self.staging_ready = true; + } } else { - let filtered_batch = filter_record_batch(&output_batch, &mask)?; self.staging_output_record_batches - .batches - .push(filtered_batch); + .coalescer + .push_batch_with_filter(output_batch, &mask)?; + if self + .staging_output_record_batches + .coalescer + .has_completed_batch() + { + self.staging_ready = true; + } } if !matches!(self.join_type, JoinType::Full) { @@ -1445,25 +1512,71 @@ impl SortMergeJoinStream { } } else { self.staging_output_record_batches - .batches - .push(output_batch); + .coalescer + .push_batch(output_batch)?; + if self + .staging_output_record_batches + .coalescer + .has_completed_batch() + { + self.staging_ready = true; + } } } else { self.staging_output_record_batches - .batches - .push(output_batch); + .coalescer + .push_batch(output_batch)?; + if self + .staging_output_record_batches + .coalescer + .has_completed_batch() + { + self.staging_ready = true; + } } } self.streamed_batch.output_indices.clear(); + // Decrement output_size by the number of rows we just processed and added to the coalescer + if rows_processed > self.output_size { + self.output_size = 0; + } else { + self.output_size -= rows_processed; + } + + // After clearing output_indices, if the coalescer has buffered data but hasn't + // reached the target yet, we may need to force a flush to prevent deadlock. + // This is only necessary for filtered joins where partial batches can accumulate + // without reaching the target batch size. + if !self.staging_output_record_batches.coalescer.is_empty() + && !self.staging_ready + && self.filter.is_some() + { + self.staging_output_record_batches + .coalescer + .finish_buffered_batch()?; + if self + .staging_output_record_batches + .coalescer + .has_completed_batch() + { + self.staging_ready = true; + } + } + Ok(()) } fn output_record_batch_and_reset(&mut self) -> Result { - let record_batch = - concat_batches(&self.schema, &self.staging_output_record_batches.batches)?; - (&record_batch).record_output(&self.join_metrics.baseline_metrics()); + self.staging_output_record_batches + .coalescer + .finish_buffered_batch()?; + let record_batch = self + .staging_output_record_batches + .coalescer + .next_completed_batch() + .unwrap_or_else(|| RecordBatch::new_empty(Arc::clone(&self.schema))); // If join filter exists, `self.output_size` is not accurate as we don't know the exact // number of rows in the output record batch. If streamed row joined with buffered rows, // once join filter is applied, the number of output rows may be more than 1. @@ -1475,29 +1588,25 @@ impl SortMergeJoinStream { self.output_size -= record_batch.num_rows(); } - if !(self.filter.is_some() - && matches!( - self.join_type, - JoinType::Left - | JoinType::LeftSemi - | JoinType::Right - | JoinType::RightSemi - | JoinType::LeftAnti - | JoinType::RightAnti - | JoinType::LeftMark - | JoinType::RightMark - | JoinType::Full - )) - { - self.staging_output_record_batches.batches.clear(); - } + // Note: coalescer is already cleared by finish_buffered_batch() above + // The metadata MUST also be cleared since the batches they refer to are gone. + // For non-filtered joins, clear everything immediately. + // For filtered joins, this path shouldn't be hit (they use filter_joined_batch), + // but if it is, we still need to clear to avoid desync. + self.staging_output_record_batches.clear(); Ok(record_batch) } fn filter_joined_batch(&mut self) -> Result { - let record_batch = - concat_batches(&self.schema, &self.staging_output_record_batches.batches)?; + self.staging_output_record_batches + .coalescer + .finish_buffered_batch()?; + let record_batch = self + .staging_output_record_batches + .coalescer + .next_completed_batch() + .unwrap_or_else(|| RecordBatch::new_empty(Arc::clone(&self.schema))); let mut out_indices = self.staging_output_record_batches.row_indices.finish(); let mut out_mask = self.staging_output_record_batches.filter_mask.finish(); let mut batch_ids = &self.staging_output_record_batches.batch_ids; @@ -1515,7 +1624,11 @@ impl SortMergeJoinStream { } if out_mask.is_empty() { - self.staging_output_record_batches.batches.clear(); + // Coalescer already cleared by finish_buffered_batch() above + // Clear metadata only + self.staging_output_record_batches.batch_ids.clear(); + self.staging_output_record_batches.filter_mask = BooleanBuilder::new(); + self.staging_output_record_batches.row_indices = UInt64Builder::new(); return Ok(record_batch); } diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs index f91bffbed78f..842480b4268b 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs @@ -2319,14 +2319,14 @@ fn build_joined_record_batches() -> Result { ])); let mut batches = JoinedRecordBatches { - batches: vec![], + coalescer: arrow::compute::BatchCoalescer::new(Arc::clone(&schema), 8192), filter_mask: BooleanBuilder::new(), row_indices: UInt64Builder::new(), batch_ids: vec![], }; // Insert already prejoined non-filtered rows - batches.batches.push(RecordBatch::try_new( + batches.coalescer.push_batch(RecordBatch::try_new( Arc::clone(&schema), vec![ Arc::new(Int32Array::from(vec![1, 1])), @@ -2334,9 +2334,9 @@ fn build_joined_record_batches() -> Result { Arc::new(Int32Array::from(vec![1, 1])), Arc::new(Int32Array::from(vec![11, 9])), ], - )?); + )?)?; - batches.batches.push(RecordBatch::try_new( + batches.coalescer.push_batch(RecordBatch::try_new( Arc::clone(&schema), vec![ Arc::new(Int32Array::from(vec![1])), @@ -2344,9 +2344,9 @@ fn build_joined_record_batches() -> Result { Arc::new(Int32Array::from(vec![1])), Arc::new(Int32Array::from(vec![12])), ], - )?); + )?)?; - batches.batches.push(RecordBatch::try_new( + batches.coalescer.push_batch(RecordBatch::try_new( Arc::clone(&schema), vec![ Arc::new(Int32Array::from(vec![1, 1])), @@ -2354,9 +2354,9 @@ fn build_joined_record_batches() -> Result { Arc::new(Int32Array::from(vec![1, 1])), Arc::new(Int32Array::from(vec![11, 13])), ], - )?); + )?)?; - batches.batches.push(RecordBatch::try_new( + batches.coalescer.push_batch(RecordBatch::try_new( Arc::clone(&schema), vec![ Arc::new(Int32Array::from(vec![1])), @@ -2364,9 +2364,9 @@ fn build_joined_record_batches() -> Result { Arc::new(Int32Array::from(vec![1])), Arc::new(Int32Array::from(vec![12])), ], - )?); + )?)?; - batches.batches.push(RecordBatch::try_new( + batches.coalescer.push_batch(RecordBatch::try_new( Arc::clone(&schema), vec![ Arc::new(Int32Array::from(vec![1, 1])), @@ -2374,7 +2374,7 @@ fn build_joined_record_batches() -> Result { Arc::new(Int32Array::from(vec![1, 1])), Arc::new(Int32Array::from(vec![12, 11])), ], - )?); + )?)?; let streamed_indices = vec![0, 0]; batches.batch_ids.extend(vec![0; streamed_indices.len()]); @@ -2424,9 +2424,16 @@ fn build_joined_record_batches() -> Result { #[tokio::test] async fn test_left_outer_join_filtered_mask() -> Result<()> { let mut joined_batches = build_joined_record_batches()?; - let schema = joined_batches.batches.first().unwrap().schema(); - let output = concat_batches(&schema, &joined_batches.batches)?; + // Extract the batches from the coalescer + joined_batches.coalescer.finish_buffered_batch()?; + let mut batches_vec = vec![]; + while let Some(batch) = joined_batches.coalescer.next_completed_batch() { + batches_vec.push(batch); + } + let schema = batches_vec.first().unwrap().schema(); + + let output = concat_batches(&schema, &batches_vec)?; let out_mask = joined_batches.filter_mask.finish(); let out_indices = joined_batches.row_indices.finish(); @@ -2631,9 +2638,16 @@ async fn test_left_outer_join_filtered_mask() -> Result<()> { async fn test_semi_join_filtered_mask() -> Result<()> { for join_type in [LeftSemi, RightSemi] { let mut joined_batches = build_joined_record_batches()?; - let schema = joined_batches.batches.first().unwrap().schema(); - let output = concat_batches(&schema, &joined_batches.batches)?; + // Extract the batches from the coalescer + joined_batches.coalescer.finish_buffered_batch()?; + let mut batches_vec = vec![]; + while let Some(batch) = joined_batches.coalescer.next_completed_batch() { + batches_vec.push(batch); + } + let schema = batches_vec.first().unwrap().schema(); + + let output = concat_batches(&schema, &batches_vec)?; let out_mask = joined_batches.filter_mask.finish(); let out_indices = joined_batches.row_indices.finish(); @@ -2806,9 +2820,16 @@ async fn test_semi_join_filtered_mask() -> Result<()> { async fn test_anti_join_filtered_mask() -> Result<()> { for join_type in [LeftAnti, RightAnti] { let mut joined_batches = build_joined_record_batches()?; - let schema = joined_batches.batches.first().unwrap().schema(); - let output = concat_batches(&schema, &joined_batches.batches)?; + // Extract the batches from the coalescer + joined_batches.coalescer.finish_buffered_batch()?; + let mut batches_vec = vec![]; + while let Some(batch) = joined_batches.coalescer.next_completed_batch() { + batches_vec.push(batch); + } + let schema = batches_vec.first().unwrap().schema(); + + let output = concat_batches(&schema, &batches_vec)?; let out_mask = joined_batches.filter_mask.finish(); let out_indices = joined_batches.row_indices.finish();