Skip to content

Commit

Permalink
[FLINK-11250][streaming] Initialize recordWriter before recovery
Browse files Browse the repository at this point in the history
  • Loading branch information
akalash committed Sep 8, 2021
1 parent d2ec320 commit d726c03
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,8 @@ void restoreInternal() throws Exception {
closedOperators = false;
LOG.debug("Initializing {}.", getName());

recordWriter.init();

operatorChain =
getEnvironment().getTaskStateManager().isFinishedOnRestore()
? new FinishedOperatorChain<>(this, recordWriter)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
import org.apache.flink.runtime.io.network.NettyShuffleEnvironment;
import org.apache.flink.runtime.io.network.NettyShuffleEnvironmentBuilder;
import org.apache.flink.runtime.io.network.api.writer.AvailabilityTestResultPartitionWriter;
import org.apache.flink.runtime.io.network.api.writer.RecordWriter;
import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
import org.apache.flink.runtime.io.network.partition.consumer.TestInputChannel;
Expand Down Expand Up @@ -123,7 +122,6 @@
import org.apache.flink.streaming.runtime.tasks.mailbox.MailboxDefaultAction;
import org.apache.flink.streaming.util.MockStreamConfig;
import org.apache.flink.streaming.util.MockStreamTaskBuilder;
import org.apache.flink.streaming.util.TestSequentialReadingStreamOperator;
import org.apache.flink.util.CloseableIterable;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.FatalExitExceptionHandler;
Expand All @@ -134,9 +132,11 @@
import org.apache.flink.util.concurrent.FutureUtils;
import org.apache.flink.util.concurrent.TestingUncaughtExceptionHandler;
import org.apache.flink.util.function.BiConsumerWithException;
import org.apache.flink.util.function.FunctionWithException;
import org.apache.flink.util.function.RunnableWithException;
import org.apache.flink.util.function.SupplierWithException;

import org.hamcrest.CoreMatchers;
import org.hamcrest.Matchers;
import org.junit.Assert;
import org.junit.Rule;
Expand All @@ -151,12 +151,12 @@
import java.io.Closeable;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.StreamCorruptedException;
import java.nio.ByteBuffer;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -183,10 +183,13 @@
import static org.apache.flink.configuration.TaskManagerOptions.BUFFER_DEBLOAT_TARGET;
import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.UNKNOWN_TASK_CHECKPOINT_NOTIFICATION_FAILURE;
import static org.apache.flink.runtime.checkpoint.StateObjectCollection.singleton;
import static org.apache.flink.runtime.io.network.api.writer.RecordWriter.DEFAULT_OUTPUT_FLUSH_THREAD_NAME;
import static org.apache.flink.runtime.state.CheckpointStorageLocationReference.getDefault;
import static org.apache.flink.streaming.runtime.tasks.mailbox.TaskMailbox.MAX_PRIORITY;
import static org.apache.flink.streaming.util.StreamTaskUtil.waitTaskIsRunning;
import static org.apache.flink.util.Preconditions.checkState;
import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.instanceOf;
Expand Down Expand Up @@ -1318,58 +1321,106 @@ public void testThreadInvariants() throws Throwable {
}
}

/**
* This test ensures that {@link RecordWriter} is correctly closed even if we fail to construct
* {@link OperatorChain}, for example because of user class deserialization error.
*/
@Test
public void testRecordWriterClosedOnStreamOperatorFactoryDeserializationError()
public void testRecordWriterClosedOnStreamTaskConstructorError() throws Exception {
HashMap<String, String> conf = new HashMap<>();
// Configure any parameter with wrong type for forcing exception in the constructor.
conf.put(BUFFER_DEBLOAT_ENABLED.key(), "11");
testRecordWriterClosedOnError(
env ->
taskBuilderWithConfiguredRecordWriter(env)
.setTaskManagerConfig(Configuration.fromMap(conf))
.build());
}

@Test
public void testRecordWriterClosedOnTransitDeployingStateError() throws Exception {
testRecordWriterClosedOnTransitStateError(ExecutionState.DEPLOYING);
}

@Test
public void testRecordWriterClosedOnTransitInitializingStateError() throws Exception {
testRecordWriterClosedOnTransitStateError(ExecutionState.INITIALIZING);
}

@Test
public void testRecordWriterClosedOnTransitRunningStateError() throws Exception {
testRecordWriterClosedOnTransitStateError(ExecutionState.RUNNING);
}

private void testRecordWriterClosedOnTransitStateError(ExecutionState executionState)
throws Exception {
// Throw the exception when the state updating to the expected one.
NoOpTaskManagerActions taskManagerActions =
new NoOpTaskManagerActions() {
@Override
public void updateTaskExecutionState(TaskExecutionState taskExecutionState) {
if (taskExecutionState.getExecutionState() == executionState) {
throw new ExpectedTestException();
}
}
};

testRecordWriterClosedOnError(
env ->
taskBuilderWithConfiguredRecordWriter(env)
.setTaskManagerActions(taskManagerActions)
.build());
}

private void testRecordWriterClosedOnError(
FunctionWithException<NettyShuffleEnvironment, Task, Exception> taskProvider)
throws Exception {
try (NettyShuffleEnvironment shuffleEnvironment =
new NettyShuffleEnvironmentBuilder().build()) {
Task task = taskProvider.apply(shuffleEnvironment);

task.startTaskThread();
task.getExecutingThread().join();

assertEquals(ExecutionState.FAILED, task.getExecutionState());
for (Thread thread : Thread.getAllStackTraces().keySet()) {
assertThat(
thread.getName(),
CoreMatchers.is(not(containsString(DEFAULT_OUTPUT_FLUSH_THREAD_NAME))));
}
}
}

private TestTaskBuilder taskBuilderWithConfiguredRecordWriter(
NettyShuffleEnvironment shuffleEnvironment) {
Configuration taskConfiguration = new Configuration();
outputEdgeConfiguration(taskConfiguration);

ResultPartitionDeploymentDescriptor descriptor =
new ResultPartitionDeploymentDescriptor(
PartitionDescriptorBuilder.newBuilder().build(),
NettyShuffleDescriptorBuilder.newBuilder().buildLocal(),
1,
false);
return new TestTaskBuilder(shuffleEnvironment)
.setInvokable(NoOpStreamTask.class)
.setTaskConfig(taskConfiguration)
.setResultPartitions(singletonList(descriptor));
}

/**
* Make sure that there is some output edge in the config so that some RecordWriter is created.
*/
private void outputEdgeConfiguration(Configuration taskConfiguration) {
StreamConfig streamConfig = new StreamConfig(taskConfiguration);
streamConfig.setStreamOperatorFactory(new UnusedOperatorFactory());

// Make sure that there is some output edge in the config so that some RecordWriter is
// created
StreamConfigChainer cfg =
new StreamConfigChainer(new OperatorID(42, 42), streamConfig, this, 1);
cfg.setBufferTimeout(1);
cfg.chain(
new OperatorID(44, 44),
new UnusedOperatorFactory(),
StringSerializer.INSTANCE,
StringSerializer.INSTANCE,
false);
cfg.finish();

// Overwrite the serialized bytes to some garbage to induce deserialization exception
taskConfiguration.setBytes(StreamConfig.SERIALIZEDUDF, new byte[42]);

try (MockEnvironment mockEnvironment =
new MockEnvironmentBuilder().setTaskConfiguration(taskConfiguration).build()) {

mockEnvironment.addOutput(new ArrayList<>());
StreamTask<String, TestSequentialReadingStreamOperator> streamTask =
new NoOpStreamTask<>(mockEnvironment);

try {
streamTask.invoke();
fail("Should have failed with an exception!");
} catch (Exception ex) {
if (!ExceptionUtils.findThrowable(ex, StreamCorruptedException.class).isPresent()) {
throw ex;
}
}
}

assertTrue(
RecordWriter.DEFAULT_OUTPUT_FLUSH_THREAD_NAME + " thread is still running",
Thread.getAllStackTraces().keySet().stream()
.noneMatch(
thread ->
thread.getName()
.startsWith(
RecordWriter
.DEFAULT_OUTPUT_FLUSH_THREAD_NAME)));
}

@Test
Expand Down

0 comments on commit d726c03

Please sign in to comment.