Skip to content

Commit

Permalink
Keep sort order of rows after external sorter when writing.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Jan 29, 2017
1 parent 40a4cfc commit 3c040b6
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,18 @@

import javax.annotation.Nullable;
import java.io.IOException;
import java.util.List;

import scala.collection.JavaConverters;
import scala.collection.Seq;

import com.google.common.annotations.VisibleForTesting;

import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.serializer.SerializerManager;
import org.apache.spark.sql.catalyst.expressions.SortOrder;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.catalyst.expressions.codegen.BaseOrdering;
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering;
Expand Down Expand Up @@ -58,7 +63,7 @@ public UnsafeKVExternalSorter(
long pageSizeBytes,
long numElementsForSpillThreshold) throws IOException {
this(keySchema, valueSchema, blockManager, serializerManager, pageSizeBytes,
numElementsForSpillThreshold, null);
numElementsForSpillThreshold, null, null);
}

public UnsafeKVExternalSorter(
Expand All @@ -69,14 +74,34 @@ public UnsafeKVExternalSorter(
long pageSizeBytes,
long numElementsForSpillThreshold,
@Nullable BytesToBytesMap map) throws IOException {
this(keySchema, valueSchema, blockManager, serializerManager, pageSizeBytes,
numElementsForSpillThreshold, map, null);
}

public UnsafeKVExternalSorter(
StructType keySchema,
StructType valueSchema,
BlockManager blockManager,
SerializerManager serializerManager,
long pageSizeBytes,
long numElementsForSpillThreshold,
@Nullable BytesToBytesMap map,
@Nullable List<SortOrder> ordering) throws IOException {
this.keySchema = keySchema;
this.valueSchema = valueSchema;
final TaskContext taskContext = TaskContext.get();

prefixComputer = SortPrefixUtils.createPrefixGenerator(keySchema);
PrefixComparator prefixComparator = SortPrefixUtils.getPrefixComparator(keySchema);
BaseOrdering ordering = GenerateOrdering.create(keySchema);
KVComparator recordComparator = new KVComparator(ordering, keySchema.length());
KVComparator recordComparator = null;
if (ordering == null) {
recordComparator = new KVComparator(GenerateOrdering.create(keySchema), keySchema.length());
} else {
Seq<SortOrder> orderingSeq =
JavaConverters.collectionAsScalaIterableConverter(ordering).asScala().toSeq();
recordComparator = new KVComparator((BaseOrdering)GenerateOrdering.generate(orderingSeq),
ordering.size());
}
boolean canUseRadixSort = keySchema.length() == 1 &&
SortPrefixUtils.canSortFullyWithPrefix(keySchema.apply(0));

Expand Down Expand Up @@ -137,7 +162,7 @@ public UnsafeKVExternalSorter(
blockManager,
serializerManager,
taskContext,
new KVComparator(ordering, keySchema.length()),
recordComparator,
prefixComparator,
SparkEnv.get().conf().getInt("spark.shuffle.sort.initialBufferSize",
UnsafeExternalRowSorter.DEFAULT_INITIAL_SORT_BUFFER_SIZE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources

import java.util.{Date, UUID}

import scala.collection.JavaConverters._
import scala.collection.mutable

import org.apache.hadoop.conf.Configuration
Expand Down Expand Up @@ -68,7 +69,8 @@ object FileFormatWriter extends Logging {
val bucketSpec: Option[BucketSpec],
val path: String,
val customPartitionLocations: Map[TablePartitionSpec, String],
val maxRecordsPerFile: Long)
val maxRecordsPerFile: Long,
val orderingInPartition: Seq[SortOrder])
extends Serializable {

assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns),
Expand Down Expand Up @@ -125,7 +127,8 @@ object FileFormatWriter extends Logging {
path = outputSpec.outputPath,
customPartitionLocations = outputSpec.customPartitionLocations,
maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong)
.getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile)
.getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile),
orderingInPartition = queryExecution.executedPlan.outputOrdering
)

SQLExecution.withNewExecutionId(sparkSession, queryExecution) {
Expand Down Expand Up @@ -368,17 +371,58 @@ object FileFormatWriter extends Logging {
}

override def execute(iter: Iterator[InternalRow]): Set[String] = {
// We should first sort by partition columns, then bucket id, and finally sorting columns.
// If there is sort ordering in the data, we need to keep the ordering.
val orderingExpressions: Seq[Expression] = if (description.orderingInPartition.isEmpty) {
Nil
} else {
description.orderingInPartition.map(_.child)
}

// We should first sort by partition columns, then bucket id, then sort ordering in the data,
// and finally sorting columns.
val sortingExpressions: Seq[Expression] =
description.partitionColumns ++ bucketIdExpression ++ sortColumns
description.partitionColumns ++ bucketIdExpression ++ orderingExpressions ++ sortColumns
val getSortingKey = UnsafeProjection.create(sortingExpressions, description.allColumns)

val sortingKeySchema = StructType(sortingExpressions.map {
case a: Attribute => StructField(a.name, a.dataType, a.nullable)
// The sorting expressions are all `Attribute` except bucket id.
case _ => StructField("bucketId", IntegerType, nullable = false)
val bucketIdExprIndex =
sortingExpressions.length - sortColumns.length - orderingExpressions.length - 1

val sortingKeySchema = StructType(sortingExpressions.zipWithIndex.map { case (e, index) =>
e match {
case a: Attribute => StructField(a.name, a.dataType, a.nullable)
// The sorting expressions are all `Attribute` except bucket id and
// sorting order's children expressions.
case _ if index == bucketIdExprIndex =>
StructField("bucketId", IntegerType, nullable = false)
case _ if index > bucketIdExprIndex =>
StructField(s"_sortOrder_$index", e.dataType, e.nullable)
}
})

val beginSortingExpr =
sortingExpressions.length - sortColumns.length - orderingExpressions.length
val recordSortingOrder =
if (description.orderingInPartition.isEmpty) {
null
} else {
sortingExpressions.zipWithIndex.map { case (field, ordinal) =>
if (ordinal < beginSortingExpr ||
ordinal > beginSortingExpr + orderingExpressions.length) {
// For partition column, bucket id and sort by columns, we sort by ascending.
SortOrder(BoundReference(ordinal, field.dataType, nullable = true), Ascending)
} else {
// For the sort ordering of data, we need to keep its sort direction and
// null ordering.
val direction =
description.orderingInPartition(ordinal - beginSortingExpr).direction
val nullOrdering =
description.orderingInPartition(ordinal - beginSortingExpr).nullOrdering
SortOrder(BoundReference(ordinal, field.dataType, nullable = true),
direction, nullOrdering)
}
}.asJava
}

// Returns the data columns to be written given an input row
val getOutputRow = UnsafeProjection.create(
description.dataColumns, description.allColumns)
Expand All @@ -395,20 +439,25 @@ object FileFormatWriter extends Logging {
SparkEnv.get.serializerManager,
TaskContext.get().taskMemoryManager().pageSizeBytes,
SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold",
UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD))
UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD),
null,
recordSortingOrder)

while (iter.hasNext) {
val currentRow = iter.next()
sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow))
}

val getBucketingKey: InternalRow => InternalRow = if (sortColumns.isEmpty) {
identity
} else {
UnsafeProjection.create(sortingExpressions.dropRight(sortColumns.length).zipWithIndex.map {
case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
})
}
val getBucketingKey: InternalRow => InternalRow =
if (sortColumns.isEmpty && orderingExpressions.isEmpty) {
identity
} else {
val bucketingKeyExprs =
sortingExpressions.dropRight(sortColumns.length + orderingExpressions.length)
UnsafeProjection.create(bucketingKeyExprs.zipWithIndex.map {
case (expr, ordinal) => BoundReference(ordinal, expr.dataType, expr.nullable)
})
}

val sortedIterator = sorter.sortedIterator()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ package org.apache.spark.sql.execution

import java.util.Properties

import scala.collection.JavaConverters._
import scala.util.Random

import org.apache.spark._
import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.{Ascending, BoundReference, Descending, InterpretedOrdering, SortOrder, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
Expand Down Expand Up @@ -110,7 +111,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
valueSchema: StructType,
inputData: Seq[(InternalRow, InternalRow)],
pageSize: Long,
spill: Boolean): Unit = {
spill: Boolean,
sortOrdering: java.util.List[SortOrder] = null): Unit = {
val memoryManager =
new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false"))
val taskMemMgr = new TaskMemoryManager(memoryManager, 0)
Expand All @@ -125,7 +127,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {

val sorter = new UnsafeKVExternalSorter(
keySchema, valueSchema, SparkEnv.get.blockManager, SparkEnv.get.serializerManager,
pageSize, UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD)
pageSize, UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD,
null, sortOrdering)

// Insert the keys and values into the sorter
inputData.foreach { case (k, v) =>
Expand All @@ -145,7 +148,11 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
}
sorter.cleanupResources()

val keyOrdering = InterpretedOrdering.forSchema(keySchema.map(_.dataType))
val keyOrdering = if (sortOrdering == null) {
InterpretedOrdering.forSchema(keySchema.map(_.dataType))
} else {
new InterpretedOrdering(sortOrdering.asScala)
}
val valueOrdering = InterpretedOrdering.forSchema(valueSchema.map(_.dataType))
val kvOrdering = new Ordering[(InternalRow, InternalRow)] {
override def compare(x: (InternalRow, InternalRow), y: (InternalRow, InternalRow)): Int = {
Expand Down Expand Up @@ -204,4 +211,41 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
spill = true
)
}

test("kv sorting with records that exceed page size: with specified sort order") {
val pageSize = 128

val keySchema = StructType(StructField("a", BinaryType) :: StructField("b", BinaryType) :: Nil)
val valueSchema = StructType(StructField("c", BinaryType) :: Nil)
val keyExternalConverter = CatalystTypeConverters.createToCatalystConverter(keySchema)
val valueExternalConverter = CatalystTypeConverters.createToCatalystConverter(valueSchema)
val keyConverter = UnsafeProjection.create(keySchema)
val valueConverter = UnsafeProjection.create(valueSchema)

val rand = new Random()
val inputData = Seq.fill(1024) {
val kBytes1 = new Array[Byte](rand.nextInt(pageSize))
val kBytes2 = new Array[Byte](rand.nextInt(pageSize))
val vBytes = new Array[Byte](rand.nextInt(pageSize))
rand.nextBytes(kBytes1)
rand.nextBytes(kBytes2)
rand.nextBytes(vBytes)
val k =
keyConverter(keyExternalConverter.apply(Row(kBytes1, kBytes2)).asInstanceOf[InternalRow])
val v = valueConverter(valueExternalConverter.apply(Row(vBytes)).asInstanceOf[InternalRow])
(k.asInstanceOf[InternalRow].copy(), v.asInstanceOf[InternalRow].copy())
}

val sortOrder = SortOrder(BoundReference(0, BinaryType, nullable = true), Ascending) ::
SortOrder(BoundReference(1, BinaryType, nullable = true), Descending) :: Nil

testKVSorter(
keySchema,
valueSchema,
inputData,
pageSize,
spill = true,
sortOrder.asJava
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,36 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi
}
}

test("SPARK-19352: Keep sort order of rows after external sorter when writing") {
spark.stop()
// Explicitly set memory configuration to force `UnsafeKVExternalSorter` to spill to files
// when inserting data.
val newSpark = SparkSession.builder()
.master("local")
.appName("test")
.config("spark.buffer.pageSize", "16b")
.config("spark.testing.memory", "1400")
.config("spark.memory.fraction", "0.1")
.config("spark.shuffle.sort.initialBufferSize", "2")
.config("spark.memory.offHeap.enabled", "false")
.getOrCreate()
withTempPath { path =>
val tempDir = path.getCanonicalPath
val df = newSpark.range(100)
.select($"id", explode(array(col("id") + 1, col("id") + 2, col("id") + 3)).as("value"))
.repartition($"id")
.sortWithinPartitions($"value".desc).toDF()

df.write
.partitionBy("id")
.parquet(tempDir)

val dfReadIn = newSpark.read.parquet(tempDir).select("id", "value")
checkAnswer(df.filter("id = 65"), dfReadIn.filter("id = 65"))
}
newSpark.stop()
}

// Helpers for checking the arguments passed to the FileFormat.

protected val checkPartitionSchema =
Expand Down

0 comments on commit 3c040b6

Please sign in to comment.