diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinEvaluatorFactory.scala new file mode 100644 index 0000000000000..57ca135407d4a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinEvaluatorFactory.scala @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.{PartitionEvaluator, PartitionEvaluatorFactory} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, JoinedRow, Predicate, Projection, RowOrdering, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, RowIterator, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetric + +class SortMergeJoinEvaluatorFactory( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan, + output: Seq[Attribute], + inMemoryThreshold: Int, + spillThreshold: Int, + numOutputRows: SQLMetric, + spillSize: SQLMetric, + onlyBufferFirstMatchedRow: Boolean) + extends PartitionEvaluatorFactory[InternalRow, InternalRow] { + override def createEvaluator(): PartitionEvaluator[InternalRow, InternalRow] = + new SortMergeJoinEvaluator + + private class SortMergeJoinEvaluator extends PartitionEvaluator[InternalRow, InternalRow] { + + private def cleanupResources(): Unit = { + IndexedSeq(left, right).foreach(_.cleanupResources()) + } + private def createLeftKeyGenerator(): Projection = + UnsafeProjection.create(leftKeys, left.output) + + private def createRightKeyGenerator(): Projection = + UnsafeProjection.create(rightKeys, right.output) + + override def eval( + partitionIndex: Int, + inputs: Iterator[InternalRow]*): Iterator[InternalRow] = { + assert(inputs.length == 2) + val leftIter = inputs(0) + val rightIter = inputs(1) + + val boundCondition: InternalRow => Boolean = { + condition.map { cond => + Predicate.create(cond, left.output ++ right.output).eval _ + }.getOrElse { + (r: InternalRow) => true + } + } + + // An ordering that can be used to compare keys from both sides. + val keyOrdering = RowOrdering.createNaturalAscendingOrdering(leftKeys.map(_.dataType)) + val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output) + + joinType match { + case _: InnerLike => + new RowIterator { + private[this] var currentLeftRow: InternalRow = _ + private[this] var currentRightMatches: ExternalAppendOnlyUnsafeRowArray = _ + private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null + private[this] val smjScanner = new SortMergeJoinScanner( + createLeftKeyGenerator(), + createRightKeyGenerator(), + keyOrdering, + RowIterator.fromScala(leftIter), + RowIterator.fromScala(rightIter), + inMemoryThreshold, + spillThreshold, + spillSize, + cleanupResources) + private[this] val joinRow = new JoinedRow + + if (smjScanner.findNextInnerJoinRows()) { + currentRightMatches = smjScanner.getBufferedMatches + currentLeftRow = smjScanner.getStreamedRow + rightMatchesIterator = currentRightMatches.generateIterator() + } + + override def advanceNext(): Boolean = { + while (rightMatchesIterator != null) { + if (!rightMatchesIterator.hasNext) { + if (smjScanner.findNextInnerJoinRows()) { + currentRightMatches = smjScanner.getBufferedMatches + currentLeftRow = smjScanner.getStreamedRow + rightMatchesIterator = currentRightMatches.generateIterator() + } else { + currentRightMatches = null + currentLeftRow = null + rightMatchesIterator = null + return false + } + } + joinRow(currentLeftRow, rightMatchesIterator.next()) + if (boundCondition(joinRow)) { + numOutputRows += 1 + return true + } + } + false + } + + override def getRow: InternalRow = resultProj(joinRow) + }.toScala + + case LeftOuter => + val smjScanner = new SortMergeJoinScanner( + streamedKeyGenerator = createLeftKeyGenerator(), + bufferedKeyGenerator = createRightKeyGenerator(), + keyOrdering, + streamedIter = RowIterator.fromScala(leftIter), + bufferedIter = RowIterator.fromScala(rightIter), + inMemoryThreshold, + spillThreshold, + spillSize, + cleanupResources) + val rightNullRow = new GenericInternalRow(right.output.length) + new LeftOuterIterator( + smjScanner, + rightNullRow, + boundCondition, + resultProj, + numOutputRows).toScala + + case RightOuter => + val smjScanner = new SortMergeJoinScanner( + streamedKeyGenerator = createRightKeyGenerator(), + bufferedKeyGenerator = createLeftKeyGenerator(), + keyOrdering, + streamedIter = RowIterator.fromScala(rightIter), + bufferedIter = RowIterator.fromScala(leftIter), + inMemoryThreshold, + spillThreshold, + spillSize, + cleanupResources) + val leftNullRow = new GenericInternalRow(left.output.length) + new RightOuterIterator( + smjScanner, + leftNullRow, + boundCondition, + resultProj, + numOutputRows).toScala + + case FullOuter => + val leftNullRow = new GenericInternalRow(left.output.length) + val rightNullRow = new GenericInternalRow(right.output.length) + val smjScanner = new SortMergeFullOuterJoinScanner( + leftKeyGenerator = createLeftKeyGenerator(), + rightKeyGenerator = createRightKeyGenerator(), + keyOrdering, + leftIter = RowIterator.fromScala(leftIter), + rightIter = RowIterator.fromScala(rightIter), + boundCondition, + leftNullRow, + rightNullRow) + + new FullOuterIterator(smjScanner, resultProj, numOutputRows).toScala + + case LeftSemi => + new RowIterator { + private[this] var currentLeftRow: InternalRow = _ + private[this] val smjScanner = new SortMergeJoinScanner( + createLeftKeyGenerator(), + createRightKeyGenerator(), + keyOrdering, + RowIterator.fromScala(leftIter), + RowIterator.fromScala(rightIter), + inMemoryThreshold, + spillThreshold, + spillSize, + cleanupResources, + onlyBufferFirstMatchedRow) + private[this] val joinRow = new JoinedRow + + override def advanceNext(): Boolean = { + while (smjScanner.findNextInnerJoinRows()) { + val currentRightMatches = smjScanner.getBufferedMatches + currentLeftRow = smjScanner.getStreamedRow + if (currentRightMatches != null && currentRightMatches.length > 0) { + val rightMatchesIterator = currentRightMatches.generateIterator() + while (rightMatchesIterator.hasNext) { + joinRow(currentLeftRow, rightMatchesIterator.next()) + if (boundCondition(joinRow)) { + numOutputRows += 1 + return true + } + } + } + } + false + } + + override def getRow: InternalRow = currentLeftRow + }.toScala + + case LeftAnti => + new RowIterator { + private[this] var currentLeftRow: InternalRow = _ + private[this] val smjScanner = new SortMergeJoinScanner( + createLeftKeyGenerator(), + createRightKeyGenerator(), + keyOrdering, + RowIterator.fromScala(leftIter), + RowIterator.fromScala(rightIter), + inMemoryThreshold, + spillThreshold, + spillSize, + cleanupResources, + onlyBufferFirstMatchedRow) + private[this] val joinRow = new JoinedRow + + override def advanceNext(): Boolean = { + while (smjScanner.findNextOuterJoinRows()) { + currentLeftRow = smjScanner.getStreamedRow + val currentRightMatches = smjScanner.getBufferedMatches + if (currentRightMatches == null || currentRightMatches.length == 0) { + numOutputRows += 1 + return true + } + var found = false + val rightMatchesIterator = currentRightMatches.generateIterator() + while (!found && rightMatchesIterator.hasNext) { + joinRow(currentLeftRow, rightMatchesIterator.next()) + if (boundCondition(joinRow)) { + found = true + } + } + if (!found) { + numOutputRows += 1 + return true + } + } + false + } + + override def getRow: InternalRow = currentLeftRow + }.toScala + + case j: ExistenceJoin => + new RowIterator { + private[this] var currentLeftRow: InternalRow = _ + private[this] val result: InternalRow = new GenericInternalRow(Array[Any](null)) + private[this] val smjScanner = new SortMergeJoinScanner( + createLeftKeyGenerator(), + createRightKeyGenerator(), + keyOrdering, + RowIterator.fromScala(leftIter), + RowIterator.fromScala(rightIter), + inMemoryThreshold, + spillThreshold, + spillSize, + cleanupResources, + onlyBufferFirstMatchedRow) + private[this] val joinRow = new JoinedRow + + override def advanceNext(): Boolean = { + while (smjScanner.findNextOuterJoinRows()) { + currentLeftRow = smjScanner.getStreamedRow + val currentRightMatches = smjScanner.getBufferedMatches + var found = false + if (currentRightMatches != null && currentRightMatches.length > 0) { + val rightMatchesIterator = currentRightMatches.generateIterator() + while (!found && rightMatchesIterator.hasNext) { + joinRow(currentLeftRow, rightMatchesIterator.next()) + if (boundCondition(joinRow)) { + found = true + } + } + } + result.setBoolean(0, found) + numOutputRows += 1 + return true + } + false + } + + override def getRow: InternalRow = resultProj(joinRow(currentLeftRow, result)) + }.toScala + + case x => + throw new IllegalArgumentException(s"SortMergeJoin should not take $x as the JoinType") + } + + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index ad2e179d6c221..0241f683d6902 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -99,12 +99,6 @@ case class SortMergeJoinExec( keys.map(SortOrder(_, Ascending)) } - private def createLeftKeyGenerator(): Projection = - UnsafeProjection.create(leftKeys, left.output) - - private def createRightKeyGenerator(): Projection = - UnsafeProjection.create(rightKeys, right.output) - private def getSpillThreshold: Int = { conf.sortMergeJoinExecBufferSpillThreshold } @@ -128,249 +122,27 @@ case class SortMergeJoinExec( val spillSize = longMetric("spillSize") val spillThreshold = getSpillThreshold val inMemoryThreshold = getInMemoryThreshold - left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => - val boundCondition: (InternalRow) => Boolean = { - condition.map { cond => - Predicate.create(cond, left.output ++ right.output).eval _ - }.getOrElse { - (r: InternalRow) => true - } - } - - // An ordering that can be used to compare keys from both sides. - val keyOrdering = RowOrdering.createNaturalAscendingOrdering(leftKeys.map(_.dataType)) - val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output) - - joinType match { - case _: InnerLike => - new RowIterator { - private[this] var currentLeftRow: InternalRow = _ - private[this] var currentRightMatches: ExternalAppendOnlyUnsafeRowArray = _ - private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null - private[this] val smjScanner = new SortMergeJoinScanner( - createLeftKeyGenerator(), - createRightKeyGenerator(), - keyOrdering, - RowIterator.fromScala(leftIter), - RowIterator.fromScala(rightIter), - inMemoryThreshold, - spillThreshold, - spillSize, - cleanupResources - ) - private[this] val joinRow = new JoinedRow - - if (smjScanner.findNextInnerJoinRows()) { - currentRightMatches = smjScanner.getBufferedMatches - currentLeftRow = smjScanner.getStreamedRow - rightMatchesIterator = currentRightMatches.generateIterator() - } - - override def advanceNext(): Boolean = { - while (rightMatchesIterator != null) { - if (!rightMatchesIterator.hasNext) { - if (smjScanner.findNextInnerJoinRows()) { - currentRightMatches = smjScanner.getBufferedMatches - currentLeftRow = smjScanner.getStreamedRow - rightMatchesIterator = currentRightMatches.generateIterator() - } else { - currentRightMatches = null - currentLeftRow = null - rightMatchesIterator = null - return false - } - } - joinRow(currentLeftRow, rightMatchesIterator.next()) - if (boundCondition(joinRow)) { - numOutputRows += 1 - return true - } - } - false - } - - override def getRow: InternalRow = resultProj(joinRow) - }.toScala - - case LeftOuter => - val smjScanner = new SortMergeJoinScanner( - streamedKeyGenerator = createLeftKeyGenerator(), - bufferedKeyGenerator = createRightKeyGenerator(), - keyOrdering, - streamedIter = RowIterator.fromScala(leftIter), - bufferedIter = RowIterator.fromScala(rightIter), - inMemoryThreshold, - spillThreshold, - spillSize, - cleanupResources - ) - val rightNullRow = new GenericInternalRow(right.output.length) - new LeftOuterIterator( - smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows).toScala - - case RightOuter => - val smjScanner = new SortMergeJoinScanner( - streamedKeyGenerator = createRightKeyGenerator(), - bufferedKeyGenerator = createLeftKeyGenerator(), - keyOrdering, - streamedIter = RowIterator.fromScala(rightIter), - bufferedIter = RowIterator.fromScala(leftIter), - inMemoryThreshold, - spillThreshold, - spillSize, - cleanupResources - ) - val leftNullRow = new GenericInternalRow(left.output.length) - new RightOuterIterator( - smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala - - case FullOuter => - val leftNullRow = new GenericInternalRow(left.output.length) - val rightNullRow = new GenericInternalRow(right.output.length) - val smjScanner = new SortMergeFullOuterJoinScanner( - leftKeyGenerator = createLeftKeyGenerator(), - rightKeyGenerator = createRightKeyGenerator(), - keyOrdering, - leftIter = RowIterator.fromScala(leftIter), - rightIter = RowIterator.fromScala(rightIter), - boundCondition, - leftNullRow, - rightNullRow) - - new FullOuterIterator( - smjScanner, - resultProj, - numOutputRows).toScala - - case LeftSemi => - new RowIterator { - private[this] var currentLeftRow: InternalRow = _ - private[this] val smjScanner = new SortMergeJoinScanner( - createLeftKeyGenerator(), - createRightKeyGenerator(), - keyOrdering, - RowIterator.fromScala(leftIter), - RowIterator.fromScala(rightIter), - inMemoryThreshold, - spillThreshold, - spillSize, - cleanupResources, - onlyBufferFirstMatchedRow - ) - private[this] val joinRow = new JoinedRow - - override def advanceNext(): Boolean = { - while (smjScanner.findNextInnerJoinRows()) { - val currentRightMatches = smjScanner.getBufferedMatches - currentLeftRow = smjScanner.getStreamedRow - if (currentRightMatches != null && currentRightMatches.length > 0) { - val rightMatchesIterator = currentRightMatches.generateIterator() - while (rightMatchesIterator.hasNext) { - joinRow(currentLeftRow, rightMatchesIterator.next()) - if (boundCondition(joinRow)) { - numOutputRows += 1 - return true - } - } - } - } - false - } - - override def getRow: InternalRow = currentLeftRow - }.toScala - - case LeftAnti => - new RowIterator { - private[this] var currentLeftRow: InternalRow = _ - private[this] val smjScanner = new SortMergeJoinScanner( - createLeftKeyGenerator(), - createRightKeyGenerator(), - keyOrdering, - RowIterator.fromScala(leftIter), - RowIterator.fromScala(rightIter), - inMemoryThreshold, - spillThreshold, - spillSize, - cleanupResources, - onlyBufferFirstMatchedRow - ) - private[this] val joinRow = new JoinedRow - - override def advanceNext(): Boolean = { - while (smjScanner.findNextOuterJoinRows()) { - currentLeftRow = smjScanner.getStreamedRow - val currentRightMatches = smjScanner.getBufferedMatches - if (currentRightMatches == null || currentRightMatches.length == 0) { - numOutputRows += 1 - return true - } - var found = false - val rightMatchesIterator = currentRightMatches.generateIterator() - while (!found && rightMatchesIterator.hasNext) { - joinRow(currentLeftRow, rightMatchesIterator.next()) - if (boundCondition(joinRow)) { - found = true - } - } - if (!found) { - numOutputRows += 1 - return true - } - } - false - } - - override def getRow: InternalRow = currentLeftRow - }.toScala - - case j: ExistenceJoin => - new RowIterator { - private[this] var currentLeftRow: InternalRow = _ - private[this] val result: InternalRow = new GenericInternalRow(Array[Any](null)) - private[this] val smjScanner = new SortMergeJoinScanner( - createLeftKeyGenerator(), - createRightKeyGenerator(), - keyOrdering, - RowIterator.fromScala(leftIter), - RowIterator.fromScala(rightIter), - inMemoryThreshold, - spillThreshold, - spillSize, - cleanupResources, - onlyBufferFirstMatchedRow - ) - private[this] val joinRow = new JoinedRow - - override def advanceNext(): Boolean = { - while (smjScanner.findNextOuterJoinRows()) { - currentLeftRow = smjScanner.getStreamedRow - val currentRightMatches = smjScanner.getBufferedMatches - var found = false - if (currentRightMatches != null && currentRightMatches.length > 0) { - val rightMatchesIterator = currentRightMatches.generateIterator() - while (!found && rightMatchesIterator.hasNext) { - joinRow(currentLeftRow, rightMatchesIterator.next()) - if (boundCondition(joinRow)) { - found = true - } - } - } - result.setBoolean(0, found) - numOutputRows += 1 - return true - } - false - } - - override def getRow: InternalRow = resultProj(joinRow(currentLeftRow, result)) - }.toScala - - case x => - throw new IllegalArgumentException( - s"SortMergeJoin should not take $x as the JoinType") + val evaluatorFactory = new SortMergeJoinEvaluatorFactory( + leftKeys, + rightKeys, + joinType, + condition, + left, + right, + output, + inMemoryThreshold, + spillThreshold, + numOutputRows, + spillSize, + onlyBufferFirstMatchedRow + ) + if (conf.usePartitionEvaluator) { + left.execute().zipPartitionsWithEvaluator(right.execute(), evaluatorFactory) + } else { + left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => + val evaluator = evaluatorFactory.createEvaluator() + evaluator.eval(0, leftIter, rightIter) } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 0654aea5105c2..c496a5ae5d80e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -182,10 +182,14 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession { testWithWholeStageCodegenOnAndOff(s"$testName using SortMergeJoin") { _ => extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => - makeSortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + Seq(true, false).foreach { enable => + withSQLConf(SQLConf.USE_PARTITION_EVALUATOR.key -> enable.toString) { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeSortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } } } }