Skip to content

Commit

Permalink
[GLUTEN-1985] avoid forceShuffledHashJoin when the join condition doe…
Browse files Browse the repository at this point in the history
…s not supported by the backend
  • Loading branch information
zheniantoushipashi committed Jun 20, 2023
1 parent f3f7872 commit f08ce5f
Show file tree
Hide file tree
Showing 11 changed files with 101 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ object CHBackendSettings extends BackendSettingsApi with Logging {
GlutenConfig.getConf.enableColumnarSort
}

override def supportSortMergeJoinExec(): Boolean = {
false
}

override def supportWindowExec(windowFunctions: Seq[NamedExpression]): Boolean = {
var allSupported = true
breakable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import io.glutenproject.backendsapi.SparkPlanExecApi
import io.glutenproject.execution._
import io.glutenproject.expression.{AggregateFunctionsBuilder, AliasTransformerBase, CHSha1Transformer, CHSha2Transformer, ConverterUtils, ExpressionConverter, ExpressionMappings, ExpressionTransformer, WindowFunctionsBuilder}
import io.glutenproject.substrait.expression.{ExpressionBuilder, ExpressionNode, WindowFunctionNode}
import io.glutenproject.utils.CHJoinValidateUtil
import io.glutenproject.vectorized.{CHBlockWriterJniWrapper, CHColumnarBatchSerializer}

import org.apache.spark.{ShuffleDependency, SparkException}
Expand Down Expand Up @@ -380,9 +381,22 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
new CHSha1Transformer(substraitExprName, child, original)
}

/**
* Define whether the join operator is fallback because of the join operator is not supported by
* backend
*/
override def joinFallback(
JoinType: JoinType,
leftOutputSet: AttributeSet,
rightOutputSet: AttributeSet,
condition: Option[Expression]): Boolean = {
return CHJoinValidateUtil.shouldFallback(JoinType, leftOutputSet, rightOutputSet, condition)
}

/**
* Generate an BasicScanExecTransformer to transfrom hive table scan. Currently only for CH
* backend.
*
* @param child
* @return
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,8 @@ case class CHShuffledHashJoinExecTransformer(
copy(left = newLeft, right = newRight)

override def doValidateInternal(): Boolean = {
var shouldFallback = false
if (substraitJoinType != JoinRel.JoinType.JOIN_TYPE_INNER) {
shouldFallback = CHJoinValidateUtil.doValidate(condition)
}
val shouldFallback =
CHJoinValidateUtil.shouldFallback(joinType, left.outputSet, right.outputSet, condition)
if (shouldFallback) {
return false
}
Expand Down Expand Up @@ -88,10 +86,8 @@ case class CHBroadcastHashJoinExecTransformer(
*/
override def doValidateInternal(): Boolean = {
var shouldFallback = false
if (substraitJoinType != JoinRel.JoinType.JOIN_TYPE_INNER) {
shouldFallback = CHJoinValidateUtil.doValidate(condition)
}
var shouldFallback =
CHJoinValidateUtil.shouldFallback(joinType, left.outputSet, right.outputSet, condition)
if (isNullAwareAntiJoin == true) {
shouldFallback = true
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
package io.glutenproject.utils

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{EqualTo, Expression, GreaterThan, LessThan, Not, Or}
import org.apache.spark.sql.catalyst.expressions.{AttributeSet, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Not, Or}
import org.apache.spark.sql.catalyst.plans.JoinType

/**
* The logic here is that if it is not an equi-join spark will create BNLJ, which will fallback, if
Expand All @@ -30,37 +31,56 @@ import org.apache.spark.sql.catalyst.expressions.{EqualTo, Expression, GreaterTh
*/

object CHJoinValidateUtil extends Logging {
def hasTwoTableColumn(l: Expression, r: Expression): Boolean = {
!l.references.toSeq
.map(_.qualifier.mkString("."))
.toSet
.subsetOf(r.references.toSeq.map(_.qualifier.mkString(".")).toSet)
def hasTwoTableColumn(
leftOutputSet: AttributeSet,
rightOutputSet: AttributeSet,
l: Expression,
r: Expression): Boolean = {
val allReferences = l.references ++ r.references
!(allReferences.subsetOf(leftOutputSet) || allReferences.subsetOf(rightOutputSet))
}

def doValidate(condition: Option[Expression]): Boolean = {
def shouldFallback(
joinType: JoinType,
leftOutputSet: AttributeSet,
rightOutputSet: AttributeSet,
condition: Option[Expression]): Boolean = {
var shouldFallback = false
if (joinType.sql.equals("INNER")) {
return shouldFallback
}
if (condition.isDefined) {
condition.get.transform {
case Or(l, r) =>
if (hasTwoTableColumn(l, r)) {
if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) {
shouldFallback = true
}
Or(l, r)
case Not(EqualTo(l, r)) =>
if (l.references.nonEmpty && r.references.nonEmpty) {
if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) {
shouldFallback = true
}
Not(EqualTo(l, r))
case LessThan(l, r) =>
if (l.references.nonEmpty && r.references.nonEmpty) {
if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) {
shouldFallback = true
}
LessThan(l, r)
case LessThanOrEqual(l, r) =>
if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) {
shouldFallback = true
}
LessThanOrEqual(l, r)
case GreaterThan(l, r) =>
if (l.references.nonEmpty && r.references.nonEmpty) {
if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) {
shouldFallback = true
}
GreaterThan(l, r)
case GreaterThanOrEqual(l, r) =>
if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) {
shouldFallback = true
}
GreaterThanOrEqual(l, r)
}
}
shouldFallback
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ abstract class GlutenClickHouseTPCDSAbstractSuite extends WholeStageTransformerS
/** Return values: (sql num, is fall back, skip fall back assert) */
def tpcdsAllQueries(isAqe: Boolean): Seq[(String, Boolean, Boolean)] =
Range
.inclusive(1, 99)
.inclusive(21, 22)
.flatMap(
queryNum => {
val sqlNums = if (queryNum == 14 || queryNum == 23 || queryNum == 24 || queryNum == 39) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Not}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}

class GlutenClickHouseTPCDSParquetSuite extends GlutenClickHouseTPCDSAbstractSuite {

Expand Down Expand Up @@ -153,7 +153,7 @@ class GlutenClickHouseTPCDSParquetSuite extends GlutenClickHouseTPCDSAbstractSui
test("test fallbackutils") {
val testSql =
"""
|SELECT i_brand_id AS brand_id, i_brand AS brand, i_manufact_id, i_manufact,
|SELECT i_brand_id AS brand_id, i_brand AS brand, i_manufact_id, i_manufact,
| sum(ss_ext_sales_price) AS ext_price
| FROM date_dim
| LEFT JOIN store_sales ON d_date_sk = ss_sold_date_sk
Expand All @@ -172,6 +172,33 @@ class GlutenClickHouseTPCDSParquetSuite extends GlutenClickHouseTPCDSAbstractSui
assert(FallbackUtil.isFallback(df.queryExecution.executedPlan))
}

test(
"Test avoid forceShuffledHashJoin when the join condition" +
" does not supported by the backend") {
val testSql =
"""
|SELECT /*+ merge(date_dim)*/ i_brand_id AS brand_id, i_brand AS brand, i_manufact_id, i_manufact,
| sum(ss_ext_sales_price) AS ext_price
| FROM date_dim
| LEFT JOIN store_sales ON d_date_sk == ss_sold_date_sk AND (d_date_sk = 213232 OR ss_sold_date_sk = 3232)
| LEFT JOIN item ON ss_item_sk = i_item_sk AND i_manager_id = 7
| LEFT JOIN customer ON ss_customer_sk = c_customer_sk
| LEFT JOIN customer_address ON c_current_addr_sk = ca_address_sk
| LEFT JOIN store ON ss_store_sk = s_store_sk AND substr(ca_zip,1,5) <> substr(s_zip,1,5)
| WHERE d_moy = 11
| AND d_year = 1999
| GROUP BY i_brand_id, i_brand, i_manufact_id, i_manufact
| ORDER BY ext_price DESC, i_brand, i_brand_id, i_manufact_id, i_manufact
| LIMIT 100;
|""".stripMargin

val df = spark.sql(testSql)
val sortMergeJoinExec = df.queryExecution.executedPlan.collect {
case s: SortMergeJoinExec => s
}
assert(sortMergeJoinExec.nonEmpty)
}

test("Gluten-1235: Fix missing reading from the broadcasted value when executing DPP") {
val testSql =
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ object BackendSettings extends BackendSettingsApi {

override def supportSortExec(): Boolean = true

override def supportSortMergeJoinExec(): Boolean = {
GlutenConfig.getConf.enableColumnarSortMergeJoin
}

override def supportWindowExec(windowFunctions: Seq[NamedExpression]): Boolean = {
var allSupported = true
breakable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ trait BackendSettingsApi {
paths: Seq[String]): Boolean = false
def supportExpandExec(): Boolean = false
def supportSortExec(): Boolean = false
def supportSortMergeJoinExec(): Boolean = true
def supportWindowExec(windowFunctions: Seq[NamedExpression]): Boolean = {
false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,15 @@ trait SparkPlanExecApi {
*/
def extraExpressionMappings: Seq[Sig] = Seq.empty

/**
* Define whether the join operator is fallback because of
* the join operator is not supported by backend
*/
def joinFallback(JoinType: JoinType,
leftOutputSet: AttributeSet,
right: AttributeSet,
condition: Option[Expression]): Boolean = false

/**
* default function to generate window function node
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ case class JoinSelectionOverrides(session: SparkSession) extends Strategy with
}

if (forceShuffledHashJoin &&
!BackendsApiManager.getSparkPlanExecApiInstance.
joinFallback(joinType, left.outputSet, right.outputSet, condition) &&
!left.getTagValue(TAG).isDefined &&
!right.getTagValue(TAG).isDefined) {
// Force use of ShuffledHashJoin in preference to SortMergeJoin. With no respect to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ case class AddTransformHintRule() extends Rule[SparkPlan] {
!scanOnly && BackendsApiManager.getSettings.supportColumnarShuffleExec()
val enableColumnarSort: Boolean = !scanOnly && columnarConf.enableColumnarSort
val enableColumnarWindow: Boolean = !scanOnly && columnarConf.enableColumnarWindow
val enableColumnarSortMergeJoin: Boolean = !scanOnly && columnarConf.enableColumnarSortMergeJoin
val enableColumnarSortMergeJoin: Boolean = !scanOnly &&
BackendsApiManager.getSettings.supportSortMergeJoinExec()
val enableColumnarBatchScan: Boolean = columnarConf.enableColumnarBatchScan
val enableColumnarFileScan: Boolean = columnarConf.enableColumnarFileScan
val enableColumnarProject: Boolean = !scanOnly && columnarConf.enableColumnarProject
Expand Down

0 comments on commit f08ce5f

Please sign in to comment.