Skip to content

Commit

Permalink
Minor: Add routine to debug join fuzz tests (#10970)
Browse files Browse the repository at this point in the history
* Minor: Add routine to debug join fuzz tests
  • Loading branch information
comphead committed Jun 18, 2024
1 parent a2c9d1a commit e9f9a23
Showing 1 changed file with 162 additions and 40 deletions.
202 changes: 162 additions & 40 deletions datafusion/core/tests/fuzz_cases/join_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@ use datafusion::physical_plan::memory::MemoryExec;
use datafusion::prelude::{SessionConfig, SessionContext};
use test_utils::stagger_batch_with_seed;

// Determines what Fuzz tests needs to run
// Ideally all tests should match, but in reality some tests
// passes only partial cases
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
enum JoinTestType {
// compare NestedLoopJoin and HashJoin
NljHj,
// compare HashJoin and SortMergeJoin, no need to compare SortMergeJoin and NestedLoopJoin
// because if existing variants both passed that means SortMergeJoin and NestedLoopJoin also passes
HjSmj,
}
#[tokio::test]
async fn test_inner_join_1k() {
JoinFuzzTestCase::new(
Expand All @@ -51,7 +62,7 @@ async fn test_inner_join_1k() {
JoinType::Inner,
None,
)
.run_test()
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand All @@ -71,6 +82,30 @@ fn less_than_100_join_filter(schema1: Arc<Schema>, _schema2: Arc<Schema>) -> Joi
JoinFilter::new(less_than_100, column_indices, intermediate_schema)
}

fn col_lt_col_filter(schema1: Arc<Schema>, schema2: Arc<Schema>) -> JoinFilter {
let less_than_100 = Arc::new(BinaryExpr::new(
Arc::new(Column::new("x", 1)),
Operator::Lt,
Arc::new(Column::new("x", 0)),
)) as _;
let column_indices = vec![
ColumnIndex {
index: 2,
side: JoinSide::Left,
},
ColumnIndex {
index: 2,
side: JoinSide::Right,
},
];
let intermediate_schema = Schema::new(vec![
schema1.field_with_name("x").unwrap().to_owned(),
schema2.field_with_name("x").unwrap().to_owned(),
]);

JoinFilter::new(less_than_100, column_indices, intermediate_schema)
}

#[tokio::test]
async fn test_inner_join_1k_filtered() {
JoinFuzzTestCase::new(
Expand All @@ -79,7 +114,7 @@ async fn test_inner_join_1k_filtered() {
JoinType::Inner,
Some(Box::new(less_than_100_join_filter)),
)
.run_test()
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand All @@ -91,7 +126,7 @@ async fn test_inner_join_1k_smjoin() {
JoinType::Inner,
None,
)
.run_test()
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand All @@ -103,7 +138,7 @@ async fn test_left_join_1k() {
JoinType::Left,
None,
)
.run_test()
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand All @@ -115,7 +150,7 @@ async fn test_left_join_1k_filtered() {
JoinType::Left,
Some(Box::new(less_than_100_join_filter)),
)
.run_test()
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand All @@ -127,7 +162,7 @@ async fn test_right_join_1k() {
JoinType::Right,
None,
)
.run_test()
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}
// Add support for Right filtered joins
Expand All @@ -140,7 +175,7 @@ async fn test_right_join_1k_filtered() {
JoinType::Right,
Some(Box::new(less_than_100_join_filter)),
)
.run_test()
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand All @@ -152,7 +187,7 @@ async fn test_full_join_1k() {
JoinType::Full,
None,
)
.run_test()
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand All @@ -164,7 +199,7 @@ async fn test_full_join_1k_filtered() {
JoinType::Full,
Some(Box::new(less_than_100_join_filter)),
)
.run_test()
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand All @@ -176,22 +211,23 @@ async fn test_semi_join_1k() {
JoinType::LeftSemi,
None,
)
.run_test()
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

// The test is flaky
// https://github.com/apache/datafusion/issues/10886
// SMJ produces 1 more row in the output
#[ignore]
#[tokio::test]
async fn test_semi_join_1k_filtered() {
JoinFuzzTestCase::new(
make_staggered_batches(1000),
make_staggered_batches(1000),
JoinType::LeftSemi,
Some(Box::new(less_than_100_join_filter)),
Some(Box::new(col_lt_col_filter)),
)
.run_test()
.run_test(&[JoinTestType::HjSmj], false)
.await
}

Expand All @@ -203,7 +239,7 @@ async fn test_anti_join_1k() {
JoinType::LeftAnti,
None,
)
.run_test()
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand All @@ -217,7 +253,7 @@ async fn test_anti_join_1k_filtered() {
JoinType::LeftAnti,
Some(Box::new(less_than_100_join_filter)),
)
.run_test()
.run_test(&[JoinTestType::HjSmj, JoinTestType::NljHj], false)
.await
}

Expand Down Expand Up @@ -331,7 +367,7 @@ impl JoinFuzzTestCase {
self.on_columns().clone(),
self.join_filter(),
self.join_type,
vec![SortOptions::default(), SortOptions::default()],
vec![SortOptions::default(); self.on_columns().len()],
false,
)
.unwrap(),
Expand Down Expand Up @@ -381,9 +417,11 @@ impl JoinFuzzTestCase {
)
}

/// Perform sort-merge join and hash join on same input
/// and verify two outputs are equal
async fn run_test(&self) {
/// Perform joins tests on same inputs and verify outputs are equal
/// `join_tests` - identifies what join types to test
/// if `debug` flag is set the test will save randomly generated inputs and outputs to user folders,
/// so it is easy to debug a test on top of the failed data
async fn run_test(&self, join_tests: &[JoinTestType], debug: bool) {
for batch_size in self.batch_sizes {
let session_config = SessionConfig::new().with_batch_size(*batch_size);
let ctx = SessionContext::new_with_config(session_config);
Expand All @@ -394,17 +432,30 @@ impl JoinFuzzTestCase {
let hj = self.hash_join();
let hj_collected = collect(hj, task_ctx.clone()).await.unwrap();

let nlj = self.nested_loop_join();
let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap();

// Get actual row counts(without formatting overhead) for HJ and SMJ
let hj_rows = hj_collected.iter().fold(0, |acc, b| acc + b.num_rows());
let smj_rows = smj_collected.iter().fold(0, |acc, b| acc + b.num_rows());
let nlj_rows = nlj_collected.iter().fold(0, |acc, b| acc + b.num_rows());

assert_eq!(
hj_rows, smj_rows,
"SortMergeJoinExec and HashJoinExec produced different row counts"
);
if debug {
println!("The debug is ON. Input data will be saved");
let out_dir_name = &format!("fuzz_test_debug_batch_size_{batch_size}");
Self::save_as_parquet(&self.input1, out_dir_name, "input1");
Self::save_as_parquet(&self.input2, out_dir_name, "input2");

let nlj = self.nested_loop_join();
let nlj_collected = collect(nlj, task_ctx.clone()).await.unwrap();
if join_tests.contains(&JoinTestType::NljHj) {
Self::save_as_parquet(&nlj_collected, out_dir_name, "nlj");
Self::save_as_parquet(&hj_collected, out_dir_name, "hj");
}

if join_tests.contains(&JoinTestType::HjSmj) {
Self::save_as_parquet(&hj_collected, out_dir_name, "hj");
Self::save_as_parquet(&smj_collected, out_dir_name, "smj");
}
}

// compare
let smj_formatted =
Expand All @@ -425,35 +476,106 @@ impl JoinFuzzTestCase {
nlj_formatted.trim().lines().collect();
nlj_formatted_sorted.sort_unstable();

// row level compare if any of joins returns the result
// the reason is different formatting when there is no rows
if smj_rows > 0 || hj_rows > 0 {
for (i, (smj_line, hj_line)) in smj_formatted_sorted
if join_tests.contains(&JoinTestType::NljHj) {
let err_msg_rowcnt = format!("NestedLoopJoinExec and HashJoinExec produced different row counts, batch_size: {}", batch_size);
assert_eq!(nlj_rows, hj_rows, "{}", err_msg_rowcnt.as_str());

let err_msg_contents = format!("NestedLoopJoinExec and HashJoinExec produced different results, batch_size: {}", batch_size);
// row level compare if any of joins returns the result
// the reason is different formatting when there is no rows
for (i, (nlj_line, hj_line)) in nlj_formatted_sorted
.iter()
.zip(&hj_formatted_sorted)
.enumerate()
{
assert_eq!(
(i, smj_line),
(i, nlj_line),
(i, hj_line),
"SortMergeJoinExec and HashJoinExec produced different results"
"{}",
err_msg_contents.as_str()
);
}
}

for (i, (nlj_line, hj_line)) in nlj_formatted_sorted
.iter()
.zip(&hj_formatted_sorted)
.enumerate()
{
assert_eq!(
(i, nlj_line),
(i, hj_line),
"NestedLoopJoinExec and HashJoinExec produced different results"
);
if join_tests.contains(&JoinTestType::HjSmj) {
let err_msg_row_cnt = format!("HashJoinExec and SortMergeJoinExec produced different row counts, batch_size: {}", &batch_size);
assert_eq!(hj_rows, smj_rows, "{}", err_msg_row_cnt.as_str());

let err_msg_contents = format!("SortMergeJoinExec and HashJoinExec produced different results, batch_size: {}", &batch_size);
// row level compare if any of joins returns the result
// the reason is different formatting when there is no rows
if smj_rows > 0 || hj_rows > 0 {
for (i, (smj_line, hj_line)) in smj_formatted_sorted
.iter()
.zip(&hj_formatted_sorted)
.enumerate()
{
assert_eq!(
(i, smj_line),
(i, hj_line),
"{}",
err_msg_contents.as_str()
);
}
}
}
}
}

/// This method useful for debugging fuzz tests
/// It helps to save randomly generated input test data for both join inputs into the user folder
/// as a parquet files preserving partitioning.
/// Once the data is saved it is possible to run a custom test on top of the saved data and debug
///
/// let ctx: SessionContext = SessionContext::new();
/// let df = ctx
/// .read_parquet(
/// "/tmp/input1/*.parquet",
/// ParquetReadOptions::default(),
/// )
/// .await
/// .unwrap();
/// let left = df.collect().await.unwrap();
///
/// let df = ctx
/// .read_parquet(
/// "/tmp/input2/*.parquet",
/// ParquetReadOptions::default(),
/// )
/// .await
/// .unwrap();
///
/// let right = df.collect().await.unwrap();
/// JoinFuzzTestCase::new(
/// left,
/// right,
/// JoinType::LeftSemi,
/// Some(Box::new(less_than_100_join_filter)),
/// )
/// .run_test()
/// .await
/// }
fn save_as_parquet(input: &[RecordBatch], output_dir: &str, out_name: &str) {
let out_path = &format!("{output_dir}/{out_name}");
std::fs::remove_dir_all(out_path).unwrap_or(());
std::fs::create_dir_all(out_path).unwrap();

input.iter().enumerate().for_each(|(idx, batch)| {
let mut file =
std::fs::File::create(format!("{out_path}/file_{}.parquet", idx))
.unwrap();
let mut writer = parquet::arrow::ArrowWriter::try_new(
&mut file,
input.first().unwrap().schema(),
None,
)
.expect("creating writer");
writer.write(batch).unwrap();
writer.close().unwrap();
});

println!("The data {out_name} saved as parquet into {out_path}");
}
}

/// Return randomly sized record batches with:
Expand Down

0 comments on commit e9f9a23

Please sign in to comment.