Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -127,23 +127,22 @@ class BlazeCelebornShuffleManager(conf: SparkConf, isDriver: Boolean)
context: TaskContext,
metrics: ShuffleWriteMetricsReporter): BlazeRssShuffleWriterBase[K, V] = {

// ensure celeborn client is initialized
assert(celebornShuffleManager.getWriter(handle, mapId, context, metrics) != null)
val celebornShuffleWriter = celebornShuffleManager.getWriter[K, V](handle, mapId, context, metrics)
val shuffleClient = FieldUtils
.readField(celebornShuffleManager, "shuffleClient", true)
.asInstanceOf[ShuffleClient]

val celebornHandle = handle.asInstanceOf[CelebornShuffleHandle[_, _, _]]
val celebornHandle = handle.asInstanceOf[CelebornShuffleHandle[K, V, _]]
val shuffleIdTracker = FieldUtils
.readField(celebornShuffleManager, "shuffleIdTracker", true)
.asInstanceOf[ExecutorShuffleIdTracker]
val writer = new BlazeCelebornShuffleWriter(
new BlazeCelebornShuffleWriter[K, V](
celebornShuffleWriter,
shuffleClient,
context,
celebornHandle,
metrics,
shuffleIdTracker)
writer.asInstanceOf[BlazeRssShuffleWriterBase[K, V]]
}

override def getRssShuffleWriter[K, V](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,19 @@ import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle, ExecutorShuffle
import org.apache.spark.sql.execution.blaze.shuffle.BlazeRssShuffleWriterBase
import org.apache.spark.sql.execution.blaze.shuffle.RssPartitionWriterBase
import org.apache.spark.TaskContext
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.ShuffleWriter

import com.thoughtworks.enableIf

class BlazeCelebornShuffleWriter[K, C](
class BlazeCelebornShuffleWriter[K, V](
celebornShuffleWriter: ShuffleWriter[K, V],
shuffleClient: ShuffleClient,
taskContext: TaskContext,
handle: CelebornShuffleHandle[K, _, C],
handle: CelebornShuffleHandle[K, V, _],
metrics: ShuffleWriteMetricsReporter,
shuffleIdTracker: ExecutorShuffleIdTracker)
extends BlazeRssShuffleWriterBase[K, C](metrics) {
extends BlazeRssShuffleWriterBase[K, V](metrics) {

private val numMappers = handle.numMappers
private val encodedAttemptId = BlazeCelebornShuffleManager.getEncodedAttemptNumber(taskContext)
Expand All @@ -58,4 +61,8 @@ class BlazeCelebornShuffleWriter[K, C](
System.getProperty("blaze.shim")))
override def getPartitionLengths(): Array[Long] = partitionLengths

override def stop(success: Boolean): Option[MapStatus] = {
celebornShuffleWriter.write(Iterator.empty) // force flush
celebornShuffleWriter.stop(success)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ import java.nio.ByteBuffer

import org.apache.spark.sql.execution.blaze.shuffle.RssPartitionWriterBase
import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.client.ShuffleClientImpl
import org.apache.spark.internal.Logging
import org.apache.spark.TaskContext
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
import org.apache.spark.shuffle.ShuffleWriter

class CelebornPartitionWriter(
shuffleClient: ShuffleClient,
Expand All @@ -40,8 +42,7 @@ class CelebornPartitionWriter(
val numBytes = buffer.limit()
val bytes = new Array[Byte](numBytes)
buffer.get(bytes)

val bytesWritten = shuffleClient.pushData(
val bytesWritten = shuffleClient.asInstanceOf[ShuffleClientImpl].pushOrMergeData(
shuffleId,
mapId,
encodedAttemptId,
Expand All @@ -50,7 +51,10 @@ class CelebornPartitionWriter(
0,
numBytes,
numMappers,
numPartitions)
numPartitions,
true, // doPush
true, // skipCompress
)
metrics.incBytesWritten(bytesWritten)
mapStatusLengths(partitionId) += bytesWritten
}
Expand Down
Loading