diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobWriter.java index 555cccfb7ca1a..2d5292b42cb07 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobWriter.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobWriter.java @@ -28,7 +28,6 @@ import java.io.IOException; import java.io.InputStream; -import java.util.Optional; /** BlobWriter is used to upload data to the BLOB store. */ public interface BlobWriter { @@ -103,13 +102,11 @@ static Either, PermanentBlobKey> tryOffload( if (serializedValue.getByteArray().length < blobWriter.getMinOffloadingSize()) { return Either.Left(serializedValue); } else { - return offloadWithException(serializedValue, jobId, blobWriter) - .map(Either::, PermanentBlobKey>Right) - .orElse(Either.Left(serializedValue)); + return offloadWithException(serializedValue, jobId, blobWriter); } } - static Optional offloadWithException( + static Either, PermanentBlobKey> offloadWithException( SerializedValue serializedValue, JobID jobId, BlobWriter blobWriter) { Preconditions.checkNotNull(serializedValue); Preconditions.checkNotNull(jobId); @@ -117,10 +114,10 @@ static Optional offloadWithException( try { final PermanentBlobKey permanentBlobKey = blobWriter.putPermanent(jobId, serializedValue.getByteArray()); - return Optional.of(permanentBlobKey); + return Either.Right(permanentBlobKey); } catch (IOException e) { LOG.warn("Failed to offload value for job {} to BLOB store.", jobId, e); - return Optional.empty(); + return Either.Left(serializedValue); } } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/CachedShuffleDescriptors.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/CachedShuffleDescriptors.java index 4ddacbd671a43..b8e0b44006fd3 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/CachedShuffleDescriptors.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/CachedShuffleDescriptors.java @@ -87,7 +87,7 @@ public void serializeShuffleDescriptors( new ShuffleDescriptorGroup( toBeSerialized.toArray(new ShuffleDescriptorAndIndex[0])); MaybeOffloaded serializedShuffleDescriptorGroup = - shuffleDescriptorSerializer.trySerializeAndOffloadShuffleDescriptor( + shuffleDescriptorSerializer.serializeAndTryOffloadShuffleDescriptor( shuffleDescriptorGroup, numConsumers); toBeSerialized.clear(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/InputGateDeploymentDescriptor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/InputGateDeploymentDescriptor.java index 4e02c6993313c..333a91e0a7320 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/InputGateDeploymentDescriptor.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/InputGateDeploymentDescriptor.java @@ -23,7 +23,7 @@ import org.apache.flink.runtime.blob.PermanentBlobKey; import org.apache.flink.runtime.blob.PermanentBlobService; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.MaybeOffloaded; -import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloadedRaw; +import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloaded; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.Offloaded; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory.ShuffleDescriptorAndIndex; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory.ShuffleDescriptorGroup; @@ -98,7 +98,9 @@ public InputGateDeploymentDescriptor( new IndexRange(consumedSubpartitionIndex, consumedSubpartitionIndex), inputChannels.length, Collections.singletonList( - new NonOffloadedRaw<>(new ShuffleDescriptorGroup(inputChannels)))); + new NonOffloaded<>( + CompressedSerializedValue.fromObject( + new ShuffleDescriptorGroup(inputChannels))))); } public InputGateDeploymentDescriptor( @@ -145,14 +147,18 @@ public ShuffleDescriptor[] getShuffleDescriptors() { // This is only for testing scenarios, in a production environment we always call // tryLoadAndDeserializeShuffleDescriptors to deserialize ShuffleDescriptors first. inputChannels = new ShuffleDescriptor[numberOfInputChannels]; - for (MaybeOffloaded rawShuffleDescriptors : - serializedInputChannels) { - checkState( - rawShuffleDescriptors instanceof NonOffloadedRaw, - "Trying to work with offloaded serialized shuffle descriptors."); - NonOffloadedRaw nonOffloadedRawValue = - (NonOffloadedRaw) rawShuffleDescriptors; - putOrReplaceShuffleDescriptors(nonOffloadedRawValue.value); + try { + for (MaybeOffloaded serializedShuffleDescriptors : + serializedInputChannels) { + checkState( + serializedShuffleDescriptors instanceof NonOffloaded, + "Trying to work with offloaded serialized shuffle descriptors."); + NonOffloaded nonOffloadedSerializedValue = + (NonOffloaded) serializedShuffleDescriptors; + tryDeserializeShuffleDescriptorGroup(nonOffloadedSerializedValue); + } + } catch (ClassNotFoundException | IOException e) { + throw new RuntimeException("Could not deserialize shuffle descriptors.", e); } } return inputChannels; @@ -207,12 +213,21 @@ private void tryLoadAndDeserializeShuffleDescriptorGroup( } putOrReplaceShuffleDescriptors(shuffleDescriptorGroup); } else { - NonOffloadedRaw nonOffloadedSerializedValue = - (NonOffloadedRaw) serializedShuffleDescriptors; - putOrReplaceShuffleDescriptors(nonOffloadedSerializedValue.value); + NonOffloaded nonOffloadedSerializedValue = + (NonOffloaded) serializedShuffleDescriptors; + tryDeserializeShuffleDescriptorGroup(nonOffloadedSerializedValue); } } + private void tryDeserializeShuffleDescriptorGroup( + NonOffloaded nonOffloadedShuffleDescriptorGroup) + throws IOException, ClassNotFoundException { + ShuffleDescriptorGroup shuffleDescriptorGroup = + nonOffloadedShuffleDescriptorGroup.serializedValue.deserializeValue( + getClass().getClassLoader()); + putOrReplaceShuffleDescriptors(shuffleDescriptorGroup); + } + private void putOrReplaceShuffleDescriptors(ShuffleDescriptorGroup shuffleDescriptorGroup) { for (ShuffleDescriptorAndIndex shuffleDescriptorAndIndex : shuffleDescriptorGroup.getShuffleDescriptors()) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java index 016105d9aaca2..5684066735f03 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java @@ -80,25 +80,6 @@ public NonOffloaded(SerializedValue serializedValue) { } } - /** - * The raw value that is not offloaded to the {@link org.apache.flink.runtime.blob.BlobServer}. - * - * @param type of the raw value - */ - public static class NonOffloadedRaw extends MaybeOffloaded { - private static final long serialVersionUID = 1L; - - /** The raw value. */ - public T value; - - @SuppressWarnings("unused") - public NonOffloadedRaw() {} - - public NonOffloadedRaw(T value) { - this.value = Preconditions.checkNotNull(value); - } - } - /** * Reference to a serialized value that was offloaded to the {@link * org.apache.flink.runtime.blob.BlobServer}. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java index d6a2b16010bdc..8b0498159a1e8 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java @@ -47,14 +47,11 @@ import org.apache.flink.runtime.shuffle.UnknownShuffleDescriptor; import org.apache.flink.types.Either; import org.apache.flink.util.CompressedSerializedValue; -import org.apache.flink.util.InstantiationUtil; import org.apache.flink.util.SerializedValue; import javax.annotation.Nullable; import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; @@ -452,7 +449,7 @@ public int getIndex() { public static class ShuffleDescriptorGroup implements Serializable { private static final long serialVersionUID = 1L; - private ShuffleDescriptorAndIndex[] shuffleDescriptors; + private final ShuffleDescriptorAndIndex[] shuffleDescriptors; public ShuffleDescriptorGroup(ShuffleDescriptorAndIndex[] shuffleDescriptors) { this.shuffleDescriptors = checkNotNull(shuffleDescriptors); @@ -461,31 +458,19 @@ public ShuffleDescriptorGroup(ShuffleDescriptorAndIndex[] shuffleDescriptors) { public ShuffleDescriptorAndIndex[] getShuffleDescriptors() { return shuffleDescriptors; } - - private void writeObject(ObjectOutputStream oos) throws IOException { - byte[] bytes = InstantiationUtil.serializeObjectAndCompress(shuffleDescriptors); - oos.writeObject(bytes); - } - - private void readObject(ObjectInputStream ois) throws IOException, ClassNotFoundException { - byte[] bytes = (byte[]) ois.readObject(); - shuffleDescriptors = - InstantiationUtil.decompressAndDeserializeObject( - bytes, ClassLoader.getSystemClassLoader()); - } } - /** Offload shuffle descriptors. */ + /** Serialize shuffle descriptors. */ interface ShuffleDescriptorSerializer { /** - * Try to serialize and offload shuffle descriptors. + * Serialize and try offload shuffle descriptors. * - * @param shuffleDescriptorGroup to serialize and offload + * @param shuffleDescriptorGroup to serialize * @param numConsumer consumers number of these shuffle descriptors, it means how many times * serialized shuffle descriptor should be sent - * @return offloaded serialized or non-offloaded raw shuffle descriptors + * @return offloaded or non-offloaded serialized shuffle descriptors */ - MaybeOffloaded trySerializeAndOffloadShuffleDescriptor( + MaybeOffloaded serializeAndTryOffloadShuffleDescriptor( ShuffleDescriptorGroup shuffleDescriptorGroup, int numConsumer) throws IOException; } @@ -502,24 +487,25 @@ public DefaultShuffleDescriptorSerializer( } @Override - public MaybeOffloaded trySerializeAndOffloadShuffleDescriptor( + public MaybeOffloaded serializeAndTryOffloadShuffleDescriptor( ShuffleDescriptorGroup shuffleDescriptorGroup, int numConsumer) throws IOException { - final Either rawValueOrBlobKey = - shouldOffload(shuffleDescriptorGroup.getShuffleDescriptors(), numConsumer) - ? BlobWriter.offloadWithException( - CompressedSerializedValue.fromObject( - shuffleDescriptorGroup), - jobID, - blobWriter) - .map(Either::Right) - .orElse(Either.Left(shuffleDescriptorGroup)) - : Either.Left(shuffleDescriptorGroup); - - if (rawValueOrBlobKey.isLeft()) { - return new TaskDeploymentDescriptor.NonOffloadedRaw<>(rawValueOrBlobKey.left()); + final CompressedSerializedValue compressedSerializedValue = + CompressedSerializedValue.fromObject(shuffleDescriptorGroup); + + final Either, PermanentBlobKey> + serializedValueOrBlobKey = + shouldOffload( + shuffleDescriptorGroup.getShuffleDescriptors(), + numConsumer) + ? BlobWriter.offloadWithException( + compressedSerializedValue, jobID, blobWriter) + : Either.Left(compressedSerializedValue); + + if (serializedValueOrBlobKey.isLeft()) { + return new TaskDeploymentDescriptor.NonOffloaded<>(serializedValueOrBlobKey.left()); } else { - return new TaskDeploymentDescriptor.Offloaded<>(rawValueOrBlobKey.right()); + return new TaskDeploymentDescriptor.Offloaded<>(serializedValueOrBlobKey.right()); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/CachedShuffleDescriptorsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/CachedShuffleDescriptorsTest.java index 0160d180e23b4..f9cd00e103bef 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/CachedShuffleDescriptorsTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/CachedShuffleDescriptorsTest.java @@ -20,7 +20,7 @@ import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutorServiceAdapter; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.MaybeOffloaded; -import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloadedRaw; +import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloaded; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory.ShuffleDescriptorAndIndex; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory.ShuffleDescriptorGroup; import org.apache.flink.runtime.executiongraph.ExecutionGraph; @@ -39,10 +39,12 @@ import org.apache.flink.runtime.testtasks.NoOpInvokable; import org.apache.flink.testutils.TestingUtils; import org.apache.flink.testutils.executor.TestExecutorExtension; +import org.apache.flink.util.CompressedSerializedValue; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; +import java.io.IOException; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -88,7 +90,7 @@ void testCreateAndGet() throws Exception { assertThat(cachedShuffleDescriptors.getAllSerializedShuffleDescriptorGroups()).hasSize(1); MaybeOffloaded maybeOffloadedShuffleDescriptor = cachedShuffleDescriptors.getAllSerializedShuffleDescriptorGroups().get(0); - assertNonOffloadedRawShuffleDescriptorAndIndexEquals( + assertNonOffloadedShuffleDescriptorAndIndexEquals( maybeOffloadedShuffleDescriptor, Collections.singletonList(shuffleDescriptor), Collections.singletonList(0)); @@ -142,22 +144,26 @@ void testMarkPartitionFinishAndSerialize() throws Exception { intermediateResultPartition2, TaskDeploymentDescriptorFactory.PartitionLocationConstraint.MUST_BE_KNOWN, false); - assertNonOffloadedRawShuffleDescriptorAndIndexEquals( + assertNonOffloadedShuffleDescriptorAndIndexEquals( maybeOffloaded, Arrays.asList(expectedShuffleDescriptor1, expectedShuffleDescriptor2), Arrays.asList(0, 1)); } - private void assertNonOffloadedRawShuffleDescriptorAndIndexEquals( + private void assertNonOffloadedShuffleDescriptorAndIndexEquals( MaybeOffloaded maybeOffloaded, List expectedDescriptors, - List expectedIndices) { + List expectedIndices) + throws Exception { assertThat(expectedDescriptors).hasSameSizeAs(expectedIndices); - assertThat(maybeOffloaded).isInstanceOf(NonOffloadedRaw.class); - NonOffloadedRaw nonOffloadedRaw = - (NonOffloadedRaw) maybeOffloaded; + assertThat(maybeOffloaded).isInstanceOf(NonOffloaded.class); + NonOffloaded nonOffloaded = + (NonOffloaded) maybeOffloaded; ShuffleDescriptorAndIndex[] shuffleDescriptorAndIndices = - nonOffloadedRaw.value.getShuffleDescriptors(); + nonOffloaded + .serializedValue + .deserializeValue(getClass().getClassLoader()) + .getShuffleDescriptors(); assertThat(shuffleDescriptorAndIndices).hasSameSizeAs(expectedDescriptors); for (int i = 0; i < shuffleDescriptorAndIndices.length; i++) { assertThat(shuffleDescriptorAndIndices[i].getIndex()).isEqualTo(expectedIndices.get(i)); @@ -212,9 +218,9 @@ private static class TestingShuffleDescriptorSerializer implements TaskDeploymentDescriptorFactory.ShuffleDescriptorSerializer { @Override - public MaybeOffloaded trySerializeAndOffloadShuffleDescriptor( - ShuffleDescriptorGroup shuffleDescriptorGroup, int numConsumer) { - return new NonOffloadedRaw<>(shuffleDescriptorGroup); + public MaybeOffloaded serializeAndTryOffloadShuffleDescriptor( + ShuffleDescriptorGroup shuffleDescriptorGroup, int numConsumer) throws IOException { + return new NonOffloaded<>(CompressedSerializedValue.fromObject(shuffleDescriptorGroup)); } } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTestUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTestUtils.java index 04683d4489c0c..19fbefe29208a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTestUtils.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTestUtils.java @@ -21,7 +21,7 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.blob.TestingBlobWriter; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.MaybeOffloaded; -import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloadedRaw; +import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloaded; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.Offloaded; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory.ShuffleDescriptorAndIndex; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory.ShuffleDescriptorGroup; @@ -47,8 +47,11 @@ public static ShuffleDescriptor[] deserializeShuffleDescriptors( int maxIndex = 0; for (MaybeOffloaded sd : maybeOffloaded) { ShuffleDescriptorGroup shuffleDescriptorGroup; - if (sd instanceof NonOffloadedRaw) { - shuffleDescriptorGroup = ((NonOffloadedRaw) sd).value; + if (sd instanceof NonOffloaded) { + shuffleDescriptorGroup = + ((NonOffloaded) sd) + .serializedValue.deserializeValue( + ClassLoader.getSystemClassLoader()); } else { final CompressedSerializedValue compressedSerializedValue = diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java index 071a23ad34bc3..0ae7420950708 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java @@ -71,6 +71,7 @@ import org.apache.flink.runtime.shuffle.NettyShuffleDescriptor; import org.apache.flink.runtime.shuffle.UnknownShuffleDescriptor; import org.apache.flink.runtime.util.NettyShuffleDescriptorBuilder; +import org.apache.flink.util.CompressedSerializedValue; import org.apache.flink.shaded.guava31.com.google.common.io.Closer; @@ -1312,8 +1313,9 @@ partitionIds[2], createExecutionAttemptId())), subpartitionIndexRange, channelDescs.length, Collections.singletonList( - new TaskDeploymentDescriptor.NonOffloadedRaw<>( - new ShuffleDescriptorGroup(channelDescs)))); + new TaskDeploymentDescriptor.NonOffloaded<>( + CompressedSerializedValue.fromObject( + new ShuffleDescriptorGroup(channelDescs))))); final TaskMetricGroup taskMetricGroup = UnregisteredMetricGroups.createUnregisteredTaskMetricGroup();