Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-34620][SQL] Code-gen broadcast nested loop join (inner/cross) #31736

Closed
wants to merge 5 commits into from

Conversation

c21
Copy link
Contributor

@c21 c21 commented Mar 4, 2021

What changes were proposed in this pull request?

BroadcastNestedLoopJoinExec does not have code-gen, and we can potentially boost the CPU performance for this operator if we add code-gen for it. https://databricks.com/blog/2017/02/16/processing-trillion-rows-per-second-single-machine-can-nested-loop-joins-fast.html also showed the evidence in one fork.

The codegen for BroadcastNestedLoopJoinExec shared some code with HashJoin, and the interface JoinCodegenSupport is created to hold those common logic. This PR is only supporting inner and cross join. Other join types will be added later in followup PRs.

Example query and generated code:

val df1 = spark.range(4).select($"id".as("k1"))
val df2 = spark.range(3).select($"id".as("k2"))
df1.join(df2, $"k1" + 1 =!= $"k2").explain("codegen")
== Subtree 2 / 2 (maxMethodCodeSize:282; maxConstantPoolSize:203(0.31% used); numInnerClasses:0) ==
*(2) BroadcastNestedLoopJoin BuildRight, Inner, NOT ((k1#2L + 1) = k2#6L)
:- *(2) Project [id#0L AS k1#2L]
:  +- *(2) Range (0, 4, step=1, splits=2)
+- BroadcastExchange IdentityBroadcastMode, [id=#22]
   +- *(1) Project [id#4L AS k2#6L]
      +- *(1) Range (0, 3, step=1, splits=2)

Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIteratorForCodegenStage2(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=2
/* 006 */ final class GeneratedIteratorForCodegenStage2 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private boolean range_initRange_0;
/* 010 */   private long range_nextIndex_0;
/* 011 */   private TaskContext range_taskContext_0;
/* 012 */   private InputMetrics range_inputMetrics_0;
/* 013 */   private long range_batchEnd_0;
/* 014 */   private long range_numElementsTodo_0;
/* 015 */   private InternalRow[] bnlj_buildRowArray_0;
/* 016 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] range_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[4];
/* 017 */
/* 018 */   public GeneratedIteratorForCodegenStage2(Object[] references) {
/* 019 */     this.references = references;
/* 020 */   }
/* 021 */
/* 022 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 023 */     partitionIndex = index;
/* 024 */     this.inputs = inputs;
/* 025 */
/* 026 */     range_taskContext_0 = TaskContext.get();
/* 027 */     range_inputMetrics_0 = range_taskContext_0.taskMetrics().inputMetrics();
/* 028 */     range_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 029 */     range_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 030 */     range_mutableStateArray_0[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 031 */     bnlj_buildRowArray_0 = (InternalRow[]) ((org.apache.spark.broadcast.TorrentBroadcast) references[1] /* broadcastTerm */).value();
/* 032 */     range_mutableStateArray_0[3] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 0);
/* 033 */
/* 034 */   }
/* 035 */
/* 036 */   private void bnlj_doConsume_0(long bnlj_expr_0_0) throws java.io.IOException {
/* 037 */     for (int bnlj_arrayIndex_0 = 0; bnlj_arrayIndex_0 < bnlj_buildRowArray_0.length; bnlj_arrayIndex_0++) {
/* 038 */       UnsafeRow bnlj_buildRow_0 = (UnsafeRow) bnlj_buildRowArray_0[bnlj_arrayIndex_0];
/* 039 */
/* 040 */       long bnlj_value_1 = bnlj_buildRow_0.getLong(0);
/* 041 */
/* 042 */       long bnlj_value_4 = -1L;
/* 043 */
/* 044 */       bnlj_value_4 = bnlj_expr_0_0 + 1L;
/* 045 */
/* 046 */       boolean bnlj_value_3 = false;
/* 047 */       bnlj_value_3 = bnlj_value_4 == bnlj_value_1;
/* 048 */       boolean bnlj_value_2 = false;
/* 049 */       bnlj_value_2 = !(bnlj_value_3);
/* 050 */       if (!(false || !bnlj_value_2))
/* 051 */       {
/* 052 */         ((org.apache.spark.sql.execution.metric.SQLMetric) references[2] /* numOutputRows */).add(1);
/* 053 */
/* 054 */         range_mutableStateArray_0[3].reset();
/* 055 */
/* 056 */         range_mutableStateArray_0[3].write(0, bnlj_expr_0_0);
/* 057 */
/* 058 */         range_mutableStateArray_0[3].write(1, bnlj_value_1);
/* 059 */         append((range_mutableStateArray_0[3].getRow()).copy());
/* 060 */
/* 061 */       }
/* 062 */     }
/* 063 */
/* 064 */   }
/* 065 */
/* 066 */   private void initRange(int idx) {
/* 067 */     java.math.BigInteger index = java.math.BigInteger.valueOf(idx);
/* 068 */     java.math.BigInteger numSlice = java.math.BigInteger.valueOf(2L);
/* 069 */     java.math.BigInteger numElement = java.math.BigInteger.valueOf(4L);
/* 070 */     java.math.BigInteger step = java.math.BigInteger.valueOf(1L);
/* 071 */     java.math.BigInteger start = java.math.BigInteger.valueOf(0L);
/* 072 */     long partitionEnd;
/* 073 */
/* 074 */     java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
/* 075 */     if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 076 */       range_nextIndex_0 = Long.MAX_VALUE;
/* 077 */     } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 078 */       range_nextIndex_0 = Long.MIN_VALUE;
/* 079 */     } else {
/* 080 */       range_nextIndex_0 = st.longValue();
/* 081 */     }
/* 082 */     range_batchEnd_0 = range_nextIndex_0;
/* 083 */
/* 084 */     java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice)
/* 085 */     .multiply(step).add(start);
/* 086 */     if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 087 */       partitionEnd = Long.MAX_VALUE;
/* 088 */     } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 089 */       partitionEnd = Long.MIN_VALUE;
/* 090 */     } else {
/* 091 */       partitionEnd = end.longValue();
/* 092 */     }
/* 093 */
/* 094 */     java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract(
/* 095 */       java.math.BigInteger.valueOf(range_nextIndex_0));
/* 096 */     range_numElementsTodo_0  = startToEnd.divide(step).longValue();
/* 097 */     if (range_numElementsTodo_0 < 0) {
/* 098 */       range_numElementsTodo_0 = 0;
/* 099 */     } else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) {
/* 100 */       range_numElementsTodo_0++;
/* 101 */     }
/* 102 */   }
/* 103 */
/* 104 */   protected void processNext() throws java.io.IOException {
/* 105 */     // initialize Range
/* 106 */     if (!range_initRange_0) {
/* 107 */       range_initRange_0 = true;
/* 108 */       initRange(partitionIndex);
/* 109 */     }
/* 110 */
/* 111 */     while (true) {
/* 112 */       if (range_nextIndex_0 == range_batchEnd_0) {
/* 113 */         long range_nextBatchTodo_0;
/* 114 */         if (range_numElementsTodo_0 > 1000L) {
/* 115 */           range_nextBatchTodo_0 = 1000L;
/* 116 */           range_numElementsTodo_0 -= 1000L;
/* 117 */         } else {
/* 118 */           range_nextBatchTodo_0 = range_numElementsTodo_0;
/* 119 */           range_numElementsTodo_0 = 0;
/* 120 */           if (range_nextBatchTodo_0 == 0) break;
/* 121 */         }
/* 122 */         range_batchEnd_0 += range_nextBatchTodo_0 * 1L;
/* 123 */       }
/* 124 */
/* 125 */       int range_localEnd_0 = (int)((range_batchEnd_0 - range_nextIndex_0) / 1L);
/* 126 */       for (int range_localIdx_0 = 0; range_localIdx_0 < range_localEnd_0; range_localIdx_0++) {
/* 127 */         long range_value_0 = ((long)range_localIdx_0 * 1L) + range_nextIndex_0;
/* 128 */
/* 129 */         // common sub-expressions
/* 130 */
/* 131 */         bnlj_doConsume_0(range_value_0);
/* 132 */
/* 133 */         if (shouldStop()) {
/* 134 */           range_nextIndex_0 = range_value_0 + 1L;
/* 135 */           ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localIdx_0 + 1);
/* 136 */           range_inputMetrics_0.incRecordsRead(range_localIdx_0 + 1);
/* 137 */           return;
/* 138 */         }
/* 139 */
/* 140 */       }
/* 141 */       range_nextIndex_0 = range_batchEnd_0;
/* 142 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localEnd_0);
/* 143 */       range_inputMetrics_0.incRecordsRead(range_localEnd_0);
/* 144 */       range_taskContext_0.killTaskIfInterrupted();
/* 145 */     }
/* 146 */   }
/* 147 */
/* 148 */ }

Why are the changes needed?

Improve query CPU performance. Added a micro benchmark query in JoinBenchmark.scala.
Saw 1x of run time improvement:

OpenJDK 64-Bit Server VM 11.0.9+11-LTS on Linux 4.14.219-161.340.amzn2.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
broadcast nested loop join:                Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
-------------------------------------------------------------------------------------------------------------------------
broadcast nested loop join wholestage off          62922          63052         184          0.3        3000.3       1.0X
broadcast nested loop join wholestage on           30946          30972          26          0.7        1475.6       2.0X

Does this PR introduce any user-facing change?

No.

How was this patch tested?

Amazon AWS EC2
type: r3.xlarge
region: us-west-2 (Oregon)
OS: Linux

* Generate the (non-equi) condition used to filter joined rows. This is used in Inner, Left Semi
* and Left Anti joins.
*/
protected def getJoinCondition(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is copied from HashJoin.scala: getJoinCondition().

/**
* Generates the code for variable of build side.
*/
protected def genBuildSideVars(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is copied from HashJoin.scala: genBuildSideVars().

@c21
Copy link
Contributor Author

c21 commented Mar 4, 2021

cc @cloud-fan and @maropu if you have time to take a look, thanks.

@github-actions github-actions bot added the SQL label Mar 4, 2021
@@ -178,6 +191,7 @@ object JoinBenchmark extends SqlBasedBenchmark {
sortMergeJoin()
sortMergeJoinWithDuplicates()
shuffleHashJoin()
broadcastNestedLoopJoin()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we update the benchmark results(JoinBenchmark-results.txt and JoinBenchmark-jdk11-results.txt)?

Generate result:

SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain org.apache.spark.sql.execution.benchmark.JoinBenchmark"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wangyum - sure, I noticed the original benchmark was running on OpenJDK 64-Bit Server VM 1.8.0_232-8u232-b09-0ubuntu1~18.04.1-b09 on Linux 4.15.0-1044-aws in result file. Should I try to run benchmark in the same type of machine? or this is not a hard requirement. cc @maropu as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wangyum, @maropu - updated the join benchmark results with same type machines as previous PRs. The description of PR has more details, thanks.

val M = 1 << 4

val dim = broadcast(spark.range(M).selectExpr("id as k", "cast(id as string) as v"))
codegenBenchmark("broadcast nested loop join", N) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you update benchmarks/JoinBenchmark-results.txt, too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maropu - sure, replied back above.

assert(oneJoinDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : BroadcastNestedLoopJoinExec) => true
}.size === 1)
checkAnswer(oneJoinDF,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For better test coverage, could you run tests with codegen enabled/disabled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maropu - sure, updated.

streamed.asInstanceOf[CodegenSupport].inputRDDs()
}

override def needCopyResult: Boolean = true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

always true? Is there a trade-off between the overheads of uniqueness checks and result copys?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maropu - in case of inner/cross broadcast nested loop join, one input row can potentially have multiple output rows, so I am following the comment here to set it true. btw sort merge join has same behavior. I think if we want to improve it, I can do in another followup PR. I am not very familiar with this part and I can check closer later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a hash join and we don't know the uniqueness. I think it needs to be always true.

case _: InnerLike => codegenInner(ctx, input)
case x =>
throw new IllegalArgumentException(
s"BroadcastNestedLoopJoin code-gen should not take $x as the JoinType")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

      case _ =>
        throw new IllegalArgumentException(
          s"BroadcastNestedLoopJoin code-gen should not take $joinType as the JoinType")

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maropu - updated.

}

private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val arrayTerm = prepareBroadcast(ctx)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: arrayTerm -> buildRowArrayTerm, buildRowsTerm, ...?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maropu - updated to buildRowArrayTerm.

val broadcastTerm = ctx.addReferenceObj("broadcastTerm", broadcastArray)

// Inline mutable state since not many join operations in a task
ctx.addMutableState("InternalRow[]", "broadcastArray",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: broadcastArray -> buildRows, buidlRowArray, ...?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maropu - updated to buildRowArray.

@SparkQA
Copy link

SparkQA commented Mar 4, 2021

Kubernetes integration test starting
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/40336/

@SparkQA
Copy link

SparkQA commented Mar 4, 2021

Kubernetes integration test status success
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/40336/

@SparkQA
Copy link

SparkQA commented Mar 5, 2021

Kubernetes integration test starting
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/40384/


/**
* Generate the (non-equi) condition used to filter joined rows. This is used in Inner, Left Semi
* and Left Anti joins.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's in a base trait now, can we add more doc to explain the return type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan - makes sense, updated.

@SparkQA
Copy link

SparkQA commented Mar 5, 2021

Kubernetes integration test status failure
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/40384/

s"""
|int $arrayIndex = 0;
|UnsafeRow $buildRow;
|while ($arrayIndex < $buildRowArrayTerm.length) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make it more java style

for (int $arrayIndex = 0; $arrayIndex < $buildRowArrayTerm.length; $arrayIndex++) {
  UnsafeRow $buildRow = (UnsafeRow) $buildRowArrayTerm[$arrayIndex];
  ...
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I vaguely remembered some places in one PR that I saw while is preferred over for (or my memory can be wrong and it's opposite). @cloud-fan - do you remember related discussion?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while is preferred over for in Scala. I don't think it's true for Java, at least from my experience of Java development in the past.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan - updated.

@c21
Copy link
Contributor Author

c21 commented Mar 6, 2021

Addressed all comments, and the PR is ready to review again, thanks @cloud-fan , @maropu and @wangyum .

@SparkQA
Copy link

SparkQA commented Mar 6, 2021

Kubernetes integration test starting
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/40408/

@SparkQA
Copy link

SparkQA commented Mar 6, 2021

Kubernetes integration test status success
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/40408/

* Generate the (non-equi) condition used to filter joined rows.
* This is used in Inner, Left Semi and Left Anti joins.
*
* @return Variable name for row of build side.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: How about this instead?

   * @return Tuple of variable name for row of build side, generated code for condition,
   *         and generated code for variables of build side.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maropu - sure, updated.

@SparkQA
Copy link

SparkQA commented Mar 6, 2021

Test build #135826 has finished for PR 31736 at commit 4e5341f.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Mar 7, 2021

Kubernetes integration test starting
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/40429/

@SparkQA
Copy link

SparkQA commented Mar 7, 2021

Kubernetes integration test status success
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/40429/

@SparkQA
Copy link

SparkQA commented Mar 8, 2021

Test build #135847 has finished for PR 31736 at commit 4d26e49.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

*/
protected def getJoinCondition(
ctx: CodegenContext,
input: Seq[ExprCode],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: streamVars?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan - updated.

@c21
Copy link
Contributor Author

c21 commented Mar 9, 2021

Addressed all comments for now, and the PR is ready for check again, thanks.

@c21 c21 closed this Mar 9, 2021
@c21 c21 reopened this Mar 9, 2021
@c21
Copy link
Contributor Author

c21 commented Mar 9, 2021

Close & reopen PR to trigger test rerun as some transient unit test failure happened.

@SparkQA
Copy link

SparkQA commented Mar 9, 2021

Kubernetes integration test starting
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/40483/

@SparkQA
Copy link

SparkQA commented Mar 9, 2021

Kubernetes integration test status success
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/40483/

@SparkQA
Copy link

SparkQA commented Mar 9, 2021

Test build #135901 has finished for PR 31736 at commit 6f6d511.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@cloud-fan
Copy link
Contributor

thanks, merging to master!

@cloud-fan cloud-fan closed this in b5b1985 Mar 9, 2021
@c21
Copy link
Contributor Author

c21 commented Mar 9, 2021

Thank you @cloud-fan , @maropu , @viirya and @wangyum for review!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
6 participants