Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-1495][SQL]add support for left semi join #837

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val OUTER = Keyword("OUTER")
protected val RIGHT = Keyword("RIGHT")
protected val SELECT = Keyword("SELECT")
protected val SEMI = Keyword("SEMI")
protected val STRING = Keyword("STRING")
protected val SUM = Keyword("SUM")
protected val TRUE = Keyword("TRUE")
Expand Down Expand Up @@ -238,6 +239,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {

protected lazy val joinType: Parser[JoinType] =
INNER ^^^ Inner |
LEFT ~ SEMI ^^^ LeftSemi |
LEFT ~ opt(OUTER) ^^^ LeftOuter |
RIGHT ~ opt(OUTER) ^^^ RightOuter |
FULL ~ opt(OUTER) ^^^ FullOuter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ object HashFilteredJoin extends Logging with PredicateHelper {
case FilteredOperation(predicates, join @ Join(left, right, Inner, condition)) =>
logger.debug(s"Considering hash inner join on: ${predicates ++ condition}")
splitPredicates(predicates ++ condition, join)
// All predicates can be evaluated for left semi join (those that are in the WHERE
// clause can only from left table, so they can all be pushed down.)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in general we should avoid making too many assumptions in the planner about what optimizations have occurred. For example, in the future we might avoid pushing down predicates that are very expensive to evaluate as it might be cheaper to run them on an already filtered set of data. However, in the case of LEFT SEMI JOIN, I think it is actually okay to push all evaluation into the join condition, even if they only refer to the left table. Is that true?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think LEFT SEMI JOIN would not suffer by pushing down predicates.

case FilteredOperation(predicates, join @ Join(left, right, LeftSemi, condition)) =>
logger.debug(s"Considering hash left semi join on: ${predicates ++ condition}")
splitPredicates(predicates ++ condition, join)
case join @ Join(left, right, joinType, condition) =>
logger.debug(s"Considering hash join on: $condition")
splitPredicates(condition.toSeq, join)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ case object Inner extends JoinType
case object LeftOuter extends JoinType
case object RightOuter extends JoinType
case object FullOuter extends JoinType
case object LeftSemi extends JoinType
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.{LeftSemi, JoinType}
import org.apache.spark.sql.catalyst.types._

case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
Expand Down Expand Up @@ -81,7 +81,12 @@ case class Join(
condition: Option[Expression]) extends BinaryNode {

def references = condition.map(_.references).getOrElse(Set.empty)
def output = left.output ++ right.output
def output = joinType match {
case LeftSemi =>
left.output
case _ =>
left.output ++ right.output
}
}

case class InsertIntoTable(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
val strategies: Seq[Strategy] =
TakeOrdered ::
PartialAggregation ::
LeftSemiJoin ::
HashJoin ::
ParquetOperations ::
BasicOperators ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,22 @@ import org.apache.spark.sql.parquet._
private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
self: SQLContext#SparkPlanner =>

object LeftSemiJoin extends Strategy with PredicateHelper {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
// Find left semi joins where at least some predicates can be evaluated by matching hash
// keys using the HashFilteredJoin pattern.
case HashFilteredJoin(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
val semiJoin = execution.LeftSemiJoinHash(
leftKeys, rightKeys, planLater(left), planLater(right))
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
// no predicate can be evaluated by matching hash keys
case logical.Join(left, right, LeftSemi, condition) =>
execution.LeftSemiJoinBNL(
planLater(left), planLater(right), condition)(sparkContext) :: Nil
case _ => Nil
}
}

object HashJoin extends Strategy with PredicateHelper {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
// Find inner joins where at least some predicates can be evaluated by matching hash keys
Expand Down
131 changes: 131 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,137 @@ case class HashJoin(
}
}

/**
* :: DeveloperApi ::
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize that we aren't particularly good about this in most of the other physical operators, but could you add some Scala doc here about how this operator works and what the expected performance characteristics are? Same below. The goal of the Scala doc for physical operators should be to make it easy for people to understand query plans that are printed out by EXPLAIN.

* Build the right table's join keys into a HashSet, and iteratively go through the left
* table, to find the if join keys are in the Hash set.
*/
@DeveloperApi
case class LeftSemiJoinHash(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
left: SparkPlan,
right: SparkPlan) extends BinaryNode {

override def outputPartitioning: Partitioning = left.outputPartitioning

override def requiredChildDistribution =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

val (buildPlan, streamedPlan) = (right, left)
val (buildKeys, streamedKeys) = (rightKeys, leftKeys)

def output = left.output

@transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output)
@transient lazy val streamSideKeyGenerator =
() => new MutableProjection(streamedKeys, streamedPlan.output)

def execute() = {

buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
val hashTable = new java.util.HashSet[Row]()
var currentRow: Row = null

// Create a Hash set of buildKeys
while (buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if(!rowKey.anyNull) {
val keyExists = hashTable.contains(rowKey)
if (!keyExists) {
hashTable.add(rowKey)
}
}
}

new Iterator[Row] {
private[this] var currentStreamedRow: Row = _
private[this] var currentHashMatched: Boolean = false

private[this] val joinKeys = streamSideKeyGenerator()

override final def hasNext: Boolean =
streamIter.hasNext && fetchNext()

override final def next() = {
currentStreamedRow
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct if the operator is created with BuildLeft instead of BuildRight? I think that would turn it into a RightSemiJoin. Perhaps we should just remove the option to build on the other side. I think you can then also safely simplify this to use a HashSet instead of a HashMap, which will reduce memory consumption significantly.

}

/**
* Searches the streamed iterator for the next row that has at least one match in hashtable.
*
* @return true if the search is successful, and false the streamed iterator runs out of
* tuples.
*/
private final def fetchNext(): Boolean = {
currentHashMatched = false
while (!currentHashMatched && streamIter.hasNext) {
currentStreamedRow = streamIter.next()
if (!joinKeys(currentStreamedRow).anyNull) {
currentHashMatched = hashTable.contains(joinKeys.currentValue)
}
}
currentHashMatched
}
}
}
}
}

/**
* :: DeveloperApi ::
* Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys
* for hash join.
*/
@DeveloperApi
case class LeftSemiJoinBNL(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this operator is exercised by the included test cases. We should add a test where the join condition can be calculated with hash keys.

streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression])
(@transient sc: SparkContext)
extends BinaryNode {
// TODO: Override requiredChildDistribution.

override def outputPartitioning: Partitioning = streamed.outputPartitioning

override def otherCopyArgs = sc :: Nil

def output = left.output

/** The Streamed Relation */
def left = streamed
/** The Broadcast relation */
def right = broadcast

@transient lazy val boundCondition =
InterpretedPredicate(
condition
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
.getOrElse(Literal(true)))


def execute() = {
val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)

streamed.execute().mapPartitions { streamedIter =>
val joinedRow = new JoinedRow

streamedIter.filter(streamedRow => {
var i = 0
var matched = false

while (i < broadcastedRelation.value.size && !matched) {
val broadcastedRow = broadcastedRelation.value(i)
if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
matched = true
}
i += 1
}
matched
})
}
}
}

/**
* :: DeveloperApi ::
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class QueryTest extends FunSuite {
fail(
s"""
|Exception thrown while executing query:
|${rdd.logicalPlan}
|${rdd.queryExecution}
|== Exception ==
|$e
""".stripMargin)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ class SQLQuerySuite extends QueryTest {
arrayData.map(d => (d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect().toSeq)
}

test("left semi greater than predicate") {
checkAnswer(
sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"),
Seq((3,1), (3,2))
)
}

test("index into array of arrays") {
checkAnswer(
sql(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
DataSinks,
Scripts,
PartialAggregation,
LeftSemiJoin,
HashJoin,
BasicOperators,
CartesianProduct,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,7 @@ private[hive] object HiveQl {
case "TOK_RIGHTOUTERJOIN" => RightOuter
case "TOK_LEFTOUTERJOIN" => LeftOuter
case "TOK_FULLOUTERJOIN" => FullOuter
case "TOK_LEFTSEMIJOIN" => LeftSemi
}
assert(other.size <= 1, "Unhandled join clauses.")
Join(nodeToRelation(relation1),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Hank 2
Hank 2
Joe 2
Joe 2
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Hank 2
Hank 2
Joe 2
Joe 2
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
2 Tie
2 Tie
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
1
1
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
1
1
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
1
1
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"lateral_view_cp",
"lateral_view_outer",
"lateral_view_ppd",
"leftsemijoin",
"leftsemijoin_mr",
"lineage1",
"literal_double",
"literal_ints",
Expand Down