diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java index bb5cd17193a27..f4abf182d12fa 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java @@ -20,24 +20,29 @@ import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.JobID; +import org.apache.flink.configuration.BlobServerOptions; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.TaskManagerOptions; import org.apache.flink.core.testutils.OneShotLatch; import org.apache.flink.runtime.blob.BlobCacheService; +import org.apache.flink.runtime.blob.BlobServer; import org.apache.flink.runtime.blob.PermanentBlobCache; +import org.apache.flink.runtime.blob.PermanentBlobKey; import org.apache.flink.runtime.blob.TransientBlobCache; +import org.apache.flink.runtime.blob.VoidBlobStore; import org.apache.flink.runtime.broadcast.BroadcastVariableManager; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.concurrent.Executors; import org.apache.flink.runtime.execution.CancelTaskException; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.execution.ExecutionState; +import org.apache.flink.runtime.execution.librarycache.BlobLibraryCacheManager; +import org.apache.flink.runtime.execution.librarycache.FlinkUserCodeClassLoaders; import org.apache.flink.runtime.execution.librarycache.LibraryCacheManager; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.JobInformation; import org.apache.flink.runtime.executiongraph.TaskInformation; import org.apache.flink.runtime.filecache.FileCache; -import org.apache.flink.runtime.instance.ActorGateway; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.network.NetworkEnvironment; import org.apache.flink.runtime.io.network.TaskEventDispatcher; @@ -49,40 +54,35 @@ import org.apache.flink.runtime.jobgraph.IntermediateDataSetID; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; -import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; import org.apache.flink.runtime.jobmanager.PartitionProducerDisposedException; import org.apache.flink.runtime.memory.MemoryManager; -import org.apache.flink.runtime.messages.TaskManagerMessages; -import org.apache.flink.runtime.messages.TaskMessages; import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup; import org.apache.flink.runtime.metrics.groups.TaskMetricGroup; +import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; import org.apache.flink.runtime.query.TaskKvStateRegistry; import org.apache.flink.runtime.state.TestTaskStateManager; +import org.apache.flink.runtime.testingUtils.TestingUtils; import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo; import org.apache.flink.util.SerializedValue; import org.apache.flink.util.TestLogger; import org.apache.flink.util.WrappingRuntimeException; - -import org.junit.After; import org.junit.Before; +import org.junit.ClassRule; import org.junit.Test; +import org.junit.rules.TemporaryFolder; import javax.annotation.Nonnull; - import java.io.IOException; import java.lang.reflect.Field; +import java.net.InetSocketAddress; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.Map; -import java.util.concurrent.BlockingQueue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; -import java.util.concurrent.LinkedBlockingQueue; -import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import scala.concurrent.duration.FiniteDuration; - import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; @@ -90,9 +90,11 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyString; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -100,445 +102,402 @@ /** * Tests for the Task, which make sure that correct state transitions happen, * and failures are correctly handled. - * - *

All tests here have a set of mock actors for TaskManager, JobManager, and - * execution listener, which simply put the messages in a queue to be picked - * up by the test and validated. */ public class TaskTest extends TestLogger { + private static final long TIMEOUT = 1000L; + private static OneShotLatch awaitLatch; private static OneShotLatch triggerLatch; - private static OneShotLatch cancelLatch; - - private ActorGateway taskManagerGateway; - private ActorGateway jobManagerGateway; - private ActorGatewayTaskManagerActions taskManagerConnection; - - private BlockingQueue taskManagerMessages; - private BlockingQueue jobManagerMessages; - private BlockingQueue listenerMessages; + @ClassRule + public static final TemporaryFolder TEMPORARY_FOLDER = new TemporaryFolder(); @Before - public void createQueuesAndActors() { - taskManagerMessages = new LinkedBlockingQueue<>(); - jobManagerMessages = new LinkedBlockingQueue<>(); - listenerMessages = new LinkedBlockingQueue<>(); - taskManagerGateway = new ForwardingActorGateway(taskManagerMessages); - jobManagerGateway = new ForwardingActorGateway(jobManagerMessages); - - taskManagerConnection = new ActorGatewayTaskManagerActions(taskManagerGateway) { - @Override - public void updateTaskExecutionState(TaskExecutionState taskExecutionState) { - super.updateTaskExecutionState(taskExecutionState); - listenerMessages.add(taskExecutionState); - } - }; - + public void setup() { awaitLatch = new OneShotLatch(); triggerLatch = new OneShotLatch(); - cancelLatch = new OneShotLatch(); } - @After - public void clearActorsAndMessages() { - jobManagerMessages = null; - taskManagerMessages = null; - listenerMessages = null; + @Test + public void testRegularExecution() throws Exception { + final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + final Task task = new TaskBuilder() + .setTaskManagerActions(taskManagerActions) + .build(); + + // task should be new and perfect + assertEquals(ExecutionState.CREATED, task.getExecutionState()); + assertFalse(task.isCanceledOrFailed()); + assertNull(task.getFailureCause()); + + // go into the run method. we should switch to DEPLOYING, RUNNING, then + // FINISHED, and all should be good + final TaskExecutionState state = new TaskExecutionState( + task.getJobID(), + task.getExecutionId(), + ExecutionState.RUNNING); - taskManagerGateway = null; - jobManagerGateway = null; - } + task.run(); - // ------------------------------------------------------------------------ - // Tests - // ------------------------------------------------------------------------ + verify(taskManagerActions, timeout(TIMEOUT)).updateTaskExecutionState(eq(state)); - @Test - public void testRegularExecution() { - try { - Task task = createTask(TestInvokableCorrect.class); - - // task should be new and perfect - assertEquals(ExecutionState.CREATED, task.getExecutionState()); - assertFalse(task.isCanceledOrFailed()); - assertNull(task.getFailureCause()); - - // go into the run method. we should switch to DEPLOYING, RUNNING, then - // FINISHED, and all should be good - task.run(); - - // verify final state - assertEquals(ExecutionState.FINISHED, task.getExecutionState()); - assertFalse(task.isCanceledOrFailed()); - assertNull(task.getFailureCause()); - assertNull(task.getInvokable()); - - // verify listener messages - validateListenerMessage(ExecutionState.RUNNING, task, false); - validateListenerMessage(ExecutionState.FINISHED, task, false); - - // make sure that the TaskManager received an message to unregister the task - validateTaskManagerStateChange(ExecutionState.RUNNING, task, false); - validateUnregisterTask(task.getExecutionId()); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } + // verify final state + assertEquals(ExecutionState.FINISHED, task.getExecutionState()); + assertFalse(task.isCanceledOrFailed()); + assertNull(task.getFailureCause()); + assertNull(task.getInvokable()); } @Test - public void testCancelRightAway() { - try { - Task task = createTask(TestInvokableCorrect.class); - task.cancelExecution(); + public void testCancelRightAway() throws Exception { + final Task task = new TaskBuilder().build(); + task.cancelExecution(); - assertEquals(ExecutionState.CANCELING, task.getExecutionState()); + assertEquals(ExecutionState.CANCELING, task.getExecutionState()); - task.run(); + task.run(); - // verify final state - assertEquals(ExecutionState.CANCELED, task.getExecutionState()); - validateUnregisterTask(task.getExecutionId()); + // verify final state + assertEquals(ExecutionState.CANCELED, task.getExecutionState()); - assertNull(task.getInvokable()); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } + assertNull(task.getInvokable()); } @Test - public void testFailExternallyRightAway() { - try { - Task task = createTask(TestInvokableCorrect.class); - task.failExternally(new Exception("fail externally")); + public void testFailExternallyRightAway() throws Exception { + Task task = new TaskBuilder().build(); + task.failExternally(new Exception("fail externally")); - assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertEquals(ExecutionState.FAILED, task.getExecutionState()); - task.run(); + task.run(); - // verify final state - assertEquals(ExecutionState.FAILED, task.getExecutionState()); - validateUnregisterTask(task.getExecutionId()); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } + // verify final state + assertEquals(ExecutionState.FAILED, task.getExecutionState()); } @Test - public void testLibraryCacheRegistrationFailed() { - try { - BlobCacheService blobService = createBlobCache(); - Task task = createTask(TestInvokableCorrect.class, blobService, - mock(LibraryCacheManager.class)); - - // task should be new and perfect - assertEquals(ExecutionState.CREATED, task.getExecutionState()); - assertFalse(task.isCanceledOrFailed()); - assertNull(task.getFailureCause()); + public void testLibraryCacheRegistrationFailed() throws Exception { + final Task task = new TaskBuilder() + .setLibraryCacheManager(mock(LibraryCacheManager.class)) // inactive manager + .build(); - // should fail - task.run(); + // task should be new and perfect + assertEquals(ExecutionState.CREATED, task.getExecutionState()); + assertFalse(task.isCanceledOrFailed()); + assertNull(task.getFailureCause()); - // verify final state - assertEquals(ExecutionState.FAILED, task.getExecutionState()); - assertTrue(task.isCanceledOrFailed()); - assertNotNull(task.getFailureCause()); - assertNotNull(task.getFailureCause().getMessage()); - assertTrue(task.getFailureCause().getMessage().contains("classloader")); - - // verify listener messages - validateListenerMessage(ExecutionState.FAILED, task, true); + // should fail + task.run(); - // make sure that the TaskManager received an message to unregister the task - validateUnregisterTask(task.getExecutionId()); + // verify final state + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertNotNull(task.getFailureCause()); + assertNotNull(task.getFailureCause().getMessage()); + assertTrue(task.getFailureCause().getMessage().contains("classloader")); - assertNull(task.getInvokable()); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } + assertNull(task.getInvokable()); } @Test - public void testExecutionFailsInNetworkRegistration() { - try { - BlobCacheService blobService = createBlobCache(); - // mock a working library cache - LibraryCacheManager libCache = mock(LibraryCacheManager.class); - when(libCache.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader()); - - // mock a network manager that rejects registration - ResultPartitionManager partitionManager = mock(ResultPartitionManager.class); - ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class); - PartitionProducerStateChecker partitionProducerStateChecker = mock(PartitionProducerStateChecker.class); - TaskEventDispatcher taskEventDispatcher = mock(TaskEventDispatcher.class); - Executor executor = mock(Executor.class); - NetworkEnvironment network = mock(NetworkEnvironment.class); - when(network.getResultPartitionManager()).thenReturn(partitionManager); - when(network.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC); - when(network.getTaskEventDispatcher()).thenReturn(taskEventDispatcher); - doThrow(new RuntimeException("buffers")).when(network).registerTask(any(Task.class)); - - Task task = createTask(TestInvokableCorrect.class, blobService, libCache, network, consumableNotifier, partitionProducerStateChecker, executor); - - task.run(); + public void testExecutionFailsInBlobsMissing() throws Exception { + final PermanentBlobKey missingKey = new PermanentBlobKey(); + + final Configuration config = new Configuration(); + config.setString(BlobServerOptions.STORAGE_DIRECTORY, + TEMPORARY_FOLDER.newFolder().getAbsolutePath()); + config.setLong(BlobServerOptions.CLEANUP_INTERVAL, 1L); + + final BlobServer blobServer = new BlobServer(config, new VoidBlobStore()); + blobServer.start(); + InetSocketAddress serverAddress = new InetSocketAddress("localhost", blobServer.getPort()); + final PermanentBlobCache permanentBlobCache = new PermanentBlobCache(config, new VoidBlobStore(), serverAddress); + + final BlobLibraryCacheManager libraryCacheManager = + new BlobLibraryCacheManager( + permanentBlobCache, + FlinkUserCodeClassLoaders.ResolveOrder.CHILD_FIRST, + new String[0]); + + final Task task = new TaskBuilder() + .setRequiredJarFileBlobKeys(Collections.singletonList(missingKey)) + .setLibraryCacheManager(libraryCacheManager) + .build(); + + // task should be new and perfect + assertEquals(ExecutionState.CREATED, task.getExecutionState()); + assertFalse(task.isCanceledOrFailed()); + assertNull(task.getFailureCause()); + + // should fail + task.run(); - assertEquals(ExecutionState.FAILED, task.getExecutionState()); - assertTrue(task.isCanceledOrFailed()); - assertTrue(task.getFailureCause().getMessage().contains("buffers")); + // verify final state + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertNotNull(task.getFailureCause()); + assertNotNull(task.getFailureCause().getMessage()); + assertTrue(task.getFailureCause().getMessage().contains("Failed to fetch BLOB")); - validateUnregisterTask(task.getExecutionId()); - validateListenerMessage(ExecutionState.FAILED, task, true); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } + assertNull(task.getInvokable()); } @Test - public void testInvokableInstantiationFailed() { - try { - Task task = createTask(InvokableNonInstantiable.class); + public void testExecutionFailsInNetworkRegistration() throws Exception { + // mock a network manager that rejects registration + final ResultPartitionManager partitionManager = mock(ResultPartitionManager.class); + final ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class); + final PartitionProducerStateChecker partitionProducerStateChecker = mock(PartitionProducerStateChecker.class); + final TaskEventDispatcher taskEventDispatcher = mock(TaskEventDispatcher.class); + + final NetworkEnvironment network = mock(NetworkEnvironment.class); + when(network.getResultPartitionManager()).thenReturn(partitionManager); + when(network.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC); + when(network.getTaskEventDispatcher()).thenReturn(taskEventDispatcher); + doThrow(new RuntimeException("buffers")).when(network).registerTask(any(Task.class)); - task.run(); + final Task task = new TaskBuilder() + .setConsumableNotifier(consumableNotifier) + .setPartitionProducerStateChecker(partitionProducerStateChecker) + .setNetworkEnvironment(network) + .build(); - assertEquals(ExecutionState.FAILED, task.getExecutionState()); - assertTrue(task.isCanceledOrFailed()); - assertTrue(task.getFailureCause().getMessage().contains("instantiate")); + // should fail + task.run(); - validateUnregisterTask(task.getExecutionId()); - validateListenerMessage(ExecutionState.FAILED, task, true); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } + // verify final state + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertTrue(task.getFailureCause().getMessage().contains("buffers")); } @Test - public void testExecutionFailsInInvoke() { - try { - Task task = createTask(InvokableWithExceptionInInvoke.class); + public void testInvokableInstantiationFailed() throws Exception { + final Task task = new TaskBuilder() + .setInvokable(InvokableNonInstantiable.class) + .build(); - task.run(); + // should fail + task.run(); - assertEquals(ExecutionState.FAILED, task.getExecutionState()); - assertTrue(task.isCanceledOrFailed()); - assertNotNull(task.getFailureCause()); - assertNotNull(task.getFailureCause().getMessage()); - assertTrue(task.getFailureCause().getMessage().contains("test")); + // verify final state + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertTrue(task.getFailureCause().getMessage().contains("instantiate")); + } - validateTaskManagerStateChange(ExecutionState.RUNNING, task, false); - validateUnregisterTask(task.getExecutionId()); + @Test + public void testExecutionFailsInInvoke() throws Exception { + final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + final Task task = new TaskBuilder() + .setInvokable(InvokableWithExceptionInInvoke.class) + .setTaskManagerActions(taskManagerActions) + .build(); + + final TaskExecutionState state = new TaskExecutionState( + task.getJobID(), + task.getExecutionId(), + ExecutionState.RUNNING); - validateListenerMessage(ExecutionState.RUNNING, task, false); - validateListenerMessage(ExecutionState.FAILED, task, true); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } + task.run(); + + verify(taskManagerActions, timeout(TIMEOUT)).updateTaskExecutionState(eq(state)); + + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertNotNull(task.getFailureCause()); + assertNotNull(task.getFailureCause().getMessage()); + assertTrue(task.getFailureCause().getMessage().contains("test")); } @Test - public void testFailWithWrappedException() { - try { - Task task = createTask(FailingInvokableWithChainedException.class); + public void testFailWithWrappedException() throws Exception { + final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + final Task task = new TaskBuilder() + .setInvokable(FailingInvokableWithChainedException.class) + .setTaskManagerActions(taskManagerActions) + .build(); + + final TaskExecutionState state = new TaskExecutionState( + task.getJobID(), + task.getExecutionId(), + ExecutionState.RUNNING); - task.run(); - - assertEquals(ExecutionState.FAILED, task.getExecutionState()); - assertTrue(task.isCanceledOrFailed()); + task.run(); - Throwable cause = task.getFailureCause(); - assertTrue(cause instanceof IOException); + verify(taskManagerActions, timeout(TIMEOUT)).updateTaskExecutionState(eq(state)); - validateTaskManagerStateChange(ExecutionState.RUNNING, task, false); - validateUnregisterTask(task.getExecutionId()); + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); - validateListenerMessage(ExecutionState.RUNNING, task, false); - validateListenerMessage(ExecutionState.FAILED, task, true); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } + final Throwable cause = task.getFailureCause(); + assertTrue(cause instanceof IOException); } @Test - public void testCancelDuringInvoke() { - try { - Task task = createTask(InvokableBlockingInInvoke.class); - - // run the task asynchronous - task.startTaskThread(); - - // wait till the task is in invoke - awaitLatch.await(); + public void testCancelDuringInvoke() throws Exception { + final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + final Task task = new TaskBuilder() + .setInvokable(InvokableBlockingInInvoke.class) + .setTaskManagerActions(taskManagerActions) + .build(); + + final TaskExecutionState state = new TaskExecutionState( + task.getJobID(), + task.getExecutionId(), + ExecutionState.RUNNING); + + // run the task asynchronous + task.startTaskThread(); - task.cancelExecution(); - assertTrue(task.getExecutionState() == ExecutionState.CANCELING || - task.getExecutionState() == ExecutionState.CANCELED); + verify(taskManagerActions, timeout(TIMEOUT)).updateTaskExecutionState(eq(state)); - task.getExecutingThread().join(); + // wait till the task is in invoke + awaitLatch.await(); - assertEquals(ExecutionState.CANCELED, task.getExecutionState()); - assertTrue(task.isCanceledOrFailed()); - assertNull(task.getFailureCause()); + task.cancelExecution(); + assertTrue(task.getExecutionState() == ExecutionState.CANCELING || + task.getExecutionState() == ExecutionState.CANCELED); - validateTaskManagerStateChange(ExecutionState.RUNNING, task, false); - validateUnregisterTask(task.getExecutionId()); + task.getExecutingThread().join(); - validateListenerMessage(ExecutionState.RUNNING, task, false); - validateListenerMessage(ExecutionState.CANCELED, task, false); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } + assertEquals(ExecutionState.CANCELED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertNull(task.getFailureCause()); } @Test - public void testFailExternallyDuringInvoke() { - try { - Task task = createTask(InvokableBlockingInInvoke.class); - - // run the task asynchronous - task.startTaskThread(); - - // wait till the task is in regInOut - awaitLatch.await(); + public void testFailExternallyDuringInvoke() throws Exception { + final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + final Task task = new TaskBuilder() + .setInvokable(InvokableBlockingInInvoke.class) + .setTaskManagerActions(taskManagerActions) + .build(); + + final TaskExecutionState state = new TaskExecutionState( + task.getJobID(), + task.getExecutionId(), + ExecutionState.RUNNING); + + // run the task asynchronous + task.startTaskThread(); - task.failExternally(new Exception("test")); - assertEquals(ExecutionState.FAILED, task.getExecutionState()); + verify(taskManagerActions, timeout(TIMEOUT)).updateTaskExecutionState(eq(state)); - task.getExecutingThread().join(); + // wait till the task is in invoke + awaitLatch.await(); - assertEquals(ExecutionState.FAILED, task.getExecutionState()); - assertTrue(task.isCanceledOrFailed()); - assertTrue(task.getFailureCause().getMessage().contains("test")); + task.failExternally(new Exception("test")); - validateTaskManagerStateChange(ExecutionState.RUNNING, task, false); - validateUnregisterTask(task.getExecutionId()); + task.getExecutingThread().join(); - validateListenerMessage(ExecutionState.RUNNING, task, false); - validateListenerMessage(ExecutionState.FAILED, task, true); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertTrue(task.getFailureCause().getMessage().contains("test")); } @Test - public void testCanceledAfterExecutionFailedInInvoke() { - try { - Task task = createTask(InvokableWithExceptionInInvoke.class); + public void testCanceledAfterExecutionFailedInInvoke() throws Exception { + final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + final Task task = new TaskBuilder() + .setInvokable(InvokableWithExceptionInInvoke.class) + .setTaskManagerActions(taskManagerActions) + .build(); + + final TaskExecutionState state = new TaskExecutionState( + task.getJobID(), + task.getExecutionId(), + ExecutionState.RUNNING); - task.run(); + task.run(); - // this should not overwrite the failure state - task.cancelExecution(); + verify(taskManagerActions, timeout(TIMEOUT)).updateTaskExecutionState(eq(state)); - assertEquals(ExecutionState.FAILED, task.getExecutionState()); - assertTrue(task.isCanceledOrFailed()); - assertTrue(task.getFailureCause().getMessage().contains("test")); - - validateTaskManagerStateChange(ExecutionState.RUNNING, task, false); - validateUnregisterTask(task.getExecutionId()); + // this should not overwrite the failure state + task.cancelExecution(); - validateListenerMessage(ExecutionState.RUNNING, task, false); - validateListenerMessage(ExecutionState.FAILED, task, true); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertTrue(task.getFailureCause().getMessage().contains("test")); } @Test - public void testExecutionFailsAfterCanceling() { - try { - Task task = createTask(InvokableWithExceptionOnTrigger.class); - - // run the task asynchronous - task.startTaskThread(); - - // wait till the task is in invoke - awaitLatch.await(); + public void testExecutionFailsAfterCanceling() throws Exception { + final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + final Task task = new TaskBuilder() + .setInvokable(InvokableWithExceptionOnTrigger.class) + .setTaskManagerActions(taskManagerActions) + .build(); + + final TaskExecutionState state = new TaskExecutionState( + task.getJobID(), + task.getExecutionId(), + ExecutionState.RUNNING); + + // run the task asynchronous + task.startTaskThread(); - task.cancelExecution(); - assertEquals(ExecutionState.CANCELING, task.getExecutionState()); + verify(taskManagerActions, timeout(TIMEOUT)).updateTaskExecutionState(eq(state)); - // this causes an exception - triggerLatch.trigger(); + // wait till the task is in invoke + awaitLatch.await(); - task.getExecutingThread().join(); + task.cancelExecution(); + assertEquals(ExecutionState.CANCELING, task.getExecutionState()); - // we should still be in state canceled - assertEquals(ExecutionState.CANCELED, task.getExecutionState()); - assertTrue(task.isCanceledOrFailed()); - assertNull(task.getFailureCause()); + // this causes an exception + triggerLatch.trigger(); - validateTaskManagerStateChange(ExecutionState.RUNNING, task, false); - validateUnregisterTask(task.getExecutionId()); + task.getExecutingThread().join(); - validateListenerMessage(ExecutionState.RUNNING, task, false); - validateListenerMessage(ExecutionState.CANCELED, task, false); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } + // we should still be in state canceled + assertEquals(ExecutionState.CANCELED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertNull(task.getFailureCause()); } @Test - public void testExecutionFailsAfterTaskMarkedFailed() { - try { - Task task = createTask(InvokableWithExceptionOnTrigger.class); - - // run the task asynchronous - task.startTaskThread(); - - // wait till the task is in invoke - awaitLatch.await(); + public void testExecutionFailsAfterTaskMarkedFailed() throws Exception { + final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + final Task task = new TaskBuilder() + .setInvokable(InvokableWithExceptionOnTrigger.class) + .setTaskManagerActions(taskManagerActions) + .build(); + + final TaskExecutionState state = new TaskExecutionState( + task.getJobID(), + task.getExecutionId(), + ExecutionState.RUNNING); + + // run the task asynchronous + task.startTaskThread(); - task.failExternally(new Exception("external")); - assertEquals(ExecutionState.FAILED, task.getExecutionState()); + verify(taskManagerActions, timeout(TIMEOUT)).updateTaskExecutionState(eq(state)); - // this causes an exception - triggerLatch.trigger(); + // wait till the task is in invoke + awaitLatch.await(); - task.getExecutingThread().join(); + task.failExternally(new Exception("external")); + assertEquals(ExecutionState.FAILED, task.getExecutionState()); - assertEquals(ExecutionState.FAILED, task.getExecutionState()); - assertTrue(task.isCanceledOrFailed()); - assertTrue(task.getFailureCause().getMessage().contains("external")); + // this causes an exception + triggerLatch.trigger(); - validateTaskManagerStateChange(ExecutionState.RUNNING, task, false); - validateUnregisterTask(task.getExecutionId()); + task.getExecutingThread().join(); - validateListenerMessage(ExecutionState.RUNNING, task, false); - validateListenerMessage(ExecutionState.FAILED, task, true); - } - catch (Exception e) { - e.printStackTrace(); - fail(e.getMessage()); - } + assertEquals(ExecutionState.FAILED, task.getExecutionState()); + assertTrue(task.isCanceledOrFailed()); + assertTrue(task.getFailureCause().getMessage().contains("external")); } + @Test public void testCancelTaskException() throws Exception { - final Task task = createTask(InvokableWithCancelTaskExceptionInInvoke.class); + final Task task = new TaskBuilder() + .setInvokable(InvokableWithCancelTaskExceptionInInvoke.class) + .build(); // Cause CancelTaskException. triggerLatch.trigger(); @@ -550,7 +509,9 @@ public void testCancelTaskException() throws Exception { @Test public void testCancelTaskExceptionAfterTaskMarkedFailed() throws Exception { - final Task task = createTask(InvokableWithCancelTaskExceptionInInvoke.class); + final Task task = new TaskBuilder() + .setInvokable(InvokableWithCancelTaskExceptionInInvoke.class) + .build(); task.startTaskThread(); @@ -573,13 +534,15 @@ public void testCancelTaskExceptionAfterTaskMarkedFailed() throws Exception { @Test public void testOnPartitionStateUpdate() throws Exception { - IntermediateDataSetID resultId = new IntermediateDataSetID(); - ResultPartitionID partitionId = new ResultPartitionID(); + final IntermediateDataSetID resultId = new IntermediateDataSetID(); + final ResultPartitionID partitionId = new ResultPartitionID(); - SingleInputGate inputGate = mock(SingleInputGate.class); + final SingleInputGate inputGate = mock(SingleInputGate.class); when(inputGate.getConsumedResultId()).thenReturn(resultId); - final Task task = createTask(InvokableBlockingInInvoke.class); + final Task task = new TaskBuilder() + .setInvokable(InvokableBlockingInInvoke.class) + .build(); // Set the mock input gate setInputGate(task, inputGate); @@ -619,36 +582,36 @@ public void testOnPartitionStateUpdate() throws Exception { */ @Test public void testTriggerPartitionStateUpdate() throws Exception { - IntermediateDataSetID resultId = new IntermediateDataSetID(); - ResultPartitionID partitionId = new ResultPartitionID(); - - BlobCacheService blobService = createBlobCache(); - LibraryCacheManager libCache = mock(LibraryCacheManager.class); - when(libCache.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader()); + final IntermediateDataSetID resultId = new IntermediateDataSetID(); + final ResultPartitionID partitionId = new ResultPartitionID(); - PartitionProducerStateChecker partitionChecker = mock(PartitionProducerStateChecker.class); - TaskEventDispatcher taskEventDispatcher = mock(TaskEventDispatcher.class); + final PartitionProducerStateChecker partitionChecker = mock(PartitionProducerStateChecker.class); + final TaskEventDispatcher taskEventDispatcher = mock(TaskEventDispatcher.class); - ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class); - NetworkEnvironment network = mock(NetworkEnvironment.class); + final ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class); + final NetworkEnvironment network = mock(NetworkEnvironment.class); when(network.getResultPartitionManager()).thenReturn(mock(ResultPartitionManager.class)); when(network.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC); when(network.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class))) .thenReturn(mock(TaskKvStateRegistry.class)); when(network.getTaskEventDispatcher()).thenReturn(taskEventDispatcher); - createTask(InvokableBlockingInInvoke.class, blobService, libCache, network, consumableNotifier, partitionChecker, Executors.directExecutor()); // Test all branches of trigger partition state check - { // Reset latches - createQueuesAndActors(); + setup(); // PartitionProducerDisposedException - Task task = createTask(InvokableBlockingInInvoke.class, blobService, libCache, network, consumableNotifier, partitionChecker, Executors.directExecutor()); - - CompletableFuture promise = new CompletableFuture<>(); + final Task task = new TaskBuilder() + .setInvokable(InvokableBlockingInInvoke.class) + .setNetworkEnvironment(network) + .setConsumableNotifier(consumableNotifier) + .setPartitionProducerStateChecker(partitionChecker) + .setExecutor(Executors.directExecutor()) + .build(); + + final CompletableFuture promise = new CompletableFuture<>(); when(partitionChecker.requestPartitionProducerState(eq(task.getJobID()), eq(resultId), eq(partitionId))).thenReturn(promise); task.triggerPartitionProducerStateCheck(task.getJobID(), resultId, partitionId); @@ -659,12 +622,18 @@ public void testTriggerPartitionStateUpdate() throws Exception { { // Reset latches - createQueuesAndActors(); + setup(); // Any other exception - Task task = createTask(InvokableBlockingInInvoke.class, blobService, libCache, network, consumableNotifier, partitionChecker, Executors.directExecutor()); - - CompletableFuture promise = new CompletableFuture<>(); + final Task task = new TaskBuilder() + .setInvokable(InvokableBlockingInInvoke.class) + .setNetworkEnvironment(network) + .setConsumableNotifier(consumableNotifier) + .setPartitionProducerStateChecker(partitionChecker) + .setExecutor(Executors.directExecutor()) + .build(); + + final CompletableFuture promise = new CompletableFuture<>(); when(partitionChecker.requestPartitionProducerState(eq(task.getJobID()), eq(resultId), eq(partitionId))).thenReturn(promise); task.triggerPartitionProducerStateCheck(task.getJobID(), resultId, partitionId); @@ -676,11 +645,19 @@ public void testTriggerPartitionStateUpdate() throws Exception { { // Reset latches - createQueuesAndActors(); + setup(); // TimeoutException handled special => retry - Task task = createTask(InvokableBlockingInInvoke.class, blobService, libCache, network, consumableNotifier, partitionChecker, Executors.directExecutor()); - SingleInputGate inputGate = mock(SingleInputGate.class); + // Any other exception + final Task task = new TaskBuilder() + .setInvokable(InvokableBlockingInInvoke.class) + .setNetworkEnvironment(network) + .setConsumableNotifier(consumableNotifier) + .setPartitionProducerStateChecker(partitionChecker) + .setExecutor(Executors.directExecutor()) + .build(); + + final SingleInputGate inputGate = mock(SingleInputGate.class); when(inputGate.getConsumedResultId()).thenReturn(resultId); try { @@ -707,11 +684,18 @@ public void testTriggerPartitionStateUpdate() throws Exception { { // Reset latches - createQueuesAndActors(); + setup(); // Success - Task task = createTask(InvokableBlockingInInvoke.class, blobService, libCache, network, consumableNotifier, partitionChecker, Executors.directExecutor()); - SingleInputGate inputGate = mock(SingleInputGate.class); + final Task task = new TaskBuilder() + .setInvokable(InvokableBlockingInInvoke.class) + .setNetworkEnvironment(network) + .setConsumableNotifier(consumableNotifier) + .setPartitionProducerStateChecker(partitionChecker) + .setExecutor(Executors.directExecutor()) + .build(); + + final SingleInputGate inputGate = mock(SingleInputGate.class); when(inputGate.getConsumedResultId()).thenReturn(resultId); try { @@ -744,22 +728,28 @@ public void testTriggerPartitionStateUpdate() throws Exception { */ @Test public void testWatchDogInterruptsTask() throws Exception { - Configuration config = new Configuration(); + final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + + // guard no fatal error + doThrow(new RuntimeException("Unexpected FatalError message")). + when(taskManagerActions).notifyFatalError(anyString(), any(Throwable.class)); + + final Configuration config = new Configuration(); config.setLong(TaskManagerOptions.TASK_CANCELLATION_INTERVAL.key(), 5); config.setLong(TaskManagerOptions.TASK_CANCELLATION_TIMEOUT.key(), 60 * 1000); - Task task = createTask(InvokableBlockingInCancel.class, config); + final Task task = new TaskBuilder() + .setInvokable(InvokableBlockingInCancel.class) + .setTaskManagerConfig(config) + .setTaskManagerActions(taskManagerActions) + .build(); + task.startTaskThread(); awaitLatch.await(); task.cancelExecution(); task.getExecutingThread().join(); - - // No fatal error - for (Object msg : taskManagerMessages) { - assertFalse("Unexpected FatalError message", msg instanceof TaskManagerMessages.FatalError); - } } /** @@ -768,23 +758,29 @@ public void testWatchDogInterruptsTask() throws Exception { * This is resolved by the watch dog, no fatal error. */ @Test - public void testInterruptableSharedLockInInvokeAndCancel() throws Exception { - Configuration config = new Configuration(); + public void testInterruptibleSharedLockInInvokeAndCancel() throws Exception { + final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + + // guard no fatal error + doThrow(new RuntimeException("Unexpected FatalError message")). + when(taskManagerActions).notifyFatalError(anyString(), any(Throwable.class)); + + final Configuration config = new Configuration(); config.setLong(TaskManagerOptions.TASK_CANCELLATION_INTERVAL, 5); config.setLong(TaskManagerOptions.TASK_CANCELLATION_TIMEOUT, 50); - Task task = createTask(InvokableInterruptableSharedLockInInvokeAndCancel.class, config); + final Task task = new TaskBuilder() + .setInvokable(InvokableInterruptibleSharedLockInInvokeAndCancel.class) + .setTaskManagerConfig(config) + .setTaskManagerActions(taskManagerActions) + .build(); + task.startTaskThread(); awaitLatch.await(); task.cancelExecution(); task.getExecutingThread().join(); - - // No fatal error - for (Object msg : taskManagerMessages) { - assertFalse("Unexpected FatalError message", msg instanceof TaskManagerMessages.FatalError); - } } /** @@ -792,34 +788,39 @@ public void testInterruptableSharedLockInInvokeAndCancel() throws Exception { * resolved by a fatal error. */ @Test - public void testFatalErrorAfterUninterruptibleInvoke() throws Exception { - Configuration config = new Configuration(); + public void testFatalErrorAfterUnInterruptibleInvoke() throws Exception { + final TaskManagerActions taskManagerActions = mock(TaskManagerActions.class); + + final Configuration config = new Configuration(); config.setLong(TaskManagerOptions.TASK_CANCELLATION_INTERVAL, 5); config.setLong(TaskManagerOptions.TASK_CANCELLATION_TIMEOUT, 50); - Task task = createTask(InvokableUninterruptibleBlockingInvoke.class, config); + final Task task = new TaskBuilder() + .setInvokable(InvokableUnInterruptibleBlockingInvoke.class) + .setTaskManagerConfig(config) + .setTaskManagerActions(taskManagerActions) + .build(); - try { - task.startTaskThread(); + final TaskExecutionState state = new TaskExecutionState( + task.getJobID(), + task.getExecutionId(), + ExecutionState.RUNNING); - awaitLatch.await(); + task.startTaskThread(); - task.cancelExecution(); + verify(taskManagerActions, timeout(TIMEOUT)).updateTaskExecutionState(eq(state)); - for (int i = 0; i < 10; i++) { - Object msg = taskManagerMessages.poll(1, TimeUnit.SECONDS); - if (msg instanceof TaskManagerMessages.FatalError) { - return; // success - } - } + awaitLatch.await(); - fail("Did not receive expected task manager message"); - } finally { - // Interrupt again to clean up Thread - cancelLatch.trigger(); - task.getExecutingThread().interrupt(); - task.getExecutingThread().join(); - } + task.cancelExecution(); + + verify(taskManagerActions, timeout(TIMEOUT)).notifyFatalError( + anyString(), any(Throwable.class)); + + // Interrupt again to clean up Thread + triggerLatch.trigger(); + task.getExecutingThread().interrupt(); + task.getExecutingThread().join(); } /** @@ -830,15 +831,19 @@ public void testTaskConfig() throws Exception { long interval = 28218123; long timeout = interval + 19292; - Configuration config = new Configuration(); + final Configuration config = new Configuration(); config.setLong(TaskManagerOptions.TASK_CANCELLATION_INTERVAL, interval); config.setLong(TaskManagerOptions.TASK_CANCELLATION_TIMEOUT, timeout); - ExecutionConfig executionConfig = new ExecutionConfig(); + final ExecutionConfig executionConfig = new ExecutionConfig(); executionConfig.setTaskCancellationInterval(interval + 1337); executionConfig.setTaskCancellationTimeout(timeout - 1337); - Task task = createTask(InvokableBlockingInInvoke.class, config, executionConfig); + final Task task = new TaskBuilder() + .setInvokable(InvokableBlockingInInvoke.class) + .setTaskManagerConfig(config) + .setExecutionConfig(executionConfig) + .build(); assertEquals(interval, task.getTaskCancellationInterval()); assertEquals(timeout, task.getTaskCancellationTimeout()); @@ -854,6 +859,8 @@ public void testTaskConfig() throws Exception { task.getExecutingThread().join(); } + // ------------------------------------------------------------------------ + // helper functions // ------------------------------------------------------------------------ private void setInputGate(Task task, SingleInputGate inputGate) { @@ -885,232 +892,159 @@ private void setState(Task task, ExecutionState state) { } } - /** - * Creates a {@link BlobCacheService} mock that is suitable to be used in the tests above. - * - * @return BlobCache mock with the bare minimum of implemented functions that work - */ - private BlobCacheService createBlobCache() { - return new BlobCacheService( - mock(PermanentBlobCache.class), - mock(TransientBlobCache.class)); - } + private final class TaskBuilder { + private Class invokable; + private TaskManagerActions taskManagerActions; + private LibraryCacheManager libraryCacheManager; + private ResultPartitionConsumableNotifier consumableNotifier; + private PartitionProducerStateChecker partitionProducerStateChecker; + private NetworkEnvironment networkEnvironment; + private Executor executor; + private Configuration taskManagerConfig; + private ExecutionConfig executionConfig; + private Collection requiredJarFileBlobKeys; - private Task createTask(Class invokable) throws IOException { - return createTask(invokable, new Configuration(), new ExecutionConfig()); - } + { + invokable = TestInvokableCorrect.class; + taskManagerActions = mock(TaskManagerActions.class); - private Task createTask(Class invokable, Configuration config) throws IOException { - BlobCacheService blobService = createBlobCache(); - LibraryCacheManager libCache = mock(LibraryCacheManager.class); - when(libCache.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader()); - return createTask(invokable, blobService, libCache, config, new ExecutionConfig()); - } + libraryCacheManager = mock(LibraryCacheManager.class); + when(libraryCacheManager.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader()); - private Task createTask(Class invokable, Configuration config, ExecutionConfig execConfig) throws IOException { - BlobCacheService blobService = createBlobCache(); - LibraryCacheManager libCache = mock(LibraryCacheManager.class); - when(libCache.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader()); - return createTask(invokable, blobService, libCache, config, execConfig); - } + consumableNotifier = mock(ResultPartitionConsumableNotifier.class); + partitionProducerStateChecker = mock(PartitionProducerStateChecker.class); - private Task createTask( - Class invokable, - BlobCacheService blobService, - LibraryCacheManager libCache) throws IOException { + final ResultPartitionManager partitionManager = mock(ResultPartitionManager.class); + final TaskEventDispatcher taskEventDispatcher = mock(TaskEventDispatcher.class); + networkEnvironment = mock(NetworkEnvironment.class); + when(networkEnvironment.getResultPartitionManager()).thenReturn(partitionManager); + when(networkEnvironment.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC); + when(networkEnvironment.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class))) + .thenReturn(mock(TaskKvStateRegistry.class)); + when(networkEnvironment.getTaskEventDispatcher()).thenReturn(taskEventDispatcher); - return createTask(invokable, blobService, libCache, new Configuration(), new ExecutionConfig()); - } + executor = TestingUtils.defaultExecutor(); - private Task createTask( - Class invokable, - BlobCacheService blobService, - LibraryCacheManager libCache, - Configuration config, - ExecutionConfig execConfig) throws IOException { - - ResultPartitionManager partitionManager = mock(ResultPartitionManager.class); - ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class); - PartitionProducerStateChecker partitionProducerStateChecker = mock(PartitionProducerStateChecker.class); - TaskEventDispatcher taskEventDispatcher = mock(TaskEventDispatcher.class); - Executor executor = mock(Executor.class); - NetworkEnvironment network = mock(NetworkEnvironment.class); - when(network.getResultPartitionManager()).thenReturn(partitionManager); - when(network.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC); - when(network.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class))) - .thenReturn(mock(TaskKvStateRegistry.class)); - when(network.getTaskEventDispatcher()).thenReturn(taskEventDispatcher); + taskManagerConfig = new Configuration(); + executionConfig = new ExecutionConfig(); - return createTask(invokable, blobService, libCache, network, consumableNotifier, partitionProducerStateChecker, executor, config, execConfig); - } + requiredJarFileBlobKeys = Collections.emptyList(); + } - private Task createTask( - Class invokable, - BlobCacheService blobService, - LibraryCacheManager libCache, - NetworkEnvironment networkEnvironment, - ResultPartitionConsumableNotifier consumableNotifier, - PartitionProducerStateChecker partitionProducerStateChecker, - Executor executor) throws IOException { - return createTask(invokable, blobService, libCache, networkEnvironment, consumableNotifier, partitionProducerStateChecker, executor, new Configuration(), new ExecutionConfig()); - } + TaskBuilder setInvokable(Class invokable) { + this.invokable = invokable; + return this; + } - private Task createTask( - Class invokable, - BlobCacheService blobService, - LibraryCacheManager libCache, - NetworkEnvironment networkEnvironment, - ResultPartitionConsumableNotifier consumableNotifier, - PartitionProducerStateChecker partitionProducerStateChecker, - Executor executor, - Configuration taskManagerConfig, - ExecutionConfig execConfig) throws IOException { - - JobID jobId = new JobID(); - JobVertexID jobVertexId = new JobVertexID(); - ExecutionAttemptID executionAttemptId = new ExecutionAttemptID(); - - InputSplitProvider inputSplitProvider = new TaskInputSplitProvider( - jobManagerGateway, - jobId, - jobVertexId, - executionAttemptId, - new FiniteDuration(60, TimeUnit.SECONDS)); - - CheckpointResponder checkpointResponder = new ActorGatewayCheckpointResponder(jobManagerGateway); - - SerializedValue serializedExecutionConfig = new SerializedValue<>(execConfig); - - JobInformation jobInformation = new JobInformation( - jobId, - "Test Job", - serializedExecutionConfig, - new Configuration(), - Collections.emptyList(), - Collections.emptyList()); - - TaskInformation taskInformation = new TaskInformation( - jobVertexId, - "Test Task", - 1, - 1, - invokable.getName(), - new Configuration()); - - TaskMetricGroup taskMetricGroup = mock(TaskMetricGroup.class); - when(taskMetricGroup.getIOMetricGroup()).thenReturn(mock(TaskIOMetricGroup.class)); - - return new Task( - jobInformation, - taskInformation, - executionAttemptId, - new AllocationID(), - 0, - 0, - Collections.emptyList(), - Collections.emptyList(), - 0, - mock(MemoryManager.class), - mock(IOManager.class), - networkEnvironment, - mock(BroadcastVariableManager.class), - new TestTaskStateManager(), - taskManagerConnection, - inputSplitProvider, - checkpointResponder, - blobService, - libCache, - mock(FileCache.class), - new TestingTaskManagerRuntimeInfo(taskManagerConfig), - taskMetricGroup, - consumableNotifier, - partitionProducerStateChecker, - executor); - } + TaskBuilder setTaskManagerActions(TaskManagerActions taskManagerActions) { + this.taskManagerActions = taskManagerActions; + return this; + } - // ------------------------------------------------------------------------ - // Validation Methods - // ------------------------------------------------------------------------ + TaskBuilder setLibraryCacheManager(LibraryCacheManager libraryCacheManager) { + this.libraryCacheManager = libraryCacheManager; + return this; + } - private void validateUnregisterTask(ExecutionAttemptID id) { - try { - // we may have to wait for a bit to give the actors time to receive the message - // and put it into the queue - Object rawMessage = taskManagerMessages.take(); + TaskBuilder setConsumableNotifier(ResultPartitionConsumableNotifier consumableNotifier) { + this.consumableNotifier = consumableNotifier; + return this; + } - assertNotNull("There is no additional TaskManager message", rawMessage); - if (!(rawMessage instanceof TaskMessages.TaskInFinalState)) { - fail("TaskManager message is not 'UnregisterTask', but " + rawMessage.getClass()); - } + TaskBuilder setPartitionProducerStateChecker(PartitionProducerStateChecker partitionProducerStateChecker) { + this.partitionProducerStateChecker = partitionProducerStateChecker; + return this; + } + + TaskBuilder setNetworkEnvironment(NetworkEnvironment networkEnvironment) { + this.networkEnvironment = networkEnvironment; + return this; + } - TaskMessages.TaskInFinalState message = (TaskMessages.TaskInFinalState) rawMessage; - assertEquals(id, message.executionID()); + TaskBuilder setExecutor(Executor executor) { + this.executor = executor; + return this; } - catch (InterruptedException e) { - fail("interrupted"); + + TaskBuilder setTaskManagerConfig(Configuration taskManagerConfig) { + this.taskManagerConfig = taskManagerConfig; + return this; } - } - private void validateTaskManagerStateChange(ExecutionState state, Task task, boolean hasError) { - try { - // we may have to wait for a bit to give the actors time to receive the message - // and put it into the queue - Object rawMessage = taskManagerMessages.take(); + TaskBuilder setExecutionConfig(ExecutionConfig executionConfig) { + this.executionConfig = executionConfig; + return this; + } - assertNotNull("There is no additional TaskManager message", rawMessage); - if (!(rawMessage instanceof TaskMessages.UpdateTaskExecutionState)) { - fail("TaskManager message is not 'UpdateTaskExecutionState', but " + rawMessage.getClass()); - } + TaskBuilder setRequiredJarFileBlobKeys(Collection requiredJarFileBlobKeys) { + this.requiredJarFileBlobKeys = requiredJarFileBlobKeys; + return this; + } - TaskMessages.UpdateTaskExecutionState message = - (TaskMessages.UpdateTaskExecutionState) rawMessage; + private Task build() throws Exception { + final JobID jobId = new JobID(); + final JobVertexID jobVertexId = new JobVertexID(); + final ExecutionAttemptID executionAttemptId = new ExecutionAttemptID(); - TaskExecutionState taskState = message.taskExecutionState(); + final SerializedValue serializedExecutionConfig = new SerializedValue<>(executionConfig); - assertEquals(task.getJobID(), taskState.getJobID()); - assertEquals(task.getExecutionId(), taskState.getID()); - assertEquals(state, taskState.getExecutionState()); + final JobInformation jobInformation = new JobInformation( + jobId, + "Test Job", + serializedExecutionConfig, + new Configuration(), + requiredJarFileBlobKeys, + Collections.emptyList()); - if (hasError) { - assertNotNull(taskState.getError(getClass().getClassLoader())); - } else { - assertNull(taskState.getError(getClass().getClassLoader())); - } - } - catch (InterruptedException e) { - fail("interrupted"); - } - } + final TaskInformation taskInformation = new TaskInformation( + jobVertexId, + "Test Task", + 1, + 1, + invokable.getName(), + new Configuration()); - private void validateListenerMessage(ExecutionState state, Task task, boolean hasError) { - try { - // we may have to wait for a bit to give the actors time to receive the message - // and put it into the queue - final TaskExecutionState taskState = listenerMessages.take(); - assertNotNull("There is no additional listener message", state); - - assertEquals(task.getJobID(), taskState.getJobID()); - assertEquals(task.getExecutionId(), taskState.getID()); - assertEquals(state, taskState.getExecutionState()); - - if (hasError) { - assertNotNull(taskState.getError(getClass().getClassLoader())); - } else { - assertNull(taskState.getError(getClass().getClassLoader())); - } - } - catch (InterruptedException e) { - fail("interrupted"); + final BlobCacheService blobCacheService = new BlobCacheService( + mock(PermanentBlobCache.class), + mock(TransientBlobCache.class)); + + final TaskMetricGroup taskMetricGroup = mock(TaskMetricGroup.class); + when(taskMetricGroup.getIOMetricGroup()).thenReturn(mock(TaskIOMetricGroup.class)); + + return new Task( + jobInformation, + taskInformation, + executionAttemptId, + new AllocationID(), + 0, + 0, + Collections.emptyList(), + Collections.emptyList(), + 0, + mock(MemoryManager.class), + mock(IOManager.class), + networkEnvironment, + mock(BroadcastVariableManager.class), + new TestTaskStateManager(), + taskManagerActions, + new MockInputSplitProvider(), + new TestCheckpointResponder(), + blobCacheService, + libraryCacheManager, + mock(FileCache.class), + new TestingTaskManagerRuntimeInfo(taskManagerConfig), + taskMetricGroup, + consumableNotifier, + partitionProducerStateChecker, + executor); } } - // -------------------------------------------------------------------------------------------- - // Mock invokable code - // -------------------------------------------------------------------------------------------- - - /** Test task class. */ - public static final class TestInvokableCorrect extends AbstractInvokable { + // ------------------------------------------------------------------------ + // test task classes + // ------------------------------------------------------------------------ + private static final class TestInvokableCorrect extends AbstractInvokable { public TestInvokableCorrect(Environment environment) { super(environment); } @@ -1119,14 +1053,18 @@ public TestInvokableCorrect(Environment environment) { public void invoke() {} @Override - public void cancel() throws Exception { + public void cancel() { fail("This should not be called"); } } - /** Test task class. */ - public static final class InvokableWithExceptionInInvoke extends AbstractInvokable { + private abstract static class InvokableNonInstantiable extends AbstractInvokable { + public InvokableNonInstantiable(Environment environment) { + super(environment); + } + } + private static final class InvokableWithExceptionInInvoke extends AbstractInvokable { public InvokableWithExceptionInInvoke(Environment environment) { super(environment); } @@ -1137,44 +1075,21 @@ public void invoke() throws Exception { } } - /** Test task class. */ - public static final class InvokableWithExceptionOnTrigger extends AbstractInvokable { - - public InvokableWithExceptionOnTrigger(Environment environment) { + private static final class FailingInvokableWithChainedException extends AbstractInvokable { + public FailingInvokableWithChainedException(Environment environment) { super(environment); } @Override public void invoke() { - awaitLatch.trigger(); - - // make sure that the interrupt call does not - // grab us out of the lock early - while (true) { - try { - triggerLatch.await(); - break; - } - catch (InterruptedException e) { - // fall through the loop - } - } - - throw new RuntimeException("test"); + throw new TestWrappedException(new IOException("test")); } - } - /** Test task class. */ - public abstract static class InvokableNonInstantiable extends AbstractInvokable { - - public InvokableNonInstantiable(Environment environment) { - super(environment); - } + @Override + public void cancel() {} } - /** Test task class. */ - public static final class InvokableBlockingInInvoke extends AbstractInvokable { - + private static final class InvokableBlockingInInvoke extends AbstractInvokable { public InvokableBlockingInInvoke(Environment environment) { super(environment); } @@ -1190,64 +1105,60 @@ public void invoke() throws Exception { } } - /** Test task class. */ - public static final class InvokableWithCancelTaskExceptionInInvoke extends AbstractInvokable { - - public InvokableWithCancelTaskExceptionInInvoke(Environment environment) { + public static final class InvokableWithExceptionOnTrigger extends AbstractInvokable { + public InvokableWithExceptionOnTrigger(Environment environment) { super(environment); } @Override - public void invoke() throws Exception { + public void invoke() { awaitLatch.trigger(); - try { - triggerLatch.await(); + // make sure that the interrupt call does not + // grab us out of the lock early + while (true) { + try { + triggerLatch.await(); + break; + } + catch (InterruptedException e) { + // fall through the loop + } } - catch (Throwable ignored) {} - throw new CancelTaskException(); + throw new RuntimeException("test"); } } - /** Test task class. */ - public static final class InvokableInterruptableSharedLockInInvokeAndCancel extends AbstractInvokable { - - private final Object lock = new Object(); - - public InvokableInterruptableSharedLockInInvokeAndCancel(Environment environment) { + public static final class InvokableWithCancelTaskExceptionInInvoke extends AbstractInvokable { + public InvokableWithCancelTaskExceptionInInvoke(Environment environment) { super(environment); } @Override - public void invoke() throws Exception { - synchronized (lock) { - awaitLatch.trigger(); - wait(); - } - } + public void invoke() { + awaitLatch.trigger(); - @Override - public void cancel() throws Exception { - synchronized (lock) { - cancelLatch.trigger(); + try { + triggerLatch.await(); } + catch (Throwable ignored) {} + + throw new CancelTaskException(); } } - /** Test task class. */ public static final class InvokableBlockingInCancel extends AbstractInvokable { - public InvokableBlockingInCancel(Environment environment) { super(environment); } @Override - public void invoke() throws Exception { + public void invoke() { awaitLatch.trigger(); try { - cancelLatch.await(); + triggerLatch.await(); // await cancel synchronized (this) { wait(); } @@ -1261,51 +1172,57 @@ public void invoke() throws Exception { @Override public void cancel() throws Exception { synchronized (this) { - cancelLatch.trigger(); + triggerLatch.trigger(); wait(); } } } - /** Test task class. */ - public static final class InvokableUninterruptibleBlockingInvoke extends AbstractInvokable { + public static final class InvokableInterruptibleSharedLockInInvokeAndCancel extends AbstractInvokable { + private final Object lock = new Object(); - public InvokableUninterruptibleBlockingInvoke(Environment environment) { + public InvokableInterruptibleSharedLockInInvokeAndCancel(Environment environment) { super(environment); } @Override public void invoke() throws Exception { - while (!cancelLatch.isTriggered()) { - try { - synchronized (this) { - awaitLatch.trigger(); - wait(); - } - } catch (InterruptedException ignored) { - } + synchronized (lock) { + awaitLatch.trigger(); + wait(); } } @Override - public void cancel() throws Exception { + public void cancel() { + synchronized (lock) { + // do nothing but a placeholder + triggerLatch.trigger(); + } } } - /** Test task class. */ - public static final class FailingInvokableWithChainedException extends AbstractInvokable { - - public FailingInvokableWithChainedException(Environment environment) { + public static final class InvokableUnInterruptibleBlockingInvoke extends AbstractInvokable { + public InvokableUnInterruptibleBlockingInvoke(Environment environment) { super(environment); } @Override - public void invoke() throws Exception { - throw new TestWrappedException(new IOException("test")); + public void invoke() { + while (!triggerLatch.isTriggered()) { + try { + synchronized (this) { + awaitLatch.trigger(); + wait(); + } + } catch (InterruptedException ignored) { + } + } } @Override - public void cancel() {} + public void cancel() { + } } // ------------------------------------------------------------------------ @@ -1315,7 +1232,7 @@ public void cancel() {} private static class TestWrappedException extends WrappingRuntimeException { private static final long serialVersionUID = 1L; - public TestWrappedException(@Nonnull Throwable cause) { + TestWrappedException(@Nonnull Throwable cause) { super(cause); } }