diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index ed2a7bdfe2cce..1e7f7129f4b9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.util.collection.ExternalSorter +import org.apache.spark.util.collection.unsafe.sort.PrefixComparator import org.apache.spark.util.{CompletionIterator, MutablePair} import org.apache.spark.{HashPartitioner, SparkEnv} @@ -274,7 +275,18 @@ case class UnsafeExternalSort( def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = { val ordering = newOrdering(sortOrder, child.output) val boundSortExpression = BindReferences.bindReference(sortOrder.head, child.output) - val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) + // Hack until we generate separate comparator implementations for ascending vs. descending + // (or choose to codegen them): + val prefixComparator = { + val comp = SortPrefixUtils.getPrefixComparator(boundSortExpression) + if (sortOrder.head.direction == Descending) { + new PrefixComparator { + override def compare(p1: Long, p2: Long): Int = -1 * comp.compare(p1, p2) + } + } else { + comp + } + } val prefixComputer = { val prefixComputer = SortPrefixUtils.getPrefixComputer(boundSortExpression) new UnsafeExternalRowSorter.PrefixComputer { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index 5233c73638c85..1dd81ee4e9fb6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -38,9 +38,10 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { // Test sorting on different data types for ( - dataType <- DataTypeTestUtils.atomicTypes; // Disable null type for now due to bug in SqlSerializer2 ++ Set(NullType); + dataType <- DataTypeTestUtils.atomicTypes // Disable null type for now due to bug in SqlSerializer2 ++ Set(NullType); + if !dataType.isInstanceOf[DecimalType]; // Since we don't have an unsafe representation for decimals nullable <- Seq(true, false); - sortOrder <- Seq('a.asc :: Nil); + sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil); randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) ) { test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") {