diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 0fdcbf7bfb3c3..96159ec1b1feb 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -232,4 +232,4 @@ public UnsafeSorterIterator getSortedIterator() throws IOException { spillMerger.addSpill(sorter.getSortedIterator()); return spillMerger.getSortedIterator(); } -} \ No newline at end of file +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java index 94ee6699ef101..2fb41fb2d402f 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java @@ -26,8 +26,8 @@ final class UnsafeSorterSpillMerger { private final PriorityQueue priorityQueue; public UnsafeSorterSpillMerger( - final RecordComparator recordComparator, - final PrefixComparator prefixComparator) { + final RecordComparator recordComparator, + final PrefixComparator prefixComparator) { final Comparator comparator = new Comparator() { @Override diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/UnsafeSortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/UnsafeSortMergeJoin.scala new file mode 100644 index 0000000000000..73a9bc6a2cf60 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/UnsafeSortMergeJoin.scala @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import java.util.NoSuchElementException + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{UnsafeExternalSort, BinaryNode, SparkPlan} +import org.apache.spark.util.collection.CompactBuffer + +/** + * :: DeveloperApi :: + * Performs an sort merge join of two child relations. + * TODO(josh): Document + */ +@DeveloperApi +case class UnsafeSortMergeJoin( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def output: Seq[Attribute] = left.output ++ right.output + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + // this is to manually construct an ordering that can be used to compare keys from both sides + private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType)) + + override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys) + + @transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output) + @transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output) + + private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = + keys.map(SortOrder(_, Ascending)) + + protected override def doExecute(): RDD[Row] = { + // Note that we purposely do not require out input to be sorted. Instead, we'll sort it + // ourselves using UnsafeExternalSorter. Not requiring the input to be sorted will prevent the + // Exchange from pushing the sort into the shuffle, which will allow the shuffle to benefit from + // Project Tungsten's shuffle optimizations which currently cannot be applied to shuffles that + // specify a key ordering. + + // Only sort if necessary: + val leftOrder = requiredOrders(leftKeys) + val leftResults = { + if (left.outputOrdering == leftOrder) { + left.execute().map(_.copy()) + } else { + new UnsafeExternalSort(leftOrder, global = false, left).execute() + } + } + val rightOrder = requiredOrders(rightKeys) + val rightResults = { + if (right.outputOrdering == rightOrder) { + right.execute().map(_.copy()) + } else { + new UnsafeExternalSort(rightOrder, global = false, right).execute() + } + } + + leftResults.zipPartitions(rightResults) { (leftIter, rightIter) => + new Iterator[Row] { + // Mutable per row objects. + private[this] val joinRow = new JoinedRow5 + private[this] var leftElement: Row = _ + private[this] var rightElement: Row = _ + private[this] var leftKey: Row = _ + private[this] var rightKey: Row = _ + private[this] var rightMatches: CompactBuffer[Row] = _ + private[this] var rightPosition: Int = -1 + private[this] var stop: Boolean = false + private[this] var matchKey: Row = _ + + // initialize iterator + initialize() + + override final def hasNext: Boolean = nextMatchingPair() + + override final def next(): Row = { + if (hasNext) { + // we are using the buffered right rows and run down left iterator + val joinedRow = joinRow(leftElement, rightMatches(rightPosition)) + rightPosition += 1 + if (rightPosition >= rightMatches.size) { + rightPosition = 0 + fetchLeft() + if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) { + stop = false + rightMatches = null + } + } + joinedRow + } else { + // no more result + throw new NoSuchElementException + } + } + + private def fetchLeft() = { + if (leftIter.hasNext) { + leftElement = leftIter.next() + println(leftElement) + leftKey = leftKeyGenerator(leftElement) + } else { + leftElement = null + } + } + + private def fetchRight() = { + if (rightIter.hasNext) { + rightElement = rightIter.next() + println(right) + rightKey = rightKeyGenerator(rightElement) + } else { + rightElement = null + } + } + + private def initialize() = { + fetchLeft() + fetchRight() + } + + /** + * Searches the right iterator for the next rows that have matches in left side, and store + * them in a buffer. + * + * @return true if the search is successful, and false if the right iterator runs out of + * tuples. + */ + private def nextMatchingPair(): Boolean = { + if (!stop && rightElement != null) { + // run both side to get the first match pair + while (!stop && leftElement != null && rightElement != null) { + val comparing = keyOrdering.compare(leftKey, rightKey) + // for inner join, we need to filter those null keys + stop = comparing == 0 && !leftKey.anyNull + if (comparing > 0 || rightKey.anyNull) { + fetchRight() + } else if (comparing < 0 || leftKey.anyNull) { + fetchLeft() + } + } + rightMatches = new CompactBuffer[Row]() + if (stop) { + stop = false + // iterate the right side to buffer all rows that matches + // as the records should be ordered, exit when we meet the first that not match + while (!stop && rightElement != null) { + rightMatches += rightElement + fetchRight() + stop = keyOrdering.compare(leftKey, rightKey) != 0 + } + if (rightMatches.size > 0) { + rightPosition = 0 + matchKey = leftKey + } + } + } + rightMatches != null && rightMatches.size > 0 + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeSortMergeJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeSortMergeJoinSuite.scala new file mode 100644 index 0000000000000..4ce0537f02418 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeSortMergeJoinSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.execution.UnsafeExternalSort +import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.test.TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext.implicits._ +import org.scalatest.BeforeAndAfterEach + +class UnsafeSortMergeJoinSuite extends QueryTest with BeforeAndAfterEach { + // Ensures tables are loaded. + TestData + + conf.setConf(SQLConf.SORTMERGE_JOIN, "true") + conf.setConf(SQLConf.CODEGEN_ENABLED, "true") + conf.setConf(SQLConf.UNSAFE_ENABLED, "true") + conf.setConf(SQLConf.EXTERNAL_SORT, "true") + conf.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, "-1") + + test("basic sort merge join test") { + val df = upperCaseData.join(lowerCaseData, $"n" === $"N") + print(df.queryExecution.optimizedPlan) + assert(df.queryExecution.sparkPlan.collect { + case smj: UnsafeSortMergeJoin => smj + }.nonEmpty) + checkAnswer( + df, + Seq( + Row(1, "A", 1, "a"), + Row(2, "B", 2, "b"), + Row(3, "C", 3, "c"), + Row(4, "D", 4, "d") + )) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala new file mode 100644 index 0000000000000..f5a6368a2b16a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.scalatest.{FunSuite, Matchers} + +import org.apache.spark.sql.{SQLConf, SQLContext, Row} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.implicits._ + +class UnsafeExternalSortSuite extends FunSuite with Matchers { + + private def createRow(values: Any*): Row = { + new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray) + } + + test("basic sorting") { + val sc = TestSQLContext.sparkContext + val sqlContext = new SQLContext(sc) + sqlContext.conf.setConf(SQLConf.CODEGEN_ENABLED, "true") + + val schema: StructType = StructType( + StructField("word", StringType, nullable = false) :: + StructField("number", IntegerType, nullable = false) :: Nil) + val sortOrder: Seq[SortOrder] = Seq( + SortOrder(BoundReference(0, StringType, nullable = false), Ascending), + SortOrder(BoundReference(1, IntegerType, nullable = false), Descending)) + val rowsToSort: Seq[Row] = Seq( + createRow("Hello", 9), + createRow("World", 4), + createRow("Hello", 7), + createRow("Skinny", 0), + createRow("Constantinople", 9)) + SparkPlan.currentContext.set(sqlContext) + val input = + new PhysicalRDD(schema.toAttributes.map(_.toAttribute), sc.parallelize(rowsToSort, 1)) + // Treat the existing sort operators as the source-of-truth for this test + val defaultSorted = new Sort(sortOrder, global = false, input).executeCollect() + val externalSorted = new ExternalSort(sortOrder, global = false, input).executeCollect() + val unsafeSorted = new UnsafeExternalSort(sortOrder, global = false, input).executeCollect() + assert (defaultSorted === externalSorted) + assert (unsafeSorted === externalSorted) + } +} diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/UnsafeSortMergeCompatibiltySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/UnsafeSortMergeCompatibiltySuite.scala new file mode 100644 index 0000000000000..093ce3504e2c2 --- /dev/null +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/UnsafeSortMergeCompatibiltySuite.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.hive.test.TestHive + +/** + * Runs the test cases that are included in the hive distribution with sort merge join and + * unsafe external sort enabled. + */ +class UnsafeSortMergeCompatibiltySuite extends SortMergeCompatibilitySuite { + override def beforeAll() { + super.beforeAll() + TestHive.setConf(SQLConf.CODEGEN_ENABLED, "true") + TestHive.setConf(SQLConf.UNSAFE_ENABLED, "true") + TestHive.setConf(SQLConf.EXTERNAL_SORT, "true") + } + + override def afterAll() { + TestHive.setConf(SQLConf.CODEGEN_ENABLED, "false") + TestHive.setConf(SQLConf.UNSAFE_ENABLED, "false") + TestHive.setConf(SQLConf.EXTERNAL_SORT, "false") + super.afterAll() + } +} \ No newline at end of file