Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-12798] [SQL] generated BroadcastHashJoin #10989

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -18,9 +18,11 @@
package org.apache.spark.sql.execution;

import java.io.IOException;
import java.util.LinkedList;

import scala.collection.Iterator;

import org.apache.spark.TaskContext;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;

Expand All @@ -31,36 +33,50 @@
* TODO: replaced it by batched columnar format.
*/
public class BufferedRowIterator {
protected InternalRow currentRow;
protected LinkedList<InternalRow> currentRows = new LinkedList<>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

orthogonal to this pr -- my first reaction to this is that maybe we should spend a week or two to convert all operators to a push-based model. Otherwise performance is going to suck big time for some operators.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a huge topic, let's talk about this offline.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan @mgaido91 @viirya
When spark.sql.codegen.wholeStage = true, some joins caused OOM, analyzed the dump file, and found that BufferedRowIterator#currentRows holds all matching rows.
If codegen is turned off, it runs just fine, only one matching row is generated each time.
Increasing the executor memory may run successfully, but there is always a probability of failure, because it is not known how many rows of the current key match.

example:

    val value = "x" * 1000 * 1000
    case class TestData(key: Int, value: String)
    val testData = spark.sparkContext.parallelize((1 to 1)
      .map(i => TestData(i, value))).toDF()
    var bigData = testData
    for (_ <- Range(0, 10)) {
      bigData = bigData.union(bigData)
    }
    val testDataX = testData.as("x").selectExpr("key as xkey", "value as xvalue")
    val bigDataY = bigData.as("y").selectExpr("key as ykey", "value as yvalue")
    testDataX.join(bigDataY).where("xkey = ykey").write.saveAsTable("test")

hprof:
image

currently generated code snippet:

protected void processNext() throws java.io.IOException {
    while (findNextInnerJoinRows(smj_leftInput_0, smj_rightInput_0)) {
      scala.collection.Iterator<UnsafeRow> smj_iterator_0 = smj_matches_0.generateIterator();
      while (smj_iterator_0.hasNext()) {
        InternalRow smj_rightRow_1 = (InternalRow) smj_iterator_0.next();
        append(xxRow.copy());
      }
      if (shouldStop()) return;
    }
}

Is it possible to change to code like this, or is there any other better way?

private scala.collection.Iterator<UnsafeRow> smj_iterator_0;
protected void processNext() throws java.io.IOException {
    if(smj_iterator_0 != null & smj_iterator_0.hasNext) {
        append(xxRow.getRow().copy());
        if(smj_iterator_0.hasNext) {
            smj_iterator_0 = null;
        }
        return;
    }
    while (findNextInnerJoinRows(smj_leftInput_0, smj_rightInput_0)) {
      smj_iterator_0 = smj_matches_0.generateIterator();
      if (smj_iterator_0.hasNext()) {
        append(xxRow.getRow().copy());
        if(smj_iterator_0.hasNext) {
            smj_iterator_0 = null;
        }
        return;
      }
      if (shouldStop()) return;
    }
}

protected Iterator<InternalRow> input;
// used when there is no column in output
protected UnsafeRow unsafeRow = new UnsafeRow(0);

public boolean hasNext() throws IOException {
if (currentRow == null) {
if (currentRows.isEmpty()) {
processNext();
}
return currentRow != null;
return !currentRows.isEmpty();
}

public InternalRow next() {
InternalRow r = currentRow;
currentRow = null;
return r;
return currentRows.remove();
}

public void setInput(Iterator<InternalRow> iter) {
input = iter;
}

/**
* Returns whether `processNext()` should stop processing next row from `input` or not.
*
* If it returns true, the caller should exit the loop (return from processNext()).
*/
protected boolean shouldStop() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rxin this seems like it could be used to support limit.

return !currentRows.isEmpty();
}

/**
* Increase the peak execution memory for current task.
*/
protected void incPeakExecutionMemory(long size) {
TaskContext.get().taskMetrics().incPeakExecutionMemory(size);
}

/**
* Processes the input until have a row as output (currentRow).
*
* After it's called, if currentRow is still null, it means no more rows left.
*/
protected void processNext() throws IOException {
if (input.hasNext()) {
currentRow = input.next();
currentRows.add(input.next());
}
}
}
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight}
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -172,6 +173,9 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
| InternalRow $row = (InternalRow) input.next();
| ${columns.map(_.code).mkString("\n").trim}
| ${consume(ctx, columns).trim}
| if (shouldStop()) {
| return;
| }
| }
""".stripMargin
}
Expand Down Expand Up @@ -283,8 +287,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
if (row != null) {
// There is an UnsafeRow already
s"""
| currentRow = $row;
| return;
| currentRows.add($row.copy());
""".stripMargin
} else {
assert(input != null)
Expand All @@ -297,14 +300,12 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
s"""
| ${code.code.trim}
| currentRow = ${code.value};
| return;
| currentRows.add(${code.value}.copy());
""".stripMargin
} else {
// There is no columns
s"""
| currentRow = unsafeRow;
| return;
| currentRows.add(unsafeRow);
""".stripMargin
}
}
Expand Down Expand Up @@ -371,6 +372,11 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru

var inputs = ArrayBuffer[SparkPlan]()
val combined = plan.transform {
// The build side can't be compiled together
case b @ BroadcastHashJoin(_, _, BuildLeft, _, left, right) =>
b.copy(left = apply(left))
case b @ BroadcastHashJoin(_, _, BuildRight, _, left, right) =>
b.copy(right = apply(right))
case p if !supportCodegen(p) =>
val input = apply(p) // collapse them recursively
inputs += input
Expand Down
Expand Up @@ -471,6 +471,8 @@ case class TungstenAggregate(
UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
$outputCode

if (shouldStop()) return;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you document the required behavior of shouldStop(). How does it need to behave so that the clean up below (hashMapTerm.free()) is called?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once shouldStop() returns true, the caller should exit the loop (via return).

map.free() is called only when it had consumed all the items in the loop (without return).

Will added these to the doc string of shouldStop().

}

$iterTerm.close();
Expand All @@ -480,7 +482,7 @@ case class TungstenAggregate(
"""
}

private def doConsumeWithKeys( ctx: CodegenContext, input: Seq[ExprCode]): String = {
private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {

// create grouping key
ctx.currentVars = input
Expand Down
Expand Up @@ -237,6 +237,8 @@ case class Range(
| $overflow = true;
| }
| ${consume(ctx, Seq(ev))}
|
| if (shouldStop()) return;
| }
""".stripMargin
}
Expand Down
Expand Up @@ -20,14 +20,17 @@ package org.apache.spark.sql.execution.joins
import scala.concurrent._
import scala.concurrent.duration._

import org.apache.spark.{InternalAccumulator, TaskContext}
import org.apache.spark.TaskContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{BindReferences, BoundReference, Expression, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.ThreadUtils
import org.apache.spark.util.collection.CompactBuffer

/**
* Performs an inner hash join of two child relations. When the output RDD of this operator is
Expand All @@ -42,7 +45,7 @@ case class BroadcastHashJoin(
condition: Option[Expression],
left: SparkPlan,
right: SparkPlan)
extends BinaryNode with HashJoin {
extends BinaryNode with HashJoin with CodegenSupport {

override private[sql] lazy val metrics = Map(
"numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"),
Expand Down Expand Up @@ -117,6 +120,87 @@ case class BroadcastHashJoin(
hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows)
}
}

// the term for hash relation
private var relationTerm: String = _

override def upstream(): RDD[InternalRow] = {
streamedPlan.asInstanceOf[CodegenSupport].upstream()
}

override def doProduce(ctx: CodegenContext): String = {
// create a name for HashRelation
val broadcastRelation = Await.result(broadcastFuture, timeout)
val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation)
relationTerm = ctx.freshName("relation")
// TODO: create specialized HashRelation for single join key
val clsName = classOf[UnsafeHashedRelation].getName
ctx.addMutableState(clsName, relationTerm,
s"""
| $relationTerm = ($clsName) $broadcast.value();
| incPeakExecutionMemory($relationTerm.getUnsafeSize());
""".stripMargin)

s"""
| ${streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)}
""".stripMargin
}

override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
// generate the key as UnsafeRow
ctx.currentVars = input
val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output))
val keyVal = GenerateUnsafeProjection.createCode(ctx, keyExpr)
val keyTerm = keyVal.value
val anyNull = if (keyExpr.exists(_.nullable)) s"$keyTerm.anyNull()" else "false"

// find the matches from HashedRelation
val matches = ctx.freshName("matches")
val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
val i = ctx.freshName("i")
val size = ctx.freshName("size")
val row = ctx.freshName("row")

// create variables for output
ctx.currentVars = null
ctx.INPUT_ROW = row
val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) =>
BoundReference(i, a.dataType, a.nullable).gen(ctx)
}
val resultVars = buildSide match {
case BuildLeft => buildColumns ++ input
case BuildRight => input ++ buildColumns
}

val ouputCode = if (condition.isDefined) {
// filter the output via condition
ctx.currentVars = resultVars
val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
s"""
| ${ev.code}
| if (!${ev.isNull} && ${ev.value}) {
| ${consume(ctx, resultVars)}
| }
""".stripMargin
} else {
consume(ctx, resultVars)
}

s"""
| // generate join key
| ${keyVal.code}
| // find matches from HashRelation
| $bufferType $matches = $anyNull ? null : ($bufferType) $relationTerm.get($keyTerm);
| if ($matches != null) {
| int $size = $matches.size();
| for (int $i = 0; $i < $size; $i++) {
Copy link
Contributor

@cloud-fan cloud-fan Oct 2, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see a strong reason that we can't interrupt this loop. We can make i a global variable for example.

I don't mean to change anything, but just to verify my understanding. Also cc @hvanhovell @viirya @mgaido91 @rednaxelafx

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmmh... this code seems rather outdated...I couldn't find it in the current codebase. Anyway, I don't understand why you want to interrupt it. AFAIU, this is generating the result from all the matches of a row, hence if we interrupt it somehow we would end up returning a wrong result (in the result we would omit some rows...).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, yeah, this code is changed a lot since this PR, looks like at that moment this BroadcastHashJoin only supports inner join. I also don't really get the idea to interrupt this loop early, as looks like we need to go through all matched rows here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea the code change a lot but we still generate loops for broadcast join.

This PR made BufferedRowIterator.currentRow to BufferedRowIterator.currentRows, to store result rows instead of a single row. If we can interrupt the loop and can still run it in the next call of processNext, we can still keep a single result row.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmmh, maybe I see your point now. I think it may be feasible but a bit complex. We might keep a global variable for $matches and read from it in the produce method. This is what you are saying right? Just changing here wouldn't work IMHO because in the next iteration the keys are changed...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Then we gotta to keep matches as global status instead of local one, so we can go over remaining matched rows in next iterations. And we shouldn't get next row from streaming side but use previous row from that side. It can make a single result row without buffering all matched rows into currentRows, though it might need to add some complexity into the generated code.

| UnsafeRow $row = (UnsafeRow) $matches.apply($i);
| ${buildColumns.map(_.code).mkString("\n")}
| $ouputCode
| }
| }
""".stripMargin
}
}

object BroadcastHashJoin {
Expand Down
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.functions._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.hash.Murmur3_x86_32
import org.apache.spark.unsafe.map.BytesToBytesMap
Expand Down Expand Up @@ -130,6 +131,30 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
benchmark.run()
}

def testBroadcastHashJoin(values: Int): Unit = {
val benchmark = new Benchmark("BroadcastHashJoin", values)

val dim = broadcast(sqlContext.range(1 << 16).selectExpr("id as k", "cast(id as string) as v"))

benchmark.addCase("BroadcastHashJoin w/o codegen") { iter =>
sqlContext.setConf("spark.sql.codegen.wholeStage", "false")
sqlContext.range(values).join(dim, (col("id") % 60000) === col("k")).count()
}
benchmark.addCase(s"BroadcastHashJoin w codegen") { iter =>
sqlContext.setConf("spark.sql.codegen.wholeStage", "true")
sqlContext.range(values).join(dim, (col("id") % 60000) === col("k")).count()
}

/*
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
BroadcastHashJoin: Avg Time(ms) Avg Rate(M/s) Relative Rate
-------------------------------------------------------------------------------
BroadcastHashJoin w/o codegen 3053.41 3.43 1.00 X
BroadcastHashJoin w codegen 1028.40 10.20 2.97 X
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also run a benchmark using a larger range so we amortize the broadcast overhead? i'm interested in seeing what the improvement is for the join part of the benchmark.

e.g. fix the size of the dimension table, but increase the probe side by 10x.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the dimension table is pretty small, overhead of broadcast is also low, when I ran it with larger range, the improvements did not change much, because looking up in BytesToBytes is the bottleneck. I will have another PR to improve the join with small dimension table.

*/
benchmark.run()
}

def testBytesToBytesMap(values: Int): Unit = {
val benchmark = new Benchmark("BytesToBytesMap", values)

Expand Down Expand Up @@ -201,6 +226,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
// testWholeStage(200 << 20)
// testStatFunctions(20 << 20)
// testAggregateWithKey(20 << 20)
// testBytesToBytesMap(1024 * 1024 * 50)
// testBytesToBytesMap(50 << 20)
// testBroadcastHashJoin(10 << 20)
}
}
Expand Up @@ -20,8 +20,10 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.functions.{avg, col, max}
import org.apache.spark.sql.execution.joins.BroadcastHashJoin
import org.apache.spark.sql.functions.{avg, broadcast, col, max}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}

class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {

Expand Down Expand Up @@ -56,4 +58,15 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined)
assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1)))
}

test("BroadcastHashJoin should be included in WholeStageCodegen") {
val rdd = sqlContext.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2")))
val schema = new StructType().add("k", IntegerType).add("v", StringType)
val smallDF = sqlContext.createDataFrame(rdd, schema)
val df = sqlContext.range(10).join(broadcast(smallDF), col("k") === col("id"))
assert(df.queryExecution.executedPlan.find(p =>
p.isInstanceOf[WholeStageCodegen] &&
p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[BroadcastHashJoin]).isDefined)
assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2")))
}
}