Skip to content

Commit

Permalink
[FLINK-33532][network] Move the serialization of ShuffleDescriptorGro…
Browse files Browse the repository at this point in the history
…up out of the RPC main thread]
  • Loading branch information
caodizhou authored and KarmaGYZ committed Nov 16, 2023
1 parent 8ef71ba commit d18a4bf
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import java.io.IOException;
import java.io.InputStream;
import java.util.Optional;

/** BlobWriter is used to upload data to the BLOB store. */
public interface BlobWriter {
Expand Down Expand Up @@ -102,22 +103,24 @@ static <T> Either<SerializedValue<T>, PermanentBlobKey> tryOffload(
if (serializedValue.getByteArray().length < blobWriter.getMinOffloadingSize()) {
return Either.Left(serializedValue);
} else {
return offloadWithException(serializedValue, jobId, blobWriter);
return offloadWithException(serializedValue, jobId, blobWriter)
.map(Either::<SerializedValue<T>, PermanentBlobKey>Right)
.orElse(Either.Left(serializedValue));
}
}

static <T> Either<SerializedValue<T>, PermanentBlobKey> offloadWithException(
static <T> Optional<PermanentBlobKey> offloadWithException(
SerializedValue<T> serializedValue, JobID jobId, BlobWriter blobWriter) {
Preconditions.checkNotNull(serializedValue);
Preconditions.checkNotNull(jobId);
Preconditions.checkNotNull(blobWriter);
try {
final PermanentBlobKey permanentBlobKey =
blobWriter.putPermanent(jobId, serializedValue.getByteArray());
return Either.Right(permanentBlobKey);
return Optional.of(permanentBlobKey);
} catch (IOException e) {
LOG.warn("Failed to offload value for job {} to BLOB store.", jobId, e);
return Either.Left(serializedValue);
return Optional.empty();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public void serializeShuffleDescriptors(
new ShuffleDescriptorGroup(
toBeSerialized.toArray(new ShuffleDescriptorAndIndex[0]));
MaybeOffloaded<ShuffleDescriptorGroup> serializedShuffleDescriptorGroup =
shuffleDescriptorSerializer.serializeAndTryOffloadShuffleDescriptor(
shuffleDescriptorSerializer.trySerializeAndOffloadShuffleDescriptor(
shuffleDescriptorGroup, numConsumers);

toBeSerialized.clear();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import org.apache.flink.runtime.blob.PermanentBlobKey;
import org.apache.flink.runtime.blob.PermanentBlobService;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.MaybeOffloaded;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloaded;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloadedRaw;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.Offloaded;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory.ShuffleDescriptorAndIndex;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory.ShuffleDescriptorGroup;
Expand Down Expand Up @@ -98,9 +98,7 @@ public InputGateDeploymentDescriptor(
new IndexRange(consumedSubpartitionIndex, consumedSubpartitionIndex),
inputChannels.length,
Collections.singletonList(
new NonOffloaded<>(
CompressedSerializedValue.fromObject(
new ShuffleDescriptorGroup(inputChannels)))));
new NonOffloadedRaw<>(new ShuffleDescriptorGroup(inputChannels))));
}

public InputGateDeploymentDescriptor(
Expand Down Expand Up @@ -147,18 +145,14 @@ public ShuffleDescriptor[] getShuffleDescriptors() {
// This is only for testing scenarios, in a production environment we always call
// tryLoadAndDeserializeShuffleDescriptors to deserialize ShuffleDescriptors first.
inputChannels = new ShuffleDescriptor[numberOfInputChannels];
try {
for (MaybeOffloaded<ShuffleDescriptorGroup> serializedShuffleDescriptors :
serializedInputChannels) {
checkState(
serializedShuffleDescriptors instanceof NonOffloaded,
"Trying to work with offloaded serialized shuffle descriptors.");
NonOffloaded<ShuffleDescriptorGroup> nonOffloadedSerializedValue =
(NonOffloaded<ShuffleDescriptorGroup>) serializedShuffleDescriptors;
tryDeserializeShuffleDescriptorGroup(nonOffloadedSerializedValue);
}
} catch (ClassNotFoundException | IOException e) {
throw new RuntimeException("Could not deserialize shuffle descriptors.", e);
for (MaybeOffloaded<ShuffleDescriptorGroup> rawShuffleDescriptors :
serializedInputChannels) {
checkState(
rawShuffleDescriptors instanceof NonOffloadedRaw,
"Trying to work with offloaded serialized shuffle descriptors.");
NonOffloadedRaw<ShuffleDescriptorGroup> nonOffloadedRawValue =
(NonOffloadedRaw<ShuffleDescriptorGroup>) rawShuffleDescriptors;
putOrReplaceShuffleDescriptors(nonOffloadedRawValue.value);
}
}
return inputChannels;
Expand Down Expand Up @@ -213,21 +207,12 @@ private void tryLoadAndDeserializeShuffleDescriptorGroup(
}
putOrReplaceShuffleDescriptors(shuffleDescriptorGroup);
} else {
NonOffloaded<ShuffleDescriptorGroup> nonOffloadedSerializedValue =
(NonOffloaded<ShuffleDescriptorGroup>) serializedShuffleDescriptors;
tryDeserializeShuffleDescriptorGroup(nonOffloadedSerializedValue);
NonOffloadedRaw<ShuffleDescriptorGroup> nonOffloadedSerializedValue =
(NonOffloadedRaw<ShuffleDescriptorGroup>) serializedShuffleDescriptors;
putOrReplaceShuffleDescriptors(nonOffloadedSerializedValue.value);
}
}

private void tryDeserializeShuffleDescriptorGroup(
NonOffloaded<ShuffleDescriptorGroup> nonOffloadedShuffleDescriptorGroup)
throws IOException, ClassNotFoundException {
ShuffleDescriptorGroup shuffleDescriptorGroup =
nonOffloadedShuffleDescriptorGroup.serializedValue.deserializeValue(
getClass().getClassLoader());
putOrReplaceShuffleDescriptors(shuffleDescriptorGroup);
}

private void putOrReplaceShuffleDescriptors(ShuffleDescriptorGroup shuffleDescriptorGroup) {
for (ShuffleDescriptorAndIndex shuffleDescriptorAndIndex :
shuffleDescriptorGroup.getShuffleDescriptors()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,25 @@ public NonOffloaded(SerializedValue<T> serializedValue) {
}
}

/**
* The raw value that is not offloaded to the {@link org.apache.flink.runtime.blob.BlobServer}.
*
* @param <T> type of the raw value
*/
public static class NonOffloadedRaw<T> extends MaybeOffloaded<T> {
private static final long serialVersionUID = 1L;

/** The raw value. */
public T value;

@SuppressWarnings("unused")
public NonOffloadedRaw() {}

public NonOffloadedRaw(T value) {
this.value = Preconditions.checkNotNull(value);
}
}

/**
* Reference to a serialized value that was offloaded to the {@link
* org.apache.flink.runtime.blob.BlobServer}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,14 @@
import org.apache.flink.runtime.shuffle.UnknownShuffleDescriptor;
import org.apache.flink.types.Either;
import org.apache.flink.util.CompressedSerializedValue;
import org.apache.flink.util.InstantiationUtil;
import org.apache.flink.util.SerializedValue;

import javax.annotation.Nullable;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
Expand Down Expand Up @@ -449,7 +452,7 @@ public int getIndex() {
public static class ShuffleDescriptorGroup implements Serializable {
private static final long serialVersionUID = 1L;

private final ShuffleDescriptorAndIndex[] shuffleDescriptors;
private ShuffleDescriptorAndIndex[] shuffleDescriptors;

public ShuffleDescriptorGroup(ShuffleDescriptorAndIndex[] shuffleDescriptors) {
this.shuffleDescriptors = checkNotNull(shuffleDescriptors);
Expand All @@ -458,19 +461,31 @@ public ShuffleDescriptorGroup(ShuffleDescriptorAndIndex[] shuffleDescriptors) {
public ShuffleDescriptorAndIndex[] getShuffleDescriptors() {
return shuffleDescriptors;
}

private void writeObject(ObjectOutputStream oos) throws IOException {
byte[] bytes = InstantiationUtil.serializeObjectAndCompress(shuffleDescriptors);
oos.writeObject(bytes);
}

private void readObject(ObjectInputStream ois) throws IOException, ClassNotFoundException {
byte[] bytes = (byte[]) ois.readObject();
shuffleDescriptors =
InstantiationUtil.decompressAndDeserializeObject(
bytes, ClassLoader.getSystemClassLoader());
}
}

/** Serialize shuffle descriptors. */
/** Offload shuffle descriptors. */
interface ShuffleDescriptorSerializer {
/**
* Serialize and try offload shuffle descriptors.
* Try to serialize and offload shuffle descriptors.
*
* @param shuffleDescriptorGroup to serialize
* @param shuffleDescriptorGroup to serialize and offload
* @param numConsumer consumers number of these shuffle descriptors, it means how many times
* serialized shuffle descriptor should be sent
* @return offloaded or non-offloaded serialized shuffle descriptors
* @return offloaded serialized or non-offloaded raw shuffle descriptors
*/
MaybeOffloaded<ShuffleDescriptorGroup> serializeAndTryOffloadShuffleDescriptor(
MaybeOffloaded<ShuffleDescriptorGroup> trySerializeAndOffloadShuffleDescriptor(
ShuffleDescriptorGroup shuffleDescriptorGroup, int numConsumer) throws IOException;
}

Expand All @@ -487,25 +502,24 @@ public DefaultShuffleDescriptorSerializer(
}

@Override
public MaybeOffloaded<ShuffleDescriptorGroup> serializeAndTryOffloadShuffleDescriptor(
public MaybeOffloaded<ShuffleDescriptorGroup> trySerializeAndOffloadShuffleDescriptor(
ShuffleDescriptorGroup shuffleDescriptorGroup, int numConsumer) throws IOException {

final CompressedSerializedValue<ShuffleDescriptorGroup> compressedSerializedValue =
CompressedSerializedValue.fromObject(shuffleDescriptorGroup);

final Either<SerializedValue<ShuffleDescriptorGroup>, PermanentBlobKey>
serializedValueOrBlobKey =
shouldOffload(
shuffleDescriptorGroup.getShuffleDescriptors(),
numConsumer)
? BlobWriter.offloadWithException(
compressedSerializedValue, jobID, blobWriter)
: Either.Left(compressedSerializedValue);

if (serializedValueOrBlobKey.isLeft()) {
return new TaskDeploymentDescriptor.NonOffloaded<>(serializedValueOrBlobKey.left());
final Either<ShuffleDescriptorGroup, PermanentBlobKey> rawValueOrBlobKey =
shouldOffload(shuffleDescriptorGroup.getShuffleDescriptors(), numConsumer)
? BlobWriter.offloadWithException(
CompressedSerializedValue.fromObject(
shuffleDescriptorGroup),
jobID,
blobWriter)
.map(Either::<ShuffleDescriptorGroup, PermanentBlobKey>Right)
.orElse(Either.Left(shuffleDescriptorGroup))
: Either.Left(shuffleDescriptorGroup);

if (rawValueOrBlobKey.isLeft()) {
return new TaskDeploymentDescriptor.NonOffloadedRaw<>(rawValueOrBlobKey.left());
} else {
return new TaskDeploymentDescriptor.Offloaded<>(serializedValueOrBlobKey.right());
return new TaskDeploymentDescriptor.Offloaded<>(rawValueOrBlobKey.right());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutorServiceAdapter;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.MaybeOffloaded;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloaded;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloadedRaw;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory.ShuffleDescriptorAndIndex;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory.ShuffleDescriptorGroup;
import org.apache.flink.runtime.executiongraph.ExecutionGraph;
Expand All @@ -39,12 +39,10 @@
import org.apache.flink.runtime.testtasks.NoOpInvokable;
import org.apache.flink.testutils.TestingUtils;
import org.apache.flink.testutils.executor.TestExecutorExtension;
import org.apache.flink.util.CompressedSerializedValue;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -90,7 +88,7 @@ void testCreateAndGet() throws Exception {
assertThat(cachedShuffleDescriptors.getAllSerializedShuffleDescriptorGroups()).hasSize(1);
MaybeOffloaded<ShuffleDescriptorGroup> maybeOffloadedShuffleDescriptor =
cachedShuffleDescriptors.getAllSerializedShuffleDescriptorGroups().get(0);
assertNonOffloadedShuffleDescriptorAndIndexEquals(
assertNonOffloadedRawShuffleDescriptorAndIndexEquals(
maybeOffloadedShuffleDescriptor,
Collections.singletonList(shuffleDescriptor),
Collections.singletonList(0));
Expand Down Expand Up @@ -144,26 +142,22 @@ void testMarkPartitionFinishAndSerialize() throws Exception {
intermediateResultPartition2,
TaskDeploymentDescriptorFactory.PartitionLocationConstraint.MUST_BE_KNOWN,
false);
assertNonOffloadedShuffleDescriptorAndIndexEquals(
assertNonOffloadedRawShuffleDescriptorAndIndexEquals(
maybeOffloaded,
Arrays.asList(expectedShuffleDescriptor1, expectedShuffleDescriptor2),
Arrays.asList(0, 1));
}

private void assertNonOffloadedShuffleDescriptorAndIndexEquals(
private void assertNonOffloadedRawShuffleDescriptorAndIndexEquals(
MaybeOffloaded<ShuffleDescriptorGroup> maybeOffloaded,
List<ShuffleDescriptor> expectedDescriptors,
List<Integer> expectedIndices)
throws Exception {
List<Integer> expectedIndices) {
assertThat(expectedDescriptors).hasSameSizeAs(expectedIndices);
assertThat(maybeOffloaded).isInstanceOf(NonOffloaded.class);
NonOffloaded<ShuffleDescriptorGroup> nonOffloaded =
(NonOffloaded<ShuffleDescriptorGroup>) maybeOffloaded;
assertThat(maybeOffloaded).isInstanceOf(NonOffloadedRaw.class);
NonOffloadedRaw<ShuffleDescriptorGroup> nonOffloadedRaw =
(NonOffloadedRaw<ShuffleDescriptorGroup>) maybeOffloaded;
ShuffleDescriptorAndIndex[] shuffleDescriptorAndIndices =
nonOffloaded
.serializedValue
.deserializeValue(getClass().getClassLoader())
.getShuffleDescriptors();
nonOffloadedRaw.value.getShuffleDescriptors();
assertThat(shuffleDescriptorAndIndices).hasSameSizeAs(expectedDescriptors);
for (int i = 0; i < shuffleDescriptorAndIndices.length; i++) {
assertThat(shuffleDescriptorAndIndices[i].getIndex()).isEqualTo(expectedIndices.get(i));
Expand Down Expand Up @@ -218,9 +212,9 @@ private static class TestingShuffleDescriptorSerializer
implements TaskDeploymentDescriptorFactory.ShuffleDescriptorSerializer {

@Override
public MaybeOffloaded<ShuffleDescriptorGroup> serializeAndTryOffloadShuffleDescriptor(
ShuffleDescriptorGroup shuffleDescriptorGroup, int numConsumer) throws IOException {
return new NonOffloaded<>(CompressedSerializedValue.fromObject(shuffleDescriptorGroup));
public MaybeOffloaded<ShuffleDescriptorGroup> trySerializeAndOffloadShuffleDescriptor(
ShuffleDescriptorGroup shuffleDescriptorGroup, int numConsumer) {
return new NonOffloadedRaw<>(shuffleDescriptorGroup);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.blob.TestingBlobWriter;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.MaybeOffloaded;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloaded;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.NonOffloadedRaw;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.Offloaded;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory.ShuffleDescriptorAndIndex;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory.ShuffleDescriptorGroup;
Expand All @@ -47,11 +47,8 @@ public static ShuffleDescriptor[] deserializeShuffleDescriptors(
int maxIndex = 0;
for (MaybeOffloaded<ShuffleDescriptorGroup> sd : maybeOffloaded) {
ShuffleDescriptorGroup shuffleDescriptorGroup;
if (sd instanceof NonOffloaded) {
shuffleDescriptorGroup =
((NonOffloaded<ShuffleDescriptorGroup>) sd)
.serializedValue.deserializeValue(
ClassLoader.getSystemClassLoader());
if (sd instanceof NonOffloadedRaw) {
shuffleDescriptorGroup = ((NonOffloadedRaw<ShuffleDescriptorGroup>) sd).value;

} else {
final CompressedSerializedValue<ShuffleDescriptorGroup> compressedSerializedValue =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@
import org.apache.flink.runtime.shuffle.NettyShuffleDescriptor;
import org.apache.flink.runtime.shuffle.UnknownShuffleDescriptor;
import org.apache.flink.runtime.util.NettyShuffleDescriptorBuilder;
import org.apache.flink.util.CompressedSerializedValue;

import org.apache.flink.shaded.guava31.com.google.common.io.Closer;

Expand Down Expand Up @@ -1337,9 +1336,8 @@ partitionIds[2], createExecutionAttemptId())),
subpartitionIndexRange,
channelDescs.length,
Collections.singletonList(
new TaskDeploymentDescriptor.NonOffloaded<>(
CompressedSerializedValue.fromObject(
new ShuffleDescriptorGroup(channelDescs)))));
new TaskDeploymentDescriptor.NonOffloadedRaw<>(
new ShuffleDescriptorGroup(channelDescs))));

final TaskMetricGroup taskMetricGroup =
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup();
Expand Down

0 comments on commit d18a4bf

Please sign in to comment.