Skip to content

Commit

Permalink
[CELEBORN-1144] Batch OpenStream RPCs
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Batch OpenStream RPCs by Worker to avoid too many RPCs.

### Why are the changes needed?
ditto

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Passes GA and Manual tests.

Closes #2362 from waitinfuture/1144.

Authored-by: zky.zhoukeyong <zky.zhoukeyong@alibaba-inc.com>
Signed-off-by: Shuang <lvshuang.xjs@alibaba-inc.com>
  • Loading branch information
waitinfuture authored and RexXiong committed Mar 25, 2024
1 parent e29f013 commit fc23800
Show file tree
Hide file tree
Showing 13 changed files with 448 additions and 189 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ public CelebornBufferStream readBufferedPartition(
}

@Override
protected ReduceFileGroups updateFileGroup(int shuffleId, int partitionId)
public ReduceFileGroups updateFileGroup(int shuffleId, int partitionId)
throws CelebornIOException {
ReduceFileGroups reduceFileGroups =
reduceFileGroupsMap.computeIfAbsent(shuffleId, (id) -> new ReduceFileGroups());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
package org.apache.spark.shuffle.celeborn

import java.io.IOException
import java.util
import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor, TimeUnit}
import java.util.concurrent.atomic.AtomicReference

import scala.collection.JavaConverters._

import org.apache.spark.{Aggregator, InterruptibleIterator, ShuffleDependency, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.serializer.SerializerInstance
Expand All @@ -33,7 +36,11 @@ import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.client.read.{CelebornInputStream, MetricsCallback}
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.exception.{CelebornIOException, PartitionUnRetryAbleException}
import org.apache.celeborn.common.util.{ExceptionMaker, ThreadUtils}
import org.apache.celeborn.common.network.client.TransportClient
import org.apache.celeborn.common.network.protocol.TransportMessage
import org.apache.celeborn.common.protocol.{MessageType, PartitionLocation, PbOpenStreamList, PbOpenStreamListResponse, PbStreamHandler}
import org.apache.celeborn.common.protocol.message.StatusCode
import org.apache.celeborn.common.util.{ExceptionMaker, JavaUtils, ThreadUtils, Utils}

class CelebornShuffleReader[K, C](
handle: CelebornShuffleHandle[K, _, C],
Expand Down Expand Up @@ -107,60 +114,139 @@ class CelebornShuffleReader[K, C](
}
}

val streams = new ConcurrentHashMap[Integer, CelebornInputStream]()
(startPartition until endPartition).map(partitionId => {
val startTime = System.currentTimeMillis()
val fetchTimeoutMs = conf.clientFetchTimeoutMs
val localFetchEnabled = conf.enableReadLocalShuffleFile
val shuffleKey = Utils.makeShuffleKey(handle.appUniqueId, shuffleId)
// startPartition is irrelevant
val fileGroups = shuffleClient.updateFileGroup(shuffleId, startPartition)
// host-port -> (TransportClient, PartitionLocation Array, PbOpenStreamList)
val workerRequestMap = new util.HashMap[
String,
(TransportClient, util.ArrayList[PartitionLocation], PbOpenStreamList.Builder)]()

var partCnt = 0

(startPartition until endPartition).foreach { partitionId =>
if (fileGroups.partitionGroups.containsKey(partitionId)) {
fileGroups.partitionGroups.get(partitionId).asScala.foreach { location =>
partCnt += 1
val hostPort = location.hostAndFetchPort
if (!workerRequestMap.containsKey(hostPort)) {
val client = shuffleClient.getDataClientFactory().createClient(
location.getHost,
location.getFetchPort)
val pbOpenStreamList = PbOpenStreamList.newBuilder()
pbOpenStreamList.setShuffleKey(shuffleKey)
workerRequestMap.put(
hostPort,
(client, new util.ArrayList[PartitionLocation], pbOpenStreamList))
}
val (_, locArr, pbOpenStreamListBuilder) = workerRequestMap.get(hostPort)

locArr.add(location)
pbOpenStreamListBuilder.addFileName(location.getFileName)
.addStartIndex(startMapIndex)
.addEndIndex(endMapIndex)
pbOpenStreamListBuilder.addReadLocalShuffle(localFetchEnabled)
}
}
}

val locationStreamHandlerMap: ConcurrentHashMap[PartitionLocation, PbStreamHandler] =
JavaUtils.newConcurrentHashMap()

val futures = workerRequestMap.values().asScala.map { entry =>
streamCreatorPool.submit(new Runnable {
override def run(): Unit = {
if (exceptionRef.get() == null) {
val (client, locArr, pbOpenStreamListBuilder) = entry
val msg = new TransportMessage(
MessageType.BATCH_OPEN_STREAM,
pbOpenStreamListBuilder.build().toByteArray)
val pbOpenStreamListResponse =
try {
val inputStream = shuffleClient.readPartition(
shuffleId,
handle.shuffleId,
partitionId,
context.attemptNumber(),
startMapIndex,
endMapIndex,
if (throwsFetchFailure) exceptionMaker else null,
metricsCallback)
streams.put(partitionId, inputStream)
val response = client.sendRpcSync(msg.toByteBuffer, fetchTimeoutMs)
TransportMessage.fromByteBuffer(response).getParsedPayload[PbOpenStreamListResponse]
} catch {
case e: IOException =>
logError(s"Exception caught when readPartition $partitionId!", e)
exceptionRef.compareAndSet(null, e)
case e: Throwable =>
logError(s"Non IOException caught when readPartition $partitionId!", e)
exceptionRef.compareAndSet(null, new CelebornIOException(e))
case _: Exception => null
}
if (pbOpenStreamListResponse != null) {
0 until locArr.size() foreach { idx =>
val streamHandlerOpt = pbOpenStreamListResponse.getStreamHandlerOptList.get(idx)
if (streamHandlerOpt.getStatus == StatusCode.SUCCESS.getValue) {
locationStreamHandlerMap.put(locArr.get(idx), streamHandlerOpt.getStreamHandler)
}
}
}
}
})
})
}.toList
// wait for all futures to complete
futures.foreach(f => f.get())
val end = System.currentTimeMillis()
logInfo(s"BatchOpenStream for $partCnt cost ${end - startTime}ms")

def createInputStream(partitionId: Int): CelebornInputStream = {
val locations =
if (fileGroups.partitionGroups.containsKey(partitionId)) {
new util.ArrayList(fileGroups.partitionGroups.get(partitionId))
} else new util.ArrayList[PartitionLocation]()
val streamHandlers =
if (locations != null) {
val streamHandlerArr = new util.ArrayList[PbStreamHandler](locations.size())
locations.asScala.foreach { loc =>
streamHandlerArr.add(locationStreamHandlerMap.get(loc))
}
streamHandlerArr
} else null
if (exceptionRef.get() == null) {
try {
shuffleClient.readPartition(
shuffleId,
handle.shuffleId,
partitionId,
context.attemptNumber(),
startMapIndex,
endMapIndex,
if (throwsFetchFailure) exceptionMaker else null,
locations,
streamHandlers,
fileGroups.mapAttempts,
metricsCallback)
} catch {
case e: IOException =>
logError(s"Exception caught when readPartition $partitionId!", e)
exceptionRef.compareAndSet(null, e)
null
case e: Throwable =>
logError(s"Non IOException caught when readPartition $partitionId!", e)
exceptionRef.compareAndSet(null, new CelebornIOException(e))
null
}
} else null
}

val recordIter = (startPartition until endPartition).iterator.map(partitionId => {
if (handle.numMappers > 0) {
val startFetchWait = System.nanoTime()
var inputStream: CelebornInputStream = streams.get(partitionId)
while (inputStream == null) {
if (exceptionRef.get() != null) {
exceptionRef.get() match {
case ce @ (_: CelebornIOException | _: PartitionUnRetryAbleException) =>
if (throwsFetchFailure &&
shuffleClient.reportShuffleFetchFailure(handle.shuffleId, shuffleId)) {
throw new FetchFailedException(
null,
handle.shuffleId,
-1,
-1,
partitionId,
SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId + "/" + shuffleId,
ce)
} else
throw ce
case e => throw e
}
val inputStream: CelebornInputStream = createInputStream(partitionId)
if (exceptionRef.get() != null) {
exceptionRef.get() match {
case ce @ (_: CelebornIOException | _: PartitionUnRetryAbleException) =>
if (throwsFetchFailure &&
shuffleClient.reportShuffleFetchFailure(handle.shuffleId, shuffleId)) {
throw new FetchFailedException(
null,
handle.shuffleId,
-1,
-1,
partitionId,
SparkUtils.FETCH_FAILURE_ERROR_MSG + handle.shuffleId + "/" + shuffleId,
ce)
} else
throw ce
case e => throw e
}
Thread.sleep(50)
inputStream = streams.get(partitionId)
}
metricsCallback.incReadTime(
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait))
Expand Down
15 changes: 15 additions & 0 deletions client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.celeborn.client;

import java.io.IOException;
import java.util.ArrayList;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.LongAdder;

Expand All @@ -28,8 +29,11 @@
import org.apache.celeborn.client.read.CelebornInputStream;
import org.apache.celeborn.client.read.MetricsCallback;
import org.apache.celeborn.common.CelebornConf;
import org.apache.celeborn.common.exception.CelebornIOException;
import org.apache.celeborn.common.identity.UserIdentifier;
import org.apache.celeborn.common.network.client.TransportClientFactory;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.protocol.PbStreamHandler;
import org.apache.celeborn.common.rpc.RpcEndpointRef;
import org.apache.celeborn.common.util.CelebornHadoopUtils;
import org.apache.celeborn.common.util.ExceptionMaker;
Expand Down Expand Up @@ -197,6 +201,9 @@ public abstract void mapPartitionMapperEnd(
// Cleanup states of the map task
public abstract void cleanup(int shuffleId, int mapId, int attemptId);

public abstract ShuffleClientImpl.ReduceFileGroups updateFileGroup(int shuffleId, int partitionId)
throws CelebornIOException;

// Reduce side read partition which is deduplicated by mapperId+mapperAttemptNum+batchId, batchId
// is a self-incrementing variable hidden in the implementation when sending data.
/**
Expand Down Expand Up @@ -227,6 +234,9 @@ public CelebornInputStream readPartition(
startMapIndex,
endMapIndex,
null,
null,
null,
null,
metricsCallback);
}

Expand All @@ -238,6 +248,9 @@ public abstract CelebornInputStream readPartition(
int startMapIndex,
int endMapIndex,
ExceptionMaker exceptionMaker,
ArrayList<PartitionLocation> locations,
ArrayList<PbStreamHandler> streamHandlers,
int[] mapAttempts,
MetricsCallback metricsCallback)
throws IOException;

Expand All @@ -261,4 +274,6 @@ public abstract ConcurrentHashMap<Integer, PartitionLocation> getPartitionLocati
* incorrect shuffle data can be fetched in re-run tasks
*/
public abstract boolean reportShuffleFetchFailure(int appShuffleId, int shuffleId);

public abstract TransportClientFactory getDataClientFactory();
}
Original file line number Diff line number Diff line change
Expand Up @@ -1646,7 +1646,7 @@ protected Tuple2<ReduceFileGroups, String> loadFileGroupInternal(int shuffleId)
}
}

protected ReduceFileGroups updateFileGroup(int shuffleId, int partitionId)
public ReduceFileGroups updateFileGroup(int shuffleId, int partitionId)
throws CelebornIOException {
if (reduceFileGroupsMap.containsKey(shuffleId)) {
return reduceFileGroupsMap.get(shuffleId);
Expand Down Expand Up @@ -1679,16 +1679,28 @@ public CelebornInputStream readPartition(
int startMapIndex,
int endMapIndex,
ExceptionMaker exceptionMaker,
ArrayList<PartitionLocation> locations,
ArrayList<PbStreamHandler> streamHandlers,
int[] mapAttempts,
MetricsCallback metricsCallback)
throws IOException {
if (partitionId == Utils$.MODULE$.UNKNOWN_APP_SHUFFLE_ID()) {
logger.warn("Shuffle data is empty for shuffle {}: UNKNOWN_APP_SHUFFLE_ID.", shuffleId);
return CelebornInputStream.empty();
}
ReduceFileGroups fileGroups = updateFileGroup(shuffleId, partitionId);

if (fileGroups.partitionGroups.isEmpty()
|| !fileGroups.partitionGroups.containsKey(partitionId)) {
// When `mapAttempts` is not null, it's guaranteed that the code path comes from
// CelebornShuffleReader, which means `updateFileGroup` is already called and
// batch open stream has been tried
if (mapAttempts == null) {
ReduceFileGroups fileGroups = updateFileGroup(shuffleId, partitionId);
mapAttempts = fileGroups.mapAttempts;
if (fileGroups.partitionGroups.containsKey(partitionId)) {
locations = new ArrayList(fileGroups.partitionGroups.get(partitionId));
}
}

if (locations == null || locations.size() == 0) {
logger.warn("Shuffle data is empty for shuffle {} partition {}.", shuffleId, partitionId);
return CelebornInputStream.empty();
} else {
Expand All @@ -1698,8 +1710,9 @@ public CelebornInputStream readPartition(
conf,
dataClientFactory,
shuffleKey,
fileGroups.partitionGroups.get(partitionId).toArray(new PartitionLocation[0]),
fileGroups.mapAttempts,
locations,
streamHandlers,
mapAttempts,
attemptNumber,
startMapIndex,
endMapIndex,
Expand Down
Loading

0 comments on commit fc23800

Please sign in to comment.