Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-44325][SQL] Use PartitionEvaluator API in SortMergeJoinExec #41884

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.joins

import org.apache.spark.{PartitionEvaluator, PartitionEvaluatorFactory}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, JoinedRow, Predicate, Projection, RowOrdering, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, RowIterator, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetric

class SortMergeJoinEvaluatorFactory(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
condition: Option[Expression],
left: SparkPlan,
right: SparkPlan,
output: Seq[Attribute],
inMemoryThreshold: Int,
spillThreshold: Int,
numOutputRows: SQLMetric,
spillSize: SQLMetric,
onlyBufferFirstMatchedRow: Boolean)
extends PartitionEvaluatorFactory[InternalRow, InternalRow] {
override def createEvaluator(): PartitionEvaluator[InternalRow, InternalRow] =
new SortMergeJoinEvaluator

private class SortMergeJoinEvaluator extends PartitionEvaluator[InternalRow, InternalRow] {

private def cleanupResources(): Unit = {
IndexedSeq(left, right).foreach(_.cleanupResources())
}
private def createLeftKeyGenerator(): Projection =
UnsafeProjection.create(leftKeys, left.output)

private def createRightKeyGenerator(): Projection =
UnsafeProjection.create(rightKeys, right.output)

override def eval(
partitionIndex: Int,
inputs: Iterator[InternalRow]*): Iterator[InternalRow] = {
assert(inputs.length == 2)
val leftIter = inputs(0)
val rightIter = inputs(1)

val boundCondition: InternalRow => Boolean = {
condition.map { cond =>
Predicate.create(cond, left.output ++ right.output).eval _
}.getOrElse {
(r: InternalRow) => true
}
}

// An ordering that can be used to compare keys from both sides.
val keyOrdering = RowOrdering.createNaturalAscendingOrdering(leftKeys.map(_.dataType))
val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output)

joinType match {
case _: InnerLike =>
new RowIterator {
private[this] var currentLeftRow: InternalRow = _
private[this] var currentRightMatches: ExternalAppendOnlyUnsafeRowArray = _
private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null
private[this] val smjScanner = new SortMergeJoinScanner(
createLeftKeyGenerator(),
createRightKeyGenerator(),
keyOrdering,
RowIterator.fromScala(leftIter),
RowIterator.fromScala(rightIter),
inMemoryThreshold,
spillThreshold,
spillSize,
cleanupResources)
private[this] val joinRow = new JoinedRow

if (smjScanner.findNextInnerJoinRows()) {
currentRightMatches = smjScanner.getBufferedMatches
currentLeftRow = smjScanner.getStreamedRow
rightMatchesIterator = currentRightMatches.generateIterator()
}

override def advanceNext(): Boolean = {
while (rightMatchesIterator != null) {
if (!rightMatchesIterator.hasNext) {
if (smjScanner.findNextInnerJoinRows()) {
currentRightMatches = smjScanner.getBufferedMatches
currentLeftRow = smjScanner.getStreamedRow
rightMatchesIterator = currentRightMatches.generateIterator()
} else {
currentRightMatches = null
currentLeftRow = null
rightMatchesIterator = null
return false
}
}
joinRow(currentLeftRow, rightMatchesIterator.next())
if (boundCondition(joinRow)) {
numOutputRows += 1
return true
}
}
false
}

override def getRow: InternalRow = resultProj(joinRow)
}.toScala

case LeftOuter =>
val smjScanner = new SortMergeJoinScanner(
streamedKeyGenerator = createLeftKeyGenerator(),
bufferedKeyGenerator = createRightKeyGenerator(),
keyOrdering,
streamedIter = RowIterator.fromScala(leftIter),
bufferedIter = RowIterator.fromScala(rightIter),
inMemoryThreshold,
spillThreshold,
spillSize,
cleanupResources)
val rightNullRow = new GenericInternalRow(right.output.length)
new LeftOuterIterator(
smjScanner,
rightNullRow,
boundCondition,
resultProj,
numOutputRows).toScala

case RightOuter =>
val smjScanner = new SortMergeJoinScanner(
streamedKeyGenerator = createRightKeyGenerator(),
bufferedKeyGenerator = createLeftKeyGenerator(),
keyOrdering,
streamedIter = RowIterator.fromScala(rightIter),
bufferedIter = RowIterator.fromScala(leftIter),
inMemoryThreshold,
spillThreshold,
spillSize,
cleanupResources)
val leftNullRow = new GenericInternalRow(left.output.length)
new RightOuterIterator(
smjScanner,
leftNullRow,
boundCondition,
resultProj,
numOutputRows).toScala

case FullOuter =>
val leftNullRow = new GenericInternalRow(left.output.length)
val rightNullRow = new GenericInternalRow(right.output.length)
val smjScanner = new SortMergeFullOuterJoinScanner(
leftKeyGenerator = createLeftKeyGenerator(),
rightKeyGenerator = createRightKeyGenerator(),
keyOrdering,
leftIter = RowIterator.fromScala(leftIter),
rightIter = RowIterator.fromScala(rightIter),
boundCondition,
leftNullRow,
rightNullRow)

new FullOuterIterator(smjScanner, resultProj, numOutputRows).toScala

case LeftSemi =>
new RowIterator {
private[this] var currentLeftRow: InternalRow = _
private[this] val smjScanner = new SortMergeJoinScanner(
createLeftKeyGenerator(),
createRightKeyGenerator(),
keyOrdering,
RowIterator.fromScala(leftIter),
RowIterator.fromScala(rightIter),
inMemoryThreshold,
spillThreshold,
spillSize,
cleanupResources,
onlyBufferFirstMatchedRow)
private[this] val joinRow = new JoinedRow

override def advanceNext(): Boolean = {
while (smjScanner.findNextInnerJoinRows()) {
val currentRightMatches = smjScanner.getBufferedMatches
currentLeftRow = smjScanner.getStreamedRow
if (currentRightMatches != null && currentRightMatches.length > 0) {
val rightMatchesIterator = currentRightMatches.generateIterator()
while (rightMatchesIterator.hasNext) {
joinRow(currentLeftRow, rightMatchesIterator.next())
if (boundCondition(joinRow)) {
numOutputRows += 1
return true
}
}
}
}
false
}

override def getRow: InternalRow = currentLeftRow
}.toScala

case LeftAnti =>
new RowIterator {
private[this] var currentLeftRow: InternalRow = _
private[this] val smjScanner = new SortMergeJoinScanner(
createLeftKeyGenerator(),
createRightKeyGenerator(),
keyOrdering,
RowIterator.fromScala(leftIter),
RowIterator.fromScala(rightIter),
inMemoryThreshold,
spillThreshold,
spillSize,
cleanupResources,
onlyBufferFirstMatchedRow)
private[this] val joinRow = new JoinedRow

override def advanceNext(): Boolean = {
while (smjScanner.findNextOuterJoinRows()) {
currentLeftRow = smjScanner.getStreamedRow
val currentRightMatches = smjScanner.getBufferedMatches
if (currentRightMatches == null || currentRightMatches.length == 0) {
numOutputRows += 1
return true
}
var found = false
val rightMatchesIterator = currentRightMatches.generateIterator()
while (!found && rightMatchesIterator.hasNext) {
joinRow(currentLeftRow, rightMatchesIterator.next())
if (boundCondition(joinRow)) {
found = true
}
}
if (!found) {
numOutputRows += 1
return true
}
}
false
}

override def getRow: InternalRow = currentLeftRow
}.toScala

case j: ExistenceJoin =>
new RowIterator {
private[this] var currentLeftRow: InternalRow = _
private[this] val result: InternalRow = new GenericInternalRow(Array[Any](null))
private[this] val smjScanner = new SortMergeJoinScanner(
createLeftKeyGenerator(),
createRightKeyGenerator(),
keyOrdering,
RowIterator.fromScala(leftIter),
RowIterator.fromScala(rightIter),
inMemoryThreshold,
spillThreshold,
spillSize,
cleanupResources,
onlyBufferFirstMatchedRow)
private[this] val joinRow = new JoinedRow

override def advanceNext(): Boolean = {
while (smjScanner.findNextOuterJoinRows()) {
currentLeftRow = smjScanner.getStreamedRow
val currentRightMatches = smjScanner.getBufferedMatches
var found = false
if (currentRightMatches != null && currentRightMatches.length > 0) {
val rightMatchesIterator = currentRightMatches.generateIterator()
while (!found && rightMatchesIterator.hasNext) {
joinRow(currentLeftRow, rightMatchesIterator.next())
if (boundCondition(joinRow)) {
found = true
}
}
}
result.setBoolean(0, found)
numOutputRows += 1
return true
}
false
}

override def getRow: InternalRow = resultProj(joinRow(currentLeftRow, result))
}.toScala

case x =>
throw new IllegalArgumentException(s"SortMergeJoin should not take $x as the JoinType")
}

}
}
}