Skip to content

Commit

Permalink
[SPARK-37316][SQL] Add code-gen for existence sort merge join
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR is to add code-gen for Existence sort merge join. Followed the same algorithm used in iterator mode from `SortMergeJoinExec.scala` and `HashJoin.scala`. Check every left side row if there's an existing match from right side. Output every left side row and boolean flag to indicate whether it has a match or not.

In addition, to help review as this PR triggers several TPCDS plan files change. The below file is having the real code change:

* `SortMergeJoinExec.scala`

### Why are the changes needed?

To improve performance for Existence sort merge join, and conclude all join types code-gen support for sort merge join.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Existing unit test in `ExistenceJoinSuite.scala`

Closes #34601 from c21/existence-join-codegen.

Authored-by: Cheng Su <chengsu@fb.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
c21 authored and cloud-fan committed Nov 17, 2021
1 parent 06586f6 commit dfca899
Show file tree
Hide file tree
Showing 7 changed files with 318 additions and 259 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.types.BooleanType
import org.apache.spark.util.collection.BitSet

/**
Expand Down Expand Up @@ -364,7 +365,7 @@ case class SortMergeJoinExec(
}

private lazy val ((streamedPlan, streamedKeys), (bufferedPlan, bufferedKeys)) = joinType match {
case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | FullOuter =>
case _: InnerLike | LeftOuter | FullOuter | LeftExistence(_) =>
((left, leftKeys), (right, rightKeys))
case RightOuter => ((right, rightKeys), (left, leftKeys))
case x =>
Expand All @@ -375,12 +376,6 @@ case class SortMergeJoinExec(
private lazy val streamedOutput = streamedPlan.output
private lazy val bufferedOutput = bufferedPlan.output

// TODO(SPARK-37316): Add code-gen for existence sort merge join.
override def supportCodegen: Boolean = joinType match {
case _: ExistenceJoin => false
case _ => true
}

override def inputRDDs(): Seq[RDD[InternalRow]] = {
streamedPlan.execute() :: bufferedPlan.execute() :: Nil
}
Expand Down Expand Up @@ -455,7 +450,7 @@ case class SortMergeJoinExec(
|$streamedRow = null;
|continue;
""".stripMargin
case LeftOuter | RightOuter | LeftAnti =>
case LeftOuter | RightOuter | LeftAnti | ExistenceJoin(_) =>
// Eagerly return streamed row. Only call `matches.clear()` when `matches.isEmpty()` is
// false, to reduce unnecessary computation.
s"""
Expand All @@ -474,7 +469,7 @@ case class SortMergeJoinExec(
case _: InnerLike | LeftSemi =>
// Skip streamed row.
s"$streamedRow = null;"
case LeftOuter | RightOuter | LeftAnti =>
case LeftOuter | RightOuter | LeftAnti | ExistenceJoin(_) =>
// Eagerly return with streamed row.
"return false;"
case x =>
Expand Down Expand Up @@ -511,17 +506,18 @@ case class SortMergeJoinExec(
// 1. Inner and Left Semi join: skip the row. `matches` will be cleared later when
// hitting the next `streamedRow` with non-null join
// keys.
// 2. Left/Right Outer and Left Anti join: clear the previous `matches` if needed,
// keep the row, and return false.
// 2. Left/Right Outer, Left Anti and Existence join: clear the previous `matches`
// if needed, keep the row, and
// return false.
//
// - Step 2: Find the `matches` from buffered side having same join keys with `streamedRow`.
// Clear `matches` if we hit a new `streamedRow`, as we need to find new matches.
// Use `bufferedRow` to iterate buffered side to put all matched rows into
// `matches` (`addRowToBuffer`). Return true when getting all matched rows.
// For `streamedRow` without `matches` (`handleStreamedWithoutMatch`):
// 1. Inner and Left Semi join: skip the row.
// 2. Left/Right Outer and Left Anti join: keep the row and return false (with
// `matches` being empty).
// 2. Left/Right Outer, Left Anti and Existence join: keep the row and return false
// (with `matches` being empty).
val findNextJoinRowsFuncName = ctx.freshName("findNextJoinRows")
ctx.addNewFunction(findNextJoinRowsFuncName,
s"""
Expand Down Expand Up @@ -666,6 +662,12 @@ case class SortMergeJoinExec(
val setDefaultValue = joinType == LeftOuter || joinType == RightOuter
val bufferedVars = genOneSideJoinVars(ctx, bufferedRow, bufferedPlan, setDefaultValue)

// Create variable name for Existence join.
val existsVar = joinType match {
case ExistenceJoin(_) => Some(ctx.freshName("exists"))
case _ => None
}

val iterator = ctx.freshName("iterator")
val numOutput = metricTerm(ctx, "numOutputRows")
val resultVars = joinType match {
Expand All @@ -675,6 +677,9 @@ case class SortMergeJoinExec(
bufferedVars ++ streamedVars
case LeftSemi | LeftAnti =>
streamedVars
case ExistenceJoin(_) =>
streamedVars ++ Seq(ExprCode.forNonNullValue(
JavaCode.variable(existsVar.get, BooleanType)))
case x =>
throw new IllegalArgumentException(
s"SortMergeJoin.doProduce should not take $x as the JoinType")
Expand Down Expand Up @@ -716,7 +721,7 @@ case class SortMergeJoinExec(
}

val loadBufferedAfterCondition = joinType match {
case LeftSemi | LeftAnti =>
case LeftExistence(_) =>
// No need to evaluate columns not used by condition from buffered side
""
case _ => bufferedAfter
Expand Down Expand Up @@ -767,6 +772,9 @@ case class SortMergeJoinExec(
case LeftAnti =>
codegenAnti(streamedInput, findNextJoinRows, beforeLoop, iterator, bufferedRow, condCheck,
loadStreamed, ctx.freshName("hasMatchedRow"), outputRow, eagerCleanup)
case ExistenceJoin(_) =>
codegenExistence(streamedInput, findNextJoinRows, beforeLoop, iterator, bufferedRow,
condCheck, loadStreamed, existsVar.get, outputRow, eagerCleanup)
case x =>
throw new IllegalArgumentException(
s"SortMergeJoin.doProduce should not take $x as the JoinType")
Expand Down Expand Up @@ -898,6 +906,45 @@ case class SortMergeJoinExec(
""".stripMargin
}

/**
* Generates the code for Existence join.
*/
private def codegenExistence(
streamedInput: String,
findNextJoinRows: String,
beforeLoop: String,
matchIterator: String,
bufferedRow: String,
conditionCheck: String,
loadStreamed: String,
exists: String,
outputRow: String,
eagerCleanup: String): String = {
s"""
|while ($streamedInput.hasNext()) {
| $findNextJoinRows;
| $beforeLoop
| boolean $exists = false;
|
| while (!$exists && $matchIterator.hasNext()) {
| InternalRow $bufferedRow = (InternalRow) $matchIterator.next();
| $conditionCheck
| $exists = true;
| }
|
| if (!$exists) {
| // load all values of streamed row, because the values not in join condition are not
| // loaded yet.
| $loadStreamed
| }
| $outputRow
|
| if (shouldStop()) return;
|}
|$eagerCleanup
""".stripMargin
}

/**
* Generates the code for Full Outer join.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ TakeOrderedAndProject (51)
: +- * BroadcastHashJoin Inner BuildRight (37)
: :- * Project (31)
: : +- * Filter (30)
: : +- SortMergeJoin ExistenceJoin(exists#1) (29)
: : :- SortMergeJoin ExistenceJoin(exists#2) (21)
: : +- * SortMergeJoin ExistenceJoin(exists#1) (29)
: : :- * SortMergeJoin ExistenceJoin(exists#2) (21)
: : : :- * SortMergeJoin LeftSemi (13)
: : : : :- * Sort (5)
: : : : : +- Exchange (4)
Expand Down Expand Up @@ -139,7 +139,7 @@ Arguments: hashpartitioning(ws_bill_customer_sk#12, 5), ENSURE_REQUIREMENTS, [id
Input [1]: [ws_bill_customer_sk#12]
Arguments: [ws_bill_customer_sk#12 ASC NULLS FIRST], false, 0

(21) SortMergeJoin
(21) SortMergeJoin [codegen id : 10]
Left keys [1]: [c_customer_sk#3]
Right keys [1]: [ws_bill_customer_sk#12]
Join condition: None
Expand All @@ -151,39 +151,39 @@ Location: InMemoryFileIndex []
PartitionFilters: [isnotnull(cs_sold_date_sk#17), dynamicpruningexpression(cs_sold_date_sk#17 IN dynamicpruning#9)]
ReadSchema: struct<cs_ship_customer_sk:int>

(23) ColumnarToRow [codegen id : 11]
(23) ColumnarToRow [codegen id : 12]
Input [2]: [cs_ship_customer_sk#16, cs_sold_date_sk#17]

(24) ReusedExchange [Reuses operator id: 56]
Output [1]: [d_date_sk#18]

(25) BroadcastHashJoin [codegen id : 11]
(25) BroadcastHashJoin [codegen id : 12]
Left keys [1]: [cs_sold_date_sk#17]
Right keys [1]: [d_date_sk#18]
Join condition: None

(26) Project [codegen id : 11]
(26) Project [codegen id : 12]
Output [1]: [cs_ship_customer_sk#16]
Input [3]: [cs_ship_customer_sk#16, cs_sold_date_sk#17, d_date_sk#18]

(27) Exchange
Input [1]: [cs_ship_customer_sk#16]
Arguments: hashpartitioning(cs_ship_customer_sk#16, 5), ENSURE_REQUIREMENTS, [id=#19]

(28) Sort [codegen id : 12]
(28) Sort [codegen id : 13]
Input [1]: [cs_ship_customer_sk#16]
Arguments: [cs_ship_customer_sk#16 ASC NULLS FIRST], false, 0

(29) SortMergeJoin
(29) SortMergeJoin [codegen id : 15]
Left keys [1]: [c_customer_sk#3]
Right keys [1]: [cs_ship_customer_sk#16]
Join condition: None

(30) Filter [codegen id : 14]
(30) Filter [codegen id : 15]
Input [5]: [c_customer_sk#3, c_current_cdemo_sk#4, c_current_addr_sk#5, exists#2, exists#1]
Condition : (exists#2 OR exists#1)

(31) Project [codegen id : 14]
(31) Project [codegen id : 15]
Output [2]: [c_current_cdemo_sk#4, c_current_addr_sk#5]
Input [5]: [c_customer_sk#3, c_current_cdemo_sk#4, c_current_addr_sk#5, exists#2, exists#1]

Expand All @@ -194,35 +194,35 @@ Location [not included in comparison]/{warehouse_dir}/customer_address]
PushedFilters: [In(ca_county, [Dona Ana County,Jefferson County,La Porte County,Rush County,Toole County]), IsNotNull(ca_address_sk)]
ReadSchema: struct<ca_address_sk:int,ca_county:string>

(33) ColumnarToRow [codegen id : 13]
(33) ColumnarToRow [codegen id : 14]
Input [2]: [ca_address_sk#20, ca_county#21]

(34) Filter [codegen id : 13]
(34) Filter [codegen id : 14]
Input [2]: [ca_address_sk#20, ca_county#21]
Condition : (ca_county#21 IN (Rush County,Toole County,Jefferson County,Dona Ana County,La Porte County) AND isnotnull(ca_address_sk#20))

(35) Project [codegen id : 13]
(35) Project [codegen id : 14]
Output [1]: [ca_address_sk#20]
Input [2]: [ca_address_sk#20, ca_county#21]

(36) BroadcastExchange
Input [1]: [ca_address_sk#20]
Arguments: HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#22]

(37) BroadcastHashJoin [codegen id : 14]
(37) BroadcastHashJoin [codegen id : 15]
Left keys [1]: [c_current_addr_sk#5]
Right keys [1]: [ca_address_sk#20]
Join condition: None

(38) Project [codegen id : 14]
(38) Project [codegen id : 15]
Output [1]: [c_current_cdemo_sk#4]
Input [3]: [c_current_cdemo_sk#4, c_current_addr_sk#5, ca_address_sk#20]

(39) Exchange
Input [1]: [c_current_cdemo_sk#4]
Arguments: hashpartitioning(c_current_cdemo_sk#4, 5), ENSURE_REQUIREMENTS, [id=#23]

(40) Sort [codegen id : 15]
(40) Sort [codegen id : 16]
Input [1]: [c_current_cdemo_sk#4]
Arguments: [c_current_cdemo_sk#4 ASC NULLS FIRST], false, 0

Expand All @@ -233,31 +233,31 @@ Location [not included in comparison]/{warehouse_dir}/customer_demographics]
PushedFilters: [IsNotNull(cd_demo_sk)]
ReadSchema: struct<cd_demo_sk:int,cd_gender:string,cd_marital_status:string,cd_education_status:string,cd_purchase_estimate:int,cd_credit_rating:string,cd_dep_count:int,cd_dep_employed_count:int,cd_dep_college_count:int>

(42) ColumnarToRow [codegen id : 16]
(42) ColumnarToRow [codegen id : 17]
Input [9]: [cd_demo_sk#24, cd_gender#25, cd_marital_status#26, cd_education_status#27, cd_purchase_estimate#28, cd_credit_rating#29, cd_dep_count#30, cd_dep_employed_count#31, cd_dep_college_count#32]

(43) Filter [codegen id : 16]
(43) Filter [codegen id : 17]
Input [9]: [cd_demo_sk#24, cd_gender#25, cd_marital_status#26, cd_education_status#27, cd_purchase_estimate#28, cd_credit_rating#29, cd_dep_count#30, cd_dep_employed_count#31, cd_dep_college_count#32]
Condition : isnotnull(cd_demo_sk#24)

(44) Exchange
Input [9]: [cd_demo_sk#24, cd_gender#25, cd_marital_status#26, cd_education_status#27, cd_purchase_estimate#28, cd_credit_rating#29, cd_dep_count#30, cd_dep_employed_count#31, cd_dep_college_count#32]
Arguments: hashpartitioning(cd_demo_sk#24, 5), ENSURE_REQUIREMENTS, [id=#33]

(45) Sort [codegen id : 17]
(45) Sort [codegen id : 18]
Input [9]: [cd_demo_sk#24, cd_gender#25, cd_marital_status#26, cd_education_status#27, cd_purchase_estimate#28, cd_credit_rating#29, cd_dep_count#30, cd_dep_employed_count#31, cd_dep_college_count#32]
Arguments: [cd_demo_sk#24 ASC NULLS FIRST], false, 0

(46) SortMergeJoin [codegen id : 18]
(46) SortMergeJoin [codegen id : 19]
Left keys [1]: [c_current_cdemo_sk#4]
Right keys [1]: [cd_demo_sk#24]
Join condition: None

(47) Project [codegen id : 18]
(47) Project [codegen id : 19]
Output [8]: [cd_gender#25, cd_marital_status#26, cd_education_status#27, cd_purchase_estimate#28, cd_credit_rating#29, cd_dep_count#30, cd_dep_employed_count#31, cd_dep_college_count#32]
Input [10]: [c_current_cdemo_sk#4, cd_demo_sk#24, cd_gender#25, cd_marital_status#26, cd_education_status#27, cd_purchase_estimate#28, cd_credit_rating#29, cd_dep_count#30, cd_dep_employed_count#31, cd_dep_college_count#32]

(48) HashAggregate [codegen id : 18]
(48) HashAggregate [codegen id : 19]
Input [8]: [cd_gender#25, cd_marital_status#26, cd_education_status#27, cd_purchase_estimate#28, cd_credit_rating#29, cd_dep_count#30, cd_dep_employed_count#31, cd_dep_college_count#32]
Keys [8]: [cd_gender#25, cd_marital_status#26, cd_education_status#27, cd_purchase_estimate#28, cd_credit_rating#29, cd_dep_count#30, cd_dep_employed_count#31, cd_dep_college_count#32]
Functions [1]: [partial_count(1)]
Expand All @@ -268,7 +268,7 @@ Results [9]: [cd_gender#25, cd_marital_status#26, cd_education_status#27, cd_pur
Input [9]: [cd_gender#25, cd_marital_status#26, cd_education_status#27, cd_purchase_estimate#28, cd_credit_rating#29, cd_dep_count#30, cd_dep_employed_count#31, cd_dep_college_count#32, count#35]
Arguments: hashpartitioning(cd_gender#25, cd_marital_status#26, cd_education_status#27, cd_purchase_estimate#28, cd_credit_rating#29, cd_dep_count#30, cd_dep_employed_count#31, cd_dep_college_count#32, 5), ENSURE_REQUIREMENTS, [id=#36]

(50) HashAggregate [codegen id : 19]
(50) HashAggregate [codegen id : 20]
Input [9]: [cd_gender#25, cd_marital_status#26, cd_education_status#27, cd_purchase_estimate#28, cd_credit_rating#29, cd_dep_count#30, cd_dep_employed_count#31, cd_dep_college_count#32, count#35]
Keys [8]: [cd_gender#25, cd_marital_status#26, cd_education_status#27, cd_purchase_estimate#28, cd_credit_rating#29, cd_dep_count#30, cd_dep_employed_count#31, cd_dep_college_count#32]
Functions [1]: [count(1)]
Expand Down
Loading

0 comments on commit dfca899

Please sign in to comment.