From 96a5691ab3a57d0bdb927a4cbb22d546afbe3beb Mon Sep 17 00:00:00 2001 From: Nuttiiya Seekhao Date: Wed, 26 Apr 2023 21:17:40 -0700 Subject: [PATCH 1/3] Fix incorrect join key fields (indices) when same table is being used more than once --- .../substrait/src/logical_plan/producer.rs | 166 +++++++++++------- .../substrait/tests/roundtrip_logical_plan.rs | 12 ++ 2 files changed, 118 insertions(+), 60 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 785bfa4ea6a7..29866f94dcbb 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, usize}; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, @@ -156,7 +156,7 @@ pub fn to_substrait_rel( let expressions = p .expr .iter() - .map(|e| to_substrait_rex(e, p.input.schema(), extension_info)) + .map(|e| to_substrait_rex(e, p.input.schema(), 0, extension_info)) .collect::>>()?; Ok(Box::new(Rel { rel_type: Some(RelType::Project(Box::new(ProjectRel { @@ -172,6 +172,7 @@ pub fn to_substrait_rel( let filter_expr = to_substrait_rex( &filter.predicate, filter.input.schema(), + 0, extension_info, )?; Ok(Box::new(Rel { @@ -218,7 +219,7 @@ pub fn to_substrait_rel( let grouping = agg .group_expr .iter() - .map(|e| to_substrait_rex(e, agg.input.schema(), extension_info)) + .map(|e| to_substrait_rex(e, agg.input.schema(), 0, extension_info)) .collect::>>()?; let measures = agg .aggr_expr @@ -281,45 +282,24 @@ pub fn to_substrait_rel( } else { Operator::Eq }; - let join_expression = join - .on - .iter() - .map(|(l, r)| binary_expr(l.clone(), eq_op, r.clone())) - .reduce(|acc: Expr, expr: Expr| acc.and(expr)); - // join schema from left and right to maintain all nececesary columns from inputs - // note that we cannot simple use join.schema here since we discard some input columns - // when performing semi and anti joins - let join_schema = match join.left.schema().join(join.right.schema()) { - Ok(schema) => Ok(schema), - Err(DataFusionError::SchemaError( - datafusion::common::SchemaError::DuplicateQualifiedField { - qualifier: _, - name: _, - }, - )) => Ok(join.schema.as_ref().clone()), - Err(e) => Err(e), - }; - if let Some(e) = join_expression { - Ok(Box::new(Rel { - rel_type: Some(RelType::Join(Box::new(JoinRel { - common: None, - left: Some(left), - right: Some(right), - r#type: join_type as i32, - expression: Some(Box::new(to_substrait_rex( - &e, - &Arc::new(join_schema?), - extension_info, - )?)), - post_join_filter: None, - advanced_extension: None, - }))), - })) - } else { - Err(DataFusionError::NotImplemented( - "Empty join condition".to_string(), - )) - } + + Ok(Box::new(Rel { + rel_type: Some(RelType::Join(Box::new(JoinRel { + common: None, + left: Some(left), + right: Some(right), + r#type: join_type as i32, + expression: Some(Box::new(to_substrait_join_expr( + &join.on, + eq_op, + join.left.schema(), + join.right.schema(), + extension_info, + )?)), + post_join_filter: None, + advanced_extension: None, + }))), + })) } LogicalPlan::SubqueryAlias(alias) => { // Do nothing if encounters SubqueryAlias @@ -353,6 +333,7 @@ pub fn to_substrait_rel( window_exprs.push(to_substrait_rex( expr, window.input.schema(), + 0, extension_info, )?); } @@ -403,6 +384,40 @@ pub fn to_substrait_rel( } } +fn to_substrait_join_expr( + join_conditions: &Vec<(Expr, Expr)>, + eq_op: Operator, + left_schema: &DFSchemaRef, + right_schema: &DFSchemaRef, + extension_info: &mut ( + Vec, + HashMap, + ), +) -> Result { + // Only support AND confunction for each binary expression in join conditions + let mut exprs: Vec = vec![]; + for (left, right) in join_conditions { + // Parse left + let l = to_substrait_rex(left, left_schema, 0, extension_info)?; + // Parse right + let r = to_substrait_rex( + right, + right_schema, + left_schema.fields().len(), // offset to return the correct index + extension_info, + )?; + // AND with existing expression + exprs.push(make_binary_op_scalar_func(&l, &r, eq_op, extension_info)); + } + let join_expr: Expression = exprs + .into_iter() + .reduce(|acc: Expression, e: Expression| { + make_binary_op_scalar_func(&acc, &e, Operator::And, extension_info) + }) + .unwrap(); + Ok(join_expr) +} + fn to_substrait_jointype(join_type: JoinType) -> join_rel::JoinType { match join_type { JoinType::Inner => join_rel::JoinType::Inner, @@ -459,7 +474,7 @@ pub fn to_substrait_agg_measure( Expr::AggregateFunction(expr::AggregateFunction { fun, args, distinct, filter, order_by: _order_by }) => { let mut arguments: Vec = vec![]; for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, extension_info)?)) }); + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); } let function_name = fun.to_string().to_lowercase(); let function_anchor = _register_function(function_name, extension_info); @@ -478,7 +493,7 @@ pub fn to_substrait_agg_measure( options: vec![], }), filter: match filter { - Some(f) => Some(to_substrait_rex(f, schema, extension_info)?), + Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?), None => None } }) @@ -570,6 +585,7 @@ pub fn make_binary_op_scalar_func( pub fn to_substrait_rex( expr: &Expr, schema: &DFSchemaRef, + col_ref_offset: usize, extension_info: &mut ( Vec, HashMap, @@ -607,9 +623,12 @@ pub fn to_substrait_rex( }) => { if *negated { // `expr NOT BETWEEN low AND high` can be translated into (expr < low OR high < expr) - let substrait_expr = to_substrait_rex(expr, schema, extension_info)?; - let substrait_low = to_substrait_rex(low, schema, extension_info)?; - let substrait_high = to_substrait_rex(high, schema, extension_info)?; + let substrait_expr = + to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; + let substrait_low = + to_substrait_rex(low, schema, col_ref_offset, extension_info)?; + let substrait_high = + to_substrait_rex(high, schema, col_ref_offset, extension_info)?; let l_expr = make_binary_op_scalar_func( &substrait_expr, @@ -632,9 +651,12 @@ pub fn to_substrait_rex( )) } else { // `expr BETWEEN low AND high` can be translated into (low <= expr AND expr <= high) - let substrait_expr = to_substrait_rex(expr, schema, extension_info)?; - let substrait_low = to_substrait_rex(low, schema, extension_info)?; - let substrait_high = to_substrait_rex(high, schema, extension_info)?; + let substrait_expr = + to_substrait_rex(expr, schema, col_ref_offset, extension_info)?; + let substrait_low = + to_substrait_rex(low, schema, col_ref_offset, extension_info)?; + let substrait_high = + to_substrait_rex(high, schema, col_ref_offset, extension_info)?; let l_expr = make_binary_op_scalar_func( &substrait_low, @@ -659,11 +681,11 @@ pub fn to_substrait_rex( } Expr::Column(col) => { let index = schema.index_of_column(col)?; - substrait_field_ref(index) + substrait_field_ref(index + col_ref_offset) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let l = to_substrait_rex(left, schema, extension_info)?; - let r = to_substrait_rex(right, schema, extension_info)?; + let l = to_substrait_rex(left, schema, col_ref_offset, extension_info)?; + let r = to_substrait_rex(right, schema, col_ref_offset, extension_info)?; Ok(make_binary_op_scalar_func(&l, &r, *op, extension_info)) } @@ -677,21 +699,41 @@ pub fn to_substrait_rex( if let Some(e) = expr { // Base expression exists ifs.push(IfClause { - r#if: Some(to_substrait_rex(e, schema, extension_info)?), + r#if: Some(to_substrait_rex( + e, + schema, + col_ref_offset, + extension_info, + )?), then: None, }); } // Parse `when`s for (r#if, then) in when_then_expr { ifs.push(IfClause { - r#if: Some(to_substrait_rex(r#if, schema, extension_info)?), - then: Some(to_substrait_rex(then, schema, extension_info)?), + r#if: Some(to_substrait_rex( + r#if, + schema, + col_ref_offset, + extension_info, + )?), + then: Some(to_substrait_rex( + then, + schema, + col_ref_offset, + extension_info, + )?), }); } // Parse outer `else` let r#else: Option> = match else_expr { - Some(e) => Some(Box::new(to_substrait_rex(e, schema, extension_info)?)), + Some(e) => Some(Box::new(to_substrait_rex( + e, + schema, + col_ref_offset, + extension_info, + )?)), None => None, }; @@ -707,6 +749,7 @@ pub fn to_substrait_rex( input: Some(Box::new(to_substrait_rex( expr, schema, + col_ref_offset, extension_info, )?)), failure_behavior: 0, // FAILURE_BEHAVIOR_UNSPECIFIED @@ -715,7 +758,9 @@ pub fn to_substrait_rex( }) } Expr::Literal(value) => to_substrait_literal(value), - Expr::Alias(expr, _alias) => to_substrait_rex(expr, schema, extension_info), + Expr::Alias(expr, _alias) => { + to_substrait_rex(expr, schema, col_ref_offset, extension_info) + } Expr::WindowFunction(WindowFunction { fun, args, @@ -733,6 +778,7 @@ pub fn to_substrait_rex( arg_type: Some(ArgType::Value(to_substrait_rex( arg, schema, + col_ref_offset, extension_info, )?)), }); @@ -740,7 +786,7 @@ pub fn to_substrait_rex( // partition by expressions let partition_by = partition_by .iter() - .map(|e| to_substrait_rex(e, schema, extension_info)) + .map(|e| to_substrait_rex(e, schema, col_ref_offset, extension_info)) .collect::>>()?; // order by expressions let order_by = order_by @@ -1325,7 +1371,7 @@ fn substrait_sort_field( asc, nulls_first, }) => { - let e = to_substrait_rex(expr, schema, extension_info)?; + let e = to_substrait_rex(expr, schema, 0, extension_info)?; let d = match (asc, nulls_first) { (true, true) => SortDirection::AscNullsFirst, (true, false) => SortDirection::AscNullsLast, diff --git a/datafusion/substrait/tests/roundtrip_logical_plan.rs b/datafusion/substrait/tests/roundtrip_logical_plan.rs index 8cdf89b29473..afd6df9985e0 100644 --- a/datafusion/substrait/tests/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/roundtrip_logical_plan.rs @@ -412,6 +412,18 @@ mod tests { roundtrip("SELECT a,b,c,d,e FROM datafusion.public.data;").await } + #[tokio::test] + async fn roundtrip_inner_join_table_reuse() -> Result<()> { + assert_expected_plan( + "SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.a = d2.a", + "Projection: data.b, data.c\ + \n Inner Join: data.a = data.a\ + \n TableScan: data projection=[a, b]\ + \n TableScan: data projection=[a, c]", + ) + .await + } + /// Construct a plan that contains several literals of types that are currently supported. /// This case ignores: /// - Date64, for this literal is not supported From 904189279007699eaf6dcc562a4312e17a3270c5 Mon Sep 17 00:00:00 2001 From: Nuttiiya Seekhao Date: Mon, 5 Jun 2023 07:03:41 -0700 Subject: [PATCH 2/3] Addressed comments Update datafusion/substrait/src/logical_plan/producer.rs Co-authored-by: Ruihang Xia Update datafusion/substrait/src/logical_plan/producer.rs Co-authored-by: Ruihang Xia --- .../substrait/src/logical_plan/consumer.rs | 2 +- .../substrait/src/logical_plan/producer.rs | 27 ++++++++++++++++--- .../substrait/tests/roundtrip_logical_plan.rs | 14 +++++++++- 3 files changed, 38 insertions(+), 5 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index f914b62a1452..f15ffdf42374 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -365,7 +365,7 @@ pub async fn from_substrait_rel( )), }, _ => Err(DataFusionError::Internal( - "invalid join condition expresssion".to_string(), + "invalid join condition expression".to_string(), )), } } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 29866f94dcbb..dd7bcc0d7d03 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::{collections::HashMap, usize}; +use std::collections::HashMap; use datafusion::{ arrow::datatypes::{DataType, TimeUnit}, @@ -394,7 +394,7 @@ fn to_substrait_join_expr( HashMap, ), ) -> Result { - // Only support AND confunction for each binary expression in join conditions + // Only support AND conjunction for each binary expression in join conditions let mut exprs: Vec = vec![]; for (left, right) in join_conditions { // Parse left @@ -581,7 +581,28 @@ pub fn make_binary_op_scalar_func( } /// Convert DataFusion Expr to Substrait Rex -#[allow(deprecated)] +/// +/// # Arguments +/// +/// * `expr` - DataFusion expression to be parse into a Substrait expression +/// * `schema` - DataFusion input schema for looking up field qualifiers +/// * `col_ref_offset` - Offset for caculating Substrait field reference indices. +/// This should only be set by caller with more than one input relations i.e. Join. +/// Substrait expects one set of indices when joining two relations. +/// Let's say `left` and `right` have `m` and `n` columns, respectively. The `right` +/// relation will have column indices from `0` to `n-1`, however, Substrait will expect +/// the `right` indices to be offset by the `left`. This means Substrait will expect to +/// evaluate the join condition expression on indices [0 .. n-1, n .. n+m-1]. For example: +/// ```SELECT * +/// FROM t1 +/// JOIN t2 +/// ON t1.c1 = t2.c0;``` +/// where t1 consists of columns [c0, c1, c2], and t2 = columns [c0, c1] +/// the join condition should become +/// `col_ref(1) = col_ref(3 + 0)` +/// , where `3` is the number of `left` columns (`col_ref_offset`) and `0` is the index +/// of the join key column from `right` +/// * `extension_info` - Substrait extension info. Contains registered function information pub fn to_substrait_rex( expr: &Expr, schema: &DFSchemaRef, diff --git a/datafusion/substrait/tests/roundtrip_logical_plan.rs b/datafusion/substrait/tests/roundtrip_logical_plan.rs index afd6df9985e0..e209ebedc0f3 100644 --- a/datafusion/substrait/tests/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/roundtrip_logical_plan.rs @@ -413,7 +413,7 @@ mod tests { } #[tokio::test] - async fn roundtrip_inner_join_table_reuse() -> Result<()> { + async fn roundtrip_inner_join_table_reuse_zero_index() -> Result<()> { assert_expected_plan( "SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.a = d2.a", "Projection: data.b, data.c\ @@ -424,6 +424,18 @@ mod tests { .await } + #[tokio::test] + async fn roundtrip_inner_join_table_reuse_non_zero_index() -> Result<()> { + assert_expected_plan( + "SELECT d1.b, d2.c FROM data d1 JOIN data d2 ON d1.b = d2.b", + "Projection: data.b, data.c\ + \n Inner Join: data.b = data.b\ + \n TableScan: data projection=[b]\ + \n TableScan: data projection=[b, c]", + ) + .await + } + /// Construct a plan that contains several literals of types that are currently supported. /// This case ignores: /// - Date64, for this literal is not supported From a8b5ae7be61241322cb239bb2ca6079faf7666be Mon Sep 17 00:00:00 2001 From: Nuttiiya Seekhao Date: Tue, 6 Jun 2023 09:24:30 -0700 Subject: [PATCH 3/3] Fixed bugs after rebase --- datafusion/substrait/src/logical_plan/producer.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index dd7bcc0d7d03..228341548813 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -32,7 +32,7 @@ use datafusion::logical_expr::expr::{ BinaryExpr, Case, Cast, ScalarFunction as DFScalarFunction, Sort, WindowFunction, }; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; -use datafusion::prelude::{binary_expr, Expr}; +use datafusion::prelude::Expr; use prost_types::Any as ProtoAny; use substrait::{ proto::{ @@ -603,6 +603,7 @@ pub fn make_binary_op_scalar_func( /// , where `3` is the number of `left` columns (`col_ref_offset`) and `0` is the index /// of the join key column from `right` /// * `extension_info` - Substrait extension info. Contains registered function information +#[allow(deprecated)] pub fn to_substrait_rex( expr: &Expr, schema: &DFSchemaRef, @@ -620,6 +621,7 @@ pub fn to_substrait_rex( arg_type: Some(ArgType::Value(to_substrait_rex( arg, schema, + col_ref_offset, extension_info, )?)), });