diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs index 8e6b0428c041..3398494e3c46 100644 --- a/datafusion/src/physical_plan/hash_join.rs +++ b/datafusion/src/physical_plan/hash_join.rs @@ -31,9 +31,9 @@ use arrow::{ datatypes::{TimeUnit, UInt32Type, UInt64Type}, }; use smallvec::{smallvec, SmallVec}; -use std::time::Instant; -use std::{any::Any, collections::HashSet}; +use std::{any::Any, usize}; use std::{hash::Hasher, sync::Arc}; +use std::{time::Instant, vec}; use async_trait::async_trait; use futures::{Stream, StreamExt, TryStreamExt}; @@ -370,6 +370,11 @@ impl ExecutionPlan for HashJoinExec { let on_right = self.on.iter().map(|on| on.1.clone()).collect::>(); let column_indices = self.column_indices_from_schema()?; + let num_rows = left_data.1.num_rows(); + let visited_left_side = match self.join_type { + JoinType::Left => vec![false; num_rows], + JoinType::Inner | JoinType::Right => vec![], + }; Ok(Box::pin(HashJoinStream { schema: self.schema.clone(), on_left, @@ -384,6 +389,8 @@ impl ExecutionPlan for HashJoinExec { num_output_rows: 0, join_time: 0, random_state: self.random_state.clone(), + visited_left_side: visited_left_side, + is_exhausted: false, })) } } @@ -453,6 +460,10 @@ struct HashJoinStream { join_time: usize, /// Random state used for hashing initialization random_state: RandomState, + /// Keeps track of the left side rows whether they are visited + visited_left_side: Vec, // TODO: use a more memory efficient data structure, https://github.com/apache/arrow-datafusion/issues/240 + /// There is nothing to process anymore and left side is processed in case of left join + is_exhausted: bool, } impl RecordBatchStream for HashJoinStream { @@ -473,7 +484,7 @@ fn build_batch_from_indices( left_indices: UInt64Array, right_indices: UInt32Array, column_indices: &[ColumnIndex], -) -> ArrowResult { +) -> ArrowResult<(RecordBatch, UInt64Array)> { // build the columns of the new [RecordBatch]: // 1. pick whether the column is from the left or right // 2. based on the pick, `take` items from the different RecordBatches @@ -489,7 +500,7 @@ fn build_batch_from_indices( }; columns.push(array); } - RecordBatch::try_new(Arc::new(schema.clone()), columns) + RecordBatch::try_new(Arc::new(schema.clone()), columns).map(|x| (x, left_indices)) } #[allow(clippy::too_many_arguments)] @@ -502,7 +513,7 @@ fn build_batch( schema: &Schema, column_indices: &[ColumnIndex], random_state: &RandomState, -) -> ArrowResult { +) -> ArrowResult<(RecordBatch, UInt64Array)> { let (left_indices, right_indices) = build_join_indexes( &left_data, &batch, @@ -617,13 +628,6 @@ fn build_join_indexes( let mut left_indices = UInt64Builder::new(0); let mut right_indices = UInt32Builder::new(0); - // Keep track of which item is visited in the build input - // TODO: this can be stored more efficiently with a marker - // https://issues.apache.org/jira/browse/ARROW-11116 - // TODO: Fix LEFT join with multiple right batches - // https://issues.apache.org/jira/browse/ARROW-10971 - let mut is_visited = HashSet::new(); - // First visit all of the rows for (row, hash_value) in hash_values.iter().enumerate() { if let Some((_, indices)) = @@ -634,20 +638,10 @@ fn build_join_indexes( if equal_rows(i as usize, row, &left_join_values, &keys_values)? { left_indices.append_value(i)?; right_indices.append_value(row as u32)?; - is_visited.insert(i); } } }; } - // Add the remaining left rows to the result set with None on the right side - for (_, indices) in left { - for i in indices.iter() { - if !is_visited.contains(i) { - left_indices.append_slice(&indices)?; - right_indices.append_null()?; - } - } - } Ok((left_indices.finish(), right_indices.finish())) } JoinType::Right => { @@ -1001,6 +995,39 @@ pub fn create_hashes<'a>( Ok(hashes_buffer) } +// Produces a batch for left-side rows that are not marked as being visited during the whole join +fn produce_unmatched( + visited_left_side: &[bool], + schema: &SchemaRef, + column_indices: &[ColumnIndex], + left_data: &JoinLeftData, +) -> ArrowResult { + // Find indices which didn't match any right row (are false) + let unmatched_indices: Vec = visited_left_side + .iter() + .enumerate() + .filter(|&(_, &value)| !value) + .map(|(index, _)| index as u64) + .collect(); + + // generate batches by taking values from the left side and generating columns filled with null on the right side + let indices = UInt64Array::from_iter_values(unmatched_indices); + let num_rows = indices.len(); + let mut columns: Vec> = Vec::with_capacity(schema.fields().len()); + for (idx, column_index) in column_indices.iter().enumerate() { + let array = if column_index.is_left { + let array = left_data.1.column(column_index.index); + compute::take(array.as_ref(), &indices, None).unwrap() + } else { + let datatype = schema.field(idx).data_type(); + arrow::array::new_null_array(datatype, num_rows) + }; + + columns.push(array); + } + RecordBatch::try_new(schema.clone(), columns) +} + impl Stream for HashJoinStream { type Item = ArrowResult; @@ -1025,14 +1052,49 @@ impl Stream for HashJoinStream { ); self.num_input_batches += 1; self.num_input_rows += batch.num_rows(); - if let Ok(ref batch) = result { + if let Ok((ref batch, ref left_side)) = result { self.join_time += start.elapsed().as_millis() as usize; self.num_output_batches += 1; self.num_output_rows += batch.num_rows(); + + match self.join_type { + JoinType::Left => { + left_side.iter().flatten().for_each(|x| { + self.visited_left_side[x as usize] = true; + }); + } + JoinType::Inner | JoinType::Right => {} + } } - Some(result) + Some(result.map(|x| x.0)) } other => { + let start = Instant::now(); + // For the left join, produce rows for unmatched rows + match self.join_type { + JoinType::Left if !self.is_exhausted => { + let result = produce_unmatched( + &self.visited_left_side, + &self.schema, + &self.column_indices, + &self.left_data, + ); + if let Ok(ref batch) = result { + self.num_input_batches += 1; + self.num_input_rows += batch.num_rows(); + if let Ok(ref batch) = result { + self.join_time += + start.elapsed().as_millis() as usize; + self.num_output_batches += 1; + self.num_output_rows += batch.num_rows(); + } + } + self.is_exhausted = true; + return Some(result); + } + JoinType::Left | JoinType::Inner | JoinType::Right => {} + } + debug!( "Processed {} probe-side input batches containing {} rows and \ produced {} output batches containing {} rows in {} ms", @@ -1299,6 +1361,87 @@ mod tests { Ok(()) } + fn build_table_two_batches( + a: (&str, &Vec), + b: (&str, &Vec), + c: (&str, &Vec), + ) -> Arc { + let batch = build_table_i32(a, b, c); + let schema = batch.schema(); + Arc::new( + MemoryExec::try_new(&[vec![batch.clone(), batch]], schema, None).unwrap(), + ) + } + + #[tokio::test] + async fn join_left_multi_batch() { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table_two_batches( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on = &[("b1", "b1")]; + + let join = join(left, right, on, &JoinType::Left).unwrap(); + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); + + let stream = join.execute(0).await.unwrap(); + let batches = common::collect(stream).await.unwrap(); + + let expected = vec![ + "+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | c2 |", + "+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 70 |", + "| 1 | 4 | 7 | 10 | 70 |", + "| 2 | 5 | 8 | 20 | 80 |", + "| 2 | 5 | 8 | 20 | 80 |", + "| 3 | 7 | 9 | | |", + "+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + } + + #[tokio::test] + async fn join_left_empty_right() { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table_i32(("a2", &vec![]), ("b1", &vec![]), ("c2", &vec![])); + let on = &[("b1", "b1")]; + let schema = right.schema(); + let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, None).unwrap()); + let join = join(left, right, on, &JoinType::Left).unwrap(); + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]); + + let stream = join.execute(0).await.unwrap(); + let batches = common::collect(stream).await.unwrap(); + + let expected = vec![ + "+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | c2 |", + "+----+----+----+----+----+", + "| 1 | 4 | 7 | | |", + "| 2 | 5 | 8 | | |", + "| 3 | 7 | 9 | | |", + "+----+----+----+----+----+", + ]; + + assert_batches_sorted_eq!(expected, &batches); + } + #[tokio::test] async fn join_left_one() -> Result<()> { let left = build_table( diff --git a/datafusion/src/physical_plan/hash_utils.rs b/datafusion/src/physical_plan/hash_utils.rs index a38cc092123d..54da1249e5c5 100644 --- a/datafusion/src/physical_plan/hash_utils.rs +++ b/datafusion/src/physical_plan/hash_utils.rs @@ -22,7 +22,7 @@ use arrow::datatypes::{Field, Schema}; use std::collections::HashSet; /// All valid types of joins. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, Eq, PartialEq)] pub enum JoinType { /// Inner join Inner,