Skip to content

Commit

Permalink
[GLUTEN-1985] avoid forceShuffledHashJoin when the join condition doe…
Browse files Browse the repository at this point in the history
…s not supported by the backend
  • Loading branch information
zheniantoushipashi committed Jun 16, 2023
1 parent f3f7872 commit e9db42a
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package io.glutenproject.execution

import io.glutenproject.utils.CHJoinValidateUtil
import io.glutenproject.utils.JoinValidateUtil

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.BuildSide
Expand Down Expand Up @@ -51,7 +51,7 @@ case class CHShuffledHashJoinExecTransformer(
override def doValidateInternal(): Boolean = {
var shouldFallback = false
if (substraitJoinType != JoinRel.JoinType.JOIN_TYPE_INNER) {
shouldFallback = CHJoinValidateUtil.doValidate(condition)
shouldFallback = JoinValidateUtil.doValidate(condition)
}
if (shouldFallback) {
return false
Expand Down Expand Up @@ -90,7 +90,7 @@ case class CHBroadcastHashJoinExecTransformer(
override def doValidateInternal(): Boolean = {
var shouldFallback = false
if (substraitJoinType != JoinRel.JoinType.JOIN_TYPE_INNER) {
shouldFallback = CHJoinValidateUtil.doValidate(condition)
shouldFallback = JoinValidateUtil.doValidate(condition)
}
if (isNullAwareAntiJoin == true) {
shouldFallback = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class GlutenClickHouseTPCDSParquetSuite extends GlutenClickHouseTPCDSAbstractSui
test("test fallbackutils") {
val testSql =
"""
|SELECT i_brand_id AS brand_id, i_brand AS brand, i_manufact_id, i_manufact,
|SELECT i_brand_id AS brand_id, i_brand AS brand, i_manufact_id, i_manufact,
| sum(ss_ext_sales_price) AS ext_price
| FROM date_dim
| LEFT JOIN store_sales ON d_date_sk = ss_sold_date_sk
Expand All @@ -172,6 +172,33 @@ class GlutenClickHouseTPCDSParquetSuite extends GlutenClickHouseTPCDSAbstractSui
assert(FallbackUtil.isFallback(df.queryExecution.executedPlan))
}

test(
"Test avoid forceShuffledHashJoin when the join condition" +
" does not supported by the backend") {
val testSql =
"""
|SELECT /*+ merge(date_dim)*/ i_brand_id AS brand_id, i_brand AS brand, i_manufact_id, i_manufact,
| sum(ss_ext_sales_price) AS ext_price
| FROM date_dim
| LEFT JOIN store_sales ON d_date_sk == ss_sold_date_sk AND (d_date_sk = 213232 OR ss_sold_date_sk = 3232)
| LEFT JOIN item ON ss_item_sk = i_item_sk AND i_manager_id = 7
| LEFT JOIN customer ON ss_customer_sk = c_customer_sk
| LEFT JOIN customer_address ON c_current_addr_sk = ca_address_sk
| LEFT JOIN store ON ss_store_sk = s_store_sk AND substr(ca_zip,1,5) <> substr(s_zip,1,5)
| WHERE d_moy = 11
| AND d_year = 1999
| GROUP BY i_brand_id, i_brand, i_manufact_id, i_manufact
| ORDER BY ext_price DESC, i_brand, i_brand_id, i_manufact_id, i_manufact
| LIMIT 100;
|""".stripMargin

val df = spark.sql(testSql)
val sortMergeJoinExecTransFormer = df.queryExecution.executedPlan.collect {
case s: SortMergeJoinExecTransformer => s
}
assert(sortMergeJoinExecTransFormer.nonEmpty)
}

test("Gluten-1235: Fix missing reading from the broadcasted value when executing DPP") {
val testSql =
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, Logical
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.execution.{JoinSelectionShim, SparkPlan, joins}
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions, Strategy}
import io.glutenproject.utils.JoinValidateUtil

object StrategyOverrides extends GlutenSparkExtensionsInjector {
override def inject(extensions: SparkSessionExtensions): Unit = {
Expand Down Expand Up @@ -84,7 +85,7 @@ case class JoinSelectionOverrides(session: SparkSession) extends Strategy with
planLater(right)))
}

if (forceShuffledHashJoin &&
if (forceShuffledHashJoin && !JoinValidateUtil.doValidate(condition) &&
!left.getTagValue(TAG).isDefined &&
!right.getTagValue(TAG).isDefined) {
// Force use of ShuffledHashJoin in preference to SortMergeJoin. With no respect to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.{EqualTo, Expression, GreaterTh
* Existence Join which is just an optimization of exist subquery, it will also fallback
*/

object CHJoinValidateUtil extends Logging {
object JoinValidateUtil extends Logging {
def hasTwoTableColumn(l: Expression, r: Expression): Boolean = {
!l.references.toSeq
.map(_.qualifier.mkString("."))
Expand Down

0 comments on commit e9db42a

Please sign in to comment.