diff --git a/datafusion/optimizer/src/limit_push_down.rs b/datafusion/optimizer/src/limit_push_down.rs index 66dba0d8587f..3a821f3aea02 100644 --- a/datafusion/optimizer/src/limit_push_down.rs +++ b/datafusion/optimizer/src/limit_push_down.rs @@ -39,6 +39,10 @@ impl LimitPushDown { } } +fn is_no_join_condition(join: &Join) -> bool { + join.on.is_empty() && join.filter.is_none() +} + fn push_down_join( join: &Join, left_limit: Option, @@ -192,6 +196,24 @@ impl OptimizerRule for LimitPushDown { LogicalPlan::Join(join) => { let limit = fetch + skip; let new_join = match join.join_type { + JoinType::Left | JoinType::Right | JoinType::Full + if is_no_join_condition(join) => + { + // push left and right + push_down_join(join, Some(limit), Some(limit)) + } + JoinType::LeftSemi | JoinType::LeftAnti + if is_no_join_condition(join) => + { + // push left + push_down_join(join, Some(limit), None) + } + JoinType::RightSemi | JoinType::RightAnti + if is_no_join_condition(join) => + { + // push right + push_down_join(join, None, Some(limit)) + } JoinType::Left => push_down_join(join, Some(limit), None), JoinType::Right => push_down_join(join, None, Some(limit)), _ => push_down_join(join, None, None), @@ -606,6 +628,142 @@ mod test { assert_optimized_plan_eq(&outer_query, expected) } + #[test] + fn limit_should_push_down_join_without_condition() -> Result<()> { + let table_scan_1 = test_table_scan()?; + let table_scan_2 = test_table_scan_with_name("test2")?; + let left_keys: Vec<&str> = Vec::new(); + let right_keys: Vec<&str> = Vec::new(); + let plan = LogicalPlanBuilder::from(table_scan_1.clone()) + .join( + &LogicalPlanBuilder::from(table_scan_2.clone()).build()?, + JoinType::Left, + (left_keys.clone(), right_keys.clone()), + None, + )? + .limit(0, Some(1000))? + .build()?; + + let expected = "Limit: skip=0, fetch=1000\ + \n Left Join: \ + \n Limit: skip=0, fetch=1000\ + \n TableScan: test, fetch=1000\ + \n Limit: skip=0, fetch=1000\ + \n TableScan: test2, fetch=1000"; + + assert_optimized_plan_eq(&plan, expected)?; + + let plan = LogicalPlanBuilder::from(table_scan_1.clone()) + .join( + &LogicalPlanBuilder::from(table_scan_2.clone()).build()?, + JoinType::Right, + (left_keys.clone(), right_keys.clone()), + None, + )? + .limit(0, Some(1000))? + .build()?; + + let expected = "Limit: skip=0, fetch=1000\ + \n Right Join: \ + \n Limit: skip=0, fetch=1000\ + \n TableScan: test, fetch=1000\ + \n Limit: skip=0, fetch=1000\ + \n TableScan: test2, fetch=1000"; + + assert_optimized_plan_eq(&plan, expected)?; + + let plan = LogicalPlanBuilder::from(table_scan_1.clone()) + .join( + &LogicalPlanBuilder::from(table_scan_2.clone()).build()?, + JoinType::Full, + (left_keys.clone(), right_keys.clone()), + None, + )? + .limit(0, Some(1000))? + .build()?; + + let expected = "Limit: skip=0, fetch=1000\ + \n Full Join: \ + \n Limit: skip=0, fetch=1000\ + \n TableScan: test, fetch=1000\ + \n Limit: skip=0, fetch=1000\ + \n TableScan: test2, fetch=1000"; + + assert_optimized_plan_eq(&plan, expected)?; + + let plan = LogicalPlanBuilder::from(table_scan_1.clone()) + .join( + &LogicalPlanBuilder::from(table_scan_2.clone()).build()?, + JoinType::LeftSemi, + (left_keys.clone(), right_keys.clone()), + None, + )? + .limit(0, Some(1000))? + .build()?; + + let expected = "Limit: skip=0, fetch=1000\ + \n LeftSemi Join: \ + \n Limit: skip=0, fetch=1000\ + \n TableScan: test, fetch=1000\ + \n TableScan: test2"; + + assert_optimized_plan_eq(&plan, expected)?; + + let plan = LogicalPlanBuilder::from(table_scan_1.clone()) + .join( + &LogicalPlanBuilder::from(table_scan_2.clone()).build()?, + JoinType::LeftAnti, + (left_keys.clone(), right_keys.clone()), + None, + )? + .limit(0, Some(1000))? + .build()?; + + let expected = "Limit: skip=0, fetch=1000\ + \n LeftAnti Join: \ + \n Limit: skip=0, fetch=1000\ + \n TableScan: test, fetch=1000\ + \n TableScan: test2"; + + assert_optimized_plan_eq(&plan, expected)?; + + let plan = LogicalPlanBuilder::from(table_scan_1.clone()) + .join( + &LogicalPlanBuilder::from(table_scan_2.clone()).build()?, + JoinType::RightSemi, + (left_keys.clone(), right_keys.clone()), + None, + )? + .limit(0, Some(1000))? + .build()?; + + let expected = "Limit: skip=0, fetch=1000\ + \n RightSemi Join: \ + \n TableScan: test\ + \n Limit: skip=0, fetch=1000\ + \n TableScan: test2, fetch=1000"; + + assert_optimized_plan_eq(&plan, expected)?; + + let plan = LogicalPlanBuilder::from(table_scan_1) + .join( + &LogicalPlanBuilder::from(table_scan_2).build()?, + JoinType::RightAnti, + (left_keys, right_keys), + None, + )? + .limit(0, Some(1000))? + .build()?; + + let expected = "Limit: skip=0, fetch=1000\ + \n RightAnti Join: \ + \n TableScan: test\ + \n Limit: skip=0, fetch=1000\ + \n TableScan: test2, fetch=1000"; + + assert_optimized_plan_eq(&plan, expected) + } + #[test] fn limit_should_push_down_left_outer_join() -> Result<()> { let table_scan_1 = test_table_scan()?;