diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index cad054392308..38b5257e32e1 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -272,8 +272,7 @@ impl TryInto for &protobuf::LogicalPlanNode { JoinConstraint::On => builder.join( &convert_box_required!(join.right)?, join_type.into(), - left_keys, - right_keys, + (left_keys, right_keys), )?, JoinConstraint::Using => builder.join_using( &convert_box_required!(join.right)?, diff --git a/ballista/rust/core/src/serde/logical_plan/mod.rs b/ballista/rust/core/src/serde/logical_plan/mod.rs index 0d27c58ac292..f6dbeaf6a151 100644 --- a/ballista/rust/core/src/serde/logical_plan/mod.rs +++ b/ballista/rust/core/src/serde/logical_plan/mod.rs @@ -701,7 +701,7 @@ mod roundtrip_tests { CsvReadOptions::new().schema(&schema).has_header(true), Some(vec![0, 3, 4]), ) - .and_then(|plan| plan.join(&scan_plan, JoinType::Inner, vec!["id"], vec!["id"])) + .and_then(|plan| plan.join(&scan_plan, JoinType::Inner, (vec!["id"], vec!["id"]))) .and_then(|plan| plan.build()) .map_err(BallistaError::DataFusionError)?; diff --git a/datafusion/src/execution/dataframe_impl.rs b/datafusion/src/execution/dataframe_impl.rs index 4edd01c2c0a9..451c4c7ba502 100644 --- a/datafusion/src/execution/dataframe_impl.rs +++ b/datafusion/src/execution/dataframe_impl.rs @@ -117,8 +117,7 @@ impl DataFrame for DataFrameImpl { .join( &right.to_logical_plan(), join_type, - left_cols.to_vec(), - right_cols.to_vec(), + (left_cols.to_vec(), right_cols.to_vec()), )? .build()?; Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan))) diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index 60e0ed3c0988..a742f346207a 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -273,23 +273,37 @@ impl LogicalPlanBuilder { &self, right: &LogicalPlan, join_type: JoinType, - left_keys: Vec>, - right_keys: Vec>, + join_keys: (Vec>, Vec>), ) -> Result { - if left_keys.len() != right_keys.len() { + if join_keys.0.len() != join_keys.1.len() { return Err(DataFusionError::Plan( "left_keys and right_keys were not the same length".to_string(), )); } - let left_keys: Vec = left_keys - .into_iter() - .map(|c| c.into().normalize(&self.plan)) - .collect::>()?; - let right_keys: Vec = right_keys - .into_iter() - .map(|c| c.into().normalize(right)) - .collect::>()?; + let (left_keys, right_keys): (Vec>, Vec>) = + join_keys + .0 + .into_iter() + .zip(join_keys.1.into_iter()) + .map(|(l, r)| { + let mut swap = false; + let l = l.into(); + let left_key = l.clone().normalize(&self.plan).or_else(|_| { + swap = true; + l.normalize(right) + }); + if swap { + (r.into().normalize(&self.plan), left_key) + } else { + (left_key, r.into().normalize(right)) + } + }) + .unzip(); + + let left_keys = left_keys.into_iter().collect::>>()?; + let right_keys = right_keys.into_iter().collect::>>()?; + let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect(); let join_schema = build_join_schema(self.plan.schema(), right.schema(), &join_type)?; diff --git a/datafusion/src/optimizer/filter_push_down.rs b/datafusion/src/optimizer/filter_push_down.rs index 399923e87218..039e92d1c128 100644 --- a/datafusion/src/optimizer/filter_push_down.rs +++ b/datafusion/src/optimizer/filter_push_down.rs @@ -973,8 +973,7 @@ mod tests { .join( &right, JoinType::Inner, - vec![Column::from_name("a")], - vec![Column::from_name("a")], + (vec![Column::from_name("a")], vec![Column::from_name("a")]), )? .filter(col("a").lt_eq(lit(1i64)))? .build()?; @@ -1058,8 +1057,7 @@ mod tests { .join( &right, JoinType::Inner, - vec![Column::from_name("a")], - vec![Column::from_name("a")], + (vec![Column::from_name("a")], vec![Column::from_name("a")]), )? // "b" and "c" are not shared by either side: they are only available together after the join .filter(col("c").lt_eq(col("b")))? @@ -1099,8 +1097,7 @@ mod tests { .join( &right, JoinType::Inner, - vec![Column::from_name("a")], - vec![Column::from_name("a")], + (vec![Column::from_name("a")], vec![Column::from_name("a")]), )? .filter(col("b").lt_eq(lit(1i64)))? .build()?; diff --git a/datafusion/src/optimizer/projection_push_down.rs b/datafusion/src/optimizer/projection_push_down.rs index 0de36f354206..96c5094711ba 100644 --- a/datafusion/src/optimizer/projection_push_down.rs +++ b/datafusion/src/optimizer/projection_push_down.rs @@ -555,7 +555,7 @@ mod tests { LogicalPlanBuilder::scan_empty(Some("test2"), &schema, None)?.build()?; let plan = LogicalPlanBuilder::from(table_scan) - .join(&table2_scan, JoinType::Left, vec!["a"], vec!["c1"])? + .join(&table2_scan, JoinType::Left, (vec!["a"], vec!["c1"]))? .project(vec![col("a"), col("b"), col("c1")])? .build()?; @@ -594,7 +594,7 @@ mod tests { LogicalPlanBuilder::scan_empty(Some("test2"), &schema, None)?.build()?; let plan = LogicalPlanBuilder::from(table_scan) - .join(&table2_scan, JoinType::Left, vec!["a"], vec!["c1"])? + .join(&table2_scan, JoinType::Left, (vec!["a"], vec!["c1"]))? // projecting joined column `a` should push the right side column `c1` projection as // well into test2 table even though `c1` is not referenced in projection. .project(vec![col("a"), col("b")])? diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index fa2b035162a6..6d9484be102f 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -375,8 +375,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let (left_keys, right_keys): (Vec, Vec) = keys.into_iter().unzip(); // return the logical plan representing the join - let join = LogicalPlanBuilder::from(left) - .join(right, join_type, left_keys, right_keys)?; + let join = LogicalPlanBuilder::from(left).join( + right, + join_type, + (left_keys, right_keys), + )?; if filter.is_empty() { join.build() @@ -548,7 +551,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { join_keys.iter().map(|(_, r)| r.clone()).collect(); let builder = LogicalPlanBuilder::from(left); left = builder - .join(right, JoinType::Inner, left_keys, right_keys)? + .join(right, JoinType::Inner, (left_keys, right_keys))? .build()?; } diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index d9f7c6ea4121..bfe2f2fc4913 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -1717,15 +1717,40 @@ fn create_case_context() -> Result { #[tokio::test] async fn equijoin() -> Result<()> { let mut ctx = create_join_context("t1_id", "t2_id")?; - let sql = - "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id ORDER BY t1_id"; - let actual = execute(&mut ctx, sql).await; + let equivalent_sql = [ + "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id ORDER BY t1_id", + "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t2_id = t1_id ORDER BY t1_id", + ]; let expected = vec![ vec!["11", "a", "z"], vec!["22", "b", "y"], vec!["44", "d", "x"], ]; - assert_eq!(expected, actual); + for sql in equivalent_sql.iter() { + let actual = execute(&mut ctx, sql).await; + assert_eq!(expected, actual); + } + Ok(()) +} + +#[tokio::test] +async fn equijoin_multiple_condition_ordering() -> Result<()> { + let mut ctx = create_join_context("t1_id", "t2_id")?; + let equivalent_sql = [ + "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id AND t1_name <> t2_name ORDER BY t1_id", + "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id AND t2_name <> t1_name ORDER BY t1_id", + "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t2_id = t1_id AND t1_name <> t2_name ORDER BY t1_id", + "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t2_id = t1_id AND t2_name <> t1_name ORDER BY t1_id", + ]; + let expected = vec![ + vec!["11", "a", "z"], + vec!["22", "b", "y"], + vec!["44", "d", "x"], + ]; + for sql in equivalent_sql.iter() { + let actual = execute(&mut ctx, sql).await; + assert_eq!(expected, actual); + } Ok(()) } @@ -1754,39 +1779,50 @@ async fn equijoin_and_unsupported_condition() -> Result<()> { #[tokio::test] async fn left_join() -> Result<()> { let mut ctx = create_join_context("t1_id", "t2_id")?; - let sql = "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id ORDER BY t1_id"; - let actual = execute(&mut ctx, sql).await; + let equivalent_sql = [ + "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id ORDER BY t1_id", + "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t2_id = t1_id ORDER BY t1_id", + ]; let expected = vec![ vec!["11", "a", "z"], vec!["22", "b", "y"], vec!["33", "c", "NULL"], vec!["44", "d", "x"], ]; - assert_eq!(expected, actual); + for sql in equivalent_sql.iter() { + let actual = execute(&mut ctx, sql).await; + assert_eq!(expected, actual); + } Ok(()) } #[tokio::test] async fn right_join() -> Result<()> { let mut ctx = create_join_context("t1_id", "t2_id")?; - let sql = - "SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t1_id = t2_id ORDER BY t1_id"; - let actual = execute(&mut ctx, sql).await; + let equivalent_sql = [ + "SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t1_id = t2_id ORDER BY t1_id", + "SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t2_id = t1_id ORDER BY t1_id" + ]; let expected = vec![ vec!["NULL", "NULL", "w"], vec!["11", "a", "z"], vec!["22", "b", "y"], vec!["44", "d", "x"], ]; - assert_eq!(expected, actual); + for sql in equivalent_sql.iter() { + let actual = execute(&mut ctx, sql).await; + assert_eq!(expected, actual); + } Ok(()) } #[tokio::test] async fn full_join() -> Result<()> { let mut ctx = create_join_context("t1_id", "t2_id")?; - let sql = "SELECT t1_id, t1_name, t2_name FROM t1 FULL JOIN t2 ON t1_id = t2_id ORDER BY t1_id"; - let actual = execute(&mut ctx, sql).await; + let equivalent_sql = [ + "SELECT t1_id, t1_name, t2_name FROM t1 FULL JOIN t2 ON t1_id = t2_id ORDER BY t1_id", + "SELECT t1_id, t1_name, t2_name FROM t1 FULL JOIN t2 ON t2_id = t1_id ORDER BY t1_id", + ]; let expected = vec![ vec!["NULL", "NULL", "w"], vec!["11", "a", "z"], @@ -1794,11 +1830,19 @@ async fn full_join() -> Result<()> { vec!["33", "c", "NULL"], vec!["44", "d", "x"], ]; - assert_eq!(expected, actual); + for sql in equivalent_sql.iter() { + let actual = execute(&mut ctx, sql).await; + assert_eq!(expected, actual); + } - let sql = "SELECT t1_id, t1_name, t2_name FROM t1 FULL OUTER JOIN t2 ON t1_id = t2_id ORDER BY t1_id"; - let actual = execute(&mut ctx, sql).await; - assert_eq!(expected, actual); + let equivalent_sql = [ + "SELECT t1_id, t1_name, t2_name FROM t1 FULL OUTER JOIN t2 ON t1_id = t2_id ORDER BY t1_id", + "SELECT t1_id, t1_name, t2_name FROM t1 FULL OUTER JOIN t2 ON t2_id = t1_id ORDER BY t1_id", + ]; + for sql in equivalent_sql.iter() { + let actual = execute(&mut ctx, sql).await; + assert_eq!(expected, actual); + } Ok(()) } @@ -1821,15 +1865,19 @@ async fn left_join_using() -> Result<()> { #[tokio::test] async fn equijoin_implicit_syntax() -> Result<()> { let mut ctx = create_join_context("t1_id", "t2_id")?; - let sql = - "SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE t1_id = t2_id ORDER BY t1_id"; - let actual = execute(&mut ctx, sql).await; + let equivalent_sql = [ + "SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE t1_id = t2_id ORDER BY t1_id", + "SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE t2_id = t1_id ORDER BY t1_id", + ]; let expected = vec![ vec!["11", "a", "z"], vec!["22", "b", "y"], vec!["44", "d", "x"], ]; - assert_eq!(expected, actual); + for sql in equivalent_sql.iter() { + let actual = execute(&mut ctx, sql).await; + assert_eq!(expected, actual); + } Ok(()) }