diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index a6a3daa2e98a..c4e198ff009c 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -146,8 +146,8 @@ fn create_runtime_env() -> Result { let object_store_provider = DatafusionCliObjectStoreProvider {}; let object_store_registry = ObjectStoreRegistry::new_with_provider(Some(Arc::new(object_store_provider))); - let rn_config = RuntimeConfig::new() - .with_object_store_registry(Arc::new(object_store_registry)); + let rn_config = + RuntimeConfig::new().with_object_store_registry(Arc::new(object_store_registry)); RuntimeEnv::new(rn_config) } diff --git a/datafusion/core/src/physical_plan/joins/hash_join.rs b/datafusion/core/src/physical_plan/joins/hash_join.rs index 82f4ec7c2c35..ae66adeb7711 100644 --- a/datafusion/core/src/physical_plan/joins/hash_join.rs +++ b/datafusion/core/src/physical_plan/joins/hash_join.rs @@ -26,7 +26,7 @@ use arrow::{ DictionaryArray, LargeStringArray, PrimitiveArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampSecondArray, - UInt32BufferBuilder, UInt32Builder, UInt64BufferBuilder, UInt64Builder, + UInt32BufferBuilder, UInt32Builder, UInt64BufferBuilder, }, compute, datatypes::{ @@ -44,7 +44,7 @@ use futures::{ready, Stream, StreamExt, TryStreamExt}; use arrow::array::{new_null_array, Array}; use arrow::datatypes::{ArrowNativeType, DataType}; use arrow::datatypes::{Schema, SchemaRef}; -use arrow::error::Result as ArrowResult; +use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; use arrow::array::{ @@ -85,7 +85,6 @@ use super::{ PartitionMode, }; use log::debug; -use std::cmp; use std::fmt; use std::task::Poll; @@ -402,7 +401,7 @@ impl ExecutionPlan for HashJoinExec { return Err(DataFusionError::Plan(format!( "Invalid HashJoinExec, unsupported PartitionMode {:?} in execute()", PartitionMode::Auto - ))) + ))); } }; @@ -650,9 +649,6 @@ impl RecordBatchStream for HashJoinStream { /// Returns a new [RecordBatch] by combining the `left` and `right` according to `indices`. /// The resulting batch has [Schema] `schema`. -/// # Error -/// This function errors when: -/// * fn build_batch_from_indices( schema: &Schema, left: &RecordBatch, @@ -660,7 +656,7 @@ fn build_batch_from_indices( left_indices: UInt64Array, right_indices: UInt32Array, column_indices: &[ColumnIndex], -) -> ArrowResult<(RecordBatch, UInt64Array)> { +) -> ArrowResult { // build the columns of the new [RecordBatch]: // 1. pick whether the column is from the left or right // 2. based on the pick, `take` items from the different RecordBatches @@ -692,95 +688,77 @@ fn build_batch_from_indices( }; columns.push(array); } - RecordBatch::try_new(Arc::new(schema.clone()), columns).map(|x| (x, left_indices)) + RecordBatch::try_new(Arc::new(schema.clone()), columns) } +// Get left and right indices which is satisfies the on condition (include equal_conditon and filter_in_join) in the Join #[allow(clippy::too_many_arguments)] -fn build_batch( +fn build_join_indices( batch: &RecordBatch, left_data: &JoinLeftData, on_left: &[Column], on_right: &[Column], filter: &Option, - join_type: JoinType, - schema: &Schema, - column_indices: &[ColumnIndex], random_state: &RandomState, null_equals_null: &bool, -) -> ArrowResult<(RecordBatch, UInt64Array)> { - let (left_indices, right_indices) = build_join_indexes( +) -> Result<(UInt64Array, UInt32Array)> { + // Get the indices which is satisfies the equal join condition, like `left.a1 = right.a2` + let (left_indices, right_indices) = build_equal_condition_join_indices( left_data, batch, - join_type, on_left, on_right, random_state, null_equals_null, - ) - .unwrap(); - - let (left_filtered_indices, right_filtered_indices) = if let Some(filter) = filter { - apply_join_filter( + )?; + if let Some(filter) = filter { + // Filter the indices which is satisfies the non-equal join condition, like `left.b1 = 10` + apply_join_filter_to_indices( &left_data.1, batch, - join_type, left_indices, right_indices, filter, ) - .unwrap() } else { - (left_indices, right_indices) - }; - - if matches!(join_type, JoinType::LeftSemi | JoinType::LeftAnti) { - return Ok(( - RecordBatch::new_empty(Arc::new(schema.clone())), - left_filtered_indices, - )); + Ok((left_indices, right_indices)) } - - build_batch_from_indices( - schema, - &left_data.1, - batch, - left_filtered_indices, - right_filtered_indices, - column_indices, - ) } -/// returns a vector with (index from left, index from right). -/// The size of this vector corresponds to the total size of a joined batch -// For a join on column A: -// left right -// batch 1 -// A B A D -// --------------- -// 1 a 3 6 -// 2 b 1 2 -// 3 c 2 4 -// batch 2 -// A B A D -// --------------- -// 1 a 5 10 -// 2 b 2 2 -// 4 d 1 1 -// indices (batch, batch_row) -// left right -// (0, 2) (0, 0) -// (0, 0) (0, 1) -// (0, 1) (0, 2) -// (1, 0) (0, 1) -// (1, 1) (0, 2) -// (0, 1) (1, 1) -// (0, 0) (1, 2) -// (1, 1) (1, 1) -// (1, 0) (1, 2) -fn build_join_indexes( +// Returns the index of equal condition join result: left_indices and right_indices +// On LEFT.b1 = RIGHT.b2 +// LEFT Table: +// a1 b1 c1 +// 1 1 10 +// 3 3 30 +// 5 5 50 +// 7 7 70 +// 9 8 90 +// 11 8 110 +// 13 10 130 +// RIGHT Table: +// a2 b2 c2 +// 2 2 20 +// 4 4 40 +// 6 6 60 +// 8 8 80 +// 10 10 100 +// 12 10 120 +// The result is +// "+----+----+-----+----+----+-----+", +// "| a1 | b1 | c1 | a2 | b2 | c2 |", +// "+----+----+-----+----+----+-----+", +// "| 11 | 8 | 110 | 8 | 8 | 80 |", +// "| 13 | 10 | 130 | 10 | 10 | 100 |", +// "| 13 | 10 | 130 | 12 | 10 | 120 |", +// "| 9 | 8 | 90 | 8 | 8 | 80 |", +// "+----+----+-----+----+----+-----+" +// And the result of left and right indices +// left indices: 5, 6, 6, 4 +// right indices: 3, 4, 5, 3 +fn build_equal_condition_join_indices( left_data: &JoinLeftData, right: &RecordBatch, - join_type: JoinType, left_on: &[Column], right_on: &[Column], random_state: &RandomState, @@ -797,225 +775,55 @@ fn build_join_indexes( let hashes_buffer = &mut vec![0; keys_values[0].len()]; let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?; let left = &left_data.0; + // Using a buffer builder to avoid slower normal builder + let mut left_indices = UInt64BufferBuilder::new(0); + let mut right_indices = UInt32BufferBuilder::new(0); - match join_type { - JoinType::Inner | JoinType::LeftSemi | JoinType::LeftAnti => { - // Using a buffer builder to avoid slower normal builder - let mut left_indices = UInt64BufferBuilder::new(0); - let mut right_indices = UInt32BufferBuilder::new(0); - - // Visit all of the right rows - for (row, hash_value) in hash_values.iter().enumerate() { - // Get the hash and find it in the build index - - // For every item on the left and right we check if it matches - // This possibly contains rows with hash collisions, - // So we have to check here whether rows are equal or not - if let Some((_, indices)) = - left.0.get(*hash_value, |(hash, _)| *hash_value == *hash) - { - for &i in indices { - // Check hash collisions - if equal_rows( - i as usize, - row, - &left_join_values, - &keys_values, - *null_equals_null, - )? { - left_indices.append(i); - right_indices.append(row as u32); - } - } - } - } - let left = ArrayData::builder(DataType::UInt64) - .len(left_indices.len()) - .add_buffer(left_indices.finish()) - .build() - .unwrap(); - let right = ArrayData::builder(DataType::UInt32) - .len(right_indices.len()) - .add_buffer(right_indices.finish()) - .build() - .unwrap(); - - Ok(( - PrimitiveArray::::from(left), - PrimitiveArray::::from(right), - )) - } - JoinType::RightSemi => { - let mut left_indices = UInt64BufferBuilder::new(0); - let mut right_indices = UInt32BufferBuilder::new(0); - - // Visit all of the right rows - for (row, hash_value) in hash_values.iter().enumerate() { - // Get the hash and find it in the build index - - // For every item on the left and right we check if it matches - // This possibly contains rows with hash collisions, - // So we have to check here whether rows are equal or not - // We only produce one row if there is a match - if let Some((_, indices)) = - left.0.get(*hash_value, |(hash, _)| *hash_value == *hash) - { - for &i in indices { - // Check hash collisions - if equal_rows( - i as usize, - row, - &left_join_values, - &keys_values, - *null_equals_null, - )? { - left_indices.append(i); - right_indices.append(row as u32); - break; - } - } - } - } - - let left = ArrayData::builder(DataType::UInt64) - .len(left_indices.len()) - .add_buffer(left_indices.finish()) - .build() - .unwrap(); - let right = ArrayData::builder(DataType::UInt32) - .len(right_indices.len()) - .add_buffer(right_indices.finish()) - .build() - .unwrap(); - - Ok(( - PrimitiveArray::::from(left), - PrimitiveArray::::from(right), - )) - } - JoinType::RightAnti => { - let mut left_indices = UInt64BufferBuilder::new(0); - let mut right_indices = UInt32BufferBuilder::new(0); - - // Visit all of the right rows - for (row, hash_value) in hash_values.iter().enumerate() { - // Get the hash and find it in the build index - - // For every item on the left and right we check if it doesn't match - // This possibly contains rows with hash collisions, - // So we have to check here whether rows are equal or not - // We only produce one row if there is no match - let matches = left.0.get(*hash_value, |(hash, _)| *hash_value == *hash); - let mut no_match = true; - match matches { - Some((_, indices)) => { - for &i in indices { - // Check hash collisions - if equal_rows( - i as usize, - row, - &left_join_values, - &keys_values, - *null_equals_null, - )? { - no_match = false; - break; - } - } - } - None => no_match = true, - }; - if no_match { + // Visit all of the right rows + for (row, hash_value) in hash_values.iter().enumerate() { + // Get the hash and find it in the build index + + // For every item on the left and right we check if it matches + // This possibly contains rows with hash collisions, + // So we have to check here whether rows are equal or not + if let Some((_, indices)) = + left.0.get(*hash_value, |(hash, _)| *hash_value == *hash) + { + for &i in indices { + // Check hash collisions + if equal_rows( + i as usize, + row, + &left_join_values, + &keys_values, + *null_equals_null, + )? { + left_indices.append(i); right_indices.append(row as u32); } } - - let left = ArrayData::builder(DataType::UInt64) - .len(left_indices.len()) - .add_buffer(left_indices.finish()) - .build() - .unwrap(); - let right = ArrayData::builder(DataType::UInt32) - .len(right_indices.len()) - .add_buffer(right_indices.finish()) - .build() - .unwrap(); - - Ok(( - PrimitiveArray::::from(left), - PrimitiveArray::::from(right), - )) - } - JoinType::Left => { - let mut left_indices = UInt64Builder::with_capacity(0); - let mut right_indices = UInt32Builder::with_capacity(0); - - // First visit all of the rows - for (row, hash_value) in hash_values.iter().enumerate() { - if let Some((_, indices)) = - left.0.get(*hash_value, |(hash, _)| *hash_value == *hash) - { - for &i in indices { - // Collision check - if equal_rows( - i as usize, - row, - &left_join_values, - &keys_values, - *null_equals_null, - )? { - left_indices.append_value(i); - right_indices.append_value(row as u32); - } - } - }; - } - Ok((left_indices.finish(), right_indices.finish())) - } - JoinType::Right | JoinType::Full => { - let mut left_indices = UInt64Builder::with_capacity(0); - let mut right_indices = UInt32Builder::with_capacity(0); - - for (row, hash_value) in hash_values.iter().enumerate() { - match left.0.get(*hash_value, |(hash, _)| *hash_value == *hash) { - Some((_, indices)) => { - let mut no_match = true; - for &i in indices { - if equal_rows( - i as usize, - row, - &left_join_values, - &keys_values, - *null_equals_null, - )? { - left_indices.append_value(i); - right_indices.append_value(row as u32); - no_match = false; - } - } - // If no rows matched left, still must keep the right - // with all nulls for left - if no_match { - left_indices.append_null(); - right_indices.append_value(row as u32); - } - } - None => { - // when no match, add the row with None for the left side - left_indices.append_null(); - right_indices.append_value(row as u32); - } - } - } - Ok((left_indices.finish(), right_indices.finish())) } } + let left = ArrayData::builder(DataType::UInt64) + .len(left_indices.len()) + .add_buffer(left_indices.finish()) + .build() + .unwrap(); + let right = ArrayData::builder(DataType::UInt32) + .len(right_indices.len()) + .add_buffer(right_indices.finish()) + .build() + .unwrap(); + + Ok(( + PrimitiveArray::::from(left), + PrimitiveArray::::from(right), + )) } -fn apply_join_filter( +fn apply_join_filter_to_indices( left: &RecordBatch, right: &RecordBatch, - join_type: JoinType, left_indices: UInt64Array, right_indices: UInt32Array, filter: &JoinFilter, @@ -1024,7 +832,7 @@ fn apply_join_filter( return Ok((left_indices, right_indices)); }; - let (intermediate_batch, _) = build_batch_from_indices( + let intermediate_batch = build_batch_from_indices( filter.schema(), left, right, @@ -1032,87 +840,20 @@ fn apply_join_filter( PrimitiveArray::from(right_indices.data().clone()), filter.column_indices(), )?; + let filter_result = filter + .expression() + .evaluate(&intermediate_batch)? + .into_array(intermediate_batch.num_rows()); + let mask = as_boolean_array(&filter_result)?; + + let left_filtered = PrimitiveArray::::from( + compute::filter(&left_indices, mask)?.data().clone(), + ); + let right_filtered = PrimitiveArray::::from( + compute::filter(&right_indices, mask)?.data().clone(), + ); - match join_type { - JoinType::Inner - | JoinType::Left - | JoinType::LeftAnti - | JoinType::RightAnti - | JoinType::LeftSemi - | JoinType::RightSemi => { - // For both INNER and LEFT joins, input arrays contains only indices for matched data. - // Due to this fact it's correct to simply apply filter to intermediate batch and return - // indices for left/right rows satisfying filter predicate - let filter_result = filter - .expression() - .evaluate(&intermediate_batch)? - .into_array(intermediate_batch.num_rows()); - let mask = as_boolean_array(&filter_result)?; - - let left_filtered = PrimitiveArray::::from( - compute::filter(&left_indices, mask)?.data().clone(), - ); - let right_filtered = PrimitiveArray::::from( - compute::filter(&right_indices, mask)?.data().clone(), - ); - - Ok((left_filtered, right_filtered)) - } - JoinType::Right | JoinType::Full => { - // In case of RIGHT and FULL join, left_indices could contain null values - these rows, - // where no match has been found, should retain in result arrays (thus join condition is satified) - // - // So, filter should be applied only to matched rows, and in case right (outer, batch) index - // doesn't have a single match after filtering, it should be added back to result arrays as - // (null, idx) pair. - let has_match = compute::is_not_null(&left_indices)?; - let filter_result = filter - .expression() - .evaluate_selection(&intermediate_batch, &has_match)? - .into_array(intermediate_batch.num_rows()); - let mask = as_boolean_array(&filter_result)?; - - let mut left_rebuilt = UInt64Builder::with_capacity(0); - let mut right_rebuilt = UInt32Builder::with_capacity(0); - - (0..right_indices.len()) - .into_iter() - .try_fold::<_, _, Result<_>>( - (right_indices.value(0), false), - |state, pos| { - // If row index changes and row doesnt have match - // append (idx, null) - if right_indices.value(pos) != state.0 && !state.1 { - right_rebuilt.append_value(state.0); - left_rebuilt.append_null(); - } - // If has match append matched row indices - if mask.value(pos) { - right_rebuilt.append_value(right_indices.value(pos)); - left_rebuilt.append_value(left_indices.value(pos)); - }; - - // Calculate if current row index has match - let has_match = if right_indices.value(pos) != state.0 { - mask.value(pos) - } else { - cmp::max(mask.value(pos), state.1) - }; - - Ok((right_indices.value(pos), has_match)) - }, - ) - // Append last row from right side if no match found - .map(|(row_idx, has_match)| { - if !has_match { - right_rebuilt.append_value(row_idx); - left_rebuilt.append_null(); - } - })?; - - Ok((left_rebuilt.finish(), right_rebuilt.finish())) - } - } + Ok((left_filtered, right_filtered)) } macro_rules! equal_rows_elem { @@ -1340,11 +1081,11 @@ fn equal_rows( } }, DataType::Dictionary(key_type, value_type) - if *value_type.as_ref() == DataType::Utf8 => - { - match key_type.as_ref() { - DataType::Int8 => { - equal_rows_elem_with_string_dict!( + if *value_type.as_ref() == DataType::Utf8 => + { + match key_type.as_ref() { + DataType::Int8 => { + equal_rows_elem_with_string_dict!( Int8Type, l, r, @@ -1352,9 +1093,9 @@ fn equal_rows( right, null_equals_null ) - } - DataType::Int16 => { - equal_rows_elem_with_string_dict!( + } + DataType::Int16 => { + equal_rows_elem_with_string_dict!( Int16Type, l, r, @@ -1362,9 +1103,9 @@ fn equal_rows( right, null_equals_null ) - } - DataType::Int32 => { - equal_rows_elem_with_string_dict!( + } + DataType::Int32 => { + equal_rows_elem_with_string_dict!( Int32Type, l, r, @@ -1372,9 +1113,9 @@ fn equal_rows( right, null_equals_null ) - } - DataType::Int64 => { - equal_rows_elem_with_string_dict!( + } + DataType::Int64 => { + equal_rows_elem_with_string_dict!( Int64Type, l, r, @@ -1382,9 +1123,9 @@ fn equal_rows( right, null_equals_null ) - } - DataType::UInt8 => { - equal_rows_elem_with_string_dict!( + } + DataType::UInt8 => { + equal_rows_elem_with_string_dict!( UInt8Type, l, r, @@ -1392,9 +1133,9 @@ fn equal_rows( right, null_equals_null ) - } - DataType::UInt16 => { - equal_rows_elem_with_string_dict!( + } + DataType::UInt16 => { + equal_rows_elem_with_string_dict!( UInt16Type, l, r, @@ -1402,9 +1143,9 @@ fn equal_rows( right, null_equals_null ) - } - DataType::UInt32 => { - equal_rows_elem_with_string_dict!( + } + DataType::UInt32 => { + equal_rows_elem_with_string_dict!( UInt32Type, l, r, @@ -1412,9 +1153,9 @@ fn equal_rows( right, null_equals_null ) - } - DataType::UInt64 => { - equal_rows_elem_with_string_dict!( + } + DataType::UInt64 => { + equal_rows_elem_with_string_dict!( UInt64Type, l, r, @@ -1422,16 +1163,16 @@ fn equal_rows( right, null_equals_null ) - } - _ => { - // should not happen - err = Some(Err(DataFusionError::Internal( - "Unsupported data type in hasher".to_string(), - ))); - false + } + _ => { + // should not happen + err = Some(Err(DataFusionError::Internal( + "Unsupported data type in hasher".to_string(), + ))); + false + } } } - } other => { // This is internal because we should have caught this before. err = Some(Err(DataFusionError::Internal(format!( @@ -1445,44 +1186,136 @@ fn equal_rows( err.unwrap_or(Ok(res)) } -// Produces a batch for left-side rows that have/have not been matched during the whole join -fn produce_from_matched( - visited_left_side: &BooleanBufferBuilder, - schema: &SchemaRef, - column_indices: &[ColumnIndex], - left_data: &JoinLeftData, - unmatched: bool, -) -> ArrowResult { - let indices = if unmatched { - UInt64Array::from_iter_values( - (0..visited_left_side.len()) - .filter_map(|v| (!visited_left_side.get_bit(v)).then_some(v as u64)), - ) +// The input is the matched indices for left and right. +// Adjust the indices according to the join type +fn adjust_indices_by_join_type( + left_indices: UInt64Array, + right_indices: UInt32Array, + count_right_batch: usize, + join_type: JoinType, +) -> (UInt64Array, UInt32Array) { + match join_type { + JoinType::Inner => { + // matched + (left_indices, right_indices) + } + JoinType::Left => { + // matched + (left_indices, right_indices) + // unmatched left row will be produced in the end of loop, and it has been set in the left visited bitmap + } + JoinType::Right | JoinType::Full => { + // matched + // unmatched right row will be produced in this batch + let right_unmatched_indices = + get_anti_indices(count_right_batch, &right_indices); + // combine the matched and unmatched right result together + append_right_indices(left_indices, right_indices, right_unmatched_indices) + } + JoinType::RightSemi => { + // need to remove the duplicated record in the right side + let right_indices = get_semi_indices(count_right_batch, &right_indices); + // the left_indices will not be used later for the `right semi` join + (left_indices, right_indices) + } + JoinType::RightAnti => { + // need to remove the duplicated record in the right side + // get the anti index for the right side + let right_indices = get_anti_indices(count_right_batch, &right_indices); + // the left_indices will not be used later for the `right anti` join + (left_indices, right_indices) + } + JoinType::LeftSemi | JoinType::LeftAnti => { + // matched or unmatched left row will be produced in the end of loop + // TODO: left semi can be optimized. + // When visit the right batch, we can output the matched left row and don't need to wait the end of loop + ( + UInt64Array::from_iter_values(vec![]), + UInt32Array::from_iter_values(vec![]), + ) + } + } +} + +fn append_right_indices( + left_indices: UInt64Array, + right_indices: UInt32Array, + appended_right_indices: UInt32Array, +) -> (UInt64Array, UInt32Array) { + // left_indices, right_indices and appended_right_indices must not contain the null value + if appended_right_indices.is_empty() { + (left_indices, right_indices) } else { - UInt64Array::from_iter_values( - (0..visited_left_side.len()) - .filter_map(|v| (visited_left_side.get_bit(v)).then_some(v as u64)), - ) - }; + let unmatched_size = appended_right_indices.len(); + // the new left indices: left_indices + null array + // the new right indices: right_indices + appended_right_indices + let new_left_indices = left_indices + .iter() + .chain(std::iter::repeat(None).take(unmatched_size)) + .collect::(); + let new_right_indices = right_indices + .iter() + .chain(appended_right_indices.iter()) + .collect::(); + (new_left_indices, new_right_indices) + } +} - // generate batches by taking values from the left side and generating columns filled with null on the right side - let num_rows = indices.len(); - let mut columns: Vec> = Vec::with_capacity(schema.fields().len()); - for (idx, column_index) in column_indices.iter().enumerate() { - let array = match column_index.side { - JoinSide::Left => { - let array = left_data.1.column(column_index.index); - compute::take(array.as_ref(), &indices, None).unwrap() - } - JoinSide::Right => { - let datatype = schema.field(idx).data_type(); - new_null_array(datatype, num_rows) - } - }; +fn get_anti_indices(row_count: usize, input_indices: &UInt32Array) -> UInt32Array { + let mut bitmap = BooleanBufferBuilder::new(row_count); + bitmap.append_n(row_count, false); + input_indices.iter().flatten().for_each(|v| { + bitmap.set_bit(v as usize, true); + }); + + // get the anti index + (0..row_count) + .filter_map(|idx| (!bitmap.get_bit(idx)).then_some(idx as u32)) + .collect::() +} - columns.push(array); - } - RecordBatch::try_new(schema.clone(), columns) +fn get_semi_indices(row_count: usize, input_indices: &UInt32Array) -> UInt32Array { + let mut bitmap = BooleanBufferBuilder::new(row_count); + bitmap.append_n(row_count, false); + input_indices.iter().flatten().for_each(|v| { + bitmap.set_bit(v as usize, true); + }); + + // get the semi index + (0..row_count) + .filter_map(|idx| (bitmap.get_bit(idx)).then_some(idx as u32)) + .collect::() +} + +fn need_produce_result_in_final(join_type: JoinType) -> bool { + matches!( + join_type, + JoinType::Left | JoinType::LeftAnti | JoinType::LeftSemi | JoinType::Full + ) +} + +fn get_final_indices( + left_bit_map: &BooleanBufferBuilder, + join_type: JoinType, +) -> (UInt64Array, UInt32Array) { + let left_size = left_bit_map.len(); + let left_indices = if join_type == JoinType::LeftSemi { + (0..left_size) + .filter_map(|idx| (left_bit_map.get_bit(idx)).then_some(idx as u64)) + .collect::() + } else { + // just for `Left`, `LeftAnti` and `Full` join + // `LeftAnti`, `Left` and `Full` will produce the unmatched left row finally + (0..left_size) + .filter_map(|idx| (!left_bit_map.get_bit(idx)).then_some(idx as u64)) + .collect::() + }; + // right_indices + // all the element in the right side is None + let mut builder = UInt32Builder::with_capacity(left_indices.len()); + builder.append_nulls(left_indices.len()); + let right_indices = builder.finish(); + (left_indices, right_indices) } impl HashJoinStream { @@ -1501,108 +1334,115 @@ impl HashJoinStream { let visited_left_side = self.visited_left_side.get_or_insert_with(|| { let num_rows = left_data.1.num_rows(); - match self.join_type { - JoinType::Left - | JoinType::Full - | JoinType::LeftSemi - | JoinType::LeftAnti => { - let mut buffer = BooleanBufferBuilder::new(num_rows); - - buffer.append_n(num_rows, false); - - buffer - } - JoinType::Inner - | JoinType::Right - | JoinType::RightSemi - | JoinType::RightAnti => BooleanBufferBuilder::new(0), + if need_produce_result_in_final(self.join_type) { + // these join type need the bitmap to identify which row has be matched or unmatched. + // For the `left semi` join, need to use the bitmap to produce the matched row in the left side + // For the `left` join, need to use the bitmap to produce the unmatched row in the left side with null + // For the `left anti` join, need to use the bitmap to produce the unmatched row in the left side + // For the `full` join, need to use the bitmap to produce the unmatched row in the left side with null + let mut buffer = BooleanBufferBuilder::new(num_rows); + buffer.append_n(num_rows, false); + buffer + } else { + BooleanBufferBuilder::new(0) } }); self.right .poll_next_unpin(cx) .map(|maybe_batch| match maybe_batch { + // one right batch in the join loop Some(Ok(batch)) => { + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); let timer = self.join_metrics.probe_time.timer(); - let result = build_batch( + + // get the matched two indices for the on condition + let left_right_indices = build_join_indices( &batch, left_data, &self.on_left, &self.on_right, &self.filter, - self.join_type, - &self.schema, - &self.column_indices, &self.random_state, &self.null_equals_null, ); - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - if let Ok((ref batch, ref left_side)) = result { - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - - match self.join_type { - JoinType::Left - | JoinType::Full - | JoinType::LeftSemi - | JoinType::LeftAnti => { + + let result = match left_right_indices { + Ok((left_side, right_side)) => { + // set the left bitmap + // and only left, full, left semi, left anti need the left bitmap + if need_produce_result_in_final(self.join_type) { left_side.iter().flatten().for_each(|x| { visited_left_side.set_bit(x as usize, true); }); } - JoinType::Inner - | JoinType::Right - | JoinType::RightSemi - | JoinType::RightAnti => {} + + // adjust the two side indices base on the join type + let (left_side, right_side) = adjust_indices_by_join_type( + left_side, + right_side, + batch.num_rows(), + self.join_type, + ); + + let result = build_batch_from_indices( + &self.schema, + &left_data.1, + &batch, + left_side, + right_side, + &self.column_indices, + ); + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + Some(result) } - } - let final_result = Some(result.map(|x| x.0)); + Err(_) => { + // TODO why the type of result stream is `Result`, and not the `DataFusionError` + Some(Err(ArrowError::ComputeError( + "Build left right indices error".to_string(), + ))) + } + }; timer.done(); - final_result + result } - Some(err) => Some(err), None => { let timer = self.join_metrics.probe_time.timer(); - // For the left join, produce rows for unmatched rows - match self.join_type { - JoinType::Left - | JoinType::Full - | JoinType::LeftSemi - | JoinType::LeftAnti - if !self.is_exhausted => - { - let result = produce_from_matched( - visited_left_side, - &self.schema, - &self.column_indices, - left_data, - self.join_type != JoinType::LeftSemi, - ); - if let Ok(ref batch) = result { - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - if let Ok(ref batch) = result { - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } - } - timer.done(); - self.is_exhausted = true; - return Some(result); + if need_produce_result_in_final(self.join_type) && !self.is_exhausted + { + // use the global left bitmap to produce the left indices and right indices + let (left_side, right_side) = + get_final_indices(visited_left_side, self.join_type); + let empty_right_batch = + RecordBatch::new_empty(self.right.schema()); + // use the left and right indices to produce the batch result + let result = build_batch_from_indices( + &self.schema, + &left_data.1, + &empty_right_batch, + left_side, + right_side, + &self.column_indices, + ); + + if let Ok(ref batch) = result { + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); } - JoinType::Left - | JoinType::Full - | JoinType::LeftSemi - | JoinType::RightSemi - | JoinType::LeftAnti - | JoinType::RightAnti - | JoinType::Inner - | JoinType::Right => {} + timer.done(); + self.is_exhausted = true; + Some(result) + } else { + // end of the join loop + None } - - None } + Some(err) => Some(err), }) } } @@ -1629,6 +1469,7 @@ mod tests { test::exec::MockExec, test::{build_table_i32, columns}, }; + use arrow::array::UInt64Builder; use arrow::datatypes::Field; use arrow::error::ArrowError; use datafusion_expr::Operator; @@ -1695,12 +1536,12 @@ mod tests { context: Arc, ) -> Result<(Vec, Vec)> { let join = join(left, right, on, join_type, null_equals_null)?; - let columns = columns(&join.schema()); + let columns_header = columns(&join.schema()); let stream = join.execute(0, context)?; let batches = common::collect(stream).await?; - Ok((columns, batches)) + Ok((columns_header, batches)) } async fn partitioned_join_collect( @@ -2314,23 +2155,36 @@ mod tests { Ok(()) } + fn build_semi_anti_left_table() -> Arc { + // just two line match + // b1 = 10 + build_table( + ("a1", &vec![1, 3, 5, 7, 9, 11, 13]), + ("b1", &vec![1, 3, 5, 7, 8, 8, 10]), + ("c1", &vec![10, 30, 50, 70, 90, 110, 130]), + ) + } + + fn build_semi_anti_right_table() -> Arc { + // just two line match + // b2 = 10 + build_table( + ("a2", &vec![8, 12, 6, 2, 10, 4]), + ("b2", &vec![8, 10, 6, 2, 10, 4]), + ("c2", &vec![20, 40, 60, 80, 100, 120]), + ) + } + #[tokio::test] async fn join_left_semi() -> Result<()> { let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); - let left = build_table( - ("a1", &vec![1, 2, 2, 3]), - ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right - ("c1", &vec![7, 8, 8, 9]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30, 40]), - ("b1", &vec![4, 5, 6, 5]), // 5 is double on the right - ("c2", &vec![70, 80, 90, 100]), - ); + let left = build_semi_anti_left_table(); + let right = build_semi_anti_right_table(); + // left_table left semi join right_table on left_table.b1 = right_table.b2 let on = vec![( Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b1", &right.schema())?, + Column::new_with_schema("b2", &right.schema())?, )]; let join = join(left, right, on, &JoinType::LeftSemi, false)?; @@ -2341,14 +2195,15 @@ mod tests { let stream = join.execute(0, task_ctx)?; let batches = common::collect(stream).await?; + // ignore the order let expected = vec![ - "+----+----+----+", - "| a1 | b1 | c1 |", - "+----+----+----+", - "| 1 | 4 | 7 |", - "| 2 | 5 | 8 |", - "| 2 | 5 | 8 |", - "+----+----+----+", + "+----+----+-----+", + "| a1 | b1 | c1 |", + "+----+----+-----+", + "| 11 | 8 | 110 |", + "| 13 | 10 | 130 |", + "| 9 | 8 | 90 |", + "+----+----+-----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -2359,24 +2214,17 @@ mod tests { async fn join_left_semi_with_filter() -> Result<()> { let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); - let left = build_table( - ("a1", &vec![1, 2, 2, 3]), - ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right - ("c1", &vec![7, 8, 8, 9]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30, 40]), - ("b1", &vec![4, 5, 6, 5]), // 5 is double on the right - ("c2", &vec![70, 80, 90, 100]), - ); + let left = build_semi_anti_left_table(); + let right = build_semi_anti_right_table(); + + // left_table left semi join right_table on left_table.b1 = right_table.b2 and right_table.a2 != 10 let on = vec![( Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b1", &right.schema())?, + Column::new_with_schema("b2", &right.schema())?, )]; - // build filter right.b2 > 4 let column_indices = vec![ColumnIndex { - index: 1, + index: 0, side: JoinSide::Right, }]; let intermediate_schema = @@ -2384,28 +2232,65 @@ mod tests { let filter_expression = Arc::new(BinaryExpr::new( Arc::new(Column::new("x", 0)), - Operator::Gt, - Arc::new(Literal::new(ScalarValue::Int32(Some(4)))), + Operator::NotEq, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), )) as Arc; + let filter = JoinFilter::new( + filter_expression, + column_indices.clone(), + intermediate_schema.clone(), + ); + + let join = join_with_filter( + left.clone(), + right.clone(), + on.clone(), + filter, + &JoinType::LeftSemi, + false, + )?; + + let columns_header = columns(&join.schema()); + assert_eq!(columns_header.clone(), vec!["a1", "b1", "c1"]); + + let stream = join.execute(0, task_ctx.clone())?; + let batches = common::collect(stream).await?; + + let expected = vec![ + "+----+----+-----+", + "| a1 | b1 | c1 |", + "+----+----+-----+", + "| 11 | 8 | 110 |", + "| 13 | 10 | 130 |", + "| 9 | 8 | 90 |", + "+----+----+-----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + // left_table left semi join right_table on left_table.b1 = right_table.b2 and right_table.a2 > 10 + let filter_expression = Arc::new(BinaryExpr::new( + Arc::new(Column::new("x", 0)), + Operator::Gt, + Arc::new(Literal::new(ScalarValue::Int32(Some(10)))), + )) as Arc; let filter = JoinFilter::new(filter_expression, column_indices, intermediate_schema); let join = join_with_filter(left, right, on, filter, &JoinType::LeftSemi, false)?; - let columns = columns(&join.schema()); - assert_eq!(columns, vec!["a1", "b1", "c1"]); + let columns_header = columns(&join.schema()); + assert_eq!(columns_header, vec!["a1", "b1", "c1"]); let stream = join.execute(0, task_ctx)?; let batches = common::collect(stream).await?; let expected = vec![ - "+----+----+----+", - "| a1 | b1 | c1 |", - "+----+----+----+", - "| 2 | 5 | 8 |", - "| 2 | 5 | 8 |", - "+----+----+----+", + "+----+----+-----+", + "| a1 | b1 | c1 |", + "+----+----+-----+", + "| 13 | 10 | 130 |", + "+----+----+-----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -2416,38 +2301,31 @@ mod tests { async fn join_right_semi() -> Result<()> { let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); - let left = build_table( - ("a2", &vec![10, 20, 30, 40]), - ("b2", &vec![4, 5, 6, 5]), // 5 is double on the left - ("c2", &vec![70, 80, 90, 100]), - ); - let right = build_table( - ("a1", &vec![1, 2, 2, 3]), - ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the left - ("c1", &vec![7, 8, 8, 9]), - ); + let left = build_semi_anti_left_table(); + let right = build_semi_anti_right_table(); + // left_table right semi join right_table on left_table.b1 = right_table.b2 let on = vec![( - Column::new_with_schema("b2", &left.schema())?, - Column::new_with_schema("b1", &right.schema())?, + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, )]; let join = join(left, right, on, &JoinType::RightSemi, false)?; let columns = columns(&join.schema()); - assert_eq!(columns, vec!["a1", "b1", "c1"]); + assert_eq!(columns, vec!["a2", "b2", "c2"]); let stream = join.execute(0, task_ctx)?; let batches = common::collect(stream).await?; let expected = vec![ - "+----+----+----+", - "| a1 | b1 | c1 |", - "+----+----+----+", - "| 1 | 4 | 7 |", - "| 2 | 5 | 8 |", - "| 2 | 5 | 8 |", - "+----+----+----+", + "+----+----+-----+", + "| a2 | b2 | c2 |", + "+----+----+-----+", + "| 10 | 10 | 100 |", + "| 12 | 10 | 40 |", + "| 8 | 8 | 20 |", + "+----+----+-----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -2458,34 +2336,65 @@ mod tests { async fn join_right_semi_with_filter() -> Result<()> { let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); - let left = build_table( - ("a2", &vec![10, 20, 30, 40]), - ("b2", &vec![4, 5, 6, 5]), // 5 is double on the left - ("c2", &vec![70, 80, 90, 100]), - ); - let right = build_table( - ("a1", &vec![1, 2, 2, 3]), - ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the left - ("c1", &vec![7, 8, 8, 9]), - ); + let left = build_semi_anti_left_table(); + let right = build_semi_anti_right_table(); + // left_table right semi join right_table on left_table.b1 = right_table.b2 on left_table.a1!=9 let on = vec![( - Column::new_with_schema("b2", &left.schema())?, - Column::new_with_schema("b1", &right.schema())?, + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, )]; - // build filter left.b2 > 4 let column_indices = vec![ColumnIndex { - index: 1, + index: 0, side: JoinSide::Left, }]; let intermediate_schema = Schema::new(vec![Field::new("x", DataType::Int32, true)]); + let filter_expression = Arc::new(BinaryExpr::new( + Arc::new(Column::new("x", 0)), + Operator::NotEq, + Arc::new(Literal::new(ScalarValue::Int32(Some(9)))), + )) as Arc; + + let filter = JoinFilter::new( + filter_expression, + column_indices.clone(), + intermediate_schema.clone(), + ); + + let join = join_with_filter( + left.clone(), + right.clone(), + on.clone(), + filter, + &JoinType::RightSemi, + false, + )?; + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a2", "b2", "c2"]); + + let stream = join.execute(0, task_ctx.clone())?; + let batches = common::collect(stream).await?; + + let expected = vec![ + "+----+----+-----+", + "| a2 | b2 | c2 |", + "+----+----+-----+", + "| 10 | 10 | 100 |", + "| 12 | 10 | 40 |", + "| 8 | 8 | 20 |", + "+----+----+-----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + // left_table right semi join right_table on left_table.b1 = right_table.b2 on left_table.a1!=9 let filter_expression = Arc::new(BinaryExpr::new( Arc::new(Column::new("x", 0)), Operator::Gt, - Arc::new(Literal::new(ScalarValue::Int32(Some(4)))), + Arc::new(Literal::new(ScalarValue::Int32(Some(11)))), )) as Arc; let filter = @@ -2493,20 +2402,16 @@ mod tests { let join = join_with_filter(left, right, on, filter, &JoinType::RightSemi, false)?; - - let columns = columns(&join.schema()); - assert_eq!(columns, vec!["a1", "b1", "c1"]); - let stream = join.execute(0, task_ctx)?; let batches = common::collect(stream).await?; let expected = vec![ - "+----+----+----+", - "| a1 | b1 | c1 |", - "+----+----+----+", - "| 2 | 5 | 8 |", - "| 2 | 5 | 8 |", - "+----+----+----+", + "+----+----+-----+", + "| a2 | b2 | c2 |", + "+----+----+-----+", + "| 10 | 10 | 100 |", + "| 12 | 10 | 40 |", + "+----+----+-----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -2517,19 +2422,12 @@ mod tests { async fn join_left_anti() -> Result<()> { let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); - let left = build_table( - ("a1", &vec![1, 2, 2, 3, 5]), - ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right - ("c1", &vec![7, 8, 8, 9, 11]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30, 40]), - ("b1", &vec![4, 5, 6, 5]), // 5 is double on the right - ("c2", &vec![70, 80, 90, 100]), - ); + let left = build_semi_anti_left_table(); + let right = build_semi_anti_right_table(); + // left_table left anti join right_table on left_table.b1 = right_table.b2 let on = vec![( Column::new_with_schema("b1", &left.schema())?, - Column::new_with_schema("b1", &right.schema())?, + Column::new_with_schema("b2", &right.schema())?, )]; let join = join(left, right, on, &JoinType::LeftAnti, false)?; @@ -2544,8 +2442,10 @@ mod tests { "+----+----+----+", "| a1 | b1 | c1 |", "+----+----+----+", - "| 3 | 7 | 9 |", - "| 5 | 7 | 11 |", + "| 1 | 1 | 10 |", + "| 3 | 3 | 30 |", + "| 5 | 5 | 50 |", + "| 7 | 7 | 70 |", "+----+----+----+", ]; assert_batches_sorted_eq!(expected, &batches); @@ -2553,102 +2453,225 @@ mod tests { } #[tokio::test] - async fn join_right_anti() -> Result<()> { + async fn join_left_anti_with_filter() -> Result<()> { let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); - let right = build_table( - ("a1", &vec![1, 2, 2, 3, 5]), - ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right - ("c1", &vec![7, 8, 8, 9, 11]), - ); - let left = build_table( - ("a2", &vec![10, 20, 30, 40]), - ("b2", &vec![4, 5, 6, 5]), // 5 is double on the right - ("c2", &vec![70, 80, 90, 100]), + let left = build_semi_anti_left_table(); + let right = build_semi_anti_right_table(); + // left_table left anti join right_table on left_table.b1 = right_table.b2 and right_table.a2!=8 + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + )]; + + let column_indices = vec![ColumnIndex { + index: 0, + side: JoinSide::Right, + }]; + let intermediate_schema = + Schema::new(vec![Field::new("x", DataType::Int32, true)]); + let filter_expression = Arc::new(BinaryExpr::new( + Arc::new(Column::new("x", 0)), + Operator::NotEq, + Arc::new(Literal::new(ScalarValue::Int32(Some(8)))), + )) as Arc; + + let filter = JoinFilter::new( + filter_expression, + column_indices.clone(), + intermediate_schema.clone(), ); + + let join = join_with_filter( + left.clone(), + right.clone(), + on.clone(), + filter, + &JoinType::LeftAnti, + false, + )?; + + let columns_header = columns(&join.schema()); + assert_eq!(columns_header, vec!["a1", "b1", "c1"]); + + let stream = join.execute(0, task_ctx.clone())?; + let batches = common::collect(stream).await?; + + let expected = vec![ + "+----+----+-----+", + "| a1 | b1 | c1 |", + "+----+----+-----+", + "| 1 | 1 | 10 |", + "| 11 | 8 | 110 |", + "| 3 | 3 | 30 |", + "| 5 | 5 | 50 |", + "| 7 | 7 | 70 |", + "| 9 | 8 | 90 |", + "+----+----+-----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + // left_table left anti join right_table on left_table.b1 = right_table.b2 and right_table.a2 != 13 + let filter_expression = Arc::new(BinaryExpr::new( + Arc::new(Column::new("x", 0)), + Operator::NotEq, + Arc::new(Literal::new(ScalarValue::Int32(Some(8)))), + )) as Arc; + + let filter = + JoinFilter::new(filter_expression, column_indices, intermediate_schema); + + let join = join_with_filter(left, right, on, filter, &JoinType::LeftAnti, false)?; + + let columns_header = columns(&join.schema()); + assert_eq!(columns_header, vec!["a1", "b1", "c1"]); + + let stream = join.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + + let expected = vec![ + "+----+----+-----+", + "| a1 | b1 | c1 |", + "+----+----+-----+", + "| 1 | 1 | 10 |", + "| 11 | 8 | 110 |", + "| 3 | 3 | 30 |", + "| 5 | 5 | 50 |", + "| 7 | 7 | 70 |", + "| 9 | 8 | 90 |", + "+----+----+-----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + + #[tokio::test] + async fn join_right_anti() -> Result<()> { + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + let left = build_semi_anti_left_table(); + let right = build_semi_anti_right_table(); let on = vec![( - Column::new_with_schema("b2", &left.schema())?, - Column::new_with_schema("b1", &right.schema())?, + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, )]; let join = join(left, right, on, &JoinType::RightAnti, false)?; let columns = columns(&join.schema()); - assert_eq!(columns, vec!["a1", "b1", "c1"]); + assert_eq!(columns, vec!["a2", "b2", "c2"]); let stream = join.execute(0, task_ctx)?; let batches = common::collect(stream).await?; let expected = vec![ - "+----+----+----+", - "| a1 | b1 | c1 |", - "+----+----+----+", - "| 3 | 7 | 9 |", - "| 5 | 7 | 11 |", - "+----+----+----+", + "+----+----+-----+", + "| a2 | b2 | c2 |", + "+----+----+-----+", + "| 2 | 2 | 80 |", + "| 4 | 4 | 120 |", + "| 6 | 6 | 60 |", + "+----+----+-----+", ]; assert_batches_sorted_eq!(expected, &batches); Ok(()) } #[tokio::test] - async fn join_left_anti_with_filter() -> Result<()> { + async fn join_right_anti_with_filter() -> Result<()> { let session_ctx = SessionContext::new(); let task_ctx = session_ctx.task_ctx(); - let left = build_table( - ("col1", &vec![1, 3]), - ("col2", &vec![2, 4]), - ("col3", &vec![3, 5]), - ); - let right = left.clone(); - - // join on col1 + let left = build_semi_anti_left_table(); + let right = build_semi_anti_right_table(); + // left_table right anti join right_table on left_table.b1 = right_table.b2 and left_table.a1!=13 let on = vec![( - Column::new_with_schema("col1", &left.schema())?, - Column::new_with_schema("col1", &right.schema())?, + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, )]; - // build filter left.col2 <> right.col2 - let column_indices = vec![ - ColumnIndex { - index: 1, - side: JoinSide::Left, - }, - ColumnIndex { - index: 1, - side: JoinSide::Right, - }, + let column_indices = vec![ColumnIndex { + index: 0, + side: JoinSide::Left, + }]; + let intermediate_schema = + Schema::new(vec![Field::new("x", DataType::Int32, true)]); + + let filter_expression = Arc::new(BinaryExpr::new( + Arc::new(Column::new("x", 0)), + Operator::NotEq, + Arc::new(Literal::new(ScalarValue::Int32(Some(13)))), + )) as Arc; + + let filter = JoinFilter::new( + filter_expression, + column_indices, + intermediate_schema.clone(), + ); + + let join = join_with_filter( + left.clone(), + right.clone(), + on.clone(), + filter, + &JoinType::RightAnti, + false, + )?; + + let columns_header = columns(&join.schema()); + assert_eq!(columns_header, vec!["a2", "b2", "c2"]); + + let stream = join.execute(0, task_ctx.clone())?; + let batches = common::collect(stream).await?; + + let expected = vec![ + "+----+----+-----+", + "| a2 | b2 | c2 |", + "+----+----+-----+", + "| 10 | 10 | 100 |", + "| 12 | 10 | 40 |", + "| 2 | 2 | 80 |", + "| 4 | 4 | 120 |", + "| 6 | 6 | 60 |", + "+----+----+-----+", ]; - let intermediate_schema = Schema::new(vec![ - Field::new("x", DataType::Int32, true), - Field::new("x", DataType::Int32, true), - ]); + assert_batches_sorted_eq!(expected, &batches); + + // left_table right anti join right_table on left_table.b1 = right_table.b2 and right_table.b2!=8 + let column_indices = vec![ColumnIndex { + index: 1, + side: JoinSide::Right, + }]; let filter_expression = Arc::new(BinaryExpr::new( Arc::new(Column::new("x", 0)), Operator::NotEq, - Arc::new(Column::new("x", 1)), + Arc::new(Literal::new(ScalarValue::Int32(Some(8)))), )) as Arc; let filter = JoinFilter::new(filter_expression, column_indices, intermediate_schema); - let join = join_with_filter(left, right, on, filter, &JoinType::LeftAnti, false)?; + let join = + join_with_filter(left, right, on, filter, &JoinType::RightAnti, false)?; - let columns = columns(&join.schema()); - assert_eq!(columns, vec!["col1", "col2", "col3"]); + let columns_header = columns(&join.schema()); + assert_eq!(columns_header, vec!["a2", "b2", "c2"]); let stream = join.execute(0, task_ctx)?; let batches = common::collect(stream).await?; let expected = vec![ - "+------+------+------+", - "| col1 | col2 | col3 |", - "+------+------+------+", - "| 1 | 2 | 3 |", - "| 3 | 4 | 5 |", - "+------+------+------+", + "+----+----+-----+", + "| a2 | b2 | c2 |", + "+----+----+-----+", + "| 2 | 2 | 80 |", + "| 4 | 4 | 120 |", + "| 6 | 6 | 60 |", + "| 8 | 8 | 20 |", + "+----+----+-----+", ]; assert_batches_sorted_eq!(expected, &batches); + Ok(()) } @@ -2798,10 +2821,9 @@ mod tests { ); let left_data = (JoinHashMap(hashmap_left), left); - let (l, r) = build_join_indexes( + let (l, r) = build_equal_condition_join_indices( &left_data, &right, - JoinType::Inner, &[Column::new("a", 0)], &[Column::new("a", 0)], &random_state, diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 87fb594c79b3..040e18fe476a 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -2211,8 +2211,6 @@ async fn null_aware_left_anti_join() -> Result<()> { } #[tokio::test] -#[ignore = "Test ignored, will be enabled after fixing right semi join bug"] -// https://github.com/apache/arrow-datafusion/issues/4247 async fn right_semi_join() -> Result<()> { let test_repartition_joins = vec![true, false]; for repartition_joins in test_repartition_joins {