Skip to content

Commit

Permalink
[SPARK-28571][CORE][SHUFFLE] Use the shuffle writer plugin for the So…
Browse files Browse the repository at this point in the history
…rtShuffleWriter

## What changes were proposed in this pull request?

Use the shuffle writer APIs introduced in SPARK-28209 in the sort shuffle writer.

## How was this patch tested?

Existing unit tests were changed to use the plugin instead, and they used the local disk version to ensure that there were no regressions.

Closes #25342 from mccheah/shuffle-writer-refactor-sort-shuffle-writer.

Lead-authored-by: mcheah <mcheah@palantir.com>
Co-authored-by: mccheah <mcheah@palantir.com>
Signed-off-by: Marcelo Vanzin <vanzin@cloudera.com>
  • Loading branch information
mccheah authored and Marcelo Vanzin committed Aug 30, 2019
1 parent 92cabf6 commit ea90ea6
Show file tree
Hide file tree
Showing 8 changed files with 265 additions and 31 deletions.
@@ -0,0 +1,126 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle

import java.io.{Closeable, IOException, OutputStream}

import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.api.ShufflePartitionWriter
import org.apache.spark.storage.BlockId
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.PairsWriter

/**
* A key-value writer inspired by {@link DiskBlockObjectWriter} that pushes the bytes to an
* arbitrary partition writer instead of writing to local disk through the block manager.
*/
private[spark] class ShufflePartitionPairsWriter(
partitionWriter: ShufflePartitionWriter,
serializerManager: SerializerManager,
serializerInstance: SerializerInstance,
blockId: BlockId,
writeMetrics: ShuffleWriteMetricsReporter)
extends PairsWriter with Closeable {

private var isClosed = false
private var partitionStream: OutputStream = _
private var wrappedStream: OutputStream = _
private var objOut: SerializationStream = _
private var numRecordsWritten = 0
private var curNumBytesWritten = 0L

override def write(key: Any, value: Any): Unit = {
if (isClosed) {
throw new IOException("Partition pairs writer is already closed.")
}
if (objOut == null) {
open()
}
objOut.writeKey(key)
objOut.writeValue(value)
recordWritten()
}

private def open(): Unit = {
try {
partitionStream = partitionWriter.openStream
wrappedStream = serializerManager.wrapStream(blockId, partitionStream)
objOut = serializerInstance.serializeStream(wrappedStream)
} catch {
case e: Exception =>
Utils.tryLogNonFatalError {
close()
}
throw e
}
}

override def close(): Unit = {
if (!isClosed) {
Utils.tryWithSafeFinally {
Utils.tryWithSafeFinally {
objOut = closeIfNonNull(objOut)
// Setting these to null will prevent the underlying streams from being closed twice
// just in case any stream's close() implementation is not idempotent.
wrappedStream = null
partitionStream = null
} {
// Normally closing objOut would close the inner streams as well, but just in case there
// was an error in initialization etc. we make sure we clean the other streams up too.
Utils.tryWithSafeFinally {
wrappedStream = closeIfNonNull(wrappedStream)
// Same as above - if wrappedStream closes then assume it closes underlying
// partitionStream and don't close again in the finally
partitionStream = null
} {
partitionStream = closeIfNonNull(partitionStream)
}
}
updateBytesWritten()
} {
isClosed = true
}
}
}

private def closeIfNonNull[T <: Closeable](closeable: T): T = {
if (closeable != null) {
closeable.close()
}
null.asInstanceOf[T]
}

/**
* Notify the writer that a record worth of bytes has been written with OutputStream#write.
*/
private def recordWritten(): Unit = {
numRecordsWritten += 1
writeMetrics.incRecordsWritten(1)

if (numRecordsWritten % 16384 == 0) {
updateBytesWritten()
}
}

private def updateBytesWritten(): Unit = {
val numBytesWritten = partitionWriter.getNumBytesWritten
val bytesWrittenDiff = numBytesWritten - curNumBytesWritten
writeMetrics.incBytesWritten(bytesWrittenDiff)
curNumBytesWritten = numBytesWritten
}
}
Expand Up @@ -157,7 +157,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
metrics,
shuffleExecutorComponents)
case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
new SortShuffleWriter(
shuffleBlockResolver, other, mapId, context, shuffleExecutorComponents)
}
}

Expand Down
Expand Up @@ -21,15 +21,15 @@ import org.apache.spark._
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter}
import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.util.Utils
import org.apache.spark.shuffle.api.ShuffleExecutorComponents
import org.apache.spark.util.collection.ExternalSorter

private[spark] class SortShuffleWriter[K, V, C](
shuffleBlockResolver: IndexShuffleBlockResolver,
handle: BaseShuffleHandle[K, V, C],
mapId: Int,
context: TaskContext)
context: TaskContext,
shuffleExecutorComponents: ShuffleExecutorComponents)
extends ShuffleWriter[K, V] with Logging {

private val dep = handle.dependency
Expand Down Expand Up @@ -64,18 +64,11 @@ private[spark] class SortShuffleWriter[K, V, C](
// Don't bother including the time to open the merged output file in the shuffle write time,
// because it just opens a single file, so is typically too fast to measure accurately
// (see SPARK-3570).
val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
val tmp = Utils.tempFileWith(output)
try {
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
} finally {
if (tmp.exists() && !tmp.delete()) {
logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
}
}
val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter(
dep.shuffleId, mapId, context.taskAttemptId(), dep.partitioner.numPartitions)
sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
val partitionLengths = mapOutputWriter.commitAllPartitions()
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
}

/** Close this writer, passing along whether the map completed */
Expand Down
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.PairsWriter

/**
* A class for writing JVM objects directly to a file on disk. This class allows data to be appended
Expand All @@ -46,7 +47,8 @@ private[spark] class DiskBlockObjectWriter(
writeMetrics: ShuffleWriteMetricsReporter,
val blockId: BlockId = null)
extends OutputStream
with Logging {
with Logging
with PairsWriter {

/**
* Guards against close calls, e.g. from a wrapping stream.
Expand Down Expand Up @@ -232,7 +234,7 @@ private[spark] class DiskBlockObjectWriter(
/**
* Writes a key-value pair.
*/
def write(key: Any, value: Any) {
override def write(key: Any, value: Any) {
if (!streamOpen) {
open()
}
Expand Down
Expand Up @@ -23,13 +23,16 @@ import java.util.Comparator
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import com.google.common.io.ByteStreams
import com.google.common.io.{ByteStreams, Closeables}

import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.serializer._
import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
import org.apache.spark.shuffle.ShufflePartitionPairsWriter
import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter}
import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId}
import org.apache.spark.util.{Utils => TryUtils}

/**
* Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
Expand Down Expand Up @@ -670,11 +673,9 @@ private[spark] class ExternalSorter[K, V, C](
}

/**
* Write all the data added into this ExternalSorter into a file in the disk store. This is
* called by the SortShuffleWriter.
*
* @param blockId block ID to write to. The index file will be blockId.name + ".index".
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
* TODO(SPARK-28764): remove this, as this is only used by UnsafeRowSerializerSuite in the SQL
* project. We should figure out an alternative way to test that so that we can remove this
* otherwise unused code path.
*/
def writePartitionedFile(
blockId: BlockId,
Expand Down Expand Up @@ -718,6 +719,77 @@ private[spark] class ExternalSorter[K, V, C](
lengths
}

/**
* Write all the data added into this ExternalSorter into a map output writer that pushes bytes
* to some arbitrary backing store. This is called by the SortShuffleWriter.
*
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
*/
def writePartitionedMapOutput(
shuffleId: Int,
mapId: Int,
mapOutputWriter: ShuffleMapOutputWriter): Unit = {
var nextPartitionId = 0
if (spills.isEmpty) {
// Case where we only have in-memory data
val collection = if (aggregator.isDefined) map else buffer
val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
while (it.hasNext()) {
val partitionId = it.nextPartition()
var partitionWriter: ShufflePartitionWriter = null
var partitionPairsWriter: ShufflePartitionPairsWriter = null
TryUtils.tryWithSafeFinally {
partitionWriter = mapOutputWriter.getPartitionWriter(partitionId)
val blockId = ShuffleBlockId(shuffleId, mapId, partitionId)
partitionPairsWriter = new ShufflePartitionPairsWriter(
partitionWriter,
serializerManager,
serInstance,
blockId,
context.taskMetrics().shuffleWriteMetrics)
while (it.hasNext && it.nextPartition() == partitionId) {
it.writeNext(partitionPairsWriter)
}
} {
if (partitionPairsWriter != null) {
partitionPairsWriter.close()
}
}
nextPartitionId = partitionId + 1
}
} else {
// We must perform merge-sort; get an iterator by partition and write everything directly.
for ((id, elements) <- this.partitionedIterator) {
val blockId = ShuffleBlockId(shuffleId, mapId, id)
var partitionWriter: ShufflePartitionWriter = null
var partitionPairsWriter: ShufflePartitionPairsWriter = null
TryUtils.tryWithSafeFinally {
partitionWriter = mapOutputWriter.getPartitionWriter(id)
partitionPairsWriter = new ShufflePartitionPairsWriter(
partitionWriter,
serializerManager,
serInstance,
blockId,
context.taskMetrics().shuffleWriteMetrics)
if (elements.hasNext) {
for (elem <- elements) {
partitionPairsWriter.write(elem._1, elem._2)
}
}
} {
if (partitionPairsWriter != null) {
partitionPairsWriter.close()
}
}
nextPartitionId = id + 1
}
}

context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)
}

def stop(): Unit = {
spills.foreach(s => s.file.delete())
spills.clear()
Expand Down Expand Up @@ -781,7 +853,7 @@ private[spark] class ExternalSorter[K, V, C](
val inMemoryIterator = new WritablePartitionedIterator {
private[this] var cur = if (upstream.hasNext) upstream.next() else null

def writeNext(writer: DiskBlockObjectWriter): Unit = {
def writeNext(writer: PairsWriter): Unit = {
writer.write(cur._1._2, cur._2)
cur = if (upstream.hasNext) upstream.next() else null
}
Expand Down
@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util.collection

/**
* An abstraction of a consumer of key-value pairs, primarily used when
* persisting partitioned data, either through the shuffle writer plugins
* or via DiskBlockObjectWriter.
*/
private[spark] trait PairsWriter {

def write(key: Any, value: Any): Unit
}
Expand Up @@ -52,7 +52,7 @@ private[spark] trait WritablePartitionedPairCollection[K, V] {
new WritablePartitionedIterator {
private[this] var cur = if (it.hasNext) it.next() else null

def writeNext(writer: DiskBlockObjectWriter): Unit = {
def writeNext(writer: PairsWriter): Unit = {
writer.write(cur._1._2, cur._2)
cur = if (it.hasNext) it.next() else null
}
Expand Down Expand Up @@ -89,7 +89,7 @@ private[spark] object WritablePartitionedPairCollection {
* has an associated partition.
*/
private[spark] trait WritablePartitionedIterator {
def writeNext(writer: DiskBlockObjectWriter): Unit
def writeNext(writer: PairsWriter): Unit

def hasNext(): Boolean

Expand Down

0 comments on commit ea90ea6

Please sign in to comment.