New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-12798] [SQL] generated BroadcastHashJoin #10989
Changes from 25 commits
3e792f3
2f1a082
7d1bd43
7880786
407460d
ff04509
081a04d
c3c0b36
77ba890
9a42b52
37bc7f0
48e125c
efe7fa2
3bfdeb2
858c1e3
be2e53b
f234c21
9ae4bc2
89614a5
dcf4fdc
e665b9b
1ecce29
0139fde
c1c0588
4d75022
4fcf5d2
e0c8c65
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,9 +18,11 @@ | |
package org.apache.spark.sql.execution; | ||
|
||
import java.io.IOException; | ||
import java.util.LinkedList; | ||
|
||
import scala.collection.Iterator; | ||
|
||
import org.apache.spark.TaskContext; | ||
import org.apache.spark.sql.catalyst.InternalRow; | ||
import org.apache.spark.sql.catalyst.expressions.UnsafeRow; | ||
|
||
|
@@ -31,36 +33,48 @@ | |
* TODO: replaced it by batched columnar format. | ||
*/ | ||
public class BufferedRowIterator { | ||
protected InternalRow currentRow; | ||
protected LinkedList<InternalRow> currentRows = new LinkedList<>(); | ||
protected Iterator<InternalRow> input; | ||
// used when there is no column in output | ||
protected UnsafeRow unsafeRow = new UnsafeRow(0); | ||
|
||
public boolean hasNext() throws IOException { | ||
if (currentRow == null) { | ||
if (currentRows.isEmpty()) { | ||
processNext(); | ||
} | ||
return currentRow != null; | ||
return !currentRows.isEmpty(); | ||
} | ||
|
||
public InternalRow next() { | ||
InternalRow r = currentRow; | ||
currentRow = null; | ||
return r; | ||
return currentRows.remove(); | ||
} | ||
|
||
public void setInput(Iterator<InternalRow> iter) { | ||
input = iter; | ||
} | ||
|
||
/** | ||
* Returns whether `processNext()` should stop processing next row from `input` or not. | ||
*/ | ||
protected boolean shouldStop() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @rxin this seems like it could be used to support limit. |
||
return !currentRows.isEmpty(); | ||
} | ||
|
||
/** | ||
* Increase the peak execution memory for current task. | ||
*/ | ||
protected void incPeakExecutionMemory(long size) { | ||
TaskContext.get().taskMetrics().incPeakExecutionMemory(size); | ||
} | ||
|
||
/** | ||
* Processes the input until have a row as output (currentRow). | ||
* | ||
* After it's called, if currentRow is still null, it means no more rows left. | ||
*/ | ||
protected void processNext() throws IOException { | ||
if (input.hasNext()) { | ||
currentRow = input.next(); | ||
currentRows.add(input.next()); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -286,16 +286,6 @@ case class TungstenAggregate( | |
GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) | ||
} | ||
|
||
|
||
/** | ||
* Update peak execution memory, called in generated Java class. | ||
*/ | ||
def updatePeakMemory(hashMap: UnsafeFixedWidthAggregationMap): Unit = { | ||
val mapMemory = hashMap.getPeakMemoryUsedBytes | ||
val metrics = TaskContext.get().taskMetrics() | ||
metrics.incPeakExecutionMemory(mapMemory) | ||
} | ||
|
||
private def doProduceWithKeys(ctx: CodegenContext): String = { | ||
val initAgg = ctx.freshName("initAgg") | ||
ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") | ||
|
@@ -389,14 +379,16 @@ case class TungstenAggregate( | |
UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); | ||
UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); | ||
$outputCode | ||
|
||
if (shouldStop()) return; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you document the required behavior of shouldStop(). How does it need to behave so that the clean up below (hashMapTerm.free()) is called? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Once
Will added these to the doc string of shouldStop(). |
||
} | ||
|
||
$thisPlan.updatePeakMemory($hashMapTerm); | ||
incPeakExecutionMemory($hashMapTerm.getPeakMemoryUsedBytes()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we bake the peak memory usage into hashMapTerm.free()? This seems like something we'll forget to do. |
||
$hashMapTerm.free(); | ||
""" | ||
} | ||
|
||
private def doConsumeWithKeys( ctx: CodegenContext, input: Seq[ExprCode]): String = { | ||
private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { | ||
|
||
// create grouping key | ||
ctx.currentVars = input | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,14 +20,17 @@ package org.apache.spark.sql.execution.joins | |
import scala.concurrent._ | ||
import scala.concurrent.duration._ | ||
|
||
import org.apache.spark.{InternalAccumulator, TaskContext} | ||
import org.apache.spark.TaskContext | ||
import org.apache.spark.broadcast.Broadcast | ||
import org.apache.spark.rdd.RDD | ||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.expressions.Expression | ||
import org.apache.spark.sql.catalyst.expressions.{BindReferences, BoundReference, Expression, UnsafeRow} | ||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} | ||
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} | ||
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan, SQLExecution} | ||
import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan, SQLExecution} | ||
import org.apache.spark.sql.execution.metric.SQLMetrics | ||
import org.apache.spark.util.ThreadUtils | ||
import org.apache.spark.util.collection.CompactBuffer | ||
|
||
/** | ||
* Performs an inner hash join of two child relations. When the output RDD of this operator is | ||
|
@@ -42,7 +45,7 @@ case class BroadcastHashJoin( | |
condition: Option[Expression], | ||
left: SparkPlan, | ||
right: SparkPlan) | ||
extends BinaryNode with HashJoin { | ||
extends BinaryNode with HashJoin with CodegenSupport { | ||
|
||
override private[sql] lazy val metrics = Map( | ||
"numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), | ||
|
@@ -117,6 +120,87 @@ case class BroadcastHashJoin( | |
hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows) | ||
} | ||
} | ||
|
||
// the term for hash relation | ||
private var relationTerm: String = _ | ||
|
||
override def upstream(): RDD[InternalRow] = { | ||
streamedPlan.asInstanceOf[CodegenSupport].upstream() | ||
} | ||
|
||
override def doProduce(ctx: CodegenContext): String = { | ||
// create a name for HashRelation | ||
val broadcastRelation = Await.result(broadcastFuture, timeout) | ||
val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation) | ||
relationTerm = ctx.freshName("relation") | ||
// TODO: create specialized HashRelation for single join key | ||
val clsName = classOf[UnsafeHashedRelation].getName | ||
ctx.addMutableState(clsName, relationTerm, | ||
s""" | ||
| $relationTerm = ($clsName) $broadcast.value(); | ||
| incPeakExecutionMemory($relationTerm.getUnsafeSize()); | ||
""".stripMargin) | ||
|
||
s""" | ||
| ${streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)} | ||
""".stripMargin | ||
} | ||
|
||
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { | ||
// generate the key as UnsafeRow | ||
ctx.currentVars = input | ||
val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output)) | ||
val keyVal = GenerateUnsafeProjection.createCode(ctx, keyExpr) | ||
val keyTerm = keyVal.value | ||
val anyNull = if (keyExpr.exists(_.nullable)) s"$keyTerm.anyNull()" else "false" | ||
|
||
// find the matches from HashedRelation | ||
val matches = ctx.freshName("matches") | ||
val bufferType = classOf[CompactBuffer[UnsafeRow]].getName | ||
val i = ctx.freshName("i") | ||
val size = ctx.freshName("size") | ||
val row = ctx.freshName("row") | ||
|
||
// create variables for output | ||
ctx.currentVars = null | ||
ctx.INPUT_ROW = row | ||
val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) => | ||
BoundReference(i, a.dataType, a.nullable).gen(ctx) | ||
} | ||
val resultVars = buildSide match { | ||
case BuildLeft => buildColumns ++ input | ||
case BuildRight => input ++ buildColumns | ||
} | ||
|
||
val ouputCode = if (condition.isDefined) { | ||
// filter the output via condition | ||
ctx.currentVars = resultVars | ||
val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx) | ||
s""" | ||
| ${ev.code} | ||
| if (!${ev.isNull} && ${ev.value}) { | ||
| ${consume(ctx, resultVars)} | ||
| } | ||
""".stripMargin | ||
} else { | ||
consume(ctx, resultVars) | ||
} | ||
|
||
s""" | ||
| // generate join key | ||
| ${keyVal.code} | ||
| // find matches from HashRelation | ||
| $bufferType $matches = $anyNull ? null : ($bufferType) $relationTerm.get($keyTerm); | ||
| if ($matches != null) { | ||
| int $size = $matches.size(); | ||
| for (int $i = 0; $i < $size; $i++) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see a strong reason that we can't interrupt this loop. We can make I don't mean to change anything, but just to verify my understanding. Also cc @hvanhovell @viirya @mgaido91 @rednaxelafx There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mmmh... this code seems rather outdated...I couldn't find it in the current codebase. Anyway, I don't understand why you want to interrupt it. AFAIU, this is generating the result from all the matches of a row, hence if we interrupt it somehow we would end up returning a wrong result (in the result we would omit some rows...). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm, yeah, this code is changed a lot since this PR, looks like at that moment this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea the code change a lot but we still generate loops for broadcast join. This PR made There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mmmh, maybe I see your point now. I think it may be feasible but a bit complex. We might keep a global variable for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. Then we gotta to keep |
||
| UnsafeRow $row = (UnsafeRow) $matches.apply($i); | ||
| ${buildColumns.map(_.code).mkString("\n")} | ||
| $ouputCode | ||
| } | ||
| } | ||
""".stripMargin | ||
} | ||
} | ||
|
||
object BroadcastHashJoin { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} | |
import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} | ||
import org.apache.spark.sql.SQLContext | ||
import org.apache.spark.sql.catalyst.expressions.UnsafeRow | ||
import org.apache.spark.sql.functions._ | ||
import org.apache.spark.unsafe.Platform | ||
import org.apache.spark.unsafe.hash.Murmur3_x86_32 | ||
import org.apache.spark.unsafe.map.BytesToBytesMap | ||
|
@@ -130,6 +131,30 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { | |
benchmark.run() | ||
} | ||
|
||
def testBroadcastHashJoin(values: Int): Unit = { | ||
val benchmark = new Benchmark("BroadcastHashJoin", values) | ||
|
||
val dim = broadcast(sqlContext.range(1 << 16).selectExpr("id as k", "cast(id as string) as v")) | ||
|
||
benchmark.addCase("BroadcastHashJoin w/o codegen") { iter => | ||
sqlContext.setConf("spark.sql.codegen.wholeStage", "false") | ||
sqlContext.range(values).join(dim, (col("id") % 60000) === col("k")).count() | ||
} | ||
benchmark.addCase(s"BroadcastHashJoin w codegen") { iter => | ||
sqlContext.setConf("spark.sql.codegen.wholeStage", "true") | ||
sqlContext.range(values).join(dim, (col("id") % 60000) === col("k")).count() | ||
} | ||
|
||
/* | ||
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz | ||
BroadcastHashJoin: Avg Time(ms) Avg Rate(M/s) Relative Rate | ||
------------------------------------------------------------------------------- | ||
BroadcastHashJoin w/o codegen 3053.41 3.43 1.00 X | ||
BroadcastHashJoin w codegen 1028.40 10.20 2.97 X | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you also run a benchmark using a larger range so we amortize the broadcast overhead? i'm interested in seeing what the improvement is for the join part of the benchmark. e.g. fix the size of the dimension table, but increase the probe side by 10x. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the dimension table is pretty small, overhead of broadcast is also low, when I ran it with larger range, the improvements did not change much, because looking up in BytesToBytes is the bottleneck. I will have another PR to improve the join with small dimension table. |
||
*/ | ||
benchmark.run() | ||
} | ||
|
||
def testBytesToBytesMap(values: Int): Unit = { | ||
val benchmark = new Benchmark("BytesToBytesMap", values) | ||
|
||
|
@@ -199,8 +224,9 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { | |
// These benchmark are skipped in normal build | ||
ignore("benchmark") { | ||
// testWholeStage(200 << 20) | ||
// testStddev(20 << 20) | ||
// testStatFunctions(20 << 20) | ||
// testAggregateWithKey(20 << 20) | ||
// testBytesToBytesMap(1024 * 1024 * 50) | ||
// testBytesToBytesMap(50 << 20) | ||
// testBroadcastHashJoin(10 << 20) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
orthogonal to this pr -- my first reaction to this is that maybe we should spend a week or two to convert all operators to a push-based model. Otherwise performance is going to suck big time for some operators.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a huge topic, let's talk about this offline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cloud-fan @mgaido91 @viirya
When
spark.sql.codegen.wholeStage = true
, some joins caused OOM, analyzed the dump file, and found thatBufferedRowIterator#currentRows
holds all matching rows.If codegen is turned off, it runs just fine, only one matching row is generated each time.
Increasing the executor memory may run successfully, but there is always a probability of failure, because it is not known how many rows of the current key match.
example:
hprof:
currently generated code snippet:
Is it possible to change to code like this, or is there any other better way?