Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-1985] avoid forceShuffledHashJoin when the join condition does not supported by the backend #1986

Merged
merged 1 commit into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ object CHBackendSettings extends BackendSettingsApi with Logging {
GlutenConfig.getConf.enableColumnarSort
}

override def supportSortMergeJoinExec(): Boolean = {
false
}

override def supportWindowExec(windowFunctions: Seq[NamedExpression]): Boolean = {
var allSupported = true
breakable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import io.glutenproject.backendsapi.SparkPlanExecApi
import io.glutenproject.execution._
import io.glutenproject.expression.{AggregateFunctionsBuilder, AliasTransformerBase, CHSha1Transformer, CHSha2Transformer, ConverterUtils, ExpressionConverter, ExpressionMappings, ExpressionTransformer, WindowFunctionsBuilder}
import io.glutenproject.substrait.expression.{ExpressionBuilder, ExpressionNode, WindowFunctionNode}
import io.glutenproject.utils.CHJoinValidateUtil
import io.glutenproject.vectorized.{CHBlockWriterJniWrapper, CHColumnarBatchSerializer}

import org.apache.spark.{ShuffleDependency, SparkException}
Expand Down Expand Up @@ -380,9 +381,22 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
new CHSha1Transformer(substraitExprName, child, original)
}

/**
* Define whether the join operator is fallback because of the join operator is not supported by
* backend
*/
override def joinFallback(
JoinType: JoinType,
leftOutputSet: AttributeSet,
rightOutputSet: AttributeSet,
condition: Option[Expression]): Boolean = {
return CHJoinValidateUtil.shouldFallback(JoinType, leftOutputSet, rightOutputSet, condition)
}

/**
* Generate an BasicScanExecTransformer to transfrom hive table scan. Currently only for CH
* backend.
*
* @param child
* @return
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,8 @@ case class CHShuffledHashJoinExecTransformer(
copy(left = newLeft, right = newRight)

override def doValidateInternal(): Boolean = {
var shouldFallback = false
if (substraitJoinType != JoinRel.JoinType.JOIN_TYPE_INNER) {
shouldFallback = CHJoinValidateUtil.doValidate(condition)
}
val shouldFallback =
CHJoinValidateUtil.shouldFallback(joinType, left.outputSet, right.outputSet, condition)
if (shouldFallback) {
return false
}
Expand Down Expand Up @@ -88,10 +86,8 @@ case class CHBroadcastHashJoinExecTransformer(

*/
override def doValidateInternal(): Boolean = {
var shouldFallback = false
if (substraitJoinType != JoinRel.JoinType.JOIN_TYPE_INNER) {
shouldFallback = CHJoinValidateUtil.doValidate(condition)
}
var shouldFallback =
CHJoinValidateUtil.shouldFallback(joinType, left.outputSet, right.outputSet, condition)
if (isNullAwareAntiJoin == true) {
shouldFallback = true
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
package io.glutenproject.utils

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{EqualTo, Expression, GreaterThan, LessThan, Not, Or}
import org.apache.spark.sql.catalyst.expressions.{AttributeSet, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Not, Or}
import org.apache.spark.sql.catalyst.plans.JoinType

/**
* The logic here is that if it is not an equi-join spark will create BNLJ, which will fallback, if
Expand All @@ -30,37 +31,59 @@ import org.apache.spark.sql.catalyst.expressions.{EqualTo, Expression, GreaterTh
*/

object CHJoinValidateUtil extends Logging {
def hasTwoTableColumn(l: Expression, r: Expression): Boolean = {
!l.references.toSeq
.map(_.qualifier.mkString("."))
.toSet
.subsetOf(r.references.toSeq.map(_.qualifier.mkString(".")).toSet)
def hasTwoTableColumn(
leftOutputSet: AttributeSet,
rightOutputSet: AttributeSet,
l: Expression,
r: Expression): Boolean = {
val allReferences = l.references ++ r.references
!(allReferences.subsetOf(leftOutputSet) || allReferences.subsetOf(rightOutputSet))
}

def doValidate(condition: Option[Expression]): Boolean = {
def shouldFallback(
joinType: JoinType,
leftOutputSet: AttributeSet,
rightOutputSet: AttributeSet,
condition: Option[Expression]): Boolean = {
var shouldFallback = false
if (joinType.toString.contains("ExistenceJoin")) {
return true
}
if (joinType.sql.equals("INNER")) {
return shouldFallback
}
if (condition.isDefined) {
condition.get.transform {
case Or(l, r) =>
if (hasTwoTableColumn(l, r)) {
if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) {
shouldFallback = true
}
Or(l, r)
case Not(EqualTo(l, r)) =>
if (l.references.nonEmpty && r.references.nonEmpty) {
if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) {
shouldFallback = true
}
Not(EqualTo(l, r))
case LessThan(l, r) =>
if (l.references.nonEmpty && r.references.nonEmpty) {
if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) {
shouldFallback = true
}
LessThan(l, r)
case LessThanOrEqual(l, r) =>
if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) {
shouldFallback = true
}
LessThanOrEqual(l, r)
case GreaterThan(l, r) =>
if (l.references.nonEmpty && r.references.nonEmpty) {
if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) {
shouldFallback = true
}
GreaterThan(l, r)
case GreaterThanOrEqual(l, r) =>
if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) {
shouldFallback = true
}
GreaterThanOrEqual(l, r)
}
}
shouldFallback
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Not}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}

class GlutenClickHouseTPCDSParquetSuite extends GlutenClickHouseTPCDSAbstractSuite {

Expand Down Expand Up @@ -153,7 +153,7 @@ class GlutenClickHouseTPCDSParquetSuite extends GlutenClickHouseTPCDSAbstractSui
test("test fallbackutils") {
val testSql =
"""
|SELECT i_brand_id AS brand_id, i_brand AS brand, i_manufact_id, i_manufact,
|SELECT i_brand_id AS brand_id, i_brand AS brand, i_manufact_id, i_manufact,
| sum(ss_ext_sales_price) AS ext_price
| FROM date_dim
| LEFT JOIN store_sales ON d_date_sk = ss_sold_date_sk
Expand All @@ -172,6 +172,33 @@ class GlutenClickHouseTPCDSParquetSuite extends GlutenClickHouseTPCDSAbstractSui
assert(FallbackUtil.isFallback(df.queryExecution.executedPlan))
}

test(
"Test avoid forceShuffledHashJoin when the join condition" +
" does not supported by the backend") {
val testSql =
"""
|SELECT /*+ merge(date_dim)*/ i_brand_id AS brand_id, i_brand AS brand, i_manufact_id, i_manufact,
| sum(ss_ext_sales_price) AS ext_price
| FROM date_dim
| LEFT JOIN store_sales ON d_date_sk == ss_sold_date_sk AND (d_date_sk = 213232 OR ss_sold_date_sk = 3232)
| LEFT JOIN item ON ss_item_sk = i_item_sk AND i_manager_id = 7
| LEFT JOIN customer ON ss_customer_sk = c_customer_sk
| LEFT JOIN customer_address ON c_current_addr_sk = ca_address_sk
| LEFT JOIN store ON ss_store_sk = s_store_sk AND substr(ca_zip,1,5) <> substr(s_zip,1,5)
| WHERE d_moy = 11
| AND d_year = 1999
| GROUP BY i_brand_id, i_brand, i_manufact_id, i_manufact
| ORDER BY ext_price DESC, i_brand, i_brand_id, i_manufact_id, i_manufact
| LIMIT 100;
|""".stripMargin

val df = spark.sql(testSql)
val sortMergeJoinExec = df.queryExecution.executedPlan.collect {
case s: SortMergeJoinExec => s
}
assert(sortMergeJoinExec.nonEmpty)
}

test("Gluten-1235: Fix missing reading from the broadcasted value when executing DPP") {
val testSql =
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ object BackendSettings extends BackendSettingsApi {

override def supportSortExec(): Boolean = true

override def supportSortMergeJoinExec(): Boolean = {
GlutenConfig.getConf.enableColumnarSortMergeJoin
}

override def supportWindowExec(windowFunctions: Seq[NamedExpression]): Boolean = {
var allSupported = true
breakable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ trait BackendSettingsApi {
paths: Seq[String]): Boolean = false
def supportExpandExec(): Boolean = false
def supportSortExec(): Boolean = false
def supportSortMergeJoinExec(): Boolean = true
def supportWindowExec(windowFunctions: Seq[NamedExpression]): Boolean = {
false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,15 @@ trait SparkPlanExecApi {
*/
def extraExpressionMappings: Seq[Sig] = Seq.empty

/**
* Define whether the join operator is fallback because of
* the join operator is not supported by backend
*/
def joinFallback(JoinType: JoinType,
leftOutputSet: AttributeSet,
right: AttributeSet,
condition: Option[Expression]): Boolean = false

/**
* default function to generate window function node
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ case class JoinSelectionOverrides(session: SparkSession) extends Strategy with
}

if (forceShuffledHashJoin &&
!BackendsApiManager.getSparkPlanExecApiInstance.
joinFallback(joinType, left.outputSet, right.outputSet, condition) &&
!left.getTagValue(TAG).isDefined &&
!right.getTagValue(TAG).isDefined) {
// Force use of ShuffledHashJoin in preference to SortMergeJoin. With no respect to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ case class AddTransformHintRule() extends Rule[SparkPlan] {
!scanOnly && BackendsApiManager.getSettings.supportColumnarShuffleExec()
val enableColumnarSort: Boolean = !scanOnly && columnarConf.enableColumnarSort
val enableColumnarWindow: Boolean = !scanOnly && columnarConf.enableColumnarWindow
val enableColumnarSortMergeJoin: Boolean = !scanOnly && columnarConf.enableColumnarSortMergeJoin
val enableColumnarSortMergeJoin: Boolean = !scanOnly &&
BackendsApiManager.getSettings.supportSortMergeJoinExec()
val enableColumnarBatchScan: Boolean = columnarConf.enableColumnarBatchScan
val enableColumnarFileScan: Boolean = columnarConf.enableColumnarFileScan
val enableColumnarProject: Boolean = !scanOnly && columnarConf.enableColumnarProject
Expand Down
Loading