diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobWriter.java index 2d5292b42cb07..555cccfb7ca1a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobWriter.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/blob/BlobWriter.java @@ -28,6 +28,7 @@ import java.io.IOException; import java.io.InputStream; +import java.util.Optional; /** BlobWriter is used to upload data to the BLOB store. */ public interface BlobWriter { @@ -102,11 +103,13 @@ static Either, PermanentBlobKey> tryOffload( if (serializedValue.getByteArray().length < blobWriter.getMinOffloadingSize()) { return Either.Left(serializedValue); } else { - return offloadWithException(serializedValue, jobId, blobWriter); + return offloadWithException(serializedValue, jobId, blobWriter) + .map(Either::, PermanentBlobKey>Right) + .orElse(Either.Left(serializedValue)); } } - static Either, PermanentBlobKey> offloadWithException( + static Optional offloadWithException( SerializedValue serializedValue, JobID jobId, BlobWriter blobWriter) { Preconditions.checkNotNull(serializedValue); Preconditions.checkNotNull(jobId); @@ -114,10 +117,10 @@ static Either, PermanentBlobKey> offloadWithException( try { final PermanentBlobKey permanentBlobKey = blobWriter.putPermanent(jobId, serializedValue.getByteArray()); - return Either.Right(permanentBlobKey); + return Optional.of(permanentBlobKey); } catch (IOException e) { LOG.warn("Failed to offload value for job {} to BLOB store.", jobId, e); - return Either.Left(serializedValue); + return Optional.empty(); } } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/CachedShuffleDescriptors.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/CachedShuffleDescriptors.java index b8e0b44006fd3..4ddacbd671a43 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/CachedShuffleDescriptors.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/CachedShuffleDescriptors.java @@ -87,7 +87,7 @@ public void serializeShuffleDescriptors( new ShuffleDescriptorGroup( toBeSerialized.toArray(new ShuffleDescriptorAndIndex[0])); MaybeOffloaded serializedShuffleDescriptorGroup = - shuffleDescriptorSerializer.serializeAndTryOffloadShuffleDescriptor( + shuffleDescriptorSerializer.trySerializeAndOffloadShuffleDescriptor( shuffleDescriptorGroup, numConsumers); toBeSerialized.clear(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/InputGateDeploymentDescriptor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/InputGateDeploymentDescriptor.java index 333a91e0a7320..4e02c6993313c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/InputGateDeploymentDescriptor.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/InputGateDeploymentDescriptor.java @@ -23,7 +23,7 @@ import org.apache.flink.runtime.blob.PermanentBlobKey; import org.apache.flink.runtime.blob.PermanentBlobService; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.MaybeOffloaded; -import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloaded; +import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloadedRaw; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.Offloaded; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory.ShuffleDescriptorAndIndex; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory.ShuffleDescriptorGroup; @@ -98,9 +98,7 @@ public InputGateDeploymentDescriptor( new IndexRange(consumedSubpartitionIndex, consumedSubpartitionIndex), inputChannels.length, Collections.singletonList( - new NonOffloaded<>( - CompressedSerializedValue.fromObject( - new ShuffleDescriptorGroup(inputChannels))))); + new NonOffloadedRaw<>(new ShuffleDescriptorGroup(inputChannels)))); } public InputGateDeploymentDescriptor( @@ -147,18 +145,14 @@ public ShuffleDescriptor[] getShuffleDescriptors() { // This is only for testing scenarios, in a production environment we always call // tryLoadAndDeserializeShuffleDescriptors to deserialize ShuffleDescriptors first. inputChannels = new ShuffleDescriptor[numberOfInputChannels]; - try { - for (MaybeOffloaded serializedShuffleDescriptors : - serializedInputChannels) { - checkState( - serializedShuffleDescriptors instanceof NonOffloaded, - "Trying to work with offloaded serialized shuffle descriptors."); - NonOffloaded nonOffloadedSerializedValue = - (NonOffloaded) serializedShuffleDescriptors; - tryDeserializeShuffleDescriptorGroup(nonOffloadedSerializedValue); - } - } catch (ClassNotFoundException | IOException e) { - throw new RuntimeException("Could not deserialize shuffle descriptors.", e); + for (MaybeOffloaded rawShuffleDescriptors : + serializedInputChannels) { + checkState( + rawShuffleDescriptors instanceof NonOffloadedRaw, + "Trying to work with offloaded serialized shuffle descriptors."); + NonOffloadedRaw nonOffloadedRawValue = + (NonOffloadedRaw) rawShuffleDescriptors; + putOrReplaceShuffleDescriptors(nonOffloadedRawValue.value); } } return inputChannels; @@ -213,21 +207,12 @@ private void tryLoadAndDeserializeShuffleDescriptorGroup( } putOrReplaceShuffleDescriptors(shuffleDescriptorGroup); } else { - NonOffloaded nonOffloadedSerializedValue = - (NonOffloaded) serializedShuffleDescriptors; - tryDeserializeShuffleDescriptorGroup(nonOffloadedSerializedValue); + NonOffloadedRaw nonOffloadedSerializedValue = + (NonOffloadedRaw) serializedShuffleDescriptors; + putOrReplaceShuffleDescriptors(nonOffloadedSerializedValue.value); } } - private void tryDeserializeShuffleDescriptorGroup( - NonOffloaded nonOffloadedShuffleDescriptorGroup) - throws IOException, ClassNotFoundException { - ShuffleDescriptorGroup shuffleDescriptorGroup = - nonOffloadedShuffleDescriptorGroup.serializedValue.deserializeValue( - getClass().getClassLoader()); - putOrReplaceShuffleDescriptors(shuffleDescriptorGroup); - } - private void putOrReplaceShuffleDescriptors(ShuffleDescriptorGroup shuffleDescriptorGroup) { for (ShuffleDescriptorAndIndex shuffleDescriptorAndIndex : shuffleDescriptorGroup.getShuffleDescriptors()) { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java index 5684066735f03..016105d9aaca2 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptor.java @@ -80,6 +80,25 @@ public NonOffloaded(SerializedValue serializedValue) { } } + /** + * The raw value that is not offloaded to the {@link org.apache.flink.runtime.blob.BlobServer}. + * + * @param type of the raw value + */ + public static class NonOffloadedRaw extends MaybeOffloaded { + private static final long serialVersionUID = 1L; + + /** The raw value. */ + public T value; + + @SuppressWarnings("unused") + public NonOffloadedRaw() {} + + public NonOffloadedRaw(T value) { + this.value = Preconditions.checkNotNull(value); + } + } + /** * Reference to a serialized value that was offloaded to the {@link * org.apache.flink.runtime.blob.BlobServer}. diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java index 8b0498159a1e8..d6a2b16010bdc 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java @@ -47,11 +47,14 @@ import org.apache.flink.runtime.shuffle.UnknownShuffleDescriptor; import org.apache.flink.types.Either; import org.apache.flink.util.CompressedSerializedValue; +import org.apache.flink.util.InstantiationUtil; import org.apache.flink.util.SerializedValue; import javax.annotation.Nullable; import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; @@ -449,7 +452,7 @@ public int getIndex() { public static class ShuffleDescriptorGroup implements Serializable { private static final long serialVersionUID = 1L; - private final ShuffleDescriptorAndIndex[] shuffleDescriptors; + private ShuffleDescriptorAndIndex[] shuffleDescriptors; public ShuffleDescriptorGroup(ShuffleDescriptorAndIndex[] shuffleDescriptors) { this.shuffleDescriptors = checkNotNull(shuffleDescriptors); @@ -458,19 +461,31 @@ public ShuffleDescriptorGroup(ShuffleDescriptorAndIndex[] shuffleDescriptors) { public ShuffleDescriptorAndIndex[] getShuffleDescriptors() { return shuffleDescriptors; } + + private void writeObject(ObjectOutputStream oos) throws IOException { + byte[] bytes = InstantiationUtil.serializeObjectAndCompress(shuffleDescriptors); + oos.writeObject(bytes); + } + + private void readObject(ObjectInputStream ois) throws IOException, ClassNotFoundException { + byte[] bytes = (byte[]) ois.readObject(); + shuffleDescriptors = + InstantiationUtil.decompressAndDeserializeObject( + bytes, ClassLoader.getSystemClassLoader()); + } } - /** Serialize shuffle descriptors. */ + /** Offload shuffle descriptors. */ interface ShuffleDescriptorSerializer { /** - * Serialize and try offload shuffle descriptors. + * Try to serialize and offload shuffle descriptors. * - * @param shuffleDescriptorGroup to serialize + * @param shuffleDescriptorGroup to serialize and offload * @param numConsumer consumers number of these shuffle descriptors, it means how many times * serialized shuffle descriptor should be sent - * @return offloaded or non-offloaded serialized shuffle descriptors + * @return offloaded serialized or non-offloaded raw shuffle descriptors */ - MaybeOffloaded serializeAndTryOffloadShuffleDescriptor( + MaybeOffloaded trySerializeAndOffloadShuffleDescriptor( ShuffleDescriptorGroup shuffleDescriptorGroup, int numConsumer) throws IOException; } @@ -487,25 +502,24 @@ public DefaultShuffleDescriptorSerializer( } @Override - public MaybeOffloaded serializeAndTryOffloadShuffleDescriptor( + public MaybeOffloaded trySerializeAndOffloadShuffleDescriptor( ShuffleDescriptorGroup shuffleDescriptorGroup, int numConsumer) throws IOException { - final CompressedSerializedValue compressedSerializedValue = - CompressedSerializedValue.fromObject(shuffleDescriptorGroup); - - final Either, PermanentBlobKey> - serializedValueOrBlobKey = - shouldOffload( - shuffleDescriptorGroup.getShuffleDescriptors(), - numConsumer) - ? BlobWriter.offloadWithException( - compressedSerializedValue, jobID, blobWriter) - : Either.Left(compressedSerializedValue); - - if (serializedValueOrBlobKey.isLeft()) { - return new TaskDeploymentDescriptor.NonOffloaded<>(serializedValueOrBlobKey.left()); + final Either rawValueOrBlobKey = + shouldOffload(shuffleDescriptorGroup.getShuffleDescriptors(), numConsumer) + ? BlobWriter.offloadWithException( + CompressedSerializedValue.fromObject( + shuffleDescriptorGroup), + jobID, + blobWriter) + .map(Either::Right) + .orElse(Either.Left(shuffleDescriptorGroup)) + : Either.Left(shuffleDescriptorGroup); + + if (rawValueOrBlobKey.isLeft()) { + return new TaskDeploymentDescriptor.NonOffloadedRaw<>(rawValueOrBlobKey.left()); } else { - return new TaskDeploymentDescriptor.Offloaded<>(serializedValueOrBlobKey.right()); + return new TaskDeploymentDescriptor.Offloaded<>(rawValueOrBlobKey.right()); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/CachedShuffleDescriptorsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/CachedShuffleDescriptorsTest.java index f9cd00e103bef..0160d180e23b4 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/CachedShuffleDescriptorsTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/CachedShuffleDescriptorsTest.java @@ -20,7 +20,7 @@ import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutorServiceAdapter; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.MaybeOffloaded; -import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloaded; +import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloadedRaw; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory.ShuffleDescriptorAndIndex; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory.ShuffleDescriptorGroup; import org.apache.flink.runtime.executiongraph.ExecutionGraph; @@ -39,12 +39,10 @@ import org.apache.flink.runtime.testtasks.NoOpInvokable; import org.apache.flink.testutils.TestingUtils; import org.apache.flink.testutils.executor.TestExecutorExtension; -import org.apache.flink.util.CompressedSerializedValue; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; -import java.io.IOException; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -90,7 +88,7 @@ void testCreateAndGet() throws Exception { assertThat(cachedShuffleDescriptors.getAllSerializedShuffleDescriptorGroups()).hasSize(1); MaybeOffloaded maybeOffloadedShuffleDescriptor = cachedShuffleDescriptors.getAllSerializedShuffleDescriptorGroups().get(0); - assertNonOffloadedShuffleDescriptorAndIndexEquals( + assertNonOffloadedRawShuffleDescriptorAndIndexEquals( maybeOffloadedShuffleDescriptor, Collections.singletonList(shuffleDescriptor), Collections.singletonList(0)); @@ -144,26 +142,22 @@ void testMarkPartitionFinishAndSerialize() throws Exception { intermediateResultPartition2, TaskDeploymentDescriptorFactory.PartitionLocationConstraint.MUST_BE_KNOWN, false); - assertNonOffloadedShuffleDescriptorAndIndexEquals( + assertNonOffloadedRawShuffleDescriptorAndIndexEquals( maybeOffloaded, Arrays.asList(expectedShuffleDescriptor1, expectedShuffleDescriptor2), Arrays.asList(0, 1)); } - private void assertNonOffloadedShuffleDescriptorAndIndexEquals( + private void assertNonOffloadedRawShuffleDescriptorAndIndexEquals( MaybeOffloaded maybeOffloaded, List expectedDescriptors, - List expectedIndices) - throws Exception { + List expectedIndices) { assertThat(expectedDescriptors).hasSameSizeAs(expectedIndices); - assertThat(maybeOffloaded).isInstanceOf(NonOffloaded.class); - NonOffloaded nonOffloaded = - (NonOffloaded) maybeOffloaded; + assertThat(maybeOffloaded).isInstanceOf(NonOffloadedRaw.class); + NonOffloadedRaw nonOffloadedRaw = + (NonOffloadedRaw) maybeOffloaded; ShuffleDescriptorAndIndex[] shuffleDescriptorAndIndices = - nonOffloaded - .serializedValue - .deserializeValue(getClass().getClassLoader()) - .getShuffleDescriptors(); + nonOffloadedRaw.value.getShuffleDescriptors(); assertThat(shuffleDescriptorAndIndices).hasSameSizeAs(expectedDescriptors); for (int i = 0; i < shuffleDescriptorAndIndices.length; i++) { assertThat(shuffleDescriptorAndIndices[i].getIndex()).isEqualTo(expectedIndices.get(i)); @@ -218,9 +212,9 @@ private static class TestingShuffleDescriptorSerializer implements TaskDeploymentDescriptorFactory.ShuffleDescriptorSerializer { @Override - public MaybeOffloaded serializeAndTryOffloadShuffleDescriptor( - ShuffleDescriptorGroup shuffleDescriptorGroup, int numConsumer) throws IOException { - return new NonOffloaded<>(CompressedSerializedValue.fromObject(shuffleDescriptorGroup)); + public MaybeOffloaded trySerializeAndOffloadShuffleDescriptor( + ShuffleDescriptorGroup shuffleDescriptorGroup, int numConsumer) { + return new NonOffloadedRaw<>(shuffleDescriptorGroup); } } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTestUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTestUtils.java index 19fbefe29208a..04683d4489c0c 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTestUtils.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTestUtils.java @@ -21,7 +21,7 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.runtime.blob.TestingBlobWriter; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.MaybeOffloaded; -import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloaded; +import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloadedRaw; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.Offloaded; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory.ShuffleDescriptorAndIndex; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory.ShuffleDescriptorGroup; @@ -47,11 +47,8 @@ public static ShuffleDescriptor[] deserializeShuffleDescriptors( int maxIndex = 0; for (MaybeOffloaded sd : maybeOffloaded) { ShuffleDescriptorGroup shuffleDescriptorGroup; - if (sd instanceof NonOffloaded) { - shuffleDescriptorGroup = - ((NonOffloaded) sd) - .serializedValue.deserializeValue( - ClassLoader.getSystemClassLoader()); + if (sd instanceof NonOffloadedRaw) { + shuffleDescriptorGroup = ((NonOffloadedRaw) sd).value; } else { final CompressedSerializedValue compressedSerializedValue = diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java index 0e452f0dde770..21a64e00cf451 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java @@ -70,7 +70,6 @@ import org.apache.flink.runtime.shuffle.NettyShuffleDescriptor; import org.apache.flink.runtime.shuffle.UnknownShuffleDescriptor; import org.apache.flink.runtime.util.NettyShuffleDescriptorBuilder; -import org.apache.flink.util.CompressedSerializedValue; import org.apache.flink.shaded.guava31.com.google.common.io.Closer; @@ -1337,9 +1336,8 @@ partitionIds[2], createExecutionAttemptId())), subpartitionIndexRange, channelDescs.length, Collections.singletonList( - new TaskDeploymentDescriptor.NonOffloaded<>( - CompressedSerializedValue.fromObject( - new ShuffleDescriptorGroup(channelDescs))))); + new TaskDeploymentDescriptor.NonOffloadedRaw<>( + new ShuffleDescriptorGroup(channelDescs)))); final TaskMetricGroup taskMetricGroup = UnregisteredMetricGroups.createUnregisteredTaskMetricGroup();