diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala new file mode 100644 index 0000000000000..a988c5e126a76 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import java.io.{Closeable, IOException, OutputStream} + +import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} +import org.apache.spark.shuffle.api.ShufflePartitionWriter +import org.apache.spark.storage.BlockId +import org.apache.spark.util.Utils +import org.apache.spark.util.collection.PairsWriter + +/** + * A key-value writer inspired by {@link DiskBlockObjectWriter} that pushes the bytes to an + * arbitrary partition writer instead of writing to local disk through the block manager. + */ +private[spark] class ShufflePartitionPairsWriter( + partitionWriter: ShufflePartitionWriter, + serializerManager: SerializerManager, + serializerInstance: SerializerInstance, + blockId: BlockId, + writeMetrics: ShuffleWriteMetricsReporter) + extends PairsWriter with Closeable { + + private var isClosed = false + private var partitionStream: OutputStream = _ + private var wrappedStream: OutputStream = _ + private var objOut: SerializationStream = _ + private var numRecordsWritten = 0 + private var curNumBytesWritten = 0L + + override def write(key: Any, value: Any): Unit = { + if (isClosed) { + throw new IOException("Partition pairs writer is already closed.") + } + if (objOut == null) { + open() + } + objOut.writeKey(key) + objOut.writeValue(value) + recordWritten() + } + + private def open(): Unit = { + try { + partitionStream = partitionWriter.openStream + wrappedStream = serializerManager.wrapStream(blockId, partitionStream) + objOut = serializerInstance.serializeStream(wrappedStream) + } catch { + case e: Exception => + Utils.tryLogNonFatalError { + close() + } + throw e + } + } + + override def close(): Unit = { + if (!isClosed) { + Utils.tryWithSafeFinally { + Utils.tryWithSafeFinally { + objOut = closeIfNonNull(objOut) + // Setting these to null will prevent the underlying streams from being closed twice + // just in case any stream's close() implementation is not idempotent. + wrappedStream = null + partitionStream = null + } { + // Normally closing objOut would close the inner streams as well, but just in case there + // was an error in initialization etc. we make sure we clean the other streams up too. + Utils.tryWithSafeFinally { + wrappedStream = closeIfNonNull(wrappedStream) + // Same as above - if wrappedStream closes then assume it closes underlying + // partitionStream and don't close again in the finally + partitionStream = null + } { + partitionStream = closeIfNonNull(partitionStream) + } + } + updateBytesWritten() + } { + isClosed = true + } + } + } + + private def closeIfNonNull[T <: Closeable](closeable: T): T = { + if (closeable != null) { + closeable.close() + } + null.asInstanceOf[T] + } + + /** + * Notify the writer that a record worth of bytes has been written with OutputStream#write. + */ + private def recordWritten(): Unit = { + numRecordsWritten += 1 + writeMetrics.incRecordsWritten(1) + + if (numRecordsWritten % 16384 == 0) { + updateBytesWritten() + } + } + + private def updateBytesWritten(): Unit = { + val numBytesWritten = partitionWriter.getNumBytesWritten + val bytesWrittenDiff = numBytesWritten - curNumBytesWritten + writeMetrics.incBytesWritten(bytesWrittenDiff) + curNumBytesWritten = numBytesWritten + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 17719f516a0a1..2a99c93b32af4 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -157,7 +157,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager metrics, shuffleExecutorComponents) case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => - new SortShuffleWriter(shuffleBlockResolver, other, mapId, context) + new SortShuffleWriter( + shuffleBlockResolver, other, mapId, context, shuffleExecutorComponents) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 16058de8bf3ff..a781b16252432 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -21,15 +21,15 @@ import org.apache.spark._ import org.apache.spark.internal.{config, Logging} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter} -import org.apache.spark.storage.ShuffleBlockId -import org.apache.spark.util.Utils +import org.apache.spark.shuffle.api.ShuffleExecutorComponents import org.apache.spark.util.collection.ExternalSorter private[spark] class SortShuffleWriter[K, V, C]( shuffleBlockResolver: IndexShuffleBlockResolver, handle: BaseShuffleHandle[K, V, C], mapId: Int, - context: TaskContext) + context: TaskContext, + shuffleExecutorComponents: ShuffleExecutorComponents) extends ShuffleWriter[K, V] with Logging { private val dep = handle.dependency @@ -64,18 +64,11 @@ private[spark] class SortShuffleWriter[K, V, C]( // Don't bother including the time to open the merged output file in the shuffle write time, // because it just opens a single file, so is typically too fast to measure accurately // (see SPARK-3570). - val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId) - val tmp = Utils.tempFileWith(output) - try { - val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) - val partitionLengths = sorter.writePartitionedFile(blockId, tmp) - shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) - } finally { - if (tmp.exists() && !tmp.delete()) { - logError(s"Error while deleting temp file ${tmp.getAbsolutePath}") - } - } + val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter( + dep.shuffleId, mapId, context.taskAttemptId(), dep.partitioner.numPartitions) + sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) + val partitionLengths = mapOutputWriter.commitAllPartitions() + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) } /** Close this writer, passing along whether the map completed */ diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index 17390f9c60e79..758621c52495b 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -24,6 +24,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.ShuffleWriteMetricsReporter import org.apache.spark.util.Utils +import org.apache.spark.util.collection.PairsWriter /** * A class for writing JVM objects directly to a file on disk. This class allows data to be appended @@ -46,7 +47,8 @@ private[spark] class DiskBlockObjectWriter( writeMetrics: ShuffleWriteMetricsReporter, val blockId: BlockId = null) extends OutputStream - with Logging { + with Logging + with PairsWriter { /** * Guards against close calls, e.g. from a wrapping stream. @@ -232,7 +234,7 @@ private[spark] class DiskBlockObjectWriter( /** * Writes a key-value pair. */ - def write(key: Any, value: Any) { + override def write(key: Any, value: Any) { if (!streamOpen) { open() } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 3f3b7d20eb169..7a822e137e556 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -23,13 +23,16 @@ import java.util.Comparator import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import com.google.common.io.ByteStreams +import com.google.common.io.{ByteStreams, Closeables} import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer._ -import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} +import org.apache.spark.shuffle.ShufflePartitionPairsWriter +import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter} +import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId} +import org.apache.spark.util.{Utils => TryUtils} /** * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner @@ -670,11 +673,9 @@ private[spark] class ExternalSorter[K, V, C]( } /** - * Write all the data added into this ExternalSorter into a file in the disk store. This is - * called by the SortShuffleWriter. - * - * @param blockId block ID to write to. The index file will be blockId.name + ".index". - * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) + * TODO(SPARK-28764): remove this, as this is only used by UnsafeRowSerializerSuite in the SQL + * project. We should figure out an alternative way to test that so that we can remove this + * otherwise unused code path. */ def writePartitionedFile( blockId: BlockId, @@ -718,6 +719,77 @@ private[spark] class ExternalSorter[K, V, C]( lengths } + /** + * Write all the data added into this ExternalSorter into a map output writer that pushes bytes + * to some arbitrary backing store. This is called by the SortShuffleWriter. + * + * @return array of lengths, in bytes, of each partition of the file (used by map output tracker) + */ + def writePartitionedMapOutput( + shuffleId: Int, + mapId: Int, + mapOutputWriter: ShuffleMapOutputWriter): Unit = { + var nextPartitionId = 0 + if (spills.isEmpty) { + // Case where we only have in-memory data + val collection = if (aggregator.isDefined) map else buffer + val it = collection.destructiveSortedWritablePartitionedIterator(comparator) + while (it.hasNext()) { + val partitionId = it.nextPartition() + var partitionWriter: ShufflePartitionWriter = null + var partitionPairsWriter: ShufflePartitionPairsWriter = null + TryUtils.tryWithSafeFinally { + partitionWriter = mapOutputWriter.getPartitionWriter(partitionId) + val blockId = ShuffleBlockId(shuffleId, mapId, partitionId) + partitionPairsWriter = new ShufflePartitionPairsWriter( + partitionWriter, + serializerManager, + serInstance, + blockId, + context.taskMetrics().shuffleWriteMetrics) + while (it.hasNext && it.nextPartition() == partitionId) { + it.writeNext(partitionPairsWriter) + } + } { + if (partitionPairsWriter != null) { + partitionPairsWriter.close() + } + } + nextPartitionId = partitionId + 1 + } + } else { + // We must perform merge-sort; get an iterator by partition and write everything directly. + for ((id, elements) <- this.partitionedIterator) { + val blockId = ShuffleBlockId(shuffleId, mapId, id) + var partitionWriter: ShufflePartitionWriter = null + var partitionPairsWriter: ShufflePartitionPairsWriter = null + TryUtils.tryWithSafeFinally { + partitionWriter = mapOutputWriter.getPartitionWriter(id) + partitionPairsWriter = new ShufflePartitionPairsWriter( + partitionWriter, + serializerManager, + serInstance, + blockId, + context.taskMetrics().shuffleWriteMetrics) + if (elements.hasNext) { + for (elem <- elements) { + partitionPairsWriter.write(elem._1, elem._2) + } + } + } { + if (partitionPairsWriter != null) { + partitionPairsWriter.close() + } + } + nextPartitionId = id + 1 + } + } + + context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled) + context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes) + } + def stop(): Unit = { spills.foreach(s => s.file.delete()) spills.clear() @@ -781,7 +853,7 @@ private[spark] class ExternalSorter[K, V, C]( val inMemoryIterator = new WritablePartitionedIterator { private[this] var cur = if (upstream.hasNext) upstream.next() else null - def writeNext(writer: DiskBlockObjectWriter): Unit = { + def writeNext(writer: PairsWriter): Unit = { writer.write(cur._1._2, cur._2) cur = if (upstream.hasNext) upstream.next() else null } diff --git a/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala b/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala new file mode 100644 index 0000000000000..05ed72c3e3778 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/PairsWriter.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +/** + * An abstraction of a consumer of key-value pairs, primarily used when + * persisting partitioned data, either through the shuffle writer plugins + * or via DiskBlockObjectWriter. + */ +private[spark] trait PairsWriter { + + def write(key: Any, value: Any): Unit +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala index dd7f68fd038d2..da8d58d05b6b9 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala @@ -52,7 +52,7 @@ private[spark] trait WritablePartitionedPairCollection[K, V] { new WritablePartitionedIterator { private[this] var cur = if (it.hasNext) it.next() else null - def writeNext(writer: DiskBlockObjectWriter): Unit = { + def writeNext(writer: PairsWriter): Unit = { writer.write(cur._1._2, cur._2) cur = if (it.hasNext) it.next() else null } @@ -89,7 +89,7 @@ private[spark] object WritablePartitionedPairCollection { * has an associated partition. */ private[spark] trait WritablePartitionedIterator { - def writeNext(writer: DiskBlockObjectWriter): Unit + def writeNext(writer: PairsWriter): Unit def hasNext(): Boolean diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala index 690bcd9905257..0dd6040808f9e 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala @@ -17,24 +17,32 @@ package org.apache.spark.shuffle.sort +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Answers.RETURNS_SMART_NULLS import org.mockito.Mockito._ -import org.mockito.MockitoAnnotations import org.scalatest.Matchers import org.apache.spark.{Partitioner, SharedSparkContext, ShuffleDependency, SparkFunSuite} import org.apache.spark.memory.MemoryTestingUtils import org.apache.spark.serializer.JavaSerializer import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver} +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents +import org.apache.spark.storage.BlockManager import org.apache.spark.util.Utils class SortShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with Matchers { + @Mock(answer = RETURNS_SMART_NULLS) + private var blockManager: BlockManager = _ + private val shuffleId = 0 private val numMaps = 5 private var shuffleHandle: BaseShuffleHandle[Int, Int, Int] = _ private val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) private val serializer = new JavaSerializer(conf) + private var shuffleExecutorComponents: ShuffleExecutorComponents = _ override def beforeEach(): Unit = { super.beforeEach() @@ -51,6 +59,8 @@ class SortShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with when(dependency.keyOrdering).thenReturn(None) new BaseShuffleHandle(shuffleId, numMaps = numMaps, dependency) } + shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( + conf, blockManager, shuffleBlockResolver) } override def afterAll(): Unit = { @@ -67,7 +77,8 @@ class SortShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with shuffleBlockResolver, shuffleHandle, mapId = 1, - context) + context, + shuffleExecutorComponents) writer.write(Iterator.empty) writer.stop(success = true) val dataFile = shuffleBlockResolver.getDataFile(shuffleId, 1) @@ -84,7 +95,8 @@ class SortShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with shuffleBlockResolver, shuffleHandle, mapId = 2, - context) + context, + shuffleExecutorComponents) writer.write(records.toIterator) writer.stop(success = true) val dataFile = shuffleBlockResolver.getDataFile(shuffleId, 2)