diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs b/datafusion/core/tests/fuzz_cases/join_fuzz.rs index b6a5e18da1f1..24b564da6e6a 100644 --- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs @@ -245,14 +245,16 @@ async fn test_right_semi_join_1k() { #[tokio::test] async fn test_right_semi_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000, false), - make_staggered_batches_i32(1000, false), - JoinType::RightSemi, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::RightSemi, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] @@ -299,14 +301,16 @@ async fn test_right_anti_join_1k() { #[tokio::test] async fn test_right_anti_join_1k_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_i32(1000, false), - make_staggered_batches_i32(1000, false), - JoinType::RightAnti, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_i32(1000, left_extra), + make_staggered_batches_i32(1000, right_extra), + JoinType::RightAnti, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] @@ -564,26 +568,30 @@ async fn test_left_anti_join_1k_binary_filtered() { #[tokio::test] async fn test_right_anti_join_1k_binary() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000, false), - make_staggered_batches_binary(1000, false), - JoinType::RightAnti, - None, - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::RightAnti, + None, + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] async fn test_right_anti_join_1k_binary_filtered() { - JoinFuzzTestCase::new( - make_staggered_batches_binary(1000, false), - make_staggered_batches_binary(1000, false), - JoinType::RightAnti, - Some(Box::new(col_lt_col_filter)), - ) - .run_test(&[HjSmj, NljHj], false) - .await + for (left_extra, right_extra) in [(true, true), (false, true), (true, false)] { + JoinFuzzTestCase::new( + make_staggered_batches_binary(1000, left_extra), + make_staggered_batches_binary(1000, right_extra), + JoinType::RightAnti, + Some(Box::new(col_lt_col_filter)), + ) + .run_test(&[HjSmj, NljHj], false) + .await + } } #[tokio::test] diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs index 0325e37d42e7..cf6a8c2db3df 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs @@ -1566,6 +1566,7 @@ impl SortMergeJoinStream { .columns() .iter() .skip(right_columns_length) + .take(left_columns_length) .cloned() .collect::>(); @@ -1595,14 +1596,16 @@ impl SortMergeJoinStream { &self.schema, &[filtered_record_batch, null_joined_streamed_batch], )?; - } else if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) { + } else if matches!( + self.join_type, + JoinType::LeftSemi + | JoinType::LeftAnti + | JoinType::RightAnti + | JoinType::RightSemi + ) { let output_column_indices = (0..left_columns_length).collect::>(); filtered_record_batch = filtered_record_batch.project(&output_column_indices)?; - } else if matches!(self.join_type, JoinType::RightAnti | JoinType::RightSemi) { - let output_column_indices = (0..right_columns_length).collect::>(); - filtered_record_batch = - filtered_record_batch.project(&output_column_indices)?; } else if matches!(self.join_type, JoinType::Full) && corrected_mask.false_count() > 0 { diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs index f91bffbed78f..2e4725995b47 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs @@ -1109,6 +1109,61 @@ async fn join_right_anti_two_with_filter() -> Result<()> { Ok(()) } +#[tokio::test] +async fn join_right_anti_filtered_with_mismatched_columns() -> Result<()> { + let left = build_table_two_cols(("a1", &vec![31, 31]), ("b1", &vec![32, 33])); + let right = build_table( + ("a2", &vec![31, 31]), + ("b2", &vec![32, 35]), + ("c2", &vec![108, 109]), + ); + let on = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("a2", &right.schema())?) as _, + ), + ( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + ), + ]; + + let filter = JoinFilter::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("b1", 0)), + Operator::LtEq, + Arc::new(Column::new("c2", 1)), + )), + vec![ + ColumnIndex { + index: 1, + side: JoinSide::Left, + }, + ColumnIndex { + index: 2, + side: JoinSide::Right, + }, + ], + Arc::new(Schema::new(vec![ + Field::new("b1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + ])), + ); + + let (_, batches) = + join_collect_with_filter(left, right, on, filter, RightAnti).await?; + + let expected = [ + "+----+----+-----+", + "| a2 | b2 | c2 |", + "+----+----+-----+", + "| 31 | 35 | 109 |", + "+----+----+-----+", + ]; + assert_batches_eq!(expected, &batches); + Ok(()) +} + #[tokio::test] async fn join_right_anti_with_nulls() -> Result<()> { let left = build_table_i32_nullable(