Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor: Add routine to debug join fuzz tests #10970

Merged
merged 15 commits into from
Jun 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 162 additions & 40 deletions datafusion/core/tests/fuzz_cases/join_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@ use datafusion::physical_plan::memory::MemoryExec;
use datafusion::prelude::{SessionConfig, SessionContext};
use test_utils::stagger_batch_with_seed;

// Determines what Fuzz tests needs to run
// Ideally all tests should match, but in reality some tests
// passes only partial cases
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
enum JoinTestType {
// compare NestedLoopJoin and HashJoin
NljHj,
// compare HashJoin and SortMergeJoin, no need to compare SortMergeJoin and NestedLoopJoin
// because if existing variants both passed that means SortMergeJoin and NestedLoopJoin also passes
HjSmj,
}
#[tokio::test]
async fn test_inner_join_1k() {
JoinFuzzTestCase::new(
Expand All @@ -51,7 +62,7 @@ async fn test_inner_join_1k() {
JoinType::Inner,
None,
)
.run_test()
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand All @@ -71,6 +82,30 @@ fn less_than_100_join_filter(schema1: Arc<Schema>, _schema2: Arc<Schema>) -> Joi
JoinFilter::new(less_than_100, column_indices, intermediate_schema)
}

fn col_lt_col_filter(schema1: Arc<Schema>, schema2: Arc<Schema>) -> JoinFilter {
let less_than_100 = Arc::new(BinaryExpr::new(
Arc::new(Column::new("x", 1)),
Operator::Lt,
Arc::new(Column::new("x", 0)),
)) as _;
let column_indices = vec![
ColumnIndex {
index: 2,
side: JoinSide::Left,
},
ColumnIndex {
index: 2,
side: JoinSide::Right,
},
];
let intermediate_schema = Schema::new(vec![
schema1.field_with_name("x").unwrap().to_owned(),
schema2.field_with_name("x").unwrap().to_owned(),
]);

JoinFilter::new(less_than_100, column_indices, intermediate_schema)
}

#[tokio::test]
async fn test_inner_join_1k_filtered() {
JoinFuzzTestCase::new(
Expand All @@ -79,7 +114,7 @@ async fn test_inner_join_1k_filtered() {
JoinType::Inner,
Some(Box::new(less_than_100_join_filter)),
)
.run_test()
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand All @@ -91,7 +126,7 @@ async fn test_inner_join_1k_smjoin() {
JoinType::Inner,
None,
)
.run_test()
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand All @@ -103,7 +138,7 @@ async fn test_left_join_1k() {
JoinType::Left,
None,
)
.run_test()
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand All @@ -115,7 +150,7 @@ async fn test_left_join_1k_filtered() {
JoinType::Left,
Some(Box::new(less_than_100_join_filter)),
)
.run_test()
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand All @@ -127,7 +162,7 @@ async fn test_right_join_1k() {
JoinType::Right,
None,
)
.run_test()
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}
// Add support for Right filtered joins
Expand All @@ -140,7 +175,7 @@ async fn test_right_join_1k_filtered() {
JoinType::Right,
Some(Box::new(less_than_100_join_filter)),
)
.run_test()
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand All @@ -152,7 +187,7 @@ async fn test_full_join_1k() {
JoinType::Full,
None,
)
.run_test()
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand All @@ -164,7 +199,7 @@ async fn test_full_join_1k_filtered() {
JoinType::Full,
Some(Box::new(less_than_100_join_filter)),
)
.run_test()
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand All @@ -176,22 +211,23 @@ async fn test_semi_join_1k() {
JoinType::LeftSemi,
None,
)
.run_test()
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

// The test is flaky
// https://github.com/apache/datafusion/issues/10886
// SMJ produces 1 more row in the output
#[ignore]
#[tokio::test]
async fn test_semi_join_1k_filtered() {
JoinFuzzTestCase::new(
make_staggered_batches(1000),
make_staggered_batches(1000),
JoinType::LeftSemi,
Some(Box::new(less_than_100_join_filter)),
Some(Box::new(col_lt_col_filter)),
)
.run_test()
.run_test(&[JoinTestType::HjSmj], false)
.await
}

Expand All @@ -203,7 +239,7 @@ async fn test_anti_join_1k() {
JoinType::LeftAnti,
None,
)
.run_test()
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand All @@ -217,7 +253,7 @@ async fn test_anti_join_1k_filtered() {
JoinType::LeftAnti,
Some(Box::new(less_than_100_join_filter)),
)
.run_test()
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand Down Expand Up @@ -331,7 +367,7 @@ impl JoinFuzzTestCase {
self.on_columns().clone(),
self.join_filter(),
self.join_type,
vec![SortOptions::default(), SortOptions::default()],
vec![SortOptions::default(); self.on_columns().len()],
false,
)
.unwrap(),
Expand Down Expand Up @@ -381,9 +417,11 @@ impl JoinFuzzTestCase {
)
}

/// Perform sort-merge join and hash join on same input
/// and verify two outputs are equal
async fn run_test(&self) {
/// Perform joins tests on same inputs and verify outputs are equal
/// `join_tests` - identifies what join types to test
/// if `debug` flag is set the test will save randomly generated inputs and outputs to user folders,
/// so it is easy to debug a test on top of the failed data
async fn run_test(&self, join_tests: &[JoinTestType], debug: bool) {
for batch_size in self.batch_sizes {
let session_config = SessionConfig::new().with_batch_size(*batch_size);
let ctx = SessionContext::new_with_config(session_config);
Expand All @@ -394,17 +432,30 @@ impl JoinFuzzTestCase {
let hj = self.hash_join();
let hj_collected = collect(hj, task_ctx.clone()).await.unwrap();

let nlj = self.nested_loop_join();
let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap();

// Get actual row counts(without formatting overhead) for HJ and SMJ
let hj_rows = hj_collected.iter().fold(0, |acc, b| acc + b.num_rows());
let smj_rows = smj_collected.iter().fold(0, |acc, b| acc + b.num_rows());
let nlj_rows = nlj_collected.iter().fold(0, |acc, b| acc + b.num_rows());

assert_eq!(
hj_rows, smj_rows,
"SortMergeJoinExec and HashJoinExec produced different row counts"
);
if debug {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

It might make sense to add this information about saving information to the overall docstring of this function (run_tests) so people don't have to read the implementation to find out what debug does.

println!("The debug is ON. Input data will be saved");
let out_dir_name = &format!("fuzz_test_debug_batch_size_{batch_size}");
Self::save_as_parquet(&self.input1, out_dir_name, "input1");
Self::save_as_parquet(&self.input2, out_dir_name, "input2");

let nlj = self.nested_loop_join();
let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap();
if join_tests.contains(&JoinTestType::NljHj) {
Self::save_as_parquet(&nlj_collected, out_dir_name, "nlj");
Self::save_as_parquet(&hj_collected, out_dir_name, "hj");
}

if join_tests.contains(&JoinTestType::HjSmj) {
Self::save_as_parquet(&hj_collected, out_dir_name, "hj");
Self::save_as_parquet(&smj_collected, out_dir_name, "smj");
}
}

// compare
let smj_formatted =
Expand All @@ -425,35 +476,106 @@ impl JoinFuzzTestCase {
nlj_formatted.trim().lines().collect();
nlj_formatted_sorted.sort_unstable();

// row level compare if any of joins returns the result
// the reason is different formatting when there is no rows
if smj_rows > 0 || hj_rows > 0 {
for (i, (smj_line, hj_line)) in smj_formatted_sorted
if join_tests.contains(&JoinTestType::NljHj) {
let err_msg_rowcnt = format!("NestedLoopJoinExec and HashJoinExec produced different row counts, batch_size: {}", batch_size);
assert_eq!(nlj_rows, hj_rows, "{}", err_msg_rowcnt.as_str());

let err_msg_contents = format!("NestedLoopJoinExec and HashJoinExec produced different results, batch_size: {}", batch_size);
// row level compare if any of joins returns the result
// the reason is different formatting when there is no rows
for (i, (nlj_line, hj_line)) in nlj_formatted_sorted
.iter()
.zip(&hj_formatted_sorted)
.enumerate()
{
assert_eq!(
(i, smj_line),
(i, nlj_line),
(i, hj_line),
"SortMergeJoinExec and HashJoinExec produced different results"
"{}",
err_msg_contents.as_str()
);
}
}

for (i, (nlj_line, hj_line)) in nlj_formatted_sorted
.iter()
.zip(&hj_formatted_sorted)
.enumerate()
{
assert_eq!(
(i, nlj_line),
(i, hj_line),
"NestedLoopJoinExec and HashJoinExec produced different results"
);
if join_tests.contains(&JoinTestType::HjSmj) {
let err_msg_row_cnt = format!("HashJoinExec and SortMergeJoinExec produced different row counts, batch_size: {}", &batch_size);
assert_eq!(hj_rows, smj_rows, "{}", err_msg_row_cnt.as_str());

let err_msg_contents = format!("SortMergeJoinExec and HashJoinExec produced different results, batch_size: {}", &batch_size);
// row level compare if any of joins returns the result
// the reason is different formatting when there is no rows
if smj_rows > 0 || hj_rows > 0 {
for (i, (smj_line, hj_line)) in smj_formatted_sorted
.iter()
.zip(&hj_formatted_sorted)
.enumerate()
{
assert_eq!(
(i, smj_line),
(i, hj_line),
"{}",
err_msg_contents.as_str()
);
}
}
}
}
}

/// This method useful for debugging fuzz tests
/// It helps to save randomly generated input test data for both join inputs into the user folder
/// as a parquet files preserving partitioning.
/// Once the data is saved it is possible to run a custom test on top of the saved data and debug
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

///
/// let ctx: SessionContext = SessionContext::new();
/// let df = ctx
/// .read_parquet(
/// "/tmp/input1/*.parquet",
/// ParquetReadOptions::default(),
/// )
/// .await
/// .unwrap();
/// let left = df.collect().await.unwrap();
///
/// let df = ctx
/// .read_parquet(
/// "/tmp/input2/*.parquet",
/// ParquetReadOptions::default(),
/// )
/// .await
/// .unwrap();
///
/// let right = df.collect().await.unwrap();
/// JoinFuzzTestCase::new(
/// left,
/// right,
/// JoinType::LeftSemi,
/// Some(Box::new(less_than_100_join_filter)),
/// )
/// .run_test()
/// .await
/// }
fn save_as_parquet(input: &[RecordBatch], output_dir: &str, out_name: &str) {
let out_path = &format!("{output_dir}/{out_name}");
std::fs::remove_dir_all(out_path).unwrap_or(());
std::fs::create_dir_all(out_path).unwrap();

input.iter().enumerate().for_each(|(idx, batch)| {
let mut file =
std::fs::File::create(format!("{out_path}/file_{}.parquet", idx))
.unwrap();
let mut writer = parquet::arrow::ArrowWriter::try_new(
&mut file,
input.first().unwrap().schema(),
None,
)
.expect("creating writer");
writer.write(batch).unwrap();
writer.close().unwrap();
});

println!("The data {out_name} saved as parquet into {out_path}");
}
}

/// Return randomly sized record batches with:
Expand Down