-
Notifications
You must be signed in to change notification settings - Fork 28k
/
ShuffleExchangeExec.scala
438 lines (392 loc) · 19.7 KB
/
ShuffleExchangeExec.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.exchange
import java.util.function.Supplier
import scala.collection.mutable
import scala.concurrent.Future
import org.apache.spark._
import org.apache.spark.internal.config
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter, ShuffleWriteProcessor}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.MutablePair
import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator}
import org.apache.spark.util.random.XORShiftRandom
/**
* Common trait for all shuffle exchange implementations to facilitate pattern matching.
*/
trait ShuffleExchangeLike extends Exchange {
/**
* Returns the number of mappers of this shuffle.
*/
def numMappers: Int
/**
* Returns the shuffle partition number.
*/
def numPartitions: Int
/**
* Returns the advisory partition size.
*/
def advisoryPartitionSize: Option[Long]
/**
* The origin of this shuffle operator.
*/
def shuffleOrigin: ShuffleOrigin
/**
* The asynchronous job that materializes the shuffle. It also does the preparations work,
* such as waiting for the subqueries.
*/
final def submitShuffleJob: Future[MapOutputStatistics] = executeQuery {
mapOutputStatisticsFuture
}
protected def mapOutputStatisticsFuture: Future[MapOutputStatistics]
/**
* Returns the shuffle RDD with specified partition specs.
*/
def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[_]
/**
* Returns the runtime statistics after shuffle materialization.
*/
def runtimeStatistics: Statistics
/**
* The shuffle ID.
*/
def shuffleId: Int
}
// Describes where the shuffle operator comes from.
sealed trait ShuffleOrigin
// Indicates that the shuffle operator was added by the internal `EnsureRequirements` rule. It
// means that the shuffle operator is used to ensure internal data partitioning requirements and
// Spark is free to optimize it as long as the requirements are still ensured.
case object ENSURE_REQUIREMENTS extends ShuffleOrigin
// Indicates that the shuffle operator was added by the user-specified repartition operator. Spark
// can still optimize it via changing shuffle partition number, as data partitioning won't change.
case object REPARTITION_BY_COL extends ShuffleOrigin
// Indicates that the shuffle operator was added by the user-specified repartition operator with
// a certain partition number. Spark can't optimize it.
case object REPARTITION_BY_NUM extends ShuffleOrigin
// Indicates that the shuffle operator was added by the user-specified rebalance operator.
// Spark will try to rebalance partitions that make per-partition size not too small and not
// too big. Local shuffle read will be used if possible to reduce network traffic.
case object REBALANCE_PARTITIONS_BY_NONE extends ShuffleOrigin
// Indicates that the shuffle operator was added by the user-specified rebalance operator with
// columns. Spark will try to rebalance partitions that make per-partition size not too small and
// not too big.
// Different from `REBALANCE_PARTITIONS_BY_NONE`, local shuffle read cannot be used for it as
// the output needs to be partitioned by the given columns.
case object REBALANCE_PARTITIONS_BY_COL extends ShuffleOrigin
/**
* Performs a shuffle that will result in the desired partitioning.
*/
case class ShuffleExchangeExec(
override val outputPartitioning: Partitioning,
child: SparkPlan,
shuffleOrigin: ShuffleOrigin = ENSURE_REQUIREMENTS,
advisoryPartitionSize: Option[Long] = None)
extends ShuffleExchangeLike {
private lazy val writeMetrics =
SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
private[sql] lazy val readMetrics =
SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext)
override lazy val metrics = Map(
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
"numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions")
) ++ readMetrics ++ writeMetrics
override def nodeName: String = "Exchange"
private lazy val serializer: Serializer =
new UnsafeRowSerializer(child.output.size, longMetric("dataSize"))
@transient lazy val inputRDD: RDD[InternalRow] = child.execute()
// 'mapOutputStatisticsFuture' is only needed when enable AQE.
@transient
override lazy val mapOutputStatisticsFuture: Future[MapOutputStatistics] = {
if (inputRDD.getNumPartitions == 0) {
Future.successful(null)
} else {
sparkContext.submitMapStage(shuffleDependency)
}
}
override def numMappers: Int = shuffleDependency.rdd.getNumPartitions
override def numPartitions: Int = shuffleDependency.partitioner.numPartitions
override def getShuffleRDD(partitionSpecs: Array[ShufflePartitionSpec]): RDD[InternalRow] = {
new ShuffledRowRDD(shuffleDependency, readMetrics, partitionSpecs)
}
override def runtimeStatistics: Statistics = {
val dataSize = metrics("dataSize").value
val rowCount = metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN).value
Statistics(dataSize, Some(rowCount))
}
override def shuffleId: Int = shuffleDependency.shuffleId
/**
* A [[ShuffleDependency]] that will partition rows of its child based on
* the partitioning scheme defined in `newPartitioning`. Those partitions of
* the returned ShuffleDependency will be the input of shuffle.
*/
@transient
lazy val shuffleDependency : ShuffleDependency[Int, InternalRow, InternalRow] = {
val dep = ShuffleExchangeExec.prepareShuffleDependency(
inputRDD,
child.output,
outputPartitioning,
serializer,
writeMetrics)
metrics("numPartitions").set(dep.partitioner.numPartitions)
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
SQLMetrics.postDriverMetricUpdates(
sparkContext, executionId, metrics("numPartitions") :: Nil)
dep
}
/**
* Caches the created ShuffleRowRDD so we can reuse that.
*/
private var cachedShuffleRDD: ShuffledRowRDD = null
protected override def doExecute(): RDD[InternalRow] = {
// Returns the same ShuffleRowRDD if this plan is used by multiple plans.
if (cachedShuffleRDD == null) {
cachedShuffleRDD = new ShuffledRowRDD(shuffleDependency, readMetrics)
}
cachedShuffleRDD
}
override protected def withNewChildInternal(newChild: SparkPlan): ShuffleExchangeExec =
copy(child = newChild)
}
object ShuffleExchangeExec {
/**
* Determines whether records must be defensively copied before being sent to the shuffle.
* Several of Spark's shuffle components will buffer deserialized Java objects in memory. The
* shuffle code assumes that objects are immutable and hence does not perform its own defensive
* copying. In Spark SQL, however, operators' iterators return the same mutable `Row` object. In
* order to properly shuffle the output of these operators, we need to perform our own copying
* prior to sending records to the shuffle. This copying is expensive, so we try to avoid it
* whenever possible. This method encapsulates the logic for choosing when to copy.
*
* In the long run, we might want to push this logic into core's shuffle APIs so that we don't
* have to rely on knowledge of core internals here in SQL.
*
* See SPARK-2967, SPARK-4479, and SPARK-7375 for more discussion of this issue.
*
* @param partitioner the partitioner for the shuffle
* @return true if rows should be copied before being shuffled, false otherwise
*/
private def needToCopyObjectsBeforeShuffle(partitioner: Partitioner): Boolean = {
// Note: even though we only use the partitioner's `numPartitions` field, we require it to be
// passed instead of directly passing the number of partitions in order to guard against
// corner-cases where a partitioner constructed with `numPartitions` partitions may output
// fewer partitions (like RangePartitioner, for example).
val conf = SparkEnv.get.conf
val shuffleManager = SparkEnv.get.shuffleManager
val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager]
val bypassMergeThreshold = conf.get(config.SHUFFLE_SORT_BYPASS_MERGE_THRESHOLD)
val numParts = partitioner.numPartitions
if (sortBasedShuffleOn) {
if (numParts <= bypassMergeThreshold) {
// If we're using the original SortShuffleManager and the number of output partitions is
// sufficiently small, then Spark will fall back to the hash-based shuffle write path, which
// doesn't buffer deserialized records.
// Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass.
false
} else if (numParts <= SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) {
// SPARK-4550 and SPARK-7081 extended sort-based shuffle to serialize individual records
// prior to sorting them. This optimization is only applied in cases where shuffle
// dependency does not specify an aggregator or ordering and the record serializer has
// certain properties and the number of partitions doesn't exceed the limitation. If this
// optimization is enabled, we can safely avoid the copy.
//
// Exchange never configures its ShuffledRDDs with aggregators or key orderings, and the
// serializer in Spark SQL always satisfy the properties, so we only need to check whether
// the number of partitions exceeds the limitation.
false
} else {
// Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory, so we must
// copy.
true
}
} else {
// Catch-all case to safely handle any future ShuffleManager implementations.
true
}
}
/**
* Returns a [[ShuffleDependency]] that will partition rows of its child based on
* the partitioning scheme defined in `newPartitioning`. Those partitions of
* the returned ShuffleDependency will be the input of shuffle.
*/
def prepareShuffleDependency(
rdd: RDD[InternalRow],
outputAttributes: Seq[Attribute],
newPartitioning: Partitioning,
serializer: Serializer,
writeMetrics: Map[String, SQLMetric])
: ShuffleDependency[Int, InternalRow, InternalRow] = {
val part: Partitioner = newPartitioning match {
case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions)
case HashPartitioning(_, n) =>
// For HashPartitioning, the partitioning key is already a valid partition ID, as we use
// `HashPartitioning.partitionIdExpression` to produce partitioning key.
new PartitionIdPassthrough(n)
case RangePartitioning(sortingExpressions, numPartitions) =>
// Extract only fields used for sorting to avoid collecting large fields that does not
// affect sorting result when deciding partition bounds in RangePartitioner
val rddForSampling = rdd.mapPartitionsInternal { iter =>
val projection =
UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
val mutablePair = new MutablePair[InternalRow, Null]()
// Internally, RangePartitioner runs a job on the RDD that samples keys to compute
// partition bounds. To get accurate samples, we need to copy the mutable keys.
iter.map(row => mutablePair.update(projection(row).copy(), null))
}
// Construct ordering on extracted sort key.
val orderingAttributes = sortingExpressions.zipWithIndex.map { case (ord, i) =>
ord.copy(child = BoundReference(i, ord.dataType, ord.nullable))
}
implicit val ordering = new LazilyGeneratedOrdering(orderingAttributes)
new RangePartitioner(
numPartitions,
rddForSampling,
ascending = true,
samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition)
case SinglePartition => new ConstantPartitioner
case k @ KeyGroupedPartitioning(expressions, n, _, _) =>
val valueMap = k.uniquePartitionValues.zipWithIndex.map {
case (partition, index) => (partition.toSeq(expressions.map(_.dataType)), index)
}.toMap
new KeyGroupedPartitioner(mutable.Map(valueMap.toSeq: _*), n)
case _ => throw SparkException.internalError(s"Exchange not implemented for $newPartitioning")
// TODO: Handle BroadcastPartitioning.
}
def getPartitionKeyExtractor(): InternalRow => Any = newPartitioning match {
case RoundRobinPartitioning(numPartitions) =>
// Distributes elements evenly across output partitions, starting from a random partition.
// nextInt(numPartitions) implementation has a special case when bound is a power of 2,
// which is basically taking several highest bits from the initial seed, with only a
// minimal scrambling. Due to deterministic seed, using the generator only once,
// and lack of scrambling, the position values for power-of-two numPartitions always
// end up being almost the same regardless of the index. substantially scrambling the
// seed by hashing will help. Refer to SPARK-21782 for more details.
val partitionId = TaskContext.get().partitionId()
var position = new XORShiftRandom(partitionId).nextInt(numPartitions)
(row: InternalRow) => {
// The HashPartitioner will handle the `mod` by the number of partitions
position += 1
position
}
case h: HashPartitioning =>
val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes)
row => projection(row).getInt(0)
case RangePartitioning(sortingExpressions, _) =>
val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
row => projection(row)
case SinglePartition => identity
case KeyGroupedPartitioning(expressions, _, _, _) =>
row => bindReferences(expressions, outputAttributes).map(_.eval(row))
case _ => throw SparkException.internalError(s"Exchange not implemented for $newPartitioning")
}
val isRoundRobin = newPartitioning.isInstanceOf[RoundRobinPartitioning] &&
newPartitioning.numPartitions > 1
val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = {
// [SPARK-23207] Have to make sure the generated RoundRobinPartitioning is deterministic,
// otherwise a retry task may output different rows and thus lead to data loss.
//
// Currently we following the most straight-forward way that perform a local sort before
// partitioning.
//
// Note that we don't perform local sort if the new partitioning has only 1 partition, under
// that case all output rows go to the same partition.
val newRdd = if (isRoundRobin && SQLConf.get.sortBeforeRepartition) {
rdd.mapPartitionsInternal { iter =>
val recordComparatorSupplier = new Supplier[RecordComparator] {
override def get: RecordComparator = new RecordBinaryComparator()
}
// The comparator for comparing row hashcode, which should always be Integer.
val prefixComparator = PrefixComparators.LONG
// The prefix computer generates row hashcode as the prefix, so we may decrease the
// probability that the prefixes are equal when input rows choose column values from a
// limited range.
val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix
override def computePrefix(row: InternalRow):
UnsafeExternalRowSorter.PrefixComputer.Prefix = {
// The hashcode generated from the binary form of a [[UnsafeRow]] should not be null.
result.isNull = false
result.value = row.hashCode()
result
}
}
val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
val sorter = UnsafeExternalRowSorter.createWithRecordComparator(
DataTypeUtils.fromAttributes(outputAttributes),
recordComparatorSupplier,
prefixComparator,
prefixComputer,
pageSize,
// We are comparing binary here, which does not support radix sort.
// See more details in SPARK-28699.
false)
sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]])
}
} else {
rdd
}
// round-robin function is order sensitive if we don't sort the input.
val isOrderSensitive = isRoundRobin && !SQLConf.get.sortBeforeRepartition
if (needToCopyObjectsBeforeShuffle(part)) {
newRdd.mapPartitionsWithIndexInternal((_, iter) => {
val getPartitionKey = getPartitionKeyExtractor()
iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) }
}, isOrderSensitive = isOrderSensitive)
} else {
newRdd.mapPartitionsWithIndexInternal((_, iter) => {
val getPartitionKey = getPartitionKeyExtractor()
val mutablePair = new MutablePair[Int, InternalRow]()
iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) }
}, isOrderSensitive = isOrderSensitive)
}
}
// Now, we manually create a ShuffleDependency. Because pairs in rddWithPartitionIds
// are in the form of (partitionId, row) and every partitionId is in the expected range
// [0, part.numPartitions - 1]. The partitioner of this is a PartitionIdPassthrough.
val dependency =
new ShuffleDependency[Int, InternalRow, InternalRow](
rddWithPartitionIds,
new PartitionIdPassthrough(part.numPartitions),
serializer,
shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics))
dependency
}
/**
* Create a customized [[ShuffleWriteProcessor]] for SQL which wrap the default metrics reporter
* with [[SQLShuffleWriteMetricsReporter]] as new reporter for [[ShuffleWriteProcessor]].
*/
def createShuffleWriteProcessor(metrics: Map[String, SQLMetric]): ShuffleWriteProcessor = {
new ShuffleWriteProcessor {
override protected def createMetricsReporter(
context: TaskContext): ShuffleWriteMetricsReporter = {
new SQLShuffleWriteMetricsReporter(context.taskMetrics().shuffleWriteMetrics, metrics)
}
}
}
}