Skip to content

Commit

Permalink
[SPARK-34620][SQL] Code-gen broadcast nested loop join (inner/cross)
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

`BroadcastNestedLoopJoinExec` does not have code-gen, and we can potentially boost the CPU performance for this operator if we add code-gen for it. https://databricks.com/blog/2017/02/16/processing-trillion-rows-per-second-single-machine-can-nested-loop-joins-fast.html also showed the evidence in one fork.

The codegen for `BroadcastNestedLoopJoinExec` shared some code with `HashJoin`, and the interface `JoinCodegenSupport` is created to hold those common logic. This PR is only supporting inner and cross join. Other join types will be added later in followup PRs.

Example query and generated code:

```
val df1 = spark.range(4).select($"id".as("k1"))
val df2 = spark.range(3).select($"id".as("k2"))
df1.join(df2, $"k1" + 1 =!= $"k2").explain("codegen")
```

```
== Subtree 2 / 2 (maxMethodCodeSize:282; maxConstantPoolSize:203(0.31% used); numInnerClasses:0) ==
*(2) BroadcastNestedLoopJoin BuildRight, Inner, NOT ((k1#2L + 1) = k2#6L)
:- *(2) Project [id#0L AS k1#2L]
:  +- *(2) Range (0, 4, step=1, splits=2)
+- BroadcastExchange IdentityBroadcastMode, [id=#22]
   +- *(1) Project [id#4L AS k2#6L]
      +- *(1) Range (0, 3, step=1, splits=2)

Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIteratorForCodegenStage2(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=2
/* 006 */ final class GeneratedIteratorForCodegenStage2 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */   private Object[] references;
/* 008 */   private scala.collection.Iterator[] inputs;
/* 009 */   private boolean range_initRange_0;
/* 010 */   private long range_nextIndex_0;
/* 011 */   private TaskContext range_taskContext_0;
/* 012 */   private InputMetrics range_inputMetrics_0;
/* 013 */   private long range_batchEnd_0;
/* 014 */   private long range_numElementsTodo_0;
/* 015 */   private InternalRow[] bnlj_buildRowArray_0;
/* 016 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] range_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[4];
/* 017 */
/* 018 */   public GeneratedIteratorForCodegenStage2(Object[] references) {
/* 019 */     this.references = references;
/* 020 */   }
/* 021 */
/* 022 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 023 */     partitionIndex = index;
/* 024 */     this.inputs = inputs;
/* 025 */
/* 026 */     range_taskContext_0 = TaskContext.get();
/* 027 */     range_inputMetrics_0 = range_taskContext_0.taskMetrics().inputMetrics();
/* 028 */     range_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 029 */     range_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 030 */     range_mutableStateArray_0[2] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 031 */     bnlj_buildRowArray_0 = (InternalRow[]) ((org.apache.spark.broadcast.TorrentBroadcast) references[1] /* broadcastTerm */).value();
/* 032 */     range_mutableStateArray_0[3] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 0);
/* 033 */
/* 034 */   }
/* 035 */
/* 036 */   private void bnlj_doConsume_0(long bnlj_expr_0_0) throws java.io.IOException {
/* 037 */     for (int bnlj_arrayIndex_0 = 0; bnlj_arrayIndex_0 < bnlj_buildRowArray_0.length; bnlj_arrayIndex_0++) {
/* 038 */       UnsafeRow bnlj_buildRow_0 = (UnsafeRow) bnlj_buildRowArray_0[bnlj_arrayIndex_0];
/* 039 */
/* 040 */       long bnlj_value_1 = bnlj_buildRow_0.getLong(0);
/* 041 */
/* 042 */       long bnlj_value_4 = -1L;
/* 043 */
/* 044 */       bnlj_value_4 = bnlj_expr_0_0 + 1L;
/* 045 */
/* 046 */       boolean bnlj_value_3 = false;
/* 047 */       bnlj_value_3 = bnlj_value_4 == bnlj_value_1;
/* 048 */       boolean bnlj_value_2 = false;
/* 049 */       bnlj_value_2 = !(bnlj_value_3);
/* 050 */       if (!(false || !bnlj_value_2))
/* 051 */       {
/* 052 */         ((org.apache.spark.sql.execution.metric.SQLMetric) references[2] /* numOutputRows */).add(1);
/* 053 */
/* 054 */         range_mutableStateArray_0[3].reset();
/* 055 */
/* 056 */         range_mutableStateArray_0[3].write(0, bnlj_expr_0_0);
/* 057 */
/* 058 */         range_mutableStateArray_0[3].write(1, bnlj_value_1);
/* 059 */         append((range_mutableStateArray_0[3].getRow()).copy());
/* 060 */
/* 061 */       }
/* 062 */     }
/* 063 */
/* 064 */   }
/* 065 */
/* 066 */   private void initRange(int idx) {
/* 067 */     java.math.BigInteger index = java.math.BigInteger.valueOf(idx);
/* 068 */     java.math.BigInteger numSlice = java.math.BigInteger.valueOf(2L);
/* 069 */     java.math.BigInteger numElement = java.math.BigInteger.valueOf(4L);
/* 070 */     java.math.BigInteger step = java.math.BigInteger.valueOf(1L);
/* 071 */     java.math.BigInteger start = java.math.BigInteger.valueOf(0L);
/* 072 */     long partitionEnd;
/* 073 */
/* 074 */     java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start);
/* 075 */     if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 076 */       range_nextIndex_0 = Long.MAX_VALUE;
/* 077 */     } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 078 */       range_nextIndex_0 = Long.MIN_VALUE;
/* 079 */     } else {
/* 080 */       range_nextIndex_0 = st.longValue();
/* 081 */     }
/* 082 */     range_batchEnd_0 = range_nextIndex_0;
/* 083 */
/* 084 */     java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice)
/* 085 */     .multiply(step).add(start);
/* 086 */     if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) {
/* 087 */       partitionEnd = Long.MAX_VALUE;
/* 088 */     } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) {
/* 089 */       partitionEnd = Long.MIN_VALUE;
/* 090 */     } else {
/* 091 */       partitionEnd = end.longValue();
/* 092 */     }
/* 093 */
/* 094 */     java.math.BigInteger startToEnd = java.math.BigInteger.valueOf(partitionEnd).subtract(
/* 095 */       java.math.BigInteger.valueOf(range_nextIndex_0));
/* 096 */     range_numElementsTodo_0  = startToEnd.divide(step).longValue();
/* 097 */     if (range_numElementsTodo_0 < 0) {
/* 098 */       range_numElementsTodo_0 = 0;
/* 099 */     } else if (startToEnd.remainder(step).compareTo(java.math.BigInteger.valueOf(0L)) != 0) {
/* 100 */       range_numElementsTodo_0++;
/* 101 */     }
/* 102 */   }
/* 103 */
/* 104 */   protected void processNext() throws java.io.IOException {
/* 105 */     // initialize Range
/* 106 */     if (!range_initRange_0) {
/* 107 */       range_initRange_0 = true;
/* 108 */       initRange(partitionIndex);
/* 109 */     }
/* 110 */
/* 111 */     while (true) {
/* 112 */       if (range_nextIndex_0 == range_batchEnd_0) {
/* 113 */         long range_nextBatchTodo_0;
/* 114 */         if (range_numElementsTodo_0 > 1000L) {
/* 115 */           range_nextBatchTodo_0 = 1000L;
/* 116 */           range_numElementsTodo_0 -= 1000L;
/* 117 */         } else {
/* 118 */           range_nextBatchTodo_0 = range_numElementsTodo_0;
/* 119 */           range_numElementsTodo_0 = 0;
/* 120 */           if (range_nextBatchTodo_0 == 0) break;
/* 121 */         }
/* 122 */         range_batchEnd_0 += range_nextBatchTodo_0 * 1L;
/* 123 */       }
/* 124 */
/* 125 */       int range_localEnd_0 = (int)((range_batchEnd_0 - range_nextIndex_0) / 1L);
/* 126 */       for (int range_localIdx_0 = 0; range_localIdx_0 < range_localEnd_0; range_localIdx_0++) {
/* 127 */         long range_value_0 = ((long)range_localIdx_0 * 1L) + range_nextIndex_0;
/* 128 */
/* 129 */         // common sub-expressions
/* 130 */
/* 131 */         bnlj_doConsume_0(range_value_0);
/* 132 */
/* 133 */         if (shouldStop()) {
/* 134 */           range_nextIndex_0 = range_value_0 + 1L;
/* 135 */           ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localIdx_0 + 1);
/* 136 */           range_inputMetrics_0.incRecordsRead(range_localIdx_0 + 1);
/* 137 */           return;
/* 138 */         }
/* 139 */
/* 140 */       }
/* 141 */       range_nextIndex_0 = range_batchEnd_0;
/* 142 */       ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(range_localEnd_0);
/* 143 */       range_inputMetrics_0.incRecordsRead(range_localEnd_0);
/* 144 */       range_taskContext_0.killTaskIfInterrupted();
/* 145 */     }
/* 146 */   }
/* 147 */
/* 148 */ }
```

### Why are the changes needed?

Improve query CPU performance. Added a micro benchmark query in `JoinBenchmark.scala`.
Saw 1x of run time improvement:

```
OpenJDK 64-Bit Server VM 11.0.9+11-LTS on Linux 4.14.219-161.340.amzn2.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2  2.50GHz
broadcast nested loop join:                Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
-------------------------------------------------------------------------------------------------------------------------
broadcast nested loop join wholestage off          62922          63052         184          0.3        3000.3       1.0X
broadcast nested loop join wholestage on           30946          30972          26          0.7        1475.6       2.0X
```

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

* Added unit test in `WholeStageCodegenSuite.scala`, and existing unit tests for `BroadcastNestedLoopJoinExec`.
* Updated golden files for several TCPDS query plans, as whole stage code-gen for `BroadcastNestedLoopJoinExec` is triggered.
* Updated `JoinBenchmark-jdk11-results.txt ` and `JoinBenchmark-results.txt` with new benchmark result. Followed previous benchmark PRs - #27078 and #26003 to use same type of machine:

```
Amazon AWS EC2
type: r3.xlarge
region: us-west-2 (Oregon)
OS: Linux
```

Closes #31736 from c21/nested-join-exec.

Authored-by: Cheng Su <chengsu@fb.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
c21 authored and cloud-fan committed Mar 9, 2021
1 parent 43b23fd commit b5b1985
Show file tree
Hide file tree
Showing 36 changed files with 1,557 additions and 1,378 deletions.
71 changes: 39 additions & 32 deletions sql/core/benchmarks/JoinBenchmark-jdk11-results.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,74 +2,81 @@
Join Benchmark
================================================================================================

OpenJDK 64-Bit Server VM 11.0.5+10-post-Ubuntu-0ubuntu1.118.04 on Linux 4.15.0-1044-aws
OpenJDK 64-Bit Server VM 11.0.9+11-LTS on Linux 4.14.219-161.340.amzn2.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
Join w long: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
Join w long wholestage off 4441 4572 185 4.7 211.8 1.0X
Join w long wholestage on 1409 1500 96 14.9 67.2 3.2X
Join w long wholestage off 3931 3998 95 5.3 187.4 1.0X
Join w long wholestage on 1507 1769 178 13.9 71.9 2.6X

OpenJDK 64-Bit Server VM 11.0.5+10-post-Ubuntu-0ubuntu1.118.04 on Linux 4.15.0-1044-aws
OpenJDK 64-Bit Server VM 11.0.9+11-LTS on Linux 4.14.219-161.340.amzn2.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
Join w long duplicated: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
Join w long duplicated wholestage off 5111 5116 7 4.1 243.7 1.0X
Join w long duplicated wholestage on 1493 1518 22 14.0 71.2 3.4X
Join w long duplicated wholestage off 5582 5617 50 3.8 266.2 1.0X
Join w long duplicated wholestage on 1435 1451 19 14.6 68.4 3.9X

OpenJDK 64-Bit Server VM 11.0.5+10-post-Ubuntu-0ubuntu1.118.04 on Linux 4.15.0-1044-aws
OpenJDK 64-Bit Server VM 11.0.9+11-LTS on Linux 4.14.219-161.340.amzn2.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
Join w 2 ints: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
Join w 2 ints wholestage off 171821 171906 121 0.1 8193.0 1.0X
Join w 2 ints wholestage on 166559 166975 263 0.1 7942.1 1.0X
Join w 2 ints wholestage off 171470 171478 11 0.1 8176.3 1.0X
Join w 2 ints wholestage on 166612 166762 123 0.1 7944.7 1.0X

OpenJDK 64-Bit Server VM 11.0.5+10-post-Ubuntu-0ubuntu1.118.04 on Linux 4.15.0-1044-aws
OpenJDK 64-Bit Server VM 11.0.9+11-LTS on Linux 4.14.219-161.340.amzn2.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
Join w 2 longs: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
Join w 2 longs wholestage off 7511 7555 62 2.8 358.2 1.0X
Join w 2 longs wholestage on 3776 4119 232 5.6 180.1 2.0X
Join w 2 longs wholestage off 6065 6093 40 3.5 289.2 1.0X
Join w 2 longs wholestage on 3285 3375 97 6.4 156.7 1.8X

OpenJDK 64-Bit Server VM 11.0.5+10-post-Ubuntu-0ubuntu1.118.04 on Linux 4.15.0-1044-aws
OpenJDK 64-Bit Server VM 11.0.9+11-LTS on Linux 4.14.219-161.340.amzn2.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
Join w 2 longs duplicated: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
Join w 2 longs duplicated wholestage off 13563 13617 77 1.5 646.7 1.0X
Join w 2 longs duplicated wholestage on 7947 8053 71 2.6 378.9 1.7X
Join w 2 longs duplicated wholestage off 14969 15027 82 1.4 713.8 1.0X
Join w 2 longs duplicated wholestage on 7902 8151 406 2.7 376.8 1.9X

OpenJDK 64-Bit Server VM 11.0.5+10-post-Ubuntu-0ubuntu1.118.04 on Linux 4.15.0-1044-aws
OpenJDK 64-Bit Server VM 11.0.9+11-LTS on Linux 4.14.219-161.340.amzn2.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
outer join w long: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
outer join w long wholestage off 3915 3923 12 5.4 186.7 1.0X
outer join w long wholestage on 1421 1461 30 14.8 67.8 2.8X
outer join w long wholestage off 2822 2823 1 7.4 134.6 1.0X
outer join w long wholestage on 1419 1436 19 14.8 67.7 2.0X

OpenJDK 64-Bit Server VM 11.0.5+10-post-Ubuntu-0ubuntu1.118.04 on Linux 4.15.0-1044-aws
OpenJDK 64-Bit Server VM 11.0.9+11-LTS on Linux 4.14.219-161.340.amzn2.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
semi join w long: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
semi join w long wholestage off 2310 2332 30 9.1 110.2 1.0X
semi join w long wholestage on 835 860 34 25.1 39.8 2.8X
semi join w long wholestage off 1821 1832 15 11.5 86.8 1.0X
semi join w long wholestage on 828 853 36 25.3 39.5 2.2X

OpenJDK 64-Bit Server VM 11.0.5+10-post-Ubuntu-0ubuntu1.118.04 on Linux 4.15.0-1044-aws
OpenJDK 64-Bit Server VM 11.0.9+11-LTS on Linux 4.14.219-161.340.amzn2.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
sort merge join: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
sort merge join wholestage off 1846 1886 56 1.1 880.5 1.0X
sort merge join wholestage on 1402 1654 234 1.5 668.3 1.3X
sort merge join wholestage off 1371 1380 13 1.5 653.7 1.0X
sort merge join wholestage on 1197 1244 37 1.8 570.9 1.1X

OpenJDK 64-Bit Server VM 11.0.5+10-post-Ubuntu-0ubuntu1.118.04 on Linux 4.15.0-1044-aws
OpenJDK 64-Bit Server VM 11.0.9+11-LTS on Linux 4.14.219-161.340.amzn2.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
sort merge join with duplicates: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
sort merge join with duplicates wholestage off 2852 2879 38 0.7 1360.0 1.0X
sort merge join with duplicates wholestage on 2645 2742 156 0.8 1261.0 1.1X
sort merge join with duplicates: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------------
sort merge join with duplicates wholestage off 1920 1933 20 1.1 915.3 1.0X
sort merge join with duplicates wholestage on 1871 1912 27 1.1 892.0 1.0X

OpenJDK 64-Bit Server VM 11.0.5+10-post-Ubuntu-0ubuntu1.118.04 on Linux 4.15.0-1044-aws
OpenJDK 64-Bit Server VM 11.0.9+11-LTS on Linux 4.14.219-161.340.amzn2.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
shuffle hash join: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
shuffle hash join wholestage off 1506 1564 82 2.8 359.1 1.0X
shuffle hash join wholestage on 1303 1330 23 3.2 310.6 1.2X
shuffle hash join wholestage off 1102 1122 28 3.8 262.8 1.0X
shuffle hash join wholestage on 657 674 13 6.4 156.6 1.7X

OpenJDK 64-Bit Server VM 11.0.9+11-LTS on Linux 4.14.219-161.340.amzn2.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
broadcast nested loop join: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
-------------------------------------------------------------------------------------------------------------------------
broadcast nested loop join wholestage off 62922 63052 184 0.3 3000.3 1.0X
broadcast nested loop join wholestage on 30946 30972 26 0.7 1475.6 2.0X


0 comments on commit b5b1985

Please sign in to comment.