diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 5acc66e12063..4b4f0a8760bf 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -98,6 +98,7 @@ final class BypassMergeSortShuffleWriter private final long mapId; private final Serializer serializer; private final ShuffleExecutorComponents shuffleExecutorComponents; + private final boolean remoteWrites; /** Array of file writers, one for each partition */ private DiskBlockObjectWriter[] partitionWriters; @@ -136,6 +137,7 @@ final class BypassMergeSortShuffleWriter this.mapId = mapId; this.shuffleId = dep.shuffleId(); this.partitioner = dep.partitioner(); + this.remoteWrites = dep.shuffleWriterProcessor() instanceof org.apache.spark.sql.execution.exchange.ConsolidationShuffleMarker; this.numPartitions = partitioner.numPartitions(); this.writeMetrics = writeMetrics; this.serializer = dep.serializer(); @@ -149,12 +151,14 @@ public void write(Iterator> records) throws IOException { assert (partitionWriters == null); ShuffleMapOutputWriter mapOutputWriter = shuffleExecutorComponents .createMapOutputWriter(shuffleId, mapId, numPartitions); + BlockManagerId blockManagerId = remoteWrites ? + RemoteShuffleStorage.BLOCK_MANAGER_ID() : blockManager.shuffleServerId(); try { if (!records.hasNext()) { partitionLengths = mapOutputWriter.commitAllPartitions( ShuffleChecksumHelper.EMPTY_CHECKSUM_VALUE).getPartitionLengths(); mapStatus = MapStatus$.MODULE$.apply( - blockManager.shuffleServerId(), partitionLengths, mapId, getAggregatedChecksumValue()); + blockManagerId, partitionLengths, mapId, getAggregatedChecksumValue()); return; } final SerializerInstance serInstance = serializer.newInstance(); @@ -196,7 +200,7 @@ public void write(Iterator> records) throws IOException { partitionLengths = writePartitionedData(mapOutputWriter); mapStatus = MapStatus$.MODULE$.apply( - blockManager.shuffleServerId(), partitionLengths, mapId, getAggregatedChecksumValue()); + blockManagerId, partitionLengths, mapId, getAggregatedChecksumValue()); } catch (Exception e) { try { mapOutputWriter.abort(e); @@ -236,8 +240,10 @@ private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) thro try { for (int i = 0; i < numPartitions; i++) { final File file = partitionWriterSegments[i].file(); - ShufflePartitionWriter writer = mapOutputWriter.getPartitionWriter(i); if (file.exists()) { + // TODO: Remove thsi comment: the line below was moved so that assertions + // cann be added and in general safe + ShufflePartitionWriter writer = mapOutputWriter.getPartitionWriter(i); if (transferToEnabled) { // Using WritableByteChannelWrapper to make resource closing consistent between // this implementation and UnsafeShuffleWriter. diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index e3ecfed32348..805bebce1e66 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -26,6 +26,8 @@ import java.nio.channels.WritableByteChannel; import java.util.Iterator; +import org.apache.spark.storage.BlockManagerId; +import org.apache.spark.storage.RemoteShuffleStorage; import scala.Option; import scala.Product2; import scala.jdk.javaapi.CollectionConverters; @@ -89,6 +91,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final boolean transferToEnabled; private final int initialSortBufferSize; private final int mergeBufferSizeInBytes; + private final boolean remoteWrites; @Nullable private MapStatus mapStatus; @Nullable private ShuffleExternalSorter sorter; @@ -135,6 +138,7 @@ public UnsafeShuffleWriter( this.shuffleId = dep.shuffleId(); this.serializer = dep.serializer().newInstance(); this.partitioner = dep.partitioner(); + this.remoteWrites = dep.shuffleWriterProcessor() instanceof org.apache.spark.sql.execution.exchange.ConsolidationShuffleMarker; this.writeMetrics = writeMetrics; this.shuffleExecutorComponents = shuffleExecutorComponents; this.taskContext = taskContext; @@ -247,8 +251,10 @@ void closeAndWriteOutput() throws IOException { } } } + BlockManagerId blockManagerId = remoteWrites ? + RemoteShuffleStorage.BLOCK_MANAGER_ID() : blockManager.shuffleServerId(); mapStatus = MapStatus$.MODULE$.apply( - blockManager.shuffleServerId(), partitionLengths, mapId, getAggregatedChecksumValue()); + blockManagerId, partitionLengths, mapId, getAggregatedChecksumValue()); } @VisibleForTesting diff --git a/core/src/main/java/org/apache/spark/storage/FileSystemManagedBuffer.java b/core/src/main/java/org/apache/spark/storage/FileSystemManagedBuffer.java new file mode 100644 index 000000000000..03ff14137f47 --- /dev/null +++ b/core/src/main/java/org/apache/spark/storage/FileSystemManagedBuffer.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * A {@link ManagedBuffer} backed by a file using Hadoop FileSystem. + * This buffer creates an input stream with a 64MB buffer size for efficient reading. + * Note: This implementation throws UnsupportedOperationException for methods that + * require loading the entire file into memory (nioByteBuffer, convertToNetty, convertToNettyForSsl) + * as files can be very large and loading them entirely into memory is not practical. + */ +public class FileSystemManagedBuffer extends ManagedBuffer { + private int bufferSize; // 64MB buffer size + private final Path filePath; + private final long fileSize; + private final Configuration hadoopConf; + + public FileSystemManagedBuffer(Path filePath, Configuration hadoopConf) throws IOException { + this.filePath = filePath; + this.hadoopConf = hadoopConf; + // Get file size using FileSystem.newInstance to avoid cached dependencies + FileSystem fileSystem = FileSystem.newInstance(filePath.toUri(), hadoopConf); + try { + this.fileSize = fileSystem.getFileStatus(filePath).getLen(); + } finally { + fileSystem.close(); + } + bufferSize = 64; + } + + public FileSystemManagedBuffer(Path filePath, Configuration hadoopConf, int bufferSize) + throws IOException { + this(filePath, hadoopConf); + this.bufferSize = bufferSize; + } + + @Override + public long size() { + return fileSize; + } + + @Override + public ByteBuffer nioByteBuffer() throws IOException { + throw new UnsupportedOperationException( + "FileSystemManagedBuffer does not support nioByteBuffer() as it would require loading " + + "the entire file into memory, which is not practical for large files. " + + "Use createInputStream() instead."); + } + + @Override + public InputStream createInputStream() throws IOException { + // Create a new FileSystem instance to avoid cached dependencies + // and create a buffered input stream with 64MB buffer size for efficient reading + FileSystem fileSystem = FileSystem.newInstance(filePath.toUri(), hadoopConf); + return fileSystem.open(filePath, bufferSize * 1024 * 1024); + } + + @Override + public ManagedBuffer retain() { + // FileSystemManagedBuffer doesn't use reference counting, so just return this + return this; + } + + @Override + public ManagedBuffer release() { + // FileSystemManagedBuffer doesn't use reference counting, so just return this + return this; + } + + @Override + public Object convertToNetty() { + throw new UnsupportedOperationException( + "FileSystemManagedBuffer does not support convertToNetty() as it would require loading " + + "the entire file into memory, which is not practical for large files. " + + "Use createInputStream() instead."); + } + + @Override + public Object convertToNettyForSsl() { + throw new UnsupportedOperationException( + "FileSystemManagedBuffer does not support convertToNettyForSsl()" + + " as it would require loading " + + "the entire file into memory, which is not practical for large files. " + + "Use createInputStream() instead."); + } + + @Override + public String toString() { + return "FileSegmentManagedBuffer[file=" + filePath + ",length=" + fileSize + "]"; + } +} diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index c436025e06bb..9b52c9ae8b8d 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -90,7 +90,8 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val mapSideCombine: Boolean = false, val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor, val rowBasedChecksums: Array[RowBasedChecksum] = ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS, - val checksumMismatchFullRetryEnabled: Boolean = false) + val checksumMismatchFullRetryEnabled: Boolean = false + ) extends Dependency[Product2[K, V]] with Logging { def this( @@ -249,7 +250,9 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( ) } - _rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) + if (!shuffleWriterProcessor.isInstanceOf[org.apache.spark.sql.execution.exchange.ConsolidationShuffleMarker]) { + _rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) + } _rdd.sparkContext.shuffleDriverComponents.registerShuffle(shuffleId) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 6f8be49e3959..e983ab4dee80 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -650,6 +650,8 @@ class SparkContext(config: SparkConf) extends Logging { _env.blockManager.initialize(_applicationId) FallbackStorage.registerBlockManagerIfNeeded( _env.blockManager.master, _conf, _hadoopConfiguration) + RemoteShuffleStorage.registerBlockManagerifNeeded(_env.blockManager.master, _conf, + _hadoopConfiguration) // The metrics system for Driver need to be set spark.app.id to app ID. // So it should start after we get app ID from the task scheduler and set spark.app.id. @@ -2377,6 +2379,11 @@ class SparkContext(config: SparkConf) extends Logging { Utils.tryLogNonFatalError { FallbackStorage.cleanUp(_conf, _hadoopConfiguration) } + + Utils.tryLogNonFatalError { + RemoteShuffleStorage.cleanUp(_conf, _hadoopConfiguration) + } + Utils.tryLogNonFatalError { _eventLogger.foreach(_.stop()) } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index e9a3780f0aaa..b747422275c8 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -626,6 +626,13 @@ package object config { .checkValue(_.endsWith(java.io.File.separator), "Path should end with separator.") .createOptional + private[spark] val SHUFFLE_REMOTE_STORAGE_CLEANUP = + ConfigBuilder("spark.shuffle.remote.storage.cleanUp") + .doc("If true, Spark cleans up its fallback storage data during shutting down.") + .version("3.2.0") + .booleanConf + .createWithDefault(false) + private[spark] val STORAGE_DECOMMISSION_SHUFFLE_MAX_DISK_SIZE = ConfigBuilder("spark.storage.decommission.shuffleBlocks.maxDiskSize") .doc("Maximum disk space to use to store shuffle blocks before rejecting remote " + @@ -2905,4 +2912,32 @@ package object config { .checkValue(v => v.forall(Set("stdout", "stderr").contains), "The value only can be one or more of 'stdout, stderr'.") .createWithDefault(Seq("stdout", "stderr")) + + private[spark] val SHUFFLE_REMOTE_STORAGE_PATH = + ConfigBuilder("spark.shuffle.remote.storage.path") + .doc("The location for storing shuffle blocks on remote storage.") + .version("4.1.0") + .stringConf + .checkValue(_.endsWith(java.io.File.separator), "Path should end with separator.") + .createOptional + + private[spark] val REMOTE_SHUFFLE_BUFFER_SIZE = + ConfigBuilder("spark.shuffle.remote.buffer.size") + .version("4.1.0") + .stringConf + .createWithDefault("64M") + + private[spark] val START_REDUCERS_IN_PARALLEL_TO_MAPPER = + ConfigBuilder("spark.shuffle.consolidation.enabled") + .doc("starts reducers in parallel to mappers") + .version("4.1.0") + .booleanConf + .createWithDefault(false) + + private[spark] val EAGERNESS_THRESHOLD_PERCENTAGE = + ConfigBuilder("spark.shuffle.remote.eagerness.percentage") + .doc("Percentage of mapper complet tasks before starting reducers ") + .version("4.1.0") + .intConf + .createWithDefault(20) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 7c8bea31334b..ce7e5af82b55 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1755,7 +1755,13 @@ private[spark] class DAGScheduler( log"${MDC(STAGE, stage)} (${MDC(RDD_ID, stage.rdd)}) (first 15 tasks are " + log"for partitions ${MDC(PARTITION_IDS, tasks.take(15).map(_.partitionId))})") val shuffleId = stage match { - case s: ShuffleMapStage => Some(s.shuffleDep.shuffleId) + case s: ShuffleMapStage => + // hack to prioritize remote shuffle writes + if (properties != null) { + properties.setProperty("remote", + s.shuffleDep.shuffleWriterProcessor.isInstanceOf[org.apache.spark.sql.execution.exchange.ConsolidationShuffleMarker].toString) + } + Some(s.shuffleDep.shuffleId) case _: ResultStage => None } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 97edbb08c7c0..ced5eb3253d8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -53,8 +53,11 @@ private[spark] class Pool( new FairSchedulingAlgorithm() case SchedulingMode.FIFO => new FIFOSchedulingAlgorithm() + case SchedulingMode.WEIGHTED_FIFO => + new WeightedFIFOSchedulingAlgorithm() case _ => - val msg = s"Unsupported scheduling mode: $schedulingMode. Use FAIR or FIFO instead." + val msg = s"Unsupported scheduling mode: $schedulingMode. Use FAIR, FIFO," + + s" or WEIGHTED_FIFO instead." throw new IllegalArgumentException(msg) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala index 18ebbbe78a5b..48261f8ff014 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulingAlgorithm.scala @@ -40,6 +40,25 @@ private[spark] class FIFOSchedulingAlgorithm extends SchedulingAlgorithm { } } +private[spark] class WeightedFIFOSchedulingAlgorithm extends SchedulingAlgorithm { + override def comparator(s1: Schedulable, s2: Schedulable): Boolean = { + val priority1 = s1.priority + val priority2 = s2.priority + var res = math.signum(priority1 - priority2) + if (res == 0) { + if (s1.weight == s2.weight) { + val stageId1 = s1.stageId + val stageId2 = s2.stageId + res = math.signum(stageId1 - stageId2) + } else { + // Higher the weight, earlier should it run(unlike priority) + res = math.signum(s2.weight - s1.weight) + } + } + res < 0 + } +} + private[spark] class FairSchedulingAlgorithm extends SchedulingAlgorithm { override def comparator(s1: Schedulable, s2: Schedulable): Boolean = { val minShare1 = s1.minShare diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala index 75186b6ba4a4..c08804db1e84 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala @@ -20,10 +20,11 @@ package org.apache.spark.scheduler /** * "FAIR" and "FIFO" determines which policy is used * to order tasks amongst a Schedulable's sub-queues + * "WEIGHTED_FIFO" is similar to FIFO but uses weight-based comparison in addition. * "NONE" is used when the a Schedulable has no sub-queues. */ object SchedulingMode extends Enumeration { type SchedulingMode = Value - val FAIR, FIFO, NONE = Value + val FAIR, FIFO, WEIGHTED_FIFO, NONE = Value } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 05bcafdb14d1..11d9aead3a7e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -96,6 +96,7 @@ private[spark] class ShuffleMapTask( val rdd = rddAndDep._1 val dep = rddAndDep._2 + // While we use the old shuffle fetch protocol, we use partitionId as mapId in the // ShuffleBlockId construction. val mapId = if (SparkEnv.get.conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) { @@ -115,3 +116,4 @@ private[spark] class ShuffleMapTask( override def toString: String = "ShuffleMapTask(%d, %d)".format(stageId, partitionId) } + diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 1351d8c778b5..ffa537374544 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -213,6 +213,8 @@ private[spark] class TaskSchedulerImpl( schedulingMode match { case SchedulingMode.FIFO => new FIFOSchedulableBuilder(rootPool) + case SchedulingMode.WEIGHTED_FIFO => + new FIFOSchedulableBuilder(rootPool) case SchedulingMode.FAIR => new FairSchedulableBuilder(rootPool, sc) case _ => diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 69e0a10a34b2..ef19c44f71c5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler import java.io.NotSerializableException import java.nio.ByteBuffer +import java.util.Locale import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue, TimeUnit} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -133,7 +134,10 @@ private[spark] class TaskSetManager( val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) private[scheduler] var tasksSuccessful = 0 - val weight = 1 + val weight = { + val remote: String = taskSet.properties.getOrDefault("remote", "false").toString + if (remote.toLowerCase(Locale.ROOT).equals("true")) 1000 else 1 + } val minShare = 0 var priority = taskSet.priority val stageId = taskSet.stageId diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala index 47d54ae4f10b..d7fa8a762bf7 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala @@ -48,6 +48,8 @@ private[spark] class ShuffleWriteProcessor extends Serializable with Logging { context: TaskContext): MapStatus = { var writer: ShuffleWriter[Any, Any] = null try { + context.getLocalProperties.setProperty("consolidation.write", + dep.shuffleWriterProcessor.isInstanceOf[org.apache.spark.sql.execution.exchange.ConsolidationShuffleMarker].toString) val manager = SparkEnv.get.shuffleManager writer = manager.getWriter[Any, Any]( dep.shuffleHandle, @@ -85,6 +87,7 @@ private[spark] class ShuffleWriteProcessor extends Serializable with Logging { } } } + mapStatus.get } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index a7ac20016a0e..12f52711b9de 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -24,6 +24,7 @@ import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter} import org.apache.spark.shuffle.ShuffleWriteMetricsReporter import org.apache.spark.shuffle.api.ShuffleExecutorComponents import org.apache.spark.shuffle.checksum.RowBasedChecksum +import org.apache.spark.storage.RemoteShuffleStorage import org.apache.spark.util.collection.ExternalSorter private[spark] class SortShuffleWriter[K, V, C]( @@ -84,8 +85,14 @@ private[spark] class SortShuffleWriter[K, V, C]( dep.shuffleId, mapId, dep.partitioner.numPartitions) sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter, writeMetrics) partitionLengths = mapOutputWriter.commitAllPartitions(sorter.getChecksums).getPartitionLengths + val blockManagerId = if (dep.shuffleWriterProcessor.isInstanceOf[org.apache.spark.sql.execution.exchange.ConsolidationShuffleMarker]) { + RemoteShuffleStorage.BLOCK_MANAGER_ID + } else { + blockManager.shuffleServerId + } mapStatus = - MapStatus(blockManager.shuffleServerId, partitionLengths, mapId, getAggregatedChecksumValue) + MapStatus(blockManagerId, partitionLengths, mapId, getAggregatedChecksumValue) + mapStatus = MapStatus(blockManagerId, partitionLengths, mapId) } /** Close this writer, passing along whether the map completed */ diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/remote/HybridShuffleDataIO.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/remote/HybridShuffleDataIO.scala new file mode 100644 index 000000000000..ca72cbe80bf8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/remote/HybridShuffleDataIO.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.remote + +import org.apache.spark.SparkConf +import org.apache.spark.shuffle.api.ShuffleDataIO +import org.apache.spark.shuffle.api.ShuffleDriverComponents +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleDriverComponents +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents + +/** + * Implementation of the [[ShuffleDataIO]] plugin system that writes to the local and + * remote storage + */ +class HybridShuffleDataIO(sparkConf: SparkConf) extends ShuffleDataIO { + + override def executor(): ShuffleExecutorComponents = { + new HybridShuffleExecutorComponents(sparkConf, + new LocalDiskShuffleExecutorComponents(sparkConf)) + } + + override def driver(): ShuffleDriverComponents = { + new HybridShuffleDriverComponents(new LocalDiskShuffleDriverComponents()) + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/remote/HybridShuffleDriverComponents.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/remote/HybridShuffleDriverComponents.scala new file mode 100644 index 000000000000..89bd9e0586c1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/remote/HybridShuffleDriverComponents.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.remote + +import java.util.Collections + +import org.apache.spark.shuffle.api.ShuffleDriverComponents +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleDriverComponents + +class HybridShuffleDriverComponents( + localDiskShuffleDriverComponents: LocalDiskShuffleDriverComponents) + extends ShuffleDriverComponents { + + override def initializeApplication(): java.util.Map[String, String] = { + localDiskShuffleDriverComponents.initializeApplication() + Collections.emptyMap() + } + + override def cleanupApplication(): Unit = { + localDiskShuffleDriverComponents.cleanupApplication() + } + + override def removeShuffle(shuffleId: Int, blocking: Boolean): Unit = { + localDiskShuffleDriverComponents.removeShuffle(shuffleId, blocking) + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/remote/HybridShuffleExecutorComponents.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/remote/HybridShuffleExecutorComponents.scala new file mode 100644 index 000000000000..d889bab226f2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/remote/HybridShuffleExecutorComponents.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.remote + +import java.util.Optional + +import org.apache.spark.{SparkConf, TaskContext} +import org.apache.spark.shuffle.api.{ShuffleExecutorComponents, ShuffleMapOutputWriter, SingleSpillShuffleMapOutputWriter} +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents + +class HybridShuffleExecutorComponents(sparkConf: SparkConf, + localDiskShuffleExecutorComponents: LocalDiskShuffleExecutorComponents) + extends ShuffleExecutorComponents { + + override def initializeExecutor(appId: String, execId: String, + extraConfigs: java.util.Map[String, String]): Unit = { + localDiskShuffleExecutorComponents.initializeExecutor(appId, execId, extraConfigs) + } + + override def createMapOutputWriter( + shuffleId: Int, + mapTaskId: Long, + numPartitions: Int): ShuffleMapOutputWriter = { + val isRemote = TaskContext.get().getLocalProperties + .getOrDefault("consolidation.write", "false").toString.toBoolean + if (isRemote) { + new RemoteShuffleMapOutputWriter(sparkConf, shuffleId, mapTaskId, numPartitions) + } else { + localDiskShuffleExecutorComponents.createMapOutputWriter(shuffleId, mapTaskId, numPartitions) + } + } + + override def createSingleFileMapOutputWriter( + shuffleId: Int, + mapId: Long): Optional[SingleSpillShuffleMapOutputWriter] = { + Optional.empty() + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/remote/RemoteShuffleMapOutputWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/remote/RemoteShuffleMapOutputWriter.scala new file mode 100644 index 000000000000..8a2165d121ae --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/remote/RemoteShuffleMapOutputWriter.scala @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.remote + +import java.io.{BufferedOutputStream, IOException, OutputStream} +import java.nio.ByteBuffer +import java.nio.channels.{Channels, WritableByteChannel} +import java.util.Optional + +import org.apache.hadoop.fs.FSDataOutputStream + +import org.apache.spark.{SparkConf, SparkEnv, TaskContext} +import org.apache.spark.internal.{config, Logging} +import org.apache.spark.internal.config.REMOTE_SHUFFLE_BUFFER_SIZE +import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter, WritableByteChannelWrapper} +import org.apache.spark.shuffle.api.metadata.MapOutputCommitMessage +import org.apache.spark.storage.{RemoteShuffleStorage, ShuffleChecksumBlockId, ShuffleDataBlockId} + + +/** Implements the ShuffleMapOutputWriter interface. + * It stores the shuffle output in one shuffle block. + * + * This file is based on Spark "LocalDiskShuffleMapOutputWriter.java". + */ + +class RemoteShuffleMapOutputWriter( + conf: SparkConf, + shuffleId: Int, + mapId: Long, + numPartitions: Int +) extends ShuffleMapOutputWriter + with Logging { + + /* Target block for writing */ + + private var stream: FSDataOutputStream = _ + private var bufferedStream: OutputStream = _ + private var bufferedStreamAsChannel: WritableByteChannel = _ + private var reduceIdStreamPosition: Long = 0 + + def initStream(): Unit = { + if (stream == null) { + val shuffleBlock = ShuffleDataBlockId(shuffleId, mapId, TaskContext.getPartitionId()) + stream = RemoteShuffleStorage.getStream(shuffleBlock) + val bufferSize: Long = conf.getSizeAsBytes(REMOTE_SHUFFLE_BUFFER_SIZE.key, "64M") + bufferedStream = new BufferedOutputStream(stream, bufferSize.toInt) + } + } + + def initChannel(): Unit = { + if (bufferedStreamAsChannel == null) { + initStream() + bufferedStreamAsChannel = Channels.newChannel(bufferedStream) + } + } + + private val partitionLengths = Array.fill[Long](numPartitions)(0) + private var totalBytesWritten: Long = 0 + + override def getPartitionWriter(reducePartitionId: Int): ShufflePartitionWriter = { + if (bufferedStream != null) { + bufferedStream.flush() + } + if (stream != null) { + stream.flush() + reduceIdStreamPosition = stream.getPos + } + new RemoteShufflePartitionWriter(reducePartitionId) + } + + /** + * Close all writers and the shuffle block. + */ + override def commitAllPartitions(checksums: Array[Long]): MapOutputCommitMessage = { + if (bufferedStream != null) { + bufferedStream.flush() + } + if (stream != null) { + if (stream.getPos != totalBytesWritten) { + throw new RuntimeException( + f"S3ShuffleMapOutputWriter: Unexpected output length ${stream.getPos}," + + f" expected: $totalBytesWritten." + ) + } + } + if (bufferedStreamAsChannel != null) { + bufferedStreamAsChannel.close() + } + if (bufferedStream != null) { + // Closes the underlying stream as well! + bufferedStream.close() + } + + // Write checksum. + if (SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED)) { + RemoteShuffleStorage.writeCheckSum(ShuffleChecksumBlockId(shuffleId = shuffleId, + mapId = mapId, reduceId = 0), checksums) + } + MapOutputCommitMessage.of(partitionLengths) + } + + override def abort(error: Throwable): Unit = { + cleanUp() + } + + private def cleanUp(): Unit = { + if (bufferedStreamAsChannel != null) { + bufferedStreamAsChannel.close() + } + if (bufferedStream != null) { + bufferedStream.close() + } + if (stream != null) { + stream.close() + } + } + + private class RemoteShufflePartitionWriter(reduceId: Int) extends ShufflePartitionWriter + with Logging { + private var partitionStream: RemoteShuffleOutputStream = _ + private var partitionChannel: RemoteShufflePartitionWriterChannel = _ + + override def openStream(): OutputStream = { + initStream() + if (partitionStream == null) { + partitionStream = new RemoteShuffleOutputStream(reduceId) + } + partitionStream + } + + override def openChannelWrapper(): Optional[WritableByteChannelWrapper] = { + if (partitionChannel == null) { + initChannel() + partitionChannel = new RemoteShufflePartitionWriterChannel(reduceId) + } + Optional.of(partitionChannel) + } + + override def getNumBytesWritten: Long = { + if (partitionChannel != null) { + return partitionChannel.numBytesWritten + } + if (partitionStream != null) { + return partitionStream.numBytesWritten + } + // The partition is empty. + 0 + } + } + + private class RemoteShuffleOutputStream(reduceId: Int) extends OutputStream { + private var byteCount: Long = 0 + private var isClosed = false + + def numBytesWritten: Long = byteCount + + override def write(b: Int): Unit = { + if (isClosed) { + throw new IOException("RemoteShuffleOutputStream is already closed.") + } + bufferedStream.write(b) + byteCount += 1 + } + + override def write(b: Array[Byte], off: Int, len: Int): Unit = { + if (isClosed) { + throw new IOException("RemoteShuffleOutputStream is already closed.") + } + bufferedStream.write(b, off, len) + byteCount += len + } + + override def flush(): Unit = { + if (isClosed) { + throw new IOException("RemoteShuffleOutputStream is already closed.") + } + bufferedStream.flush() + } + + override def close(): Unit = { + partitionLengths(reduceId) = byteCount + totalBytesWritten += byteCount + isClosed = true + } + } + + private class RemoteShufflePartitionWriterChannel(reduceId: Int) + extends WritableByteChannelWrapper { + private val partChannel = new RemotePartitionWritableByteChannel(bufferedStreamAsChannel) + + override def channel(): WritableByteChannel = { + partChannel + } + + def numBytesWritten: Long = { + partChannel.numBytesWritten() + } + + override def close(): Unit = { + partitionLengths(reduceId) = numBytesWritten + totalBytesWritten += numBytesWritten + } + } + + private class RemotePartitionWritableByteChannel(channel: WritableByteChannel) extends + WritableByteChannel { + + private var count: Long = 0 + + def numBytesWritten(): Long = { + count + } + + override def isOpen: Boolean = { + channel.isOpen + } + + override def close(): Unit = {} + + override def write(x: ByteBuffer): Int = { + var c = 0 + while (x.hasRemaining) { + c += channel.write(x) + } + count += c + c + } + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/RemoteShuffleStorage.scala b/core/src/main/scala/org/apache/spark/storage/RemoteShuffleStorage.scala new file mode 100644 index 000000000000..be9816e67688 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/RemoteShuffleStorage.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.io.{BufferedOutputStream, DataOutputStream} + +import scala.concurrent.Future +import scala.reflect.ClassTag + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, FSDataOutputStream, Path} + +import org.apache.spark.{SparkConf, SparkEnv, SparkException} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.internal.Logging +import org.apache.spark.internal.LogKeys._ +import org.apache.spark.internal.config.{REMOTE_SHUFFLE_BUFFER_SIZE, SHUFFLE_REMOTE_STORAGE_PATH} +import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcTimeout} +import org.apache.spark.storage.BlockManagerMessages.RemoveShuffle +import org.apache.spark.storage.RemoteShuffleStorage.{appId, remoteFileSystem, remoteStoragePath} + +private[storage] class RemoteStorageRpcEndpointRef(conf: SparkConf) extends RpcEndpointRef(conf) { + // scalastyle:off executioncontextglobal + import scala.concurrent.ExecutionContext.Implicits.global + // scalastyle:on executioncontextglobal + override def address: RpcAddress = null + override def name: String = "remoteStorageEndpoint" + override def send(message: Any): Unit = {} + override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { + message match { + case RemoveShuffle(shuffleId) => + val dataFile = new Path(remoteStoragePath, s"$appId/$shuffleId") + SparkEnv.get.mapOutputTracker.unregisterShuffle(shuffleId) + val shuffleManager = SparkEnv.get.shuffleManager + if (shuffleManager != null) { + shuffleManager.unregisterShuffle(shuffleId) + } else { + logDebug(log"Ignore remove shuffle ${MDC(SHUFFLE_ID, shuffleId)}") + } + Future { + remoteFileSystem.delete(dataFile, true).asInstanceOf[T] + } + case _ => + Future{true.asInstanceOf[T]} + } + } +} + + +private[spark] object RemoteShuffleStorage extends Logging { + + val blockManagerId = "remoteShuffleBlockStore" + lazy val hadoopConf: Configuration = SparkHadoopUtil.get.newConfiguration(SparkEnv.get.conf) + lazy val appId: String = SparkEnv.get.conf.getAppId + lazy val remoteStoragePath = new Path(SparkEnv.get.conf.get(SHUFFLE_REMOTE_STORAGE_PATH).get) + lazy val remoteFileSystem: FileSystem = FileSystem.get(remoteStoragePath.toUri, hadoopConf) + + /** We use one block manager id as a place holder. */ + val BLOCK_MANAGER_ID: BlockManagerId = BlockManagerId(blockManagerId, "remote", 7337) + + /** Register the remote shuffle block manager and its RPC endpoint. */ + def registerBlockManagerifNeeded(master: BlockManagerMaster, conf: SparkConf, + hadoopConf: Configuration): Unit = { + if (conf.get(SHUFFLE_REMOTE_STORAGE_PATH).isDefined) { + master.registerBlockManager( + BLOCK_MANAGER_ID, Array.empty[String], 0, 0, new RemoteStorageRpcEndpointRef(conf)) + } + } + + /** Clean up the generated remote shuffle location for this app. */ + def cleanUp(conf: SparkConf, hadoopConf: Configuration): Unit = { + if (conf.contains("spark.app.id") && conf.contains(SHUFFLE_REMOTE_STORAGE_PATH)) { + val shuffleRemotePath = + new Path(conf.get(SHUFFLE_REMOTE_STORAGE_PATH).get, conf.getAppId) + val remoteUri = shuffleRemotePath.toUri + val remoteFileSystem = FileSystem.get(remoteUri, hadoopConf) + if (remoteFileSystem.exists(shuffleRemotePath)) { + if (remoteFileSystem.delete(shuffleRemotePath, true)) { + logInfo(log"Succeed to clean up: ${MDC(URI, remoteUri)}") + } else { + // Clean-up can fail due to the permission issues. + logWarning(log"Failed to clean up: ${MDC(URI, remoteUri)}") + } + } + } + } + + /** + * Read a ManagedBuffer. + */ + def read(blockIds: Seq[BlockId], listener: BlockFetchingListener): Unit = { + blockIds.foreach { blockId => + // Use blockId.name to ensure the block name matches what's in remainingBlocks + // For batch blocks, this will be the batch name (e.g., "shuffle_0_0_5_8") + // For regular blocks, this will be the individual block name (e.g., "shuffle_0_0_5") + logInfo(log"Read ${MDC(BLOCK_ID, blockId)}") + listener.onBlockFetchSuccess(blockId.name, + new FileSystemManagedBuffer(getPath(blockId), hadoopConf, + SparkEnv.get.conf.getSizeAsMb(REMOTE_SHUFFLE_BUFFER_SIZE.key, "64M").toInt)) + } + } + + def getPath(blockId: BlockId): Path = { + val (shuffleId, name) = blockId match { + case ShuffleBlockId(shuffleId, mapId, reduceId) => + (shuffleId, ShuffleDataBlockId(shuffleId, mapId, reduceId).name) + case shuffleDataBlock@ ShuffleDataBlockId(shuffleId, _, _) => + (shuffleId, shuffleDataBlock.name) + case ShuffleBlockBatchId(shuffleId, mapId, startReduceId, endReduceId) => + // For batches, we use the startReduceId to identify the block file. + // The batch range [startReduceId, endReduceId) will be read from the same file. + (shuffleId, ShuffleDataBlockId(shuffleId, mapId, startReduceId).name) + case shuffleCheckSumBlock@ ShuffleChecksumBlockId(shuffleId, _, _) => + (shuffleId, shuffleCheckSumBlock.name) + case _ => throw new SparkException(s"Unsupported block id type: ${blockId.name}") + } + val hash = JavaUtils.nonNegativeHash(name) + new Path(remoteStoragePath, s"$appId/$shuffleId/$hash/$name") + } + + def getStream(blockId: BlockId): FSDataOutputStream = { + val path = getPath(blockId) + remoteFileSystem.create(path) + } + + def writeCheckSum(blockId: BlockId, array: Array[Long]): Unit = { + if (array.nonEmpty) { + val out = new DataOutputStream(new BufferedOutputStream(getStream(blockId), + scala.math.min(8192, 8 * array.length))) + array.foreach(out.writeLong) + out.flush() + out.close() + } + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index b2f185bc590f..ef22a844d1e2 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -364,15 +364,19 @@ final class ShuffleBlockFetcherIterator( } } - // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is - // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch - // the data and write it to file directly. - if (req.size > maxReqSizeShuffleToMem) { - shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, this) + if (req.address == RemoteShuffleStorage.BLOCK_MANAGER_ID) { + RemoteShuffleStorage.read(req.blocks.map(_.blockId).toSeq, blockFetchingListener) } else { - shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, null) + // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data + // is already encrypted and compressed over the wire(w.r.t. the related configs), + // we can just fetch the data and write it to file directly. + if (req.size > maxReqSizeShuffleToMem) { + shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, + blockFetchingListener, this) + } else { + shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, + blockFetchingListener, null) + } } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index ba36f7549270..424498b406ab 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -81,7 +81,7 @@ object FakeTask { val tasks = Array.tabulate[Task[_]](numTasks) { i => new FakeTask(stageId, i, if (prefLocs.size != 0) prefLocs(i) else Nil) } - new TaskSet(tasks, stageId, stageAttemptId, priority = priority, null, rpId, None) + new TaskSet(tasks, stageId, stageAttemptId, priority = priority, new Properties(), rpId, None) } def createShuffleMapTaskSet( @@ -107,7 +107,7 @@ object FakeTask { }, 1, prefLocs(i), JobArtifactSet.defaultJobArtifactSet, new Properties, SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array()) } - new TaskSet(tasks, stageId, stageAttemptId, priority = priority, null, + new TaskSet(tasks, stageId, stageAttemptId, priority = priority, new Properties(), ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID, Some(0)) } @@ -137,6 +137,6 @@ object FakeTask { val tasks = Array.tabulate[Task[_]](numTasks) { i => new FakeTask(stageId, i, if (prefLocs.size != 0) prefLocs(i) else Nil, isBarrier = true) } - new TaskSet(tasks, stageId, stageAttemptId, priority = priority, null, rpId, None) + new TaskSet(tasks, stageId, stageAttemptId, priority = priority, new Properties(), rpId, None) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala index 30ed80dbe848..e1f20adace3a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/PoolSuite.scala @@ -44,7 +44,7 @@ class PoolSuite extends SparkFunSuite with LocalSparkContext { val tasks = Array.tabulate[Task[_]](numTasks) { i => new FakeTask(stageId, i, Nil) } - new TaskSetManager(taskScheduler, new TaskSet(tasks, stageId, 0, 0, null, + new TaskSetManager(taskScheduler, new TaskSet(tasks, stageId, 0, 0, new Properties(), ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID, None), 0) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 9007ae4e0990..34cad783916e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -546,7 +546,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext val numFreeCores = 1 val taskSet = new TaskSet( Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), - 0, 0, 0, null, ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID, None) + 0, 0, 0, new Properties(), ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID, None) val multiCoreWorkerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", taskCpus), new WorkerOffer("executor1", "host1", numFreeCores)) taskScheduler.submitTasks(taskSet) @@ -561,7 +561,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext taskScheduler.submitTasks(FakeTask.createTaskSet(1)) val taskSet2 = new TaskSet( Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), - 1, 0, 0, null, ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID, None) + 1, 0, 0, new Properties(), ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID, None) taskScheduler.submitTasks(taskSet2) taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten assert(taskDescriptions.map(_.executorId) === Seq("executor0")) @@ -2111,7 +2111,7 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext }, 1, Seq(TaskLocation("host1", "executor1")), JobArtifactSet.getActiveOrDefault(sc), new Properties, null) - val taskSet = new TaskSet(Array(task1, task2), 0, 0, 0, null, 0, Some(0)) + val taskSet = new TaskSet(Array(task1, task2), 0, 0, 0, new Properties(), 0, Some(0)) taskScheduler.submitTasks(taskSet) val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index be1bc5fe3212..a282999dfa0d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -860,7 +860,7 @@ class TaskSetManagerSuite sched = new FakeTaskScheduler(sc, ("exec1", "host1")) val taskSet = new TaskSet(Array(new LargeTask(0)), 0, 0, 0, - null, ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID, None) + new Properties(), ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID, None) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) assert(!manager.emittedTaskSizeWarning) @@ -876,7 +876,7 @@ class TaskSetManagerSuite val taskSet = new TaskSet( Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), - 0, 0, 0, null, ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID, None) + 0, 0, 0, new Properties(), ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID, None) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) intercept[TaskNotSerializableException] { @@ -957,7 +957,7 @@ class TaskSetManagerSuite }, 1, Seq(TaskLocation("host1", "execA")), JobArtifactSet.getActiveOrDefault(sc), new Properties, null) val taskSet = new TaskSet(Array(singleTask), 0, 0, 0, - null, ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID, Some(0)) + new Properties(), ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID, Some(0)) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) // Offer host1, which should be accepted as a PROCESS_LOCAL location @@ -2287,8 +2287,8 @@ class TaskSetManagerSuite TestUtils.waitUntilExecutorsUp(sc, 2, 60000) val tasks = Array.tabulate[Task[_]](2)(partition => new FakeLongTasks(stageId = 0, partition)) - val taskSet: TaskSet = new TaskSet(tasks, stageId = 0, stageAttemptId = 0, priority = 0, null, - ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID, None) + val taskSet: TaskSet = new TaskSet(tasks, stageId = 0, stageAttemptId = 0, priority = 0, + new Properties(), ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID, None) val stageId = taskSet.stageId val stageAttemptId = taskSet.stageAttemptId sched.submitTasks(taskSet) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 5a8505dc6992..6bb47063b467 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -189,7 +189,8 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { * Represents data where tuples are broadcasted to every node. It is quite common that the * entire set of tuples is transformed into different data structure. */ -case class BroadcastDistribution(mode: BroadcastMode) extends Distribution { +case class +BroadcastDistribution(mode: BroadcastMode) extends Distribution { override def requiredNumPartitions: Option[Int] = Some(1) override def createPartitioning(numPartitions: Int): Partitioning = { @@ -621,6 +622,54 @@ case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning { } } +/** + * A partitioning that always satisfies any distribution. + * This is useful when you want to ensure that a partitioning will always satisfy a distribution + * regardless of the distribution's requirements. + */ +case class PassThroughPartitioning(childPartitioning: Partitioning) + extends Expression with Partitioning with Unevaluable { + /** + * Always returns true, satisfying any distribution. + */ + override def satisfies0(required: Distribution): Boolean = childPartitioning.satisfies(required) + + override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = { + childPartitioning.createShuffleSpec(distribution) + } + + /** Returns the number of partitions that the data is split across */ + override val numPartitions: Int = childPartitioning.numPartitions + + // Expression interface implementation + override def children: Seq[Expression] = childPartitioning match { + case e: Expression => e.children + case _ => Nil + } + + override def nullable: Boolean = false + override def dataType: DataType = IntegerType + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): PassThroughPartitioning = { + childPartitioning match { + case e: Expression => + copy(childPartitioning = e.withNewChildren(newChildren).asInstanceOf[Partitioning]) + case _ => + this + } + } + + override lazy val canonicalized: Expression = { + val canonicalizedChild = childPartitioning match { + case e: Expression => e.canonicalized.asInstanceOf[Partitioning] + case other => other + } + copy(childPartitioning = canonicalizedChild) + } +} + + /** * This is used in the scenario where an operator has multiple children (e.g., join) and one or more * of which have their own requirement regarding whether its data can be considered as diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4e9caa822997..eb116a9d2a2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3806,6 +3806,26 @@ object SQLConf { .version("4.1.0") .fallbackConf(SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED) + val ENABLE_SHUFFLE_CONSOLIDATION = + buildConf("spark.sql.shuffle.consolidation.enabled") + .doc("When enabled, creates a consolidation shuffle stage that consolidates shuffle data " + + "from earlier stages and uploads it to remote storage. This " + + "consolidation stage uses PassThroughPartitioning to satisfy distribution requirements " + + "without changing the actual data partitioning. The remote storage upload can improve " + + "shuffle performance and enable better resource utilization in distributed environments.") + .version("4.1.0") + .booleanConf + .createWithDefault(false) + + val SHUFFLE_CONSOLIDATION_SIZE_THRESHOLD = + buildConf("spark.sql.shuffle.consolidation.size.threshold") + .doc("Minimum shuffle size in bytes required to create a consolidation shuffle stage " + + "in adaptive execution. Only shuffles larger than this threshold will have a " + + "consolidation stage added. This helps avoid overhead for small shuffles.") + .version("4.1.0") + .bytesConf(ByteUnit.BYTE) + .createWithDefaultString("100MB") + val SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD = buildConf("spark.sql.sortMergeJoinExec.buffer.in.memory.threshold") .internal() @@ -7880,12 +7900,14 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def hadoopLineRecordReaderEnabled: Boolean = getConf(SQLConf.HADOOP_LINE_RECORD_READER_ENABLED) - def legacyXMLParserEnabled: Boolean = - getConf(SQLConf.LEGACY_XML_PARSER_ENABLED) + def shuffleConsolidationEnabled: Boolean = getConf(SQLConf.ENABLE_SHUFFLE_CONSOLIDATION) def coerceMergeNestedTypes: Boolean = getConf(SQLConf.MERGE_INTO_SOURCE_NESTED_TYPE_COERCION_ENABLED) + def shuffleConsolidationSizeThreshold: Long = + getConf(SQLConf.SHUFFLE_CONSOLIDATION_SIZE_THRESHOLD) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/pom.xml b/sql/core/pom.xml index a5b5c399d4fc..4cb6c378fc93 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -344,7 +344,7 @@ org.scalatest scalatest-maven-plugin - -ea -Xmx4g -Xss4m -XX:ReservedCodeCacheSize=${CodeCacheSize} ${extraJavaTestArgs} -Dio.netty.tryReflectionSetAccessible=true + -ea -Xmx16g -Xss4m -XX:ReservedCodeCacheSize=${CodeCacheSize} ${extraJavaTestArgs} -Dio.netty.tryReflectionSetAccessible=true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/AddConsolidationShuffle.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/AddConsolidationShuffle.scala new file mode 100644 index 000000000000..25bdd7a0b56a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/AddConsolidationShuffle.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{DirectShufflePartitionID, SparkPartitionID} +import org.apache.spark.sql.catalyst.plans.physical.{PassThroughPartitioning, ShufflePartitionIdPassThrough} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec +import org.apache.spark.sql.execution.exchange.{SHUFFLE_CONSOLIDATION, ShuffleExchangeExec} +import org.apache.spark.sql.internal.SQLConf + +object AddConsolidationShuffle extends Rule[SparkPlan] { + + def apply(plan: SparkPlan): SparkPlan = { + if (!SQLConf.get.shuffleConsolidationEnabled) { + return plan + } + plan transformUp { + case plan@ShuffleExchangeExec(part, _, origin, _) => + val passThroughPartitioning = ShufflePartitionIdPassThrough( + DirectShufflePartitionID(SparkPartitionID()), + part.numPartitions + ) + // Non-adaptive: always add consolidation exchange + new ShuffleExchangeExec(PassThroughPartitioning(part), plan, SHUFFLE_CONSOLIDATION) { + override def doExecute(): RDD[InternalRow] = { + super.doExecute() + } + } + case p: ShuffleQueryStageExec + if p.shuffle.shuffleOrigin != SHUFFLE_CONSOLIDATION && p.isMaterialized => + // Add consolidation exchange only if: + // 1. Stage is materialized + // 2. Size exceeds consolidation threshold + val size = p.getRuntimeStatistics.sizeInBytes + val consolidationThreshold = SQLConf.get.shuffleConsolidationSizeThreshold + if (size > consolidationThreshold) { + val passThroughPartitioning = ShufflePartitionIdPassThrough( + DirectShufflePartitionID(SparkPartitionID()), + p.outputPartitioning.numPartitions + ) + ShuffleExchangeExec(passThroughPartitioning, p, + SHUFFLE_CONSOLIDATION) + } else { + p + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 12fce2f91dac..dbf046e5d482 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -588,6 +588,7 @@ object QueryExecution { PlanSubqueries(sparkSession), RemoveRedundantProjects, EnsureRequirements(), + AddConsolidationShuffle, // This rule must be run after `EnsureRequirements`. InsertSortForLimitAndOffset, // `ReplaceHashWithSortAgg` needs to be added after `EnsureRequirements` to guarantee the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 0e50c03b6cc9..6fbc47e539ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -120,6 +120,7 @@ case class AdaptiveSparkPlanExec( CoalesceBucketsInJoin, RemoveRedundantProjects, ensureRequirements, + AddConsolidationShuffle, // This rule must be run after `EnsureRequirements`. InsertSortForLimitAndOffset, AdjustShuffleExchangePosition, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala index 3fdcb17bdeae..291bf54c01d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.physical.SinglePartition import org.apache.spark.sql.execution.{ShufflePartitionSpec, SparkPlan, UnaryExecNode, UnionExec} -import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, REBALANCE_PARTITIONS_BY_COL, REBALANCE_PARTITIONS_BY_NONE, REPARTITION_BY_COL, ShuffleExchangeLike, ShuffleOrigin} +import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, REBALANCE_PARTITIONS_BY_COL, REBALANCE_PARTITIONS_BY_NONE, REPARTITION_BY_COL, SHUFFLE_CONSOLIDATION, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, CartesianProductExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils @@ -37,14 +37,22 @@ case class CoalesceShufflePartitions(session: SparkSession) extends AQEShuffleRe override val supportedShuffleOrigins: Seq[ShuffleOrigin] = Seq(ENSURE_REQUIREMENTS, REPARTITION_BY_COL, REBALANCE_PARTITIONS_BY_NONE, - REBALANCE_PARTITIONS_BY_COL) + REBALANCE_PARTITIONS_BY_COL, SHUFFLE_CONSOLIDATION) override def isSupported(shuffle: ShuffleExchangeLike): Boolean = { - shuffle.outputPartitioning != SinglePartition && super.isSupported(shuffle) + shuffle.outputPartitioning != SinglePartition && canApply(shuffle) + super.isSupported(shuffle) + } + + def canApply(shuffle: ShuffleExchangeLike): Boolean = { + shuffle.shuffleOrigin match { + case ENSURE_REQUIREMENTS => !conf.shuffleConsolidationEnabled + case _ => true + } } override def apply(plan: SparkPlan): SparkPlan = { - if (!conf.coalesceShufflePartitionsEnabled) { + if (!conf.coalesceShufflePartitionsEnabled || conf.shuffleConsolidationEnabled) { return plan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeShuffleWithLocalRead.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeShuffleWithLocalRead.scala index cf1c7ecedd5b..f45d938ceda0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeShuffleWithLocalRead.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeShuffleWithLocalRead.scala @@ -108,7 +108,8 @@ object OptimizeShuffleWithLocalRead extends AQEShuffleReadRule { } override def apply(plan: SparkPlan): SparkPlan = { - if (!conf.getConf(SQLConf.LOCAL_SHUFFLE_READER_ENABLED)) { + if (!conf.getConf(SQLConf.LOCAL_SHUFFLE_READER_ENABLED) || + conf.getConf(SQLConf.ENABLE_SHUFFLE_CONSOLIDATION)) { return plan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala index 9c9c8e13d2d5..dc4f55c5a01e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/simpleCosting.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.adaptive import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike +import org.apache.spark.sql.execution.exchange.{SHUFFLE_CONSOLIDATION, ShuffleExchangeLike} import org.apache.spark.sql.execution.joins.ShuffledJoin /** @@ -42,7 +42,7 @@ case class SimpleCost(value: Long) extends Cost { case class SimpleCostEvaluator(forceOptimizeSkewedJoin: Boolean) extends CostEvaluator { override def evaluateCost(plan: SparkPlan): Cost = { val numShuffles = plan.collect { - case s: ShuffleExchangeLike => s + case s: ShuffleExchangeLike if (s.shuffleOrigin != SHUFFLE_CONSOLIDATION) => s }.size if (forceOptimizeSkewedJoin) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index f052bd906880..acba7a8f4b36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -182,6 +182,11 @@ case object REBALANCE_PARTITIONS_BY_COL extends ShuffleOrigin // change it. case object REQUIRED_BY_STATEFUL_OPERATOR extends ShuffleOrigin +// Indicates that the shuffle operator was added by the consolidation shuffle optimization to +// consolidate shuffle data from earlier stages and upload it to remote storage. +case object SHUFFLE_CONSOLIDATION extends ShuffleOrigin + + /** * Performs a shuffle that will result in the desired partitioning. */ @@ -201,7 +206,13 @@ case class ShuffleExchangeExec( "numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions") ) ++ readMetrics ++ writeMetrics - override def nodeName: String = "Exchange" + override def nodeName: String = { + if (shuffleOrigin == SHUFFLE_CONSOLIDATION) { + "Consolidation exchange" + } else { + "Exchange" + } + } private lazy val serializer: Serializer = new UnsafeRowSerializer(child.output.size, longMetric("dataSize")) @@ -246,7 +257,8 @@ case class ShuffleExchangeExec( child.output, outputPartitioning, serializer, - writeMetrics) + writeMetrics, + shuffleOrigin) metrics("numPartitions").set(dep.partitioner.numPartitions) val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) SQLMetrics.postDriverMetricUpdates( @@ -258,6 +270,7 @@ case class ShuffleExchangeExec( // The ShuffleRowRDD will be cached in SparkPlan.executeRDD and reused if this plan is used by // multiple plans. new ShuffledRowRDD(shuffleDependency, readMetrics) + } override protected def withNewChildInternal(newChild: SparkPlan): ShuffleExchangeExec = @@ -336,7 +349,8 @@ object ShuffleExchangeExec { outputAttributes: Seq[Attribute], newPartitioning: Partitioning, serializer: Serializer, - writeMetrics: Map[String, SQLMetric]) + writeMetrics: Map[String, SQLMetric], + shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS) : ShuffleDependency[Int, InternalRow, InternalRow] = { val part: Partitioner = newPartitioning match { case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions) @@ -493,7 +507,7 @@ object ShuffleExchangeExec { rddWithPartitionIds, new PartitionIdPassthrough(part.numPartitions), serializer, - shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics), + shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics, shuffleOrigin), rowBasedChecksums = UnsafeRowChecksum.createUnsafeRowChecksums(checksumSize), checksumMismatchFullRetryEnabled = SQLConf.get.shuffleChecksumMismatchFullRetryEnabled) @@ -503,13 +517,34 @@ object ShuffleExchangeExec { /** * Create a customized [[ShuffleWriteProcessor]] for SQL which wrap the default metrics reporter * with [[SQLShuffleWriteMetricsReporter]] as new reporter for [[ShuffleWriteProcessor]]. + * + * For consolidation shuffles, the processor is marked with [[ConsolidationShuffleMarker]] + * to indicate that remote storage should be used. */ - def createShuffleWriteProcessor(metrics: Map[String, SQLMetric]): ShuffleWriteProcessor = { - new ShuffleWriteProcessor { - override protected def createMetricsReporter( - context: TaskContext): ShuffleWriteMetricsReporter = { - new SQLShuffleWriteMetricsReporter(context.taskMetrics().shuffleWriteMetrics, metrics) + def createShuffleWriteProcessor( + metrics: Map[String, SQLMetric], + shuffleOrigin: ShuffleOrigin): ShuffleWriteProcessor = { + if (shuffleOrigin == SHUFFLE_CONSOLIDATION) { + new ShuffleWriteProcessor with ConsolidationShuffleMarker { + override protected def createMetricsReporter( + context: TaskContext): ShuffleWriteMetricsReporter = { + new SQLShuffleWriteMetricsReporter(context.taskMetrics().shuffleWriteMetrics, metrics) + } + } + } else { + new ShuffleWriteProcessor { + override protected def createMetricsReporter( + context: TaskContext): ShuffleWriteMetricsReporter = { + new SQLShuffleWriteMetricsReporter(context.taskMetrics().shuffleWriteMetrics, metrics) + } } } } } + +/** + * Marker trait to identify shuffle write processors for consolidation shuffles. + * Shuffle writers can check if a processor implements this trait to determine + * whether to use remote storage for shuffle data. + */ +private[spark] trait ConsolidationShuffleMarker diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/reuse/ReuseExchangeAndSubquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/reuse/ReuseExchangeAndSubquery.scala index 471b926dc012..01f289575500 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/reuse/ReuseExchangeAndSubquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/reuse/ReuseExchangeAndSubquery.scala @@ -42,14 +42,14 @@ case object ReuseExchangeAndSubquery extends Rule[SparkPlan] { def reuse(plan: SparkPlan): SparkPlan = { plan.transformUpWithPruning(_.containsAnyPattern(EXCHANGE, PLAN_EXPRESSION)) { - case exchange: Exchange if conf.exchangeReuseEnabled => + case exchange: Exchange if conf.exchangeReuseEnabled && + !conf.shuffleConsolidationEnabled => val cachedExchange = exchanges.getOrElseUpdate(exchange.canonicalized, exchange) if (cachedExchange.ne(exchange)) { ReusedExchangeExec(exchange.output, cachedExchange) } else { cachedExchange } - case other => other.transformExpressionsUpWithPruning(_.containsPattern(PLAN_EXPRESSION)) { case sub: ExecSubqueryExpression =>