diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala index cd5edd3af6c3..9ce400de56e4 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala @@ -29,6 +29,7 @@ import org.apache.gluten.extension.columnar.transition.{InsertTransitions, Remov import org.apache.gluten.extension.columnar.validator.{Validator, Validators} import org.apache.gluten.extension.injector.{Injector, SparkInjector} import org.apache.gluten.extension.injector.GlutenInjector.LegacyInjector +import org.apache.gluten.extension.joinagg.{ImplementJoinAggregate, PushAggregateThroughJoinBatch} import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.sql.execution._ @@ -56,6 +57,8 @@ object VeloxRuleApi { injector.injectOptimizerRule(CollapseGetJsonObjectExpressionRule.apply) injector.injectOptimizerRule(RewriteCastFromArray.apply) injector.injectOptimizerRule(RewriteUnboundedWindow.apply) + injector.injectOptimizerRule(PushAggregateThroughJoinBatch.apply) + injector.injectPlannerStrategy(ImplementJoinAggregate.apply) if (!BackendsApiManager.getSettings.enableJoinKeysRewrite()) { injector.injectPlannerStrategy(_ => org.apache.gluten.extension.GlutenJoinKeysCapture()) diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/StarSchemaJoinAggregateSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/StarSchemaJoinAggregateSuite.scala new file mode 100644 index 000000000000..54aa3770d749 --- /dev/null +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/StarSchemaJoinAggregateSuite.scala @@ -0,0 +1,1062 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.execution + +import org.apache.gluten.config.GlutenConfig +import org.apache.gluten.extension.joinagg.{ImplementJoinAggregate, PushAggregateThroughJoinBatch} + +import org.apache.spark.SparkConf +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.SparkSessionExtensions +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper + +class VanillaJoinAggregateLogicalOnlyExtensions extends (SparkSessionExtensions => Unit) { + override def apply(extensions: SparkSessionExtensions): Unit = { + extensions.injectOptimizerRule(PushAggregateThroughJoinBatch.apply) + } +} + +class VanillaStarSchemaJoinAggregateLogicalOnlySuite extends StarSchemaJoinAggregateSuite { + override protected def sparkConf: SparkConf = { + super.sparkConf + .set(GlutenConfig.GLUTEN_ENABLED.key, "false") + .set("spark.shuffle.manager", "sort") + .set( + "spark.sql.extensions", + classOf[VanillaJoinAggregateLogicalOnlyExtensions].getCanonicalName) + } + + override protected def checkPlan(df: DataFrame): Unit = {} +} + +class VanillaJoinAggregateExtensions extends (SparkSessionExtensions => Unit) { + override def apply(extensions: SparkSessionExtensions): Unit = { + extensions.injectOptimizerRule(PushAggregateThroughJoinBatch.apply) + extensions.injectPlannerStrategy(ImplementJoinAggregate.apply) + } +} + +class VanillaStarSchemaJoinAggregateSuite extends StarSchemaJoinAggregateSuite { + override protected def sparkConf: SparkConf = { + super.sparkConf + .set(GlutenConfig.GLUTEN_ENABLED.key, "false") + .set("spark.shuffle.manager", "sort") + .set("spark.sql.extensions", classOf[VanillaJoinAggregateExtensions].getCanonicalName) + } + + override protected def checkPlan(df: DataFrame): Unit = {} +} + +class StarSchemaJoinAggregateSingleDepthSuite extends StarSchemaJoinAggregateSuite { + override protected def sparkConf: SparkConf = { + super.sparkConf + .set(GlutenConfig.PUSH_AGGREGATE_THROUGH_JOIN_MAX_DEPTH.key, "1") + } +} + +class StarSchemaJoinAggregateSuite extends VeloxTPCHTableSupport with AdaptiveSparkPlanHelper { + private val factMeasureColumnNames = Set( + "sales_price", + "return_amt", + "profit", + "net_loss", + "ss_sales_price", + "ss_net_profit", + "sr_return_amt", + "sr_net_loss", + "cs_ext_sales_price", + "cs_net_profit", + "cr_return_amount", + "cr_net_loss", + "ws_ext_sales_price", + "ws_net_profit", + "wr_return_amt", + "wr_net_loss" + ) + + override protected def sparkConf: SparkConf = { + super.sparkConf + .set(GlutenConfig.COLUMNAR_FORCE_SHUFFLED_HASH_JOIN_ENABLED.key, "true") + .set(GlutenConfig.PUSH_AGGREGATE_THROUGH_JOIN_ENABLED.key, "true") + .set(GlutenConfig.PUSH_AGGREGATE_THROUGH_JOIN_MAX_DEPTH.key, s"${Int.MaxValue}") + .set("spark.sql.adaptive.enabled", "false") + } + + private val tpcdsTempTables = Seq( + "store_sales", + "store_returns", + "catalog_sales", + "catalog_returns", + "web_sales", + "web_returns", + "date_dim", + "store", + "catalog_page", + "customer_address", + "call_center", + "web_site", + "warehouse", + "ship_mode", + "item", + "promotion", + "customer_demographics" + ) + + override def beforeAll(): Unit = { + super.beforeAll() + createTPCDSMiniTables() + } + + override def afterAll(): Unit = { + tpcdsTempTables.foreach { + t => + if (spark.catalog.tableExists(t)) { + spark.catalog.dropTempView(t) + } + } + super.afterAll() + } + + private def createTPCDSMiniTables(): Unit = { + spark.sql( + """ + |CREATE OR REPLACE TEMP VIEW date_dim AS + |SELECT CAST(d_date_sk AS BIGINT) AS d_date_sk, + | CAST(d_date AS DATE) AS d_date, + | CAST(d_year AS INT) AS d_year, + | CAST(d_moy AS INT) AS d_moy, + | CAST(d_month_seq AS INT) AS d_month_seq, + | CAST(d_week_seq AS INT) AS d_week_seq, + | CAST(d_quarter_name AS STRING) AS d_quarter_name, + | CAST(d_day_name AS STRING) AS d_day_name + |FROM VALUES + | (1, DATE'1998-08-05', 1998, 8, 1100, 10, '1998Q1', 'Wednesday'), + | (2, DATE'1998-08-06', 1998, 8, 1100, 10, '1998Q1', 'Thursday'), + | (3, DATE'1998-08-07', 1998, 8, 1100, 10, '1998Q1', 'Friday'), + | (10, DATE'1999-02-10', 1999, 2, 1212, 40, '1999Q1', 'Wednesday'), + | (11, DATE'1999-03-01', 1999, 3, 1213, 43, '1999Q1', 'Monday'), + | (12, DATE'1999-04-01', 1999, 4, 1214, 48, '1999Q2', 'Thursday'), + | (13, DATE'1999-07-01', 1999, 7, 1217, 61, '1999Q3', 'Thursday'), + | (200, DATE'1999-12-15', 1999, 12, 1300, 70, '1999Q4', 'Wednesday'), + | (201, DATE'2000-01-15', 2000, 1, 1301, 71, '2000Q1', 'Saturday'), + | (202, DATE'2001-01-15', 2001, 1, 1452, 72, '2001Q1', 'Monday'), + | (1001, DATE'2001-01-07', 2001, 1, 1452, 1, '2001Q1', 'Sunday'), + | (1002, DATE'2001-01-08', 2001, 1, 1452, 1, '2001Q1', 'Monday'), + | (1054, DATE'2002-01-13', 2002, 1, 1464, 54, '2002Q1', 'Sunday'), + | (1055, DATE'2002-01-14', 2002, 1, 1464, 54, '2002Q1', 'Monday') + |AS t(d_date_sk, d_date, d_year, d_moy, d_month_seq, d_week_seq, d_quarter_name, d_day_name) + |""".stripMargin) + + spark.sql(""" + |CREATE OR REPLACE TEMP VIEW store AS + |SELECT CAST(s_store_sk AS BIGINT) AS s_store_sk, + | CAST(s_store_id AS STRING) AS s_store_id, + | CAST(s_store_name AS STRING) AS s_store_name, + | CAST(s_company_name AS STRING) AS s_company_name, + | CAST(s_state AS STRING) AS s_state + |FROM VALUES + | (10, 'S10', 'Store 10', 'StoreCo', 'CA'), + | (20, 'S20', 'Store 20', 'StoreCo', 'WA') + |AS t(s_store_sk, s_store_id, s_store_name, s_company_name, s_state) + |""".stripMargin) + + spark.sql(""" + |CREATE OR REPLACE TEMP VIEW catalog_page AS + |SELECT CAST(cp_catalog_page_sk AS BIGINT) AS cp_catalog_page_sk, + | CAST(cp_catalog_page_id AS STRING) AS cp_catalog_page_id + |FROM VALUES + | (100, 'CP100'), + | (200, 'CP200') + |AS t(cp_catalog_page_sk, cp_catalog_page_id) + |""".stripMargin) + + spark.sql(""" + |CREATE OR REPLACE TEMP VIEW web_site AS + |SELECT CAST(web_site_sk AS BIGINT) AS web_site_sk, + | CAST(web_site_id AS STRING) AS web_site_id, + | CAST(web_name AS STRING) AS web_name + |FROM VALUES + | (1000, 'W1000', 'Web 1000'), + | (2000, 'W2000', 'Web 2000') + |AS t(web_site_sk, web_site_id, web_name) + |""".stripMargin) + + spark.sql(""" + |CREATE OR REPLACE TEMP VIEW warehouse AS + |SELECT CAST(w_warehouse_sk AS BIGINT) AS w_warehouse_sk, + | CAST(w_warehouse_name AS STRING) AS w_warehouse_name + |FROM VALUES + | (1, 'Warehouse 1'), + | (2, 'Warehouse 2') + |AS t(w_warehouse_sk, w_warehouse_name) + |""".stripMargin) + + spark.sql(""" + |CREATE OR REPLACE TEMP VIEW ship_mode AS + |SELECT CAST(sm_ship_mode_sk AS BIGINT) AS sm_ship_mode_sk, + | CAST(sm_type AS STRING) AS sm_type + |FROM VALUES + | (11, 'AIR'), + | (22, 'GROUND') + |AS t(sm_ship_mode_sk, sm_type) + |""".stripMargin) + + spark.sql(""" + |CREATE OR REPLACE TEMP VIEW store_sales AS + |SELECT CAST(ss_store_sk AS BIGINT) AS ss_store_sk, + | CAST(ss_sold_date_sk AS BIGINT) AS ss_sold_date_sk, + | CAST(ss_ext_sales_price AS DECIMAL(7,2)) AS ss_ext_sales_price, + | CAST(ss_net_profit AS DECIMAL(7,2)) AS ss_net_profit, + | CAST(ss_item_sk AS BIGINT) AS ss_item_sk, + | CAST(ss_customer_sk AS BIGINT) AS ss_customer_sk, + | CAST(ss_ticket_number AS BIGINT) AS ss_ticket_number, + | CAST(ss_cdemo_sk AS BIGINT) AS ss_cdemo_sk, + | CAST(ss_promo_sk AS BIGINT) AS ss_promo_sk, + | CAST(ss_quantity AS DECIMAL(7,2)) AS ss_quantity, + | CAST(ss_list_price AS DECIMAL(7,2)) AS ss_list_price, + | CAST(ss_coupon_amt AS DECIMAL(7,2)) AS ss_coupon_amt, + | CAST(ss_sales_price AS DECIMAL(7,2)) AS ss_sales_price + |FROM VALUES + | (10, 1, 11.00, 3.00, 1001, 5001, 7001, 2001, 3001, 1.00, 10.00, 0.50, 9.50), + | (20, 2, 12.00, 4.00, 1002, 5002, 7002, 2002, 3002, 2.00, 20.00, 1.00, 19.00), + | (10, 3, 13.00, 5.00, 1001, 5001, 7003, 2001, 3001, 3.00, 30.00, 1.50, 28.50), + | (10, 200, 14.00, 2.00, 1001, 5001, 7101, 2001, 3001, 1.00, 11.00, 0.30, 10.70), + | (10, 201, 20.00, 4.00, 1001, 5001, 7102, 2001, 3001, 2.00, 12.00, 0.40, 11.60), + | (10, 202, 10.00, 1.00, 1001, 5001, 7103, 2001, 3001, 3.00, 13.00, 0.20, 12.80) + |AS t( + | ss_store_sk, ss_sold_date_sk, ss_ext_sales_price, ss_net_profit, + | ss_item_sk, ss_customer_sk, ss_ticket_number, ss_cdemo_sk, ss_promo_sk, + | ss_quantity, ss_list_price, ss_coupon_amt, ss_sales_price) + |""".stripMargin) + + spark.sql(""" + |CREATE OR REPLACE TEMP VIEW item AS + |SELECT CAST(i_item_sk AS BIGINT) AS i_item_sk, + | CAST(i_item_id AS STRING) AS i_item_id, + | CAST(i_category AS STRING) AS i_category, + | CAST(i_brand AS STRING) AS i_brand, + | CAST(i_item_desc AS STRING) AS i_item_desc + |FROM VALUES + | (1001, 'I1001', 'Category 1', 'Brand 1', 'Item 1001 desc'), + | (1002, 'I1002', 'Category 2', 'Brand 2', 'Item 1002 desc') + |AS t(i_item_sk, i_item_id, i_category, i_brand, i_item_desc) + |""".stripMargin) + + spark.sql(""" + |CREATE OR REPLACE TEMP VIEW promotion AS + |SELECT CAST(p_promo_sk AS BIGINT) AS p_promo_sk, + | CAST(p_channel_email AS STRING) AS p_channel_email, + | CAST(p_channel_event AS STRING) AS p_channel_event + |FROM VALUES + | (3001, 'N', 'Y'), + | (3002, 'Y', 'N') + |AS t(p_promo_sk, p_channel_email, p_channel_event) + |""".stripMargin) + + spark.sql(""" + |CREATE OR REPLACE TEMP VIEW customer_demographics AS + |SELECT CAST(cd_demo_sk AS BIGINT) AS cd_demo_sk, + | CAST(cd_gender AS STRING) AS cd_gender, + | CAST(cd_marital_status AS STRING) AS cd_marital_status, + | CAST(cd_education_status AS STRING) AS cd_education_status + |FROM VALUES + | (2001, 'F', 'W', 'Primary'), + | (2002, 'F', 'W', 'Primary') + |AS t(cd_demo_sk, cd_gender, cd_marital_status, cd_education_status) + |""".stripMargin) + + spark.sql(""" + |CREATE OR REPLACE TEMP VIEW store_returns AS + |SELECT CAST(sr_store_sk AS BIGINT) AS sr_store_sk, + | CAST(sr_returned_date_sk AS BIGINT) AS sr_returned_date_sk, + | CAST(sr_customer_sk AS BIGINT) AS sr_customer_sk, + | CAST(sr_item_sk AS BIGINT) AS sr_item_sk, + | CAST(sr_ticket_number AS BIGINT) AS sr_ticket_number, + | CAST(sr_return_quantity AS DECIMAL(7,2)) AS sr_return_quantity, + | CAST(sr_return_amt AS DECIMAL(7,2)) AS sr_return_amt, + | CAST(sr_net_loss AS DECIMAL(7,2)) AS sr_net_loss + |FROM VALUES + | (10, 1, 5001, 1001, 7001, 1.00, 1.00, 0.50), + | (20, 2, 5002, 1002, 7002, 1.00, 2.00, 0.25), + | (10, 3, 5001, 1001, 7003, 2.00, 0.50, 0.10) + |AS t( + | sr_store_sk, sr_returned_date_sk, sr_customer_sk, sr_item_sk, sr_ticket_number, + | sr_return_quantity, sr_return_amt, sr_net_loss) + |""".stripMargin) + + spark.sql(""" + |CREATE OR REPLACE TEMP VIEW catalog_sales AS + |SELECT CAST(cs_catalog_page_sk AS BIGINT) AS cs_catalog_page_sk, + | CAST(cs_sold_date_sk AS BIGINT) AS cs_sold_date_sk, + | CAST(cs_ext_sales_price AS DECIMAL(7,2)) AS cs_ext_sales_price, + | CAST(cs_net_profit AS DECIMAL(7,2)) AS cs_net_profit, + | CAST(cs_order_number AS BIGINT) AS cs_order_number, + | CAST(cs_bill_customer_sk AS BIGINT) AS cs_bill_customer_sk, + | CAST(cs_item_sk AS BIGINT) AS cs_item_sk, + | CAST(cs_quantity AS DECIMAL(7,2)) AS cs_quantity, + | CAST(cs_ext_ship_cost AS DECIMAL(7,2)) AS cs_ext_ship_cost, + | CAST(cs_ship_date_sk AS BIGINT) AS cs_ship_date_sk, + | CAST(cs_ship_addr_sk AS BIGINT) AS cs_ship_addr_sk, + | CAST(cs_call_center_sk AS BIGINT) AS cs_call_center_sk, + | CAST(cs_warehouse_sk AS BIGINT) AS cs_warehouse_sk + |FROM VALUES + | (100, 1, 21.00, 5.00, 4001, 5001, 1001, 2.00, 2.00, 1, 9001, 7001, 1), + | (200, 2, 22.00, 6.00, 4002, 5002, 1002, 1.00, 3.00, 2, 9002, 7002, 1), + | (100, 10, 30.00, 2.00, 5001, 5001, 1001, 1.00, 5.00, 10, 9001, 7001, 1), + | (200, 10, 35.00, 1.50, 5001, 5001, 1001, 1.00, 3.00, 10, 9001, 7001, 2), + | (100, 10, 40.00, 1.00, 5002, 5002, 1002, 2.00, 4.00, 10, 9001, 7001, 1), + | (100, 1001, 50.00, 3.00, 6001, 5001, 1001, 3.00, 2.00, 1001, 9001, 7001, 1), + | (100, 1054, 70.00, 4.00, 6002, 5001, 1001, 4.00, 2.00, 1054, 9001, 7001, 1) + |AS t( + | cs_catalog_page_sk, cs_sold_date_sk, cs_ext_sales_price, cs_net_profit, + | cs_order_number, cs_bill_customer_sk, cs_item_sk, cs_quantity, cs_ext_ship_cost, + | cs_ship_date_sk, cs_ship_addr_sk, cs_call_center_sk, cs_warehouse_sk) + |""".stripMargin) + + spark.sql(""" + |CREATE OR REPLACE TEMP VIEW catalog_returns AS + |SELECT CAST(cr_catalog_page_sk AS BIGINT) AS cr_catalog_page_sk, + | CAST(cr_returned_date_sk AS BIGINT) AS cr_returned_date_sk, + | CAST(cr_return_amount AS DECIMAL(7,2)) AS cr_return_amount, + | CAST(cr_net_loss AS DECIMAL(7,2)) AS cr_net_loss, + | CAST(cr_order_number AS BIGINT) AS cr_order_number + |FROM VALUES + | (100, 1, 1.50, 0.30, 3001), + | (200, 2, 0.50, 0.10, 3002), + | (100, 10, 1.00, 0.10, 5002) + |AS t( + | cr_catalog_page_sk, cr_returned_date_sk, cr_return_amount, cr_net_loss, + | cr_order_number) + |""".stripMargin) + + spark.sql(""" + |CREATE OR REPLACE TEMP VIEW customer_address AS + |SELECT CAST(ca_address_sk AS BIGINT) AS ca_address_sk, + | CAST(ca_state AS STRING) AS ca_state + |FROM VALUES + | (9001, 'IL'), + | (9002, 'CA') + |AS t(ca_address_sk, ca_state) + |""".stripMargin) + + spark.sql(""" + |CREATE OR REPLACE TEMP VIEW call_center AS + |SELECT CAST(cc_call_center_sk AS BIGINT) AS cc_call_center_sk, + | CAST(cc_county AS STRING) AS cc_county + |FROM VALUES + | (7001, 'Williamson County'), + | (7002, 'Other County') + |AS t(cc_call_center_sk, cc_county) + |""".stripMargin) + + spark.sql(""" + |CREATE OR REPLACE TEMP VIEW web_sales AS + |SELECT CAST(ws_web_site_sk AS BIGINT) AS ws_web_site_sk, + | CAST(ws_sold_date_sk AS BIGINT) AS ws_sold_date_sk, + | CAST(ws_ship_date_sk AS BIGINT) AS ws_ship_date_sk, + | CAST(ws_warehouse_sk AS BIGINT) AS ws_warehouse_sk, + | CAST(ws_ship_mode_sk AS BIGINT) AS ws_ship_mode_sk, + | CAST(ws_ext_sales_price AS DECIMAL(7,2)) AS ws_ext_sales_price, + | CAST(ws_net_profit AS DECIMAL(7,2)) AS ws_net_profit, + | CAST(ws_item_sk AS BIGINT) AS ws_item_sk, + | CAST(ws_order_number AS BIGINT) AS ws_order_number + |FROM VALUES + | (1000, 1, 10, 1, 11, 31.00, 7.00, 500, 9000), + | (2000, 2, 11, 2, 22, 32.00, 8.00, 600, 9001), + | (1000, 1002, 1054, 1, 11, 40.00, 5.00, 700, 9100), + | (1000, 1055, 1055, 1, 22, 80.00, 6.00, 701, 9101) + |AS t( + | ws_web_site_sk, ws_sold_date_sk, ws_ship_date_sk, ws_warehouse_sk, + | ws_ship_mode_sk, ws_ext_sales_price, ws_net_profit, ws_item_sk, ws_order_number) + |""".stripMargin) + + spark.sql(""" + |CREATE OR REPLACE TEMP VIEW web_returns AS + |SELECT CAST(wr_returned_date_sk AS BIGINT) AS wr_returned_date_sk, + | CAST(wr_return_amt AS DECIMAL(7,2)) AS wr_return_amt, + | CAST(wr_net_loss AS DECIMAL(7,2)) AS wr_net_loss, + | CAST(wr_item_sk AS BIGINT) AS wr_item_sk, + | CAST(wr_order_number AS BIGINT) AS wr_order_number + |FROM VALUES + | (1, 2.00, 0.60, 500, 9000), + | (2, 3.00, 0.40, 600, 9001) + |AS t(wr_returned_date_sk, wr_return_amt, wr_net_loss, wr_item_sk, wr_order_number) + |""".stripMargin) + } + + private def checkDf(df: DataFrame): Unit = { + assert(df.queryExecution.optimizedPlan.toString().contains("join_agg_wrapper_")) + val invalidPushedAggregates = df.queryExecution.optimizedPlan.collect { + case agg: Aggregate + if agg.aggregateExpressions.exists(_.exists { + case AggregateExpression(_, _, _, _, _) => true + case _ => false + }) && + agg.aggregateExpressions.exists(_.toString.contains("join_agg_wrapper_partial")) => + val offendingGroupingExprs = agg.groupingExpressions.filter { + groupingExpr => + groupingExpr.references.exists(attr => factMeasureColumnNames.contains(attr.name)) + } + if (offendingGroupingExprs.nonEmpty) { + Some((agg, offendingGroupingExprs)) + } else { + None + } + }.flatten + assert( + invalidPushedAggregates.isEmpty, + invalidPushedAggregates + .map { + case (agg, offendingGroupingExprs) => + s"Unexpected fact measure in pushed aggregate grouping: " + + s"${offendingGroupingExprs.mkString(", ")}\n${agg.treeString}" + } + .mkString("\n---\n") + ) + checkPlan(df) + } + + private def checkDfNoPush(df: DataFrame): Unit = { + assert(!df.queryExecution.optimizedPlan.toString().contains("join_agg_wrapper_")) + checkPlan(df) + } + + protected def checkPlan(df: DataFrame): Unit = { + checkGlutenPlan[HashAggregateExecTransformer](df) + } + + test("Join-aggregate wrapper aggregate") { + val query = + """ + |SELECT + | SUM(l_extendedprice * (1 - l_discount)) AS revenue + |FROM lineitem + |JOIN part + | ON l_partkey = p_partkey + |WHERE p_size > 10 AND l_shipmode IN ('AIR', 'RAIL') + |""".stripMargin + + runQueryAndCompare(query)(df => checkDf(df)) + } + + test("Join-aggregate wrapper aggregate 2") { + val query = + """ + |SELECT + | l_discount, p_partkey, AVG(l_extendedprice) + |FROM lineitem + |JOIN part + | ON l_partkey = p_partkey + |GROUP BY l_discount, p_partkey + |ORDER BY l_discount, p_partkey + |LIMIT 100 + |""".stripMargin + + runQueryAndCompare(query)(df => checkDf(df)) + } + + test("Support join-aggregate wrapper aggregate for simplified TPC-H q12") { + val query = + """ + |SELECT + | l_shipmode, + | SUM(CASE + | WHEN o_orderpriority = '1-URGENT' OR o_orderpriority = '2-HIGH' + | THEN 1 + | ELSE 0 + | END) AS high_line_count, + | SUM(CASE + | WHEN o_orderpriority <> '1-URGENT' AND o_orderpriority <> '2-HIGH' + | THEN 1 + | ELSE 0 + | END) AS low_line_count + |FROM orders + |JOIN lineitem + | ON o_orderkey = l_orderkey + |WHERE l_shipmode IN ('MAIL', 'SHIP') + | AND l_commitdate < l_receiptdate + | AND l_shipdate < l_commitdate + | AND l_receiptdate >= DATE '1994-01-01' + | AND l_receiptdate < DATE '1995-01-01' + |GROUP BY l_shipmode + |ORDER BY l_shipmode + |""".stripMargin + + runQueryAndCompare(query)(df => checkDf(df)) + } + + test("Support join-aggregate wrapper aggregate for multi-field agg buffer (avg)") { + val query = + """ + |SELECT + | c_nationkey, + | AVG(CAST(o_shippriority AS DOUBLE)) AS avg_shippriority + |FROM customer + |JOIN orders + | ON c_custkey = o_custkey + |WHERE c_mktsegment = 'BUILDING' + |GROUP BY c_nationkey + |ORDER BY c_nationkey + |""".stripMargin + + runQueryAndCompare(query)(df => checkDf(df)) + } + + test("Support join-aggregate wrapper aggregate with duplicate decimal sum buffers") { + val query = + """ + |SELECT + | c_nationkey, + | SUM(o_totalprice) AS sum_totalprice_a, + | SUM(o_totalprice + CAST(1 AS DECIMAL(12, 2))) AS sum_totalprice_b, + | SUM(CAST(o_shippriority AS BIGINT)) AS sum_shippriority, + | SUM(CAST(o_orderkey AS BIGINT)) AS sum_orderkey + |FROM customer + |JOIN orders + | ON c_custkey = o_custkey + |WHERE c_mktsegment = 'BUILDING' + |GROUP BY c_nationkey + |ORDER BY c_nationkey + |""".stripMargin + + runQueryAndCompare(query)(df => checkDf(df)) + } + + test("Minimal pushdown for two different decimal sums") { + withTempView("ss_fact", "ss_dim") { + spark.sql(""" + |CREATE OR REPLACE TEMP VIEW ss_fact AS + |SELECT + | CAST(item_sk AS INT) AS item_sk, + | CAST(cp_catalog_page_id AS INT) AS grp_id, + | CAST(sales_price AS DECIMAL(12, 2)) AS dec_a, + | CAST(return_amt AS DECIMAL(12, 2)) AS dec_b + |FROM VALUES + | (1, 100, 10.00, 5.00), + | (1, 100, 20.00, 8.00), + | (2, 200, 30.00, 9.00) + |AS t(item_sk, cp_catalog_page_id, sales_price, return_amt) + |""".stripMargin) + + spark.sql(""" + |CREATE OR REPLACE TEMP VIEW ss_dim AS + |SELECT CAST(item_sk AS INT) AS item_sk + |FROM VALUES + | (1), (1), (2), (2), (2) + |AS t(item_sk) + |""".stripMargin) + + val query = + """ + |SELECT + | f.grp_id, + | SUM(f.dec_a) AS sum_dec_a, + | SUM(f.dec_b) AS sum_dec_b + |FROM ss_fact f + |JOIN ss_dim d + | ON f.item_sk = d.item_sk + |GROUP BY f.grp_id + |ORDER BY f.grp_id + |""".stripMargin + + withSQLConf( + "spark.sql.adaptive.enabled" -> "true", + "spark.sql.shuffle.partitions" -> "100", + "spark.sql.autoBroadcastJoinThreshold" -> "10m") { + runQueryAndCompare(query)(df => checkDf(df)) + } + } + } + + test("Minimal pushdown for decimal avg") { + withTempView("ss_fact_avg", "ss_dim_avg") { + spark.sql(""" + |CREATE OR REPLACE TEMP VIEW ss_fact_avg AS + |SELECT + | CAST(item_sk AS INT) AS item_sk, + | CAST(grp_id AS INT) AS grp_id, + | CAST(metric AS DECIMAL(7, 2)) AS metric + |FROM VALUES + | (1, 100, 10.00), + | (1, 100, 20.00), + | (2, 200, 30.00) + |AS t(item_sk, grp_id, metric) + |""".stripMargin) + + spark.sql(""" + |CREATE OR REPLACE TEMP VIEW ss_dim_avg AS + |SELECT CAST(item_sk AS INT) AS item_sk + |FROM VALUES + | (1), (1), (2), (2), (2) + |AS t(item_sk) + |""".stripMargin) + + val query = + """ + |SELECT + | f.grp_id, + | AVG(f.metric) AS avg_metric + |FROM ss_fact_avg f + |JOIN ss_dim_avg d + | ON f.item_sk = d.item_sk + |GROUP BY f.grp_id + |ORDER BY f.grp_id + |""".stripMargin + + withSQLConf( + "spark.sql.adaptive.enabled" -> "false", + "spark.sql.shuffle.partitions" -> "100", + "spark.sql.autoBroadcastJoinThreshold" -> "10m") { + runQueryAndCompare(query)(df => checkDf(df)) + } + } + } + + test("Support TPC-DS q5 shape") { + val query = + """ + |with ssr as + | (select s_store_id, + | sum(sales_price) as sales, + | sum(profit) as profit, + | sum(return_amt) as returns, + | sum(net_loss) as profit_loss + | from + | ( select ss_store_sk as store_sk, + | ss_sold_date_sk as date_sk, + | ss_ext_sales_price as sales_price, + | ss_net_profit as profit, + | cast(0 as decimal(7,2)) as return_amt, + | cast(0 as decimal(7,2)) as net_loss + | from store_sales + | union all + | select sr_store_sk as store_sk, + | sr_returned_date_sk as date_sk, + | cast(0 as decimal(7,2)) as sales_price, + | cast(0 as decimal(7,2)) as profit, + | sr_return_amt as return_amt, + | sr_net_loss as net_loss + | from store_returns + | ) salesreturns, + | date_dim, + | store + | where date_sk = d_date_sk + | and d_date between cast('1998-08-04' as date) + | and (cast('1998-08-04' as date) + interval '14' day) + | and store_sk = s_store_sk + | group by s_store_id) + | , + | csr as + | (select cp_catalog_page_id, + | sum(sales_price) as sales, + | sum(profit) as profit, + | sum(return_amt) as returns, + | sum(net_loss) as profit_loss + | from + | ( select cs_catalog_page_sk as page_sk, + | cs_sold_date_sk as date_sk, + | cs_ext_sales_price as sales_price, + | cs_net_profit as profit, + | cast(0 as decimal(7,2)) as return_amt, + | cast(0 as decimal(7,2)) as net_loss + | from catalog_sales + | union all + | select cr_catalog_page_sk as page_sk, + | cr_returned_date_sk as date_sk, + | cast(0 as decimal(7,2)) as sales_price, + | cast(0 as decimal(7,2)) as profit, + | cr_return_amount as return_amt, + | cr_net_loss as net_loss + | from catalog_returns + | ) salesreturns, + | date_dim, + | catalog_page + | where date_sk = d_date_sk + | and d_date between cast('1998-08-04' as date) + | and (cast('1998-08-04' as date) + interval '14' day) + | and page_sk = cp_catalog_page_sk + | group by cp_catalog_page_id) + | , + | wsr as + | (select web_site_id, + | sum(sales_price) as sales, + | sum(profit) as profit, + | sum(return_amt) as returns, + | sum(net_loss) as profit_loss + | from + | ( select ws_web_site_sk as wsr_web_site_sk, + | ws_sold_date_sk as date_sk, + | ws_ext_sales_price as sales_price, + | ws_net_profit as profit, + | cast(0 as decimal(7,2)) as return_amt, + | cast(0 as decimal(7,2)) as net_loss + | from web_sales + | union all + | select ws_web_site_sk as wsr_web_site_sk, + | wr_returned_date_sk as date_sk, + | cast(0 as decimal(7,2)) as sales_price, + | cast(0 as decimal(7,2)) as profit, + | wr_return_amt as return_amt, + | wr_net_loss as net_loss + | from web_returns left outer join web_sales on + | ( wr_item_sk = ws_item_sk + | and wr_order_number = ws_order_number) + | ) salesreturns, + | date_dim, + | web_site + | where date_sk = d_date_sk + | and d_date between cast('1998-08-04' as date) + | and (cast('1998-08-04' as date) + interval '14' day) + | and wsr_web_site_sk = web_site_sk + | group by web_site_id) + | select channel + | , id + | , sum(sales) as sales + | , sum(returns) as returns + | , sum(profit) as profit + | from + | (select 'store channel' as channel + | , 'store' || s_store_id as id + | , sales + | , returns + | , (profit - profit_loss) as profit + | from ssr + | union all + | select 'catalog channel' as channel + | , 'catalog_page' || cp_catalog_page_id as id + | , sales + | , returns + | , (profit - profit_loss) as profit + | from csr + | union all + | select 'web channel' as channel + | , 'web_site' || web_site_id as id + | , sales + | , returns + | , (profit - profit_loss) as profit + | from wsr + | ) x + | group by rollup (channel, id) + | order by channel + | ,id + | LIMIT 100 + |""".stripMargin + + runQueryAndCompare(query)(df => checkDf(df)) + } + + test("Support simplified TPC-DS q7 shape") { + val query = + """ + |select i_item_id, + | avg(ss_quantity) agg1 + |from store_sales, item + |where ss_item_sk = i_item_sk + |group by i_item_id + |order by i_item_id + |limit 100 + |""".stripMargin + + withSQLConf( + "spark.sql.adaptive.enabled" -> "false", + "spark.sql.shuffle.partitions" -> "100", + "spark.sql.autoBroadcastJoinThreshold" -> "10m") { + runQueryAndCompare(query)(df => checkDf(df)) + } + } + + test("Support simplified TPC-DS q16 shape") { + val query = + """ + |select + | count(distinct cs_order_number) as `order count` + | ,sum(cs_ext_ship_cost) as `total shipping cost` + | ,sum(cs_net_profit) as `total net profit` + |from + | catalog_sales cs1 + | ,date_dim + | ,customer_address + | ,call_center + |where + | d_date between '1999-2-01' and + | (cast('1999-2-01' as date) + interval '60' day) + |and cs1.cs_ship_date_sk = d_date_sk + |and cs1.cs_ship_addr_sk = ca_address_sk + |and ca_state = 'IL' + |and cs1.cs_call_center_sk = cc_call_center_sk + |and cc_county in ('Williamson County','Williamson County','Williamson County', + | 'Williamson County','Williamson County') + |and exists (select * + | from catalog_sales cs2 + | where cs1.cs_order_number = cs2.cs_order_number + | and cs1.cs_warehouse_sk <> cs2.cs_warehouse_sk) + |and not exists(select * + | from catalog_returns cr1 + | where cs1.cs_order_number = cr1.cr_order_number) + |order by count(distinct cs_order_number) + |limit 100 + |""".stripMargin + + runQueryAndCompare(query) { + df => + // Mixed distinct + non-distinct aggregate shape is currently not pushed by the + // Join-aggregate pre-aggregate rule. + checkDfNoPush(df) + } + } + + test("Support TPC-DS q17 shape") { + val query = + """ + |select i_item_id + | ,i_item_desc + | ,s_state + | ,count(ss_quantity) as store_sales_quantitycount + | ,avg(ss_quantity) as store_sales_quantityave + | ,stddev_samp(ss_quantity) as store_sales_quantitystdev + | ,stddev_samp(ss_quantity)/avg(ss_quantity) as store_sales_quantitycov + | ,count(sr_return_quantity) as store_returns_quantitycount + | ,avg(sr_return_quantity) as store_returns_quantityave + | ,stddev_samp(sr_return_quantity) as store_returns_quantitystdev + | ,stddev_samp(sr_return_quantity)/avg(sr_return_quantity) + | as store_returns_quantitycov + | ,count(cs_quantity) as catalog_sales_quantitycount + | ,avg(cs_quantity) as catalog_sales_quantityave + | ,stddev_samp(cs_quantity) as catalog_sales_quantitystdev + | ,stddev_samp(cs_quantity)/avg(cs_quantity) as catalog_sales_quantitycov + |from store_sales + | ,store_returns + | ,catalog_sales + | ,date_dim d1 + | ,date_dim d2 + | ,date_dim d3 + | ,store + | ,item + |where d1.d_quarter_name = '1998Q1' + | and d1.d_date_sk = ss_sold_date_sk + | and i_item_sk = ss_item_sk + | and s_store_sk = ss_store_sk + | and ss_customer_sk = sr_customer_sk + | and ss_item_sk = sr_item_sk + | and ss_ticket_number = sr_ticket_number + | and sr_returned_date_sk = d2.d_date_sk + | and d2.d_quarter_name in ('1998Q1','1998Q2','1998Q3') + | and sr_customer_sk = cs_bill_customer_sk + | and sr_item_sk = cs_item_sk + | and cs_sold_date_sk = d3.d_date_sk + | and d3.d_quarter_name in ('1998Q1','1998Q2','1998Q3') + |group by i_item_id + | ,i_item_desc + | ,s_state + |order by i_item_id + | ,i_item_desc + | ,s_state + |limit 100 + |""".stripMargin + + runQueryAndCompare(query)(df => checkDf(df)) + } + + test("Support SUM + COUNT DISTINCT shape") { + withTempView("ss_fact_distinct", "ss_dim_distinct") { + spark.sql(""" + |CREATE OR REPLACE TEMP VIEW ss_fact_distinct AS + |SELECT + | CAST(item_sk AS INT) AS item_sk, + | CAST(grp_id AS INT) AS grp_id, + | CAST(metric AS DECIMAL(7, 2)) AS metric, + | CAST(order_no AS BIGINT) AS order_no + |FROM VALUES + | (1, 100, 10.00, 9001), + | (1, 100, 20.00, 9002), + | (2, 200, 30.00, 9002), + | (2, 200, 15.00, 9003) + |AS t(item_sk, grp_id, metric, order_no) + |""".stripMargin) + + spark.sql(""" + |CREATE OR REPLACE TEMP VIEW ss_dim_distinct AS + |SELECT CAST(item_sk AS INT) AS item_sk + |FROM VALUES + | (1), (2) + |AS t(item_sk) + |""".stripMargin) + + val query = + """ + |SELECT + | f.grp_id, + | SUM(f.metric) AS sum_metric, + | COUNT(DISTINCT f.order_no) AS distinct_orders + |FROM ss_fact_distinct f + |JOIN ss_dim_distinct d + | ON f.item_sk = d.item_sk + |GROUP BY f.grp_id + |ORDER BY f.grp_id + |""".stripMargin + + runQueryAndCompare(query)(df => checkDfNoPush(df)) + } + } + + test("Support TPC-DS q2 shape") { + val query = + """ + |with wscs as + | (select sold_date_sk + | ,sales_price + | from (select ws_sold_date_sk sold_date_sk + | ,ws_ext_sales_price sales_price + | from web_sales + | union all + | select cs_sold_date_sk sold_date_sk + | ,cs_ext_sales_price sales_price + | from catalog_sales)), + | wswscs as + | (select d_week_seq, + | sum(case when (d_day_name='Sunday') then sales_price else null end) sun_sales, + | sum(case when (d_day_name='Monday') then sales_price else null end) mon_sales, + | sum(case when (d_day_name='Tuesday') then sales_price else null end) tue_sales, + | sum(case when (d_day_name='Wednesday') then sales_price else null end) wed_sales, + | sum(case when (d_day_name='Thursday') then sales_price else null end) thu_sales, + | sum(case when (d_day_name='Friday') then sales_price else null end) fri_sales, + | sum(case when (d_day_name='Saturday') then sales_price else null end) sat_sales + | from wscs + | ,date_dim + | where d_date_sk = sold_date_sk + | group by d_week_seq) + | select d_week_seq1 + | ,round(sun_sales1/sun_sales2,2) + | ,round(mon_sales1/mon_sales2,2) + | ,round(tue_sales1/tue_sales2,2) + | ,round(wed_sales1/wed_sales2,2) + | ,round(thu_sales1/thu_sales2,2) + | ,round(fri_sales1/fri_sales2,2) + | ,round(sat_sales1/sat_sales2,2) + | from + | (select wswscs.d_week_seq d_week_seq1 + | ,sun_sales sun_sales1 + | ,mon_sales mon_sales1 + | ,tue_sales tue_sales1 + | ,wed_sales wed_sales1 + | ,thu_sales thu_sales1 + | ,fri_sales fri_sales1 + | ,sat_sales sat_sales1 + | from wswscs,date_dim + | where date_dim.d_week_seq = wswscs.d_week_seq and + | d_year = 2001) y, + | (select wswscs.d_week_seq d_week_seq2 + | ,sun_sales sun_sales2 + | ,mon_sales mon_sales2 + | ,tue_sales tue_sales2 + | ,wed_sales wed_sales2 + | ,thu_sales thu_sales2 + | ,fri_sales fri_sales2 + | ,sat_sales sat_sales2 + | from wswscs + | ,date_dim + | where date_dim.d_week_seq = wswscs.d_week_seq and + | d_year = 2001+1) z + | where d_week_seq1=d_week_seq2-53 + | order by d_week_seq1 + |""".stripMargin + + runQueryAndCompare(query)(df => checkDf(df)) + } + + test("Support TPC-DS q62 shape") { + val query = + """ + |select + | substr(w_warehouse_name,1,20) + | ,sm_type + | ,web_name + | ,sum(case when (ws_ship_date_sk - ws_sold_date_sk <= 30 ) then 1 else 0 end) + | as `30 days` + | ,sum(case when (ws_ship_date_sk - ws_sold_date_sk > 30) and + | (ws_ship_date_sk - ws_sold_date_sk <= 60) then 1 else 0 end) + | as `31-60 days` + | ,sum(case when (ws_ship_date_sk - ws_sold_date_sk > 60) and + | (ws_ship_date_sk - ws_sold_date_sk <= 90) then 1 else 0 end) + | as `61-90 days` + | ,sum(case when (ws_ship_date_sk - ws_sold_date_sk > 90) and + | (ws_ship_date_sk - ws_sold_date_sk <= 120) then 1 else 0 end) + | as `91-120 days` + | ,sum(case when (ws_ship_date_sk - ws_sold_date_sk > 120) then 1 else 0 end) + | as `>120 days` + |from + | web_sales + | ,warehouse + | ,ship_mode + | ,web_site + | ,date_dim + |where + | d_month_seq between 1212 and 1212 + 11 + |and ws_ship_date_sk = d_date_sk + |and ws_warehouse_sk = w_warehouse_sk + |and ws_ship_mode_sk = sm_ship_mode_sk + |and ws_web_site_sk = web_site_sk + |group by + | substr(w_warehouse_name,1,20) + | ,sm_type + | ,web_name + |order by substr(w_warehouse_name,1,20) + | ,sm_type + | ,web_name + |limit 100 + |""".stripMargin + + runQueryAndCompare(query)(df => checkDf(df)) + } + + test("Support join-aggregate wrapper aggregate for simplified TPC-H q18") { + val query = + """ + |SELECT + | c_name, + | c_custkey, + | o_orderkey, + | o_orderdate, + | o_totalprice, + | SUM(l_quantity) + |FROM customer, orders, lineitem + |WHERE o_orderkey IN ( + | SELECT + | l_orderkey + | FROM lineitem + | GROUP BY l_orderkey + | HAVING SUM(l_quantity) > 300 + |) + | AND c_custkey = o_custkey + | AND o_orderkey = l_orderkey + |GROUP BY + | c_name, + | c_custkey, + | o_orderkey, + | o_orderdate, + | o_totalprice + |ORDER BY + | o_totalprice DESC, + | o_orderdate + |LIMIT 100 + |""".stripMargin + + runQueryAndCompare(query)(df => checkDf(df)) + } +} diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxTPCHSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxTPCHSuite.scala index 12807448c7e0..3203a9127635 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxTPCHSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxTPCHSuite.scala @@ -353,6 +353,17 @@ class VeloxTPCHV1BhjOffheapSuite extends VeloxTPCHSuite { } } +class VeloxTPCHV1BhjPushAggThroughJoinSuite extends VeloxTPCHSuite { + override def subType(): String = "v1-bhj-push-agg-through-join" + + override protected def sparkConf: SparkConf = { + super.sparkConf + .set("spark.sql.sources.useV1SourceList", "parquet") + .set("spark.sql.autoBroadcastJoinThreshold", "30M") + .set(GlutenConfig.PUSH_AGGREGATE_THROUGH_JOIN_ENABLED.key, "true") + } +} + class VeloxTPCHV2Suite extends VeloxTPCHSuite { override def subType(): String = "v2" diff --git a/docs/Configuration.md b/docs/Configuration.md index c4a05709a318..4a3ecf2ff217 100644 --- a/docs/Configuration.md +++ b/docs/Configuration.md @@ -130,6 +130,8 @@ nav_order: 15 | spark.gluten.sql.native.writeColumnMetadataExclusionList | comment | Native write files does not support column metadata. Metadata in list would be removed to support native write files. Multiple values separated by commas. | | spark.gluten.sql.native.writer.enabled | <undefined> | This is config to specify whether to enable the native columnar parquet/orc writer | | spark.gluten.sql.orc.charType.scan.fallback.enabled | true | Force fallback for orc char type scan. | +| spark.gluten.sql.pushAggregateThroughJoin.enabled | false | Enables the push-aggregate-through-join optimization in Gluten. When enabled, aggregate operators may be pushed below joins during logical optimization and corresponding physical plans may be rewritten to execute the aggregation earlier. | +| spark.gluten.sql.pushAggregateThroughJoin.maxDepth | 2147483647 | Maximum join traversal depth when applying the push-aggregate-through-join optimization. A value of 1 allows pushing an aggregate through a single join; larger values allow the rule to traverse and push through multiple consecutive joins. | | spark.gluten.sql.removeNativeWriteFilesSortAndProject | true | When true, Gluten will remove the vanilla Spark V1Writes added sort and project for velox backend. | | spark.gluten.sql.rewrite.dateTimestampComparison | true | Rewrite the comparision between date and timestamp to timestamp comparison.For example `from_unixtime(ts) > date` will be rewritten to `ts > to_unixtime(date)` | | spark.gluten.sql.scan.fileSchemeValidation.enabled | true | When true, enable file path scheme validation for scan. Validation will fail if file scheme is not supported by registered file systems, which will cause scan operator fall back. | diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/config/GlutenConfig.scala b/gluten-substrait/src/main/scala/org/apache/gluten/config/GlutenConfig.scala index 198fc3f025d7..952d643a6fa3 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/config/GlutenConfig.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/config/GlutenConfig.scala @@ -151,6 +151,12 @@ class GlutenConfig(conf: SQLConf) extends GlutenCoreConfig(conf) { def enableExtendedColumnPruning: Boolean = getConf(ENABLE_EXTENDED_COLUMN_PRUNING) + def pushAggregateThroughJoinEnabled: Boolean = + getConf(PUSH_AGGREGATE_THROUGH_JOIN_ENABLED) + + def pushAggregateThroughJoinMaxDepth: Int = + getConf(PUSH_AGGREGATE_THROUGH_JOIN_MAX_DEPTH) + def forceOrcCharTypeScanFallbackEnabled: Boolean = getConf(VELOX_FORCE_ORC_CHAR_TYPE_SCAN_FALLBACK) @@ -706,6 +712,31 @@ object GlutenConfig extends ConfigRegistry { .stringConf .createWithDefault("and,or"); + val PUSH_AGGREGATE_THROUGH_JOIN_ENABLED = + buildConf("spark.gluten.sql.pushAggregateThroughJoin.enabled") + .doc( + "Enables the push-aggregate-through-join optimization in Gluten. " + + "When enabled, aggregate operators may be pushed below joins " + + "during logical optimization " + + "and corresponding physical plans may be rewritten to execute " + + "the aggregation earlier." + ) + .booleanConf + .createWithDefault(false) + + val PUSH_AGGREGATE_THROUGH_JOIN_MAX_DEPTH = + buildConf("spark.gluten.sql.pushAggregateThroughJoin.maxDepth") + .doc( + "Maximum join traversal depth when applying the push-aggregate-through-join " + + "optimization. " + + "A value of 1 allows pushing an aggregate through a single join; larger " + + "values allow the rule " + + "to traverse and push through multiple consecutive joins." + ) + .intConf + .checkValue(_ >= 1, "must be greater than or equal to 1.") + .createWithDefault(Int.MaxValue) + val GLUTEN_SOFT_AFFINITY_ENABLED = buildConf("spark.gluten.soft-affinity.enabled") .doc("Whether to enable Soft Affinity scheduling.") diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/ProjectColumnPruning.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/ProjectColumnPruning.scala index 258edae1993c..da5be69e5417 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/ProjectColumnPruning.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/rewrite/ProjectColumnPruning.scala @@ -16,6 +16,7 @@ */ package org.apache.gluten.extension.columnar.rewrite +import org.apache.spark.sql.catalyst.expressions.AttributeSet import org.apache.spark.sql.execution.{ProjectExec, SparkPlan, UnaryExecNode} /** @@ -30,10 +31,16 @@ object ProjectColumnPruning extends RewriteSingleNode { } } + private def getReferences(plan: SparkPlan): AttributeSet = { + // SPARK-55979 - aggregate.references is unreliable. + AttributeSet(plan.expressions) -- (plan.producedAttributes -- plan.children.flatMap( + _.outputSet)) + } + override def rewrite(plan: SparkPlan): SparkPlan = plan match { case parent: UnaryExecNode if parent.child.isInstanceOf[ProjectExec] => val project = parent.child.asInstanceOf[ProjectExec] - val unusedAttribute = project.outputSet -- (parent.references ++ parent.outputSet) + val unusedAttribute = project.outputSet -- (getReferences(parent) ++ parent.outputSet) if (unusedAttribute.nonEmpty) { val newProject = project.copy(projectList = project.projectList.diff(unusedAttribute.toSeq)) diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/joinagg/ImplementJoinAggregate.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/joinagg/ImplementJoinAggregate.scala new file mode 100644 index 000000000000..bc01889eedee --- /dev/null +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/joinagg/ImplementJoinAggregate.scala @@ -0,0 +1,308 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension.joinagg + +import org.apache.gluten.config.GlutenConfig + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.{Alias, GetStructField, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.planning.PhysicalAggregation +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.execution.{ProjectExec, SparkPlan, SparkStrategy} +import org.apache.spark.sql.execution.aggregate.HashAggregateExec + +import scala.collection.mutable.ArrayBuffer + +case class ImplementJoinAggregate(spark: SparkSession) extends SparkStrategy { + /* + * This strategy lowers the logical join-aggregate wrapper shape produced by + * `PushAggregateThroughJoin` into ordinary Spark physical hash aggregates. + * + * The optimizer builds two wrapper phases: + * - pushed phase: wrapper partial below / through joins + * - final phase: wrapper final above the pushed phase + * + * Spark's physical aggregate operators still expect normal aggregate functions and normal input + * / buffer attributes. This strategy therefore: + * 1. rewrites wrapper aggregates back to the wrapped Spark aggregate with the correct physical + * aggregate mode; + * 2. inserts a post-project for the pushed phase to pack Spark aggregate buffers into a + * struct, matching the wrapper's logical output type; + * 3. inserts a pre-project for the final phase to unpack that struct back into the buffer + * attributes expected by the wrapped Spark aggregate. + * + * In short: the wrapper exists only at the logical boundary; this strategy makes the plan look + * like an ordinary Spark aggregate plan again while preserving the wrapper data contract across + * the join. + */ + + override def apply(plan: LogicalPlan): Seq[SparkPlan] = { + if (!GlutenConfig.get.pushAggregateThroughJoinEnabled) { + return Nil + } + + plan match { + case agg: Aggregate if containsWrapperAggregate(agg) => + planJoinAggregate(agg).toSeq + case _ => + Nil + } + } + + private def planJoinAggregate(agg: Aggregate): Option[SparkPlan] = agg match { + case PhysicalAggregation(groupingExpressions, aes, resultExpressions, child) => + // KEEP: For compatibility with Spark version before 3.5, which widens + // the aggregate expressions + // to `Seq[NamedExpression]` instead of `Seq[AggregateExpression]`. + val aggExpressions = aes.map(_.asInstanceOf[AggregateExpression]) + // A single logical aggregate must lower either entirely as pushed-phase wrappers or entirely + // as final-phase wrappers. Mixed-phase aggregates are rejected here. + val grouping = groupingExpressions.collect { case ne: NamedExpression => ne } + if (grouping.size != groupingExpressions.size) { + return None + } + + val wrappers = aggExpressions.flatMap { + case ae @ AggregateExpression(wrapper: JoinAggregateFunctionWrapper, _, _, _, _) => + Some((ae, wrapper)) + case _ => + None + } + if (wrappers.isEmpty) { + return None + } + + val phases = wrappers.map(_._2.targetPhase).distinct + if (phases.size != 1) { + return None + } + val phase = phases.head + + val childPlan = planLater(child) + phase match { + case JoinAggregateFunctionWrapper.PartialPhase => + planPartialPhase(grouping, aggExpressions, resultExpressions, childPlan) + case JoinAggregateFunctionWrapper.FinalPhase => + planFinalPhase(grouping, aggExpressions, resultExpressions, childPlan) + } + case _ => + None + } + + private def planPartialPhase( + grouping: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + resultExpressions: Seq[NamedExpression], + childPlan: SparkPlan): Option[SparkPlan] = { + // The pushed logical aggregate exposes one wrapper-typed output per pushed aggregate. Spark + // physically computes ordinary aggregate buffers, so this phase runs a normal HashAggregateExec + // first and then repacks those buffers into the struct-valued wrapper outputs expected by the + // logical plan above. + val rewrittenAggExprs = aggregateExpressions.map { + case ae @ AggregateExpression(_: JoinAggregateFunctionWrapper, _, _, _, _) => + rewriteSingleAggregateExpression(ae) + case ae => + ae + } + if (rewrittenAggExprs.isEmpty) { + return None + } + + val hashAgg = HashAggregateExec( + requiredChildDistributionExpressions = None, + isStreaming = false, + numShufflePartitions = None, + groupingExpressions = grouping, + aggregateExpressions = rewrittenAggExprs, + aggregateAttributes = rewrittenAggExprs.flatMap(_.aggregateFunction.aggBufferAttributes), + initialInputBufferOffset = 0, + resultExpressions = grouping.map(_.toAttribute) ++ rewrittenAggExprs.flatMap( + _.aggregateFunction.aggBufferAttributes), + child = childPlan + ) + + val rewrittenByOriginalResultId = rewrittenAggExprs.map(ae => ae.resultId -> ae).toMap + + val packedByOriginalResultId = aggregateExpressions.map { + originalAe => + // Each pushed wrapper output corresponds to one rewritten Spark aggregate. Repack the + // physical buffer attrs into a struct so the parent plan still sees the wrapper contract + // created by PushAggregateThroughJoin. + val rewrittenAe = rewrittenByOriginalResultId.getOrElse( + originalAe.resultId, + throw new IllegalStateException( + s"Cannot resolve pushed aggregate output for ${originalAe.sql}")) + originalAe.resultId -> org.apache.spark.sql.catalyst.expressions + .CreateStruct(rewrittenAe.aggregateFunction.aggBufferAttributes) + }.toMap + + def packedResultForAttr(attr: org.apache.spark.sql.catalyst.expressions.AttributeReference) + : Option[NamedExpression] = { + aggregateExpressions + .find(_.resultAttribute.exprId == attr.exprId) + .flatMap(ae => packedByOriginalResultId.get(ae.resultId)) + .map { + packed => Alias(packed, attr.name)(exprId = attr.exprId, qualifier = attr.qualifier) + } + } + + val rewrittenResultExpressions = resultExpressions.map { + case attr: org.apache.spark.sql.catalyst.expressions.AttributeReference => + packedResultForAttr(attr).getOrElse(attr) + case alias: Alias => + alias + .transformUp { + case ae: AggregateExpression => + packedByOriginalResultId.getOrElse( + ae.resultId, + throw new IllegalStateException( + s"Cannot resolve pushed aggregate output for ${ae.sql}")) + case attr: org.apache.spark.sql.catalyst.expressions.AttributeReference => + aggregateExpressions + .find(_.resultAttribute.exprId == attr.exprId) + .flatMap(ae => packedByOriginalResultId.get(ae.resultId)) + .getOrElse(attr) + } + .asInstanceOf[NamedExpression] + case other => + other + .transformUp { + case ae: AggregateExpression => + packedByOriginalResultId.getOrElse( + ae.resultId, + throw new IllegalStateException( + s"Cannot resolve pushed aggregate output for ${ae.sql}")) + } + .asInstanceOf[NamedExpression] + } + + Some(ProjectExec(rewrittenResultExpressions, hashAgg)) + } + + private def planFinalPhase( + grouping: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression], + resultExpressions: Seq[NamedExpression], + childPlan: SparkPlan): Option[SparkPlan] = { + // Lower the final wrapper phase by first unpacking the wrapper struct into the wrapped + // aggregate's input buffer attributes, then running a normal Spark final / merge aggregate. + val wrapperWithRewritten: Seq[(JoinAggregateFunctionWrapper, AggregateExpression)] = + aggregateExpressions.flatMap { + case originalAe @ AggregateExpression(wrapper: JoinAggregateFunctionWrapper, _, _, _, _) => + Some((wrapper, rewriteSingleAggregateExpression(originalAe))) + case _ => + None + } + + val unpackAliases = ArrayBuffer.empty[Alias] + val seenExprIds = scala.collection.mutable.HashSet.empty[Long] + wrapperWithRewritten.foreach { + case (wrapper, rewrittenAe) => + val bufferExpr = wrapper.children.head + rewrittenAe.aggregateFunction.inputAggBufferAttributes.zipWithIndex.foreach { + case (bufferAttr, idx) if seenExprIds.add(bufferAttr.exprId.id) => + // Keep exprId for binding correctness, but avoid dotted names (e.g. a.b) in the + // temporary unpack projection. This projection only recreates the physical buffer attrs + // that Spark's final / merge aggregate expects to read from the wrapper struct. + val safeName = s"_joinagg_buf_${bufferAttr.exprId.id}_$idx" + unpackAliases += Alias( + GetStructField(bufferExpr, idx, Some(bufferAttr.name)), + safeName + )(exprId = bufferAttr.exprId, qualifier = bufferAttr.qualifier) + case _ => + } + } + + val childWithUnpacked = if (unpackAliases.nonEmpty) { + ProjectExec(childPlan.output ++ unpackAliases, childPlan) + } else { + childPlan + } + + val rewrittenAggExprs = aggregateExpressions.map { + case ae @ AggregateExpression(_: JoinAggregateFunctionWrapper, _, _, _, _) => + rewriteSingleAggregateExpression(ae) + case ae => + ae + } + if (rewrittenAggExprs.isEmpty) { + return None + } + val aggregateAttrs = rewrittenAggExprs.map(_.resultAttribute) + val rewrittenResultExpressions = + rewriteResultAsAggregateAttributes(resultExpressions, rewrittenAggExprs) + + Some( + HashAggregateExec( + requiredChildDistributionExpressions = Some(grouping.map(_.toAttribute)), + isStreaming = false, + numShufflePartitions = None, + groupingExpressions = grouping, + aggregateExpressions = rewrittenAggExprs, + aggregateAttributes = aggregateAttrs, + initialInputBufferOffset = 0, + resultExpressions = rewrittenResultExpressions, + child = childWithUnpacked + )) + } + + private def containsWrapperAggregate(agg: Aggregate): Boolean = { + agg.aggregateExpressions.exists { + _.exists { + case AggregateExpression(wrapper: JoinAggregateFunctionWrapper, _, _, _, _) => true + case _ => false + } + } + } + + private def rewriteSingleAggregateExpression( + original: AggregateExpression): AggregateExpression = { + original + .transformUp { + case ae @ AggregateExpression(wrapper: JoinAggregateFunctionWrapper, _, _, _, _) => + // The wrapper carries the logical phase; `semanticMode` turns + // that phase together with the current Spark aggregate mode into + // the actual wrapped aggregate mode required by this physical aggregate node. + val mode = JoinAggregateFunctionWrapper.semanticMode(ae.mode, wrapper.targetPhase) + ae.copy(aggregateFunction = wrapper.innerAgg, mode = mode) + } + .asInstanceOf[AggregateExpression] + } + + private def rewriteResultAsAggregateAttributes( + rewrittenOutput: Seq[NamedExpression], + rewrittenAggExprs: Seq[AggregateExpression]): Seq[NamedExpression] = { + // After the HashAggregateExec is built, rewrite the original output tree so every aggregate + // expression points at the corresponding physical aggregate result attribute. + // Some results already reference the wrapper aggregate's resultAttribute directly instead of + // carrying the AggregateExpression node, so rewrite those attributes as well to avoid leaking + // wrapper names into the final physical plan. + val rewrittenByResultId = rewrittenAggExprs.map(ae => ae.resultId -> ae.resultAttribute).toMap + rewrittenOutput.map { + _.transformUp { + case ae: AggregateExpression => + rewrittenByResultId + .get(ae.resultId) + .getOrElse( + throw new IllegalStateException(s"Cannot resolve aggregate attribute for ${ae.sql}")) + case attr: org.apache.spark.sql.catalyst.expressions.AttributeReference => + rewrittenByResultId.getOrElse(attr.exprId, attr) + }.asInstanceOf[NamedExpression] + } + } +} diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/joinagg/JoinAggregateFunctionWrapper.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/joinagg/JoinAggregateFunctionWrapper.scala new file mode 100644 index 000000000000..1474bdf2ba8f --- /dev/null +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/joinagg/JoinAggregateFunctionWrapper.scala @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension.joinagg + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, CreateStruct, Expression, GetStructField, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.types.DataType + +import java.util.Locale + +import scala.collection.mutable + +object JoinAggregateFunctionWrapper { + // The wrapper is used in exactly two logical phases: + // - PartialPhase: a pushed aggregate below / through joins + // - FinalPhase: the aggregate above the join that restores the original query semantics + sealed trait TargetPhase { + def sqlName: String + } + + case object PartialPhase extends TargetPhase { + override val sqlName: String = "PARTIAL" + } + + case object FinalPhase extends TargetPhase { + override val sqlName: String = "FINAL" + } + + def wrapperPartial( + innerAgg: DeclarativeAggregate, + wrapperKey: String = "0"): JoinAggregateFunctionWrapper = { + JoinAggregateFunctionWrapper( + innerAgg = innerAgg, + targetPhase = PartialPhase, + inputBuffer = None, + wrapperKey = wrapperKey) + } + + def wrapperFinal( + innerAgg: DeclarativeAggregate, + inputBuffer: Expression, + wrapperKey: String = "0"): JoinAggregateFunctionWrapper = { + JoinAggregateFunctionWrapper( + innerAgg = innerAgg, + targetPhase = FinalPhase, + inputBuffer = Some(inputBuffer), + wrapperKey = wrapperKey) + } + + // Translate Spark's physical aggregate mode together with the wrapper semantic phase into the + // actual Spark aggregate mode that should be used for the wrapped aggregate. For example, a + // FinalPhase wrapper running in Spark's Partial mode is really doing a partial merge of the + // pushed aggregate buffers. + def semanticMode(actualMode: AggregateMode, targetPhase: TargetPhase): AggregateMode = { + // The wrapper phase says what the logical plan wants to do with the aggregate state across + // the join. `actualMode` says which physical aggregate stage Spark is currently building. + // The combination determines the real wrapped aggregate mode that should run physically. + (actualMode, targetPhase) match { + case (Partial, PartialPhase) => Partial + case (PartialMerge, PartialPhase) => PartialMerge + case (Final, PartialPhase) => PartialMerge + case (Complete, PartialPhase) => Partial + case (Partial, FinalPhase) => PartialMerge + case (PartialMerge, FinalPhase) => PartialMerge + case (Final, FinalPhase) => Final + case (Complete, FinalPhase) => Final + case _ => + throw new UnsupportedOperationException( + s"Unsupported wrapper semantic mode mapping: actualMode=$actualMode, " + + s"targetPhase=$targetPhase") + } + } +} + +case class JoinAggregateFunctionWrapper( + innerAgg: DeclarativeAggregate, + targetPhase: JoinAggregateFunctionWrapper.TargetPhase, + inputBuffer: Option[Expression], + wrapperKey: String = "0") + extends DeclarativeAggregate { + /* + * Logical wrapper around a Spark declarative aggregate used by the join-aggregate rewrite. + * + * Why the wrapper exists: + * - Below the join, we want to aggregate early and carry the aggregate buffer through the + * join as a single value. + * - Above the join, we want to merge / evaluate that buffer and recover the original + * aggregate result. + * + * The wrapper therefore changes only the *logical contract* across the join: + * - PartialPhase exposes the wrapped aggregate buffer as a single struct-valued output. + * - FinalPhase consumes that struct-valued buffer and delegates merge / evaluate semantics + * back to the wrapped Spark aggregate. + * + * `ImplementJoinAggregate` later lowers this wrapper back into normal Spark physical aggregates + * by packing / unpacking the struct around the aggregate buffers. + */ + import JoinAggregateFunctionWrapper._ + + private val wrappedBufferAttrs: Seq[AttributeReference] = + innerAgg.aggBufferAttributes.zipWithIndex.map { + case (attr, index) => + AttributeReference(attr.name, attr.dataType, attr.nullable)() + } + + private def outputBufferExpr: Expression = + inputBuffer.getOrElse(CreateStruct(innerAgg.inputAggBufferAttributes)) + + override lazy val nullable: Boolean = true + + override lazy val dataType: DataType = targetPhase match { + case PartialPhase => + // The pushed phase carries the aggregate buffer through the plan as a single struct-valued + // payload so the join sees one logical column per pushed aggregate. + CreateStruct(wrappedBufferAttrs).dataType + case FinalPhase => + innerAgg.dataType + } + + override def children: Seq[Expression] = targetPhase match { + case PartialPhase => innerAgg.children + case FinalPhase => Seq(outputBufferExpr) + } + + override lazy val aggBufferAttributes: Seq[AttributeReference] = wrappedBufferAttrs + + override lazy val initialValues: Seq[Expression] = { + rewrite(innerAgg.initialValues, childReplacements = Map.empty, useInputBufferField = false) + } + + override lazy val updateExpressions: Seq[Expression] = targetPhase match { + case PartialPhase => + // Update the wrapped aggregate buffer from the original aggregate children. + rewrite( + innerAgg.updateExpressions, + childReplacements = innerAgg.children.zip(children).toMap, + useInputBufferField = false + ) + case FinalPhase => + // Merge expressions read from the struct-valued input buffer produced by the pushed phase. + rewrite(innerAgg.mergeExpressions, childReplacements = Map.empty, useInputBufferField = true) + } + + override lazy val mergeExpressions: Seq[Expression] = { + rewrite(innerAgg.mergeExpressions, childReplacements = Map.empty, useInputBufferField = false) + } + + override lazy val evaluateExpression: Expression = targetPhase match { + case PartialPhase => + // The pushed phase returns the entire aggregate buffer, not the final aggregate value. + CreateStruct(aggBufferAttributes) + case FinalPhase => + rewrite( + innerAgg.evaluateExpression, + childReplacements = Map.empty, + useInputBufferField = false) + } + + override def nodeName: String = "JoinAggregateWrapper" + + override def prettyName: String = + s"join_agg_wrapper_${targetPhase.sqlName.toLowerCase(Locale.ROOT)}" + + override def sql: String = { + s"$prettyName(${innerAgg.sql(false)})" + } + + override lazy val deterministic: Boolean = innerAgg.deterministic + + override lazy val defaultResult: Option[Literal] = targetPhase match { + case PartialPhase => None + case FinalPhase => innerAgg.defaultResult + } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): Expression = { + targetPhase match { + case PartialPhase => + val newInner = innerAgg.withNewChildren(newChildren).asInstanceOf[DeclarativeAggregate] + copy(innerAgg = newInner, inputBuffer = None) + case FinalPhase => + if (newChildren.size != 1) { + throw new IllegalArgumentException( + s"Final JoinAggregateWrapper expects exactly one child, got ${newChildren.size}") + } + copy(inputBuffer = Some(newChildren.head)) + } + } + + private def rewrite( + exprs: Seq[Expression], + childReplacements: Map[Expression, Expression], + useInputBufferField: Boolean): Seq[Expression] = { + exprs.map(rewrite(_, childReplacements, useInputBufferField)) + } + + private def rewrite( + expr: Expression, + childReplacements: Map[Expression, Expression], + useInputBufferField: Boolean): Expression = { + // Rebind every attribute reference in the wrapped aggregate expression tree to the + // corresponding wrapper-side attribute: + // - wrapped buffer attrs when we are updating / evaluating the wrapper buffer + // - struct-field reads when FinalPhase is consuming the pushed buffer payload + // - original aggregate children when PartialPhase still reads the raw input rows + val innerToWrappedBuffer = innerAgg.aggBufferAttributes.zip(aggBufferAttributes) + val innerToInputBuffer = innerAgg.inputAggBufferAttributes.zipWithIndex.map { + case (attr, index) => + if (useInputBufferField) { + attr -> GetStructField(outputBufferExpr, index, Some(attr.name)) + } else { + attr -> inputAggBufferAttributes(index) + } + } + val attrRewriteMap = mutable.ArrayBuffer.empty[(Attribute, Expression)] + attrRewriteMap ++= innerToWrappedBuffer + attrRewriteMap ++= innerToInputBuffer + val childRewriteSeq = childReplacements.toSeq + + // Rewrite in two passes: + // 1. swap original aggregate child expressions when PartialPhase still reads raw input rows; + // 2. rebind buffer/input attributes to wrapper-side attrs or struct-field reads. + // + // Using semantic equality here is important because the wrapped aggregate expression tree is + // copied and re-bound several times while the rule and strategy move between logical and + // physical representations. + childRewriteSeq + .foldLeft(expr) { + case (curExpr, (from, to)) => + curExpr.transformUp { + case e if e.semanticEquals(from) => to + } + } + .transformUp { + case a: Attribute => + attrRewriteMap + .collectFirst { + case (from, to) if a.semanticEquals(from) => to + } + .getOrElse(a) + } + } +} diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/joinagg/PushAggregateThroughJoin.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/joinagg/PushAggregateThroughJoin.scala new file mode 100644 index 000000000000..bb115de796d2 --- /dev/null +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/joinagg/PushAggregateThroughJoin.scala @@ -0,0 +1,598 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension.joinagg + +import org.apache.gluten.config.GlutenConfig + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, EqualTo, Expression, NamedExpression, PredicateHelper} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} +import org.apache.spark.sql.catalyst.optimizer.DecimalAggregates +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule + +case class PushAggregateThroughJoin(spark: SparkSession) + extends Rule[LogicalPlan] + with PredicateHelper { + /* + * This rule rewrites: + * + * Aggregate(..., Join(...)) + * + * into a pair of wrapper aggregates: + * + * FinalWrapperAggregate(..., PushedWrapperAggregate(..., Join(...))) + * + * and then repeatedly pushes the wrapper aggregate through inner equi-joins. The pushed + * aggregate always represents the "pre-join" aggregation work, while the final wrapper above + * the join preserves the original query semantics. + * + * The key invariants are: + * 1. Only declarative, non-distinct aggregates are rewritten. + * 2. Pushability is decided per aggregate subexpression, not per whole result expression. + * This is what lets expressions such as `sum(a) - sum(b)` push both aggregate terms while + * keeping the arithmetic shape above the join unchanged. + * 3. Pushed grouping keys must preserve the semantics of all expressions that still remain + * above the pushed aggregate, including derived grouping expressions coming from extracted + * Project / Filter nodes. + * 4. Aggregate measure inputs needed to evaluate the pushed aggregate must be preserved in the + * rebuilt subtree, but they must not be promoted into grouping keys. Promoting those + * inputs would change the pre-aggregation grain and make the rewrite semantically wrong. + * 5. The rule intentionally pushes one join edge at a time; repeated application performs + * deep pushdown for multi-join shapes. + */ + + private case class SidePartialSpec( + originalExpr: Expression, + aggregate: DeclarativeAggregate, + wrapperKey: String) + + private case class SidePartialRef(spec: SidePartialSpec, attr: AttributeReference) + + private var successfulSplitCount: Int = 0 + private var successfulPushCount: Int = 0 + + def resetSuccessfulSplitCount(): Unit = { + successfulSplitCount = 0 + } + + def resetSuccessfulPushCount(): Unit = { + successfulPushCount = 0 + } + + def getSuccessfulSplitCount: Int = successfulSplitCount + + def getSuccessfulPushCount: Int = successfulPushCount + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!isEnabled) { + return plan + } + var splitCount = 0 + val newPlan = plan.transformUp { + case agg: Aggregate + if !containsWrapperAggregateInOutput(agg.aggregateExpressions) && + hasPushableAggExpr(agg.aggregateExpressions) && + !hasDistinctAggExpr(agg.aggregateExpressions) => + // 1) Aggregate+Join => FinalWrapperAgg(PartialWrapperAgg(...Join...)) + splitAggregate(agg) match { + case Some(newAgg) => + splitCount += 1 + // 2) Exhaustively push PartialWrapperAgg through join edges. + val pushed = pushPartialWrapperAggregate(newAgg) + // 3) Return rewritten plan with pushed partial wrapper aggregates. + pushed + case None => agg + } + } + if (splitCount > 0) { + successfulSplitCount += splitCount + newPlan + } else { + assert(newPlan eq plan) + plan + } + } + + private def isEnabled: Boolean = GlutenConfig.get.pushAggregateThroughJoinEnabled + + private def maxDepth: Int = GlutenConfig.get.pushAggregateThroughJoinMaxDepth + + private def splitAggregate(agg: Aggregate): Option[Aggregate] = { + // Split is intentionally child-agnostic. It only rewrites: + // + // Aggregate(resultExprs, child) + // + // into: + // + // Aggregate(rewrittenResultExprs, Aggregate(pushedKeys + wrapperOutputs, child)) + // + // Any join-specific key widening or subtree rebuild is handled later by pushOnce / + // pushPartialAggToJoinSide. This keeps the split step reusable even when the current child is + // not yet a join-shaped subtree. + val pushableSpecs = collectPushableSpecs(agg.aggregateExpressions) + if (pushableSpecs.isEmpty) { + None + } else { + val groupingRefAttrs = dedupeAttrs(agg.groupingExpressions.flatMap(referencedAttrsInOrder)) + // Aggregate subexpressions that are not rewritten remain above the lower aggregate after the + // split. Their referenced attrs must therefore still be produced by the lower aggregate, but + // at this stage we only derive that requirement from the aggregate itself, not from any + // particular child shape. + val nonPushableAggAttrs = dedupeAttrs(agg.aggregateExpressions.flatMap { + case Alias(expr, _) if !isPushableExpr(expr) => referencedAttrsInOrder(expr) + case expr if !isPushableExpr(expr) => referencedAttrsInOrder(expr) + case _ => Nil + }) + val partialGroupingKeys = dedupeAttrs(groupingRefAttrs ++ nonPushableAggAttrs) + val partialAliases = pushableSpecs.zipWithIndex.map { + case (spec, idx) => + Alias( + JoinAggregateFunctionWrapper + .wrapperPartial(spec.aggregate, spec.wrapperKey) + .toAggregateExpression(), + s"_pushed_${spec.aggregate.prettyName}_buffer_$idx" + )() + } + + val partialAgg = Aggregate( + groupingExpressions = partialGroupingKeys, + aggregateExpressions = partialGroupingKeys ++ partialAliases, + child = agg.child + ) + + val partialOutputAttrs = + partialAgg.output + .slice(partialGroupingKeys.size, partialGroupingKeys.size + partialAliases.size) + .map(cleanAttr) + val partialRefs = pushableSpecs.zip(partialOutputAttrs).map { + case (spec, attr) => SidePartialRef(spec, attr) + } + + rewriteAggregateExpressions(agg.aggregateExpressions, partialRefs).map { + rewrittenAggExprs => + agg.copy( + aggregateExpressions = rewrittenAggExprs, + child = partialAgg + ) + } + } + } + + private def pushPartialWrapperAggregate(agg: Aggregate): LogicalPlan = { + // Push one join edge per iteration. This keeps the rewrite local and lets `maxDepth` bound + // how far a pushed aggregate is allowed to travel through a multi-join subtree. + var current: LogicalPlan = agg + var changed = true + var pushCount = 0 + while (changed && pushCount < maxDepth) { + changed = false + current = current.transformUp { + case partialAgg: Aggregate if isPurePartialWrapperAggregate(partialAgg) => + pushOnce( + partialAgg, + partialAgg.groupingExpressions, + partialAgg.aggregateExpressions, + partialAgg.child) match { + case Some(newPlan) => + pushCount += 1 + changed = true + newPlan + case None => + partialAgg + } + } + } + successfulPushCount += pushCount + current + } + + private def pushOnce( + partialAgg: Aggregate, + groupingExprs: Seq[Expression], + aggExprs: Seq[NamedExpression], + child: LogicalPlan): Option[LogicalPlan] = { + extractJoin(child).flatMap { + case (join, wrapperRequiredAttrs, rebuild) => + Seq(JoinLeft, JoinRight).iterator + .flatMap { + side => + val maybePushedJoin = + pushPartialAggToJoinSide(join, groupingExprs, aggExprs, wrapperRequiredAttrs, side) + maybePushedJoin match { + case Some(pushedJoin) => + val requiredAttrs = partialAgg.output.collect { case a: Attribute => a } + Some(rebuild(pushedJoin, requiredAttrs)) + case None => None + } + } + .toSeq + .headOption + } + } + + private def pushPartialAggToJoinSide( + join: Join, + groupingExprs: Seq[Expression], + aggExprs: Seq[NamedExpression], + wrapperRequiredAttrs: Seq[Attribute], + side: JoinSide): Option[Join] = { + // A pushed wrapper aggregate may move to a join side only when all of its aggregate inputs + // come from that side. The pushed grouping then has to include: + // - grouping attrs already present on this side + // - join keys / join-condition attrs required to preserve join semantics + // - attrs referenced by non-pushable expressions still above the pushed aggregate + // - attrs propagated from extracted Project / Filter nodes when those attrs are required to + // preserve derived grouping expressions + // + // The pushed grouping must *not* include pure measure inputs of the pushed aggregate such as + // `ss_net_profit`, otherwise the pre-aggregation becomes over-constrained and ineffective. + val wrapperAliases = collectPartialWrapperAliases(aggExprs) + if (wrapperAliases.isEmpty) { + return None + } + + val sideOutputSet = side.outputSet(join) + val allPushable = wrapperAliases.forall { + case (_, wrapper) => canPushToSide(wrapper.innerAgg, sideOutputSet) + } + if (!allPushable) { + return None + } + + val sideGroupingAttrs = groupingExprs + .flatMap(referencedAttrsInOrder) + .collect { case a: Attribute if sideOutputSet.contains(a) => a } + val sideJoinKeys = join.condition.toSeq.flatMap(splitConjunctivePredicates).collect { + case EqualTo(l: Attribute, r: Attribute) + if sideOutputSet.contains(l) && !sideOutputSet.contains(r) => + l + case EqualTo(l: Attribute, r: Attribute) + if sideOutputSet.contains(r) && !sideOutputSet.contains(l) => + r + } + val sideJoinCondAttrs = join.condition.toSeq + .flatMap(referencedAttrsInOrder) + .collect { case a: Attribute if sideOutputSet.contains(a) => a } + + // These attrs belong to aggregate subexpressions that stay above the pushed aggregate. They + // must survive subtree rebuild, but they are not themselves proof that the pushed aggregate + // needs to group by those measures. + val sideNonPushableAggAttrs = dedupeAttrs(aggExprs.flatMap { + case Alias(expr, _) if !isPushableExpr(expr) && !containsWrapperAggregateExpr(expr) => + referencedAttrsInOrder(expr).collect { + case a: Attribute if sideOutputSet.contains(a) => a + } + case expr if !isPushableExpr(expr) && !containsWrapperAggregateExpr(expr) => + referencedAttrsInOrder(expr).collect { + case a: Attribute if sideOutputSet.contains(a) => a + } + case _ => Nil + }) + val sidePushableAggInputAttrs = dedupeAttrs(wrapperAliases.flatMap { + case (_, wrapper) => + wrapper.innerAgg.children.flatMap(referencedAttrsInOrder).collect { + case a: Attribute if sideOutputSet.contains(a) => a + } + }) + val sideWrapperRequiredAttrs = dedupeAttrs( + wrapperRequiredAttrs + .collect { + case a: Attribute if sideOutputSet.contains(a) => a + } + .filterNot(attr => sidePushableAggInputAttrs.exists(_.semanticEquals(attr)))) + // `wrapperRequiredAttrs` carries dependencies from extracted Project/Filter nodes above the + // join. Those dependencies are needed to preserve grouping semantics for derived grouping + // expressions such as `substr(w_warehouse_name, 1, 20)`. However, pushed aggregate inputs + // like `ss_net_profit` must stay as child inputs only; promoting them into grouping keys + // would over-constrain the pushed pre-aggregation. + val pushedGrouping = + dedupeAttrs( + sideGroupingAttrs ++ + sideJoinKeys ++ + sideJoinCondAttrs ++ + sideNonPushableAggAttrs ++ + sideWrapperRequiredAttrs) + if (pushedGrouping.isEmpty) { + return None + } + + val pushedWrapperAliases = wrapperAliases.map { + case (alias, wrapper) => + val wrapped = JoinAggregateFunctionWrapper + .wrapperPartial(wrapper.innerAgg, wrapper.wrapperKey) + .toAggregateExpression() + Alias(wrapped, alias.name)( + exprId = alias.exprId, + qualifier = alias.qualifier, + explicitMetadata = alias.explicitMetadata, + nonInheritableMetadataKeys = alias.nonInheritableMetadataKeys + ) + } + + val pushedAgg = Aggregate( + groupingExpressions = pushedGrouping, + aggregateExpressions = pushedGrouping ++ pushedWrapperAliases, + child = side.plan(join) + ) + + val pushedJoin = side.replace(join, pushedAgg) + Some(pushedJoin) + } + + private def isPurePartialWrapperAggregate(agg: Aggregate): Boolean = { + val wrapperAliases = collectPartialWrapperAliases(agg.aggregateExpressions) + wrapperAliases.nonEmpty && agg.aggregateExpressions.forall { + case Alias(_: AggregateExpression, _) => true + case _: AggregateExpression => false + case _ => true + } + } + + private def collectPartialWrapperAliases( + output: Seq[NamedExpression]): Seq[(Alias, JoinAggregateFunctionWrapper)] = { + output.collect { + case alias @ Alias(AggregateExpression(wrapper: JoinAggregateFunctionWrapper, _, _, _, _), _) + if wrapper.targetPhase == JoinAggregateFunctionWrapper.PartialPhase => + (alias, wrapper) + } + } + + private def extractJoin( + child: LogicalPlan): Option[(Join, Seq[Attribute], (Join, Seq[Attribute]) => LogicalPlan)] = + child match { + case project @ Project(_, projectChild) => + // Preserve dependencies introduced by extracted Project nodes. When pushdown rebuilds the + // join subtree, these attrs tell us which project outputs (and therefore which child + // inputs) must remain available above the pushed aggregate. + extractJoin(projectChild).map { + case (join, wrapperRequiredAttrs, rebuild) => + val projectRequiredAttrs = + dedupeAttrs( + project.projectList.flatMap(referencedAttrsInOrder) ++ wrapperRequiredAttrs) + ( + join, + projectRequiredAttrs, + (j: Join, requiredAttrs: Seq[Attribute]) => { + val requiredAttrSet = AttributeSet(requiredAttrs) + val retainedProjectList = project.projectList.filter { + ne => requiredAttrSet.contains(ne.toAttribute) + } + // If a retained project expression depends on child columns (e.g. CASE refs), + // propagate those dependencies when rebuilding the join subtree. + val requiredForChild = + dedupeAttrs(requiredAttrs ++ retainedProjectList.flatMap(referencedAttrsInOrder)) + val rebuiltChild = rebuild(j, requiredForChild) + val projectOutputSet = AttributeSet(project.projectList.map(_.toAttribute)) + val passThroughAttrs = requiredAttrs.filter { + attr => rebuiltChild.outputSet.contains(attr) && !projectOutputSet.contains(attr) + } + // Only keep project outputs required by the consumer above this extracted join. + // This allows pushdown to replace measure columns with partial-wrapper outputs. + project.copy( + projectList = retainedProjectList ++ dedupeAttrs(passThroughAttrs), + child = rebuiltChild + ) + } + ) + } + case join: Join if isInnerEquiJoin(join) => + Some((join, Nil, (j: Join, _: Seq[Attribute]) => j)) + case filter @ Filter(_, join: Join) if isInnerEquiJoin(join) => + // Filter predicates above the join become additional required attrs for rebuild, because + // those predicates still execute above the pushed aggregate after rewrite. + Some( + ( + join, + referencedAttrsInOrder(filter.condition), + (j: Join, _: Seq[Attribute]) => filter.copy(child = j))) + case _ => None + } + + private def isInnerEquiJoin(join: Join): Boolean = { + join.joinType == Inner && join.hint == JoinHint.NONE && join.condition.exists { + cond => + splitConjunctivePredicates(cond).exists { + case EqualTo(_: Attribute, _: Attribute) => true + case _ => false + } + } + } + + private def cleanAttr(attr: Attribute): AttributeReference = { + AttributeReference(attr.name, attr.dataType, attr.nullable)( + exprId = attr.exprId, + qualifier = attr.qualifier) + } + + private def containsWrapperAggregateInOutput(aggExprs: Seq[NamedExpression]): Boolean = { + aggExprs.exists { + _.exists { + case AggregateExpression(_: JoinAggregateFunctionWrapper, _, _, _, _) => true + case _ => false + } + } + } + + private def containsWrapperAggregateExpr(expr: Expression): Boolean = { + expr.exists { + case AggregateExpression(_: JoinAggregateFunctionWrapper, _, _, _, _) => true + case _ => false + } + } + + private def hasPushableAggExpr(aggExprs: Seq[Expression]): Boolean = { + aggExprs.exists(expr => collectPushableSpecs(expr).nonEmpty) + } + + private def isPushableExpr(expr: Expression): Boolean = { + collectPushableSpecs(expr).nonEmpty + } + + private def hasDistinctAggExpr(aggExprs: Seq[NamedExpression]): Boolean = { + aggExprs.exists { + _.exists { + case ae: AggregateExpression if ae.isDistinct => true + case _ => false + } + } + } + + private def pushableSpec(ae: AggregateExpression): Option[SidePartialSpec] = { + if (!ae.isDistinct && ae.filter.isEmpty) { + ae.aggregateFunction match { + case da: DeclarativeAggregate if !da.isInstanceOf[JoinAggregateFunctionWrapper] => + val stableExprSql = ae.canonicalized.sql + Some(SidePartialSpec(ae, da, Integer.toUnsignedString(stableExprSql.hashCode))) + case _ => None + } + } else { + None + } + } + + private def collectPushableSpecs(expr: Expression): Seq[SidePartialSpec] = { + expr + .collect { case ae: AggregateExpression => ae } + .flatMap(pushableSpec) + .foldLeft(Seq.empty[SidePartialSpec]) { + case (acc, spec) if acc.exists(_.originalExpr.semanticEquals(spec.originalExpr)) => acc + case (acc, spec) => acc :+ spec + } + } + + private def collectPushableSpecs(aggExprs: Seq[NamedExpression]): Seq[SidePartialSpec] = { + aggExprs + .flatMap { + case Alias(expr, _) => collectPushableSpecs(expr) + case expr => collectPushableSpecs(expr) + } + .foldLeft(Seq.empty[SidePartialSpec]) { + case (acc, spec) if acc.exists(_.originalExpr.semanticEquals(spec.originalExpr)) => acc + case (acc, spec) => acc :+ spec + } + } + + private def rewriteAggregateExpressions( + aggExprs: Seq[NamedExpression], + sidePartials: Seq[SidePartialRef]): Option[Seq[NamedExpression]] = { + // Replace every pushed aggregate subexpression with a final wrapper that consumes the pushed + // wrapper output. This allows expressions like `sum(a) - sum(b)` to be split and pushed even + // though the top-level result expression contains multiple aggregate functions. + var rewrittenAny = false + + val rewrittenAggExprs = aggExprs.map { + case alias @ Alias(expr, _) => + val rewrittenExpr = rewriteAggregateExpr(expr, sidePartials) + if (!rewrittenExpr.fastEquals(expr)) { + rewrittenAny = true + Alias(rewrittenExpr, alias.name)( + exprId = alias.exprId, + qualifier = alias.qualifier, + explicitMetadata = alias.explicitMetadata, + nonInheritableMetadataKeys = alias.nonInheritableMetadataKeys + ) + } else { + alias + } + case other => + val rewrittenExpr = rewriteAggregateExpr(other, sidePartials) + if (!rewrittenExpr.fastEquals(other)) { + rewrittenAny = true + rewrittenExpr.asInstanceOf[NamedExpression] + } else { + other + } + } + + if (rewrittenAny) Some(rewrittenAggExprs) else None + } + + private def rewriteAggregateExpr( + expr: Expression, + sidePartials: Seq[SidePartialRef]): Expression = { + // Rewrite only the aggregate nodes chosen for pushdown. Everything around them stays in its + // original logical shape, so expressions such as `sum(a) - sum(b)` still look the same above + // the join after each aggregate term is swapped to a final-phase wrapper. + expr.transformUp { + case ae: AggregateExpression => + pushableSpec(ae) + .flatMap { + spec => + sidePartials.collectFirst { + case SidePartialRef(sideSpec, partialAttr) + if sideSpec.originalExpr.semanticEquals(spec.originalExpr) => + JoinAggregateFunctionWrapper + .wrapperFinal(spec.aggregate, partialAttr, spec.wrapperKey) + .toAggregateExpression() + } + } + .getOrElse(ae) + } + } + + private def canPushToSide(agg: DeclarativeAggregate, sideOutputSet: AttributeSet): Boolean = { + agg.children.forall(_.references.forall(sideOutputSet.contains)) + } + + private def dedupeAttrs(attrs: Seq[Attribute]): Seq[Attribute] = { + attrs.foldLeft(Seq.empty[Attribute]) { + case (acc, attr) if acc.exists(_.semanticEquals(attr)) => acc + case (acc, attr) => acc :+ attr + } + } + + private def referencedAttrsInOrder(expr: Expression): Seq[Attribute] = { + expr.collect { case attr: Attribute => attr } + } + + sealed private trait JoinSide { + def plan(join: Join): LogicalPlan + def outputSet(join: Join): AttributeSet + def replace(join: Join, newPlan: LogicalPlan): Join + } + + private case object JoinLeft extends JoinSide { + override def plan(join: Join): LogicalPlan = join.left + override def outputSet(join: Join): AttributeSet = join.left.outputSet + override def replace(join: Join, newPlan: LogicalPlan): Join = join.copy(left = newPlan) + } + + private case object JoinRight extends JoinSide { + override def plan(join: Join): LogicalPlan = join.right + override def outputSet(join: Join): AttributeSet = join.right.outputSet + override def replace(join: Join, newPlan: LogicalPlan): Join = join.copy(right = newPlan) + } +} + +case class PushAggregateThroughJoinBatch(spark: SparkSession) extends Rule[LogicalPlan] { + private val decimalAvgRule = DecimalAggregates + private val pushRule = PushAggregateThroughJoin(spark) + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!isEnabled) { + return plan + } + val decimalAvgRewrittenPlan = decimalAvgRule(plan) + pushRule(decimalAvgRewrittenPlan) + } + + private def isEnabled: Boolean = { + GlutenConfig.get.pushAggregateThroughJoinEnabled + } +} diff --git a/gluten-substrait/src/test/scala/org/apache/gluten/execution/PushAggregateThroughJoinSuite.scala b/gluten-substrait/src/test/scala/org/apache/gluten/execution/PushAggregateThroughJoinSuite.scala new file mode 100644 index 000000000000..aac2376429df --- /dev/null +++ b/gluten-substrait/src/test/scala/org/apache/gluten/execution/PushAggregateThroughJoinSuite.scala @@ -0,0 +1,456 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.execution + +import org.apache.gluten.config.GlutenConfig +import org.apache.gluten.extension.joinagg.ImplementJoinAggregate +import org.apache.gluten.extension.joinagg.PushAggregateThroughJoin + +import org.apache.spark.SparkConf +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +import java.sql.Date + +class PushAggregateThroughJoinSuite extends PlanTest with SharedSparkSession { + private val joinAggregateRule = PushAggregateThroughJoin(spark) + private val debugMode: Boolean = true + + private case class PushdownCase(inputSql: String, expectedAggCount: Int) + + override protected def sparkConf: SparkConf = { + // Avoid Janino projection codegen here because Spark 4's QueryExecutionErrors + // has Arrow-typed methods, which breaks test runs as arrow-vector is excluded. + super.sparkConf + .set(SQLConf.CODEGEN_FACTORY_MODE.key, "NO_CODEGEN") + .set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "false") + } + + override def beforeAll(): Unit = { + super.beforeAll() + registerSampleTables() + } + + override def afterAll(): Unit = { + try { + spark.catalog.dropTempView("store_sales") + spark.catalog.dropTempView("date_dim") + spark.catalog.dropTempView("item") + } finally { + super.afterAll() + } + } + + private def registerSampleTables(): Unit = { + import testImplicits._ + + Seq( + (1, 100, 10.0, 1.0, 2.0), + (1, 100, 12.5, 2.0, 3.0), + (1, 100, 7.5, 3.0, 1.5), + (1, 101, 9.0, 2.0, 2.5), + (2, 100, 3.5, 1.0, 0.5), + (2, 100, 4.5, 2.0, 1.0), + (2, 103, 8.0, 4.0, 4.0) + ).toDF("ss_item_sk", "ss_sold_date_sk", "ss_sales_price", "ss_quantity", "ss_net_profit") + .createOrReplaceTempView("store_sales") + + Seq( + (100, 1999, Date.valueOf("2020-01-01")), + (100, 1999, Date.valueOf("2020-01-01")), + (101, 2000, Date.valueOf("2020-01-02")), + (103, 2003, Date.valueOf("2020-01-03")) + ).toDF("d_date_sk", "d_year", "d_date") + .createOrReplaceTempView("date_dim") + + Seq((1, "item-one", 1), (1, "item-one", 1), (2, "item-two", 6)) + .toDF("i_item_sk", "i_item_desc", "i_category_id") + .createOrReplaceTempView("item") + } + + private def runCaseWithMaxDepth( + testCase: PushdownCase, + maxDepth: Int, + expectedPushCount: Int): Unit = { + withSQLConf( + GlutenConfig.PUSH_AGGREGATE_THROUGH_JOIN_ENABLED.key -> "true", + GlutenConfig.PUSH_AGGREGATE_THROUGH_JOIN_MAX_DEPTH.key -> maxDepth.toString) { + val (withoutRuleRows, withoutRuleLogicalPlan, withoutRulePhysicalPlan) = + withExtraPlanning(Nil, Nil) { + val df = spark.sql(testCase.inputSql) + ( + df.collect().toSeq.sortBy(_.toString()), + df.queryExecution.optimizedPlan, + finalExecutedPlan(df.queryExecution.executedPlan) + ) + } + + val (withRuleRows, withRuleLogicalPlan, withRulePhysicalPlan) = + withExtraPlanning(Seq(joinAggregateRule), Nil) { + joinAggregateRule.resetSuccessfulPushCount() + val df = spark.sql(testCase.inputSql) + val withRuleRows = df.collect().toSeq.sortBy(_.toString()) + val withRulePlan = df.queryExecution.optimizedPlan + val withRulePhysicalPlan = finalExecutedPlan(df.queryExecution.executedPlan) + val aggregateNodeCount = withRulePlan.collect { case _: Aggregate => 1 }.size + val nodesWithMissingInput = withRulePlan.collect { + case p if p.missingInput.nonEmpty => p + } + + assert( + withRulePlan.resolved, + s"Optimized plan unresolved:\n${withRulePlan.treeString}\n" + + s"MissingInput=${withRulePlan.missingInput}") + assert( + nodesWithMissingInput.isEmpty, + s"Plan has missing input:\n${nodesWithMissingInput + .map(_.treeString) + .mkString("\n---\n")}") + assert(joinAggregateRule.getSuccessfulPushCount == expectedPushCount) + assert(aggregateNodeCount == testCase.expectedAggCount) + (withRuleRows, withRulePlan, withRulePhysicalPlan) + } + + val ( + withRuleAndStrategyRows, + withRuleAndStrategyLogicalPlan, + withRuleAndStrategyPhysicalPlan) = + withExtraPlanning(Seq(joinAggregateRule), Seq(ImplementJoinAggregate(spark))) { + joinAggregateRule.resetSuccessfulPushCount() + val df = spark.sql(testCase.inputSql) + ( + df.collect().toSeq.sortBy(_.toString()), + df.queryExecution.optimizedPlan, + finalExecutedPlan(df.queryExecution.executedPlan) + ) + } + + if (debugMode) { + // scalastyle:off println + println("=== Optimized Plan Before (without PushJoinAggregatePreAggregation) ===") + println(withoutRuleLogicalPlan.treeString) + println("=== Optimized Plan After (with PushJoinAggregatePreAggregation) ===") + println(withRuleLogicalPlan.treeString) + println("=== Optimized Plan After (with PushJoinAggregatePreAggregation and strategy) ===") + println(withRuleAndStrategyLogicalPlan.treeString) + println("=== Physical Plan Before (without PushJoinAggregatePreAggregation) ===") + println(withoutRulePhysicalPlan.treeString) + println("=== Physical Plan After (with PushJoinAggregatePreAggregation only) ===") + println(withRulePhysicalPlan.treeString) + println("=== Physical Plan After (with PushJoinAggregatePreAggregation and strategy) ===") + println(withRuleAndStrategyPhysicalPlan.treeString) + println("=== Result Before (without PushJoinAggregatePreAggregation) ===") + println(withoutRuleRows.mkString("\n")) + println("=== Result After (with PushJoinAggregatePreAggregation only) ===") + println(withRuleRows.mkString("\n")) + println("=== Result After (with PushJoinAggregatePreAggregation and strategy) ===") + println(withRuleAndStrategyRows.mkString("\n")) + // scalastyle:on println + } + + assertRowsEqual(withRuleRows, withoutRuleRows) + assertRowsEqual(withRuleAndStrategyRows, withoutRuleRows) + } + } + + private def assertRowsEqual(left: Seq[Row], right: Seq[Row]): Unit = { + assert(left == right, s"Result mismatch:\nleft=$left\nright=$right") + } + + private def finalExecutedPlan(plan: SparkPlan): SparkPlan = plan match { + case adaptive: AdaptiveSparkPlanExec => + adaptive.executedPlan + case other => + other + } + + private def withExtraPlanning[T](rules: Seq[Rule[LogicalPlan]], strategies: Seq[SparkStrategy])( + f: => T): T = { + val previousOptimizations = spark.experimental.extraOptimizations + val previousStrategies = spark.experimental.extraStrategies + try { + spark.experimental.extraOptimizations = rules + spark.experimental.extraStrategies = strategies + f + } finally { + spark.experimental.extraOptimizations = previousOptimizations + spark.experimental.extraStrategies = previousStrategies + } + } + + test("pre-aggregate store_sales for both joins with having filter") { + val pushdownCase = PushdownCase( + inputSql = """ + |SELECT + | substring(i_item_desc, 1, 30) AS itemdesc, + | i_item_sk AS item_sk, + | d_date AS solddate, + | count(1) AS cnt + |FROM store_sales + |JOIN date_dim ON ss_sold_date_sk = d_date_sk + |JOIN item ON ss_item_sk = i_item_sk + |WHERE d_year IN (1999, 2000, 2001, 2002) + |GROUP BY substring(i_item_desc, 1, 30), i_item_sk, d_date + |HAVING count(1) > 4 + |""".stripMargin, + expectedAggCount = 2 + ) + runCaseWithMaxDepth(pushdownCase, maxDepth = Int.MaxValue, expectedPushCount = 2) + } + + test("pre-aggregate store_sales for sum") { + val pushdownCase = PushdownCase( + inputSql = """ + |SELECT + | i_item_sk AS item_sk, + | sum(ss_sales_price) AS total_sales_price + |FROM store_sales + |JOIN item ON ss_item_sk = i_item_sk + |GROUP BY i_item_sk + |""".stripMargin, + expectedAggCount = 2 + ) + runCaseWithMaxDepth(pushdownCase, maxDepth = Int.MaxValue, expectedPushCount = 1) + } + + test("pre-aggregate store_sales for avg") { + val pushdownCase = PushdownCase( + inputSql = """ + |SELECT + | i_item_sk AS item_sk, + | avg(ss_sales_price) AS avg_sales_price + |FROM store_sales + |JOIN item ON ss_item_sk = i_item_sk + |GROUP BY i_item_sk + |""".stripMargin, + expectedAggCount = 2 + ) + runCaseWithMaxDepth(pushdownCase, maxDepth = Int.MaxValue, expectedPushCount = 1) + } + + test("pre-aggregate store_sales for sum on fact table") { + val pushdownCase = PushdownCase( + inputSql = """ + |SELECT + | ss_sold_date_sk, + | sum(ss_sales_price) AS total_sales_price + |FROM store_sales + |JOIN item ON ss_item_sk = i_item_sk + |GROUP BY ss_sold_date_sk + |""".stripMargin, + expectedAggCount = 2 + ) + runCaseWithMaxDepth(pushdownCase, maxDepth = Int.MaxValue, expectedPushCount = 1) + } + + test("pre-aggregate store_sales for avg on fact table") { + val pushdownCase = PushdownCase( + inputSql = """ + |SELECT + | ss_sold_date_sk, + | avg(ss_sales_price) AS avg_sales_price + |FROM store_sales + |JOIN item ON ss_item_sk = i_item_sk + |GROUP BY ss_sold_date_sk + |""".stripMargin, + expectedAggCount = 2 + ) + runCaseWithMaxDepth(pushdownCase, maxDepth = Int.MaxValue, expectedPushCount = 1) + } + + test("pre-aggregate store_sales for sum on three-way join") { + val pushdownCase = PushdownCase( + inputSql = """ + |SELECT + | i_item_desc AS item_desc, + | d_date AS sold_date, + | sum(ss_sales_price) AS total_sales_price + |FROM store_sales + |JOIN date_dim ON ss_sold_date_sk = d_date_sk + |JOIN item ON ss_item_sk = i_item_sk + |GROUP BY item_desc, d_date + |""".stripMargin, + expectedAggCount = 2 + ) + runCaseWithMaxDepth(pushdownCase, maxDepth = Int.MaxValue, expectedPushCount = 2) + } + + test("pre-aggregate store_sales for sum and avg on different fact columns on three-way join") { + val pushdownCase = PushdownCase( + inputSql = """ + |SELECT + | i_item_desc AS item_desc, + | d_date AS sold_date, + | sum(ss_sales_price) AS total_sales_price, + | avg(ss_quantity) AS avg_quantity + |FROM store_sales + |JOIN date_dim ON ss_sold_date_sk = d_date_sk + |JOIN item ON ss_item_sk = i_item_sk + |GROUP BY item_desc, d_date + |""".stripMargin, + expectedAggCount = 2 + ) + runCaseWithMaxDepth(pushdownCase, maxDepth = Int.MaxValue, expectedPushCount = 2) + } + + test("pre-aggregate store_sales for sum and avg on same fact column on three-way join") { + val pushdownCase = PushdownCase( + inputSql = """ + |SELECT + | i_item_desc AS item_desc, + | d_date AS sold_date, + | sum(ss_sales_price) AS total_sales_price, + | avg(ss_sales_price) AS avg_sales_price + |FROM store_sales + |JOIN date_dim ON ss_sold_date_sk = d_date_sk + |JOIN item ON ss_item_sk = i_item_sk + |GROUP BY item_desc, d_date + |""".stripMargin, + expectedAggCount = 2 + ) + runCaseWithMaxDepth(pushdownCase, maxDepth = Int.MaxValue, expectedPushCount = 2) + } + + test("pre-aggregate store_sales by i_item_desc") { + val pushdownCase = PushdownCase( + inputSql = """ + |SELECT + | i_item_desc AS item_desc, + | avg(ss_sales_price) AS avg_sales_price + |FROM store_sales + |JOIN item ON ss_item_sk = i_item_sk + |GROUP BY item_desc + |""".stripMargin, + expectedAggCount = 2 + ) + runCaseWithMaxDepth(pushdownCase, maxDepth = Int.MaxValue, expectedPushCount = 1) + } + + test("pre-aggregate store_sales by substr(i_item_desc, 3), 3 ways") { + val pushdownCase = PushdownCase( + inputSql = """ + |SELECT + | d_date AS sold_date, + | substr(i_item_desc, 3) AS item_desc, + | avg(ss_sales_price) AS avg_sales_price + |FROM store_sales + |JOIN date_dim ON ss_sold_date_sk = d_date_sk + |JOIN item ON ss_item_sk = i_item_sk + |GROUP BY d_date, item_desc + |""".stripMargin, + expectedAggCount = 2 + ) + runCaseWithMaxDepth(pushdownCase, maxDepth = Int.MaxValue, expectedPushCount = 2) + } + + test("pre-aggregate store_sales for sum with item filter") { + val pushdownCase = PushdownCase( + inputSql = """ + |SELECT + | sum(ss_net_profit) AS profit + |FROM store_sales + |JOIN item ON ss_item_sk = i_item_sk + |WHERE i_category_id IN (1, 2, 3, 4, 5) + |""".stripMargin, + expectedAggCount = 2 + ) + runCaseWithMaxDepth(pushdownCase, maxDepth = Int.MaxValue, expectedPushCount = 1) + } + + test("pre-aggregate three-way joins independently under union all") { + val pushdownCase = PushdownCase( + inputSql = + """ + |SELECT key, total_sales_price + |FROM ( + | SELECT + | concat('item-', cast(i_item_sk AS string), '-', cast(d_date_sk AS string)) AS key, + | sum(ss_sales_price) AS total_sales_price + | FROM store_sales + | JOIN date_dim ON ss_sold_date_sk = d_date_sk + | JOIN item ON ss_item_sk = i_item_sk + | GROUP BY concat('item-', cast(i_item_sk AS string), '-', cast(d_date_sk AS string)) + | + | UNION ALL + | + | SELECT + | concat('desc-', i_item_desc, '-', cast(d_date_sk AS string)) AS key, + | sum(ss_sales_price) AS total_sales_price + | FROM store_sales + | JOIN date_dim ON ss_sold_date_sk = d_date_sk + | JOIN item ON ss_item_sk = i_item_sk + | GROUP BY concat('desc-', i_item_desc, '-', cast(d_date_sk AS string)) + | + | UNION ALL + | + | SELECT + | concat('year-', cast(d_year AS string), '-', cast(i_item_sk AS string)) AS key, + | sum(ss_sales_price) AS total_sales_price + | FROM store_sales + | JOIN date_dim ON ss_sold_date_sk = d_date_sk + | JOIN item ON ss_item_sk = i_item_sk + | GROUP BY concat('year-', cast(d_year AS string), '-', cast(i_item_sk AS string)) + |) + |""".stripMargin, + expectedAggCount = 6 + ) + runCaseWithMaxDepth(pushdownCase, maxDepth = 1, expectedPushCount = 3) + runCaseWithMaxDepth(pushdownCase, maxDepth = 2, expectedPushCount = 6) + runCaseWithMaxDepth(pushdownCase, maxDepth = Int.MaxValue, expectedPushCount = 6) + } + + test("pre-aggregate store_sales for sum on three-way join with maxDepth=1 / maxDepth=2") { + val pushdownCase = PushdownCase( + inputSql = """ + |SELECT + | i_item_desc AS item_desc, + | d_date AS sold_date, + | sum(ss_sales_price) AS total_sales_price + |FROM store_sales + |JOIN date_dim ON ss_sold_date_sk = d_date_sk + |JOIN item ON ss_item_sk = i_item_sk + |GROUP BY item_desc, d_date + |""".stripMargin, + expectedAggCount = 2 + ) + runCaseWithMaxDepth(pushdownCase, maxDepth = 1, expectedPushCount = 1) + runCaseWithMaxDepth(pushdownCase, maxDepth = 2, expectedPushCount = 2) + runCaseWithMaxDepth(pushdownCase, maxDepth = Int.MaxValue, expectedPushCount = 2) + } + + test("pre-aggregate with filter inside inner equi-join") { + val pushdownCase = PushdownCase( + inputSql = """ + |SELECT + | i_item_sk AS item_sk, + | sum(ss_sales_price) AS total_sales_price + |FROM store_sales + |JOIN item ON ss_item_sk = i_item_sk AND ss_quantity > 1 + |GROUP BY i_item_sk + |""".stripMargin, + expectedAggCount = 2 + ) + runCaseWithMaxDepth(pushdownCase, maxDepth = Int.MaxValue, expectedPushCount = 1) + } +}