From d0cf483e0fb21010ed996935ffb14d62b34ee8ed Mon Sep 17 00:00:00 2001 From: Nico Kruber Date: Tue, 29 Aug 2017 17:32:52 +0200 Subject: [PATCH 1/3] [FLINK-7746][network] move ResultPartitionWriter#writeBufferToAllChannels implementation up into ResultPartition --- .../api/writer/ResultPartitionWriter.java | 12 +---------- .../io/network/partition/ResultPartition.java | 21 +++++++++++++++++++ 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java index 57c7098a85767..5756bc776de1c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java @@ -82,17 +82,7 @@ public void writeBuffer(Buffer buffer, int targetChannel) throws IOException { * @throws IOException */ public void writeBufferToAllChannels(final Buffer eventBuffer) throws IOException { - try { - for (int targetChannel = 0; targetChannel < partition.getNumberOfSubpartitions(); targetChannel++) { - // retain the buffer so that it can be recycled by each channel of targetPartition - eventBuffer.retain(); - writeBuffer(eventBuffer, targetChannel); - } - } finally { - // we do not need to further retain the eventBuffer - // (it will be recycled after the last channel stops using it) - eventBuffer.recycle(); - } + partition.addToAllChannels(eventBuffer); } // ------------------------------------------------------------------------ diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java index 9b02e4d603c30..4d8e71c04b698 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java @@ -292,6 +292,27 @@ public void add(Buffer buffer, int subpartitionIndex) throws IOException { } } + /** + * Writes the given buffer to all available target channels. + * + *

The buffer is taken over and used for each of the channels. It will be recycled afterwards. + * + * @param buffer the buffer to write + */ + public void addToAllChannels(Buffer buffer) throws IOException { + try { + for (int targetChannel = 0; targetChannel < subpartitions.length; targetChannel++) { + // retain the buffer so that it can be recycled by each channel of targetPartition + buffer.retain(); + add(buffer, targetChannel); + } + } finally { + // we do not need to further retain the buffer + // (it will be recycled after the last channel stops using it) + buffer.recycle(); + } + } + /** * Finishes the result partition. * From cfea094aa214612eaafef14554bc68626a6ff948 Mon Sep 17 00:00:00 2001 From: Nico Kruber Date: Tue, 29 Aug 2017 18:24:00 +0200 Subject: [PATCH 2/3] [FLINK-7748][network] properly use the TaskEventDispatcher for subscribing to events Previously, the ResultPartitionWriter implemented the EventListener interface and was used for event registration, although event publishing was already handled via the TaskEventDispatcher. Now, we use the TaskEventDispatcher for both, event registration and publishing. It also adds the TaskEventDispatcher to the Environment information for a task to be able to work with it (only IterationHeadTask so far). --- .../flink/runtime/execution/Environment.java | 3 + .../io/network/NetworkEnvironment.java | 4 +- .../io/network/TaskEventDispatcher.java | 109 +++++++++--- .../api/writer/ResultPartitionWriter.java | 20 +-- .../iterative/task/IterationHeadTask.java | 8 +- .../taskmanager/RuntimeEnvironment.java | 10 ++ .../flink/runtime/taskmanager/Task.java | 2 +- .../io/network/TaskEventDispatcherTest.java | 168 ++++++++++++++++++ .../consumer/SingleInputGateTest.java | 3 +- .../operators/testutils/DummyEnvironment.java | 5 + .../operators/testutils/MockEnvironment.java | 8 + .../taskexecutor/TaskExecutorITCase.java | 3 +- .../taskexecutor/TaskExecutorTest.java | 3 + .../taskmanager/TaskAsyncCallTest.java | 3 + .../flink/runtime/taskmanager/TaskTest.java | 7 + .../runtime/util/JvmExitOnFatalErrorTest.java | 3 + .../tasks/BlockingCheckpointsTest.java | 3 + .../tasks/InterruptSensitiveRestoreTest.java | 3 + .../runtime/tasks/StreamMockEnvironment.java | 8 + .../tasks/StreamTaskTerminationTest.java | 3 + .../runtime/tasks/StreamTaskTest.java | 3 + 21 files changed, 326 insertions(+), 53 deletions(-) create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/io/network/TaskEventDispatcherTest.java diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java index 203ee8547cf42..ad66c5703313a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java @@ -29,6 +29,7 @@ import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; +import org.apache.flink.runtime.io.network.TaskEventDispatcher; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; import org.apache.flink.runtime.jobgraph.JobVertexID; @@ -209,4 +210,6 @@ public interface Environment { InputGate getInputGate(int index); InputGate[] getAllInputGates(); + + TaskEventDispatcher getTaskEventDispatcher(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java index 4269af6401551..03628a596d069 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java @@ -201,7 +201,7 @@ public void registerTask(Task task) throws IOException { } // Register writer with task event dispatcher - taskEventDispatcher.registerWriterForIncomingTaskEvents(writer.getPartitionId(), writer); + taskEventDispatcher.registerPartition(writer.getPartitionId()); } // Setup the buffer pool for each buffer reader @@ -251,7 +251,7 @@ public void unregisterTask(Task task) { ResultPartitionWriter[] writers = task.getAllWriters(); if (writers != null) { for (ResultPartitionWriter writer : writers) { - taskEventDispatcher.unregisterWriter(writer); + taskEventDispatcher.unregisterPartition(writer.getPartitionId()); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/TaskEventDispatcher.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/TaskEventDispatcher.java index 8816e32cb27e6..38c0f18ba003a 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/TaskEventDispatcher.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/TaskEventDispatcher.java @@ -19,70 +19,125 @@ package org.apache.flink.runtime.io.network; import org.apache.flink.runtime.event.TaskEvent; -import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; +import org.apache.flink.runtime.io.network.api.TaskEventHandler; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.io.network.partition.consumer.LocalInputChannel; import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel; import org.apache.flink.runtime.util.event.EventListener; -import org.apache.flink.shaded.guava18.com.google.common.collect.Maps; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import java.util.HashMap; import java.util.Map; +import static org.apache.flink.util.Preconditions.checkNotNull; + /** * The task event dispatcher dispatches events flowing backwards from a consuming task to the task * producing the consumed result. * - *

Backwards events only work for tasks, which produce pipelined results, where both the + *

Backwards events only work for tasks, which produce pipelined results, where both the * producing and consuming task are running at the same time. */ public class TaskEventDispatcher { + private static final Logger LOG = LoggerFactory.getLogger(TaskEventDispatcher.class); - private final Map registeredWriters = Maps.newHashMap(); + private final Map registeredHandlers = new HashMap<>(); - public void registerWriterForIncomingTaskEvents(ResultPartitionID partitionId, ResultPartitionWriter writer) { - synchronized (registeredWriters) { - if (registeredWriters.put(partitionId, writer) != null) { - throw new IllegalStateException("Already registered at task event dispatcher."); + /** + * Registers the given partition for incoming task events allowing calls to {@link + * #subscribeToEvent(ResultPartitionID, EventListener, Class)}. + * + * @param partitionId + * the partition ID + */ + public void registerPartition(ResultPartitionID partitionId) { + checkNotNull(partitionId); + + synchronized (registeredHandlers) { + LOG.debug("registering {}", partitionId); + if (registeredHandlers.put(partitionId, new TaskEventHandler()) != null) { + throw new IllegalStateException( + "Partition " + partitionId + " already registered at task event dispatcher."); } } } - public void unregisterWriter(ResultPartitionWriter writer) { - synchronized (registeredWriters) { - registeredWriters.remove(writer.getPartitionId()); + /** + * Removes the given partition from listening to incoming task events, thus forbidding calls to + * {@link #subscribeToEvent(ResultPartitionID, EventListener, Class)}. + * + * @param partitionId + * the partition ID + */ + public void unregisterPartition(ResultPartitionID partitionId) { + checkNotNull(partitionId); + + synchronized (registeredHandlers) { + LOG.debug("unregistering {}", partitionId); + // NOTE: tolerate un-registration of non-registered task (unregister is always called + // in the cleanup phase of a task even if it never came to the registration - see + // Task.java) + registeredHandlers.remove(partitionId); } } /** - * Publishes the event to the registered {@link ResultPartitionWriter} instances. - *

- * This method is either called directly from a {@link LocalInputChannel} or the network I/O + * Subscribes a listener to this dispatcher for events on a partition. + * + * @param partitionId + * ID of the partition to subscribe for (must be registered via {@link + * #registerPartition(ResultPartitionID)} first!) + * @param eventListener + * the event listener to subscribe + * @param eventType + * event type to subscribe to + */ + public void subscribeToEvent( + ResultPartitionID partitionId, EventListener eventListener, + Class eventType) { + checkNotNull(partitionId); + checkNotNull(eventListener); + checkNotNull(eventType); + + TaskEventHandler taskEventHandler = registeredHandlers.get(partitionId); + if (taskEventHandler == null) { + throw new IllegalStateException( + "Partition " + partitionId + " not registered at task event dispatcher."); + } + taskEventHandler.subscribe(eventListener, eventType); + } + + /** + * Publishes the event to the registered {@link EventListener} instances. + * + *

This method is either called directly from a {@link LocalInputChannel} or the network I/O * thread on behalf of a {@link RemoteInputChannel}. + * + * @return whether the event was published to a registered event handler (initiated via {@link + * #registerPartition(ResultPartitionID)}) or not */ public boolean publish(ResultPartitionID partitionId, TaskEvent event) { - EventListener listener = registeredWriters.get(partitionId); + checkNotNull(partitionId); + checkNotNull(event); - if (listener != null) { - listener.onEvent(event); + TaskEventHandler taskEventHandler = registeredHandlers.get(partitionId); + + if (taskEventHandler != null) { + taskEventHandler.publish(event); return true; } return false; } - public void clearAll() { - synchronized (registeredWriters) { - registeredWriters.clear(); - } - } - /** - * Returns the number of currently registered writers. + * Removes all registered event handlers. */ - int getNumberOfRegisteredWriters() { - synchronized (registeredWriters) { - return registeredWriters.size(); + public void clearAll() { + synchronized (registeredHandlers) { + registeredHandlers.clear(); } } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java index 5756bc776de1c..225c2320cf028 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java @@ -18,13 +18,10 @@ package org.apache.flink.runtime.io.network.api.writer; -import org.apache.flink.runtime.event.TaskEvent; -import org.apache.flink.runtime.io.network.api.TaskEventHandler; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.BufferProvider; import org.apache.flink.runtime.io.network.partition.ResultPartition; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; -import org.apache.flink.runtime.util.event.EventListener; import java.io.IOException; @@ -34,12 +31,10 @@ * The {@link ResultPartitionWriter} is the runtime API for producing results. It * supports two kinds of data to be sent: buffers and events. */ -public class ResultPartitionWriter implements EventListener { +public class ResultPartitionWriter { private final ResultPartition partition; - private final TaskEventHandler taskEventHandler = new TaskEventHandler(); - public ResultPartitionWriter(ResultPartition partition) { this.partition = partition; } @@ -84,17 +79,4 @@ public void writeBuffer(Buffer buffer, int targetChannel) throws IOException { public void writeBufferToAllChannels(final Buffer eventBuffer) throws IOException { partition.addToAllChannels(eventBuffer); } - - // ------------------------------------------------------------------------ - // Event handling - // ------------------------------------------------------------------------ - - public void subscribeToEvent(EventListener eventListener, Class eventType) { - taskEventHandler.subscribe(eventListener, eventType); - } - - @Override - public void onEvent(TaskEvent event) { - taskEventHandler.publish(event); - } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationHeadTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationHeadTask.java index b673ba09194cd..977600e7385e0 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationHeadTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationHeadTask.java @@ -27,10 +27,12 @@ import org.apache.flink.core.memory.DataInputView; import org.apache.flink.core.memory.MemorySegment; import org.apache.flink.runtime.io.disk.InputViewIterator; +import org.apache.flink.runtime.io.network.TaskEventDispatcher; import org.apache.flink.runtime.io.network.api.EndOfSuperstepEvent; import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; import org.apache.flink.runtime.io.network.api.writer.RecordWriter; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.iterative.concurrent.BlockingBackChannel; import org.apache.flink.runtime.iterative.concurrent.BlockingBackChannelBroker; import org.apache.flink.runtime.iterative.concurrent.Broker; @@ -223,8 +225,10 @@ private void readInitialSolutionSet(JoinHashMap solutionSet, MutableObjectIte private SuperstepBarrier initSuperstepBarrier() { SuperstepBarrier barrier = new SuperstepBarrier(getUserCodeClassLoader()); - this.toSync.subscribeToEvent(barrier, AllWorkersDoneEvent.class); - this.toSync.subscribeToEvent(barrier, TerminationEvent.class); + TaskEventDispatcher taskEventDispatcher = getEnvironment().getTaskEventDispatcher(); + ResultPartitionID partitionId = this.toSync.getPartitionId(); + taskEventDispatcher.subscribeToEvent(partitionId, barrier, AllWorkersDoneEvent.class); + taskEventDispatcher.subscribeToEvent(partitionId, barrier, TerminationEvent.class); return barrier; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java index 92b58868d666f..60738f02f2af6 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java @@ -30,6 +30,7 @@ import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; +import org.apache.flink.runtime.io.network.TaskEventDispatcher; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; import org.apache.flink.runtime.jobgraph.JobVertexID; @@ -69,6 +70,8 @@ public class RuntimeEnvironment implements Environment { private final ResultPartitionWriter[] writers; private final InputGate[] inputGates; + + private final TaskEventDispatcher taskEventDispatcher; private final CheckpointResponder checkpointResponder; @@ -101,6 +104,7 @@ public RuntimeEnvironment( Map> distCacheEntries, ResultPartitionWriter[] writers, InputGate[] inputGates, + TaskEventDispatcher taskEventDispatcher, CheckpointResponder checkpointResponder, TaskManagerRuntimeInfo taskManagerInfo, TaskMetricGroup metrics, @@ -123,6 +127,7 @@ public RuntimeEnvironment( this.distCacheEntries = checkNotNull(distCacheEntries); this.writers = checkNotNull(writers); this.inputGates = checkNotNull(inputGates); + this.taskEventDispatcher = checkNotNull(taskEventDispatcher); this.checkpointResponder = checkNotNull(checkpointResponder); this.taskManagerInfo = checkNotNull(taskManagerInfo); this.containingTask = containingTask; @@ -236,6 +241,11 @@ public InputGate[] getAllInputGates() { return inputGates; } + @Override + public TaskEventDispatcher getTaskEventDispatcher() { + return taskEventDispatcher; + } + @Override public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics) { acknowledgeCheckpoint(checkpointId, checkpointMetrics, null); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java index d62896054d40c..8f1727429ca86 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java @@ -665,7 +665,7 @@ else if (current == ExecutionState.CANCELING) { jobConfiguration, taskConfiguration, userCodeClassLoader, memoryManager, ioManager, broadcastVariableManager, accumulatorRegistry, kvStateRegistry, inputSplitProvider, - distributedCacheEntries, writers, inputGates, + distributedCacheEntries, writers, inputGates, network.getTaskEventDispatcher(), checkpointResponder, taskManagerConfig, metrics, this); // let the task code create its readers and writers diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/TaskEventDispatcherTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/TaskEventDispatcherTest.java new file mode 100644 index 0000000000000..979a9575a299b --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/TaskEventDispatcherTest.java @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.io.network; + +import org.apache.flink.runtime.event.TaskEvent; +import org.apache.flink.runtime.io.network.partition.ResultPartitionID; +import org.apache.flink.runtime.iterative.event.AllWorkersDoneEvent; +import org.apache.flink.runtime.iterative.event.TerminationEvent; +import org.apache.flink.runtime.util.event.EventListener; +import org.apache.flink.util.TestLogger; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import static junit.framework.TestCase.assertTrue; +import static org.junit.Assert.assertFalse; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +/** + * Basic tests for {@link TaskEventDispatcher}. + */ +public class TaskEventDispatcherTest extends TestLogger { + + @Rule + public ExpectedException expectedException = ExpectedException.none(); + + @Test + public void registerPartitionTwice() throws Exception { + ResultPartitionID partitionId = new ResultPartitionID(); + TaskEventDispatcher ted = new TaskEventDispatcher(); + ted.registerPartition(partitionId); + + expectedException.expect(IllegalStateException.class); + expectedException.expectMessage("already registered at task event dispatcher"); + + ted.registerPartition(partitionId); + } + + @Test + public void subscribeToEventNotRegistered() throws Exception { + TaskEventDispatcher ted = new TaskEventDispatcher(); + + expectedException.expect(IllegalStateException.class); + expectedException.expectMessage("not registered at task event dispatcher"); + + //noinspection unchecked + ted.subscribeToEvent(new ResultPartitionID(), mock(EventListener.class), TaskEvent.class); + } + + /** + * Tests {@link TaskEventDispatcher#publish(ResultPartitionID, TaskEvent)} and {@link TaskEventDispatcher#subscribeToEvent(ResultPartitionID, EventListener, Class)} methods. + */ + @Test + public void publishSubscribe() throws Exception { + ResultPartitionID partitionId1 = new ResultPartitionID(); + ResultPartitionID partitionId2 = new ResultPartitionID(); + TaskEventDispatcher ted = new TaskEventDispatcher(); + + AllWorkersDoneEvent event1 = new AllWorkersDoneEvent(); + assertFalse(ted.publish(partitionId1, event1)); + + ted.registerPartition(partitionId1); + ted.registerPartition(partitionId2); + + // no event listener subscribed yet, but the event is forwarded to a TaskEventHandler + assertTrue(ted.publish(partitionId1, event1)); + + //noinspection unchecked + EventListener eventListener1a = mock(EventListener.class); + //noinspection unchecked + EventListener eventListener1b = mock(EventListener.class); + //noinspection unchecked + EventListener eventListener2 = mock(EventListener.class); + //noinspection unchecked + EventListener eventListener3 = mock(EventListener.class); + ted.subscribeToEvent(partitionId1, eventListener1a, AllWorkersDoneEvent.class); + ted.subscribeToEvent(partitionId2, eventListener1b, AllWorkersDoneEvent.class); + ted.subscribeToEvent(partitionId1, eventListener2, TaskEvent.class); + ted.subscribeToEvent(partitionId1, eventListener3, TerminationEvent.class); + + assertTrue(ted.publish(partitionId1, event1)); + verify(eventListener1a, times(1)).onEvent(event1); + verify(eventListener1b, times(0)).onEvent(any()); + verify(eventListener2, times(0)).onEvent(any()); + verify(eventListener3, times(0)).onEvent(any()); + + // publish another event, verify that only the right subscriber is called + TerminationEvent event2 = new TerminationEvent(); + assertTrue(ted.publish(partitionId1, event2)); + verify(eventListener1a, times(1)).onEvent(event1); + verify(eventListener1b, times(0)).onEvent(any()); + verify(eventListener2, times(0)).onEvent(any()); + verify(eventListener3, times(1)).onEvent(event2); + } + + @Test + public void unregisterPartition() throws Exception { + ResultPartitionID partitionId1 = new ResultPartitionID(); + ResultPartitionID partitionId2 = new ResultPartitionID(); + TaskEventDispatcher ted = new TaskEventDispatcher(); + + AllWorkersDoneEvent event = new AllWorkersDoneEvent(); + assertFalse(ted.publish(partitionId1, event)); + + ted.registerPartition(partitionId1); + ted.registerPartition(partitionId2); + + //noinspection unchecked + EventListener eventListener1a = mock(EventListener.class); + //noinspection unchecked + EventListener eventListener1b = mock(EventListener.class); + //noinspection unchecked + EventListener eventListener2 = mock(EventListener.class); + ted.subscribeToEvent(partitionId1, eventListener1a, AllWorkersDoneEvent.class); + ted.subscribeToEvent(partitionId2, eventListener1b, AllWorkersDoneEvent.class); + ted.subscribeToEvent(partitionId1, eventListener2, AllWorkersDoneEvent.class); + + ted.unregisterPartition(partitionId2); + + // publis something for partitionId1 triggering all according listeners + assertTrue(ted.publish(partitionId1, event)); + verify(eventListener1a, times(1)).onEvent(event); + verify(eventListener1b, times(0)).onEvent(any()); + verify(eventListener2, times(1)).onEvent(event); + + // now publish something for partitionId2 which should not trigger any listeners + assertFalse(ted.publish(partitionId2, event)); + verify(eventListener1a, times(1)).onEvent(event); + verify(eventListener1b, times(0)).onEvent(any()); + verify(eventListener2, times(1)).onEvent(event); + } + + @Test + public void clearAll() throws Exception { + ResultPartitionID partitionId = new ResultPartitionID(); + TaskEventDispatcher ted = new TaskEventDispatcher(); + ted.registerPartition(partitionId); + + //noinspection unchecked + EventListener eventListener1 = mock(EventListener.class); + ted.subscribeToEvent(partitionId, eventListener1, AllWorkersDoneEvent.class); + + ted.clearAll(); + + assertFalse(ted.publish(partitionId, new AllWorkersDoneEvent())); + verify(eventListener1, times(0)).onEvent(any()); + } +} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java index 737f17b8ca34c..5ab85fcf57b0b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java @@ -322,9 +322,10 @@ public void testRequestBackoffConfiguration() throws Exception { int initialBackoff = 137; int maxBackoff = 1001; + TaskEventDispatcher taskEventDispatcher = new TaskEventDispatcher(); NetworkEnvironment netEnv = mock(NetworkEnvironment.class); when(netEnv.getResultPartitionManager()).thenReturn(new ResultPartitionManager()); - when(netEnv.getTaskEventDispatcher()).thenReturn(new TaskEventDispatcher()); + when(netEnv.getTaskEventDispatcher()).thenReturn(taskEventDispatcher); when(netEnv.getPartitionRequestInitialBackoff()).thenReturn(initialBackoff); when(netEnv.getPartitionRequestMaxBackoff()).thenReturn(maxBackoff); when(netEnv.getConnectionManager()).thenReturn(new LocalConnectionManager()); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java index 8ed06b2ef3682..c888e72cfa6e4 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java @@ -30,6 +30,7 @@ import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; +import org.apache.flink.runtime.io.network.TaskEventDispatcher; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; import org.apache.flink.runtime.jobgraph.JobVertexID; @@ -189,4 +190,8 @@ public InputGate[] getAllInputGates() { return null; } + @Override + public TaskEventDispatcher getTaskEventDispatcher() { + return null; + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java index 7514cc4200d74..c8ca6541a1497 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java @@ -32,6 +32,7 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync; +import org.apache.flink.runtime.io.network.TaskEventDispatcher; import org.apache.flink.runtime.io.network.api.serialization.AdaptiveSpanningRecordDeserializer; import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; @@ -100,6 +101,8 @@ public class MockEnvironment implements Environment { private final ClassLoader userCodeClassLoader; + private TaskEventDispatcher taskEventDispatcher = mock(TaskEventDispatcher.class); + public MockEnvironment(String taskName, long memorySize, MockInputSplitProvider inputSplitProvider, int bufferSize) { this(taskName, memorySize, inputSplitProvider, bufferSize, new Configuration(), new ExecutionConfig()); } @@ -323,6 +326,11 @@ public InputGate[] getAllInputGates() { return gates; } + @Override + public TaskEventDispatcher getTaskEventDispatcher() { + return taskEventDispatcher; + } + @Override public JobVertexID getJobVertexId() { return new JobVertexID(new byte[16]); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorITCase.java index e448cccbcaaf7..07b1eb9874ec2 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorITCase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorITCase.java @@ -30,6 +30,7 @@ import org.apache.flink.runtime.heartbeat.HeartbeatServices; import org.apache.flink.runtime.highavailability.TestingHighAvailabilityServices; import org.apache.flink.runtime.io.disk.iomanager.IOManager; +import org.apache.flink.runtime.io.network.MockNetworkEnvironment; import org.apache.flink.runtime.io.network.NetworkEnvironment; import org.apache.flink.runtime.jobmaster.JMTMRegistrationSuccess; import org.apache.flink.runtime.jobmaster.JobMasterGateway; @@ -119,7 +120,7 @@ public void testSlotAllocation() throws Exception { final TaskManagerLocation taskManagerLocation = new TaskManagerLocation(taskManagerResourceId, InetAddress.getLocalHost(), 1234); final MemoryManager memoryManager = mock(MemoryManager.class); final IOManager ioManager = mock(IOManager.class); - final NetworkEnvironment networkEnvironment = mock(NetworkEnvironment.class); + final NetworkEnvironment networkEnvironment = MockNetworkEnvironment.getMock(); final TaskManagerMetricGroup taskManagerMetricGroup = mock(TaskManagerMetricGroup.class); final BroadcastVariableManager broadcastVariableManager = mock(BroadcastVariableManager.class); final FileCache fileCache = mock(FileCache.class); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorTest.java index 714644514f166..b883cf0b970fe 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskExecutorTest.java @@ -49,6 +49,7 @@ import org.apache.flink.runtime.instance.InstanceID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.network.NetworkEnvironment; +import org.apache.flink.runtime.io.network.TaskEventDispatcher; import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.jobgraph.JobVertexID; @@ -715,9 +716,11 @@ public void testTaskSubmission() throws Exception { when(taskSlotTable.existsActiveSlot(eq(jobId), eq(allocationId))).thenReturn(true); when(taskSlotTable.addTask(any(Task.class))).thenReturn(true); + TaskEventDispatcher taskEventDispatcher = new TaskEventDispatcher(); final NetworkEnvironment networkEnvironment = mock(NetworkEnvironment.class); when(networkEnvironment.createKvStateTaskRegistry(eq(jobId), eq(jobVertexId))).thenReturn(mock(TaskKvStateRegistry.class)); + when(networkEnvironment.getTaskEventDispatcher()).thenReturn(taskEventDispatcher); final TaskManagerMetricGroup taskManagerMetricGroup = mock(TaskManagerMetricGroup.class); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java index 392dc29bf3ca8..1f1199237e10a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java @@ -40,6 +40,7 @@ import org.apache.flink.runtime.filecache.FileCache; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.network.NetworkEnvironment; +import org.apache.flink.runtime.io.network.TaskEventDispatcher; import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.io.network.partition.ResultPartitionManager; @@ -154,11 +155,13 @@ private static Task createTask() throws Exception { ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class); PartitionProducerStateChecker partitionProducerStateChecker = mock(PartitionProducerStateChecker.class); Executor executor = mock(Executor.class); + TaskEventDispatcher taskEventDispatcher = mock(TaskEventDispatcher.class); NetworkEnvironment networkEnvironment = mock(NetworkEnvironment.class); when(networkEnvironment.getResultPartitionManager()).thenReturn(partitionManager); when(networkEnvironment.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC); when(networkEnvironment.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class))) .thenReturn(mock(TaskKvStateRegistry.class)); + when(networkEnvironment.getTaskEventDispatcher()).thenReturn(taskEventDispatcher); TaskMetricGroup taskMetricGroup = mock(TaskMetricGroup.class); when(taskMetricGroup.getIOMetricGroup()).thenReturn(mock(TaskIOMetricGroup.class)); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java index d4cd0cfcf64d5..e4318b1b48fa3 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java @@ -40,6 +40,7 @@ import org.apache.flink.runtime.instance.ActorGateway; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.network.NetworkEnvironment; +import org.apache.flink.runtime.io.network.TaskEventDispatcher; import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; @@ -271,10 +272,12 @@ public void testExecutionFailsInNetworkRegistration() { ResultPartitionManager partitionManager = mock(ResultPartitionManager.class); ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class); PartitionProducerStateChecker partitionProducerStateChecker = mock(PartitionProducerStateChecker.class); + TaskEventDispatcher taskEventDispatcher = mock(TaskEventDispatcher.class); Executor executor = mock(Executor.class); NetworkEnvironment network = mock(NetworkEnvironment.class); when(network.getResultPartitionManager()).thenReturn(partitionManager); when(network.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC); + when(network.getTaskEventDispatcher()).thenReturn(taskEventDispatcher); doThrow(new RuntimeException("buffers")).when(network).registerTask(any(Task.class)); Task task = createTask(TestInvokableCorrect.class, blobCache, libCache, network, consumableNotifier, partitionProducerStateChecker, executor); @@ -625,6 +628,7 @@ public void testTriggerPartitionStateUpdate() throws Exception { when(libCache.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader()); PartitionProducerStateChecker partitionChecker = mock(PartitionProducerStateChecker.class); + TaskEventDispatcher taskEventDispatcher = mock(TaskEventDispatcher.class); ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class); NetworkEnvironment network = mock(NetworkEnvironment.class); @@ -632,6 +636,7 @@ public void testTriggerPartitionStateUpdate() throws Exception { when(network.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC); when(network.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class))) .thenReturn(mock(TaskKvStateRegistry.class)); + when(network.getTaskEventDispatcher()).thenReturn(taskEventDispatcher); createTask(InvokableBlockingInInvoke.class, blobCache, libCache, network, consumableNotifier, partitionChecker, Executors.directExecutor()); @@ -917,12 +922,14 @@ private Task createTask( ResultPartitionManager partitionManager = mock(ResultPartitionManager.class); ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class); PartitionProducerStateChecker partitionProducerStateChecker = mock(PartitionProducerStateChecker.class); + TaskEventDispatcher taskEventDispatcher = mock(TaskEventDispatcher.class); Executor executor = mock(Executor.class); NetworkEnvironment network = mock(NetworkEnvironment.class); when(network.getResultPartitionManager()).thenReturn(partitionManager); when(network.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC); when(network.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class))) .thenReturn(mock(TaskKvStateRegistry.class)); + when(network.getTaskEventDispatcher()).thenReturn(taskEventDispatcher); return createTask(invokable, blobCache, libCache, network, consumableNotifier, partitionProducerStateChecker, executor, config, execConfig); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java index 229f1eb08b859..2ccc2469c086e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java @@ -41,6 +41,7 @@ import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync; import org.apache.flink.runtime.io.network.NetworkEnvironment; +import org.apache.flink.runtime.io.network.TaskEventDispatcher; import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; @@ -156,6 +157,8 @@ public static void main(String[] args) throws Exception { final NetworkEnvironment networkEnvironment = mock(NetworkEnvironment.class); when(networkEnvironment.createKvStateTaskRegistry(jid, jobVertexId)).thenReturn(mock(TaskKvStateRegistry.class)); + TaskEventDispatcher taskEventDispatcher = mock(TaskEventDispatcher.class); + when(networkEnvironment.getTaskEventDispatcher()).thenReturn(taskEventDispatcher); final TaskManagerRuntimeInfo tmInfo = TaskManagerConfiguration.fromConfiguration(taskManagerConfig); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/BlockingCheckpointsTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/BlockingCheckpointsTest.java index 82642eab4dcc7..6868a93b9edb2 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/BlockingCheckpointsTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/BlockingCheckpointsTest.java @@ -43,6 +43,7 @@ import org.apache.flink.runtime.filecache.FileCache; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.network.NetworkEnvironment; +import org.apache.flink.runtime.io.network.TaskEventDispatcher; import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.jobgraph.JobVertexID; @@ -136,8 +137,10 @@ private static Task createTask(Configuration taskConfig) throws IOException { taskConfig); TaskKvStateRegistry mockKvRegistry = mock(TaskKvStateRegistry.class); + TaskEventDispatcher taskEventDispatcher = new TaskEventDispatcher(); NetworkEnvironment network = mock(NetworkEnvironment.class); when(network.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class))).thenReturn(mockKvRegistry); + when(network.getTaskEventDispatcher()).thenReturn(taskEventDispatcher); return new Task( jobInformation, diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java index f73499c13162b..71a37fed1b928 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java @@ -40,6 +40,7 @@ import org.apache.flink.runtime.filecache.FileCache; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.network.NetworkEnvironment; +import org.apache.flink.runtime.io.network.TaskEventDispatcher; import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.jobgraph.JobVertexID; @@ -174,9 +175,11 @@ private static Task createTask( StreamStateHandle state, int mode) throws IOException { + TaskEventDispatcher taskEventDispatcher = new TaskEventDispatcher(); NetworkEnvironment networkEnvironment = mock(NetworkEnvironment.class); when(networkEnvironment.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class))) .thenReturn(mock(TaskKvStateRegistry.class)); + when(networkEnvironment.getTaskEventDispatcher()).thenReturn(taskEventDispatcher); Collection keyedStateFromBackend = Collections.emptyList(); Collection keyedStateFromStream = Collections.emptyList(); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java index 231f59e97fb2a..6b6506ab043db 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java @@ -34,6 +34,7 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync; +import org.apache.flink.runtime.io.network.TaskEventDispatcher; import org.apache.flink.runtime.io.network.api.serialization.AdaptiveSpanningRecordDeserializer; import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer; @@ -107,6 +108,8 @@ public class StreamMockEnvironment implements Environment { private volatile boolean wasFailedExternally = false; + private TaskEventDispatcher taskEventDispatcher = mock(TaskEventDispatcher.class); + public StreamMockEnvironment(Configuration jobConfig, Configuration taskConfig, ExecutionConfig executionConfig, long memorySize, MockInputSplitProvider inputSplitProvider, int bufferSize) { this.taskInfo = new TaskInfo( @@ -303,6 +306,11 @@ public InputGate[] getAllInputGates() { return gates; } + @Override + public TaskEventDispatcher getTaskEventDispatcher() { + return taskEventDispatcher; + } + @Override public JobVertexID getJobVertexId() { return new JobVertexID(new byte[16]); diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java index 79e9583a8be72..0943763803444 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java @@ -40,6 +40,7 @@ import org.apache.flink.runtime.filecache.FileCache; import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync; import org.apache.flink.runtime.io.network.NetworkEnvironment; +import org.apache.flink.runtime.io.network.TaskEventDispatcher; import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.jobgraph.JobVertexID; @@ -133,8 +134,10 @@ public void testConcurrentAsyncCheckpointCannotFailFinishedStreamTask() throws E final TaskManagerRuntimeInfo taskManagerRuntimeInfo = new TestingTaskManagerRuntimeInfo(); + TaskEventDispatcher taskEventDispatcher = new TaskEventDispatcher(); final NetworkEnvironment networkEnv = mock(NetworkEnvironment.class); when(networkEnv.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class))).thenReturn(mock(TaskKvStateRegistry.class)); + when(networkEnv.getTaskEventDispatcher()).thenReturn(taskEventDispatcher); final Task task = new Task( jobInformation, diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java index 811d70019c9fa..9e366c5c0c6ce 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java @@ -48,6 +48,7 @@ import org.apache.flink.runtime.filecache.FileCache; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.network.NetworkEnvironment; +import org.apache.flink.runtime.io.network.TaskEventDispatcher; import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; import org.apache.flink.runtime.io.network.partition.ResultPartitionManager; @@ -866,12 +867,14 @@ public static Task createTask( ResultPartitionConsumableNotifier consumableNotifier = mock(ResultPartitionConsumableNotifier.class); PartitionProducerStateChecker partitionProducerStateChecker = mock(PartitionProducerStateChecker.class); Executor executor = mock(Executor.class); + TaskEventDispatcher taskEventDispatcher = new TaskEventDispatcher(); NetworkEnvironment network = mock(NetworkEnvironment.class); when(network.getResultPartitionManager()).thenReturn(partitionManager); when(network.getDefaultIOMode()).thenReturn(IOManager.IOMode.SYNC); when(network.createKvStateTaskRegistry(any(JobID.class), any(JobVertexID.class))) .thenReturn(mock(TaskKvStateRegistry.class)); + when(network.getTaskEventDispatcher()).thenReturn(taskEventDispatcher); JobInformation jobInformation = new JobInformation( new JobID(), From 88a0d0efc3e4daabbb6ac50ca0d5fa0481a333b6 Mon Sep 17 00:00:00 2001 From: Nico Kruber Date: Tue, 29 Aug 2017 18:53:48 +0200 Subject: [PATCH 3/3] [FLINK-7749][network] remove the ResultPartitionWriter wrapper Previous tasks, i.e. task event notification and buffer writing, are now handled completely by the TaskEventDispatcher and the ResultPartition, respectively. --- .../flink/runtime/execution/Environment.java | 6 +- .../io/network/NetworkEnvironment.java | 21 +---- .../io/network/api/writer/RecordWriter.java | 29 +++---- .../api/writer/ResultPartitionWriter.java | 82 ------------------ .../iterative/task/IterationHeadTask.java | 8 +- .../metrics/groups/TaskIOMetricGroup.java | 4 +- .../flink/runtime/operators/BatchTask.java | 2 +- .../taskmanager/RuntimeEnvironment.java | 16 ++-- .../flink/runtime/taskmanager/Task.java | 14 +--- .../io/network/NetworkEnvironmentTest.java | 7 +- .../network/api/writer/RecordWriterTest.java | 49 +++++------ .../api/writer/ResultPartitionWriterTest.java | 84 ------------------- .../PartialConsumePipelinedResultTest.java | 5 +- .../partition/ResultPartitionTest.java | 32 +++++++ .../SlotCountExceedingParallelismTest.java | 2 +- .../ScheduleOrUpdateConsumersTest.java | 4 +- .../operators/chaining/ChainTaskTest.java | 4 +- .../chaining/ChainedAllReduceDriverTest.java | 4 +- .../operators/testutils/DummyEnvironment.java | 6 +- .../operators/testutils/MockEnvironment.java | 18 ++-- ...TaskCancelAsyncProducerConsumerITCase.java | 6 +- .../runtime/taskmanager/TaskManagerTest.java | 2 +- .../flink/runtime/jobmanager/Tasks.scala | 4 +- .../runtime/io/StreamRecordWriter.java | 10 +-- .../runtime/tasks/OperatorChain.java | 4 +- .../streaming/runtime/tasks/StreamTask.java | 6 +- .../api/streamtask/MockRecordWriter.java | 2 +- .../runtime/io/StreamRecordWriterTest.java | 15 ++-- .../runtime/tasks/StreamMockEnvironment.java | 20 ++--- .../runtime/tasks/StreamTaskTestHarness.java | 3 - .../runtime/NetworkStackThroughputITCase.java | 4 +- 31 files changed, 154 insertions(+), 319 deletions(-) delete mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java delete mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriterTest.java diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java index ad66c5703313a..7c13adfa4b76b 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java @@ -30,7 +30,7 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.network.TaskEventDispatcher; -import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; +import org.apache.flink.runtime.io.network.partition.ResultPartition; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; @@ -203,9 +203,9 @@ public interface Environment { // Fields relevant to the I/O system. Should go into Task // -------------------------------------------------------------------------------------------- - ResultPartitionWriter getWriter(int index); + ResultPartition getOutputPartition(int index); - ResultPartitionWriter[] getAllWriters(); + ResultPartition[] getAllOutputPartitions(); InputGate getInputGate(int index); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java index 03628a596d069..afb369b815f70 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NetworkEnvironment.java @@ -22,7 +22,6 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.disk.iomanager.IOManager.IOMode; -import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; import org.apache.flink.runtime.io.network.buffer.BufferPool; import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; import org.apache.flink.runtime.io.network.partition.ResultPartition; @@ -160,12 +159,7 @@ public TaskKvStateRegistry createKvStateTaskRegistry(JobID jobId, JobVertexID jo // -------------------------------------------------------------------------------------------- public void registerTask(Task task) throws IOException { - final ResultPartition[] producedPartitions = task.getProducedPartitions(); - final ResultPartitionWriter[] writers = task.getAllWriters(); - - if (writers.length != producedPartitions.length) { - throw new IllegalStateException("Unequal number of writers and partitions."); - } + final ResultPartition[] producedPartitions = task.getAllOutputPartitions(); synchronized (lock) { if (isShutdown) { @@ -174,7 +168,6 @@ public void registerTask(Task task) throws IOException { for (int i = 0; i < producedPartitions.length; i++) { final ResultPartition partition = producedPartitions[i]; - final ResultPartitionWriter writer = writers[i]; // Buffer pool for the partition BufferPool bufferPool = null; @@ -201,7 +194,7 @@ public void registerTask(Task task) throws IOException { } // Register writer with task event dispatcher - taskEventDispatcher.registerPartition(writer.getPartitionId()); + taskEventDispatcher.registerPartition(partition.getPartitionId()); } // Setup the buffer pool for each buffer reader @@ -248,16 +241,10 @@ public void unregisterTask(Task task) { resultPartitionManager.releasePartitionsProducedBy(executionId, task.getFailureCause()); } - ResultPartitionWriter[] writers = task.getAllWriters(); - if (writers != null) { - for (ResultPartitionWriter writer : writers) { - taskEventDispatcher.unregisterPartition(writer.getPartitionId()); - } - } - - ResultPartition[] partitions = task.getProducedPartitions(); + ResultPartition[] partitions = task.getAllOutputPartitions(); if (partitions != null) { for (ResultPartition partition : partitions) { + taskEventDispatcher.unregisterPartition(partition.getPartitionId()); partition.destroyBufferPool(); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriter.java index c698ff5d7b839..1ca95c4705cb0 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriter.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriter.java @@ -22,6 +22,7 @@ import org.apache.flink.metrics.Counter; import org.apache.flink.metrics.SimpleCounter; import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; +import org.apache.flink.runtime.io.network.partition.ResultPartition; import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup; import org.apache.flink.runtime.event.AbstractEvent; import org.apache.flink.runtime.io.network.api.serialization.RecordSerializer; @@ -36,11 +37,11 @@ /** * A record-oriented runtime result writer. - *

- * The RecordWriter wraps the runtime's {@link ResultPartitionWriter} and takes care of + * + *

The RecordWriter wraps the runtime's {@link ResultPartition} and takes care of * serializing records into buffers. - *

- * Important: it is necessary to call {@link #flush()} after + * + *

Important: it is necessary to call {@link #flush()} after * all records have been written with {@link #emit(IOReadableWritable)}. This * ensures that all produced records are written to the output stream (incl. * partially filled ones). @@ -49,7 +50,7 @@ */ public class RecordWriter { - protected final ResultPartitionWriter targetPartition; + protected final ResultPartition targetPartition; private final ChannelSelector channelSelector; @@ -62,16 +63,16 @@ public class RecordWriter { private Counter numBytesOut = new SimpleCounter(); - public RecordWriter(ResultPartitionWriter writer) { - this(writer, new RoundRobinChannelSelector()); + public RecordWriter(ResultPartition partition) { + this(partition, new RoundRobinChannelSelector()); } @SuppressWarnings("unchecked") - public RecordWriter(ResultPartitionWriter writer, ChannelSelector channelSelector) { - this.targetPartition = writer; + public RecordWriter(ResultPartition partition, ChannelSelector channelSelector) { + this.targetPartition = partition; this.channelSelector = channelSelector; - this.numChannels = writer.getNumberOfOutputChannels(); + this.numChannels = partition.getNumberOfSubpartitions(); /** * The runtime exposes a channel abstraction for the produced results @@ -154,7 +155,7 @@ public void broadcastEvent(AbstractEvent event) throws IOException, InterruptedE // retain the buffer so that it can be recycled by each channel of targetPartition eventBuffer.retain(); - targetPartition.writeBuffer(eventBuffer, targetChannel); + targetPartition.add(eventBuffer, targetChannel); } } } finally { @@ -174,7 +175,7 @@ public void flush() throws IOException { if (buffer != null) { numBytesOut.inc(buffer.getSize()); - targetPartition.writeBuffer(buffer, targetChannel); + targetPartition.add(buffer, targetChannel); } } finally { serializer.clear(); @@ -209,7 +210,7 @@ public void setMetricGroup(TaskIOMetricGroup metrics) { } /** - * Writes the buffer to the {@link ResultPartitionWriter} and removes the + * Writes the buffer to the {@link ResultPartition} and removes the * buffer from the serializer state. * * Needs to be synchronized on the serializer! @@ -220,7 +221,7 @@ private void writeAndClearBuffer( RecordSerializer serializer) throws IOException { try { - targetPartition.writeBuffer(buffer, targetChannel); + targetPartition.add(buffer, targetChannel); } finally { serializer.clearCurrentBuffer(); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java deleted file mode 100644 index 225c2320cf028..0000000000000 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.runtime.io.network.api.writer; - -import org.apache.flink.runtime.io.network.buffer.Buffer; -import org.apache.flink.runtime.io.network.buffer.BufferProvider; -import org.apache.flink.runtime.io.network.partition.ResultPartition; -import org.apache.flink.runtime.io.network.partition.ResultPartitionID; - -import java.io.IOException; - -/** - * A buffer-oriented runtime result writer. - *

- * The {@link ResultPartitionWriter} is the runtime API for producing results. It - * supports two kinds of data to be sent: buffers and events. - */ -public class ResultPartitionWriter { - - private final ResultPartition partition; - - public ResultPartitionWriter(ResultPartition partition) { - this.partition = partition; - } - - // ------------------------------------------------------------------------ - // Attributes - // ------------------------------------------------------------------------ - - public ResultPartitionID getPartitionId() { - return partition.getPartitionId(); - } - - public BufferProvider getBufferProvider() { - return partition.getBufferProvider(); - } - - public int getNumberOfOutputChannels() { - return partition.getNumberOfSubpartitions(); - } - - public int getNumTargetKeyGroups() { - return partition.getNumTargetKeyGroups(); - } - - // ------------------------------------------------------------------------ - // Data processing - // ------------------------------------------------------------------------ - - public void writeBuffer(Buffer buffer, int targetChannel) throws IOException { - partition.add(buffer, targetChannel); - } - - /** - * Writes the given buffer to all available target channels. - * - * The buffer is taken over and used for each of the channels. - * It will be recycled afterwards. - * - * @param eventBuffer the buffer to write - * @throws IOException - */ - public void writeBufferToAllChannels(final Buffer eventBuffer) throws IOException { - partition.addToAllChannels(eventBuffer); - } -} diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationHeadTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationHeadTask.java index 977600e7385e0..2bd4b4aa74946 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationHeadTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationHeadTask.java @@ -31,7 +31,7 @@ import org.apache.flink.runtime.io.network.api.EndOfSuperstepEvent; import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; import org.apache.flink.runtime.io.network.api.writer.RecordWriter; -import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; +import org.apache.flink.runtime.io.network.partition.ResultPartition; import org.apache.flink.runtime.io.network.partition.ResultPartitionID; import org.apache.flink.runtime.iterative.concurrent.BlockingBackChannel; import org.apache.flink.runtime.iterative.concurrent.BlockingBackChannelBroker; @@ -94,7 +94,7 @@ public class IterationHeadTask extends AbstractIte private TypeSerializerFactory solutionTypeSerializer; - private ResultPartitionWriter toSync; + private ResultPartition toSync; private int feedbackDataInput; // workset or bulk partial solution @@ -129,7 +129,7 @@ protected void initOutputs() throws Exception { throw new Exception("Error: Inconsistent head task setup - wrong mapping of output gates."); } // now, we can instantiate the sync gate - this.toSync = getEnvironment().getWriter(syncGateIndex); + this.toSync = getEnvironment().getOutputPartition(syncGateIndex); } /** @@ -441,6 +441,6 @@ private void sendEventToSync(WorkerDoneEvent event) throws IOException, Interrup log.info(formatLogString("sending " + WorkerDoneEvent.class.getSimpleName() + " to sync")); } - this.toSync.writeBufferToAllChannels(EventSerializer.toBuffer(event)); + this.toSync.addToAllChannels(EventSerializer.toBuffer(event)); } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/metrics/groups/TaskIOMetricGroup.java b/flink-runtime/src/main/java/org/apache/flink/runtime/metrics/groups/TaskIOMetricGroup.java index e12ecd7d25c02..9c9a978f929db 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/metrics/groups/TaskIOMetricGroup.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/metrics/groups/TaskIOMetricGroup.java @@ -158,7 +158,7 @@ public OutputBuffersGauge(Task task) { public Integer getValue() { int totalBuffers = 0; - for (ResultPartition producedPartition : task.getProducedPartitions()) { + for (ResultPartition producedPartition : task.getAllOutputPartitions()) { totalBuffers += producedPartition.getNumberOfQueuedBuffers(); } @@ -211,7 +211,7 @@ public Float getValue() { int usedBuffers = 0; int bufferPoolSize = 0; - for (ResultPartition resultPartition : task.getProducedPartitions()) { + for (ResultPartition resultPartition : task.getAllOutputPartitions()) { usedBuffers += resultPartition.getBufferPool().bestEffortGetNumOfUsedBuffers(); bufferPoolSize += resultPartition.getBufferPool().getNumBuffers(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/BatchTask.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/BatchTask.java index 87b0a76d0e8e4..00e9af3489ce0 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/BatchTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/BatchTask.java @@ -1238,7 +1238,7 @@ public static Collector getOutputCollector(AbstractInvokable task, TaskCo } final RecordWriter> recordWriter = - new RecordWriter>(task.getEnvironment().getWriter(outputOffset + i), oe); + new RecordWriter>(task.getEnvironment().getOutputPartition(outputOffset + i), oe); recordWriter.setMetricGroup(task.getEnvironment().getMetricGroup().getIOMetricGroup()); diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java index 60738f02f2af6..846636912c007 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java @@ -31,7 +31,7 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.network.TaskEventDispatcher; -import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; +import org.apache.flink.runtime.io.network.partition.ResultPartition; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; @@ -68,7 +68,7 @@ public class RuntimeEnvironment implements Environment { private final Map> distCacheEntries; - private final ResultPartitionWriter[] writers; + private final ResultPartition[] outputPartitions; private final InputGate[] inputGates; private final TaskEventDispatcher taskEventDispatcher; @@ -102,7 +102,7 @@ public RuntimeEnvironment( TaskKvStateRegistry kvStateRegistry, InputSplitProvider splitProvider, Map> distCacheEntries, - ResultPartitionWriter[] writers, + ResultPartition[] outputPartitions, InputGate[] inputGates, TaskEventDispatcher taskEventDispatcher, CheckpointResponder checkpointResponder, @@ -125,7 +125,7 @@ public RuntimeEnvironment( this.kvStateRegistry = checkNotNull(kvStateRegistry); this.splitProvider = checkNotNull(splitProvider); this.distCacheEntries = checkNotNull(distCacheEntries); - this.writers = checkNotNull(writers); + this.outputPartitions = checkNotNull(outputPartitions); this.inputGates = checkNotNull(inputGates); this.taskEventDispatcher = checkNotNull(taskEventDispatcher); this.checkpointResponder = checkNotNull(checkpointResponder); @@ -222,13 +222,13 @@ public Map> getDistributedCacheEntries() { } @Override - public ResultPartitionWriter getWriter(int index) { - return writers[index]; + public ResultPartition getOutputPartition(int index) { + return outputPartitions[index]; } @Override - public ResultPartitionWriter[] getAllWriters() { - return writers; + public ResultPartition[] getAllOutputPartitions() { + return outputPartitions; } @Override diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java index 8f1727429ca86..14ca34786d887 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java @@ -51,7 +51,6 @@ import org.apache.flink.runtime.filecache.FileCache; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.network.NetworkEnvironment; -import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; import org.apache.flink.runtime.io.network.netty.PartitionProducerStateChecker; import org.apache.flink.runtime.io.network.partition.ResultPartition; import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; @@ -186,8 +185,6 @@ public class Task implements Runnable, TaskActions { private final ResultPartition[] producedPartitions; - private final ResultPartitionWriter[] writers; - private final SingleInputGate[] inputGates; private final Map inputGatesById; @@ -353,7 +350,6 @@ public Task( // Produced intermediate result partitions this.producedPartitions = new ResultPartition[resultPartitionDeploymentDescriptors.size()]; - this.writers = new ResultPartitionWriter[resultPartitionDeploymentDescriptors.size()]; int counter = 0; @@ -373,8 +369,6 @@ public Task( ioManager, desc.sendScheduleOrUpdateConsumersMessage()); - writers[counter] = new ResultPartitionWriter(producedPartitions[counter]); - ++counter; } @@ -438,15 +432,11 @@ public Configuration getTaskConfiguration() { return this.taskConfiguration; } - public ResultPartitionWriter[] getAllWriters() { - return writers; - } - public SingleInputGate[] getAllInputGates() { return inputGates; } - public ResultPartition[] getProducedPartitions() { + public ResultPartition[] getAllOutputPartitions() { return producedPartitions; } @@ -665,7 +655,7 @@ else if (current == ExecutionState.CANCELING) { jobConfiguration, taskConfiguration, userCodeClassLoader, memoryManager, ioManager, broadcastVariableManager, accumulatorRegistry, kvStateRegistry, inputSplitProvider, - distributedCacheEntries, writers, inputGates, network.getTaskEventDispatcher(), + distributedCacheEntries, producedPartitions, inputGates, network.getTaskEventDispatcher(), checkpointResponder, taskManagerConfig, metrics, this); // let the task code create its readers and writers diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java index b95669122d85d..56af8d0867ed1 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NetworkEnvironmentTest.java @@ -21,7 +21,6 @@ import org.apache.flink.api.common.JobID; import org.apache.flink.core.memory.MemoryType; import org.apache.flink.runtime.io.disk.iomanager.IOManager; -import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; import org.apache.flink.runtime.io.network.buffer.BufferPool; import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; import org.apache.flink.runtime.io.network.partition.ResultPartition; @@ -77,9 +76,6 @@ public void testRegisterTaskUsesBoundedBuffers() throws Exception { ResultPartition rp3 = createResultPartition(ResultPartitionType.PIPELINED_BOUNDED, 2); ResultPartition rp4 = createResultPartition(ResultPartitionType.PIPELINED_BOUNDED, 8); final ResultPartition[] resultPartitions = new ResultPartition[] {rp1, rp2, rp3, rp4}; - final ResultPartitionWriter[] resultPartitionWriters = new ResultPartitionWriter[] { - new ResultPartitionWriter(rp1), new ResultPartitionWriter(rp2), - new ResultPartitionWriter(rp3), new ResultPartitionWriter(rp4)}; // input gates final SingleInputGate[] inputGates = new SingleInputGate[] { @@ -90,8 +86,7 @@ public void testRegisterTaskUsesBoundedBuffers() throws Exception { // overall task to register Task task = mock(Task.class); - when(task.getProducedPartitions()).thenReturn(resultPartitions); - when(task.getAllWriters()).thenReturn(resultPartitionWriters); + when(task.getAllOutputPartitions()).thenReturn(resultPartitions); when(task.getAllInputGates()).thenReturn(inputGates); network.registerTask(task); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java index 98d4f65dfadee..beef90bb7138e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java @@ -35,6 +35,7 @@ import org.apache.flink.runtime.io.network.buffer.BufferPool; import org.apache.flink.runtime.io.network.buffer.BufferProvider; import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool; +import org.apache.flink.runtime.io.network.partition.ResultPartition; import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent; import org.apache.flink.runtime.io.network.util.TestBufferFactory; import org.apache.flink.runtime.io.network.util.TestInfiniteBufferProvider; @@ -74,7 +75,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -@PrepareForTest({ResultPartitionWriter.class, EventSerializer.class}) +@PrepareForTest({EventSerializer.class}) @RunWith(PowerMockRunner.class) public class RecordWriterTest { @@ -120,7 +121,7 @@ public Buffer answer(InvocationOnMock invocation) throws Throwable { BufferProvider bufferProvider = mock(BufferProvider.class); when(bufferProvider.requestBufferBlocking()).thenAnswer(request); - ResultPartitionWriter partitionWriter = createResultPartitionWriter(bufferProvider); + ResultPartition partitionWriter = createResultPartition(bufferProvider); final RecordWriter recordWriter = new RecordWriter(partitionWriter); @@ -156,7 +157,7 @@ public Void call() throws Exception { // Verify that buffer have been requested, but only one has been written out. verify(bufferProvider, times(2)).requestBufferBlocking(); - verify(partitionWriter, times(1)).writeBuffer(any(Buffer.class), anyInt()); + verify(partitionWriter, times(1)).add(any(Buffer.class), anyInt()); // Verify that the written out buffer has only been recycled once // (by the partition writer). @@ -179,9 +180,9 @@ public void testClearBuffersAfterExceptionInPartitionWriter() throws Exception { buffers = new NetworkBufferPool(1, 1024, MemoryType.HEAP); bufferPool = spy(buffers.createBufferPool(1, Integer.MAX_VALUE)); - ResultPartitionWriter partitionWriter = mock(ResultPartitionWriter.class); + ResultPartition partitionWriter = mock(ResultPartition.class); when(partitionWriter.getBufferProvider()).thenReturn(checkNotNull(bufferPool)); - when(partitionWriter.getNumberOfOutputChannels()).thenReturn(1); + when(partitionWriter.getNumberOfSubpartitions()).thenReturn(1); // Recycle buffer and throw Exception doAnswer(new Answer() { @@ -192,7 +193,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable { throw new RuntimeException("Expected test Exception"); } - }).when(partitionWriter).writeBuffer(any(Buffer.class), anyInt()); + }).when(partitionWriter).add(any(Buffer.class), anyInt()); RecordWriter recordWriter = new RecordWriter<>(partitionWriter); @@ -213,7 +214,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable { } // Verify expected methods have been called - verify(partitionWriter, times(1)).writeBuffer(any(Buffer.class), anyInt()); + verify(partitionWriter, times(1)).add(any(Buffer.class), anyInt()); verify(bufferPool, times(1)).requestBufferBlocking(); try { @@ -228,7 +229,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable { } // Verify expected methods have been called - verify(partitionWriter, times(2)).writeBuffer(any(Buffer.class), anyInt()); + verify(partitionWriter, times(2)).add(any(Buffer.class), anyInt()); verify(bufferPool, times(2)).requestBufferBlocking(); try { @@ -242,7 +243,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable { } // Verify expected methods have been called - verify(partitionWriter, times(3)).writeBuffer(any(Buffer.class), anyInt()); + verify(partitionWriter, times(3)).add(any(Buffer.class), anyInt()); verify(bufferPool, times(3)).requestBufferBlocking(); try { @@ -257,7 +258,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable { } // Verify expected methods have been called - verify(partitionWriter, times(4)).writeBuffer(any(Buffer.class), anyInt()); + verify(partitionWriter, times(4)).add(any(Buffer.class), anyInt()); verify(bufferPool, times(4)).requestBufferBlocking(); try { @@ -272,7 +273,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable { } // Verify expected methods have been called - verify(partitionWriter, times(5)).writeBuffer(any(Buffer.class), anyInt()); + verify(partitionWriter, times(5)).add(any(Buffer.class), anyInt()); verify(bufferPool, times(5)).requestBufferBlocking(); } finally { @@ -293,14 +294,14 @@ public void testSerializerClearedAfterClearBuffers() throws Exception { final Buffer buffer = TestBufferFactory.createBuffer(16); - ResultPartitionWriter partitionWriter = createResultPartitionWriter( + ResultPartition partitionWriter = createResultPartition( createBufferProvider(buffer)); RecordWriter recordWriter = new RecordWriter(partitionWriter); // Fill a buffer, but don't write it out. recordWriter.emit(new IntValue(0)); - verify(partitionWriter, never()).writeBuffer(any(Buffer.class), anyInt()); + verify(partitionWriter, never()).add(any(Buffer.class), anyInt()); // Clear all buffers. recordWriter.clearBuffers(); @@ -326,7 +327,7 @@ public void testBroadcastEventNoRecords() throws Exception { BufferProvider bufferProvider = createBufferProvider(bufferSize); - ResultPartitionWriter partitionWriter = createCollectingPartitionWriter(queues, bufferProvider); + ResultPartition partitionWriter = createCollectingPartitionWriter(queues, bufferProvider); RecordWriter writer = new RecordWriter<>(partitionWriter, new RoundRobin()); CheckpointBarrier barrier = new CheckpointBarrier(Integer.MAX_VALUE + 919192L, Integer.MAX_VALUE + 18828228L, CheckpointOptions.forFullCheckpoint()); @@ -362,7 +363,7 @@ public void testBroadcastEventMixedRecords() throws Exception { BufferProvider bufferProvider = createBufferProvider(bufferSize); - ResultPartitionWriter partitionWriter = createCollectingPartitionWriter(queues, bufferProvider); + ResultPartition partitionWriter = createCollectingPartitionWriter(queues, bufferProvider); RecordWriter writer = new RecordWriter<>(partitionWriter, new RoundRobin()); CheckpointBarrier barrier = new CheckpointBarrier(Integer.MAX_VALUE + 1292L, Integer.MAX_VALUE + 199L, CheckpointOptions.forFullCheckpoint()); @@ -420,7 +421,7 @@ public void testBroadcastEventBufferReferenceCounting() throws Exception { ArrayDeque[] queues = new ArrayDeque[]{new ArrayDeque(), new ArrayDeque()}; - ResultPartitionWriter partition = + ResultPartition partition = createCollectingPartitionWriter(queues, new TestInfiniteBufferProvider()); RecordWriter writer = new RecordWriter<>(partition); @@ -446,15 +447,15 @@ public void testBroadcastEventBufferReferenceCounting() throws Exception { * the mocking. Ideally, we will refactor all of this mess in order to make * our lives easier and test it better. */ - private ResultPartitionWriter createCollectingPartitionWriter( + private ResultPartition createCollectingPartitionWriter( final Queue[] queues, BufferProvider bufferProvider) throws IOException { int numChannels = queues.length; - ResultPartitionWriter partitionWriter = mock(ResultPartitionWriter.class); + ResultPartition partitionWriter = mock(ResultPartition.class); when(partitionWriter.getBufferProvider()).thenReturn(checkNotNull(bufferProvider)); - when(partitionWriter.getNumberOfOutputChannels()).thenReturn(numChannels); + when(partitionWriter.getNumberOfSubpartitions()).thenReturn(numChannels); doAnswer(new Answer() { @Override @@ -472,7 +473,7 @@ public Void answer(InvocationOnMock invocationOnMock) throws Throwable { } return null; } - }).when(partitionWriter).writeBuffer(any(Buffer.class), anyInt()); + }).when(partitionWriter).add(any(Buffer.class), anyInt()); return partitionWriter; } @@ -507,12 +508,12 @@ private BufferProvider createBufferProvider(Buffer... buffers) return bufferProvider; } - private ResultPartitionWriter createResultPartitionWriter(BufferProvider bufferProvider) + private ResultPartition createResultPartition(BufferProvider bufferProvider) throws IOException { - ResultPartitionWriter partitionWriter = mock(ResultPartitionWriter.class); + ResultPartition partitionWriter = mock(ResultPartition.class); when(partitionWriter.getBufferProvider()).thenReturn(checkNotNull(bufferProvider)); - when(partitionWriter.getNumberOfOutputChannels()).thenReturn(1); + when(partitionWriter.getNumberOfSubpartitions()).thenReturn(1); // Recycle each written buffer. doAnswer(new Answer() { @@ -522,7 +523,7 @@ public Void answer(InvocationOnMock invocation) throws Throwable { return null; } - }).when(partitionWriter).writeBuffer(any(Buffer.class), anyInt()); + }).when(partitionWriter).add(any(Buffer.class), anyInt()); return partitionWriter; } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriterTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriterTest.java deleted file mode 100644 index 2e5816d7ca8a8..0000000000000 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriterTest.java +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.runtime.io.network.api.writer; - -import org.apache.flink.api.common.JobID; -import org.apache.flink.runtime.io.disk.iomanager.IOManager; -import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent; -import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; -import org.apache.flink.runtime.io.network.buffer.Buffer; -import org.apache.flink.runtime.io.network.partition.ResultPartition; -import org.apache.flink.runtime.io.network.partition.ResultPartitionConsumableNotifier; -import org.apache.flink.runtime.io.network.partition.ResultPartitionID; -import org.apache.flink.runtime.io.network.partition.ResultPartitionManager; -import org.apache.flink.runtime.io.network.partition.ResultPartitionType; -import org.apache.flink.runtime.taskmanager.TaskActions; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.powermock.core.classloader.annotations.PrepareForTest; -import org.powermock.modules.junit4.PowerMockRunner; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; -import static org.mockito.Mockito.mock; - -@PrepareForTest({ResultPartitionWriter.class}) -@RunWith(PowerMockRunner.class) -public class ResultPartitionWriterTest { - - // --------------------------------------------------------------------------------------------- - // Resource release tests - // --------------------------------------------------------------------------------------------- - - /** - * Tests that event buffers are properly recycled when broadcasting events - * to multiple channels. - * - * @throws Exception - */ - @Test - public void testWriteBufferToAllChannelsReferenceCounting() throws Exception { - Buffer buffer = EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE); - - ResultPartition partition = new ResultPartition( - "TestTask", - mock(TaskActions.class), - new JobID(), - new ResultPartitionID(), - ResultPartitionType.PIPELINED, - 2, - 2, - mock(ResultPartitionManager.class), - mock(ResultPartitionConsumableNotifier.class), - mock(IOManager.class), - false); - ResultPartitionWriter partitionWriter = - new ResultPartitionWriter( - partition); - - partitionWriter.writeBufferToAllChannels(buffer); - - // Verify added to all queues, i.e. two buffers in total - assertEquals(2, partition.getTotalNumberOfBuffers()); - // release the buffers in the partition - partition.release(); - - assertTrue(buffer.isRecycled()); - } -} diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/PartialConsumePipelinedResultTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/PartialConsumePipelinedResultTest.java index 0346e483d4d38..cecefb789b541 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/PartialConsumePipelinedResultTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/PartialConsumePipelinedResultTest.java @@ -22,7 +22,6 @@ import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; import org.apache.flink.configuration.TaskManagerOptions; -import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; import org.apache.flink.runtime.jobgraph.DistributionPattern; @@ -110,11 +109,11 @@ public static class SlowBufferSender extends AbstractInvokable { @Override public void invoke() throws Exception { - final ResultPartitionWriter writer = getEnvironment().getWriter(0); + final ResultPartition writer = getEnvironment().getOutputPartition(0); for (int i = 0; i < 8; i++) { final Buffer buffer = writer.getBufferProvider().requestBufferBlocking(); - writer.writeBuffer(buffer, 0); + writer.add(buffer, 0); Thread.sleep(50); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java index 0cd359197efcb..9e0664334a9be 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java @@ -26,6 +26,8 @@ import org.junit.Assert; import org.junit.Test; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -121,6 +123,36 @@ public void testAddOnReleasedBlockingPartition() throws Exception { testAddOnReleasedPartition(ResultPartitionType.BLOCKING); } + /** + * Tests that buffers are properly recycled after {@link ResultPartition#addToAllChannels(Buffer)}. + */ + @Test + public void testWriteBufferToAllChannelsReferenceCounting() throws Exception { + Buffer buffer = TestBufferFactory.createBuffer(TestBufferFactory.BUFFER_SIZE); + + ResultPartition partition = new ResultPartition( + "TestTask", + mock(TaskActions.class), + new JobID(), + new ResultPartitionID(), + ResultPartitionType.PIPELINED, + 2, + 2, + mock(ResultPartitionManager.class), + mock(ResultPartitionConsumableNotifier.class), + mock(IOManager.class), + false); + + partition.addToAllChannels(buffer); + + // Verify added to all queues, i.e. two buffers in total + assertEquals(2, partition.getTotalNumberOfBuffers()); + // release the buffers in the partition + partition.release(); + + assertTrue(buffer.isRecycled()); + } + /** * Tests {@link ResultPartition#add} on a partition which has already been released. * diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/SlotCountExceedingParallelismTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/SlotCountExceedingParallelismTest.java index 49b11b587ff07..8f07e9498141e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/SlotCountExceedingParallelismTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/SlotCountExceedingParallelismTest.java @@ -124,7 +124,7 @@ public static class RoundRobinSubtaskIndexSender extends AbstractInvokable { @Override public void invoke() throws Exception { - RecordWriter writer = new RecordWriter<>(getEnvironment().getWriter(0)); + RecordWriter writer = new RecordWriter<>(getEnvironment().getOutputPartition(0)); final int numberOfTimesToSend = getTaskConfiguration().getInteger(CONFIG_KEY, 0); final IntValue subtaskIndex = new IntValue( diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/scheduler/ScheduleOrUpdateConsumersTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/scheduler/ScheduleOrUpdateConsumersTest.java index 9c781ec101c18..bec2bc84bfe8d 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/scheduler/ScheduleOrUpdateConsumersTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/scheduler/ScheduleOrUpdateConsumersTest.java @@ -132,10 +132,10 @@ public void invoke() throws Exception { // The order of intermediate result creation in the job graph specifies which produced // result partition is pipelined/blocking. final RecordWriter pipelinedWriter = - new RecordWriter<>(getEnvironment().getWriter(0)); + new RecordWriter<>(getEnvironment().getOutputPartition(0)); final RecordWriter blockingWriter = - new RecordWriter<>(getEnvironment().getWriter(1)); + new RecordWriter<>(getEnvironment().getOutputPartition(1)); writers.add(pipelinedWriter); writers.add(blockingWriter); diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/chaining/ChainTaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/chaining/ChainTaskTest.java index fb8ed684a3215..b54312ba6e3b4 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/chaining/ChainTaskTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/chaining/ChainTaskTest.java @@ -25,11 +25,11 @@ import org.apache.flink.api.common.functions.GroupCombineFunction; import org.apache.flink.api.common.functions.GroupReduceFunction; import org.apache.flink.api.common.operators.util.UserCodeClassWrapper; +import org.apache.flink.runtime.io.network.partition.ResultPartition; import org.apache.flink.runtime.testutils.recordutils.RecordComparatorFactory; import org.apache.flink.runtime.testutils.recordutils.RecordSerializerFactory; import org.apache.flink.api.common.functions.RichGroupReduceFunction; import org.apache.flink.configuration.Configuration; -import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; import org.apache.flink.runtime.operators.DriverStrategy; import org.apache.flink.runtime.operators.BatchTask; import org.apache.flink.runtime.operators.FlatMapDriver; @@ -51,7 +51,7 @@ import org.powermock.modules.junit4.PowerMockRunner; @RunWith(PowerMockRunner.class) -@PrepareForTest({Task.class, ResultPartitionWriter.class}) +@PrepareForTest({Task.class, ResultPartition.class}) @PowerMockIgnore({"javax.management.*", "com.sun.jndi.*"}) public class ChainTaskTest extends TaskTestBase { diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/chaining/ChainedAllReduceDriverTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/chaining/ChainedAllReduceDriverTest.java index 43ec8b8f0dbf1..5f7e5fdbde5a0 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/chaining/ChainedAllReduceDriverTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/chaining/ChainedAllReduceDriverTest.java @@ -22,7 +22,7 @@ import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.common.operators.util.UserCodeClassWrapper; import org.apache.flink.configuration.Configuration; -import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; +import org.apache.flink.runtime.io.network.partition.ResultPartition; import org.apache.flink.runtime.operators.BatchTask; import org.apache.flink.runtime.operators.DriverStrategy; import org.apache.flink.runtime.operators.FlatMapDriver; @@ -46,7 +46,7 @@ import java.util.List; @RunWith(PowerMockRunner.class) -@PrepareForTest({Task.class, ResultPartitionWriter.class}) +@PrepareForTest({Task.class, ResultPartition.class}) public class ChainedAllReduceDriverTest extends TaskTestBase { private static final int MEMORY_MANAGER_SIZE = 1024 * 1024 * 3; diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java index c888e72cfa6e4..e2e0de3e0b0fd 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java @@ -31,7 +31,7 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.io.disk.iomanager.IOManager; import org.apache.flink.runtime.io.network.TaskEventDispatcher; -import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; +import org.apache.flink.runtime.io.network.partition.ResultPartition; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; @@ -171,12 +171,12 @@ public void failExternally(Throwable cause) { } @Override - public ResultPartitionWriter getWriter(int index) { + public ResultPartition getOutputPartition(int index) { return null; } @Override - public ResultPartitionWriter[] getAllWriters() { + public ResultPartition[] getAllOutputPartitions() { return null; } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java index c8ca6541a1497..342f539d7b0c0 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java @@ -35,10 +35,10 @@ import org.apache.flink.runtime.io.network.TaskEventDispatcher; import org.apache.flink.runtime.io.network.api.serialization.AdaptiveSpanningRecordDeserializer; import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer; -import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.BufferProvider; import org.apache.flink.runtime.io.network.buffer.BufferRecycler; +import org.apache.flink.runtime.io.network.partition.ResultPartition; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; import org.apache.flink.runtime.io.network.partition.consumer.IteratorWrappingTestSingleInputGate; import org.apache.flink.runtime.jobgraph.JobVertexID; @@ -87,7 +87,7 @@ public class MockEnvironment implements Environment { private final List inputs; - private final List outputs; + private final List outputs; private final JobID jobID = new JobID(); @@ -159,7 +159,7 @@ public MockEnvironment( this.jobConfiguration = new Configuration(); this.taskConfiguration = taskConfiguration; this.inputs = new LinkedList(); - this.outputs = new LinkedList(); + this.outputs = new LinkedList<>(); this.memManager = new MemoryManager(memorySize, 1); this.ioManager = new IOManagerAsync(); @@ -203,8 +203,8 @@ public Buffer answer(InvocationOnMock invocationOnMock) throws Throwable { } }); - ResultPartitionWriter mockWriter = mock(ResultPartitionWriter.class); - when(mockWriter.getNumberOfOutputChannels()).thenReturn(1); + ResultPartition mockWriter = mock(ResultPartition.class); + when(mockWriter.getNumberOfSubpartitions()).thenReturn(1); when(mockWriter.getBufferProvider()).thenReturn(mockBufferProvider); final Record record = new Record(); @@ -234,7 +234,7 @@ public Void answer(InvocationOnMock invocationOnMock) throws Throwable { return null; } - }).when(mockWriter).writeBuffer(any(Buffer.class), anyInt()); + }).when(mockWriter).add(any(Buffer.class), anyInt()); outputs.add(mockWriter); } @@ -305,13 +305,13 @@ public Map> getDistributedCacheEntries() { } @Override - public ResultPartitionWriter getWriter(int index) { + public ResultPartition getOutputPartition(int index) { return outputs.get(index); } @Override - public ResultPartitionWriter[] getAllWriters() { - return outputs.toArray(new ResultPartitionWriter[outputs.size()]); + public ResultPartition[] getAllOutputPartitions() { + return outputs.toArray(new ResultPartition[outputs.size()]); } @Override diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskCancelAsyncProducerConsumerITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskCancelAsyncProducerConsumerITCase.java index 69f1a4998cf5e..01ef12ae927d3 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskCancelAsyncProducerConsumerITCase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskCancelAsyncProducerConsumerITCase.java @@ -23,7 +23,7 @@ import org.apache.flink.configuration.TaskManagerOptions; import org.apache.flink.runtime.instance.ActorGateway; import org.apache.flink.runtime.io.network.api.writer.RecordWriter; -import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; +import org.apache.flink.runtime.io.network.partition.ResultPartition; import org.apache.flink.runtime.io.network.partition.ResultPartitionType; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; import org.apache.flink.runtime.jobgraph.DistributionPattern; @@ -194,7 +194,7 @@ public static class AsyncProducer extends AbstractInvokable { @Override public void invoke() throws Exception { - Thread producer = new ProducerThread(getEnvironment().getWriter(0)); + Thread producer = new ProducerThread(getEnvironment().getOutputPartition(0)); // Publish the async producer for the main test Thread ASYNC_PRODUCER_THREAD = producer; @@ -218,7 +218,7 @@ private static class ProducerThread extends Thread { private final RecordWriter recordWriter; - public ProducerThread(ResultPartitionWriter partitionWriter) { + public ProducerThread(ResultPartition partitionWriter) { this.recordWriter = new RecordWriter<>(partitionWriter); } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskManagerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskManagerTest.java index fdae2516e5ea6..4239f67a519f3 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskManagerTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskManagerTest.java @@ -2074,7 +2074,7 @@ public static final class TestInvokableRecordCancel extends AbstractInvokable { @Override public void invoke() throws Exception { final Object o = new Object(); - RecordWriter recordWriter = new RecordWriter<>(getEnvironment().getWriter(0)); + RecordWriter recordWriter = new RecordWriter<>(getEnvironment().getOutputPartition(0)); for (int i = 0; i < 1024; i++) { recordWriter.emit(new IntValue(42)); diff --git a/flink-runtime/src/test/scala/org/apache/flink/runtime/jobmanager/Tasks.scala b/flink-runtime/src/test/scala/org/apache/flink/runtime/jobmanager/Tasks.scala index fabd66b106dca..cae6d99565d9d 100644 --- a/flink-runtime/src/test/scala/org/apache/flink/runtime/jobmanager/Tasks.scala +++ b/flink-runtime/src/test/scala/org/apache/flink/runtime/jobmanager/Tasks.scala @@ -29,7 +29,7 @@ object Tasks { class Sender extends AbstractInvokable{ override def invoke(): Unit = { - val writer = new RecordWriter[IntValue](getEnvironment.getWriter(0)) + val writer = new RecordWriter[IntValue](getEnvironment.getOutputPartition(0)) try{ writer.emit(new IntValue(42)) @@ -49,7 +49,7 @@ object Tasks { classOf[IntValue], getEnvironment.getTaskManagerInfo.getTmpDirectories) - val writer = new RecordWriter[IntValue](getEnvironment.getWriter(0)) + val writer = new RecordWriter[IntValue](getEnvironment.getOutputPartition(0)) try { while (true) { diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamRecordWriter.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamRecordWriter.java index 6775bc4a9ba33..34031f5e23993 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamRecordWriter.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/StreamRecordWriter.java @@ -21,7 +21,7 @@ import org.apache.flink.core.io.IOReadableWritable; import org.apache.flink.runtime.io.network.api.writer.ChannelSelector; import org.apache.flink.runtime.io.network.api.writer.RecordWriter; -import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; +import org.apache.flink.runtime.io.network.partition.ResultPartition; import java.io.IOException; @@ -49,14 +49,14 @@ public class StreamRecordWriter extends RecordWrit /** The exception encountered in the flushing thread. */ private Throwable flusherException; - public StreamRecordWriter(ResultPartitionWriter writer, ChannelSelector channelSelector, long timeout) { - this(writer, channelSelector, timeout, null); + public StreamRecordWriter(ResultPartition partition, ChannelSelector channelSelector, long timeout) { + this(partition, channelSelector, timeout, null); } - public StreamRecordWriter(ResultPartitionWriter writer, ChannelSelector channelSelector, + public StreamRecordWriter(ResultPartition partition, ChannelSelector channelSelector, long timeout, String taskName) { - super(writer, channelSelector); + super(partition, channelSelector); checkArgument(timeout >= -1); diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java index 38279822ec2d4..b73732dcf7fe0 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java @@ -25,7 +25,7 @@ import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; -import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; +import org.apache.flink.runtime.io.network.partition.ResultPartition; import org.apache.flink.runtime.metrics.groups.OperatorMetricGroup; import org.apache.flink.runtime.plugable.SerializationDelegate; import org.apache.flink.streaming.api.collector.selector.CopyingDirectedOutput; @@ -370,7 +370,7 @@ private RecordWriterOutput createStreamOutput( LOG.debug("Using partitioner {} for output {} of task ", outputPartitioner, outputIndex, taskName); - ResultPartitionWriter bufferWriter = taskEnvironment.getWriter(outputIndex); + ResultPartition bufferWriter = taskEnvironment.getOutputPartition(outputIndex); // we initialize the partitioner here with the number of key groups (aka max. parallelism) if (outputPartitioner instanceof ConfigurableStreamPartitioner) { diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index 631cdfcd7a8c7..c1127557862f2 100644 --- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -32,7 +32,7 @@ import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker; import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; -import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; +import org.apache.flink.runtime.io.network.partition.ResultPartition; import org.apache.flink.runtime.jobgraph.OperatorID; import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; import org.apache.flink.runtime.jobgraph.tasks.StatefulTask; @@ -599,9 +599,9 @@ private boolean performCheckpoint( final CancelCheckpointMarker message = new CancelCheckpointMarker(checkpointMetaData.getCheckpointId()); Exception exception = null; - for (ResultPartitionWriter output : getEnvironment().getAllWriters()) { + for (ResultPartition output : getEnvironment().getAllOutputPartitions()) { try { - output.writeBufferToAllChannels(EventSerializer.toBuffer(message)); + output.addToAllChannels(EventSerializer.toBuffer(message)); } catch (Exception e) { exception = ExceptionUtils.firstOrSuppressed( new Exception("Could not send cancel checkpoint marker to downstream tasks.", e), diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/streamtask/MockRecordWriter.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/streamtask/MockRecordWriter.java index 781a2167e2efd..b20060e4d5acf 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/streamtask/MockRecordWriter.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/streamtask/MockRecordWriter.java @@ -33,7 +33,7 @@ public class MockRecordWriter extends RecordWriter emittedRecords; public MockRecordWriter(DataSourceTask inputBase, Class>> outputClass) { - super(inputBase.getEnvironment().getWriter(0)); + super(inputBase.getEnvironment().getOutputPartition(0)); } public boolean initList() { diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/StreamRecordWriterTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/StreamRecordWriterTest.java index d11413927bd93..565e1c88bdf53 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/StreamRecordWriterTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/StreamRecordWriterTest.java @@ -21,11 +21,11 @@ import org.apache.flink.core.io.IOReadableWritable; import org.apache.flink.core.memory.MemorySegmentFactory; import org.apache.flink.runtime.io.network.api.writer.ChannelSelector; -import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; import org.apache.flink.runtime.io.network.api.writer.RoundRobinChannelSelector; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.BufferProvider; import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler; +import org.apache.flink.runtime.io.network.partition.ResultPartition; import org.apache.flink.types.LongValue; import org.junit.Test; @@ -41,8 +41,7 @@ import static org.mockito.Mockito.when; /** - * This test uses the PowerMockRunner runner to work around the fact that the - * {@link ResultPartitionWriter} class is final. + * Tests for {@link StreamRecordWriter}. */ public class StreamRecordWriterTest { @@ -54,7 +53,7 @@ public class StreamRecordWriterTest { public void testPropagateAsyncFlushError() { FailingWriter testWriter = null; try { - ResultPartitionWriter mockResultPartitionWriter = getMockWriter(5); + ResultPartition mockResultPartitionWriter = getMockWriter(5); // test writer that flushes every 5ms and fails after 3 flushes testWriter = new FailingWriter(mockResultPartitionWriter, @@ -86,7 +85,7 @@ public void testPropagateAsyncFlushError() { } } - private static ResultPartitionWriter getMockWriter(int numPartitions) throws Exception { + private static ResultPartition getMockWriter(int numPartitions) throws Exception { BufferProvider mockProvider = mock(BufferProvider.class); when(mockProvider.requestBufferBlocking()).thenAnswer(new Answer() { @Override @@ -97,9 +96,9 @@ public Buffer answer(InvocationOnMock invocation) { } }); - ResultPartitionWriter mockWriter = mock(ResultPartitionWriter.class); + ResultPartition mockWriter = mock(ResultPartition.class); when(mockWriter.getBufferProvider()).thenReturn(mockProvider); - when(mockWriter.getNumberOfOutputChannels()).thenReturn(numPartitions); + when(mockWriter.getNumberOfSubpartitions()).thenReturn(numPartitions); return mockWriter; } @@ -110,7 +109,7 @@ private static class FailingWriter extends StreamR private int flushesBeforeException; - private FailingWriter(ResultPartitionWriter writer, ChannelSelector channelSelector, + private FailingWriter(ResultPartition writer, ChannelSelector channelSelector, long timeout, int flushesBeforeException) { super(writer, channelSelector, timeout); this.flushesBeforeException = flushesBeforeException; diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java index 6b6506ab043db..f38bcc2c03cf3 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java @@ -38,10 +38,10 @@ import org.apache.flink.runtime.io.network.api.serialization.AdaptiveSpanningRecordDeserializer; import org.apache.flink.runtime.io.network.api.serialization.EventSerializer; import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer; -import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; import org.apache.flink.runtime.io.network.buffer.Buffer; import org.apache.flink.runtime.io.network.buffer.BufferProvider; import org.apache.flink.runtime.io.network.buffer.BufferRecycler; +import org.apache.flink.runtime.io.network.partition.ResultPartition; import org.apache.flink.runtime.io.network.partition.consumer.InputGate; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.jobgraph.tasks.InputSplitProvider; @@ -92,7 +92,7 @@ public class StreamMockEnvironment implements Environment { private final List inputs; - private final List outputs; + private final List outputs; private final JobID jobID = new JobID(); @@ -121,7 +121,7 @@ public StreamMockEnvironment(Configuration jobConfig, Configuration taskConfig, this.jobConfiguration = jobConfig; this.taskConfiguration = taskConfig; this.inputs = new LinkedList(); - this.outputs = new LinkedList(); + this.outputs = new LinkedList<>(); this.memManager = new MemoryManager(memorySize, 1); this.ioManager = new IOManagerAsync(); @@ -160,8 +160,8 @@ public Buffer answer(InvocationOnMock invocationOnMock) throws Throwable { } }); - ResultPartitionWriter mockWriter = mock(ResultPartitionWriter.class); - when(mockWriter.getNumberOfOutputChannels()).thenReturn(1); + ResultPartition mockWriter = mock(ResultPartition.class); + when(mockWriter.getNumberOfSubpartitions()).thenReturn(1); when(mockWriter.getBufferProvider()).thenReturn(mockBufferProvider); final RecordDeserializer> recordDeserializer = new AdaptiveSpanningRecordDeserializer>(); @@ -176,7 +176,7 @@ public Void answer(InvocationOnMock invocationOnMock) throws Throwable { addBufferToOutputList(recordDeserializer, delegate, buffer, outputList); return null; } - }).when(mockWriter).writeBuffer(any(Buffer.class), anyInt()); + }).when(mockWriter).add(any(Buffer.class), anyInt()); doAnswer(new Answer() { @@ -186,7 +186,7 @@ public Void answer(InvocationOnMock invocationOnMock) throws Throwable { addBufferToOutputList(recordDeserializer, delegate, buffer, outputList); return null; } - }).when(mockWriter).writeBufferToAllChannels(any(Buffer.class)); + }).when(mockWriter).addToAllChannels(any(Buffer.class)); outputs.add(mockWriter); } @@ -285,13 +285,13 @@ public Map> getDistributedCacheEntries() { } @Override - public ResultPartitionWriter getWriter(int index) { + public ResultPartition getOutputPartition(int index) { return outputs.get(index); } @Override - public ResultPartitionWriter[] getAllWriters() { - return outputs.toArray(new ResultPartitionWriter[outputs.size()]); + public ResultPartition[] getAllOutputPartitions() { + return outputs.toArray(new ResultPartition[outputs.size()]); } @Override diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java index 19d48e195f2ef..8db43fd654f6e 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java @@ -57,9 +57,6 @@ *

After setting up everything the Task can be invoked using {@link #invoke()}. This will start * a new Thread to execute the Task. Use {@link #waitForTaskCompletion()} to wait for the Task * thread to finish. - * - *

When using this you need to add the following line to your test class to setup Powermock: - * {@code {@literal @}PrepareForTest({ResultPartitionWriter.class})} */ public class StreamTaskTestHarness { diff --git a/flink-tests/src/test/java/org/apache/flink/test/runtime/NetworkStackThroughputITCase.java b/flink-tests/src/test/java/org/apache/flink/test/runtime/NetworkStackThroughputITCase.java index 92bf6d60a8efe..0aa116c75c9e1 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/runtime/NetworkStackThroughputITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/runtime/NetworkStackThroughputITCase.java @@ -174,7 +174,7 @@ public static class SpeedTestProducer extends AbstractInvokable { @Override public void invoke() throws Exception { - RecordWriter writer = new RecordWriter<>(getEnvironment().getWriter(0)); + RecordWriter writer = new RecordWriter<>(getEnvironment().getOutputPartition(0)); try { // Determine the amount of data to send per subtask @@ -219,7 +219,7 @@ public void invoke() throws Exception { SpeedTestRecord.class, getEnvironment().getTaskManagerInfo().getTmpDirectories()); - RecordWriter writer = new RecordWriter<>(getEnvironment().getWriter(0)); + RecordWriter writer = new RecordWriter<>(getEnvironment().getOutputPartition(0)); try { SpeedTestRecord record;