diff --git a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala index 36f164d697e..e12d4b697f7 100644 --- a/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala +++ b/common/src/main/scala/org/apache/celeborn/common/protocol/message/ControlMessages.scala @@ -1364,7 +1364,13 @@ object ControlMessages extends Logging { case HEARTBEAT_FROM_APPLICATION_RESPONSE_VALUE => val pbHeartbeatFromApplicationResponse = PbHeartbeatFromApplicationResponse.parseFrom(message.getPayload) - val pbCheckQuotaResponse = pbHeartbeatFromApplicationResponse.getCheckQuotaResponse + val checkQuotaResponse = + if (pbHeartbeatFromApplicationResponse.hasCheckQuotaResponse) { + val pbCheckQuotaResponse = pbHeartbeatFromApplicationResponse.getCheckQuotaResponse + CheckQuotaResponse(pbCheckQuotaResponse.getAvailable, pbCheckQuotaResponse.getReason) + } else { + CheckQuotaResponse(isAvailable = true, "") + } HeartbeatFromApplicationResponse( StatusCode.fromValue(pbHeartbeatFromApplicationResponse.getStatus), pbHeartbeatFromApplicationResponse.getExcludedWorkersList.asScala @@ -1374,7 +1380,7 @@ object ControlMessages extends Logging { pbHeartbeatFromApplicationResponse.getShuttingWorkersList.asScala .map(PbSerDeUtils.fromPbWorkerInfo).toList.asJava, pbHeartbeatFromApplicationResponse.getRegisteredShufflesList, - CheckQuotaResponse(pbCheckQuotaResponse.getAvailable, pbCheckQuotaResponse.getReason)) + checkQuotaResponse) case CHECK_QUOTA_VALUE => val pbCheckAvailable = PbCheckQuota.parseFrom(message.getPayload) diff --git a/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala b/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala index 5b8fe9979a1..47b618b566e 100644 --- a/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala +++ b/common/src/test/scala/org/apache/celeborn/common/util/PbSerDeUtilsTest.scala @@ -31,10 +31,11 @@ import org.apache.hadoop.shaded.org.apache.commons.lang3.RandomStringUtils import org.apache.celeborn.CelebornFunSuite import org.apache.celeborn.common.identity.UserIdentifier import org.apache.celeborn.common.meta._ -import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType, PbFileInfo, PbPackedWorkerResource, PbWorkerResource, StorageInfo} +import org.apache.celeborn.common.network.protocol.TransportMessage +import org.apache.celeborn.common.protocol.{MessageType, PartitionLocation, PartitionType, PbFileInfo, PbHeartbeatFromApplicationResponse, PbPackedWorkerResource, PbWorkerResource, StorageInfo} import org.apache.celeborn.common.protocol.PartitionLocation.Mode import org.apache.celeborn.common.protocol.message.{ControlMessages, StatusCode} -import org.apache.celeborn.common.protocol.message.ControlMessages.{GetReducerFileGroupResponse, WorkerResource} +import org.apache.celeborn.common.protocol.message.ControlMessages.{CheckQuotaResponse, GetReducerFileGroupResponse, HeartbeatFromApplicationResponse, WorkerResource} import org.apache.celeborn.common.quota.ResourceConsumption import org.apache.celeborn.common.util.PbSerDeUtils.{fromPbPackedPartitionLocationsPair, toPbPackedPartitionLocationsPair, toPbUserIdentifier} import org.apache.celeborn.common.write.LocationPushFailedBatches @@ -806,4 +807,52 @@ class PbSerDeUtilsTest extends CelebornFunSuite { assert(restoredFailedBatch.equals(failedBatch)) } + test("fromAndToHeartbeatFromApplicationResponse") { + val heartbeatFromApplicationResponse = HeartbeatFromApplicationResponse( + StatusCode.SUCCESS, + mockWorkers("host0").toList.asJava, + mockWorkers("host1").toList.asJava, + mockWorkers("host2").toList.asJava, + Array(Integer.valueOf(1)).toList.asJava, + CheckQuotaResponse(isAvailable = false, "test_reason")) + val toTransportHeartbeatFromApplicationResponse = + ControlMessages.toTransportMessage(heartbeatFromApplicationResponse) + val fromTransportHeartbeatFromApplicationResponse = + ControlMessages.fromTransportMessage(toTransportHeartbeatFromApplicationResponse) + .asInstanceOf[HeartbeatFromApplicationResponse] + + assert(fromTransportHeartbeatFromApplicationResponse.equals(heartbeatFromApplicationResponse)) + } + + test("HeartbeatFromApplicationResponse backward compatibility without checkQuotaResponse") { + val payload = PbHeartbeatFromApplicationResponse.newBuilder() + .setStatus(StatusCode.SUCCESS.getValue) + .addAllExcludedWorkers( + mockWorkers("host0").map(PbSerDeUtils.toPbWorkerInfo( + _, + true, + true)).toList.asJava) + .addAllUnknownWorkers( + mockWorkers("host1").map(PbSerDeUtils.toPbWorkerInfo( + _, + true, + true)).toList.asJava) + .addAllShuttingWorkers( + mockWorkers("host2").map(PbSerDeUtils.toPbWorkerInfo( + _, + true, + true)).toList.asJava) + .addAllRegisteredShuffles(Array(Integer.valueOf(1)).toList.asJava) + .build().toByteArray + val fromTransportHeartbeatFromApplicationResponse = ControlMessages.fromTransportMessage( + new TransportMessage(MessageType.HEARTBEAT_FROM_APPLICATION_RESPONSE, payload)) + .asInstanceOf[HeartbeatFromApplicationResponse] + assert( + fromTransportHeartbeatFromApplicationResponse.checkQuotaResponse.isAvailable.equals(true)) + assert(fromTransportHeartbeatFromApplicationResponse.checkQuotaResponse.reason.equals("")) + } + + def mockWorkers(host: String): Array[WorkerInfo] = { + Array(new WorkerInfo(host, -1, -1, -1, -1)) + } }