Skip to content

Commit

Permalink
[SPARK-35352][SQL] Add code-gen for full outer sort merge join
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR is to add code-gen for FULL OUTER sort merge join. The change is in `SortMergeJoinExec.scala:codegenFullOuter()`. Followed the same algorithm in iterator mode - `SortMergeFullOuterJoinScanner`: maintain buffer for join left and right sides, and iterate over matched rows in buffers.

Example query:

```
val df1 = spark.range(5).select($"id".as("k1"))
val df2 = spark.range(10).select($"id".as("k2"))
df1.join(df2.hint(hint), $"k1" === $"k2" % 3 && $"k1" + 3 =!= $"k2", "full_outer")
```

Example generated code: https://gist.github.com/c21/5cab9751f24ae448d77a259d28cb77d7

In addition, to help review as this PR triggers several TPCDS plan files change. The below files are having the real code change:

* `SortMergeJoinExec.scala`
* `WholeStageCodegenSuite.scala`

All other files are auto-generated golden file plan changes for TPCDS queries.

### Why are the changes needed?

Improve the run-time/CPU performance of FULL OUTER sort merge join.

Micro benchmark (same query in `JoinBenchmark.scala`):

```
  def sortMergeJoin(): Unit = {
    val N = 2 << 20
    codegenBenchmark("sort merge join", N) {
      val df1 = spark.range(N).selectExpr(s"id * 2 as k1")
      val df2 = spark.range(N).selectExpr(s"id * 3 as k2")
      val df = df1.join(df2, col("k1") === col("k2"), "full_outer")
      assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[SortMergeJoinExec]).isDefined)
      df.noop()
    }
  }

  def sortMergeJoinWithDuplicates(): Unit = {
    val N = 2 << 20
    codegenBenchmark("sort merge join with duplicates", N) {
      val df1 = spark.range(N)
        .selectExpr(s"(id * 15485863) % ${N*10} as k1")
      val df2 = spark.range(N)
        .selectExpr(s"(id * 15485867) % ${N*10} as k2")
      val df = df1.join(df2, col("k1") === col("k2"), "full_outer")
      assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[SortMergeJoinExec]).isDefined)
      df.noop()
    }
  }
```

Seeing 20-30% of run-time improvement:

```
Running benchmark: sort merge join
  Running case: sort merge join wholestage off
  Stopped after 2 iterations, 2979 ms
  Running case: sort merge join wholestage on
  Stopped after 5 iterations, 5849 ms

Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.16
Intel(R) Core(TM) i9-9980HK CPU  2.40GHz
sort merge join:                          Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------------------------------
sort merge join wholestage off                     1453           1490          52          1.4         693.0       1.0X
sort merge join wholestage on                      1115           1170          43          1.9         531.6       1.3X

Running benchmark: sort merge join with duplicates
  Running case: sort merge join with duplicates wholestage off
  Stopped after 2 iterations, 3236 ms
  Running case: sort merge join with duplicates wholestage on
  Stopped after 5 iterations, 6768 ms

Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.16
Intel(R) Core(TM) i9-9980HK CPU  2.40GHz
sort merge join with duplicates:                Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------------------------------------
sort merge join with duplicates wholestage off           1609           1618          13          1.3         767.2       1.0X
sort merge join with duplicates wholestage on            1330           1354          24          1.6         634.4       1.2X
```

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

* Added unit test in `WholeStageCodegenSuite.scala`.
* Existing unit test in `OuterJoinSuite.scala`.

Closes #34581 from c21/smj-codegen.

Authored-by: Cheng Su <chengsu@fb.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
c21 authored and cloud-fan committed Nov 15, 2021
1 parent 3da5608 commit 2ef60f7
Show file tree
Hide file tree
Showing 14 changed files with 327 additions and 65 deletions.
Expand Up @@ -364,7 +364,8 @@ case class SortMergeJoinExec(
}

private lazy val ((streamedPlan, streamedKeys), (bufferedPlan, bufferedKeys)) = joinType match {
case _: InnerLike | LeftOuter | LeftSemi | LeftAnti => ((left, leftKeys), (right, rightKeys))
case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | FullOuter =>
((left, leftKeys), (right, rightKeys))
case RightOuter => ((right, rightKeys), (left, leftKeys))
case x =>
throw new IllegalArgumentException(
Expand All @@ -374,9 +375,10 @@ case class SortMergeJoinExec(
private lazy val streamedOutput = streamedPlan.output
private lazy val bufferedOutput = bufferedPlan.output

// TODO(SPARK-37316): Add code-gen for existence sort merge join.
override def supportCodegen: Boolean = joinType match {
case _: InnerLike | LeftOuter | RightOuter | LeftSemi | LeftAnti => true
case _ => false
case _: ExistenceJoin => false
case _ => true
}

override def inputRDDs(): Seq[RDD[InternalRow]] = {
Expand Down Expand Up @@ -644,6 +646,12 @@ case class SortMergeJoinExec(
override def needCopyResult: Boolean = true

override def doProduce(ctx: CodegenContext): String = {
// Specialize `doProduce` code for full outer join, because full outer join needs to
// buffer both sides of join.
if (joinType == FullOuter) {
return codegenFullOuter(ctx)
}

// Inline mutable state since not many join operations in a task
val streamedInput = ctx.addMutableState("scala.collection.Iterator", "streamedInput",
v => s"$v = inputs[0];", forceInline = true)
Expand Down Expand Up @@ -890,6 +898,248 @@ case class SortMergeJoinExec(
""".stripMargin
}

/**
* Generates the code for Full Outer join.
*/
private def codegenFullOuter(ctx: CodegenContext): String = {
// Inline mutable state since not many join operations in a task.
// Create class member for input iterator from both sides.
val leftInput = ctx.addMutableState("scala.collection.Iterator", "leftInput",
v => s"$v = inputs[0];", forceInline = true)
val rightInput = ctx.addMutableState("scala.collection.Iterator", "rightInput",
v => s"$v = inputs[1];", forceInline = true)

// Create class member for next input row from both sides.
val leftInputRow = ctx.addMutableState("InternalRow", "leftInputRow", forceInline = true)
val rightInputRow = ctx.addMutableState("InternalRow", "rightInputRow", forceInline = true)

// Create variables for join keys from both sides.
val leftKeyVars = createJoinKey(ctx, leftInputRow, leftKeys, left.output)
val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ")
val rightKeyVars = createJoinKey(ctx, rightInputRow, rightKeys, right.output)
val rightAnyNull = rightKeyVars.map(_.isNull).mkString(" || ")
val matchedKeyVars = copyKeys(ctx, leftKeyVars)
val leftMatchedKeyVars = createJoinKey(ctx, leftInputRow, leftKeys, left.output)
val rightMatchedKeyVars = createJoinKey(ctx, rightInputRow, rightKeys, right.output)

// Create class member for next output row from both sides.
val leftOutputRow = ctx.addMutableState("InternalRow", "leftOutputRow", forceInline = true)
val rightOutputRow = ctx.addMutableState("InternalRow", "rightOutputRow", forceInline = true)

// Create class member for buffers of rows with same join keys from both sides.
val bufferClsName = "java.util.ArrayList<InternalRow>"
val leftBuffer = ctx.addMutableState(bufferClsName, "leftBuffer",
v => s"$v = new $bufferClsName();", forceInline = true)
val rightBuffer = ctx.addMutableState(bufferClsName, "rightBuffer",
v => s"$v = new $bufferClsName();", forceInline = true)
val matchedClsName = classOf[BitSet].getName
val leftMatched = ctx.addMutableState(matchedClsName, "leftMatched",
v => s"$v = new $matchedClsName(1);", forceInline = true)
val rightMatched = ctx.addMutableState(matchedClsName, "rightMatched",
v => s"$v = new $matchedClsName(1);", forceInline = true)
val leftIndex = ctx.freshName("leftIndex")
val rightIndex = ctx.freshName("rightIndex")

// Generate code for join condition
val leftResultVars = genOneSideJoinVars(
ctx, leftOutputRow, left, setDefaultValue = true)
val rightResultVars = genOneSideJoinVars(
ctx, rightOutputRow, right, setDefaultValue = true)
val resultVars = leftResultVars ++ rightResultVars
val (_, conditionCheck, _) =
getJoinCondition(ctx, leftResultVars, left, right, Some(rightOutputRow))

// Generate code for result output in separate function, as we need to output result from
// multiple places in join code.
val consumeFullOuterJoinRow = ctx.freshName("consumeFullOuterJoinRow")
ctx.addNewFunction(consumeFullOuterJoinRow,
s"""
|private void $consumeFullOuterJoinRow() throws java.io.IOException {
| ${metricTerm(ctx, "numOutputRows")}.add(1);
| ${consume(ctx, resultVars)}
|}
""".stripMargin)

// Handle the case when input row has no match.
val outputLeftNoMatch =
s"""
|$leftOutputRow = $leftInputRow;
|$rightOutputRow = null;
|$leftInputRow = null;
|$consumeFullOuterJoinRow();
""".stripMargin
val outputRightNoMatch =
s"""
|$rightOutputRow = $rightInputRow;
|$leftOutputRow = null;
|$rightInputRow = null;
|$consumeFullOuterJoinRow();
""".stripMargin

// Generate a function to scan both sides to find rows with matched join keys.
// The matched rows from both sides are copied in buffers separately. This function assumes
// either non-empty `leftIter` and `rightIter`, or non-null `leftInputRow` and `rightInputRow`.
//
// The function has the following steps:
// - Step 1: Find the next `leftInputRow` and `rightInputRow` with non-null join keys.
// Output row with null join keys (`outputLeftNoMatch` and `outputRightNoMatch`).
//
// - Step 2: Compare and find next same join keys from between `leftInputRow` and
// `rightInputRow`.
// Output row with smaller join keys (`outputLeftNoMatch` and `outputRightNoMatch`).
//
// - Step 3: Buffer rows with same join keys from both sides into `leftBuffer` and
// `rightBuffer`. Reset bit sets for both buffers accordingly (`leftMatched` and
// `rightMatched`).
val findNextJoinRowsFuncName = ctx.freshName("findNextJoinRows")
ctx.addNewFunction(findNextJoinRowsFuncName,
s"""
|private void $findNextJoinRowsFuncName(
| scala.collection.Iterator leftIter,
| scala.collection.Iterator rightIter) throws java.io.IOException {
| int comp = 0;
| $leftBuffer.clear();
| $rightBuffer.clear();
|
| if ($leftInputRow == null) {
| $leftInputRow = (InternalRow) leftIter.next();
| }
| if ($rightInputRow == null) {
| $rightInputRow = (InternalRow) rightIter.next();
| }
|
| ${leftKeyVars.map(_.code).mkString("\n")}
| if ($leftAnyNull) {
| // The left row join key is null, join it with null row
| $outputLeftNoMatch
| return;
| }
|
| ${rightKeyVars.map(_.code).mkString("\n")}
| if ($rightAnyNull) {
| // The right row join key is null, join it with null row
| $outputRightNoMatch
| return;
| }
|
| ${genComparison(ctx, leftKeyVars, rightKeyVars)}
| if (comp < 0) {
| // The left row join key is smaller, join it with null row
| $outputLeftNoMatch
| return;
| } else if (comp > 0) {
| // The right row join key is smaller, join it with null row
| $outputRightNoMatch
| return;
| }
|
| ${matchedKeyVars.map(_.code).mkString("\n")}
| $leftBuffer.add($leftInputRow.copy());
| $rightBuffer.add($rightInputRow.copy());
| $leftInputRow = null;
| $rightInputRow = null;
|
| // Buffer rows from both sides with same join key
| while (leftIter.hasNext()) {
| $leftInputRow = (InternalRow) leftIter.next();
| ${leftMatchedKeyVars.map(_.code).mkString("\n")}
| ${genComparison(ctx, leftMatchedKeyVars, matchedKeyVars)}
| if (comp == 0) {
|
| $leftBuffer.add($leftInputRow.copy());
| $leftInputRow = null;
| } else {
| break;
| }
| }
| while (rightIter.hasNext()) {
| $rightInputRow = (InternalRow) rightIter.next();
| ${rightMatchedKeyVars.map(_.code).mkString("\n")}
| ${genComparison(ctx, rightMatchedKeyVars, matchedKeyVars)}
| if (comp == 0) {
| $rightBuffer.add($rightInputRow.copy());
| $rightInputRow = null;
| } else {
| break;
| }
| }
|
| // Reset bit sets of buffers accordingly
| if ($leftBuffer.size() <= $leftMatched.capacity()) {
| $leftMatched.clearUntil($leftBuffer.size());
| } else {
| $leftMatched = new $matchedClsName($leftBuffer.size());
| }
| if ($rightBuffer.size() <= $rightMatched.capacity()) {
| $rightMatched.clearUntil($rightBuffer.size());
| } else {
| $rightMatched = new $matchedClsName($rightBuffer.size());
| }
|}
""".stripMargin)

// Scan the left and right buffers to find all matched rows.
val matchRowsInBuffer =
s"""
|int $leftIndex;
|int $rightIndex;
|
|for ($leftIndex = 0; $leftIndex < $leftBuffer.size(); $leftIndex++) {
| $leftOutputRow = (InternalRow) $leftBuffer.get($leftIndex);
| for ($rightIndex = 0; $rightIndex < $rightBuffer.size(); $rightIndex++) {
| $rightOutputRow = (InternalRow) $rightBuffer.get($rightIndex);
| $conditionCheck {
| $consumeFullOuterJoinRow();
| $leftMatched.set($leftIndex);
| $rightMatched.set($rightIndex);
| }
| }
|
| if (!$leftMatched.get($leftIndex)) {
|
| $rightOutputRow = null;
| $consumeFullOuterJoinRow();
| }
|}
|
|$leftOutputRow = null;
|for ($rightIndex = 0; $rightIndex < $rightBuffer.size(); $rightIndex++) {
| if (!$rightMatched.get($rightIndex)) {
| // The right row has never matched any left row, join it with null row
| $rightOutputRow = (InternalRow) $rightBuffer.get($rightIndex);
| $consumeFullOuterJoinRow();
| }
|}
""".stripMargin

s"""
|while (($leftInputRow != null || $leftInput.hasNext()) &&
| ($rightInputRow != null || $rightInput.hasNext())) {
| $findNextJoinRowsFuncName($leftInput, $rightInput);
| $matchRowsInBuffer
| if (shouldStop()) return;
|}
|
|// The right iterator has no more rows, join left row with null
|while ($leftInputRow != null || $leftInput.hasNext()) {
| if ($leftInputRow == null) {
| $leftInputRow = (InternalRow) $leftInput.next();
| }
| $outputLeftNoMatch
| if (shouldStop()) return;
|}
|
|// The left iterator has no more rows, join right row with null
|while ($rightInputRow != null || $rightInput.hasNext()) {
| if ($rightInputRow == null) {
| $rightInputRow = (InternalRow) $rightInput.next();
| }
| $outputRightNoMatch
| if (shouldStop()) return;
|}
""".stripMargin
}

override protected def withNewChildrenInternal(
newLeft: SparkPlan, newRight: SparkPlan): SortMergeJoinExec =
copy(left = newLeft, right = newRight)
Expand Down
Expand Up @@ -5,7 +5,7 @@ TakeOrderedAndProject (37)
+- * Sort (34)
+- Exchange (33)
+- * Project (32)
+- SortMergeJoin FullOuter (31)
+- * SortMergeJoin FullOuter (31)
:- * Sort (15)
: +- Exchange (14)
: +- * Project (13)
Expand Down Expand Up @@ -176,7 +176,7 @@ Arguments: hashpartitioning(item_sk#25, d_date#20, 5), ENSURE_REQUIREMENTS, [id=
Input [3]: [item_sk#25, d_date#20, cume_sales#28]
Arguments: [item_sk#25 ASC NULLS FIRST, d_date#20 ASC NULLS FIRST], false, 0

(31) SortMergeJoin
(31) SortMergeJoin [codegen id : 13]
Left keys [2]: [item_sk#11, d_date#6]
Right keys [2]: [item_sk#25, d_date#20]
Join condition: None
Expand Down
Expand Up @@ -9,8 +9,8 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store
Exchange [item_sk] #1
WholeStageCodegen (13)
Project [item_sk,item_sk,d_date,d_date,cume_sales,cume_sales]
InputAdapter
SortMergeJoin [item_sk,d_date,item_sk,d_date]
SortMergeJoin [item_sk,d_date,item_sk,d_date]
InputAdapter
WholeStageCodegen (6)
Sort [item_sk,d_date]
InputAdapter
Expand Down Expand Up @@ -45,6 +45,7 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store
Scan parquet default.date_dim [d_date_sk,d_date,d_month_seq]
InputAdapter
ReusedExchange [d_date_sk,d_date] #5
InputAdapter
WholeStageCodegen (12)
Sort [item_sk,d_date]
InputAdapter
Expand Down
Expand Up @@ -5,7 +5,7 @@ TakeOrderedAndProject (37)
+- * Sort (34)
+- Exchange (33)
+- * Project (32)
+- SortMergeJoin FullOuter (31)
+- * SortMergeJoin FullOuter (31)
:- * Sort (15)
: +- Exchange (14)
: +- * Project (13)
Expand Down Expand Up @@ -176,7 +176,7 @@ Arguments: hashpartitioning(item_sk#25, d_date#20, 5), ENSURE_REQUIREMENTS, [id=
Input [3]: [item_sk#25, d_date#20, cume_sales#28]
Arguments: [item_sk#25 ASC NULLS FIRST, d_date#20 ASC NULLS FIRST], false, 0

(31) SortMergeJoin
(31) SortMergeJoin [codegen id : 13]
Left keys [2]: [item_sk#11, d_date#6]
Right keys [2]: [item_sk#25, d_date#20]
Join condition: None
Expand Down
Expand Up @@ -9,8 +9,8 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store
Exchange [item_sk] #1
WholeStageCodegen (13)
Project [item_sk,item_sk,d_date,d_date,cume_sales,cume_sales]
InputAdapter
SortMergeJoin [item_sk,d_date,item_sk,d_date]
SortMergeJoin [item_sk,d_date,item_sk,d_date]
InputAdapter
WholeStageCodegen (6)
Sort [item_sk,d_date]
InputAdapter
Expand Down Expand Up @@ -45,6 +45,7 @@ TakeOrderedAndProject [item_sk,d_date,web_sales,store_sales,web_cumulative,store
Scan parquet default.date_dim [d_date_sk,d_date,d_month_seq]
InputAdapter
ReusedExchange [d_date_sk,d_date] #5
InputAdapter
WholeStageCodegen (12)
Sort [item_sk,d_date]
InputAdapter
Expand Down
Expand Up @@ -3,7 +3,7 @@
+- Exchange (22)
+- * HashAggregate (21)
+- * Project (20)
+- SortMergeJoin FullOuter (19)
+- * SortMergeJoin FullOuter (19)
:- * Sort (9)
: +- * HashAggregate (8)
: +- Exchange (7)
Expand Down Expand Up @@ -112,7 +112,7 @@ Results [2]: [cs_bill_customer_sk#9 AS customer_sk#14, cs_item_sk#10 AS item_sk#
Input [2]: [customer_sk#14, item_sk#15]
Arguments: [customer_sk#14 ASC NULLS FIRST, item_sk#15 ASC NULLS FIRST], false, 0

(19) SortMergeJoin
(19) SortMergeJoin [codegen id : 7]
Left keys [2]: [customer_sk#7, item_sk#8]
Right keys [2]: [customer_sk#14, item_sk#15]
Join condition: None
Expand Down
Expand Up @@ -5,8 +5,8 @@ WholeStageCodegen (8)
WholeStageCodegen (7)
HashAggregate [customer_sk,customer_sk] [sum,sum,sum,sum,sum,sum]
Project [customer_sk,customer_sk]
InputAdapter
SortMergeJoin [customer_sk,item_sk,customer_sk,item_sk]
SortMergeJoin [customer_sk,item_sk,customer_sk,item_sk]
InputAdapter
WholeStageCodegen (3)
Sort [customer_sk,item_sk]
HashAggregate [ss_customer_sk,ss_item_sk] [customer_sk,item_sk]
Expand All @@ -29,6 +29,7 @@ WholeStageCodegen (8)
Scan parquet default.date_dim [d_date_sk,d_month_seq]
InputAdapter
ReusedExchange [d_date_sk] #3
InputAdapter
WholeStageCodegen (6)
Sort [customer_sk,item_sk]
HashAggregate [cs_bill_customer_sk,cs_item_sk] [customer_sk,item_sk]
Expand Down

0 comments on commit 2ef60f7

Please sign in to comment.