Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FLINK-15467][task] Wait for sourceTaskThread to finish before exiting from StreamTask.invoke #13000

Merged
merged 1 commit into from Jul 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -150,7 +150,12 @@ protected void cancelTask() {
}
}
finally {
sourceThread.interrupt();
if (sourceThread.isAlive()) {
sourceThread.interrupt();
} else if (!sourceThread.getCompletionFuture().isDone()) {
// source thread didn't start
sourceThread.getCompletionFuture().complete(null);
}
}
}

Expand All @@ -160,6 +165,11 @@ protected void finishTask() throws Exception {
cancelTask();
}

@Override
protected CompletableFuture<Void> getCompletionFuture() {
return sourceThread.getCompletionFuture();
}

// ------------------------------------------------------------------------
// Checkpointing
// ------------------------------------------------------------------------
Expand Down Expand Up @@ -210,8 +220,12 @@ public void setTaskDescription(final String taskDescription) {
setName("Legacy Source Thread - " + taskDescription);
}

/**
* @return future that is completed once this thread completes. If this task {@link #isFailing()} and this thread
* is not alive (e.g. not started) returns a normally completed future.
*/
CompletableFuture<Void> getCompletionFuture() {
return completionFuture;
return isFailing() && !isAlive() ? CompletableFuture.completedFuture(null) : completionFuture;
}
}
}
Expand Up @@ -90,6 +90,7 @@
import java.util.Map;
import java.util.OptionalLong;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
Expand Down Expand Up @@ -199,6 +200,9 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
/** Flag to mark this task as canceled. */
private volatile boolean canceled;

/** Flag to mark this task as failing, i.e. if an exception has occurred inside {@link #invoke()}. */
private volatile boolean failing;

private boolean disposedOperators;

/** Thread pool for async snapshot workers. */
Expand Down Expand Up @@ -538,6 +542,7 @@ public final void invoke() throws Exception {
afterInvoke();
}
catch (Exception invokeException) {
failing = !canceled;
try {
cleanUpInvoke();
}
Expand All @@ -562,6 +567,7 @@ private void runMailboxLoop() throws Exception {

protected void afterInvoke() throws Exception {
LOG.debug("Finished task {}", getName());
getCompletionFuture().exceptionally(unused -> null).join();
pnowojski marked this conversation as resolved.
Show resolved Hide resolved

final CompletableFuture<Void> timersFinishedFuture = new CompletableFuture<>();

Expand Down Expand Up @@ -599,6 +605,7 @@ protected void afterInvoke() throws Exception {
}

protected void cleanUpInvoke() throws Exception {
getCompletionFuture().exceptionally(unused -> null).join();
// clean up everything we initialized
isRunning = false;

Expand Down Expand Up @@ -637,6 +644,10 @@ protected void cleanUpInvoke() throws Exception {
}
}

protected CompletableFuture<Void> getCompletionFuture() {
return FutureUtils.completedVoidFuture();
}

@Override
public final void cancel() throws Exception {
isRunning = false;
Expand All @@ -648,8 +659,16 @@ public final void cancel() throws Exception {
cancelTask();
}
finally {
mailboxProcessor.allActionsCompleted();
cancelables.close();
getCompletionFuture()
.whenComplete((unusedResult, unusedError) -> {
// WARN: the method is called from the task thread but the callback can be invoked from a different thread
mailboxProcessor.allActionsCompleted();
try {
cancelables.close();
} catch (IOException e) {
throw new CompletionException(e);
}
});
}
}

Expand All @@ -665,6 +684,10 @@ public final boolean isCanceled() {
return canceled;
}

public final boolean isFailing() {
return failing;
}

private void shutdownAsyncThreads() throws Exception {
if (!asyncOperationsThreadPool.isShutdown()) {
asyncOperationsThreadPool.shutdownNow();
Expand Down
Expand Up @@ -466,6 +466,24 @@ public void finishingIgnoresExceptions() throws Exception {
testHarness.waitForTaskCompletion();
}

@Test
public void testWaitsForSourceThreadOnCancel() throws Exception {
StreamTaskTestHarness<String> harness = new StreamTaskTestHarness<>(SourceStreamTask::new, BasicTypeInfo.STRING_TYPE_INFO);

harness.setupOutputForSingletonOperatorChain();
harness.getStreamConfig().setStreamOperator(new StreamSource<>(new NonStoppingSource()));

harness.invoke();
NonStoppingSource.waitForStart();

harness.getTask().cancel();
harness.waitForTaskCompletion(500, true); // allow task to exit prematurely
assertTrue(harness.taskThread.isAlive());

NonStoppingSource.forceCancel();
harness.waitForTaskCompletion(Long.MAX_VALUE, true);
}

private static class MockSource implements SourceFunction<Tuple2<Long, Integer>>, ListCheckpointed<Serializable> {
private static final long serialVersionUID = 1;

Expand Down Expand Up @@ -572,6 +590,37 @@ public Boolean call() throws Exception {
}
}

private static class NonStoppingSource implements SourceFunction<String> {
private static final long serialVersionUID = 1L;
private static boolean running = true;
private static CompletableFuture<Void> startFuture = new CompletableFuture<>();

@Override
public void run(SourceContext<String> ctx) throws Exception {
startFuture.complete(null);
while (running) {
try {
Thread.sleep(500);
} catch (InterruptedException e) {
// ignore
}
}
}

@Override
public void cancel() {
// do nothing - ignore usual cancellation
}

static void forceCancel() {
running = false;
}

static void waitForStart() {
startFuture.join();
}
}

private static class OpenCloseTestSource extends RichSourceFunction<String> {
private static final long serialVersionUID = 1L;

Expand Down
Expand Up @@ -26,6 +26,7 @@
import org.apache.flink.metrics.Metric;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
import org.apache.flink.runtime.event.AbstractEvent;
import org.apache.flink.runtime.execution.CancelTaskException;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
Expand Down Expand Up @@ -53,6 +54,7 @@
import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner;
import org.apache.flink.streaming.runtime.streamrecord.StreamElement;
import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.function.FunctionWithException;
import org.apache.flink.util.function.SupplierWithException;
Expand Down Expand Up @@ -273,18 +275,24 @@ public void waitForTaskCompletion() throws Exception {
waitForTaskCompletion(Long.MAX_VALUE);
}

public void waitForTaskCompletion(long timeout) throws Exception {
waitForTaskCompletion(timeout, false);
}

/**
* Waits for the task completion. If this does not happen within the timeout, then a
* TimeoutException is thrown.
*
* @param timeout Timeout for the task completion
*/
public void waitForTaskCompletion(long timeout) throws Exception {
public void waitForTaskCompletion(long timeout, boolean ignoreCancellationException) throws Exception {
Preconditions.checkState(taskThread != null, "Task thread was not started.");

taskThread.join(timeout);
if (taskThread.getError() != null) {
throw new Exception("error in task", taskThread.getError());
if (!ignoreCancellationException || !ExceptionUtils.findThrowable(taskThread.getError(), CancelTaskException.class).isPresent()) {
throw new Exception("error in task", taskThread.getError());
}
}
}

Expand Down Expand Up @@ -439,7 +447,7 @@ public StreamConfigChainer<StreamTaskTestHarness<OUT>> setupOperatorChain(Operat

// ------------------------------------------------------------------------

private class TaskThread extends Thread {
class TaskThread extends Thread {

private final SupplierWithException<? extends StreamTask<OUT, ?>, Exception> taskFactory;
private volatile StreamTask<OUT, ?> task;
Expand Down