From 24374e19ef4cb0f3578360c7f1386bee5c655890 Mon Sep 17 00:00:00 2001 From: tison Date: Thu, 27 Sep 2018 02:06:41 +0800 Subject: [PATCH] [FLINK-10426] Port TaskTest to new code base [FLINK-10426] (Part 1) testRegularExecution [FLINK-10426] (Part 2) porting fails 1. testCancelRightAway 2. testFailExternallyRightAway 3. testLibraryCacheRegistrationFailed 4. testExecutionFailsInNetworkRegistration 5. testInvokableInstantiationFailed 6. testExecutionFailsInInvoke 7. testFailWithWrappedException 8. testCancelDuringInvoke 9. testFailExternallyDuringInvoke 10. testCanceledAfterExecutionFailedInInvoke 11. testExecutionFailsAfterCanceling 12. testExecutionFailsAfterTaskMarkedFailed 13. testCancelTaskException 14. testCancelTaskExceptionAfterTaskMarkedFailed [FLINK-10426] (Part 3) partition state update tests See also FLINK-10319, some of these tests would be removed based on that. [FLINK-10426] (Part 4) watch dog [FLINK-10426] (Part 4) config --- .../flink/runtime/taskmanager/TaskTest.java | 1311 ++++++++--------- 1 file changed, 614 insertions(+), 697 deletions(-) diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java index bb5cd17193a27..f4abf182d12fa 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java @@ -20,24 +20,29 @@ import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.JobID; +import org.apache.flink.configuration.BlobServerOptions; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.TaskManagerOptions; import org.apache.flink.core.testutils.OneShotLatch; import org.apache.flink.runtime.blob.BlobCacheService; +import org.apache.flink.runtime.blob.BlobServer; import org.apache.flink.runtime.blob.PermanentBlobCache; +import org.apache.flink.runtime.blob.PermanentBlobKey; import org.apache.flink.runtime.blob.TransientBlobCache; +import org.apache.flink.runtime.blob.VoidBlobStore; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.concurrent.Executors; import org.apache.flink.runtime.execution.CancelTaskException; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.execution.ExecutionState; +import org.apache.flink.runtime.execution.librarycache.BlobLibraryCacheManager; +import org.apache.flink.runtime.execution.librarycache.FlinkUserCodeClassLoaders; import org.apache.flink.runtime.execution.librarycache.LibraryCacheManager; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.JobInformation; import org.apache.flink.runtime.executiongraph.TaskInformation; import org.apache.flink.runtime.filecache.FileCache; -import org.apache.flink.runtime.instance.ActorGateway; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.network.NetworkEnvironment; import org.apache.flink.runtime.io.network.TaskEventDispatcher; @@ -49,40 +54,35 @@ import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; -import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; import org.apache.flink.runtime.jobmanager.PartitionProducerDisposedException; import org.apache.flink.runtime.memory.MemoryManager; -import org.apache.flink.runtime.messages.TaskManagerMessages; -import org.apache.flink.runtime.messages.TaskMessages; import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup; import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; +import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; import org.apache.flink.runtime.query.TaskKvStateRegistry; import org.apache.flink.runtime.state.TestTaskStateManager; +import org.apache.flink.runtime.testingUtils.TestingUtils; import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo; import org.apache.flink.util.SerializedValue; import org.apache.flink.util.TestLogger; import org.apache.flink.util.WrappingRuntimeException; - -import org.junit.After; import org.junit.Before; +import org.junit.ClassRule; import org.junit.Test; +import org.junit.rules.TemporaryFolder; import javax.annotation.Nonnull; - import java.io.IOException; import java.lang.reflect.Field; +import java.net.InetSocketAddress; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.Map; -import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import scala.concurrent.duration.FiniteDuration; - import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; @@ -90,9 +90,11 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyString; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -100,445 +102,402 @@ /** * Tests for the Task, which make sure that correct state transitions happen, * and failures are correctly handled. - * - *

All tests here have a set of mock actors for TaskManager, JobManager, and - * execution listener, which simply put the messages in a queue to be picked - * up by the test and validated. */ public class TaskTest extends TestLogger { + private static final long TIMEOUT = 1000L; + private static OneShotLatch awaitLatch; private static OneShotLatch triggerLatch; - private static OneShotLatch cancelLatch; - - private ActorGateway taskManagerGateway; - private ActorGateway jobManagerGateway; - private ActorGatewayTaskManagerActions taskManagerConnection; - - private BlockingQueue taskManagerMessages; - private BlockingQueue jobManagerMessages; - private BlockingQueue listenerMessages; + @ClassRule + public static final TemporaryFolder TEMPORARY_FOLDER = new TemporaryFolder(); @Before - public void createQueuesAndActors() { - taskManagerMessages = new LinkedBlockingQueue<>(); - jobManagerMessages = new LinkedBlockingQueue<>(); - listenerMessages = new LinkedBlockingQueue<>(); - taskManagerGateway = new ForwardingActorGateway(taskManagerMessages); - jobManagerGateway = new ForwardingActorGateway(jobManagerMessages); - - taskManagerConnection = new ActorGatewayTaskManagerActions(taskManagerGateway) { - @Override - public void updateTaskExecutionState(TaskExecutionState taskExecutionState) { - super.updateTaskExecutionState(taskExecutionState); - listenerMessages.add(taskExecutionState); - } - }; - + public void setup() { awaitLatch = new OneShotLatch(); triggerLatch = new OneShotLatch(); - cancelLatch = new OneShotLatch(); } - @After - public void clearActorsAndMessages() { - jobManagerMessages = null; - taskManagerMessages = null; - listenerMessages = null; + @Test + public void testRegularExecution() throws Exception { + final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + final Task task = new TaskBuilder() + .setTaskManagerActions(taskManagerActions) + .build(); + + // task should be new and perfect + assertEquals(ExecutionState.CREATED, task.getExecutionState()); + assertFalse(task.isCanceledOrFailed()); + assertNull(task.getFailureCause()); + + // go into the run method. we should switch to DEPLOYING, RUNNING, then + // FINISHED, and all should be good + final TaskExecutionState state = new TaskExecutionState( + task.getJobID(), + task.getExecutionId(), + ExecutionState.RUNNING); - taskManagerGateway = null; - jobManagerGateway = null; - } + task.run(); - // ------------------------------------------------------------------------ - // Tests - // ------------------------------------------------------------------------ + verify(taskManagerActions, timeout(TIMEOUT)).updateTaskExecutionState(eq(state)); - @Test - public void testRegularExecution() { - try { - Task task = createTask(TestInvokableCorrect.class); - - // task should be new and perfect - assertEquals(ExecutionState.CREATED, task.getExecutionState()); - assertFalse(task.isCanceledOrFailed()); - assertNull(task.getFailureCause()); - - // go into the run method. we should switch to DEPLOYING, RUNNING, then - // FINISHED, and all should be good - task.run(); - - // verify final state - assertEquals(ExecutionState.FINISHED, task.getExecutionState()); - assertFalse(task.isCanceledOrFailed()); - assertNull(task.getFailureCause()); - assertNull(task.getInvokable()); - - // verify listener messages - validateListenerMessage(ExecutionState.RUNNING, task, false); - validateListenerMessage(ExecutionState.FINISHED, task, false); - - // make sure that the TaskManager received an message to unregister the task - validateTaskManagerStateChange(ExecutionState.RUNNING, task, false); - validateUnregisterTask(task.getExecutionId()); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } + // verify final state + assertEquals(ExecutionState.FINISHED, task.getExecutionState()); + assertFalse(task.isCanceledOrFailed()); + assertNull(task.getFailureCause()); + assertNull(task.getInvokable()); } @Test - public void testCancelRightAway() { - try { - Task task = createTask(TestInvokableCorrect.class); - task.cancelExecution(); + public void testCancelRightAway() throws Exception { + final Task task = new TaskBuilder().build(); + task.cancelExecution(); - assertEquals(ExecutionState.CANCELING, task.getExecutionState()); + assertEquals(ExecutionState.CANCELING, task.getExecutionState()); - task.run(); + task.run(); - // verify final state - assertEquals(ExecutionState.CANCELED, task.getExecutionState()); - validateUnregisterTask(task.getExecutionId()); + // verify final state + assertEquals(ExecutionState.CANCELED, task.getExecutionState()); - assertNull(task.getInvokable()); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } + assertNull(task.getInvokable()); } @Test - public void testFailExternallyRightAway() { - try { - Task task = createTask(TestInvokableCorrect.class); - task.failExternally(new Exception("fail externally")); + public void testFailExternallyRightAway() throws Exception { + Task task = new TaskBuilder().build(); + task.failExternally(new Exception("fail externally")); - assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertEquals(ExecutionState.FAILED, task.getExecutionState()); - task.run(); + task.run(); - // verify final state - assertEquals(ExecutionState.FAILED, task.getExecutionState()); - validateUnregisterTask(task.getExecutionId()); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } + // verify final state + assertEquals(ExecutionState.FAILED, task.getExecutionState()); } @Test - public void testLibraryCacheRegistrationFailed() { - try { - BlobCacheService blobService = createBlobCache(); - Task task = createTask(TestInvokableCorrect.class, blobService, - mock(LibraryCacheManager.class)); - - // task should be new and perfect - assertEquals(ExecutionState.CREATED, task.getExecutionState()); - assertFalse(task.isCanceledOrFailed()); - assertNull(task.getFailureCause()); + public void testLibraryCacheRegistrationFailed() throws Exception { + final Task task = new TaskBuilder() + .setLibraryCacheManager(mock(LibraryCacheManager.class)) // inactive manager + .build(); - // should fail - task.run(); + // task should be new and perfect + assertEquals(ExecutionState.CREATED, task.getExecutionState()); + assertFalse(task.isCanceledOrFailed()); + assertNull(task.getFailureCause()); - // verify final state - assertEquals(ExecutionState.FAILED, task.getExecutionState()); - assertTrue(task.isCanceledOrFailed()); - assertNotNull(task.getFailureCause()); - assertNotNull(task.getFailureCause().getMessage()); - assertTrue(task.getFailureCause().getMessage().contains("classloader")); - - // verify listener messages - validateListenerMessage(ExecutionState.FAILED, task, true); + // should fail + task.run(); - // make sure that the TaskManager received an message to unregister the task - validateUnregisterTask(task.getExecutionId()); + // verify final state + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertNotNull(task.getFailureCause()); + assertNotNull(task.getFailureCause().getMessage()); + assertTrue(task.getFailureCause().getMessage().contains("classloader")); - assertNull(task.getInvokable()); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } + assertNull(task.getInvokable()); } @Test - public void testExecutionFailsInNetworkRegistration() { - try { - BlobCacheService blobService = createBlobCache(); - // mock a working library cache - LibraryCacheManager libCache = mock(LibraryCacheManager.class); - when(libCache.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader()); - - // mock a network manager that rejects registration - ResultPartitionManager partitionManager = mock(ResultPartitionManager.class); - ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class); - PartitionProducerStateChecker partitionProducerStateChecker = mock(PartitionProducerStateChecker.class); - TaskEventDispatcher taskEventDispatcher = mock(TaskEventDispatcher.class); - Executor executor = mock(Executor.class); - NetworkEnvironment network = mock(NetworkEnvironment.class); - when(network.getResultPartitionManager()).thenReturn(partitionManager); - when(network.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC); - when(network.getTaskEventDispatcher()).thenReturn(taskEventDispatcher); - doThrow(new RuntimeException("buffers")).when(network).registerTask(any(Task.class)); - - Task task = createTask(TestInvokableCorrect.class, blobService, libCache, network, consumableNotifier, partitionProducerStateChecker, executor); - - task.run(); + public void testExecutionFailsInBlobsMissing() throws Exception { + final PermanentBlobKey missingKey = new PermanentBlobKey(); + + final Configuration config = new Configuration(); + config.setString(BlobServerOptions.STORAGE_DIRECTORY, + TEMPORARY_FOLDER.newFolder().getAbsolutePath()); + config.setLong(BlobServerOptions.CLEANUP_INTERVAL, 1L); + + final BlobServer blobServer = new BlobServer(config, new VoidBlobStore()); + blobServer.start(); + InetSocketAddress serverAddress = new InetSocketAddress("localhost", blobServer.getPort()); + final PermanentBlobCache permanentBlobCache = new PermanentBlobCache(config, new VoidBlobStore(), serverAddress); + + final BlobLibraryCacheManager libraryCacheManager = + new BlobLibraryCacheManager( + permanentBlobCache, + FlinkUserCodeClassLoaders.ResolveOrder.CHILD_FIRST, + new String[0]); + + final Task task = new TaskBuilder() + .setRequiredJarFileBlobKeys(Collections.singletonList(missingKey)) + .setLibraryCacheManager(libraryCacheManager) + .build(); + + // task should be new and perfect + assertEquals(ExecutionState.CREATED, task.getExecutionState()); + assertFalse(task.isCanceledOrFailed()); + assertNull(task.getFailureCause()); + + // should fail + task.run(); - assertEquals(ExecutionState.FAILED, task.getExecutionState()); - assertTrue(task.isCanceledOrFailed()); - assertTrue(task.getFailureCause().getMessage().contains("buffers")); + // verify final state + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertNotNull(task.getFailureCause()); + assertNotNull(task.getFailureCause().getMessage()); + assertTrue(task.getFailureCause().getMessage().contains("Failed to fetch BLOB")); - validateUnregisterTask(task.getExecutionId()); - validateListenerMessage(ExecutionState.FAILED, task, true); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } + assertNull(task.getInvokable()); } @Test - public void testInvokableInstantiationFailed() { - try { - Task task = createTask(InvokableNonInstantiable.class); + public void testExecutionFailsInNetworkRegistration() throws Exception { + // mock a network manager that rejects registration + final ResultPartitionManager partitionManager = mock(ResultPartitionManager.class); + final ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class); + final PartitionProducerStateChecker partitionProducerStateChecker = mock(PartitionProducerStateChecker.class); + final TaskEventDispatcher taskEventDispatcher = mock(TaskEventDispatcher.class); + + final NetworkEnvironment network = mock(NetworkEnvironment.class); + when(network.getResultPartitionManager()).thenReturn(partitionManager); + when(network.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC); + when(network.getTaskEventDispatcher()).thenReturn(taskEventDispatcher); + doThrow(new RuntimeException("buffers")).when(network).registerTask(any(Task.class)); - task.run(); + final Task task = new TaskBuilder() + .setConsumableNotifier(consumableNotifier) + .setPartitionProducerStateChecker(partitionProducerStateChecker) + .setNetworkEnvironment(network) + .build(); - assertEquals(ExecutionState.FAILED, task.getExecutionState()); - assertTrue(task.isCanceledOrFailed()); - assertTrue(task.getFailureCause().getMessage().contains("instantiate")); + // should fail + task.run(); - validateUnregisterTask(task.getExecutionId()); - validateListenerMessage(ExecutionState.FAILED, task, true); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } + // verify final state + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertTrue(task.getFailureCause().getMessage().contains("buffers")); } @Test - public void testExecutionFailsInInvoke() { - try { - Task task = createTask(InvokableWithExceptionInInvoke.class); + public void testInvokableInstantiationFailed() throws Exception { + final Task task = new TaskBuilder() + .setInvokable(InvokableNonInstantiable.class) + .build(); - task.run(); + // should fail + task.run(); - assertEquals(ExecutionState.FAILED, task.getExecutionState()); - assertTrue(task.isCanceledOrFailed()); - assertNotNull(task.getFailureCause()); - assertNotNull(task.getFailureCause().getMessage()); - assertTrue(task.getFailureCause().getMessage().contains("test")); + // verify final state + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertTrue(task.getFailureCause().getMessage().contains("instantiate")); + } - validateTaskManagerStateChange(ExecutionState.RUNNING, task, false); - validateUnregisterTask(task.getExecutionId()); + @Test + public void testExecutionFailsInInvoke() throws Exception { + final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + final Task task = new TaskBuilder() + .setInvokable(InvokableWithExceptionInInvoke.class) + .setTaskManagerActions(taskManagerActions) + .build(); + + final TaskExecutionState state = new TaskExecutionState( + task.getJobID(), + task.getExecutionId(), + ExecutionState.RUNNING); - validateListenerMessage(ExecutionState.RUNNING, task, false); - validateListenerMessage(ExecutionState.FAILED, task, true); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } + task.run(); + + verify(taskManagerActions, timeout(TIMEOUT)).updateTaskExecutionState(eq(state)); + + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertNotNull(task.getFailureCause()); + assertNotNull(task.getFailureCause().getMessage()); + assertTrue(task.getFailureCause().getMessage().contains("test")); } @Test - public void testFailWithWrappedException() { - try { - Task task = createTask(FailingInvokableWithChainedException.class); + public void testFailWithWrappedException() throws Exception { + final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + final Task task = new TaskBuilder() + .setInvokable(FailingInvokableWithChainedException.class) + .setTaskManagerActions(taskManagerActions) + .build(); + + final TaskExecutionState state = new TaskExecutionState( + task.getJobID(), + task.getExecutionId(), + ExecutionState.RUNNING); - task.run(); - - assertEquals(ExecutionState.FAILED, task.getExecutionState()); - assertTrue(task.isCanceledOrFailed()); + task.run(); - Throwable cause = task.getFailureCause(); - assertTrue(cause instanceof IOException); + verify(taskManagerActions, timeout(TIMEOUT)).updateTaskExecutionState(eq(state)); - validateTaskManagerStateChange(ExecutionState.RUNNING, task, false); - validateUnregisterTask(task.getExecutionId()); + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); - validateListenerMessage(ExecutionState.RUNNING, task, false); - validateListenerMessage(ExecutionState.FAILED, task, true); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } + final Throwable cause = task.getFailureCause(); + assertTrue(cause instanceof IOException); } @Test - public void testCancelDuringInvoke() { - try { - Task task = createTask(InvokableBlockingInInvoke.class); - - // run the task asynchronous - task.startTaskThread(); - - // wait till the task is in invoke - awaitLatch.await(); + public void testCancelDuringInvoke() throws Exception { + final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + final Task task = new TaskBuilder() + .setInvokable(InvokableBlockingInInvoke.class) + .setTaskManagerActions(taskManagerActions) + .build(); + + final TaskExecutionState state = new TaskExecutionState( + task.getJobID(), + task.getExecutionId(), + ExecutionState.RUNNING); + + // run the task asynchronous + task.startTaskThread(); - task.cancelExecution(); - assertTrue(task.getExecutionState() == ExecutionState.CANCELING || - task.getExecutionState() == ExecutionState.CANCELED); + verify(taskManagerActions, timeout(TIMEOUT)).updateTaskExecutionState(eq(state)); - task.getExecutingThread().join(); + // wait till the task is in invoke + awaitLatch.await(); - assertEquals(ExecutionState.CANCELED, task.getExecutionState()); - assertTrue(task.isCanceledOrFailed()); - assertNull(task.getFailureCause()); + task.cancelExecution(); + assertTrue(task.getExecutionState() == ExecutionState.CANCELING || + task.getExecutionState() == ExecutionState.CANCELED); - validateTaskManagerStateChange(ExecutionState.RUNNING, task, false); - validateUnregisterTask(task.getExecutionId()); + task.getExecutingThread().join(); - validateListenerMessage(ExecutionState.RUNNING, task, false); - validateListenerMessage(ExecutionState.CANCELED, task, false); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } + assertEquals(ExecutionState.CANCELED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertNull(task.getFailureCause()); } @Test - public void testFailExternallyDuringInvoke() { - try { - Task task = createTask(InvokableBlockingInInvoke.class); - - // run the task asynchronous - task.startTaskThread(); - - // wait till the task is in regInOut - awaitLatch.await(); + public void testFailExternallyDuringInvoke() throws Exception { + final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + final Task task = new TaskBuilder() + .setInvokable(InvokableBlockingInInvoke.class) + .setTaskManagerActions(taskManagerActions) + .build(); + + final TaskExecutionState state = new TaskExecutionState( + task.getJobID(), + task.getExecutionId(), + ExecutionState.RUNNING); + + // run the task asynchronous + task.startTaskThread(); - task.failExternally(new Exception("test")); - assertEquals(ExecutionState.FAILED, task.getExecutionState()); + verify(taskManagerActions, timeout(TIMEOUT)).updateTaskExecutionState(eq(state)); - task.getExecutingThread().join(); + // wait till the task is in invoke + awaitLatch.await(); - assertEquals(ExecutionState.FAILED, task.getExecutionState()); - assertTrue(task.isCanceledOrFailed()); - assertTrue(task.getFailureCause().getMessage().contains("test")); + task.failExternally(new Exception("test")); - validateTaskManagerStateChange(ExecutionState.RUNNING, task, false); - validateUnregisterTask(task.getExecutionId()); + task.getExecutingThread().join(); - validateListenerMessage(ExecutionState.RUNNING, task, false); - validateListenerMessage(ExecutionState.FAILED, task, true); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertTrue(task.getFailureCause().getMessage().contains("test")); } @Test - public void testCanceledAfterExecutionFailedInInvoke() { - try { - Task task = createTask(InvokableWithExceptionInInvoke.class); + public void testCanceledAfterExecutionFailedInInvoke() throws Exception { + final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + final Task task = new TaskBuilder() + .setInvokable(InvokableWithExceptionInInvoke.class) + .setTaskManagerActions(taskManagerActions) + .build(); + + final TaskExecutionState state = new TaskExecutionState( + task.getJobID(), + task.getExecutionId(), + ExecutionState.RUNNING); - task.run(); + task.run(); - // this should not overwrite the failure state - task.cancelExecution(); + verify(taskManagerActions, timeout(TIMEOUT)).updateTaskExecutionState(eq(state)); - assertEquals(ExecutionState.FAILED, task.getExecutionState()); - assertTrue(task.isCanceledOrFailed()); - assertTrue(task.getFailureCause().getMessage().contains("test")); - - validateTaskManagerStateChange(ExecutionState.RUNNING, task, false); - validateUnregisterTask(task.getExecutionId()); + // this should not overwrite the failure state + task.cancelExecution(); - validateListenerMessage(ExecutionState.RUNNING, task, false); - validateListenerMessage(ExecutionState.FAILED, task, true); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertTrue(task.getFailureCause().getMessage().contains("test")); } @Test - public void testExecutionFailsAfterCanceling() { - try { - Task task = createTask(InvokableWithExceptionOnTrigger.class); - - // run the task asynchronous - task.startTaskThread(); - - // wait till the task is in invoke - awaitLatch.await(); + public void testExecutionFailsAfterCanceling() throws Exception { + final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + final Task task = new TaskBuilder() + .setInvokable(InvokableWithExceptionOnTrigger.class) + .setTaskManagerActions(taskManagerActions) + .build(); + + final TaskExecutionState state = new TaskExecutionState( + task.getJobID(), + task.getExecutionId(), + ExecutionState.RUNNING); + + // run the task asynchronous + task.startTaskThread(); - task.cancelExecution(); - assertEquals(ExecutionState.CANCELING, task.getExecutionState()); + verify(taskManagerActions, timeout(TIMEOUT)).updateTaskExecutionState(eq(state)); - // this causes an exception - triggerLatch.trigger(); + // wait till the task is in invoke + awaitLatch.await(); - task.getExecutingThread().join(); + task.cancelExecution(); + assertEquals(ExecutionState.CANCELING, task.getExecutionState()); - // we should still be in state canceled - assertEquals(ExecutionState.CANCELED, task.getExecutionState()); - assertTrue(task.isCanceledOrFailed()); - assertNull(task.getFailureCause()); + // this causes an exception + triggerLatch.trigger(); - validateTaskManagerStateChange(ExecutionState.RUNNING, task, false); - validateUnregisterTask(task.getExecutionId()); + task.getExecutingThread().join(); - validateListenerMessage(ExecutionState.RUNNING, task, false); - validateListenerMessage(ExecutionState.CANCELED, task, false); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } + // we should still be in state canceled + assertEquals(ExecutionState.CANCELED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertNull(task.getFailureCause()); } @Test - public void testExecutionFailsAfterTaskMarkedFailed() { - try { - Task task = createTask(InvokableWithExceptionOnTrigger.class); - - // run the task asynchronous - task.startTaskThread(); - - // wait till the task is in invoke - awaitLatch.await(); + public void testExecutionFailsAfterTaskMarkedFailed() throws Exception { + final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + final Task task = new TaskBuilder() + .setInvokable(InvokableWithExceptionOnTrigger.class) + .setTaskManagerActions(taskManagerActions) + .build(); + + final TaskExecutionState state = new TaskExecutionState( + task.getJobID(), + task.getExecutionId(), + ExecutionState.RUNNING); + + // run the task asynchronous + task.startTaskThread(); - task.failExternally(new Exception("external")); - assertEquals(ExecutionState.FAILED, task.getExecutionState()); + verify(taskManagerActions, timeout(TIMEOUT)).updateTaskExecutionState(eq(state)); - // this causes an exception - triggerLatch.trigger(); + // wait till the task is in invoke + awaitLatch.await(); - task.getExecutingThread().join(); + task.failExternally(new Exception("external")); + assertEquals(ExecutionState.FAILED, task.getExecutionState()); - assertEquals(ExecutionState.FAILED, task.getExecutionState()); - assertTrue(task.isCanceledOrFailed()); - assertTrue(task.getFailureCause().getMessage().contains("external")); + // this causes an exception + triggerLatch.trigger(); - validateTaskManagerStateChange(ExecutionState.RUNNING, task, false); - validateUnregisterTask(task.getExecutionId()); + task.getExecutingThread().join(); - validateListenerMessage(ExecutionState.RUNNING, task, false); - validateListenerMessage(ExecutionState.FAILED, task, true); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertTrue(task.getFailureCause().getMessage().contains("external")); } + @Test public void testCancelTaskException() throws Exception { - final Task task = createTask(InvokableWithCancelTaskExceptionInInvoke.class); + final Task task = new TaskBuilder() + .setInvokable(InvokableWithCancelTaskExceptionInInvoke.class) + .build(); // Cause CancelTaskException. triggerLatch.trigger(); @@ -550,7 +509,9 @@ public void testCancelTaskException() throws Exception { @Test public void testCancelTaskExceptionAfterTaskMarkedFailed() throws Exception { - final Task task = createTask(InvokableWithCancelTaskExceptionInInvoke.class); + final Task task = new TaskBuilder() + .setInvokable(InvokableWithCancelTaskExceptionInInvoke.class) + .build(); task.startTaskThread(); @@ -573,13 +534,15 @@ public void testCancelTaskExceptionAfterTaskMarkedFailed() throws Exception { @Test public void testOnPartitionStateUpdate() throws Exception { - IntermediateDataSetID resultId = new IntermediateDataSetID(); - ResultPartitionID partitionId = new ResultPartitionID(); + final IntermediateDataSetID resultId = new IntermediateDataSetID(); + final ResultPartitionID partitionId = new ResultPartitionID(); - SingleInputGate inputGate = mock(SingleInputGate.class); + final SingleInputGate inputGate = mock(SingleInputGate.class); when(inputGate.getConsumedResultId()).thenReturn(resultId); - final Task task = createTask(InvokableBlockingInInvoke.class); + final Task task = new TaskBuilder() + .setInvokable(InvokableBlockingInInvoke.class) + .build(); // Set the mock input gate setInputGate(task, inputGate); @@ -619,36 +582,36 @@ public void testOnPartitionStateUpdate() throws Exception { */ @Test public void testTriggerPartitionStateUpdate() throws Exception { - IntermediateDataSetID resultId = new IntermediateDataSetID(); - ResultPartitionID partitionId = new ResultPartitionID(); - - BlobCacheService blobService = createBlobCache(); - LibraryCacheManager libCache = mock(LibraryCacheManager.class); - when(libCache.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader()); + final IntermediateDataSetID resultId = new IntermediateDataSetID(); + final ResultPartitionID partitionId = new ResultPartitionID(); - PartitionProducerStateChecker partitionChecker = mock(PartitionProducerStateChecker.class); - TaskEventDispatcher taskEventDispatcher = mock(TaskEventDispatcher.class); + final PartitionProducerStateChecker partitionChecker = mock(PartitionProducerStateChecker.class); + final TaskEventDispatcher taskEventDispatcher = mock(TaskEventDispatcher.class); - ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class); - NetworkEnvironment network = mock(NetworkEnvironment.class); + final ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class); + final NetworkEnvironment network = mock(NetworkEnvironment.class); when(network.getResultPartitionManager()).thenReturn(mock(ResultPartitionManager.class)); when(network.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC); when(network.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class))) .thenReturn(mock(TaskKvStateRegistry.class)); when(network.getTaskEventDispatcher()).thenReturn(taskEventDispatcher); - createTask(InvokableBlockingInInvoke.class, blobService, libCache, network, consumableNotifier, partitionChecker, Executors.directExecutor()); // Test all branches of trigger partition state check - { // Reset latches - createQueuesAndActors(); + setup(); // PartitionProducerDisposedException - Task task = createTask(InvokableBlockingInInvoke.class, blobService, libCache, network, consumableNotifier, partitionChecker, Executors.directExecutor()); - - CompletableFuture promise = new CompletableFuture<>(); + final Task task = new TaskBuilder() + .setInvokable(InvokableBlockingInInvoke.class) + .setNetworkEnvironment(network) + .setConsumableNotifier(consumableNotifier) + .setPartitionProducerStateChecker(partitionChecker) + .setExecutor(Executors.directExecutor()) + .build(); + + final CompletableFuture promise = new CompletableFuture<>(); when(partitionChecker.requestPartitionProducerState(eq(task.getJobID()), eq(resultId), eq(partitionId))).thenReturn(promise); task.triggerPartitionProducerStateCheck(task.getJobID(), resultId, partitionId); @@ -659,12 +622,18 @@ public void testTriggerPartitionStateUpdate() throws Exception { { // Reset latches - createQueuesAndActors(); + setup(); // Any other exception - Task task = createTask(InvokableBlockingInInvoke.class, blobService, libCache, network, consumableNotifier, partitionChecker, Executors.directExecutor()); - - CompletableFuture promise = new CompletableFuture<>(); + final Task task = new TaskBuilder() + .setInvokable(InvokableBlockingInInvoke.class) + .setNetworkEnvironment(network) + .setConsumableNotifier(consumableNotifier) + .setPartitionProducerStateChecker(partitionChecker) + .setExecutor(Executors.directExecutor()) + .build(); + + final CompletableFuture promise = new CompletableFuture<>(); when(partitionChecker.requestPartitionProducerState(eq(task.getJobID()), eq(resultId), eq(partitionId))).thenReturn(promise); task.triggerPartitionProducerStateCheck(task.getJobID(), resultId, partitionId); @@ -676,11 +645,19 @@ public void testTriggerPartitionStateUpdate() throws Exception { { // Reset latches - createQueuesAndActors(); + setup(); // TimeoutException handled special => retry - Task task = createTask(InvokableBlockingInInvoke.class, blobService, libCache, network, consumableNotifier, partitionChecker, Executors.directExecutor()); - SingleInputGate inputGate = mock(SingleInputGate.class); + // Any other exception + final Task task = new TaskBuilder() + .setInvokable(InvokableBlockingInInvoke.class) + .setNetworkEnvironment(network) + .setConsumableNotifier(consumableNotifier) + .setPartitionProducerStateChecker(partitionChecker) + .setExecutor(Executors.directExecutor()) + .build(); + + final SingleInputGate inputGate = mock(SingleInputGate.class); when(inputGate.getConsumedResultId()).thenReturn(resultId); try { @@ -707,11 +684,18 @@ public void testTriggerPartitionStateUpdate() throws Exception { { // Reset latches - createQueuesAndActors(); + setup(); // Success - Task task = createTask(InvokableBlockingInInvoke.class, blobService, libCache, network, consumableNotifier, partitionChecker, Executors.directExecutor()); - SingleInputGate inputGate = mock(SingleInputGate.class); + final Task task = new TaskBuilder() + .setInvokable(InvokableBlockingInInvoke.class) + .setNetworkEnvironment(network) + .setConsumableNotifier(consumableNotifier) + .setPartitionProducerStateChecker(partitionChecker) + .setExecutor(Executors.directExecutor()) + .build(); + + final SingleInputGate inputGate = mock(SingleInputGate.class); when(inputGate.getConsumedResultId()).thenReturn(resultId); try { @@ -744,22 +728,28 @@ public void testTriggerPartitionStateUpdate() throws Exception { */ @Test public void testWatchDogInterruptsTask() throws Exception { - Configuration config = new Configuration(); + final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + + // guard no fatal error + doThrow(new RuntimeException("Unexpected FatalError message")). + when(taskManagerActions).notifyFatalError(anyString(), any(Throwable.class)); + + final Configuration config = new Configuration(); config.setLong(TaskManagerOptions.TASK_CANCELLATION_INTERVAL.key(), 5); config.setLong(TaskManagerOptions.TASK_CANCELLATION_TIMEOUT.key(), 60 * 1000); - Task task = createTask(InvokableBlockingInCancel.class, config); + final Task task = new TaskBuilder() + .setInvokable(InvokableBlockingInCancel.class) + .setTaskManagerConfig(config) + .setTaskManagerActions(taskManagerActions) + .build(); + task.startTaskThread(); awaitLatch.await(); task.cancelExecution(); task.getExecutingThread().join(); - - // No fatal error - for (Object msg : taskManagerMessages) { - assertFalse("Unexpected FatalError message", msg instanceof TaskManagerMessages.FatalError); - } } /** @@ -768,23 +758,29 @@ public void testWatchDogInterruptsTask() throws Exception { * This is resolved by the watch dog, no fatal error. */ @Test - public void testInterruptableSharedLockInInvokeAndCancel() throws Exception { - Configuration config = new Configuration(); + public void testInterruptibleSharedLockInInvokeAndCancel() throws Exception { + final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + + // guard no fatal error + doThrow(new RuntimeException("Unexpected FatalError message")). + when(taskManagerActions).notifyFatalError(anyString(), any(Throwable.class)); + + final Configuration config = new Configuration(); config.setLong(TaskManagerOptions.TASK_CANCELLATION_INTERVAL, 5); config.setLong(TaskManagerOptions.TASK_CANCELLATION_TIMEOUT, 50); - Task task = createTask(InvokableInterruptableSharedLockInInvokeAndCancel.class, config); + final Task task = new TaskBuilder() + .setInvokable(InvokableInterruptibleSharedLockInInvokeAndCancel.class) + .setTaskManagerConfig(config) + .setTaskManagerActions(taskManagerActions) + .build(); + task.startTaskThread(); awaitLatch.await(); task.cancelExecution(); task.getExecutingThread().join(); - - // No fatal error - for (Object msg : taskManagerMessages) { - assertFalse("Unexpected FatalError message", msg instanceof TaskManagerMessages.FatalError); - } } /** @@ -792,34 +788,39 @@ public void testInterruptableSharedLockInInvokeAndCancel() throws Exception { * resolved by a fatal error. */ @Test - public void testFatalErrorAfterUninterruptibleInvoke() throws Exception { - Configuration config = new Configuration(); + public void testFatalErrorAfterUnInterruptibleInvoke() throws Exception { + final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + + final Configuration config = new Configuration(); config.setLong(TaskManagerOptions.TASK_CANCELLATION_INTERVAL, 5); config.setLong(TaskManagerOptions.TASK_CANCELLATION_TIMEOUT, 50); - Task task = createTask(InvokableUninterruptibleBlockingInvoke.class, config); + final Task task = new TaskBuilder() + .setInvokable(InvokableUnInterruptibleBlockingInvoke.class) + .setTaskManagerConfig(config) + .setTaskManagerActions(taskManagerActions) + .build(); - try { - task.startTaskThread(); + final TaskExecutionState state = new TaskExecutionState( + task.getJobID(), + task.getExecutionId(), + ExecutionState.RUNNING); - awaitLatch.await(); + task.startTaskThread(); - task.cancelExecution(); + verify(taskManagerActions, timeout(TIMEOUT)).updateTaskExecutionState(eq(state)); - for (int i = 0; i < 10; i++) { - Object msg = taskManagerMessages.poll(1, TimeUnit.SECONDS); - if (msg instanceof TaskManagerMessages.FatalError) { - return; // success - } - } + awaitLatch.await(); - fail("Did not receive expected task manager message"); - } finally { - // Interrupt again to clean up Thread - cancelLatch.trigger(); - task.getExecutingThread().interrupt(); - task.getExecutingThread().join(); - } + task.cancelExecution(); + + verify(taskManagerActions, timeout(TIMEOUT)).notifyFatalError( + anyString(), any(Throwable.class)); + + // Interrupt again to clean up Thread + triggerLatch.trigger(); + task.getExecutingThread().interrupt(); + task.getExecutingThread().join(); } /** @@ -830,15 +831,19 @@ public void testTaskConfig() throws Exception { long interval = 28218123; long timeout = interval + 19292; - Configuration config = new Configuration(); + final Configuration config = new Configuration(); config.setLong(TaskManagerOptions.TASK_CANCELLATION_INTERVAL, interval); config.setLong(TaskManagerOptions.TASK_CANCELLATION_TIMEOUT, timeout); - ExecutionConfig executionConfig = new ExecutionConfig(); + final ExecutionConfig executionConfig = new ExecutionConfig(); executionConfig.setTaskCancellationInterval(interval + 1337); executionConfig.setTaskCancellationTimeout(timeout - 1337); - Task task = createTask(InvokableBlockingInInvoke.class, config, executionConfig); + final Task task = new TaskBuilder() + .setInvokable(InvokableBlockingInInvoke.class) + .setTaskManagerConfig(config) + .setExecutionConfig(executionConfig) + .build(); assertEquals(interval, task.getTaskCancellationInterval()); assertEquals(timeout, task.getTaskCancellationTimeout()); @@ -854,6 +859,8 @@ public void testTaskConfig() throws Exception { task.getExecutingThread().join(); } + // ------------------------------------------------------------------------ + // helper functions // ------------------------------------------------------------------------ private void setInputGate(Task task, SingleInputGate inputGate) { @@ -885,232 +892,159 @@ private void setState(Task task, ExecutionState state) { } } - /** - * Creates a {@link BlobCacheService} mock that is suitable to be used in the tests above. - * - * @return BlobCache mock with the bare minimum of implemented functions that work - */ - private BlobCacheService createBlobCache() { - return new BlobCacheService( - mock(PermanentBlobCache.class), - mock(TransientBlobCache.class)); - } + private final class TaskBuilder { + private Class invokable; + private TaskManagerActions taskManagerActions; + private LibraryCacheManager libraryCacheManager; + private ResultPartitionConsumableNotifier consumableNotifier; + private PartitionProducerStateChecker partitionProducerStateChecker; + private NetworkEnvironment networkEnvironment; + private Executor executor; + private Configuration taskManagerConfig; + private ExecutionConfig executionConfig; + private Collection requiredJarFileBlobKeys; - private Task createTask(Class invokable) throws IOException { - return createTask(invokable, new Configuration(), new ExecutionConfig()); - } + { + invokable = TestInvokableCorrect.class; + taskManagerActions = mock(TaskManagerActions.class); - private Task createTask(Class invokable, Configuration config) throws IOException { - BlobCacheService blobService = createBlobCache(); - LibraryCacheManager libCache = mock(LibraryCacheManager.class); - when(libCache.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader()); - return createTask(invokable, blobService, libCache, config, new ExecutionConfig()); - } + libraryCacheManager = mock(LibraryCacheManager.class); + when(libraryCacheManager.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader()); - private Task createTask(Class invokable, Configuration config, ExecutionConfig execConfig) throws IOException { - BlobCacheService blobService = createBlobCache(); - LibraryCacheManager libCache = mock(LibraryCacheManager.class); - when(libCache.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader()); - return createTask(invokable, blobService, libCache, config, execConfig); - } + consumableNotifier = mock(ResultPartitionConsumableNotifier.class); + partitionProducerStateChecker = mock(PartitionProducerStateChecker.class); - private Task createTask( - Class invokable, - BlobCacheService blobService, - LibraryCacheManager libCache) throws IOException { + final ResultPartitionManager partitionManager = mock(ResultPartitionManager.class); + final TaskEventDispatcher taskEventDispatcher = mock(TaskEventDispatcher.class); + networkEnvironment = mock(NetworkEnvironment.class); + when(networkEnvironment.getResultPartitionManager()).thenReturn(partitionManager); + when(networkEnvironment.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC); + when(networkEnvironment.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class))) + .thenReturn(mock(TaskKvStateRegistry.class)); + when(networkEnvironment.getTaskEventDispatcher()).thenReturn(taskEventDispatcher); - return createTask(invokable, blobService, libCache, new Configuration(), new ExecutionConfig()); - } + executor = TestingUtils.defaultExecutor(); - private Task createTask( - Class invokable, - BlobCacheService blobService, - LibraryCacheManager libCache, - Configuration config, - ExecutionConfig execConfig) throws IOException { - - ResultPartitionManager partitionManager = mock(ResultPartitionManager.class); - ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class); - PartitionProducerStateChecker partitionProducerStateChecker = mock(PartitionProducerStateChecker.class); - TaskEventDispatcher taskEventDispatcher = mock(TaskEventDispatcher.class); - Executor executor = mock(Executor.class); - NetworkEnvironment network = mock(NetworkEnvironment.class); - when(network.getResultPartitionManager()).thenReturn(partitionManager); - when(network.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC); - when(network.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class))) - .thenReturn(mock(TaskKvStateRegistry.class)); - when(network.getTaskEventDispatcher()).thenReturn(taskEventDispatcher); + taskManagerConfig = new Configuration(); + executionConfig = new ExecutionConfig(); - return createTask(invokable, blobService, libCache, network, consumableNotifier, partitionProducerStateChecker, executor, config, execConfig); - } + requiredJarFileBlobKeys = Collections.emptyList(); + } - private Task createTask( - Class invokable, - BlobCacheService blobService, - LibraryCacheManager libCache, - NetworkEnvironment networkEnvironment, - ResultPartitionConsumableNotifier consumableNotifier, - PartitionProducerStateChecker partitionProducerStateChecker, - Executor executor) throws IOException { - return createTask(invokable, blobService, libCache, networkEnvironment, consumableNotifier, partitionProducerStateChecker, executor, new Configuration(), new ExecutionConfig()); - } + TaskBuilder setInvokable(Class invokable) { + this.invokable = invokable; + return this; + } - private Task createTask( - Class invokable, - BlobCacheService blobService, - LibraryCacheManager libCache, - NetworkEnvironment networkEnvironment, - ResultPartitionConsumableNotifier consumableNotifier, - PartitionProducerStateChecker partitionProducerStateChecker, - Executor executor, - Configuration taskManagerConfig, - ExecutionConfig execConfig) throws IOException { - - JobID jobId = new JobID(); - JobVertexID jobVertexId = new JobVertexID(); - ExecutionAttemptID executionAttemptId = new ExecutionAttemptID(); - - InputSplitProvider inputSplitProvider = new TaskInputSplitProvider( - jobManagerGateway, - jobId, - jobVertexId, - executionAttemptId, - new FiniteDuration(60, TimeUnit.SECONDS)); - - CheckpointResponder checkpointResponder = new ActorGatewayCheckpointResponder(jobManagerGateway); - - SerializedValue serializedExecutionConfig = new SerializedValue<>(execConfig); - - JobInformation jobInformation = new JobInformation( - jobId, - "Test Job", - serializedExecutionConfig, - new Configuration(), - Collections.emptyList(), - Collections.emptyList()); - - TaskInformation taskInformation = new TaskInformation( - jobVertexId, - "Test Task", - 1, - 1, - invokable.getName(), - new Configuration()); - - TaskMetricGroup taskMetricGroup = mock(TaskMetricGroup.class); - when(taskMetricGroup.getIOMetricGroup()).thenReturn(mock(TaskIOMetricGroup.class)); - - return new Task( - jobInformation, - taskInformation, - executionAttemptId, - new AllocationID(), - 0, - 0, - Collections.emptyList(), - Collections.emptyList(), - 0, - mock(MemoryManager.class), - mock(IOManager.class), - networkEnvironment, - mock(BroadcastVariableManager.class), - new TestTaskStateManager(), - taskManagerConnection, - inputSplitProvider, - checkpointResponder, - blobService, - libCache, - mock(FileCache.class), - new TestingTaskManagerRuntimeInfo(taskManagerConfig), - taskMetricGroup, - consumableNotifier, - partitionProducerStateChecker, - executor); - } + TaskBuilder setTaskManagerActions(TaskManagerActions taskManagerActions) { + this.taskManagerActions = taskManagerActions; + return this; + } - // ------------------------------------------------------------------------ - // Validation Methods - // ------------------------------------------------------------------------ + TaskBuilder setLibraryCacheManager(LibraryCacheManager libraryCacheManager) { + this.libraryCacheManager = libraryCacheManager; + return this; + } - private void validateUnregisterTask(ExecutionAttemptID id) { - try { - // we may have to wait for a bit to give the actors time to receive the message - // and put it into the queue - Object rawMessage = taskManagerMessages.take(); + TaskBuilder setConsumableNotifier(ResultPartitionConsumableNotifier consumableNotifier) { + this.consumableNotifier = consumableNotifier; + return this; + } - assertNotNull("There is no additional TaskManager message", rawMessage); - if (!(rawMessage instanceof TaskMessages.TaskInFinalState)) { - fail("TaskManager message is not 'UnregisterTask', but " + rawMessage.getClass()); - } + TaskBuilder setPartitionProducerStateChecker(PartitionProducerStateChecker partitionProducerStateChecker) { + this.partitionProducerStateChecker = partitionProducerStateChecker; + return this; + } + + TaskBuilder setNetworkEnvironment(NetworkEnvironment networkEnvironment) { + this.networkEnvironment = networkEnvironment; + return this; + } - TaskMessages.TaskInFinalState message = (TaskMessages.TaskInFinalState) rawMessage; - assertEquals(id, message.executionID()); + TaskBuilder setExecutor(Executor executor) { + this.executor = executor; + return this; } - catch (InterruptedException e) { - fail("interrupted"); + + TaskBuilder setTaskManagerConfig(Configuration taskManagerConfig) { + this.taskManagerConfig = taskManagerConfig; + return this; } - } - private void validateTaskManagerStateChange(ExecutionState state, Task task, boolean hasError) { - try { - // we may have to wait for a bit to give the actors time to receive the message - // and put it into the queue - Object rawMessage = taskManagerMessages.take(); + TaskBuilder setExecutionConfig(ExecutionConfig executionConfig) { + this.executionConfig = executionConfig; + return this; + } - assertNotNull("There is no additional TaskManager message", rawMessage); - if (!(rawMessage instanceof TaskMessages.UpdateTaskExecutionState)) { - fail("TaskManager message is not 'UpdateTaskExecutionState', but " + rawMessage.getClass()); - } + TaskBuilder setRequiredJarFileBlobKeys(Collection requiredJarFileBlobKeys) { + this.requiredJarFileBlobKeys = requiredJarFileBlobKeys; + return this; + } - TaskMessages.UpdateTaskExecutionState message = - (TaskMessages.UpdateTaskExecutionState) rawMessage; + private Task build() throws Exception { + final JobID jobId = new JobID(); + final JobVertexID jobVertexId = new JobVertexID(); + final ExecutionAttemptID executionAttemptId = new ExecutionAttemptID(); - TaskExecutionState taskState = message.taskExecutionState(); + final SerializedValue serializedExecutionConfig = new SerializedValue<>(executionConfig); - assertEquals(task.getJobID(), taskState.getJobID()); - assertEquals(task.getExecutionId(), taskState.getID()); - assertEquals(state, taskState.getExecutionState()); + final JobInformation jobInformation = new JobInformation( + jobId, + "Test Job", + serializedExecutionConfig, + new Configuration(), + requiredJarFileBlobKeys, + Collections.emptyList()); - if (hasError) { - assertNotNull(taskState.getError(getClass().getClassLoader())); - } else { - assertNull(taskState.getError(getClass().getClassLoader())); - } - } - catch (InterruptedException e) { - fail("interrupted"); - } - } + final TaskInformation taskInformation = new TaskInformation( + jobVertexId, + "Test Task", + 1, + 1, + invokable.getName(), + new Configuration()); - private void validateListenerMessage(ExecutionState state, Task task, boolean hasError) { - try { - // we may have to wait for a bit to give the actors time to receive the message - // and put it into the queue - final TaskExecutionState taskState = listenerMessages.take(); - assertNotNull("There is no additional listener message", state); - - assertEquals(task.getJobID(), taskState.getJobID()); - assertEquals(task.getExecutionId(), taskState.getID()); - assertEquals(state, taskState.getExecutionState()); - - if (hasError) { - assertNotNull(taskState.getError(getClass().getClassLoader())); - } else { - assertNull(taskState.getError(getClass().getClassLoader())); - } - } - catch (InterruptedException e) { - fail("interrupted"); + final BlobCacheService blobCacheService = new BlobCacheService( + mock(PermanentBlobCache.class), + mock(TransientBlobCache.class)); + + final TaskMetricGroup taskMetricGroup = mock(TaskMetricGroup.class); + when(taskMetricGroup.getIOMetricGroup()).thenReturn(mock(TaskIOMetricGroup.class)); + + return new Task( + jobInformation, + taskInformation, + executionAttemptId, + new AllocationID(), + 0, + 0, + Collections.emptyList(), + Collections.emptyList(), + 0, + mock(MemoryManager.class), + mock(IOManager.class), + networkEnvironment, + mock(BroadcastVariableManager.class), + new TestTaskStateManager(), + taskManagerActions, + new MockInputSplitProvider(), + new TestCheckpointResponder(), + blobCacheService, + libraryCacheManager, + mock(FileCache.class), + new TestingTaskManagerRuntimeInfo(taskManagerConfig), + taskMetricGroup, + consumableNotifier, + partitionProducerStateChecker, + executor); } } - // -------------------------------------------------------------------------------------------- - // Mock invokable code - // -------------------------------------------------------------------------------------------- - - /** Test task class. */ - public static final class TestInvokableCorrect extends AbstractInvokable { + // ------------------------------------------------------------------------ + // test task classes + // ------------------------------------------------------------------------ + private static final class TestInvokableCorrect extends AbstractInvokable { public TestInvokableCorrect(Environment environment) { super(environment); } @@ -1119,14 +1053,18 @@ public TestInvokableCorrect(Environment environment) { public void invoke() {} @Override - public void cancel() throws Exception { + public void cancel() { fail("This should not be called"); } } - /** Test task class. */ - public static final class InvokableWithExceptionInInvoke extends AbstractInvokable { + private abstract static class InvokableNonInstantiable extends AbstractInvokable { + public InvokableNonInstantiable(Environment environment) { + super(environment); + } + } + private static final class InvokableWithExceptionInInvoke extends AbstractInvokable { public InvokableWithExceptionInInvoke(Environment environment) { super(environment); } @@ -1137,44 +1075,21 @@ public void invoke() throws Exception { } } - /** Test task class. */ - public static final class InvokableWithExceptionOnTrigger extends AbstractInvokable { - - public InvokableWithExceptionOnTrigger(Environment environment) { + private static final class FailingInvokableWithChainedException extends AbstractInvokable { + public FailingInvokableWithChainedException(Environment environment) { super(environment); } @Override public void invoke() { - awaitLatch.trigger(); - - // make sure that the interrupt call does not - // grab us out of the lock early - while (true) { - try { - triggerLatch.await(); - break; - } - catch (InterruptedException e) { - // fall through the loop - } - } - - throw new RuntimeException("test"); + throw new TestWrappedException(new IOException("test")); } - } - /** Test task class. */ - public abstract static class InvokableNonInstantiable extends AbstractInvokable { - - public InvokableNonInstantiable(Environment environment) { - super(environment); - } + @Override + public void cancel() {} } - /** Test task class. */ - public static final class InvokableBlockingInInvoke extends AbstractInvokable { - + private static final class InvokableBlockingInInvoke extends AbstractInvokable { public InvokableBlockingInInvoke(Environment environment) { super(environment); } @@ -1190,64 +1105,60 @@ public void invoke() throws Exception { } } - /** Test task class. */ - public static final class InvokableWithCancelTaskExceptionInInvoke extends AbstractInvokable { - - public InvokableWithCancelTaskExceptionInInvoke(Environment environment) { + public static final class InvokableWithExceptionOnTrigger extends AbstractInvokable { + public InvokableWithExceptionOnTrigger(Environment environment) { super(environment); } @Override - public void invoke() throws Exception { + public void invoke() { awaitLatch.trigger(); - try { - triggerLatch.await(); + // make sure that the interrupt call does not + // grab us out of the lock early + while (true) { + try { + triggerLatch.await(); + break; + } + catch (InterruptedException e) { + // fall through the loop + } } - catch (Throwable ignored) {} - throw new CancelTaskException(); + throw new RuntimeException("test"); } } - /** Test task class. */ - public static final class InvokableInterruptableSharedLockInInvokeAndCancel extends AbstractInvokable { - - private final Object lock = new Object(); - - public InvokableInterruptableSharedLockInInvokeAndCancel(Environment environment) { + public static final class InvokableWithCancelTaskExceptionInInvoke extends AbstractInvokable { + public InvokableWithCancelTaskExceptionInInvoke(Environment environment) { super(environment); } @Override - public void invoke() throws Exception { - synchronized (lock) { - awaitLatch.trigger(); - wait(); - } - } + public void invoke() { + awaitLatch.trigger(); - @Override - public void cancel() throws Exception { - synchronized (lock) { - cancelLatch.trigger(); + try { + triggerLatch.await(); } + catch (Throwable ignored) {} + + throw new CancelTaskException(); } } - /** Test task class. */ public static final class InvokableBlockingInCancel extends AbstractInvokable { - public InvokableBlockingInCancel(Environment environment) { super(environment); } @Override - public void invoke() throws Exception { + public void invoke() { awaitLatch.trigger(); try { - cancelLatch.await(); + triggerLatch.await(); // await cancel synchronized (this) { wait(); } @@ -1261,51 +1172,57 @@ public void invoke() throws Exception { @Override public void cancel() throws Exception { synchronized (this) { - cancelLatch.trigger(); + triggerLatch.trigger(); wait(); } } } - /** Test task class. */ - public static final class InvokableUninterruptibleBlockingInvoke extends AbstractInvokable { + public static final class InvokableInterruptibleSharedLockInInvokeAndCancel extends AbstractInvokable { + private final Object lock = new Object(); - public InvokableUninterruptibleBlockingInvoke(Environment environment) { + public InvokableInterruptibleSharedLockInInvokeAndCancel(Environment environment) { super(environment); } @Override public void invoke() throws Exception { - while (!cancelLatch.isTriggered()) { - try { - synchronized (this) { - awaitLatch.trigger(); - wait(); - } - } catch (InterruptedException ignored) { - } + synchronized (lock) { + awaitLatch.trigger(); + wait(); } } @Override - public void cancel() throws Exception { + public void cancel() { + synchronized (lock) { + // do nothing but a placeholder + triggerLatch.trigger(); + } } } - /** Test task class. */ - public static final class FailingInvokableWithChainedException extends AbstractInvokable { - - public FailingInvokableWithChainedException(Environment environment) { + public static final class InvokableUnInterruptibleBlockingInvoke extends AbstractInvokable { + public InvokableUnInterruptibleBlockingInvoke(Environment environment) { super(environment); } @Override - public void invoke() throws Exception { - throw new TestWrappedException(new IOException("test")); + public void invoke() { + while (!triggerLatch.isTriggered()) { + try { + synchronized (this) { + awaitLatch.trigger(); + wait(); + } + } catch (InterruptedException ignored) { + } + } } @Override - public void cancel() {} + public void cancel() { + } } // ------------------------------------------------------------------------ @@ -1315,7 +1232,7 @@ public void cancel() {} private static class TestWrappedException extends WrappingRuntimeException { private static final long serialVersionUID = 1L; - public TestWrappedException(@Nonnull Throwable cause) { + TestWrappedException(@Nonnull Throwable cause) { super(cause); } }